diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 19bd43bc6afb..130846c2f92e 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -21,8 +21,10 @@ from prometheus_client import Counter +from synapse.metrics import LaterGauge from synapse.replication.tcp.client import ReplicationClientFactory from synapse.replication.tcp.commands import ( + ClearUserSyncsCommand, Command, FederationAckCommand, InvalidateCacheCommand, @@ -30,6 +32,7 @@ RdataCommand, RemoteServerUpCommand, RemovePusherCommand, + ReplicateCommand, SyncCommand, UserIpCommand, UserSyncCommand, @@ -44,6 +47,13 @@ inbound_rdata_count = Counter( "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] ) +user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") +federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") +remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") +invalidate_cache_counter = Counter( + "synapse_replication_tcp_resource_invalidate_cache", "" +) +user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") class ReplicationCommandHandler: @@ -52,6 +62,8 @@ class ReplicationCommandHandler: def __init__(self, hs): self.replication_data_handler = hs.get_replication_data_handler() + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() self.presence_handler = hs.get_presence_handler() # Set of streams that we're currently catching up with. @@ -71,7 +83,41 @@ def __init__(self, hs): self.factory = None # type: Optional[ReplicationClientFactory] # The currently connected connections. - self.connections = [] + self.connections = [] # type: List[Any] + + LaterGauge( + "synapse_replication_tcp_resource_total_connections", + "", + [], + lambda: len(self.connections), + ) + + LaterGauge( + "synapse_replication_tcp_resource_connections_per_stream", + "", + ["stream_name"], + lambda: { + (stream_name,): len( + [ + conn + for conn in self.connections + if stream_name in conn.replication_streams + ] + ) + for stream_name in self.streams + }, + ) + + self.is_master = hs.config.worker_app is None + + self.federation_sender = None + if self.is_master and not hs.config.send_federation: + self.federation_sender = hs.get_federation_sender() + + self._server_notices_sender = None + if self.is_master: + self._server_notices_sender = hs.get_server_notices_sender() + self.notifier.add_remote_server_up_callback(self.send_remote_server_up) def start_replication(self, hs): """Helper method to start a replication connection to the remote server @@ -83,6 +129,73 @@ def start_replication(self, hs): port = hs.config.worker_replication_port hs.get_reactor().connectTCP(host, port, self.factory) + async def on_REPLICATE(self, cmd: ReplicateCommand): + # We only want to announce positions by the writer of the streams. + # Currently this is just the master process. + if not self.is_master: + return + + if not self.connections: + raise Exception("Not connected") + + for stream_name, stream in self.streams.items(): + current_token = stream.current_token() + self.send_command(PositionCommand(stream_name, current_token)) + + async def on_USER_SYNC(self, cmd: UserSyncCommand): + user_sync_counter.inc() + + if self.is_master: + await self.presence_handler.update_external_syncs_row( + cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms + ) + + async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand): + if self.is_master: + await self.presence_handler.update_external_syncs_clear(cmd.instance_id) + + async def on_FEDERATION_ACK(self, cmd: FederationAckCommand): + federation_ack_counter.inc() + + if self.federation_sender: + self.federation_sender.federation_ack(cmd.token) + + async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand): + remove_pusher_counter.inc() + + if self.is_master: + await self.store.delete_pusher_by_app_id_pushkey_user_id( + app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id + ) + + self.notifier.on_new_replication_data() + + async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand): + invalidate_cache_counter.inc() + + if self.is_master: + # We invalidate the cache locally, but then also stream that to other + # workers. + await self.store.invalidate_cache_and_stream( + cmd.cache_func, tuple(cmd.keys) + ) + + async def on_USER_IP(self, cmd: UserIpCommand): + user_ip_cache_counter.inc() + + if self.is_master: + await self.store.insert_client_ip( + cmd.user_id, + cmd.access_token, + cmd.ip, + cmd.user_agent, + cmd.device_id, + cmd.last_seen, + ) + + if self._server_notices_sender: + await self._server_notices_sender.on_user_ip(cmd.user_id) + async def on_RDATA(self, cmd: RdataCommand): stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() @@ -167,6 +280,8 @@ async def on_SYNC(self, cmd: SyncCommand): async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): """"Called when get a new REMOTE_SERVER_UP command.""" + if self.is_master: + self.notifier.notify_remote_server_up(cmd.data) def get_currently_syncing_users(self): """Get the list of currently syncing users (if any). This is called @@ -249,3 +364,10 @@ def send_user_ip( def send_remote_server_up(self, server: str): self.send_command(RemoteServerUpCommand(server)) + + def stream_update(self, stream_name: str, token: str, data: Any): + """Called when a new update is available to stream to clients. + + We need to check if the client is interested in the stream or not + """ + self.send_command(RdataCommand(stream_name, token, data)) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index ccf24a62ca33..d98e9e2b69ac 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -70,13 +70,8 @@ ErrorCommand, NameCommand, PingCommand, - PositionCommand, - RdataCommand, - RemoteServerUpCommand, ReplicateCommand, ServerCommand, - SyncCommand, - UserSyncCommand, ) from synapse.replication.tcp.streams import STREAMS_MAP, Stream from synapse.types import Collection @@ -136,8 +131,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): max_line_buffer = 10000 - def __init__(self, clock): + def __init__(self, clock, handler): self.clock = clock + self.handler = handler self.last_received_command = self.clock.time_msec() self.last_sent_command = 0 @@ -177,6 +173,8 @@ def connectionMade(self): # can time us out. self.send_command(PingCommand(self.clock.time_msec())) + self.handler.new_connection(self) + def send_ping(self): """Periodically sends a ping and checks if we should close the connection due to the other side timing out. @@ -250,8 +248,23 @@ async def handle_command(self, cmd: Command): Args: cmd: received command """ - handler = getattr(self, "on_%s" % (cmd.NAME,)) - await handler(cmd) + handled = False + + # First call any command handlers on this instance. These are for TCP + # specific handling. + cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + # Then call out to the handler. + cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + if not handled: + logger.warning("Unhandled command: %r", cmd) def close(self): logger.warning("[%s] Closing connection", self.id()) @@ -380,6 +393,8 @@ def on_connection_closed(self): self.state = ConnectionStates.CLOSED self.pending_commands = [] + self.handler.lost_connection(self) + if self.transport: self.transport.unregisterProducer() @@ -406,74 +421,19 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS - def __init__(self, server_name, clock, streamer): - BaseReplicationStreamProtocol.__init__(self, clock) # Old style class + def __init__(self, server_name, clock, handler): + BaseReplicationStreamProtocol.__init__(self, clock, handler) # Old style class self.server_name = server_name - self.streamer = streamer def connectionMade(self): self.send_command(ServerCommand(self.server_name)) BaseReplicationStreamProtocol.connectionMade(self) - self.streamer.new_connection(self) async def on_NAME(self, cmd): logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data - async def on_USER_SYNC(self, cmd): - await self.streamer.on_user_sync( - cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms - ) - - async def on_CLEAR_USER_SYNC(self, cmd): - await self.streamer.on_clear_user_syncs(cmd.instance_id) - - async def on_REPLICATE(self, cmd): - # Subscribe to all streams we're publishing to. - for stream_name in self.streamer.streams_by_name: - current_token = self.streamer.get_stream_token(stream_name) - self.send_command(PositionCommand(stream_name, current_token)) - - async def on_FEDERATION_ACK(self, cmd): - self.streamer.federation_ack(cmd.token) - - async def on_REMOVE_PUSHER(self, cmd): - await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) - - async def on_INVALIDATE_CACHE(self, cmd): - await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) - - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): - self.streamer.on_remote_server_up(cmd.data) - - async def on_USER_IP(self, cmd): - self.streamer.on_user_ip( - cmd.user_id, - cmd.access_token, - cmd.ip, - cmd.user_agent, - cmd.device_id, - cmd.last_seen, - ) - - def stream_update(self, stream_name, token, data): - """Called when a new update is available to stream to clients. - - We need to check if the client is interested in the stream or not - """ - self.send_command(RdataCommand(stream_name, token, data)) - - def send_sync(self, data): - self.send_command(SyncCommand(data)) - - def send_remote_server_up(self, server: str): - self.send_command(RemoteServerUpCommand(server)) - - def on_connection_closed(self): - BaseReplicationStreamProtocol.on_connection_closed(self) - self.streamer.lost_connection(self) - class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): """ @@ -552,13 +512,12 @@ def __init__( clock: Clock, command_handler, ): - BaseReplicationStreamProtocol.__init__(self, clock) + BaseReplicationStreamProtocol.__init__(self, clock, command_handler) self.instance_id = hs.get_instance_id() self.client_name = client_name self.server_name = server_name - self.handler = command_handler self.streams = { stream.NAME: stream(hs) for stream in STREAMS_MAP.values() @@ -578,42 +537,6 @@ def connectionMade(self): # Once we've connected subscribe to the necessary streams self.replicate() - # Tell the server if we have any users currently syncing (should only - # happen on synchrotrons) - currently_syncing = self.handler.get_currently_syncing_users() - now = self.clock.time_msec() - for user_id in currently_syncing: - self.send_command(UserSyncCommand(self.instance_id, user_id, True, now)) - - # We've now finished connecting to so inform the client handler - self.handler.new_connection(self) - - async def handle_command(self, cmd: Command): - """Handle a command we have received over the replication stream. - - By default delegates to on_, which should return an awaitable. - - Args: - cmd: received command - """ - handled = False - - # First call any command handlers on this instance. These are for TCP - # specific handling. - cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) - if cmd_func: - await cmd_func(cmd) - handled = True - - # Then call out to the handler. - cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) - if cmd_func: - await cmd_func(cmd) - handled = True - - if not handled: - logger.warning("Unhandled command: %r", cmd) - async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) @@ -626,10 +549,6 @@ def replicate(self): self.send_command(ReplicateCommand()) - def on_connection_closed(self): - BaseReplicationStreamProtocol.on_connection_closed(self) - self.handler.lost_connection(self) - # The following simply registers metrics for the replication connections diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 8b6067e20df4..cc0322258748 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -17,7 +17,7 @@ import logging import random -from typing import Any, Dict, List +from typing import Dict from six import itervalues @@ -25,24 +25,14 @@ from twisted.internet.protocol import Factory -from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.metrics import Measure, measure_func - -from .protocol import ServerReplicationStreamProtocol -from .streams import STREAMS_MAP, Stream -from .streams.federation import FederationStream +from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol +from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream +from synapse.util.metrics import Measure stream_updates_counter = Counter( "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"] ) -user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") -federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") -remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") -invalidate_cache_counter = Counter( - "synapse_replication_tcp_resource_invalidate_cache", "" -) -user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") logger = logging.getLogger(__name__) @@ -52,13 +42,18 @@ class ReplicationStreamProtocolFactory(Factory): """ def __init__(self, hs): - self.streamer = hs.get_replication_streamer() + self.handler = hs.get_tcp_replication() self.clock = hs.get_clock() self.server_name = hs.config.server_name + self.hs = hs + + # Ensure the replication streamer is started if we register a + # replication server endpoint. + hs.get_replication_streamer() def buildProtocol(self, addr): return ServerReplicationStreamProtocol( - self.server_name, self.clock, self.streamer + self.server_name, self.clock, self.handler ) @@ -78,16 +73,6 @@ def __init__(self, hs): self._replication_torture_level = hs.config.replication_torture_level - # Current connections. - self.connections = [] # type: List[ServerReplicationStreamProtocol] - - LaterGauge( - "synapse_replication_tcp_resource_total_connections", - "", - [], - lambda: len(self.connections), - ) - # List of streams that clients can subscribe to. # We only support federation stream if federation sending hase been # disabled on the master. @@ -99,39 +84,17 @@ def __init__(self, hs): self.streams_by_name = {stream.NAME: stream for stream in self.streams} - LaterGauge( - "synapse_replication_tcp_resource_connections_per_stream", - "", - ["stream_name"], - lambda: { - (stream_name,): len( - [ - conn - for conn in self.connections - if stream_name in conn.replication_streams - ] - ) - for stream_name in self.streams_by_name - }, - ) - self.federation_sender = None if not hs.config.send_federation: self.federation_sender = hs.get_federation_sender() self.notifier.add_replication_callback(self.on_notifier_poke) - self.notifier.add_remote_server_up_callback(self.send_remote_server_up) # Keeps track of whether we are currently checking for updates self.is_looping = False self.pending_updates = False - hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown) - - def on_shutdown(self): - # close all connections on shutdown - for conn in self.connections: - conn.send_error("server shutting down") + self.client = hs.get_tcp_replication() def get_streams(self) -> Dict[str, Stream]: """Get a mapp from stream name to stream instance. @@ -145,7 +108,7 @@ def on_notifier_poke(self): This should get called each time new data is available, even if it is currently being executed, so that nothing gets missed """ - if not self.connections: + if not self.client.connected(): # Don't bother if nothing is listening. We still need to advance # the stream tokens otherwise they'll fall beihind forever for stream in self.streams: @@ -202,9 +165,7 @@ async def _run_notifier_loop(self): raise logger.debug( - "Sending %d updates to %d connections", - len(updates), - len(self.connections), + "Sending %d updates", len(updates), ) if updates: @@ -220,112 +181,17 @@ async def _run_notifier_loop(self): # token. See RdataCommand for more details. batched_updates = _batch_updates(updates) - for conn in self.connections: - for token, row in batched_updates: - try: - conn.stream_update(stream.NAME, token, row) - except Exception: - logger.exception("Failed to replicate") + for token, row in batched_updates: + try: + self.client.stream_update(stream.NAME, token, row) + except Exception: + logger.exception("Failed to replicate") logger.debug("No more pending updates, breaking poke loop") finally: self.pending_updates = False self.is_looping = False - def get_stream_token(self, stream_name): - """For a given stream get all updates since token. This is called when - a client first subscribes to a stream. - """ - stream = self.streams_by_name.get(stream_name, None) - if not stream: - raise Exception("unknown stream %s", stream_name) - - return stream.current_token() - - @measure_func("repl.federation_ack") - def federation_ack(self, token): - """We've received an ack for federation stream from a client. - """ - federation_ack_counter.inc() - if self.federation_sender: - self.federation_sender.federation_ack(token) - - @measure_func("repl.on_user_sync") - async def on_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): - """A client has started/stopped syncing on a worker. - """ - user_sync_counter.inc() - await self.presence_handler.update_external_syncs_row( - instance_id, user_id, is_syncing, last_sync_ms - ) - - async def on_clear_user_syncs(self, instance_id): - """A replication client wants us to drop all their UserSync data. - """ - await self.presence_handler.update_external_syncs_clear(instance_id) - - @measure_func("repl.on_remove_pusher") - async def on_remove_pusher(self, app_id, push_key, user_id): - """A client has asked us to remove a pusher - """ - remove_pusher_counter.inc() - await self.store.delete_pusher_by_app_id_pushkey_user_id( - app_id=app_id, pushkey=push_key, user_id=user_id - ) - - self.notifier.on_new_replication_data() - - @measure_func("repl.on_invalidate_cache") - async def on_invalidate_cache(self, cache_func: str, keys: List[Any]): - """The client has asked us to invalidate a cache - """ - invalidate_cache_counter.inc() - - # We invalidate the cache locally, but then also stream that to other - # workers. - await self.store.invalidate_cache_and_stream(cache_func, tuple(keys)) - - @measure_func("repl.on_user_ip") - async def on_user_ip( - self, user_id, access_token, ip, user_agent, device_id, last_seen - ): - """The client saw a user request - """ - user_ip_cache_counter.inc() - await self.store.insert_client_ip( - user_id, access_token, ip, user_agent, device_id, last_seen - ) - await self._server_notices_sender.on_user_ip(user_id) - - @measure_func("repl.on_remote_server_up") - def on_remote_server_up(self, server: str): - self.notifier.notify_remote_server_up(server) - - def send_remote_server_up(self, server: str): - for conn in self.connections: - conn.send_remote_server_up(server) - - def send_sync_to_all_connections(self, data): - """Sends a SYNC command to all clients. - - Used in tests. - """ - for conn in self.connections: - conn.send_sync(data) - - def new_connection(self, connection): - """A new client connection has been established - """ - self.connections.append(connection) - - def lost_connection(self, connection): - """A client connection has been lost - """ - try: - self.connections.remove(connection) - except ValueError: - pass - def _batch_updates(updates): """Takes a list of updates of form [(token, row)] and sets the token to