diff --git a/changelog.d/13787.misc b/changelog.d/13787.misc new file mode 100644 index 000000000000..a9b93717f059 --- /dev/null +++ b/changelog.d/13787.misc @@ -0,0 +1 @@ +Optimise get rooms for user calls. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 9c2c3a0e687c..bffeb08e49dd 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -268,11 +268,9 @@ async def get_user_ids_changed( possibly_left = possibly_changed | possibly_left # Double check if we still share rooms with the given user. - users_rooms = await self.store.get_rooms_for_users_with_stream_ordering( - possibly_left - ) + users_rooms = await self.store.get_rooms_for_users(possibly_left) for changed_user_id, entries in users_rooms.items(): - if any(e.room_id in room_ids for e in entries): + if any(rid in room_ids for rid in entries): possibly_left.discard(changed_user_id) else: possibly_joined.discard(changed_user_id) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b2a48c6d1077..14587d8b996d 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1507,16 +1507,14 @@ async def _generate_sync_entry_for_device_list( since_token.device_list_key ) if changed_users is not None: - result = await self.store.get_rooms_for_users_with_stream_ordering( - changed_users - ) + result = await self.store.get_rooms_for_users(changed_users) for changed_user_id, entries in result.items(): # Check if the changed user shares any rooms with the user, # or if the changed user is the syncing user (as we always # want to include device list updates of their own devices). if user_id == changed_user_id or any( - e.room_id in joined_rooms for e in entries + rid in joined_rooms for rid in entries ): users_that_have_changed.add(changed_user_id) else: @@ -1550,13 +1548,9 @@ async def _generate_sync_entry_for_device_list( newly_left_users.update(left_users) # Remove any users that we still share a room with. - left_users_rooms = ( - await self.store.get_rooms_for_users_with_stream_ordering( - newly_left_users - ) - ) + left_users_rooms = await self.store.get_rooms_for_users(newly_left_users) for user_id, entries in left_users_rooms.items(): - if any(e.room_id in joined_rooms for e in entries): + if any(rid in joined_rooms for rid in entries): newly_left_users.discard(user_id) return DeviceListUpdates( diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 5c88992a73dc..3787ede30691 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -90,6 +90,10 @@ def _invalidate_state_caches( self._attempt_to_invalidate_cache( "get_user_in_room_with_profile", (room_id, user_id) ) + self._attempt_to_invalidate_cache( + "get_rooms_for_user_with_stream_ordering", (user_id,) + ) + self._attempt_to_invalidate_cache("get_rooms_for_user", (user_id,)) # Purge other caches based on room state. self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 5cb1b56dfd47..2f562c6066ba 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -205,6 +205,7 @@ def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: self.get_rooms_for_user_with_stream_ordering.invalidate( (data.state_key,) ) + self.get_rooms_for_user.invalidate((data.state_key,)) else: raise Exception("Unknown events stream row type %s" % (row.type,)) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 505c8d361165..96f28ab9ff63 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -15,7 +15,6 @@ import logging from typing import ( TYPE_CHECKING, - Callable, Collection, Dict, FrozenSet, @@ -690,117 +689,109 @@ def _get_rooms_for_user_with_stream_ordering_txn( for room_id, instance, stream_id in txn ) - @cachedList( - cached_method_name="get_rooms_for_user_with_stream_ordering", - list_name="user_ids", - ) - async def get_rooms_for_users_with_stream_ordering( + async def get_users_server_still_shares_room_with( self, user_ids: Collection[str] - ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: - """A batched version of `get_rooms_for_user_with_stream_ordering`. - - Returns: - Map from user_id to set of rooms that is currently in. + ) -> Set[str]: + """Given a list of users return the set that the server still share a + room with. """ + + if not user_ids: + return set() + return await self.db_pool.runInteraction( - "get_rooms_for_users_with_stream_ordering", - self._get_rooms_for_users_with_stream_ordering_txn, + "get_users_server_still_shares_room_with", + self.get_users_server_still_shares_room_with_txn, user_ids, ) - def _get_rooms_for_users_with_stream_ordering_txn( - self, txn: LoggingTransaction, user_ids: Collection[str] - ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: + def get_users_server_still_shares_room_with_txn( + self, + txn: LoggingTransaction, + user_ids: Collection[str], + ) -> Set[str]: + if not user_ids: + return set() + + sql = """ + SELECT state_key FROM current_state_events + WHERE + type = 'm.room.member' + AND membership = 'join' + AND %s + GROUP BY state_key + """ clause, args = make_in_list_sql_clause( - self.database_engine, - "c.state_key", - user_ids, + self.database_engine, "state_key", user_ids ) - if self._current_state_events_membership_up_to_date: - sql = f""" - SELECT c.state_key, room_id, e.instance_name, e.stream_ordering - FROM current_state_events AS c - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND c.membership = ? - AND {clause} - """ - else: - sql = f""" - SELECT c.state_key, room_id, e.instance_name, e.stream_ordering - FROM current_state_events AS c - INNER JOIN room_memberships AS m USING (room_id, event_id) - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND m.membership = ? - AND {clause} - """ + txn.execute(sql % (clause,), args) - txn.execute(sql, [Membership.JOIN] + args) + return {row[0] for row in txn} - result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = { - user_id: set() for user_id in user_ids - } - for user_id, room_id, instance, stream_id in txn: - result[user_id].add( - GetRoomsForUserWithStreamOrdering( - room_id, PersistedEventPosition(instance, stream_id) - ) - ) - - return {user_id: frozenset(v) for user_id, v in result.items()} + @cached(max_entries=500000, iterable=True) + async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]: + """Returns a set of room_ids the user is currently joined to. - async def get_users_server_still_shares_room_with( - self, user_ids: Collection[str] - ) -> Set[str]: - """Given a list of users return the set that the server still share a - room with. + If a remote user only returns rooms this server is currently + participating in. """ + rooms = self.get_rooms_for_user_with_stream_ordering.cache.get_immediate( + (user_id,), + None, + update_metrics=False, + ) + if rooms: + return frozenset(r.room_id for r in rooms) - if not user_ids: - return set() - - def _get_users_server_still_shares_room_with_txn( - txn: LoggingTransaction, - ) -> Set[str]: - sql = """ - SELECT state_key FROM current_state_events - WHERE - type = 'm.room.member' - AND membership = 'join' - AND %s - GROUP BY state_key - """ + room_ids = await self.db_pool.simple_select_onecol( + table="current_state_events", + keyvalues={ + "type": EventTypes.Member, + "membership": Membership.JOIN, + "state_key": user_id, + }, + retcol="room_id", + desc="get_rooms_for_user", + ) - clause, args = make_in_list_sql_clause( - self.database_engine, "state_key", user_ids - ) + return frozenset(room_ids) - txn.execute(sql % (clause,), args) + @cachedList( + cached_method_name="get_rooms_for_user", + list_name="user_ids", + ) + async def get_rooms_for_users( + self, user_ids: Collection[str] + ) -> Dict[str, FrozenSet[str]]: + """A batched version of `get_rooms_for_user`. - return {row[0] for row in txn} + Returns: + Map from user_id to set of rooms that is currently in. + """ - return await self.db_pool.runInteraction( - "get_users_server_still_shares_room_with", - _get_users_server_still_shares_room_with_txn, + rows = await self.db_pool.simple_select_many_batch( + table="current_state_events", + column="state_key", + iterable=user_ids, + retcols=( + "state_key", + "room_id", + ), + keyvalues={ + "type": EventTypes.Member, + "membership": Membership.JOIN, + }, + desc="get_rooms_for_users", ) - async def get_rooms_for_user( - self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None - ) -> FrozenSet[str]: - """Returns a set of room_ids the user is currently joined to. + user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids} - If a remote user only returns rooms this server is currently - participating in. - """ - rooms = await self.get_rooms_for_user_with_stream_ordering( - user_id, on_invalidate=on_invalidate - ) - return frozenset(r.room_id for r in rooms) + for row in rows: + user_rooms[row["state_key"]].add(row["room_id"]) + + return {key: frozenset(rooms) for key, rooms in user_rooms.items()} @cached(max_entries=10000) async def does_pair_of_users_share_a_room( diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index e3f38fbcc5ce..ab5c101eb708 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -159,6 +159,7 @@ def test_unknown_room_version(self): # Blow away caches (supported room versions can only change due to a restart). self.store.get_rooms_for_user_with_stream_ordering.invalidate_all() + self.store.get_rooms_for_user.invalidate_all() self.get_success(self.store._get_event_cache.clear()) self.store._event_ref.clear()