Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Improve get auth chain difference algorithm. #7095

Merged
merged 9 commits into from
Mar 18, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7095.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Attempt to improve performance of state res v2 algorithm.
28 changes: 8 additions & 20 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,28 +662,16 @@ def get_events(self, event_ids, allow_rejected=False):
allow_rejected=allow_rejected,
)

def get_auth_chain(self, event_ids: List[str], ignore_events: Set[str]):
"""Gets the full auth chain for a set of events (including rejected
events).

Includes the given event IDs in the result.

Note that:
1. All events must be state events.
2. For v1 rooms this may not have the full auth chain in the
presence of rejected events

Args:
event_ids: The event IDs of the events to fetch the auth chain for.
Must be state events.
ignore_events: Set of events to exclude from the returned auth
chain.
def get_auth_chain_difference(self, state_sets: List[Set[str]]):
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).

This equivalent to fetching the full auth chain for each set of state
and returning the events that don't appear in each and every auth
chain.

Returns:
Deferred[list[str]]: List of event IDs of the auth chain.
Deferred[Set[str]]
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
"""

return self.store.get_auth_chain_ids(
event_ids, include_given=True, ignore_events=ignore_events,
)
return self.store.get_auth_chain_difference(state_sets)
41 changes: 16 additions & 25 deletions synapse/state/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,36 +227,27 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
Returns:
Deferred[set[str]]: Set of event IDs
"""
common = set(itervalues(state_sets[0])).intersection(
*(itervalues(s) for s in state_sets[1:])
)

auth_sets = []
for state_set in state_sets:
auth_ids = {
eid
for key, eid in iteritems(state_set)
if (
key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite)
or key
in (
(EventTypes.PowerLevels, ""),
(EventTypes.Create, ""),
(EventTypes.JoinRules, ""),
auth_sets.append(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these aren't really auth sets as I understand the term? consider a rename?

Also, a comment to explain why we're looking at these event types specifically.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is a colossal hack that basically notices that we re-add all the conflicted state later, and only care about events that are "auth" events (i.e. types that are required to auth other events). Given its a bit of a mind bend to understand that this is OK, and the vast majority of state is memberships, I think its probably better to just remove this filtering entirely.

{
eid
for key, eid in iteritems(state_set)
if (
key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite)
or key
in (
(EventTypes.PowerLevels, ""),
(EventTypes.Create, ""),
(EventTypes.JoinRules, ""),
)
)
)
and eid not in common
}

auth_chain = yield state_res_store.get_auth_chain(auth_ids, common)
auth_ids.update(auth_chain)

auth_sets.append(auth_ids)
}
)

intersection = set(auth_sets[0]).intersection(*auth_sets[1:])
union = set().union(*auth_sets)
difference = yield state_res_store.get_auth_chain_difference(auth_sets)

return union - intersection
return difference


def _seperate(state_sets):
Expand Down
145 changes: 144 additions & 1 deletion synapse/storage/data_stores/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import itertools
import logging
from typing import List, Optional, Set
from typing import Dict, List, Optional, Set

from six.moves.queue import Empty, PriorityQueue

Expand Down Expand Up @@ -103,6 +103,149 @@ def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):

return list(results)

def get_auth_chain_difference(self, state_sets: List[Set[str]]):
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).

This equivalent to fetching the full auth chain for each set of state
and returning the events that don't appear in each and every auth
chain.

Returns:
Deferred[Set[str]]
"""

return self.db.runInteraction(
"get_auth_chain_difference",
self._get_auth_chain_difference_txn,
state_sets,
)

def _get_auth_chain_difference_txn(
self, txn, state_sets: List[Set[str]]
) -> Set[str]:

# Algorithm Description
# ~~~~~~~~~~~~~~~~~~~~~
#
# The idea here is to basically walk the auth graph of each state set in
# tandem, keeping track of which auth events are reachable by each state
# set. If we reach an auth event we've already visited (via a different
# state set) then we mark that auth event and all ancestors as reachable
# by the state set. This requires that we keep track of the auth chains
# in memory.
#
# Doing it in a such a way means that we can stop early if all auth
# events we're currently walking are reachable by all state sets.
#
# *Note*: We can't stop walking an event's auth chain if it is reachable
# by all state sets. This is because other auth chains we're walking
# might be reachable only via the original auth chain. For example,
# given the following auth chain:
#
# A -> C -> D -> E
# / /
# B -´---------´
#
# and state sets {A} and {B} then walking the auth chains of A and B
# would immediately show that C is reachable by both. However, if we
# stopped at C then we'd only reach E via the auth chain of B and so E
# would errornously get included in the returned difference.
#
# The other thing that we do is limit the number of auth chains we walk
# at once, due to practical limits (i.e. we can only query the database
# with a limited set of parameters). We pick the auth chains we walk
# each iteration based on their depth, in the hope that events with a
# lower depth are likely reachable by those with higher depths.
#
# We could use any ordering that we believe would give a rough
# topological ordering, e.g. origin server timestamp. If the ordering
# chosen is not topological then the algorithm still produces the right
# result, but perhaps a bit more inefficiently. This is why it is safe
# to use "depth" here.

initial_events = set(state_sets[0]).union(*state_sets[1:])

# Dict from events in auth chains to which sets *cannot* reach them.
# I.e. if the set is empty then all sets can reach the event.
event_to_missing_sets = {
event_id: {i for i, a in enumerate(state_sets) if event_id not in a}
for event_id in initial_events
}

# We need to get the depth of the initial events for sorting purposes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we trust this value of depth?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope! Though hopefully as explained above we're using it as a hint, and so worst case if its wrong we basically pull the full auth chains out of the DB as we're currently doing

rows = self.db.simple_select_many_txn(
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
txn,
table="events",
column="event_id",
iterable=initial_events,
keyvalues={},
retcols=("event_id", "depth"),
)

# The sorted list of events we should walk the auth chain off.
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
search = sorted((row["depth"], row["event_id"]) for row in rows)
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved

# Map from event to its auth events
event_to_auth_events = {} # type: Dict[str, Set[str]]

base_sql = """
SELECT depth, a.event_id, auth_id
FROM event_auth AS a
INNER JOIN events AS e ON (e.event_id = a.auth_id)
WHERE
"""

while search:
# Check whether all our current walks are reachable by all state
# sets. If so we can bail.
if all(not event_to_missing_sets[eid] for _, eid in search):
break

# Fetch the auth events and their depths of the N last events we're
# currently walking
search, chunk = search[:-100], search[-100:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vaguely wondering if a heap would be more efficient for search. let's not rewrite it now though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Me too, but heapq seems to only let you pop items one at a time I think? Which sounds inefficient

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed. we might have to roll our own impl, which may or may not work out as a winner over a sorted list.

clause, args = make_in_list_sql_clause(
txn.database_engine, "a.event_id", [e_id for _, e_id in chunk]
)
txn.execute(base_sql + clause, args)

for depth, event_id, auth_id in txn:
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
event_to_auth_events.setdefault(event_id, set()).add(auth_id)

if auth_id not in event_to_missing_sets:
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
# First time we're seeing this event, so we add it to the
# queue of things to fetch.
search.append((depth, auth_id))
else:
# We've previously seen this event, so look up its auth
# events and recursively mark all ancestors as reachable
# by the current events state set.
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
a_ids = event_to_auth_events.get(auth_id)
while a_ids:
new_aids = set()
for a_id in a_ids:
event_to_missing_sets[a_id].intersection_update(
event_to_missing_sets[event_id]
)

b = event_to_auth_events.get(a_id)
if b:
new_aids.update(b)

a_ids = new_aids

# Mark that the auth event is reachable by the approriate sets.
sets = event_to_missing_sets.setdefault(
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
auth_id, set(range(len(state_sets)))
)
sets.intersection_update(event_to_missing_sets[event_id])

search.sort()

# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}

def get_oldest_events_in_room(self, room_id):
return self.db.runInteraction(
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
Expand Down
13 changes: 8 additions & 5 deletions tests/state/test_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ def get_events(self, event_ids, allow_rejected=False):

return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}

def get_auth_chain(self, event_ids, ignore_events):
def _get_auth_chain(self, event_ids):
"""Gets the full auth chain for a set of events (including rejected
events).

Expand All @@ -617,9 +617,6 @@ def get_auth_chain(self, event_ids, ignore_events):
Args:
event_ids (list): The event IDs of the events to fetch the auth
chain for. Must be state events.
ignore_events: Set of events to exclude from the returned auth
chain.

Returns:
Deferred[list[str]]: List of event IDs of the auth chain.
"""
Expand All @@ -629,7 +626,7 @@ def get_auth_chain(self, event_ids, ignore_events):
stack = list(event_ids)
while stack:
event_id = stack.pop()
if event_id in result or event_id in ignore_events:
if event_id in result:
continue

result.add(event_id)
Expand All @@ -639,3 +636,9 @@ def get_auth_chain(self, event_ids, ignore_events):
stack.append(aid)

return list(result)

def get_auth_chain_difference(self, auth_sets):
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]

common = set(chains[0]).intersection(*chains[1:])
return set(chains[0]).union(*chains[1:]) - common
Loading