Skip to content

Commit

Permalink
Reduce device lists replication traffic. (#17333)
Browse files Browse the repository at this point in the history
Reduce the replication traffic of device lists, by not sending every
destination that needs to be sent the device list update over
replication. Instead a "hosts to send to have been calculated"
notification over replication, and then federation senders read the
destinations from the DB.

For non federation senders this should heavily reduce the impact of a
user in many large rooms changing a device.
  • Loading branch information
erikjohnston authored Jun 24, 2024
1 parent 700d2cc commit cf711ac
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 48 deletions.
1 change: 1 addition & 0 deletions changelog.d/17333.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Handle device lists notifications for large accounts more efficiently in worker mode.
19 changes: 12 additions & 7 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,19 @@ async def on_rdata(
"""
all_room_ids: Set[str] = set()
if stream_name == DeviceListsStream.NAME:
if any(row.entity.startswith("@") and not row.is_signature for row in rows):
if any(not row.is_signature and not row.hosts_calculated for row in rows):
prev_token = self.store.get_device_stream_token()
all_room_ids = await self.store.get_all_device_list_changes(
prev_token, token
)
self.store.device_lists_in_rooms_have_changed(all_room_ids, token)

# If we're sending federation we need to update the device lists
# outbound pokes stream change cache with updated hosts.
if self.send_handler and any(row.hosts_calculated for row in rows):
hosts = await self.store.get_destinations_for_device(token)
self.store.device_lists_outbound_pokes_have_changed(hosts, token)

self.store.process_replication_rows(stream_name, instance_name, token, rows)
# NOTE: this must be called after process_replication_rows to ensure any
# cache invalidations are first handled before any stream ID advances.
Expand Down Expand Up @@ -433,12 +439,11 @@ async def process_replication_rows(
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
hosts = {
row.entity
for row in rows
if not row.entity.startswith("@") and not row.is_signature
}
await self.federation_sender.send_device_messages(hosts, immediate=False)
if any(row.hosts_calculated for row in rows):
hosts = await self.store.get_destinations_for_device(token)
await self.federation_sender.send_device_messages(
hosts, immediate=False
)

elif stream_name == ToDeviceStream.NAME:
# The to_device stream includes stuff to be pushed to both local
Expand Down
12 changes: 8 additions & 4 deletions synapse/replication/tcp/streams/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,10 +549,14 @@ class DeviceListsStream(_StreamFromIdGen):

@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceListsStreamRow:
entity: str
user_id: str
# Indicates that a user has signed their own device with their user-signing key
is_signature: bool

# Indicates if this is a notification that we've calculated the hosts we
# need to send the update to.
hosts_calculated: bool

NAME = "device_lists"
ROW_TYPE = DeviceListsStreamRow

Expand Down Expand Up @@ -594,13 +598,13 @@ async def _update_function(
upper_limit_token = min(upper_limit_token, signatures_to_token)

device_updates = [
(stream_id, (entity, False))
for stream_id, (entity,) in device_updates
(stream_id, (entity, False, hosts))
for stream_id, (entity, hosts) in device_updates
if stream_id <= upper_limit_token
]

signatures_updates = [
(stream_id, (entity, True))
(stream_id, (entity, True, False))
for stream_id, (entity,) in signatures_updates
if stream_id <= upper_limit_token
]
Expand Down
93 changes: 58 additions & 35 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,22 +164,24 @@ def __init__(
prefilled_cache=user_signature_stream_prefill,
)

(
device_list_federation_prefill,
device_list_federation_list_id,
) = self.db_pool.get_cache_dict(
db_conn,
"device_lists_outbound_pokes",
entity_column="destination",
stream_column="stream_id",
max_value=device_list_max,
limit=10000,
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache",
device_list_federation_list_id,
prefilled_cache=device_list_federation_prefill,
)
self._device_list_federation_stream_cache = None
if hs.should_send_federation():
(
device_list_federation_prefill,
device_list_federation_list_id,
) = self.db_pool.get_cache_dict(
db_conn,
"device_lists_outbound_pokes",
entity_column="destination",
stream_column="stream_id",
max_value=device_list_max,
limit=10000,
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache",
device_list_federation_list_id,
prefilled_cache=device_list_federation_prefill,
)

if hs.config.worker.run_background_tasks:
self._clock.looping_call(
Expand Down Expand Up @@ -207,23 +209,30 @@ def _invalidate_caches_for_devices(
) -> None:
for row in rows:
if row.is_signature:
self._user_signature_stream_cache.entity_has_changed(row.entity, token)
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
continue

# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
if row.entity.startswith("@"):
self._device_list_stream_cache.entity_has_changed(row.entity, token)
self.get_cached_devices_for_user.invalidate((row.entity,))
self._get_cached_user_device.invalidate((row.entity,))
self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))

else:
self._device_list_federation_stream_cache.entity_has_changed(
row.entity, token
if not row.hosts_calculated:
self._device_list_stream_cache.entity_has_changed(row.user_id, token)
self.get_cached_devices_for_user.invalidate((row.user_id,))
self._get_cached_user_device.invalidate((row.user_id,))
self.get_device_list_last_stream_id_for_remote.invalidate(
(row.user_id,)
)

def device_lists_outbound_pokes_have_changed(
self, destinations: StrCollection, token: int
) -> None:
assert self._device_list_federation_stream_cache is not None

for destination in destinations:
self._device_list_federation_stream_cache.entity_has_changed(
destination, token
)

def device_lists_in_rooms_have_changed(
self, room_ids: StrCollection, token: int
) -> None:
Expand Down Expand Up @@ -363,6 +372,11 @@ async def get_device_updates_by_remote(
EDU contents.
"""
now_stream_id = self.get_device_stream_token()
if from_stream_id == now_stream_id:
return now_stream_id, []

if self._device_list_federation_stream_cache is None:
raise Exception("Func can only be used on federation senders")

has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
Expand Down Expand Up @@ -1018,10 +1032,10 @@ def _get_all_device_list_changes_for_remotes(
# This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries.
sql = """
SELECT stream_id, entity FROM (
SELECT stream_id, user_id AS entity FROM device_lists_stream
SELECT stream_id, user_id, hosts FROM (
SELECT stream_id, user_id, false AS hosts FROM device_lists_stream
UNION ALL
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
SELECT DISTINCT stream_id, user_id, true AS hosts FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
Expand Down Expand Up @@ -1577,6 +1591,14 @@ def get_device_list_changes_in_room_txn(
get_device_list_changes_in_room_txn,
)

async def get_destinations_for_device(self, stream_id: int) -> StrCollection:
return await self.db_pool.simple_select_onecol(
table="device_lists_outbound_pokes",
keyvalues={"stream_id": stream_id},
retcol="destination",
desc="get_destinations_for_device",
)


class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(
Expand Down Expand Up @@ -2112,12 +2134,13 @@ def _add_device_outbound_poke_to_stream_txn(
stream_ids: List[int],
context: Optional[Dict[str, str]],
) -> None:
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host,
stream_ids[-1],
)
if self._device_list_federation_stream_cache:
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host,
stream_ids[-1],
)

now = self._clock.time_msec()
stream_id_iterator = iter(stream_ids)
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ def process_replication_rows(
if stream_name == DeviceListsStream.NAME:
for row in rows:
assert isinstance(row, DeviceListsStream.DeviceListsStreamRow)
if row.entity.startswith("@"):
if not row.hosts_calculated:
self._get_e2e_device_keys_for_federation_query_inner.invalidate(
(row.entity,)
(row.user_id,)
)

super().process_replication_rows(stream_name, instance_name, token, rows)
Expand Down
8 changes: 8 additions & 0 deletions tests/storage/test_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ class DeviceStoreTestCase(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main

def default_config(self) -> JsonDict:
config = super().default_config()

# We 'enable' federation otherwise `get_device_updates_by_remote` will
# throw an exception.
config["federation_sender_instances"] = ["master"]
return config

def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None:
"""Add a device list change for the given device to
`device_lists_outbound_pokes` table.
Expand Down

0 comments on commit cf711ac

Please sign in to comment.