Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add cancellation support to ReadWriteLock #12120

Merged
merged 16 commits into from
Mar 14, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12120.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for cancellation to `ReadWriteLock`.
8 changes: 4 additions & 4 deletions synapse/handlers/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ async def _purge_history(
"""
self._purges_in_progress_by_room.add(room_id)
try:
with await self.pagination_lock.write(room_id):
async with self.pagination_lock.write(room_id):
await self.storage.purge_events.purge_history(
room_id, token, delete_local_events
)
Expand Down Expand Up @@ -406,7 +406,7 @@ async def purge_room(self, room_id: str, force: bool = False) -> None:
room_id: room to be purged
force: set true to skip checking for joined users.
"""
with await self.pagination_lock.write(room_id):
async with self.pagination_lock.write(room_id):
# first check that we have no users in this room
if not force:
joined = await self.store.is_host_joined(room_id, self._server_name)
Expand Down Expand Up @@ -448,7 +448,7 @@ async def get_messages(

room_token = from_token.room_key

with await self.pagination_lock.read(room_id):
async with self.pagination_lock.read(room_id):
(
membership,
member_event_id,
Expand Down Expand Up @@ -615,7 +615,7 @@ async def _shutdown_and_purge_room(

self._purges_in_progress_by_room.add(room_id)
try:
with await self.pagination_lock.write(room_id):
async with self.pagination_lock.write(room_id):
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN
self._delete_by_id[
delete_id
Expand Down
128 changes: 97 additions & 31 deletions synapse/util/async_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import inspect
import itertools
import logging
from contextlib import contextmanager
from contextlib import asynccontextmanager, contextmanager
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Collection,
Expand All @@ -40,7 +41,7 @@
)

import attr
from typing_extensions import ContextManager, Literal
from typing_extensions import AsyncContextManager, Literal

from twisted.internet import defer
from twisted.internet.defer import CancelledError
Expand Down Expand Up @@ -491,7 +492,7 @@ class ReadWriteLock:

Example:

with await read_write_lock.read("test_key"):
async with read_write_lock.read("test_key"):
# do some work
"""

Expand All @@ -514,22 +515,24 @@ def __init__(self) -> None:
# Latest writer queued
self.key_to_current_writer: Dict[str, defer.Deferred] = {}

async def read(self, key: str) -> ContextManager:
new_defer: "defer.Deferred[None]" = defer.Deferred()
def read(self, key: str) -> AsyncContextManager:
@asynccontextmanager
async def _ctx_manager() -> AsyncIterator[None]:
new_defer: "defer.Deferred[None]" = defer.Deferred()

curr_readers = self.key_to_current_readers.setdefault(key, set())
curr_writer = self.key_to_current_writer.get(key, None)
curr_readers = self.key_to_current_readers.setdefault(key, set())
curr_writer = self.key_to_current_writer.get(key, None)

curr_readers.add(new_defer)
curr_readers.add(new_defer)

# We wait for the latest writer to finish writing. We can safely ignore
# any existing readers... as they're readers.
if curr_writer:
await make_deferred_yieldable(curr_writer)

@contextmanager
def _ctx_manager() -> Iterator[None]:
try:
# We wait for the latest writer to finish writing. We can safely ignore
# any existing readers... as they're readers.
# May raise a `CancelledError` if the `Deferred` wrapping us is
# cancelled. The `Deferred` we are waiting on must not be cancelled,
# since we do not own it.
if curr_writer:
await make_deferred_yieldable(stop_cancellation(curr_writer))
yield
finally:
with PreserveLoggingContext():
Expand All @@ -538,29 +541,36 @@ def _ctx_manager() -> Iterator[None]:

return _ctx_manager()

async def write(self, key: str) -> ContextManager:
new_defer: "defer.Deferred[None]" = defer.Deferred()
def write(self, key: str) -> AsyncContextManager:
@asynccontextmanager
async def _ctx_manager() -> AsyncIterator[None]:
new_defer: "defer.Deferred[None]" = defer.Deferred()

curr_readers = self.key_to_current_readers.get(key, set())
curr_writer = self.key_to_current_writer.get(key, None)
curr_readers = self.key_to_current_readers.get(key, set())
curr_writer = self.key_to_current_writer.get(key, None)

# We wait on all latest readers and writer.
to_wait_on = list(curr_readers)
if curr_writer:
to_wait_on.append(curr_writer)
# We wait on all latest readers and writer.
to_wait_on = list(curr_readers)
if curr_writer:
to_wait_on.append(curr_writer)

# We can clear the list of current readers since the new writer waits
# for them to finish.
curr_readers.clear()
self.key_to_current_writer[key] = new_defer
# We can clear the list of current readers since `new_defer` waits
# for them to finish.
curr_readers.clear()
self.key_to_current_writer[key] = new_defer

await make_deferred_yieldable(defer.gatherResults(to_wait_on))

@contextmanager
def _ctx_manager() -> Iterator[None]:
to_wait_on_defer = defer.gatherResults(to_wait_on)
try:
# Wait for all current readers and the latest writer to finish.
# May raise a `CancelledError` immediately after the wait if the
# `Deferred` wrapping us is cancelled. We must only release the lock
# once we have acquired it, hence the delay.
squahtx marked this conversation as resolved.
Show resolved Hide resolved
await make_deferred_yieldable(
delay_cancellation(to_wait_on_defer, all=True)
)
yield
finally:
# Release the lock.
with PreserveLoggingContext():
new_defer.callback(None)
# `self.key_to_current_writer[key]` may be missing if there was another
Expand Down Expand Up @@ -695,3 +705,59 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
new_deferred: defer.Deferred[T] = defer.Deferred()
deferred.chainDeferred(new_deferred)
return new_deferred


def delay_cancellation(deferred: "defer.Deferred[T]", all: bool) -> "defer.Deferred[T]":
"""Delay cancellation of a `Deferred` until it resolves.

Has the same effect as `stop_cancellation`, but the returned `Deferred` will not
resolve with a `CancelledError` until the original `Deferred` resolves.

Args:
deferred: The `Deferred` to protect against cancellation. Must not follow the
Synapse logcontext rules if `all` is `False`.
all: `True` to delay multiple cancellations. `False` to delay only the first
cancellation.

Returns:
A new `Deferred`, which will contain the result of the original `Deferred`.
The new `Deferred` will not propagate cancellation through to the original.
When cancelled, the new `Deferred` will wait until the original `Deferred`
resolves before failing with a `CancelledError`.

The new `Deferred` will only follow the Synapse logcontext rules if `all` is
`True` and `deferred` follows the Synapse logcontext rules. Otherwise the new
`Deferred` should be wrapped with `make_deferred_yieldable`.
"""

def cancel_errback(failure: Failure) -> Union[Failure, "defer.Deferred[T]"]:
"""Insert another `Deferred` into the chain to delay cancellation.

Called when the original `Deferred` resolves or the new `Deferred` is
cancelled.
"""
failure.trap(CancelledError)

if deferred.called and not deferred.paused:
# The `CancelledError` came from the original `Deferred`. Pass it through.
return failure

# Construct another `Deferred` that will only fail with the `CancelledError`
# once the original `Deferred` resolves.
delay_deferred: "defer.Deferred[T]" = defer.Deferred()
deferred.chainDeferred(delay_deferred)

if all:
# Intercept cancellations recursively. Each cancellation will cause another
# `Deferred` to be inserted into the chain.
delay_deferred.addErrback(cancel_errback)

# Override the result with the `CancelledError`.
delay_deferred.addBoth(lambda _: failure)

return delay_deferred

new_deferred: "defer.Deferred[T]" = defer.Deferred()
deferred.chainDeferred(new_deferred)
new_deferred.addErrback(cancel_errback)
return new_deferred
Loading