Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Always notify replication when a stream advances #14877

Merged
merged 5 commits into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/14877.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Always notify replication when a stream advances automatically.
4 changes: 4 additions & 0 deletions synapse/_scripts/synapse_port_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
make_deferred_yieldable,
run_in_background,
)
from synapse.notifier import ReplicationNotifier
from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
from synapse.storage.databases.main import PushRuleStore
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
Expand Down Expand Up @@ -260,6 +261,9 @@ def get_instance_name(self) -> str:
def should_send_federation(self) -> bool:
return False

def get_replication_notifier(self) -> ReplicationNotifier:
return ReplicationNotifier()


class Porter:
def __init__(
Expand Down
31 changes: 26 additions & 5 deletions synapse/notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.pending_new_room_events: List[_PendingRoomEventEntry] = []

# Called when there are new things to stream over replication
self.replication_callbacks: List[Callable[[], None]] = []
self._replication_notifier = hs.get_replication_notifier()
self._new_join_in_room_callbacks: List[Callable[[str, str], None]] = []

self._federation_client = hs.get_federation_http_client()
Expand Down Expand Up @@ -279,7 +278,7 @@ def add_replication_callback(self, cb: Callable[[], None]) -> None:
it needs to do any asynchronous work, a background thread should be started and
wrapped with run_as_background_process.
"""
self.replication_callbacks.append(cb)
self._replication_notifier.add_replication_callback(cb)

def add_new_join_in_room_callback(self, cb: Callable[[str, str], None]) -> None:
"""Add a callback that will be called when a user joins a room.
Expand Down Expand Up @@ -741,8 +740,7 @@ def _user_joined_room(self, user_id: str, room_id: str) -> None:

def notify_replication(self) -> None:
"""Notify the any replication listeners that there's a new event"""
for cb in self.replication_callbacks:
cb()
self._replication_notifier.notify_replication()

def notify_user_joined_room(self, event_id: str, room_id: str) -> None:
for cb in self._new_join_in_room_callbacks:
Expand All @@ -759,3 +757,26 @@ def notify_remote_server_up(self, server: str) -> None:
# Tell the federation client about the fact the server is back up, so
# that any in flight requests can be immediately retried.
self._federation_client.wake_destination(server)


@attr.s(auto_attribs=True)
class ReplicationNotifier:
"""Tracks callbacks for things that need to know about stream changes.

This is separate from the notifier to avoid circular dependencies.
"""

_replication_callbacks: List[Callable[[], None]] = attr.Factory(list)

def add_replication_callback(self, cb: Callable[[], None]) -> None:
"""Add a callback that will be called when some new data is available.
Callback is not given any arguments. It should *not* return a Deferred - if
it needs to do any asynchronous work, a background thread should be started and
wrapped with run_as_background_process.
"""
self._replication_callbacks.append(cb)

def notify_replication(self) -> None:
"""Notify the any replication listeners that there's a new event"""
for cb in self._replication_callbacks:
cb()
6 changes: 5 additions & 1 deletion synapse/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager
from synapse.module_api import ModuleApi
from synapse.notifier import Notifier
from synapse.notifier import Notifier, ReplicationNotifier
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import ReplicationDataHandler
Expand Down Expand Up @@ -389,6 +389,10 @@ def get_federation_server(self) -> FederationServer:
def get_notifier(self) -> Notifier:
return Notifier(self)

@cache_in_self
def get_replication_notifier(self) -> ReplicationNotifier:
return ReplicationNotifier()

@cache_in_self
def get_auth(self) -> Auth:
return Auth(self)
Expand Down
2 changes: 2 additions & 0 deletions synapse/storage/databases/main/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
self._account_data_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="account_data",
instance_name=self._instance_name,
tables=[
Expand All @@ -95,6 +96,7 @@ def __init__(
# SQLite).
self._account_data_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"room_account_data",
"stream_id",
extra_tables=[("room_tags_revisions", "stream_id")],
Expand Down
1 change: 1 addition & 0 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
self._cache_id_gen = MultiWriterIdGenerator(
db_conn,
database,
notifier=hs.get_replication_notifier(),
stream_name="caches",
instance_name=hs.get_instance_name(),
tables=[
Expand Down
3 changes: 2 additions & 1 deletion synapse/storage/databases/main/deviceinbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="to_device",
instance_name=self._instance_name,
tables=[("device_inbox", "instance_name", "stream_id")],
Expand All @@ -101,7 +102,7 @@ def __init__(
else:
self._can_write_to_device = True
self._device_inbox_id_gen = StreamIdGenerator(
db_conn, "device_inbox", "stream_id"
db_conn, hs.get_replication_notifier(), "device_inbox", "stream_id"
)

max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
Expand Down
1 change: 1 addition & 0 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(
# class below that is used on the main process.
self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"device_lists_stream",
"stream_id",
extra_tables=[
Expand Down
5 changes: 4 additions & 1 deletion synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,7 +1181,10 @@ def __init__(
super().__init__(database, db_conn, hs)

self._cross_signing_id_gen = StreamIdGenerator(
db_conn, "e2e_cross_signing_keys", "stream_id"
db_conn,
hs.get_replication_notifier(),
"e2e_cross_signing_keys",
"stream_id",
)

async def set_e2e_device_keys(
Expand Down
10 changes: 9 additions & 1 deletion synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def __init__(
self._stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="events",
instance_name=hs.get_instance_name(),
tables=[("events", "instance_name", "stream_ordering")],
Expand All @@ -200,6 +201,7 @@ def __init__(
self._backfill_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="backfill",
instance_name=hs.get_instance_name(),
tables=[("events", "instance_name", "stream_ordering")],
Expand All @@ -217,12 +219,14 @@ def __init__(
# SQLite).
self._stream_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"events",
"stream_ordering",
is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"events",
"stream_ordering",
step=-1,
Expand Down Expand Up @@ -300,6 +304,7 @@ def get_chain_id_txn(txn: Cursor) -> int:
self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="un_partial_stated_event_stream",
instance_name=hs.get_instance_name(),
tables=[
Expand All @@ -311,7 +316,10 @@ def get_chain_id_txn(txn: Cursor) -> int:
)
else:
self._un_partial_stated_events_stream_id_gen = StreamIdGenerator(
db_conn, "un_partial_stated_event_stream", "stream_id"
db_conn,
hs.get_replication_notifier(),
"un_partial_stated_event_stream",
"stream_id",
)

def get_un_partial_stated_events_token(self) -> int:
Expand Down
3 changes: 2 additions & 1 deletion synapse/storage/databases/main/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
self._presence_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="presence_stream",
instance_name=self._instance_name,
tables=[("presence_stream", "instance_name", "stream_id")],
Expand All @@ -85,7 +86,7 @@ def __init__(
)
else:
self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id"
db_conn, hs.get_replication_notifier(), "presence_stream", "stream_id"
)

self.hs = hs
Expand Down
1 change: 1 addition & 0 deletions synapse/storage/databases/main/push_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
# class below that is used on the main process.
self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"push_rules_stream",
"stream_id",
is_writer=hs.config.worker.worker_app is None,
Expand Down
1 change: 1 addition & 0 deletions synapse/storage/databases/main/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
# class below that is used on the main process.
self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"pushers",
"id",
extra_tables=[("deleted_pushers", "stream_id")],
Expand Down
2 changes: 2 additions & 0 deletions synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
self._receipts_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="receipts",
instance_name=self._instance_name,
tables=[("receipts_linearized", "instance_name", "stream_id")],
Expand All @@ -91,6 +92,7 @@ def __init__(
# SQLite).
self._receipts_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"receipts_linearized",
"stream_id",
is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts,
Expand Down
6 changes: 5 additions & 1 deletion synapse/storage/databases/main/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def __init__(
self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="un_partial_stated_room_stream",
instance_name=self._instance_name,
tables=[
Expand All @@ -137,7 +138,10 @@ def __init__(
)
else:
self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator(
db_conn, "un_partial_stated_room_stream", "stream_id"
db_conn,
hs.get_replication_notifier(),
"un_partial_stated_room_stream",
"stream_id",
)

async def store_room(
Expand Down
26 changes: 24 additions & 2 deletions synapse/storage/util/id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from contextlib import contextmanager
from types import TracebackType
from typing import (
TYPE_CHECKING,
AsyncContextManager,
ContextManager,
Dict,
Expand Down Expand Up @@ -49,6 +50,9 @@
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import PostgresSequenceGenerator

if TYPE_CHECKING:
from synapse.notifier import ReplicationNotifier

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -182,6 +186,7 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
def __init__(
self,
db_conn: LoggingDatabaseConnection,
notifier: "ReplicationNotifier",
table: str,
column: str,
extra_tables: Iterable[Tuple[str, str]] = (),
Expand All @@ -205,6 +210,8 @@ def __init__(
# The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()

self._notifier = notifier

def advance(self, instance_name: str, new_id: int) -> None:
# Advance should never be called on a writer instance, only over replication
if self._is_writer:
Expand All @@ -227,6 +234,8 @@ def manager() -> Generator[int, None, None]:
with self._lock:
self._unfinished_ids.pop(next_id)

self._notifier.notify_replication()

return _AsyncCtxManagerWrapper(manager())

def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
Expand All @@ -250,6 +259,8 @@ def manager() -> Generator[Sequence[int], None, None]:
for next_id in next_ids:
self._unfinished_ids.pop(next_id)

self._notifier.notify_replication()

return _AsyncCtxManagerWrapper(manager())

def get_current_token(self) -> int:
Expand Down Expand Up @@ -296,6 +307,7 @@ def __init__(
self,
db_conn: LoggingDatabaseConnection,
db: DatabasePool,
notifier: "ReplicationNotifier",
stream_name: str,
instance_name: str,
tables: List[Tuple[str, str, str]],
Expand All @@ -304,6 +316,7 @@ def __init__(
positive: bool = True,
) -> None:
self._db = db
self._notifier = notifier
self._stream_name = stream_name
self._instance_name = instance_name
self._positive = positive
Expand Down Expand Up @@ -535,7 +548,9 @@ def get_next(self) -> AsyncContextManager[int]:
# Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids,
# controls the return type. If `None` or omitted, the context manager yields
# a single integer stream_id; otherwise it yields a list of stream_ids.
return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
return cast(
AsyncContextManager[int], _MultiWriterCtxManager(self, self._notifier)
)

def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
# If we have a list of instances that are allowed to write to this
Expand All @@ -544,7 +559,10 @@ def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
raise Exception("Tried to allocate stream ID on non-writer")

# Cast safety: see get_next.
return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n))
return cast(
AsyncContextManager[List[int]],
_MultiWriterCtxManager(self, self._notifier, n),
)

def get_next_txn(self, txn: LoggingTransaction) -> int:
"""
Expand All @@ -563,6 +581,7 @@ def get_next_txn(self, txn: LoggingTransaction) -> int:

txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
txn.call_after(self._notifier.notify_replication)

# Update the `stream_positions` table with newly updated stream
# ID (unless self._writers is not set in which case we don't
Expand Down Expand Up @@ -787,6 +806,7 @@ class _MultiWriterCtxManager:
"""Async context manager returned by MultiWriterIdGenerator"""

id_gen: MultiWriterIdGenerator
notifier: "ReplicationNotifier"
multiple_ids: Optional[int] = None
stream_ids: List[int] = attr.Factory(list)

Expand Down Expand Up @@ -814,6 +834,8 @@ async def __aexit__(
for i in self.stream_ids:
self.id_gen._mark_id_as_finished(i)

self.notifier.notify_replication()

if exc_type is not None:
return False

Expand Down
Loading