From 88ba6f9fb08654c7e5428b70e55c599dd95802e9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 16 Aug 2023 15:13:24 +0100 Subject: [PATCH 1/9] Deduplicate query desc --- synapse/storage/databases/main/keys.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index cea32a034a4d..3847acd08dd6 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -295,5 +295,5 @@ def _get_server_keys_json_txn( return results return await self.db_pool.runInteraction( - "get_server_keys_json", _get_server_keys_json_txn + "get_server_keys_json_for_remote", _get_server_keys_json_txn ) From a127d927905759ddc8939cacd3700e3c8b3d6292 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 16 Aug 2023 15:13:31 +0100 Subject: [PATCH 2/9] Remove unused param --- synapse/rest/key/v2/remote_key_resource.py | 4 ++-- synapse/storage/databases/main/keys.py | 23 ++++++++++------------ tests/crypto/test_keyring.py | 18 ++++++++--------- 3 files changed, 21 insertions(+), 24 deletions(-) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 8f3865d41233..a9140894ed33 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -162,7 +162,7 @@ async def query_keys( if not key_ids: key_ids = (None,) for key_id in key_ids: - store_queries.append((server_name, key_id, None)) + store_queries.append((server_name, key_id)) cached = await self.store.get_server_keys_json_for_remote(store_queries) @@ -173,7 +173,7 @@ async def query_keys( # Map server_name->key_id->int. Note that the value of the int is unused. # XXX: why don't we just use a set? cache_misses: Dict[str, Dict[str, int]] = {} - for (server_name, key_id, _), key_results in cached.items(): + for (server_name, key_id), key_results in cached.items(): results = [(result["ts_added_ms"], result) for result in key_results] if key_id is None: diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 3847acd08dd6..2130da0c51c2 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -254,31 +254,28 @@ def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: return await self.db_pool.runInteraction("get_server_keys_json", _txn) async def get_server_keys_json_for_remote( - self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] - ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]: + self, server_keys: Iterable[Tuple[str, Optional[str]]] + ) -> Dict[Tuple[str, Optional[str]], List[Dict[str, Any]]]: """Retrieve the key json for a list of server_keys and key ids. - If no keys are found for a given server, key_id and source then - that server, key_id, and source triplet entry will be an empty list. - The JSON is returned as a byte array so that it can be efficiently - used in an HTTP response. + If no keys are found for a given server and key_id then that server and + key_id tuple entry will be an empty list. The JSON is returned as a byte + array so that it can be efficiently used in an HTTP response. Args: - server_keys: List of (server_name, key_id, source) triplets. + server_keys: List of (server_name, key_id) tuples. Returns: - A mapping from (server_name, key_id, source) triplets to a list of dicts + A mapping from (server_name, key_id) tuples to a list of dicts """ def _get_server_keys_json_txn( txn: LoggingTransaction, - ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]: + ) -> Dict[Tuple[str, Optional[str]], List[Dict[str, Any]]]: results = {} - for server_name, key_id, from_server in server_keys: + for server_name, key_id in server_keys: keyvalues = {"server_name": server_name} if key_id is not None: keyvalues["key_id"] = key_id - if from_server is not None: - keyvalues["from_server"] = from_server rows = self.db_pool.simple_select_list_txn( txn, "server_keys_json", @@ -291,7 +288,7 @@ def _get_server_keys_json_txn( "key_json", ), ) - results[(server_name, key_id, from_server)] = rows + results[(server_name, key_id)] = rows return results return await self.db_pool.runInteraction( diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index fdfd4f911d6b..16ebbe9abdb7 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -456,13 +456,13 @@ async def get_json(destination: str, path: str, **kwargs: Any) -> JsonDict: self.assertEqual(k.verify_key.version, "ver1") # check that the perspectives store is correctly updated - lookup_triplet = (SERVER_NAME, testverifykey_id, None) + lookup_tuple = (SERVER_NAME, testverifykey_id) key_json = self.get_success( self.hs.get_datastores().main.get_server_keys_json_for_remote( - [lookup_triplet] + [lookup_tuple] ) ) - res_keys = key_json[lookup_triplet] + res_keys = key_json[lookup_tuple] self.assertEqual(len(res_keys), 1) res = res_keys[0] self.assertEqual(res["key_id"], testverifykey_id) @@ -576,13 +576,13 @@ def test_get_keys_from_perspectives(self) -> None: self.assertEqual(k.verify_key.version, "ver1") # check that the perspectives store is correctly updated - lookup_triplet = (SERVER_NAME, testverifykey_id, None) + lookup_tuple = (SERVER_NAME, testverifykey_id) key_json = self.get_success( self.hs.get_datastores().main.get_server_keys_json_for_remote( - [lookup_triplet] + [lookup_tuple] ) ) - res_keys = key_json[lookup_triplet] + res_keys = key_json[lookup_tuple] self.assertEqual(len(res_keys), 1) res = res_keys[0] self.assertEqual(res["key_id"], testverifykey_id) @@ -699,13 +699,13 @@ def test_get_perspectives_own_key(self) -> None: self.assertEqual(k.verify_key.version, "ver1") # check that the perspectives store is correctly updated - lookup_triplet = (SERVER_NAME, testverifykey_id, None) + lookup_tuple = (SERVER_NAME, testverifykey_id) key_json = self.get_success( self.hs.get_datastores().main.get_server_keys_json_for_remote( - [lookup_triplet] + [lookup_tuple] ) ) - res_keys = key_json[lookup_triplet] + res_keys = key_json[lookup_tuple] self.assertEqual(len(res_keys), 1) res = res_keys[0] self.assertEqual(res["key_id"], testverifykey_id) From cb1568bbb97d7ab0feaf24ad7346e974faf4a366 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 16 Aug 2023 15:29:41 +0100 Subject: [PATCH 3/9] Refactor to return a proper type --- synapse/rest/key/v2/remote_key_resource.py | 11 ++++---- synapse/storage/databases/main/keys.py | 19 ++++++++++---- synapse/storage/keys.py | 7 +++++ tests/crypto/test_keyring.py | 30 +++++++--------------- 4 files changed, 35 insertions(+), 32 deletions(-) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index a9140894ed33..2a8ac77a7dff 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -174,14 +174,13 @@ async def query_keys( # XXX: why don't we just use a set? cache_misses: Dict[str, Dict[str, int]] = {} for (server_name, key_id), key_results in cached.items(): - results = [(result["ts_added_ms"], result) for result in key_results] + results = [(result.added_ts, result) for result in key_results] if key_id is None: # all keys were requested. Just return what we have without worrying # about validity for _, result in results: - # Cast to bytes since postgresql returns a memoryview. - json_results.add(bytes(result["key_json"])) + json_results.add(result.key_json) continue miss = False @@ -189,7 +188,7 @@ async def query_keys( miss = True else: ts_added_ms, most_recent_result = max(results) - ts_valid_until_ms = most_recent_result["ts_valid_until_ms"] + ts_valid_until_ms = most_recent_result.valid_until_ts req_key = query.get(server_name, {}).get(key_id, {}) req_valid_until = req_key.get("minimum_valid_until_ts") if req_valid_until is not None: @@ -235,8 +234,8 @@ async def query_keys( ts_valid_until_ms, time_now_ms, ) - # Cast to bytes since postgresql returns a memoryview. - json_results.add(bytes(most_recent_result["key_json"])) + + json_results.add(most_recent_result.key_json) if miss and query_remote_on_cache_miss: # only bother attempting to fetch keys from servers on our whitelist diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 2130da0c51c2..3a1db4f18a9b 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -16,14 +16,14 @@ import itertools import json import logging -from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple +from typing import Dict, Iterable, List, Mapping, Optional, Tuple from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 from synapse.storage._base import SQLBaseStore from synapse.storage.database import LoggingTransaction -from synapse.storage.keys import FetchKeyResult +from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote from synapse.storage.types import Cursor from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter @@ -255,7 +255,7 @@ def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: async def get_server_keys_json_for_remote( self, server_keys: Iterable[Tuple[str, Optional[str]]] - ) -> Dict[Tuple[str, Optional[str]], List[Dict[str, Any]]]: + ) -> Dict[Tuple[str, Optional[str]], List[FetchKeyResultForRemote]]: """Retrieve the key json for a list of server_keys and key ids. If no keys are found for a given server and key_id then that server and key_id tuple entry will be an empty list. The JSON is returned as a byte @@ -270,12 +270,13 @@ async def get_server_keys_json_for_remote( def _get_server_keys_json_txn( txn: LoggingTransaction, - ) -> Dict[Tuple[str, Optional[str]], List[Dict[str, Any]]]: + ) -> Dict[Tuple[str, Optional[str]], List[FetchKeyResultForRemote]]: results = {} for server_name, key_id in server_keys: keyvalues = {"server_name": server_name} if key_id is not None: keyvalues["key_id"] = key_id + rows = self.db_pool.simple_select_list_txn( txn, "server_keys_json", @@ -288,7 +289,15 @@ def _get_server_keys_json_txn( "key_json", ), ) - results[(server_name, key_id)] = rows + results[(server_name, key_id)] = [ + FetchKeyResultForRemote( + # Cast to bytes since postgresql returns a memoryview. + key_json=bytes(row["key_json"]), + valid_until_ts=row["ts_valid_until_ms"], + added_ts=row["ts_added_ms"], + ) + for row in rows + ] return results return await self.db_pool.runInteraction( diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 71584f3f744b..24f16d20145c 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -25,3 +25,10 @@ class FetchKeyResult: verify_key: VerifyKey # the key itself valid_until_ts: int # how long we can use this key for + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class FetchKeyResultForRemote: + key_json: bytes # the full key JSON + valid_until_ts: int # how long we can use this key for + added_ts: int # When we added this key diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 16ebbe9abdb7..33fb565e0f93 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -465,15 +465,11 @@ async def get_json(destination: str, path: str, **kwargs: Any) -> JsonDict: res_keys = key_json[lookup_tuple] self.assertEqual(len(res_keys), 1) res = res_keys[0] - self.assertEqual(res["key_id"], testverifykey_id) - self.assertEqual(res["from_server"], SERVER_NAME) - self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) - self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS) + self.assertEqual(res.added_ts, self.reactor.seconds() * 1000) + self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS) # we expect it to be encoded as canonical json *before* it hits the db - self.assertEqual( - bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) - ) + self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response)) # change the server name: the result should be ignored response["server_name"] = "OTHER_SERVER" @@ -585,14 +581,10 @@ def test_get_keys_from_perspectives(self) -> None: res_keys = key_json[lookup_tuple] self.assertEqual(len(res_keys), 1) res = res_keys[0] - self.assertEqual(res["key_id"], testverifykey_id) - self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) - self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) - self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS) + self.assertEqual(res.added_ts, self.reactor.seconds() * 1000) + self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS) - self.assertEqual( - bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) - ) + self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response)) def test_get_multiple_keys_from_perspectives(self) -> None: """Check that we can correctly request multiple keys for the same server""" @@ -708,14 +700,10 @@ def test_get_perspectives_own_key(self) -> None: res_keys = key_json[lookup_tuple] self.assertEqual(len(res_keys), 1) res = res_keys[0] - self.assertEqual(res["key_id"], testverifykey_id) - self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) - self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) - self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS) + self.assertEqual(res.added_ts, self.reactor.seconds() * 1000) + self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS) - self.assertEqual( - bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) - ) + self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response)) def test_invalid_perspectives_responses(self) -> None: """Check that invalid responses from the perspectives server are rejected""" From 907c80354288cb46183b500a2f410a30c87fd37f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 16 Aug 2023 16:16:11 +0100 Subject: [PATCH 4/9] Return only the most recent key for each key ID --- synapse/rest/key/v2/remote_key_resource.py | 18 +++++----- synapse/storage/databases/main/keys.py | 39 ++++++++++++---------- tests/crypto/test_keyring.py | 18 +++++----- 3 files changed, 38 insertions(+), 37 deletions(-) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 2a8ac77a7dff..c034e02ca396 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -173,22 +173,20 @@ async def query_keys( # Map server_name->key_id->int. Note that the value of the int is unused. # XXX: why don't we just use a set? cache_misses: Dict[str, Dict[str, int]] = {} - for (server_name, key_id), key_results in cached.items(): - results = [(result.added_ts, result) for result in key_results] - - if key_id is None: + for (server_name, key_id), key_result in cached.items(): + if not query[server_name]: # all keys were requested. Just return what we have without worrying # about validity - for _, result in results: - json_results.add(result.key_json) + if key_result: + json_results.add(key_result.key_json) continue miss = False - if not results: + if key_result is None: miss = True else: - ts_added_ms, most_recent_result = max(results) - ts_valid_until_ms = most_recent_result.valid_until_ts + ts_added_ms = key_result.added_ts + ts_valid_until_ms = key_result.valid_until_ts req_key = query.get(server_name, {}).get(key_id, {}) req_valid_until = req_key.get("minimum_valid_until_ts") if req_valid_until is not None: @@ -235,7 +233,7 @@ async def query_keys( time_now_ms, ) - json_results.add(most_recent_result.key_json) + json_results.add(key_result.key_json) if miss and query_remote_on_cache_miss: # only bother attempting to fetch keys from servers on our whitelist diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 3a1db4f18a9b..a7adcd46e4ec 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -16,7 +16,7 @@ import itertools import json import logging -from typing import Dict, Iterable, List, Mapping, Optional, Tuple +from typing import Dict, Iterable, Mapping, Optional, Tuple from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 @@ -255,23 +255,22 @@ def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: async def get_server_keys_json_for_remote( self, server_keys: Iterable[Tuple[str, Optional[str]]] - ) -> Dict[Tuple[str, Optional[str]], List[FetchKeyResultForRemote]]: + ) -> Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]]: """Retrieve the key json for a list of server_keys and key ids. - If no keys are found for a given server and key_id then that server and - key_id tuple entry will be an empty list. The JSON is returned as a byte - array so that it can be efficiently used in an HTTP response. Args: - server_keys: List of (server_name, key_id) tuples. + server_keys: List of (server_name, key_id) tuples. If `key_id` is + None, returns all keys that are in the DB. Returns: - A mapping from (server_name, key_id) tuples to a list of dicts + A mapping from (server_name, key_id) tuples to the key data, or + None if we don't have that key_id stored """ def _get_server_keys_json_txn( txn: LoggingTransaction, - ) -> Dict[Tuple[str, Optional[str]], List[FetchKeyResultForRemote]]: - results = {} + ) -> Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]]: + results: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {} for server_name, key_id in server_keys: keyvalues = {"server_name": server_name} if key_id is not None: @@ -289,15 +288,19 @@ def _get_server_keys_json_txn( "key_json", ), ) - results[(server_name, key_id)] = [ - FetchKeyResultForRemote( - # Cast to bytes since postgresql returns a memoryview. - key_json=bytes(row["key_json"]), - valid_until_ts=row["ts_valid_until_ms"], - added_ts=row["ts_added_ms"], - ) - for row in rows - ] + + if not rows: + continue + + row = max(rows, key=lambda r: r["ts_added_ms"]) + + results[(server_name, row["key_id"])] = FetchKeyResultForRemote( + # Cast to bytes since postgresql returns a memoryview. + key_json=bytes(row["key_json"]), + valid_until_ts=row["ts_valid_until_ms"], + added_ts=row["ts_added_ms"], + ) + return results return await self.db_pool.runInteraction( diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 33fb565e0f93..9088c157ca60 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -462,9 +462,9 @@ async def get_json(destination: str, path: str, **kwargs: Any) -> JsonDict: [lookup_tuple] ) ) - res_keys = key_json[lookup_tuple] - self.assertEqual(len(res_keys), 1) - res = res_keys[0] + res = key_json[lookup_tuple] + self.assertIsNotNone(res) + assert res is not None self.assertEqual(res.added_ts, self.reactor.seconds() * 1000) self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS) @@ -578,9 +578,9 @@ def test_get_keys_from_perspectives(self) -> None: [lookup_tuple] ) ) - res_keys = key_json[lookup_tuple] - self.assertEqual(len(res_keys), 1) - res = res_keys[0] + res = key_json[lookup_tuple] + self.assertIsNotNone(res) + assert res is not None self.assertEqual(res.added_ts, self.reactor.seconds() * 1000) self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS) @@ -697,9 +697,9 @@ def test_get_perspectives_own_key(self) -> None: [lookup_tuple] ) ) - res_keys = key_json[lookup_tuple] - self.assertEqual(len(res_keys), 1) - res = res_keys[0] + res = key_json[lookup_tuple] + self.assertIsNotNone(res) + assert res is not None self.assertEqual(res.added_ts, self.reactor.seconds() * 1000) self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS) From b38d8d672647a0174f6d5f4eaaf5b314a2fdee2f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 16 Aug 2023 16:07:44 +0100 Subject: [PATCH 5/9] Split up the functions This isn't as inefficient as it sounds, as we only ever really get a single server in each request generally. --- synapse/rest/key/v2/remote_key_resource.py | 23 +++-- synapse/storage/databases/main/keys.py | 108 ++++++++++++--------- tests/crypto/test_keyring.py | 17 ++-- 3 files changed, 86 insertions(+), 62 deletions(-) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index c034e02ca396..e223c35e94cb 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -14,7 +14,7 @@ import logging import re -from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple from signedjson.sign import sign_json @@ -27,6 +27,7 @@ parse_integer, parse_json_object_from_request, ) +from synapse.storage.keys import FetchKeyResultForRemote from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import yieldable_gather_results @@ -157,14 +158,22 @@ async def query_keys( ) -> JsonDict: logger.info("Handling query for keys %r", query) - store_queries = [] + cached: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {} for server_name, key_ids in query.items(): - if not key_ids: - key_ids = (None,) - for key_id in key_ids: - store_queries.append((server_name, key_id)) + if key_ids: + results: Mapping[ + str, Optional[FetchKeyResultForRemote] + ] = await self.store.get_server_keys_json_for_remote( + server_name, key_ids + ) + else: + results = await self.store.get_all_server_keys_json_for_remote( + server_name + ) - cached = await self.store.get_server_keys_json_for_remote(store_queries) + cached.update( + ((server_name, key_id), res) for key_id, res in results.items() + ) json_results: Set[bytes] = set() diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index a7adcd46e4ec..7b0305bffaf8 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -22,7 +22,6 @@ from unpaddedbase64 import decode_base64 from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote from synapse.storage.types import Cursor from synapse.util.caches.descriptors import cached, cachedList @@ -254,55 +253,74 @@ def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: return await self.db_pool.runInteraction("get_server_keys_json", _txn) async def get_server_keys_json_for_remote( - self, server_keys: Iterable[Tuple[str, Optional[str]]] - ) -> Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]]: - """Retrieve the key json for a list of server_keys and key ids. + self, server_name: str, key_ids: Iterable[str] + ) -> Dict[str, Optional[FetchKeyResultForRemote]]: + """Fetch the cached keys for the given server/key IDs. - Args: - server_keys: List of (server_name, key_id) tuples. If `key_id` is - None, returns all keys that are in the DB. - - Returns: - A mapping from (server_name, key_id) tuples to the key data, or - None if we don't have that key_id stored + If we have multiple entries for a given key ID, returns the most recent. """ + rows = await self.db_pool.simple_select_many_batch( + table="server_keys_json", + column="key_id", + iterable=key_ids, + keyvalues={"server_name": server_name}, + retcols=( + "key_id", + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "key_json", + ), + desc="get_server_keys_json_for_remote", + ) - def _get_server_keys_json_txn( - txn: LoggingTransaction, - ) -> Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]]: - results: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {} - for server_name, key_id in server_keys: - keyvalues = {"server_name": server_name} - if key_id is not None: - keyvalues["key_id"] = key_id - - rows = self.db_pool.simple_select_list_txn( - txn, - "server_keys_json", - keyvalues=keyvalues, - retcols=( - "key_id", - "from_server", - "ts_added_ms", - "ts_valid_until_ms", - "key_json", - ), - ) - - if not rows: - continue + if not rows: + return {} - row = max(rows, key=lambda r: r["ts_added_ms"]) + rows.sort(key=lambda r: r["ts_added_ms"]) - results[(server_name, row["key_id"])] = FetchKeyResultForRemote( - # Cast to bytes since postgresql returns a memoryview. - key_json=bytes(row["key_json"]), - valid_until_ts=row["ts_valid_until_ms"], - added_ts=row["ts_added_ms"], - ) + return { + row["key_id"]: FetchKeyResultForRemote( + # Cast to bytes since postgresql returns a memoryview. + key_json=bytes(row["key_json"]), + valid_until_ts=row["ts_valid_until_ms"], + added_ts=row["ts_added_ms"], + ) + for row in rows + } - return results + async def get_all_server_keys_json_for_remote( + self, + server_name: str, + ) -> Dict[str, FetchKeyResultForRemote]: + """Fetch the cached keys for the given server. - return await self.db_pool.runInteraction( - "get_server_keys_json_for_remote", _get_server_keys_json_txn + If we have multiple entries for a given key ID, returns the most recent. + """ + rows = await self.db_pool.simple_select_list( + table="server_keys_json", + keyvalues={"server_name": server_name}, + retcols=( + "key_id", + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "key_json", + ), + desc="get_server_keys_json_for_remote", ) + + if not rows: + return {} + + rows.sort(key=lambda r: r["ts_added_ms"]) + + return { + row["key_id"]: FetchKeyResultForRemote( + # Cast to bytes since postgresql returns a memoryview. + key_json=bytes(row["key_json"]), + valid_until_ts=row["ts_valid_until_ms"], + added_ts=row["ts_added_ms"], + ) + for row in rows + } diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 9088c157ca60..b7bbd40deea2 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -455,14 +455,13 @@ async def get_json(destination: str, path: str, **kwargs: Any) -> JsonDict: self.assertEqual(k.verify_key.alg, "ed25519") self.assertEqual(k.verify_key.version, "ver1") - # check that the perspectives store is correctly updated - lookup_tuple = (SERVER_NAME, testverifykey_id) + # check that the perspectives store is correctly updated= key_json = self.get_success( self.hs.get_datastores().main.get_server_keys_json_for_remote( - [lookup_tuple] + SERVER_NAME, [testverifykey_id] ) ) - res = key_json[lookup_tuple] + res = key_json[testverifykey_id] self.assertIsNotNone(res) assert res is not None self.assertEqual(res.added_ts, self.reactor.seconds() * 1000) @@ -572,13 +571,12 @@ def test_get_keys_from_perspectives(self) -> None: self.assertEqual(k.verify_key.version, "ver1") # check that the perspectives store is correctly updated - lookup_tuple = (SERVER_NAME, testverifykey_id) key_json = self.get_success( self.hs.get_datastores().main.get_server_keys_json_for_remote( - [lookup_tuple] + SERVER_NAME, [testverifykey_id] ) ) - res = key_json[lookup_tuple] + res = key_json[testverifykey_id] self.assertIsNotNone(res) assert res is not None self.assertEqual(res.added_ts, self.reactor.seconds() * 1000) @@ -691,13 +689,12 @@ def test_get_perspectives_own_key(self) -> None: self.assertEqual(k.verify_key.version, "ver1") # check that the perspectives store is correctly updated - lookup_tuple = (SERVER_NAME, testverifykey_id) key_json = self.get_success( self.hs.get_datastores().main.get_server_keys_json_for_remote( - [lookup_tuple] + SERVER_NAME, [testverifykey_id] ) ) - res = key_json[lookup_tuple] + res = key_json[testverifykey_id] self.assertIsNotNone(res) assert res is not None self.assertEqual(res.added_ts, self.reactor.seconds() * 1000) From 7d69047d631b6254d487d0f82302280b4c43a3ff Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 16 Aug 2023 16:12:10 +0100 Subject: [PATCH 6/9] Add caching for fetching remote server keys --- synapse/storage/databases/main/keys.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 7b0305bffaf8..f6d8f103a551 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -21,7 +21,7 @@ from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 -from synapse.storage._base import SQLBaseStore +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote from synapse.storage.types import Cursor from synapse.util.caches.descriptors import cached, cachedList @@ -33,7 +33,7 @@ db_binary_type = memoryview -class KeyStore(SQLBaseStore): +class KeyStore(CacheInvalidationWorkerStore): """Persistence for signature verification keys""" @cached() @@ -187,7 +187,12 @@ async def store_server_keys_json( # invalidate takes a tuple corresponding to the params of # _get_server_keys_json. _get_server_keys_json only takes one # param, which is itself the 2-tuple (server_name, key_id). - self._get_server_keys_json.invalidate(((server_name, key_id),)) + await self.invalidate_cache_and_stream( + "_get_server_keys_json", ((server_name, key_id),) + ) + await self.invalidate_cache_and_stream( + "get_server_key_json_for_remote", (server_name, key_id) + ) @cached() def _get_server_keys_json( @@ -252,6 +257,17 @@ def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: return await self.db_pool.runInteraction("get_server_keys_json", _txn) + @cached() + def get_server_key_json_for_remote( + self, + server_name: str, + key_id: str, + ) -> Optional[FetchKeyResultForRemote]: + raise NotImplementedError() + + @cachedList( + cached_method_name="get_server_key_json_for_remote", list_name="key_ids" + ) async def get_server_keys_json_for_remote( self, server_name: str, key_ids: Iterable[str] ) -> Dict[str, Optional[FetchKeyResultForRemote]]: From ca6b0bbbe62142c083bd4325663ac44fa1c065d9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 16 Aug 2023 16:15:03 +0100 Subject: [PATCH 7/9] Newsfile --- changelog.d/16123.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/16123.misc diff --git a/changelog.d/16123.misc b/changelog.d/16123.misc new file mode 100644 index 000000000000..b7c6b7c2f201 --- /dev/null +++ b/changelog.d/16123.misc @@ -0,0 +1 @@ +Add cache to `get_server_keys_json_for_remote`. From 07c1e2e1a0c5d2a48a9feebe5a03d4dfb2c90c50 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 18 Aug 2023 09:38:36 +0100 Subject: [PATCH 8/9] Update tests/crypto/test_keyring.py Co-authored-by: David Robertson --- tests/crypto/test_keyring.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index b7bbd40deea2..2be341ac7b84 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -455,7 +455,7 @@ async def get_json(destination: str, path: str, **kwargs: Any) -> JsonDict: self.assertEqual(k.verify_key.alg, "ed25519") self.assertEqual(k.verify_key.version, "ver1") - # check that the perspectives store is correctly updated= + # check that the perspectives store is correctly updated key_json = self.get_success( self.hs.get_datastores().main.get_server_keys_json_for_remote( SERVER_NAME, [testverifykey_id] From ed7b06365803b9f1ff9bf935df4e772d7c9c5771 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 18 Aug 2023 09:38:50 +0100 Subject: [PATCH 9/9] Review comments --- synapse/rest/key/v2/remote_key_resource.py | 6 +++--- synapse/storage/databases/main/keys.py | 1 + synapse/storage/keys.py | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index e223c35e94cb..981fd1f58a68 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -158,7 +158,7 @@ async def query_keys( ) -> JsonDict: logger.info("Handling query for keys %r", query) - cached: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {} + server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {} for server_name, key_ids in query.items(): if key_ids: results: Mapping[ @@ -171,7 +171,7 @@ async def query_keys( server_name ) - cached.update( + server_keys.update( ((server_name, key_id), res) for key_id, res in results.items() ) @@ -182,7 +182,7 @@ async def query_keys( # Map server_name->key_id->int. Note that the value of the int is unused. # XXX: why don't we just use a set? cache_misses: Dict[str, Dict[str, int]] = {} - for (server_name, key_id), key_result in cached.items(): + for (server_name, key_id), key_result in server_keys.items(): if not query[server_name]: # all keys were requested. Just return what we have without worrying # about validity diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index f6d8f103a551..a3b4744855d7 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -293,6 +293,7 @@ async def get_server_keys_json_for_remote( if not rows: return {} + # We sort the rows so that the most recently added entry is picked up. rows.sort(key=lambda r: r["ts_added_ms"]) return { diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 24f16d20145c..e74b2269d216 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -30,5 +30,5 @@ class FetchKeyResult: @attr.s(slots=True, frozen=True, auto_attribs=True) class FetchKeyResultForRemote: key_json: bytes # the full key JSON - valid_until_ts: int # how long we can use this key for - added_ts: int # When we added this key + valid_until_ts: int # how long we can use this key for, in milliseconds. + added_ts: int # When we added this key, in milliseconds.