From 6c4292d67ffaea30270b9c6e4e6335885b776532 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 25 Mar 2020 14:05:53 +0000 Subject: [PATCH] Pass instance name through to rdata --- synapse/app/generic_worker.py | 10 ++++------ synapse/replication/http/streams.py | 4 +++- synapse/replication/tcp/client.py | 4 +++- synapse/replication/tcp/handler.py | 19 ++++++++++++++----- tests/replication/tcp/streams/_base.py | 4 ++-- .../replication/tcp/streams/test_receipts.py | 4 ++-- tests/replication/tcp/streams/test_typing.py | 4 ++-- 7 files changed, 30 insertions(+), 19 deletions(-) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index d125327f08e6..c05d82d7489d 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -650,11 +650,9 @@ def __init__(self, hs): else: self.send_handler = None - async def on_rdata(self, stream_name, token, rows): - await super(GenericWorkerReplicationHandler, self).on_rdata( - stream_name, token, rows - ) - await self.process_and_notify(stream_name, token, rows) + async def on_rdata(self, stream_name, instance_name, token, rows): + await super().on_rdata(stream_name, instance_name, token, rows) + await self.process_and_notify(stream_name, instance_name, token, rows) def get_streams_to_replicate(self): args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate() @@ -663,7 +661,7 @@ def get_streams_to_replicate(self): args.update(self.send_handler.stream_positions()) return args - async def process_and_notify(self, stream_name, token, rows): + async def process_and_notify(self, stream_name, instance_name, token, rows): try: if self.send_handler: await self.send_handler.process_replication_rows( diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index f35cebc710e0..0459f582bfc7 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -50,6 +50,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): def __init__(self, hs): super().__init__(hs) + self._instance_name = hs.get_instance_name() + # We pull the streams from the replication steamer (if we try and make # them ourselves we end up in an import loop). self.streams = hs.get_replication_streamer().get_streams() @@ -67,7 +69,7 @@ async def _handle_request(self, request, stream_name): upto_token = parse_integer(request, "upto_token", required=True) updates, upto_token, limited = await stream.get_updates_since( - from_token, upto_token + self._instance_name, from_token, upto_token ) return ( diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 2d07b8b2d0de..9f9921005342 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -86,7 +86,9 @@ class ReplicationDataHandler: def __init__(self, store: BaseSlavedStore): self.store = store - async def on_rdata(self, stream_name: str, token: int, rows: list): + async def on_rdata( + self, stream_name: str, instance_name: str, token: int, rows: list + ): """Called to handle a batch of replication data with a given stream token. By default this just pokes the slave store. Can be overridden in subclasses to diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 6f7054d5af15..f6e15825c75e 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -278,9 +278,11 @@ async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): # Check if this is the last of a batch of updates rows = self._pending_batches.pop(stream_name, []) rows.append(row) - await self.on_rdata(stream_name, cmd.token, rows) + await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) - async def on_rdata(self, stream_name: str, token: int, rows: list): + async def on_rdata( + self, stream_name: str, instance_name: str, token: int, rows: list + ): """Called to handle a batch of replication data with a given stream token. Args: @@ -290,7 +292,9 @@ async def on_rdata(self, stream_name: str, token: int, rows: list): Stream.parse_row. """ logger.debug("Received rdata %s -> %s", stream_name, token) - await self._replication_data_handler.on_rdata(stream_name, token, rows) + await self._replication_data_handler.on_rdata( + stream_name, instance_name, token, rows + ) async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): if cmd.instance_name == self._instance_name: @@ -333,7 +337,9 @@ async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): updates, current_token, missing_updates, - ) = await stream.get_updates_since(current_token, cmd.token) + ) = await stream.get_updates_since( + cmd.instance_name, current_token, cmd.token + ) # TODO: add some tests for this @@ -342,7 +348,10 @@ async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): for token, rows in _batch_updates(updates): await self.on_rdata( - cmd.stream_name, token, [stream.parse_row(row) for row in rows], + cmd.stream_name, + cmd.instance_name, + token, + [stream.parse_row(row) for row in rows], ) # We've now caught up to position sent to us, notify handler. diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 83e16cfe3d08..5aa27d89f526 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -187,8 +187,8 @@ def __init__(self, store: BaseSlavedStore): def get_streams_to_replicate(self): return self.stream_positions - async def on_rdata(self, stream_name, token, rows): - await super().on_rdata(stream_name, token, rows) + async def on_rdata(self, stream_name, instance_name, token, rows): + await super().on_rdata(stream_name, instance_name, token, rows) for r in rows: self.received_rdata_rows.append((stream_name, token, r)) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index c122b8589c06..6d246232e0c7 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -44,7 +44,7 @@ def test_receipt(self): # there should be one RDATA command self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "receipts") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow @@ -74,7 +74,7 @@ def test_receipt(self): # We should now have caught up and get the missing data self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "receipts") self.assertEqual(token, 3) self.assertEqual(1, len(rdata_rows)) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index 4d354a9db847..397a23960661 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -50,7 +50,7 @@ def test_typing(self): self.assert_request_is_get_repl_stream_updates(request, "typing") self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] # type: TypingStream.TypingStreamRow @@ -77,7 +77,7 @@ def test_typing(self): self.assertEqual(int(request.args[b"from_token"][0]), token) self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0]