Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Rip out auth-event reconciliation code #12943

Merged
merged 14 commits into from
Jul 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12943.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Remove code which incorrectly attempted to reconcile state with remote servers when processing incoming events.
277 changes: 82 additions & 195 deletions synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import itertools
import logging
from http import HTTPStatus
Expand Down Expand Up @@ -347,7 +348,7 @@ async def on_send_membership_event(
event.internal_metadata.send_on_behalf_of = origin

context = await self._state_handler.compute_event_context(event)
context = await self._check_event_auth(origin, event, context)
await self._check_event_auth(origin, event, context)
if context.rejected:
raise SynapseError(
403, f"{event.membership} event was rejected", Codes.FORBIDDEN
Expand Down Expand Up @@ -485,7 +486,7 @@ async def process_remote_join(
partial_state=partial_state,
)

context = await self._check_event_auth(origin, event, context)
await self._check_event_auth(origin, event, context)
if context.rejected:
raise SynapseError(400, "Join event was rejected")

Expand Down Expand Up @@ -1116,11 +1117,7 @@ async def _process_received_pdu(
state_ids_before_event=state_ids,
)
try:
context = await self._check_event_auth(
origin,
event,
context,
)
await self._check_event_auth(origin, event, context)
except AuthError as e:
# This happens only if we couldn't find the auth events. We'll already have
# logged a warning, so now we just convert to a FederationError.
Expand Down Expand Up @@ -1495,11 +1492,8 @@ async def prep(event: EventBase) -> None:
)

async def _check_event_auth(
self,
origin: str,
event: EventBase,
context: EventContext,
) -> EventContext:
self, origin: str, event: EventBase, context: EventContext
) -> None:
"""
Checks whether an event should be rejected (for failing auth checks).

Expand All @@ -1509,9 +1503,6 @@ async def _check_event_auth(
context:
The event context.

Returns:
The updated context object.

Raises:
AuthError if we were unable to find copies of the event's auth events.
(Most other failures just cause us to set `context.rejected`.)
Expand All @@ -1526,7 +1517,7 @@ async def _check_event_auth(
logger.warning("While validating received event %r: %s", event, e)
# TODO: use a different rejected reason here?
context.rejected = RejectedReason.AUTH_ERROR
return context
return

# next, check that we have all of the event's auth events.
#
Expand All @@ -1538,6 +1529,9 @@ async def _check_event_auth(
)

# ... and check that the event passes auth at those auth events.
# https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
# 4. Passes authorization rules based on the event’s auth events,
# otherwise it is rejected.
try:
await check_state_independent_auth_rules(self._store, event)
check_state_dependent_auth_rules(event, claimed_auth_events)
Expand All @@ -1546,55 +1540,90 @@ async def _check_event_auth(
"While checking auth of %r against auth_events: %s", event, e
)
context.rejected = RejectedReason.AUTH_ERROR
return context
return

# now check the auth rules pass against the room state before the event
# https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
# 5. Passes authorization rules based on the state before the event,
# otherwise it is rejected.
#
# ... however, if we only have partial state for the room, then there is a good
# chance that we'll be missing some of the state needed to auth the new event.
# So, we state-resolve the auth events that we are given against the state that
# we know about, which ensures things like bans are applied. (Note that we'll
# already have checked we have all the auth events, in
# _load_or_fetch_auth_events_for_event above)
if context.partial_state:
room_version = await self._store.get_room_version_id(event.room_id)

local_state_id_map = await context.get_prev_state_ids()
claimed_auth_events_id_map = {
(ev.type, ev.state_key): ev.event_id for ev in claimed_auth_events
}

state_for_auth_id_map = (
await self._state_resolution_handler.resolve_events_with_store(
event.room_id,
room_version,
[local_state_id_map, claimed_auth_events_id_map],
event_map=None,
state_res_store=StateResolutionStore(self._store),
)
)
else:
event_types = event_auth.auth_types_for_event(event.room_version, event)
state_for_auth_id_map = await context.get_prev_state_ids(
StateFilter.from_types(event_types)
)

# now check auth against what we think the auth events *should* be.
event_types = event_auth.auth_types_for_event(event.room_version, event)
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types(event_types)
calculated_auth_event_ids = self._event_auth_handler.compute_auth_events(
event, state_for_auth_id_map, for_verification=True
)

auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=True
# if those are the same, we're done here.
if collections.Counter(event.auth_event_ids()) == collections.Counter(
calculated_auth_event_ids
):
return

# otherwise, re-run the auth checks based on what we calculated.
calculated_auth_events = await self._store.get_events_as_list(
calculated_auth_event_ids
)
auth_events_x = await self._store.get_events(auth_events_ids)

# log the differences

claimed_auth_event_map = {(e.type, e.state_key): e for e in claimed_auth_events}
calculated_auth_event_map = {
(e.type, e.state_key): e for e in auth_events_x.values()
(e.type, e.state_key): e for e in calculated_auth_events
}
logger.info(
"event's auth_events are different to our calculated auth_events. "
"Claimed but not calculated: %s. Calculated but not claimed: %s",
[
ev
for k, ev in claimed_auth_event_map.items()
if k not in calculated_auth_event_map
or calculated_auth_event_map[k].event_id != ev.event_id
],
[
ev
for k, ev in calculated_auth_event_map.items()
if k not in claimed_auth_event_map
or claimed_auth_event_map[k].event_id != ev.event_id
],
)

try:
updated_auth_events = await self._update_auth_events_for_auth(
check_state_dependent_auth_rules(event, calculated_auth_events)
except AuthError as e:
logger.warning(
"While checking auth of %r against room state before the event: %s",
event,
calculated_auth_event_map=calculated_auth_event_map,
)
except Exception:
# We don't really mind if the above fails, so lets not fail
# processing if it does. However, it really shouldn't fail so
# let's still log as an exception since we'll still want to fix
# any bugs.
logger.exception(
"Failed to double check auth events for %s with remote. "
"Ignoring failure and continuing processing of event.",
event.event_id,
)
updated_auth_events = None

if updated_auth_events:
context = await self._update_context_for_auth_events(
event, context, updated_auth_events
e,
)
auth_events_for_auth = updated_auth_events
else:
auth_events_for_auth = calculated_auth_event_map

try:
check_state_dependent_auth_rules(event, auth_events_for_auth.values())
except AuthError as e:
logger.warning("Failed auth resolution for %r because %s", event, e)
context.rejected = RejectedReason.AUTH_ERROR

return context

async def _maybe_kick_guest_users(self, event: EventBase) -> None:
if event.type != EventTypes.GuestAccess:
return
Expand Down Expand Up @@ -1704,93 +1733,6 @@ async def _check_for_soft_fail(
soft_failed_event_counter.inc()
event.internal_metadata.soft_failed = True

async def _update_auth_events_for_auth(
self,
event: EventBase,
calculated_auth_event_map: StateMap[EventBase],
) -> Optional[StateMap[EventBase]]:
"""Helper for _check_event_auth. See there for docs.

Checks whether a given event has the expected auth events. If it
doesn't then we talk to the remote server to compare state to see if
we can come to a consensus (e.g. if one server missed some valid
state).

This attempts to resolve any potential divergence of state between
servers, but is not essential and so failures should not block further
processing of the event.

Args:
event:

calculated_auth_event_map:
Our calculated auth_events based on the state of the room
at the event's position in the DAG.

Returns:
updated auth event map, or None if no changes are needed.

"""
assert not event.internal_metadata.outlier

# check for events which are in the event's claimed auth_events, but not
# in our calculated event map.
event_auth_events = set(event.auth_event_ids())
different_auth = event_auth_events.difference(
e.event_id for e in calculated_auth_event_map.values()
)

if not different_auth:
return None

logger.info(
"auth_events refers to events which are not in our calculated auth "
"chain: %s",
different_auth,
)

# XXX: currently this checks for redactions but I'm not convinced that is
# necessary?
different_events = await self._store.get_events_as_list(different_auth)

# double-check they're all in the same room - we should already have checked
# this but it doesn't hurt to check again.
for d in different_events:
assert (
d.room_id == event.room_id
), f"Event {event.event_id} refers to auth_event {d.event_id} which is in a different room"

# now we state-resolve between our own idea of the auth events, and the remote's
# idea of them.

local_state = calculated_auth_event_map.values()
remote_auth_events = dict(calculated_auth_event_map)
remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
remote_state = remote_auth_events.values()

room_version = await self._store.get_room_version_id(event.room_id)
new_state = await self._state_handler.resolve_events(
room_version, (local_state, remote_state), event
)
different_state = {
(d.type, d.state_key): d
for d in new_state.values()
if calculated_auth_event_map.get((d.type, d.state_key)) != d
}
if not different_state:
logger.info("State res returned no new state")
return None

logger.info(
"After state res: updating auth_events with new state %s",
different_state.values(),
)

# take a copy of calculated_auth_event_map before we modify it.
auth_events = dict(calculated_auth_event_map)
auth_events.update(different_state)
return auth_events

async def _load_or_fetch_auth_events_for_event(
self, destination: str, event: EventBase
) -> Collection[EventBase]:
Expand Down Expand Up @@ -1888,61 +1830,6 @@ async def _get_remote_auth_chain_for_event(

await self._auth_and_persist_outliers(room_id, remote_auth_events)

async def _update_context_for_auth_events(
self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
) -> EventContext:
"""Update the state_ids in an event context after auth event resolution,
storing the changes as a new state group.

Args:
event: The event we're handling the context for

context: initial event context

auth_events: Events to update in the event context.

Returns:
new event context
"""
# exclude the state key of the new event from the current_state in the context.
if event.is_state():
event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
else:
event_key = None
state_updates = {
k: a.event_id for k, a in auth_events.items() if k != event_key
}

current_state_ids = await context.get_current_state_ids()
current_state_ids = dict(current_state_ids) # type: ignore

current_state_ids.update(state_updates)

prev_state_ids = await context.get_prev_state_ids()
prev_state_ids = dict(prev_state_ids)

prev_state_ids.update({k: a.event_id for k, a in auth_events.items()})

# create a new state group as a delta from the existing one.
prev_group = context.state_group
state_group = await self._state_storage_controller.store_state_group(
event.event_id,
event.room_id,
prev_group=prev_group,
delta_ids=state_updates,
current_state_ids=current_state_ids,
)

return EventContext.with_state(
storage=self._storage_controllers,
state_group=state_group,
state_group_before_event=context.state_group_before_event,
state_delta_due_to_event=state_updates,
prev_group=prev_group,
delta_ids=state_updates,
partial_state=context.partial_state,
)

async def _run_push_actions_and_persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
) -> None:
Expand Down
Loading