Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Pass instance name through to rdata
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston committed Apr 30, 2020
1 parent 91d30dc commit 6c4292d
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 19 deletions.
10 changes: 4 additions & 6 deletions synapse/app/generic_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,11 +650,9 @@ def __init__(self, hs):
else:
self.send_handler = None

async def on_rdata(self, stream_name, token, rows):
await super(GenericWorkerReplicationHandler, self).on_rdata(
stream_name, token, rows
)
await self.process_and_notify(stream_name, token, rows)
async def on_rdata(self, stream_name, instance_name, token, rows):
await super().on_rdata(stream_name, instance_name, token, rows)
await self.process_and_notify(stream_name, instance_name, token, rows)

def get_streams_to_replicate(self):
args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate()
Expand All @@ -663,7 +661,7 @@ def get_streams_to_replicate(self):
args.update(self.send_handler.stream_positions())
return args

async def process_and_notify(self, stream_name, token, rows):
async def process_and_notify(self, stream_name, instance_name, token, rows):
try:
if self.send_handler:
await self.send_handler.process_replication_rows(
Expand Down
4 changes: 3 additions & 1 deletion synapse/replication/http/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
def __init__(self, hs):
super().__init__(hs)

self._instance_name = hs.get_instance_name()

# We pull the streams from the replication steamer (if we try and make
# them ourselves we end up in an import loop).
self.streams = hs.get_replication_streamer().get_streams()
Expand All @@ -67,7 +69,7 @@ async def _handle_request(self, request, stream_name):
upto_token = parse_integer(request, "upto_token", required=True)

updates, upto_token, limited = await stream.get_updates_since(
from_token, upto_token
self._instance_name, from_token, upto_token
)

return (
Expand Down
4 changes: 3 additions & 1 deletion synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ class ReplicationDataHandler:
def __init__(self, store: BaseSlavedStore):
self.store = store

async def on_rdata(self, stream_name: str, token: int, rows: list):
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
Expand Down
19 changes: 14 additions & 5 deletions synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,11 @@ async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
# Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, [])
rows.append(row)
await self.on_rdata(stream_name, cmd.token, rows)
await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)

async def on_rdata(self, stream_name: str, token: int, rows: list):
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
):
"""Called to handle a batch of replication data with a given stream token.
Args:
Expand All @@ -290,7 +292,9 @@ async def on_rdata(self, stream_name: str, token: int, rows: list):
Stream.parse_row.
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
await self._replication_data_handler.on_rdata(stream_name, token, rows)
await self._replication_data_handler.on_rdata(
stream_name, instance_name, token, rows
)

async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
Expand Down Expand Up @@ -333,7 +337,9 @@ async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
updates,
current_token,
missing_updates,
) = await stream.get_updates_since(current_token, cmd.token)
) = await stream.get_updates_since(
cmd.instance_name, current_token, cmd.token
)

# TODO: add some tests for this

Expand All @@ -342,7 +348,10 @@ async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):

for token, rows in _batch_updates(updates):
await self.on_rdata(
cmd.stream_name, token, [stream.parse_row(row) for row in rows],
cmd.stream_name,
cmd.instance_name,
token,
[stream.parse_row(row) for row in rows],
)

# We've now caught up to position sent to us, notify handler.
Expand Down
4 changes: 2 additions & 2 deletions tests/replication/tcp/streams/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ def __init__(self, store: BaseSlavedStore):
def get_streams_to_replicate(self):
return self.stream_positions

async def on_rdata(self, stream_name, token, rows):
await super().on_rdata(stream_name, token, rows)
async def on_rdata(self, stream_name, instance_name, token, rows):
await super().on_rdata(stream_name, instance_name, token, rows)
for r in rows:
self.received_rdata_rows.append((stream_name, token, r))

Expand Down
4 changes: 2 additions & 2 deletions tests/replication/tcp/streams/test_receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_receipt(self):

# there should be one RDATA command
self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_receipt(self):

# We should now have caught up and get the missing data
self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts")
self.assertEqual(token, 3)
self.assertEqual(1, len(rdata_rows))
Expand Down
4 changes: 2 additions & 2 deletions tests/replication/tcp/streams/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_typing(self):
self.assert_request_is_get_repl_stream_updates(request, "typing")

self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
Expand All @@ -77,7 +77,7 @@ def test_typing(self):
self.assertEqual(int(request.args[b"from_token"][0]), token)

self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0]
Expand Down

0 comments on commit 6c4292d

Please sign in to comment.