Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add cancellation support to ReadWriteLock #12120

Merged
merged 16 commits into from
Mar 14, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions synapse/handlers/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ async def _purge_history(
"""
self._purges_in_progress_by_room.add(room_id)
try:
with await self.pagination_lock.write(room_id):
async with self.pagination_lock.write(room_id):
await self.storage.purge_events.purge_history(
room_id, token, delete_local_events
)
Expand Down Expand Up @@ -405,7 +405,7 @@ async def purge_room(self, room_id: str, force: bool = False) -> None:
room_id: room to be purged
force: set true to skip checking for joined users.
"""
with await self.pagination_lock.write(room_id):
async with self.pagination_lock.write(room_id):
# first check that we have no users in this room
if not force:
joined = await self.store.is_host_joined(room_id, self._server_name)
Expand Down Expand Up @@ -447,7 +447,7 @@ async def get_messages(

room_token = from_token.room_key

with await self.pagination_lock.read(room_id):
async with self.pagination_lock.read(room_id):
(
membership,
member_event_id,
Expand Down Expand Up @@ -612,7 +612,7 @@ async def _shutdown_and_purge_room(

self._purges_in_progress_by_room.add(room_id)
try:
with await self.pagination_lock.write(room_id):
async with self.pagination_lock.write(room_id):
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN
self._delete_by_id[
delete_id
Expand Down
31 changes: 15 additions & 16 deletions synapse/util/async_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import inspect
import itertools
import logging
from contextlib import contextmanager
from contextlib import asynccontextmanager, contextmanager
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Collection,
Expand All @@ -40,7 +41,7 @@
)

import attr
from typing_extensions import ContextManager
from typing_extensions import AsyncContextManager

from twisted.internet import defer
from twisted.internet.defer import CancelledError
Expand Down Expand Up @@ -483,7 +484,7 @@ class ReadWriteLock:

Example:

with await read_write_lock.read("test_key"):
async with read_write_lock.read("test_key"):
# do some work
"""

Expand All @@ -506,22 +507,21 @@ def __init__(self) -> None:
# Latest writer queued
self.key_to_current_writer: Dict[str, defer.Deferred] = {}

async def read(self, key: str) -> ContextManager:
def read(self, key: str) -> AsyncContextManager:
new_defer: "defer.Deferred[None]" = defer.Deferred()

curr_readers = self.key_to_current_readers.setdefault(key, set())
curr_writer = self.key_to_current_writer.get(key, None)

curr_readers.add(new_defer)

# We wait for the latest writer to finish writing. We can safely ignore
# any existing readers... as they're readers.
if curr_writer:
await make_deferred_yieldable(curr_writer)

@contextmanager
def _ctx_manager() -> Iterator[None]:
@asynccontextmanager
async def _ctx_manager() -> AsyncIterator[None]:
try:
# We wait for the latest writer to finish writing. We can safely ignore
# any existing readers... as they're readers.
if curr_writer:
await make_deferred_yieldable(curr_writer)
yield
finally:
with PreserveLoggingContext():
Expand All @@ -530,7 +530,7 @@ def _ctx_manager() -> Iterator[None]:

return _ctx_manager()

async def write(self, key: str) -> ContextManager:
def write(self, key: str) -> AsyncContextManager:
new_defer: "defer.Deferred[None]" = defer.Deferred()

curr_readers = self.key_to_current_readers.get(key, set())
Expand All @@ -546,11 +546,10 @@ async def write(self, key: str) -> ContextManager:
curr_readers.clear()
self.key_to_current_writer[key] = new_defer

await make_deferred_yieldable(defer.gatherResults(to_wait_on))

@contextmanager
def _ctx_manager() -> Iterator[None]:
@asynccontextmanager
async def _ctx_manager() -> AsyncIterator[None]:
try:
await make_deferred_yieldable(defer.gatherResults(to_wait_on))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should, for strict correctness, do all the to_wait_on calculation etc inside the ctx_manager. Otherwise it's possible for someone to do:

ctx = self.pagination_lock.write(room_id)
await something_asynchronous()
with await ctx:
    ...

... which is a stupid thing to do, but still.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm all in favour of removing footguns. I'lk move the setup of read() and write() into the context manager.

yield
finally:
with PreserveLoggingContext():
Expand Down
95 changes: 55 additions & 40 deletions tests/util/test_rwlock.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import AsyncContextManager, Callable, Tuple

from twisted.internet import defer
from twisted.internet.defer import Deferred

Expand All @@ -32,58 +34,71 @@ def _assert_called_before_not_after(self, lst, first_false):

def test_rwlock(self):
rwlock = ReadWriteLock()
key = "key"

def start_reader_or_writer(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a bit confusing that this is slightly different to the class-level _start_reader_or_writer. Can we combine them, even if it means returning three deferreds rather than two?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, let's combine them.

read_or_write: Callable[[str], AsyncContextManager]
) -> Tuple["Deferred[None]", "Deferred[None]"]:
acquired_d: "Deferred[None]" = Deferred()
release_d: "Deferred[None]" = Deferred()

async def action():
async with read_or_write(key):
acquired_d.callback(None)
await release_d

key = object()
defer.ensureDeferred(action())
return acquired_d, release_d

ds = [
rwlock.read(key), # 0
rwlock.read(key), # 1
rwlock.write(key), # 2
rwlock.write(key), # 3
rwlock.read(key), # 4
rwlock.read(key), # 5
rwlock.write(key), # 6
start_reader_or_writer(rwlock.read), # 0
start_reader_or_writer(rwlock.read), # 1
start_reader_or_writer(rwlock.write), # 2
start_reader_or_writer(rwlock.write), # 3
start_reader_or_writer(rwlock.read), # 4
start_reader_or_writer(rwlock.read), # 5
start_reader_or_writer(rwlock.write), # 6
]
ds = [defer.ensureDeferred(d) for d in ds]
# `Deferred`s that resolve when each reader or writer acquires the lock.
acquired_ds = [acquired_d for acquired_d, _release_d in ds]
# `Deferred`s that will trigger the release of locks when resolved.
release_ds = [release_d for _acquired_d, release_d in ds]

self._assert_called_before_not_after(ds, 2)
self._assert_called_before_not_after(acquired_ds, 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's hard to follow what _assert_called_before_not_after is doing. Assuming you've figured it out, could you add a comment/docstring/type annotation/etc or two?

Actually generally this test could do with a few comments, for example:

Suggested change
self._assert_called_before_not_after(acquired_ds, 2)
# we should have acquired the locks for the first two readers, but nothing else.
self._assert_called_before_not_after(acquired_ds, 2)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found this test case hard to understand too. I'll try to make it more readable.


with ds[0].result:
self._assert_called_before_not_after(ds, 2)
self._assert_called_before_not_after(ds, 2)
self._assert_called_before_not_after(acquired_ds, 2)
release_ds[0].callback(None)
self._assert_called_before_not_after(acquired_ds, 2)

with ds[1].result:
self._assert_called_before_not_after(ds, 2)
self._assert_called_before_not_after(ds, 3)
self._assert_called_before_not_after(acquired_ds, 2)
release_ds[1].callback(None)
self._assert_called_before_not_after(acquired_ds, 3)

with ds[2].result:
self._assert_called_before_not_after(ds, 3)
self._assert_called_before_not_after(ds, 4)
self._assert_called_before_not_after(acquired_ds, 3)
release_ds[2].callback(None)
self._assert_called_before_not_after(acquired_ds, 4)

with ds[3].result:
self._assert_called_before_not_after(ds, 4)
self._assert_called_before_not_after(ds, 6)
self._assert_called_before_not_after(acquired_ds, 4)
release_ds[3].callback(None)
self._assert_called_before_not_after(acquired_ds, 6)

with ds[5].result:
self._assert_called_before_not_after(ds, 6)
self._assert_called_before_not_after(ds, 6)
self._assert_called_before_not_after(acquired_ds, 6)
release_ds[5].callback(None)
self._assert_called_before_not_after(acquired_ds, 6)

with ds[4].result:
self._assert_called_before_not_after(ds, 6)
self._assert_called_before_not_after(ds, 7)
self._assert_called_before_not_after(acquired_ds, 6)
release_ds[4].callback(None)
self._assert_called_before_not_after(acquired_ds, 7)

with ds[6].result:
pass
release_ds[6].callback(None)

d = defer.ensureDeferred(rwlock.write(key))
self.assertTrue(d.called)
with d.result:
pass
acquired_d, release_d = start_reader_or_writer(rwlock.write)
self.assertTrue(acquired_d.called)
release_d.callback(None)

d = defer.ensureDeferred(rwlock.read(key))
self.assertTrue(d.called)
with d.result:
pass
acquired_d, release_d = start_reader_or_writer(rwlock.read)
self.assertTrue(acquired_d.called)
release_d.callback(None)

def test_lock_handoff_to_nonblocking_writer(self):
"""Test a writer handing the lock to another writer that completes instantly."""
Expand All @@ -93,11 +108,11 @@ def test_lock_handoff_to_nonblocking_writer(self):
unblock: "Deferred[None]" = Deferred()

async def blocking_write():
with await rwlock.write(key):
async with rwlock.write(key):
await unblock

async def nonblocking_write():
with await rwlock.write(key):
async with rwlock.write(key):
pass

d1 = defer.ensureDeferred(blocking_write())
Expand Down