Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Speed up chain cover calculation #9176

Merged
merged 6 commits into from
Jan 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9176.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Speed up chain cover calculation when persisting a batch of state events at once.
199 changes: 144 additions & 55 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,9 @@ def _persist_event_auth_chain_txn(
txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
)

@staticmethod
@classmethod
def _add_chain_cover_index(
cls,
txn,
db_pool: DatabasePool,
event_to_room_id: Dict[str, str],
Expand Down Expand Up @@ -614,60 +615,17 @@ def _add_chain_cover_index(
if not events_to_calc_chain_id_for:
return

# We now calculate the chain IDs/sequence numbers for the events. We
# do this by looking at the chain ID and sequence number of any auth
# event with the same type/state_key and incrementing the sequence
# number by one. If there was no match or the chain ID/sequence
# number is already taken we generate a new chain.
#
# We need to do this in a topologically sorted order as we want to
# generate chain IDs/sequence numbers of an event's auth events
# before the event itself.
chains_tuples_allocated = set() # type: Set[Tuple[int, int]]
new_chain_tuples = {} # type: Dict[str, Tuple[int, int]]
for event_id in sorted_topologically(
events_to_calc_chain_id_for, event_to_auth_chain
):
existing_chain_id = None
for auth_id in event_to_auth_chain.get(event_id, []):
if event_to_types.get(event_id) == event_to_types.get(auth_id):
existing_chain_id = chain_map[auth_id]
break

new_chain_tuple = None
if existing_chain_id:
# We found a chain ID/sequence number candidate, check its
# not already taken.
proposed_new_id = existing_chain_id[0]
proposed_new_seq = existing_chain_id[1] + 1
if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
already_allocated = db_pool.simple_select_one_onecol_txn(
txn,
table="event_auth_chains",
keyvalues={
"chain_id": proposed_new_id,
"sequence_number": proposed_new_seq,
},
retcol="event_id",
allow_none=True,
)
if already_allocated:
# Mark it as already allocated so we don't need to hit
# the DB again.
chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
else:
new_chain_tuple = (
proposed_new_id,
proposed_new_seq,
)

if not new_chain_tuple:
new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1)

chains_tuples_allocated.add(new_chain_tuple)

chain_map[event_id] = new_chain_tuple
new_chain_tuples[event_id] = new_chain_tuple
# Allocate chain ID/sequence numbers to each new event.
new_chain_tuples = cls._allocate_chain_ids(
txn,
db_pool,
event_to_room_id,
event_to_types,
event_to_auth_chain,
events_to_calc_chain_id_for,
chain_map,
)
chain_map.update(new_chain_tuples)

db_pool.simple_insert_many_txn(
txn,
Expand Down Expand Up @@ -794,6 +752,137 @@ def _add_chain_cover_index(
],
)

@staticmethod
def _allocate_chain_ids(
txn,
db_pool: DatabasePool,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]],
events_to_calc_chain_id_for: Set[str],
chain_map: Dict[str, Tuple[int, int]],
) -> Dict[str, Tuple[int, int]]:
"""Allocates, but does not persist, chain ID/sequence numbers for the
events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index
for info on args)
"""

# We now calculate the chain IDs/sequence numbers for the events. We do
# this by looking at the chain ID and sequence number of any auth event
# with the same type/state_key and incrementing the sequence number by
# one. If there was no match or the chain ID/sequence number is already
# taken we generate a new chain.
#
# We try to reduce the number of times that we hit the database by
# batching up calls, to make this more efficient when persisting large
# numbers of state events (e.g. during joins).
#
# We do this by:
# 1. Calculating for each event which auth event will be used to
# inherit the chain ID, i.e. converting the auth chain graph to a
# tree that we can allocate chains on. We also keep track of which
# existing chain IDs have been referenced.
# 2. Fetching the max allocated sequence number for each referenced
# existing chain ID, generating a map from chain ID to the max
# allocated sequence number.
# 3. Iterating over the tree and allocating a chain ID/seq no. to the
# new event, by incrementing the sequence number from the
# referenced event's chain ID/seq no. and checking that the
# incremented sequence number hasn't already been allocated (by
# looking in the map generated in the previous step). We generate a
# new chain if the sequence number has already been allocated.
#

existing_chains = set() # type: Set[int]
tree = [] # type: List[Tuple[str, Optional[str]]]

# We need to do this in a topologically sorted order as we want to
# generate chain IDs/sequence numbers of an event's auth events before
# the event itself.
for event_id in sorted_topologically(
events_to_calc_chain_id_for, event_to_auth_chain
):
for auth_id in event_to_auth_chain.get(event_id, []):
if event_to_types.get(event_id) == event_to_types.get(auth_id):
existing_chain_id = chain_map.get(auth_id)
if existing_chain_id:
existing_chains.add(existing_chain_id[0])

tree.append((event_id, auth_id))
break
else:
tree.append((event_id, None))

# Fetch the current max sequence number for each existing referenced chain.
sql = """
SELECT chain_id, MAX(sequence_number) FROM event_auth_chains
WHERE %s
GROUP BY chain_id
"""
clause, args = make_in_list_sql_clause(
db_pool.engine, "chain_id", existing_chains
)
txn.execute(sql % (clause,), args)

chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int]

# Allocate the new events chain ID/sequence numbers.
#
# To reduce the number of calls to the database we don't allocate a
# chain ID number in the loop, instead we use a temporary `object()` for
# each new chain ID. Once we've done the loop we generate the necessary
# number of new chain IDs in one call, replacing all temporary
# objects with real allocated chain IDs.

unallocated_chain_ids = set() # type: Set[object]
new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]]
for event_id, auth_event_id in tree:
# If we reference an auth_event_id we fetch the allocated chain ID,
# either from the existing `chain_map` or the newly generated
# `new_chain_tuples` map.
existing_chain_id = None
if auth_event_id:
clokep marked this conversation as resolved.
Show resolved Hide resolved
existing_chain_id = new_chain_tuples.get(auth_event_id)
if not existing_chain_id:
existing_chain_id = chain_map[auth_event_id]

new_chain_tuple = None # type: Optional[Tuple[Any, int]]
if existing_chain_id:
# We found a chain ID/sequence number candidate, check its
# not already taken.
proposed_new_id = existing_chain_id[0]
proposed_new_seq = existing_chain_id[1] + 1

if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq:
new_chain_tuple = (
proposed_new_id,
proposed_new_seq,
)

# If we need to start a new chain we allocate a temporary chain ID.
if not new_chain_tuple:
new_chain_tuple = (object(), 1)
unallocated_chain_ids.add(new_chain_tuple[0])

new_chain_tuples[event_id] = new_chain_tuple
chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]

# Generate new chain IDs for all unallocated chain IDs.
newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
txn, len(unallocated_chain_ids)
)

# Map from potentially temporary chain ID to real chain ID
chain_id_to_allocated_map = dict(
zip(unallocated_chain_ids, newly_allocated_chain_ids)
) # type: Dict[Any, int]
chain_id_to_allocated_map.update((c, c) for c in existing_chains)

return {
event_id: (chain_id_to_allocated_map[chain_id], seq)
for event_id, (chain_id, seq) in new_chain_tuples.items()
}

def _persist_transaction_ids_txn(
self,
txn: LoggingTransaction,
Expand Down
16 changes: 16 additions & 0 deletions synapse/storage/util/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ def get_next_id_txn(self, txn: Cursor) -> int:
"""Gets the next ID in the sequence"""
...

@abc.abstractmethod
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
"""Get the next `n` IDs in the sequence"""
...

@abc.abstractmethod
def check_consistency(
self,
Expand Down Expand Up @@ -174,6 +179,17 @@ def get_next_id_txn(self, txn: Cursor) -> int:
self._current_max_id += 1
return self._current_max_id

def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
with self._lock:
if self._current_max_id is None:
assert self._callback is not None
self._current_max_id = self._callback(txn)
self._callback = None

first_id = self._current_max_id + 1
self._current_max_id += n
return [first_id + i for i in range(n)]

def check_consistency(
self, db_conn: Connection, table: str, id_column: str, positive: bool = True
):
Expand Down