From 120e6217e560d1165e37c63a7f6cd0084c000a08 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 21 Apr 2023 14:09:05 -0400 Subject: [PATCH 1/8] Add support for claiming multiple OTKs at once. --- synapse/appservice/api.py | 26 +++++-- synapse/federation/federation_server.py | 2 +- synapse/handlers/appservice.py | 14 ++-- synapse/handlers/e2e_keys.py | 28 ++++--- synapse/rest/client/keys.py | 36 +++++++-- .../storage/databases/main/end_to_end_keys.py | 77 +++++++++++-------- tests/appservice/test_api.py | 11 ++- tests/handlers/test_e2e_keys.py | 32 ++++---- 8 files changed, 143 insertions(+), 83 deletions(-) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 86ddb1bb289e..6d7f2792dd0d 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -442,8 +442,10 @@ async def push_bulk( return False async def claim_client_keys( - self, service: "ApplicationService", query: List[Tuple[str, str, str]] - ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]: + self, service: "ApplicationService", query: List[Tuple[str, str, str, int]] + ) -> Tuple[ + Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]] + ]: """Claim one time keys from an application service. Note that any error (including a timeout) is treated as the application @@ -469,7 +471,8 @@ async def claim_client_keys( # Create the expected payload shape. body: Dict[str, Dict[str, List[str]]] = {} - for user_id, device, algorithm in query: + for user_id, device, algorithm, _count in query: + # Note that only a single OTK can be claimed this way. body.setdefault(user_id, {}).setdefault(device, []).append(algorithm) uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim" @@ -493,11 +496,18 @@ async def claim_client_keys( # or if some are still missing. # # TODO This places a lot of faith in the response shape being correct. - missing = [ - (user_id, device, algorithm) - for user_id, device, algorithm in query - if algorithm not in response.get(user_id, {}).get(device, []) - ] + missing = [] + for user_id, device, algorithm, count in query: + # The number of keys responded for this algorithm. + response_count = sum( + key_id.startswith(f"{algorithm}:") + for key_id in response.get(user_id, {}).get(device, {}) + ) + count -= response_count + # If the appservice responds with fewer keys than requested, then + # consider the request unfulfilled. + if count > 0: + missing.append((user_id, device, algorithm, count)) return response, missing diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index c618f3d7a6cd..8cd0ab50d2ed 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1010,7 +1010,7 @@ async def on_claim_client_keys( query = [] for user_id, device_keys in content.get("one_time_keys", {}).items(): for device_id, algorithm in device_keys.items(): - query.append((user_id, device_id, algorithm)) + query.append((user_id, device_id, algorithm, 1)) log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) results = await self._e2e_keys_handler.claim_local_one_time_keys( diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 4ca2bc04203b..6429545c98d5 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -841,8 +841,10 @@ async def _check_user_exists(self, user_id: str) -> bool: return True async def claim_e2e_one_time_keys( - self, query: Iterable[Tuple[str, str, str]] - ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]: + self, query: Iterable[Tuple[str, str, str, int]] + ) -> Tuple[ + Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]] + ]: """Claim one time keys from application services. Users which are exclusively owned by an application service are sent a @@ -863,18 +865,18 @@ async def claim_e2e_one_time_keys( services = self.store.get_app_services() # Partition the users by appservice. - query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {} + query_by_appservice: Dict[str, List[Tuple[str, str, str, int]]] = {} missing = [] - for user_id, device, algorithm in query: + for user_id, device, algorithm, count in query: if not self.store.get_if_app_services_interested_in_user(user_id): - missing.append((user_id, device, algorithm)) + missing.append((user_id, device, algorithm, count)) continue # Find the associated appservice. for service in services: if service.is_exclusive_user(user_id): query_by_appservice.setdefault(service.id, []).append( - (user_id, device, algorithm) + (user_id, device, algorithm, count) ) continue diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d1ab95126c0b..58de53a51318 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -564,7 +564,7 @@ async def on_federation_query_client_keys( async def claim_local_one_time_keys( self, - local_query: List[Tuple[str, str, str]], + local_query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool, ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]: """Claim one time keys for local users. @@ -607,7 +607,7 @@ async def claim_local_one_time_keys( # from the appservice for that user ID / device ID. If it is found, # check if any of the keys match the requested algorithm & are a # fallback key. - for user_id, device_id, algorithm in local_query: + for user_id, device_id, algorithm, _count in local_query: # Check if the appservice responded for this query. as_result = appservice_results.get(user_id, {}).get(device_id, {}) found_otk = False @@ -630,13 +630,17 @@ async def claim_local_one_time_keys( .get(device_id, {}) .keys() ) + # Note that it doesn't make sense to request more than 1 fallback key + # per (user_id, device_id, algorithm). fallback_query.append((user_id, device_id, algorithm, mark_as_used)) else: # All fallback keys get marked as used. fallback_query = [ + # Note that it doesn't make sense to request more than 1 fallback key + # per (user_id, device_id, algorithm). (user_id, device_id, algorithm, True) - for user_id, device_id, algorithm in not_found + for user_id, device_id, algorithm, count in not_found ] # For each user that does not have a one-time keys available, see if @@ -650,21 +654,27 @@ async def claim_local_one_time_keys( @trace async def claim_one_time_keys( self, - query: Dict[str, Dict[str, Dict[str, str]]], + query: Dict[str, Dict[str, Dict[str, int]]], timeout: Optional[int], always_include_fallback_keys: bool, ) -> JsonDict: - local_query: List[Tuple[str, str, str]] = [] + local_query: List[Tuple[str, str, str, int]] = [] remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {} - for user_id, one_time_keys in query.get("one_time_keys", {}).items(): + for user_id, one_time_keys in query.items(): # we use UserID.from_string to catch invalid user ids if self.is_mine(UserID.from_string(user_id)): - for device_id, algorithm in one_time_keys.items(): - local_query.append((user_id, device_id, algorithm)) + for device_id, algorithms in one_time_keys.items(): + for algorithm, count in algorithms.items(): + local_query.append((user_id, device_id, algorithm, count)) else: domain = get_domain_from_id(user_id) - remote_queries.setdefault(domain, {})[user_id] = one_time_keys + # TODO Support passing the count to remote destinations. + for device_id, algorithms in one_time_keys.items(): + if algorithms: + remote_queries.setdefault(domain, {})[user_id] = { + device_id: next(iter(algorithms)) + } set_tag("local_key_query", str(local_query)) set_tag("remote_key_query", str(remote_queries)) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 2a2509410961..cceffde7dbdc 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -16,7 +16,7 @@ import logging import re -from typing import TYPE_CHECKING, Any, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from synapse.api.errors import InvalidAPICallError, SynapseError from synapse.http.server import HttpServer @@ -289,16 +289,41 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) + + # Map the legacy request to the new request format. + query: Dict[str, Dict[str, Dict[str, int]]] = {} + for user_id, one_time_keys in body.get("one_time_keys", {}).items(): + for device_id, algorithm in one_time_keys.items(): + query.setdefault(user_id, {})[device_id] = {algorithm: 1} + result = await self.e2e_keys_handler.claim_one_time_keys( - body, timeout, always_include_fallback_keys=False + query, timeout, always_include_fallback_keys=False ) return 200, result class UnstableOneTimeKeyServlet(RestServlet): """ - Identical to the stable endpoint (OneTimeKeyServlet) except it always includes - fallback keys in the response. + Identical to the stable endpoint (OneTimeKeyServlet) except it allows for + querying for multiple OTKs at once and always includes fallback keys in the + response. + + POST /keys/claim HTTP/1.1 + { + "one_time_keys": { + "": { + "": { + "": + } } } } + + HTTP/1.1 200 OK + { + "one_time_keys": { + "": { + "": { + ":": "" + } } } } + """ PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")] @@ -313,8 +338,9 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) + query = body.get("one_time_keys", {}) result = await self.e2e_keys_handler.claim_one_time_keys( - body, timeout, always_include_fallback_keys=True + query, timeout, always_include_fallback_keys=True ) return 200, result diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 1a4ae55304ba..4bc391f21316 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1027,8 +1027,10 @@ def get_device_stream_token(self) -> int: ... async def claim_e2e_one_time_keys( - self, query_list: Iterable[Tuple[str, str, str]] - ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]: + self, query_list: Iterable[Tuple[str, str, str, int]] + ) -> Tuple[ + Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]] + ]: """Take a list of one time keys out of the database. Args: @@ -1043,8 +1045,12 @@ async def claim_e2e_one_time_keys( @trace def _claim_e2e_one_time_key_simple( - txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str - ) -> Optional[Tuple[str, str]]: + txn: LoggingTransaction, + user_id: str, + device_id: str, + algorithm: str, + count: int, + ) -> List[Tuple[str, str]]: """Claim OTK for device for DBs that don't support RETURNING. Returns: @@ -1055,36 +1061,41 @@ def _claim_e2e_one_time_key_simple( sql = """ SELECT key_id, key_json FROM e2e_one_time_keys_json WHERE user_id = ? AND device_id = ? AND algorithm = ? - LIMIT 1 + LIMIT ? """ - txn.execute(sql, (user_id, device_id, algorithm)) - otk_row = txn.fetchone() - if otk_row is None: - return None + txn.execute(sql, (user_id, device_id, algorithm, count)) + otk_rows = list(txn) + if not otk_rows: + return [] - key_id, key_json = otk_row - - self.db_pool.simple_delete_one_txn( + self.db_pool.simple_delete_many_txn( txn, table="e2e_one_time_keys_json", + column="key_id", + values=[otk_row[0] for otk_row in otk_rows], keyvalues={ "user_id": user_id, "device_id": device_id, "algorithm": algorithm, - "key_id": key_id, }, ) self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - return f"{algorithm}:{key_id}", key_json + return [ + (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows + ] @trace def _claim_e2e_one_time_key_returning( - txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str - ) -> Optional[Tuple[str, str]]: + txn: LoggingTransaction, + user_id: str, + device_id: str, + algorithm: str, + count: int, + ) -> List[Tuple[str, str]]: """Claim OTK for device for DBs that support RETURNING. Returns: @@ -1099,28 +1110,30 @@ def _claim_e2e_one_time_key_returning( AND key_id IN ( SELECT key_id FROM e2e_one_time_keys_json WHERE user_id = ? AND device_id = ? AND algorithm = ? - LIMIT 1 + LIMIT ? ) RETURNING key_id, key_json """ txn.execute( - sql, (user_id, device_id, algorithm, user_id, device_id, algorithm) + sql, + (user_id, device_id, algorithm, user_id, device_id, algorithm, count), ) - otk_row = txn.fetchone() - if otk_row is None: - return None + otk_rows = list(txn) + if not otk_rows: + return [] self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - key_id, key_json = otk_row - return f"{algorithm}:{key_id}", key_json + return [ + (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows + ] results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} - missing: List[Tuple[str, str, str]] = [] - for user_id, device_id, algorithm in query_list: + missing: List[Tuple[str, str, str, int]] = [] + for user_id, device_id, algorithm, count in query_list: if self.database_engine.supports_returning: # If we support RETURNING clause we can use a single query that # allows us to use autocommit mode. @@ -1130,21 +1143,25 @@ def _claim_e2e_one_time_key_returning( _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple db_autocommit = False - claim_row = await self.db_pool.runInteraction( + claim_rows = await self.db_pool.runInteraction( "claim_e2e_one_time_keys", _claim_e2e_one_time_key, user_id, device_id, algorithm, + count, db_autocommit=db_autocommit, ) - if claim_row: + if claim_rows: device_results = results.setdefault(user_id, {}).setdefault( device_id, {} ) - device_results[claim_row[0]] = json_decoder.decode(claim_row[1]) - else: - missing.append((user_id, device_id, algorithm)) + for claim_row in claim_rows: + device_results[claim_row[0]] = json_decoder.decode(claim_row[1]) + # Did we get enough OTKs? + count -= len(claim_rows) + if count: + missing.append((user_id, device_id, algorithm, count)) return results, missing diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py index 7deb923a280d..15fce165b611 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py @@ -195,11 +195,11 @@ async def post_json_get_json( MISSING_KEYS = [ # Known user, known device, missing algorithm. - ("@alice:example.org", "DEVICE_1", "signed_curve25519:DDDDHg"), + ("@alice:example.org", "DEVICE_2", "xyz", 1), # Known user, missing device. - ("@alice:example.org", "DEVICE_3", "signed_curve25519:EEEEHg"), + ("@alice:example.org", "DEVICE_3", "signed_curve25519", 1), # Unknown user. - ("@bob:example.org", "DEVICE_4", "signed_curve25519:FFFFHg"), + ("@bob:example.org", "DEVICE_4", "signed_curve25519", 1), ] claimed_keys, missing = self.get_success( @@ -207,9 +207,8 @@ async def post_json_get_json( self.service, [ # Found devices - ("@alice:example.org", "DEVICE_1", "signed_curve25519:AAAAHg"), - ("@alice:example.org", "DEVICE_1", "signed_curve25519:BBBBHg"), - ("@alice:example.org", "DEVICE_2", "signed_curve25519:CCCCHg"), + ("@alice:example.org", "DEVICE_1", "signed_curve25519", 1), + ("@alice:example.org", "DEVICE_2", "signed_curve25519", 1), ] + MISSING_KEYS, ) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 18edebd652fc..72d05840613e 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -160,7 +160,7 @@ def test_claim_one_time_key(self) -> None: res2 = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, + {local_user: {device_id: {"alg1": 1}}}, timeout=None, always_include_fallback_keys=False, ) @@ -205,7 +205,7 @@ def test_fallback_key(self) -> None: # key claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, + {local_user: {device_id: {"alg1": 1}}}, timeout=None, always_include_fallback_keys=False, ) @@ -224,7 +224,7 @@ def test_fallback_key(self) -> None: # claiming an OTK again should return the same fallback key claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, + {local_user: {device_id: {"alg1": 1}}}, timeout=None, always_include_fallback_keys=False, ) @@ -273,7 +273,7 @@ def test_fallback_key(self) -> None: claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, + {local_user: {device_id: {"alg1": 1}}}, timeout=None, always_include_fallback_keys=False, ) @@ -285,7 +285,7 @@ def test_fallback_key(self) -> None: claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, + {local_user: {device_id: {"alg1": 1}}}, timeout=None, always_include_fallback_keys=False, ) @@ -306,7 +306,7 @@ def test_fallback_key(self) -> None: claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, + {local_user: {device_id: {"alg1": 1}}}, timeout=None, always_include_fallback_keys=False, ) @@ -347,7 +347,7 @@ def test_fallback_key_always_returned(self) -> None: # return both. claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, + {local_user: {device_id: {"alg1": 1}}}, timeout=None, always_include_fallback_keys=True, ) @@ -369,7 +369,7 @@ def test_fallback_key_always_returned(self) -> None: # Claiming an OTK again should return only the fallback key. claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id: "alg1"}}}, + {local_user: {device_id: {"alg1": 1}}}, timeout=None, always_include_fallback_keys=True, ) @@ -1052,7 +1052,7 @@ def test_query_appservice(self) -> None: # Setup a response, but only for device 2. self.appservice_api.claim_client_keys.return_value = make_awaitable( - ({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1")]) + ({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1", 1)]) ) # we shouldn't have any unused fallback keys yet @@ -1079,11 +1079,7 @@ def test_query_appservice(self) -> None: # query the fallback keys. claim_res = self.get_success( self.handler.claim_one_time_keys( - { - "one_time_keys": { - local_user: {device_id_1: "alg1", device_id_2: "alg1"} - } - }, + {local_user: {device_id_1: {"alg1": 1}, device_id_2: {"alg1": 1}}}, timeout=None, always_include_fallback_keys=False, ) @@ -1128,7 +1124,7 @@ def test_query_appservice_with_fallback(self) -> None: # Claim OTKs, which will ask the appservice and do nothing else. claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id_1: "alg1"}}}, + {local_user: {device_id_1: {"alg1": 1}}}, timeout=None, always_include_fallback_keys=True, ) @@ -1172,7 +1168,7 @@ def test_query_appservice_with_fallback(self) -> None: # uploaded fallback key. claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id_1: "alg1"}}}, + {local_user: {device_id_1: {"alg1": 1}}}, timeout=None, always_include_fallback_keys=True, ) @@ -1205,7 +1201,7 @@ def test_query_appservice_with_fallback(self) -> None: # Claim OTKs, which will return information only from the database. claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id_1: "alg1"}}}, + {local_user: {device_id_1: {"alg1": 1}}}, timeout=None, always_include_fallback_keys=True, ) @@ -1232,7 +1228,7 @@ def test_query_appservice_with_fallback(self) -> None: # Claim OTKs, which will return only the fallback key from the database. claim_res = self.get_success( self.handler.claim_one_time_keys( - {"one_time_keys": {local_user: {device_id_1: "alg1"}}}, + {local_user: {device_id_1: {"alg1": 1}}}, timeout=None, always_include_fallback_keys=True, ) From 0d9593fe9185593f40b9dcee6e61b307c295881a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 21 Apr 2023 14:25:33 -0400 Subject: [PATCH 2/8] Limit the number of local queues that will be dispersed. --- synapse/handlers/e2e_keys.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 58de53a51318..092f1910b3b6 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -581,6 +581,12 @@ async def claim_local_one_time_keys( An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes. """ + # Cap the number of OTKs that can be claimed at once to avoid abuse. + local_query = [ + (user_id, device_id, algorithm, min(count, 5)) + for user_id, device_id, algorithm, count in local_query + ] + otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query) # If the application services have not provided any keys via the C-S From 384c9b77c400276ce4a6e82d8a0e9900bf169da1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 21 Apr 2023 15:13:11 -0400 Subject: [PATCH 3/8] Attempt calling out over federation for multiple OTK claims. --- synapse/federation/federation_client.py | 40 ++++++++++++++++++++-- synapse/federation/transport/client.py | 45 ++++++++++++++++++++++++- synapse/handlers/e2e_keys.py | 11 ++---- 3 files changed, 85 insertions(+), 11 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index ba34573d466d..c3e03aec8d96 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -235,7 +235,10 @@ async def query_user_devices( ) async def claim_client_keys( - self, destination: str, content: JsonDict, timeout: Optional[int] + self, + destination: str, + content: Dict[str, Dict[str, Dict[str, int]]], + timeout: Optional[int], ) -> JsonDict: """Claims one-time keys for a device hosted on a remote server. @@ -247,8 +250,41 @@ async def claim_client_keys( The JSON object from the response """ sent_queries_counter.labels("client_one_time_keys").inc() + + # Convert the query with counts into a legacy query and check if attempting + # to claim more than 1 OTK. + legacy_content: Dict[str, Dict[str, str]] = {} + use_unstable = False + for user_id, one_time_keys in content.items(): + for device_id, algorithms in one_time_keys.items(): + if any(count > 1 for count in algorithms.values()): + use_unstable = True + if algorithms: + # Choose the first algorithm only. + legacy_content.setdefault(user_id, {})[device_id] = next( + iter(algorithms) + ) + + if use_unstable: + try: + return await self.transport_layer.claim_client_keys_unstable( + destination, content, timeout + ) + except HttpResponseException as e: + # If an error is received that is due to an unrecognised endpoint, + # fallback to the v1 endpoint. Otherwise, consider it a legitimate error + # and raise. + if not is_unknown_endpoint(e): + raise + + logger.debug( + "Couldn't claim client keys with the unstable API, falling back to the v1 API" + ) + else: + logger.debug("Skipping unstable claim client keys API") + return await self.transport_layer.claim_client_keys( - destination, content, timeout + destination, legacy_content, timeout ) @trace diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index bedbd23dedee..26df867c7065 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -669,7 +669,50 @@ async def claim_client_keys( path = _create_v1_path("/user/keys/claim") return await self.client.post_json( - destination=destination, path=path, data=query_content, timeout=timeout + destination=destination, + path=path, + data={"one_time_keys": query_content}, + timeout=timeout, + ) + + async def claim_client_keys_unstable( + self, destination: str, query_content: JsonDict, timeout: Optional[int] + ) -> JsonDict: + """Claim one-time keys for a list of devices hosted on a remote server. + + Request: + { + "one_time_keys": { + "": { + "": {"": } + } + } + } + + Response: + { + "device_keys": { + "": { + "": { + ":": "" + } + } + } + } + + Args: + destination: The server to query. + query_content: The user ids to query. + Returns: + A dict containing the one-time keys. + """ + path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/user/keys/claim") + + return await self.client.post_json( + destination=destination, + path=path, + data={"one_time_keys": query_content}, + timeout=timeout, ) async def get_missing_events( diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 092f1910b3b6..24741b667bb9 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -665,7 +665,7 @@ async def claim_one_time_keys( always_include_fallback_keys: bool, ) -> JsonDict: local_query: List[Tuple[str, str, str, int]] = [] - remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {} + remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {} for user_id, one_time_keys in query.items(): # we use UserID.from_string to catch invalid user ids @@ -675,12 +675,7 @@ async def claim_one_time_keys( local_query.append((user_id, device_id, algorithm, count)) else: domain = get_domain_from_id(user_id) - # TODO Support passing the count to remote destinations. - for device_id, algorithms in one_time_keys.items(): - if algorithms: - remote_queries.setdefault(domain, {})[user_id] = { - device_id: next(iter(algorithms)) - } + remote_queries.setdefault(domain, {})[user_id] = one_time_keys set_tag("local_key_query", str(local_query)) set_tag("remote_key_query", str(remote_queries)) @@ -708,7 +703,7 @@ async def claim_client_keys(destination: str) -> None: device_keys = remote_queries[destination] try: remote_result = await self.federation.claim_client_keys( - destination, {"one_time_keys": device_keys}, timeout=timeout + destination, device_keys, timeout=timeout ) for user_id, keys in remote_result["one_time_keys"].items(): if user_id in device_keys: From abedabd98348bc1c91dbebcbdc1196fc7127cdb0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 21 Apr 2023 15:23:59 -0400 Subject: [PATCH 4/8] Add a federation endpoint to handle multiple key requests. --- synapse/federation/federation_server.py | 7 +----- .../federation/transport/server/federation.py | 23 +++++++++++++++---- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 8cd0ab50d2ed..ca43c7bfc0d1 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1005,13 +1005,8 @@ async def on_query_user_devices( @trace async def on_claim_client_keys( - self, origin: str, content: JsonDict, always_include_fallback_keys: bool + self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool ) -> Dict[str, Any]: - query = [] - for user_id, device_keys in content.get("one_time_keys", {}).items(): - for device_id, algorithm in device_keys.items(): - query.append((user_id, device_id, algorithm, 1)) - log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) results = await self._e2e_keys_handler.claim_local_one_time_keys( query, always_include_fallback_keys=always_include_fallback_keys diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index e2340d70d509..5bf0629b7fa0 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -577,16 +577,23 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet): async def on_POST( self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] ) -> Tuple[int, JsonDict]: + # Flatten the request query. + key_query: List[Tuple[str, str, str, int]] = [] + for user_id, device_keys in content.get("one_time_keys", {}).items(): + for device_id, algorithm in device_keys.items(): + key_query.append((user_id, device_id, algorithm, 1)) + response = await self.handler.on_claim_client_keys( - origin, content, always_include_fallback_keys=False + key_query, always_include_fallback_keys=False ) return 200, response class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet): """ - Identical to the stable endpoint (FederationClientKeysClaimServlet) except it - always includes fallback keys in the response. + Identical to the stable endpoint (FederationClientKeysClaimServlet) except + it allows for querying for multiple OTKs at once and always includes fallback + keys in the response. """ PREFIX = FEDERATION_UNSTABLE_PREFIX @@ -596,8 +603,15 @@ class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet): async def on_POST( self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] ) -> Tuple[int, JsonDict]: + # Flatten the request query. + key_query: List[Tuple[str, str, str, int]] = [] + for user_id, device_keys in content.get("one_time_keys", {}).items(): + for device_id, algorithms in device_keys.items(): + for algorithm, count in algorithms.items(): + key_query.append((user_id, device_id, algorithm, count)) + response = await self.handler.on_claim_client_keys( - origin, content, always_include_fallback_keys=True + key_query, always_include_fallback_keys=True ) return 200, response @@ -805,6 +819,7 @@ async def on_POST( FederationClientKeysQueryServlet, FederationUserDevicesQueryServlet, FederationClientKeysClaimServlet, + FederationUnstableClientKeysClaimServlet, FederationThirdPartyInviteExchangeServlet, On3pidBindServlet, FederationVersionServlet, From 2d6b903a66bd80a97ece73754a595bda28aba513 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 21 Apr 2023 15:26:49 -0400 Subject: [PATCH 5/8] Newsfragment --- changelog.d/15468.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/15468.misc diff --git a/changelog.d/15468.misc b/changelog.d/15468.misc new file mode 100644 index 000000000000..e0a94f36fdf4 --- /dev/null +++ b/changelog.d/15468.misc @@ -0,0 +1 @@ +Support claiming more than one OTK at a time. From bb9081eeb72c8d0c9b9cf5e9aadd99795ebca5d1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 24 Apr 2023 13:42:44 -0400 Subject: [PATCH 6/8] Use a flat list of algorithms instead of a map. --- synapse/appservice/api.py | 7 +++-- synapse/federation/federation_client.py | 29 ++++++++++++------- .../federation/transport/server/federation.py | 8 +++-- synapse/rest/client/keys.py | 16 ++++++---- 4 files changed, 39 insertions(+), 21 deletions(-) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 6d7f2792dd0d..34e9796b8c6e 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -471,9 +471,10 @@ async def claim_client_keys( # Create the expected payload shape. body: Dict[str, Dict[str, List[str]]] = {} - for user_id, device, algorithm, _count in query: - # Note that only a single OTK can be claimed this way. - body.setdefault(user_id, {}).setdefault(device, []).append(algorithm) + for user_id, device, algorithm, count in query: + body.setdefault(user_id, {}).setdefault(device, []).extend( + [algorithm] * count + ) uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim" try: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index c3e03aec8d96..dee8e957a4a4 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -237,7 +237,7 @@ async def query_user_devices( async def claim_client_keys( self, destination: str, - content: Dict[str, Dict[str, Dict[str, int]]], + query: Dict[str, Dict[str, Dict[str, int]]], timeout: Optional[int], ) -> JsonDict: """Claims one-time keys for a device hosted on a remote server. @@ -251,24 +251,33 @@ async def claim_client_keys( """ sent_queries_counter.labels("client_one_time_keys").inc() - # Convert the query with counts into a legacy query and check if attempting - # to claim more than 1 OTK. - legacy_content: Dict[str, Dict[str, str]] = {} + # Convert the query with counts into a stable and unstable query and check + # if attempting to claim more than 1 OTK. + content: Dict[str, Dict[str, str]] = {} + unstable_content: Dict[str, Dict[str, List[str]]] = {} use_unstable = False - for user_id, one_time_keys in content.items(): + for user_id, one_time_keys in query.items(): for device_id, algorithms in one_time_keys.items(): if any(count > 1 for count in algorithms.values()): use_unstable = True if algorithms: - # Choose the first algorithm only. - legacy_content.setdefault(user_id, {})[device_id] = next( - iter(algorithms) + # Choose the first algorithm only for the stable query. + content.setdefault(user_id, {})[device_id] = next(iter(algorithms)) + # Flatten the map of algorithm -> count to a list repeating + # each algorithm count times for the unstable query. + unstable_content.setdefault(user_id, {})[device_id] = list( + itertools.chain( + *( + itertools.repeat(algorithm, count) + for algorithm, count in algorithms.items() + ) + ) ) if use_unstable: try: return await self.transport_layer.claim_client_keys_unstable( - destination, content, timeout + destination, unstable_content, timeout ) except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, @@ -284,7 +293,7 @@ async def claim_client_keys( logger.debug("Skipping unstable claim client keys API") return await self.transport_layer.claim_client_keys( - destination, legacy_content, timeout + destination, content, timeout ) @trace diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 5bf0629b7fa0..36b0362504f5 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from collections import Counter from typing import ( TYPE_CHECKING, Dict, @@ -577,7 +578,7 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet): async def on_POST( self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] ) -> Tuple[int, JsonDict]: - # Flatten the request query. + # Generate a count for each algorithm, which is hard-coded to 1. key_query: List[Tuple[str, str, str, int]] = [] for user_id, device_keys in content.get("one_time_keys", {}).items(): for device_id, algorithm in device_keys.items(): @@ -603,11 +604,12 @@ class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet): async def on_POST( self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] ) -> Tuple[int, JsonDict]: - # Flatten the request query. + # Generate a count for each algorithm. key_query: List[Tuple[str, str, str, int]] = [] for user_id, device_keys in content.get("one_time_keys", {}).items(): for device_id, algorithms in device_keys.items(): - for algorithm, count in algorithms.items(): + counts = Counter(algorithms) + for algorithm, count in counts.items(): key_query.append((user_id, device_id, algorithm, count)) response = await self.handler.on_claim_client_keys( diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index cceffde7dbdc..9bbab5e6241e 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -16,6 +16,7 @@ import logging import re +from collections import Counter from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from synapse.api.errors import InvalidAPICallError, SynapseError @@ -290,7 +291,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - # Map the legacy request to the new request format. + # Generate a count for each algorithm, which is hard-coded to 1. query: Dict[str, Dict[str, Dict[str, int]]] = {} for user_id, one_time_keys in body.get("one_time_keys", {}).items(): for device_id, algorithm in one_time_keys.items(): @@ -312,9 +313,8 @@ class UnstableOneTimeKeyServlet(RestServlet): { "one_time_keys": { "": { - "": { - "": - } } } } + "": ["", ...] + } } } HTTP/1.1 200 OK { @@ -338,7 +338,13 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - query = body.get("one_time_keys", {}) + + # Generate a count for each algorithm. + query: Dict[str, Dict[str, Dict[str, int]]] = {} + for user_id, one_time_keys in body.get("one_time_keys", {}).items(): + for device_id, algorithms in one_time_keys.items(): + query.setdefault(user_id, {})[device_id] = Counter(algorithms) + result = await self.e2e_keys_handler.claim_one_time_keys( query, timeout, always_include_fallback_keys=True ) From 72c35ca557b5d58d44b53b18c97854901fb1414e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 26 Apr 2023 10:10:42 -0400 Subject: [PATCH 7/8] Fix incorrect comments. --- synapse/federation/transport/client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 26df867c7065..bc70b94f6820 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -650,10 +650,10 @@ async def claim_client_keys( Response: { - "device_keys": { + "one_time_keys": { "": { "": { - ":": "" + ":": } } } @@ -691,10 +691,10 @@ async def claim_client_keys_unstable( Response: { - "device_keys": { + "one_time_keys": { "": { "": { - ":": "" + ":": } } } From da9db7beba4705c26d1d3e32e974e3047020c7ac Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Apr 2023 11:46:28 -0400 Subject: [PATCH 8/8] Clarify comments. --- synapse/appservice/api.py | 4 +++- synapse/federation/federation_client.py | 8 +++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 34e9796b8c6e..024098e9cbb0 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -499,7 +499,9 @@ async def claim_client_keys( # TODO This places a lot of faith in the response shape being correct. missing = [] for user_id, device, algorithm, count in query: - # The number of keys responded for this algorithm. + # Count the number of keys in the response for this algorithm by + # checking which key IDs start with the algorithm. This uses that + # True == 1 in Python to generate a count. response_count = sum( key_id.startswith(f"{algorithm}:") for key_id in response.get(user_id, {}).get(device, {}) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index dee8e957a4a4..0b2d1a78f7b5 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -261,10 +261,12 @@ async def claim_client_keys( if any(count > 1 for count in algorithms.values()): use_unstable = True if algorithms: - # Choose the first algorithm only for the stable query. + # For the stable query, choose only the first algorithm. content.setdefault(user_id, {})[device_id] = next(iter(algorithms)) - # Flatten the map of algorithm -> count to a list repeating - # each algorithm count times for the unstable query. + # For the unstable query, repeat each algorithm by count, then + # splat those into chain to get a flattened list of all algorithms. + # + # Converts from {"algo1": 2, "algo2": 2} to ["algo1", "algo1", "algo2"]. unstable_content.setdefault(user_id, {})[device_id] = list( itertools.chain( *(