Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Thread through instance name to replication client
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston committed Apr 30, 2020
1 parent 37f6823 commit 91d30dc
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 18 deletions.
8 changes: 7 additions & 1 deletion synapse/replication/http/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class ReplicationEndpoint(object):
must call `register` to register the path with the HTTP server.
Requests can be sent by calling the client returned by `make_client`.
Requests are sent to master process by default, but can be sent to other
named processes by specifying an `instance_name` keyword argument.
Attributes:
NAME (str): A name for the endpoint, added to the path as well as used
Expand Down Expand Up @@ -135,7 +137,11 @@ def make_client(cls, hs):

@trace(opname="outgoing_replication_request")
@defer.inlineCallbacks
def send_request(**kwargs):
def send_request(instance_name="master", **kwargs):
# Currently we only support sending requests to master process.
if instance_name != "master":
raise Exception("Unknown instance")

data = yield cls._serialize_payload(**kwargs)

url_args = [
Expand Down
50 changes: 37 additions & 13 deletions synapse/replication/tcp/streams/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import logging
from collections import namedtuple
from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple
from typing import Any, Awaitable, Callable, List, Optional, Tuple

import attr

Expand Down Expand Up @@ -53,6 +53,7 @@
#
# The arguments are:
#
# * instance_name: the writer of the stream
# * from_token: the previous stream token: the starting point for fetching the
# updates
# * to_token: the new stream token: the point to get updates up to
Expand All @@ -62,7 +63,7 @@
# If there are more updates available, it should set `limited` in the result, and
# it will be called again to get the next batch.
#
UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]


class Stream(object):
Expand Down Expand Up @@ -93,6 +94,7 @@ def parse_row(cls, row: StreamRow):

def __init__(
self,
local_instance_name: str,
current_token_function: Callable[[], Token],
update_function: UpdateFunction,
):
Expand All @@ -108,9 +110,11 @@ def __init__(
stream tokens. See the UpdateFunction type definition for more info.
Args:
local_instance_name: The instance name of the current process
current_token_function: callback to get the current token, as above
update_function: callback go get stream updates, as above
"""
self.local_instance_name = local_instance_name
self.current_token = current_token_function
self.update_function = update_function

Expand All @@ -135,14 +139,14 @@ async def get_updates(self) -> StreamUpdateResult:
"""
current_token = self.current_token()
updates, current_token, limited = await self.get_updates_since(
self.last_token, current_token
self.local_instance_name, self.last_token, current_token
)
self.last_token = current_token

return updates, current_token, limited

async def get_updates_since(
self, from_token: Token, upto_token: Token
self, instance_name: str, from_token: Token, upto_token: Token
) -> StreamUpdateResult:
"""Like get_updates except allows specifying from when we should
stream updates
Expand All @@ -160,19 +164,19 @@ async def get_updates_since(
return [], upto_token, False

updates, upto_token, limited = await self.update_function(
from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
)
return updates, upto_token, limited


def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]]
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> UpdateFunction:
"""Wraps a db query function which returns a list of rows to make it
suitable for use as an `update_function` for the Stream class
"""

async def update_function(from_token, upto_token, limit):
async def update_function(instance_name, from_token, upto_token, limit):
rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows]
limited = False
Expand All @@ -194,10 +198,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
client = ReplicationGetStreamUpdates.make_client(hs)

async def update_function(
from_token: int, upto_token: int, limit: int
instance_name: str, from_token: int, upto_token: int, limit: int
) -> StreamUpdateResult:
result = await client(
stream_name=stream_name, from_token=from_token, upto_token=upto_token,
instance_name=instance_name,
stream_name=stream_name,
from_token=from_token,
upto_token=upto_token,
)
return result["updates"], result["upto_token"], result["limited"]

Expand Down Expand Up @@ -227,6 +234,7 @@ class BackfillStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_current_backfill_token,
db_query_to_update_function(store.get_all_new_backfill_event_rows),
)
Expand Down Expand Up @@ -262,7 +270,9 @@ def __init__(self, hs):
# Query master process
update_function = make_http_update_function(hs, self.NAME)

super().__init__(store.get_current_presence_token, update_function)
super().__init__(
hs.get_instance_name(), store.get_current_presence_token, update_function
)


class TypingStream(Stream):
Expand All @@ -285,7 +295,9 @@ def __init__(self, hs):
# Query master process
update_function = make_http_update_function(hs, self.NAME)

super().__init__(typing_handler.get_current_token, update_function)
super().__init__(
hs.get_instance_name(), typing_handler.get_current_token, update_function
)


class ReceiptsStream(Stream):
Expand All @@ -306,6 +318,7 @@ class ReceiptsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_max_receipt_stream_id,
db_query_to_update_function(store.get_all_updated_receipts),
)
Expand All @@ -323,14 +336,16 @@ class PushRulesStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
super(PushRulesStream, self).__init__(
self._current_token, self._update_function
hs.get_instance_name(), self._current_token, self._update_function
)

def _current_token(self) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token

async def _update_function(self, from_token: Token, to_token: Token, limit: int):
async def _update_function(
self, instance_name: str, from_token: Token, to_token: Token, limit: int
):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)

limited = False
Expand All @@ -357,6 +372,7 @@ def __init__(self, hs):
store = hs.get_datastore()

super().__init__(
hs.get_instance_name(),
store.get_pushers_stream_token,
db_query_to_update_function(store.get_all_updated_pushers_rows),
)
Expand Down Expand Up @@ -388,6 +404,7 @@ class CachesStreamRow:
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_cache_stream_token,
db_query_to_update_function(store.get_all_updated_caches),
)
Expand All @@ -413,6 +430,7 @@ class PublicRoomsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_current_public_room_stream_id,
db_query_to_update_function(store.get_all_new_public_rooms),
)
Expand All @@ -433,6 +451,7 @@ class DeviceListsStreamRow:
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_device_stream_token,
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
)
Expand All @@ -450,6 +469,7 @@ class ToDeviceStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_to_device_stream_token,
db_query_to_update_function(store.get_all_new_device_messages),
)
Expand All @@ -469,6 +489,7 @@ class TagAccountDataStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_max_account_data_stream_id,
db_query_to_update_function(store.get_all_updated_tags),
)
Expand All @@ -488,6 +509,7 @@ class AccountDataStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
self.store.get_max_account_data_stream_id,
db_query_to_update_function(self._update_function),
)
Expand Down Expand Up @@ -518,6 +540,7 @@ class GroupServerStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_group_stream_token,
db_query_to_update_function(store.get_all_groups_changes),
)
Expand All @@ -535,6 +558,7 @@ class UserSignatureStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_device_stream_token,
db_query_to_update_function(
store.get_all_user_signature_changes_for_remotes
Expand Down
10 changes: 8 additions & 2 deletions synapse/replication/tcp/streams/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,17 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
super().__init__(
self._store.get_current_events_token, self._update_function,
hs.get_instance_name(),
self._store.get_current_events_token,
self._update_function,
)

async def _update_function(
self, from_token: Token, current_token: Token, target_row_count: int
self,
instance_name: str,
from_token: Token,
current_token: Token,
target_row_count: int,
) -> StreamUpdateResult:

# the events stream merges together three separate sources:
Expand Down
4 changes: 2 additions & 2 deletions synapse/replication/tcp/streams/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def __init__(self, hs):
current_token = lambda: 0
update_function = self._stub_update_function

super().__init__(current_token, update_function)
super().__init__(hs.get_instance_name(), current_token, update_function)

@staticmethod
async def _stub_update_function(from_token, upto_token, limit):
async def _stub_update_function(instance_name, from_token, upto_token, limit):
return [], upto_token, False

0 comments on commit 91d30dc

Please sign in to comment.