Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Faster room joins: avoid blocking when pulling events with missing pr…
Browse files Browse the repository at this point in the history
…evs (#13355)

Avoid blocking on full state in `_resolve_state_at_missing_prevs` and
return a new flag indicating whether the resolved state is partial.
Thread that flag around so that it makes it into the event context.

Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
  • Loading branch information
squahtx and richvdh committed Jul 26, 2022
1 parent 8b60329 commit 335ebb2
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 33 deletions.
1 change: 1 addition & 0 deletions changelog.d/13355.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Faster room joins: avoid blocking when pulling events with partially missing prev events.
116 changes: 92 additions & 24 deletions synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,19 @@ async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None:
)

try:
await self._process_received_pdu(origin, pdu, state_ids=None)
await self._process_received_pdu(
origin, pdu, state_ids=None, partial_state=None
)
except PartialStateConflictError:
# The room was un-partial stated while we were processing the PDU.
# Try once more, with full state this time.
logger.info(
"Room %s was un-partial stated while processing the PDU, trying again.",
room_id,
)
await self._process_received_pdu(origin, pdu, state_ids=None)
await self._process_received_pdu(
origin, pdu, state_ids=None, partial_state=None
)

async def on_send_membership_event(
self, origin: str, event: EventBase
Expand Down Expand Up @@ -534,14 +538,36 @@ async def update_state_for_partial_state_event(
#
# This is the same operation as we do when we receive a regular event
# over federation.
state_ids = await self._resolve_state_at_missing_prevs(destination, event)

# build a new state group for it if need be
context = await self._state_handler.compute_event_context(
event,
state_ids_before_event=state_ids,
state_ids, partial_state = await self._resolve_state_at_missing_prevs(
destination, event
)
if context.partial_state:

# There are three possible cases for (state_ids, partial_state):
# * `state_ids` and `partial_state` are both `None` if we had all the
# prev_events. The prev_events may or may not have partial state and
# we won't know until we compute the event context.
# * `state_ids` is not `None` and `partial_state` is `False` if we were
# missing some prev_events (but we have full state for any we did
# have). We calculated the full state after the prev_events.
# * `state_ids` is not `None` and `partial_state` is `True` if we were
# missing some, but not all, prev_events. At least one of the
# prev_events we did have had partial state, so we calculated a partial
# state after the prev_events.

context = None
if state_ids is not None and partial_state:
# the state after the prev events is still partial. We can't de-partial
# state the event, so don't bother building the event context.
pass
else:
# build a new state group for it if need be
context = await self._state_handler.compute_event_context(
event,
state_ids_before_event=state_ids,
partial_state=partial_state,
)

if context is None or context.partial_state:
# this can happen if some or all of the event's prev_events still have
# partial state - ie, an event has an earlier stream_ordering than one
# or more of its prev_events, so we de-partial-state it before its
Expand Down Expand Up @@ -806,14 +832,39 @@ async def _process_pulled_event(
return

try:
state_ids = await self._resolve_state_at_missing_prevs(origin, event)
# TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
# not return partial state
# https://github.com/matrix-org/synapse/issues/13002
try:
state_ids, partial_state = await self._resolve_state_at_missing_prevs(
origin, event
)
await self._process_received_pdu(
origin,
event,
state_ids=state_ids,
partial_state=partial_state,
backfilled=backfilled,
)
except PartialStateConflictError:
# The room was un-partial stated while we were processing the event.
# Try once more, with full state this time.
state_ids, partial_state = await self._resolve_state_at_missing_prevs(
origin, event
)

await self._process_received_pdu(
origin, event, state_ids=state_ids, backfilled=backfilled
)
# We ought to have full state now, barring some unlikely race where we left and
# rejoned the room in the background.
if state_ids is not None and partial_state:
raise AssertionError(
f"Event {event.event_id} still has a partial resolved state "
f"after room {event.room_id} was un-partial stated"
)

await self._process_received_pdu(
origin,
event,
state_ids=state_ids,
partial_state=partial_state,
backfilled=backfilled,
)
except FederationError as e:
if e.code == 403:
logger.warning("Pulled event %s failed history check.", event_id)
Expand All @@ -822,7 +873,7 @@ async def _process_pulled_event(

async def _resolve_state_at_missing_prevs(
self, dest: str, event: EventBase
) -> Optional[StateMap[str]]:
) -> Tuple[Optional[StateMap[str]], Optional[bool]]:
"""Calculate the state at an event with missing prev_events.
This is used when we have pulled a batch of events from a remote server, and
Expand All @@ -849,8 +900,10 @@ async def _resolve_state_at_missing_prevs(
event: an event to check for missing prevs.
Returns:
if we already had all the prev events, `None`. Otherwise, returns
the event ids of the state at `event`.
if we already had all the prev events, `None, None`. Otherwise, returns a
tuple containing:
* the event ids of the state at `event`.
* a boolean indicating whether the state may be partial.
Raises:
FederationError if we fail to get the state from the remote server after any
Expand All @@ -864,7 +917,7 @@ async def _resolve_state_at_missing_prevs(
missing_prevs = prevs - seen

if not missing_prevs:
return None
return None, None

logger.info(
"Event %s is missing prev_events %s: calculating state for a "
Expand All @@ -876,9 +929,15 @@ async def _resolve_state_at_missing_prevs(
# resolve them to find the correct state at the current event.

try:
# Determine whether we may be about to retrieve partial state
# Events may be un-partial stated right after we compute the partial state
# flag, but that's okay, as long as the flag errs on the conservative side.
partial_state_flags = await self._store.get_partial_state_events(seen)
partial_state = any(partial_state_flags.values())

# Get the state of the events we know about
ours = await self._state_storage_controller.get_state_groups_ids(
room_id, seen
room_id, seen, await_full_state=False
)

# state_maps is a list of mappings from (type, state_key) to event_id
Expand Down Expand Up @@ -924,7 +983,7 @@ async def _resolve_state_at_missing_prevs(
"We can't get valid state history.",
affected=event_id,
)
return state_map
return state_map, partial_state

async def _get_state_ids_after_missing_prev_event(
self,
Expand Down Expand Up @@ -1094,6 +1153,7 @@ async def _process_received_pdu(
origin: str,
event: EventBase,
state_ids: Optional[StateMap[str]],
partial_state: Optional[bool],
backfilled: bool = False,
) -> None:
"""Called when we have a new non-outlier event.
Expand All @@ -1117,21 +1177,29 @@ async def _process_received_pdu(
state_ids: Normally None, but if we are handling a gap in the graph
(ie, we are missing one or more prev_events), the resolved state at the
event. Must not be partial state.
event
partial_state:
`True` if `state_ids` is partial and omits non-critical membership
events.
`False` if `state_ids` is the full state.
`None` if `state_ids` is not provided. In this case, the flag will be
calculated based on `event`'s prev events.
backfilled: True if this is part of a historical batch of events (inhibits
notification to clients, and validation of device keys.)
PartialStateConflictError: if the room was un-partial stated in between
computing the state at the event and persisting it. The caller should retry
exactly once in this case. Will never be raised if `state_ids` is provided.
exactly once in this case.
"""
logger.debug("Processing event: %s", event)
assert not event.internal_metadata.outlier

context = await self._state_handler.compute_event_context(
event,
state_ids_before_event=state_ids,
partial_state=partial_state,
)
try:
await self._check_event_auth(origin, event, context)
Expand Down
4 changes: 4 additions & 0 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,10 @@ async def create_new_client_event(
context = await self.state.compute_event_context(
event,
state_ids_before_event=state_map_for_event,
# TODO(faster_joins): check how MSC2716 works and whether we can have
# partial state here
# https://github.com/matrix-org/synapse/issues/13003
partial_state=False,
)
else:
context = await self.state.compute_event_context(event)
Expand Down
18 changes: 12 additions & 6 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ async def compute_event_context(
self,
event: EventBase,
state_ids_before_event: Optional[StateMap[str]] = None,
partial_state: bool = False,
partial_state: Optional[bool] = None,
) -> EventContext:
"""Build an EventContext structure for a non-outlier event.
Expand All @@ -270,8 +270,12 @@ async def compute_event_context(
it can't be calculated from existing events. This is normally
only specified when receiving an event from federation where we
don't have the prev events, e.g. when backfilling.
partial_state: True if `state_ids_before_event` is partial and omits
non-critical membership events
partial_state:
`True` if `state_ids_before_event` is partial and omits non-critical
membership events.
`False` if `state_ids_before_event` is the full state.
`None` when `state_ids_before_event` is not provided. In this case, the
flag will be calculated based on `event`'s prev events.
Returns:
The event context.
"""
Expand All @@ -298,12 +302,14 @@ async def compute_event_context(
)
)

# the partial_state flag must be provided
assert partial_state is not None
else:
# otherwise, we'll need to resolve the state across the prev_events.

# partial_state should not be set explicitly in this case:
# we work it out dynamically
assert not partial_state
assert partial_state is None

# if any of the prev-events have partial state, so do we.
# (This is slightly racy - the prev-events might get fixed up before we use
Expand All @@ -313,13 +319,13 @@ async def compute_event_context(
incomplete_prev_events = await self.store.get_partial_state_events(
prev_event_ids
)
if any(incomplete_prev_events.values()):
partial_state = any(incomplete_prev_events.values())
if partial_state:
logger.debug(
"New/incoming event %s refers to prev_events %s with partial state",
event.event_id,
[k for (k, v) in incomplete_prev_events.items() if v],
)
partial_state = True

logger.debug("calling resolve_state_groups from compute_event_context")
# we've already taken into account partial state, so no need to wait for
Expand Down
8 changes: 6 additions & 2 deletions synapse/storage/controllers/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,15 @@ async def get_state_group_delta(
return state_group_delta.prev_group, state_group_delta.delta_ids

async def get_state_groups_ids(
self, _room_id: str, event_ids: Collection[str]
self, _room_id: str, event_ids: Collection[str], await_full_state: bool = True
) -> Dict[int, MutableStateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events
Args:
_room_id: id of the room for these events
event_ids: ids of the events
await_full_state: if `True`, will block if we do not yet have complete
state at these events.
Returns:
dict of state_group_id -> (dict of (type, state_key) -> event id)
Expand All @@ -100,7 +102,9 @@ async def get_state_groups_ids(
if not event_ids:
return {}

event_to_groups = await self.get_state_group_for_events(event_ids)
event_to_groups = await self.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)

groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(groups)
Expand Down
1 change: 1 addition & 0 deletions tests/handlers/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def test_backfill_with_many_backward_extremities(self) -> None:
state_ids={
(e.type, e.state_key): e.event_id for e in current_state
},
partial_state=False,
)
)

Expand Down
7 changes: 6 additions & 1 deletion tests/storage/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ def prepare(self, reactor, clock, homeserver):
def persist_event(self, event, state=None):
"""Persist the event, with optional state"""
context = self.get_success(
self.state.compute_event_context(event, state_ids_before_event=state)
self.state.compute_event_context(
event,
state_ids_before_event=state,
partial_state=None if state is None else False,
)
)
self.get_success(self._persistence.persist_event(event, context))

Expand Down Expand Up @@ -148,6 +152,7 @@ def test_do_not_prune_gap_if_state_different(self):
self.state.compute_event_context(
remote_event_2,
state_ids_before_event=state_before_gap,
partial_state=False,
)
)

Expand Down
2 changes: 2 additions & 0 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ def test_annotate_with_old_message(self):
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
partial_state=False,
)
)

Expand Down Expand Up @@ -492,6 +493,7 @@ def test_annotate_with_old_state(self):
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
partial_state=False,
)
)

Expand Down

0 comments on commit 335ebb2

Please sign in to comment.