Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Refactor resolve_state_groups_for_events to not pull out full state when no state resolution happens. #12775

Merged
merged 5 commits into from
May 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12775.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor `resolve_state_groups_for_events` to not pull out full state when no state resolution happens.
35 changes: 19 additions & 16 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ async def compute_event_context(
#
# first of all, figure out the state before the event
#

if old_state:
# if we're given the state before the event, then we use that
state_ids_before_event: StateMap[str] = {
Expand Down Expand Up @@ -419,33 +418,37 @@ async def resolve_state_groups_for_events(
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)

# map from state group id to the state in that state group (where
# 'state' is a map from state key to event id)
# dict[int, dict[(str, str), str]]
state_groups_ids = await self.state_store.get_state_groups_ids(
room_id, event_ids
)

if len(state_groups_ids) == 0:
return _StateCacheEntry(state={}, state_group=None)
elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop()
state_groups = await self.state_store.get_state_group_for_events(event_ids)

prev_group, delta_ids = await self.state_store.get_state_group_delta(name)
state_group_ids = state_groups.values()

# check if each event has same state group id, if so there's no state to resolve
state_group_ids_set = set(state_group_ids)
if len(state_group_ids_set) == 1:
(state_group_id,) = state_group_ids_set
state = await self.state_store.get_state_for_groups(state_group_ids_set)
prev_group, delta_ids = await self.state_store.get_state_group_delta(
state_group_id
)
return _StateCacheEntry(
state=state_list,
state_group=name,
state=state[state_group_id],
state_group=state_group_id,
prev_group=prev_group,
delta_ids=delta_ids,
)
elif len(state_group_ids_set) == 0:
return _StateCacheEntry(state={}, state_group=None)

room_version = await self.store.get_room_version_id(room_id)

state_to_resolve = await self.state_store.get_state_for_groups(
state_group_ids_set
)

result = await self._state_resolution_handler.resolve_state_groups(
room_id,
room_version,
state_groups_ids,
state_to_resolve,
None,
state_res_store=StateResolutionStore(self.store),
)
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/state/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _get_state_for_group_using_cache(
group: int,
state_filter: StateFilter,
) -> Tuple[MutableStateMap[str], bool]:
"""Checks if group is in cache. See `_get_state_for_groups`
"""Checks if group is in cache. See `get_state_for_groups`

Args:
cache: the state group cache to use
Expand Down
12 changes: 6 additions & 6 deletions synapse/storage/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ async def get_state_groups_ids(
if not event_ids:
return {}

event_to_groups = await self._get_state_group_for_events(event_ids)
event_to_groups = await self.get_state_group_for_events(event_ids)

groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(groups)
Expand All @@ -602,7 +602,7 @@ async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
Returns:
Resolves to a map of (type, state_key) -> event_id
"""
group_to_state = await self._get_state_for_groups((state_group,))
group_to_state = await self.get_state_for_groups((state_group,))

return group_to_state[state_group]

Expand Down Expand Up @@ -675,7 +675,7 @@ async def get_state_for_events(
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
event_to_groups = await self._get_state_group_for_events(event_ids)
event_to_groups = await self.get_state_group_for_events(event_ids)

groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
Expand Down Expand Up @@ -716,7 +716,7 @@ async def get_state_ids_for_events(
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
event_to_groups = await self._get_state_group_for_events(event_ids)
event_to_groups = await self.get_state_group_for_events(event_ids)

groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
Expand Down Expand Up @@ -774,7 +774,7 @@ async def get_state_ids_for_event(
)
return state_map[event_id]

def _get_state_for_groups(
def get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
"""Gets the state at each of a list of state groups, optionally
Expand All @@ -792,7 +792,7 @@ def _get_state_for_groups(
groups, state_filter or StateFilter.all()
)

async def _get_state_group_for_events(
async def get_state_group_for_events(
self,
event_ids: Collection[str],
await_full_state: bool = True,
Expand Down
13 changes: 13 additions & 0 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,19 @@ def register_event_id_state_group(self, event_id, state_group):
async def get_room_version_id(self, room_id):
return RoomVersions.V1.identifier

async def get_state_group_for_events(self, event_ids):
res = {}
for event in event_ids:
res[event] = self._event_to_state_group[event]
return res

async def get_state_for_groups(self, groups):
res = {}
for group in groups:
state = self._group_to_state[group]
res[group] = state
return res


class DictObj(dict):
def __init__(self, **kwargs):
Expand Down