Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Handle local device list updates during partial join #13934

Merged
merged 2 commits into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13934.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Correctly handle sending local device list updates to remote servers during a partial join.
84 changes: 82 additions & 2 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,10 +762,90 @@ async def handle_room_un_partial_stated(self, room_id: str) -> None:
gone from partial to full state.
"""

# We defer to the device list updater implementation as we're on the
# right worker.
# We defer to the device list updater to handle pending remote device
# list updates.
await self.device_list_updater.handle_room_un_partial_stated(room_id)

# Replay local updates.
(
join_event_id,
device_lists_stream_id,
) = await self.store.get_join_event_id_and_device_lists_stream_id_for_partial_state(
room_id
)

# Get the local device list changes that have happened in the room since
# we started joining. If there are no updates there's nothing left to do.
changes = await self.store.get_device_list_changes_in_room(
room_id, device_lists_stream_id
)
local_changes = {(u, d) for u, d in changes if self.hs.is_mine_id(u)}
if not local_changes:
return

# Note: We have persisted the full state at this point, we just haven't
# cleared the `partial_room` flag.
join_state_ids = await self._state_storage.get_state_ids_for_event(
join_event_id, await_full_state=False
)
current_state_ids = await self.store.get_partial_current_state_ids(room_id)

# Now we need to work out all servers that might have been in the room
# at any point during our join.

# First we look for any membership states that have changed between the
# initial join and now...
all_keys = set(join_state_ids)
all_keys.update(current_state_ids)

potentially_changed_hosts = set()
for etype, state_key in all_keys:
if etype != EventTypes.Member:
continue

prev = join_state_ids.get((etype, state_key))
current = current_state_ids.get((etype, state_key))

if prev != current:
potentially_changed_hosts.add(get_domain_from_id(state_key))

# ... then we add all the hosts that are currently joined to the room...
current_hosts_in_room = await self.store.get_current_hosts_in_room(room_id)
potentially_changed_hosts.update(current_hosts_in_room)

# ... and finally we remove any hosts that we were told about, as we
# will have sent device list updates to those hosts when they happened.
known_hosts_at_join = await self.store.get_partial_state_servers_at_join(
room_id
)
potentially_changed_hosts.difference_update(known_hosts_at_join)

potentially_changed_hosts.discard(self.server_name)

if not potentially_changed_hosts:
# Nothing to do.
return

logger.info(
"Found %d changed hosts to send device list updates to",
len(potentially_changed_hosts),
)

for user_id, device_id in local_changes:
await self.store.add_device_list_outbound_pokes(
user_id=user_id,
device_id=device_id,
room_id=room_id,
stream_id=None,
hosts=potentially_changed_hosts,
context=None,
)

# Notify things that device lists need to be sent out.
self.notifier.notify_replication()
for host in potentially_changed_hosts:
self.federation_sender.send_device_messages(host, immediate=False)


def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
Expand Down
55 changes: 42 additions & 13 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,33 @@ def _get_device_list_changes_in_rooms_txn(

return changes

async def get_device_list_changes_in_room(
self, room_id: str, min_stream_id: int
) -> Collection[Tuple[str, str]]:
"""Get all device list changes that happened in the room since the given
stream ID.

Returns:
Collection of user ID/device ID tuples of all devices that have
changed
"""

sql = """
SELECT DISTINCT user_id, device_id FROM device_lists_changes_in_room
WHERE room_id = ? AND stream_id > ?
"""

def get_device_list_changes_in_room_txn(
txn: LoggingTransaction,
) -> Collection[Tuple[str, str]]:
txn.execute(sql, (room_id, min_stream_id))
return cast(Collection[Tuple[str, str]], txn.fetchall())

return await self.db_pool.runInteraction(
"get_device_list_changes_in_room",
get_device_list_changes_in_room_txn,
)


class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(
Expand Down Expand Up @@ -1946,14 +1973,15 @@ async def add_device_list_outbound_pokes(
user_id: str,
device_id: str,
room_id: str,
stream_id: int,
stream_id: Optional[int],
hosts: Collection[str],
context: Optional[Dict[str, str]],
) -> None:
"""Queue the device update to be sent to the given set of hosts,
calculated from the room ID.

Marks the associated row in `device_lists_changes_in_room` as handled.
Marks the associated row in `device_lists_changes_in_room` as handled,
if `stream_id` is provided.
"""

def add_device_list_outbound_pokes_txn(
Expand All @@ -1969,17 +1997,18 @@ def add_device_list_outbound_pokes_txn(
context=context,
)

self.db_pool.simple_update_txn(
txn,
table="device_lists_changes_in_room",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"stream_id": stream_id,
"room_id": room_id,
},
updatevalues={"converted_to_destinations": True},
)
if stream_id:
self.db_pool.simple_update_txn(
txn,
table="device_lists_changes_in_room",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"stream_id": stream_id,
"room_id": room_id,
},
updatevalues={"converted_to_destinations": True},
)

if not hosts:
# If there are no hosts then we don't try and generate stream IDs.
Expand Down
16 changes: 16 additions & 0 deletions synapse/storage/databases/main/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,22 @@ async def is_partial_state_room(self, room_id: str) -> bool:

return entry is not None

async def get_join_event_id_and_device_lists_stream_id_for_partial_state(
self, room_id: str
) -> Tuple[str, int]:
"""Get the event ID of the initial join that started the partial
join, and the device list stream ID at the point we started the partial
join.
"""

result = await self.db_pool.simple_select_one(
table="partial_state_rooms",
keyvalues={"room_id": room_id},
retcols=("join_event_id", "device_lists_stream_id"),
desc="get_join_event_id_for_partial_state",
)
return result["join_event_id"], result["device_lists_stream_id"]


class _BackgroundUpdates:
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
Expand Down