From b6373b07dc4fba7a62d1b425a8eb9582ae9775e8 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Tue, 17 May 2022 17:14:39 -0700 Subject: [PATCH 1/5] refactor resolve_state_groups_for_events to use _get_state_group_for_events --- synapse/state/__init__.py | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 54e41d537584..b1a89ca98dd4 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -288,7 +288,6 @@ async def compute_event_context( # # first of all, figure out the state before the event # - if old_state: # if we're given the state before the event, then we use that state_ids_before_event: StateMap[str] = { @@ -419,33 +418,38 @@ async def resolve_state_groups_for_events( """ logger.debug("resolve_state_groups event_ids %s", event_ids) - # map from state group id to the state in that state group (where - # 'state' is a map from state key to event id) - # dict[int, dict[(str, str), str]] - state_groups_ids = await self.state_store.get_state_groups_ids( - room_id, event_ids - ) - - if len(state_groups_ids) == 0: - return _StateCacheEntry(state={}, state_group=None) - elif len(state_groups_ids) == 1: - name, state_list = list(state_groups_ids.items()).pop() + state_groups = await self.state_store._get_state_group_for_events(event_ids) - prev_group, delta_ids = await self.state_store.get_state_group_delta(name) + state_group_ids = state_groups.values() + # check if each event has same state group id, if so there's no state to resolve + state_group_ids_set = set(state_group_ids) + if len(state_group_ids_set) == 1: + (state_group_id,) = state_group_ids_set + state = await self.state_store._get_state_for_groups(state_group_ids_set) + prev_group, delta_ids = await self.state_store.get_state_group_delta( + state_group_id + ) return _StateCacheEntry( - state=state_list, - state_group=name, + state=state[state_group_id], + state_group=state_group_id, prev_group=prev_group, delta_ids=delta_ids, ) + elif len(state_group_ids_set) == 0: + return _StateCacheEntry(state={}, state_group=None) room_version = await self.store.get_room_version_id(room_id) + state_group_ids_list = list(state_group_ids) + state_to_resolve = await self.state_store._get_state_for_groups( + state_group_ids_list + ) + result = await self._state_resolution_handler.resolve_state_groups( room_id, room_version, - state_groups_ids, + state_to_resolve, None, state_res_store=StateResolutionStore(self.store), ) From c3388e6f1fd47c1fa7a097a556493222bf913d2b Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Tue, 17 May 2022 17:15:16 -0700 Subject: [PATCH 2/5] add _get_state_group_for_events and _get_state_for_groups to dummy store --- tests/test_state.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_state.py b/tests/test_state.py index 651ec1c7d4bd..26f8ad526a2e 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -129,6 +129,19 @@ def register_event_id_state_group(self, event_id, state_group): async def get_room_version_id(self, room_id): return RoomVersions.V1.identifier + async def _get_state_group_for_events(self, event_ids): + res = {} + for event in event_ids: + res[event] = self._event_to_state_group[event] + return res + + async def _get_state_for_groups(self, groups): + res = {} + for group in groups: + state = self._group_to_state[group] + res[group] = state + return res + class DictObj(dict): def __init__(self, **kwargs): From 265f766eff32b24aeb894dea5ee507d1c819efa1 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Tue, 17 May 2022 17:24:19 -0700 Subject: [PATCH 3/5] newsfragment --- changelog.d/12775.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/12775.misc diff --git a/changelog.d/12775.misc b/changelog.d/12775.misc new file mode 100644 index 000000000000..5648c75ec766 --- /dev/null +++ b/changelog.d/12775.misc @@ -0,0 +1 @@ +Refactor resolve_state_groups_for_events to not pull out full state when no state resolution happens. \ No newline at end of file From 0fb5ee16348f9edfa4dcb12392770e369e53c707 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Wed, 18 May 2022 09:21:13 -0700 Subject: [PATCH 4/5] make state_store._get_state_group_for_events and _get_state_for_groups public --- synapse/state/__init__.py | 6 +++--- synapse/storage/databases/state/store.py | 2 +- synapse/storage/state.py | 12 ++++++------ tests/test_state.py | 4 ++-- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index b1a89ca98dd4..a308f1b96b78 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -418,7 +418,7 @@ async def resolve_state_groups_for_events( """ logger.debug("resolve_state_groups event_ids %s", event_ids) - state_groups = await self.state_store._get_state_group_for_events(event_ids) + state_groups = await self.state_store.get_state_group_for_events(event_ids) state_group_ids = state_groups.values() @@ -426,7 +426,7 @@ async def resolve_state_groups_for_events( state_group_ids_set = set(state_group_ids) if len(state_group_ids_set) == 1: (state_group_id,) = state_group_ids_set - state = await self.state_store._get_state_for_groups(state_group_ids_set) + state = await self.state_store.get_state_for_groups(state_group_ids_set) prev_group, delta_ids = await self.state_store.get_state_group_delta( state_group_id ) @@ -442,7 +442,7 @@ async def resolve_state_groups_for_events( room_version = await self.store.get_room_version_id(room_id) state_group_ids_list = list(state_group_ids) - state_to_resolve = await self.state_store._get_state_for_groups( + state_to_resolve = await self.state_store.get_state_for_groups( state_group_ids_list ) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 7614d76ac646..609a2b88bfbf 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -189,7 +189,7 @@ def _get_state_for_group_using_cache( group: int, state_filter: StateFilter, ) -> Tuple[MutableStateMap[str], bool]: - """Checks if group is in cache. See `_get_state_for_groups` + """Checks if group is in cache. See `get_state_for_groups` Args: cache: the state group cache to use diff --git a/synapse/storage/state.py b/synapse/storage/state.py index d4a1bd4f9d7d..a6c60de50434 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -586,7 +586,7 @@ async def get_state_groups_ids( if not event_ids: return {} - event_to_groups = await self._get_state_group_for_events(event_ids) + event_to_groups = await self.get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups(groups) @@ -602,7 +602,7 @@ async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]: Returns: Resolves to a map of (type, state_key) -> event_id """ - group_to_state = await self._get_state_for_groups((state_group,)) + group_to_state = await self.get_state_for_groups((state_group,)) return group_to_state[state_group] @@ -675,7 +675,7 @@ async def get_state_for_events( RuntimeError if we don't have a state group for one or more of the events (ie they are outliers or unknown) """ - event_to_groups = await self._get_state_group_for_events(event_ids) + event_to_groups = await self.get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( @@ -716,7 +716,7 @@ async def get_state_ids_for_events( RuntimeError if we don't have a state group for one or more of the events (ie they are outliers or unknown) """ - event_to_groups = await self._get_state_group_for_events(event_ids) + event_to_groups = await self.get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( @@ -774,7 +774,7 @@ async def get_state_ids_for_event( ) return state_map[event_id] - def _get_state_for_groups( + def get_state_for_groups( self, groups: Iterable[int], state_filter: Optional[StateFilter] = None ) -> Awaitable[Dict[int, MutableStateMap[str]]]: """Gets the state at each of a list of state groups, optionally @@ -792,7 +792,7 @@ def _get_state_for_groups( groups, state_filter or StateFilter.all() ) - async def _get_state_group_for_events( + async def get_state_group_for_events( self, event_ids: Collection[str], await_full_state: bool = True, diff --git a/tests/test_state.py b/tests/test_state.py index 26f8ad526a2e..74a8ce6096b9 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -129,13 +129,13 @@ def register_event_id_state_group(self, event_id, state_group): async def get_room_version_id(self, room_id): return RoomVersions.V1.identifier - async def _get_state_group_for_events(self, event_ids): + async def get_state_group_for_events(self, event_ids): res = {} for event in event_ids: res[event] = self._event_to_state_group[event] return res - async def _get_state_for_groups(self, groups): + async def get_state_for_groups(self, groups): res = {} for group in groups: state = self._group_to_state[group] From acbb4dc1cc04e97235001bdd584c9378cda49ee7 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Wed, 18 May 2022 09:23:40 -0700 Subject: [PATCH 5/5] fix newsfragment + pass set instead of list --- changelog.d/12775.misc | 2 +- synapse/state/__init__.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/changelog.d/12775.misc b/changelog.d/12775.misc index 5648c75ec766..eac326cde3a7 100644 --- a/changelog.d/12775.misc +++ b/changelog.d/12775.misc @@ -1 +1 @@ -Refactor resolve_state_groups_for_events to not pull out full state when no state resolution happens. \ No newline at end of file +Refactor `resolve_state_groups_for_events` to not pull out full state when no state resolution happens. \ No newline at end of file diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index a308f1b96b78..ec11f46c4c5a 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -441,9 +441,8 @@ async def resolve_state_groups_for_events( room_version = await self.store.get_room_version_id(room_id) - state_group_ids_list = list(state_group_ids) state_to_resolve = await self.state_store.get_state_for_groups( - state_group_ids_list + state_group_ids_set ) result = await self._state_resolution_handler.resolve_state_groups(