Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Clean-up some receipts code #12888

Merged
merged 5 commits into from
May 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12888.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor receipt linearization code.
89 changes: 47 additions & 42 deletions synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def process_replication_rows(

return super().process_replication_rows(stream_name, instance_name, token, rows)

def insert_linearized_receipt_txn(
def _insert_linearized_receipt_txn(
self,
txn: LoggingTransaction,
room_id: str,
Expand Down Expand Up @@ -686,6 +686,44 @@ def insert_linearized_receipt_txn(

return rx_ts

def _graph_to_linear(
self, txn: LoggingTransaction, room_id: str, event_ids: List[str]
) -> str:
"""
Generate a linearized event from a list of events (i.e. a list of forward
extremities in the room).

This should allow for calculation of the correct read receipt even if
servers have different event ordering.

Args:
txn: The transaction
room_id: The room ID the events are in.
event_ids: The list of event IDs to linearize.

Returns:
The linearized event ID.
"""
# TODO: Make this better.
clause, args = make_in_list_sql_clause(
self.database_engine, "event_id", event_ids
)

sql = """
SELECT event_id WHERE room_id = ? AND stream_ordering IN (
SELECT max(stream_ordering) WHERE %s
)
""" % (
clause,
)

txn.execute(sql, [room_id] + list(args))
rows = txn.fetchall()
if rows:
return rows[0][0]
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))

async def insert_receipt(
self,
room_id: str,
Expand All @@ -712,35 +750,14 @@ async def insert_receipt(
linearized_event_id = event_ids[0]
else:
# we need to points in graph -> linearized form.
# TODO: Make this better.
def graph_to_linear(txn: LoggingTransaction) -> str:
clause, args = make_in_list_sql_clause(
self.database_engine, "event_id", event_ids
)

sql = """
SELECT event_id WHERE room_id = ? AND stream_ordering IN (
SELECT max(stream_ordering) WHERE %s
)
""" % (
clause,
)

txn.execute(sql, [room_id] + list(args))
rows = txn.fetchall()
if rows:
return rows[0][0]
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))

linearized_event_id = await self.db_pool.runInteraction(
"insert_receipt_conv", graph_to_linear
"insert_receipt_conv", self._graph_to_linear, room_id, event_ids
)

async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
self._insert_linearized_receipt_txn,
room_id,
receipt_type,
user_id,
Expand All @@ -761,33 +778,21 @@ def graph_to_linear(txn: LoggingTransaction) -> str:
now - event_ts,
)

await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)

max_persisted_id = self._receipts_id_gen.get_current_token()

return stream_id, max_persisted_id

async def insert_graph_receipt(
self,
room_id: str,
receipt_type: str,
user_id: str,
event_ids: List[str],
data: JsonDict,
) -> None:
assert self._can_write_to_receipts

await self.db_pool.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
self._insert_graph_receipt_txn,
room_id,
receipt_type,
user_id,
event_ids,
data,
)

def insert_graph_receipt_txn(
max_persisted_id = self._receipts_id_gen.get_current_token()

return stream_id, max_persisted_id

def _insert_graph_receipt_txn(
self,
txn: LoggingTransaction,
room_id: str,
Expand Down