Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add cache to get_server_keys_json_for_remote #16123

Merged
merged 9 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16123.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add cache to `get_server_keys_json_for_remote`.
44 changes: 25 additions & 19 deletions synapse/rest/key/v2/remote_key_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
import re
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple

from signedjson.sign import sign_json

Expand All @@ -27,6 +27,7 @@
parse_integer,
parse_json_object_from_request,
)
from synapse.storage.keys import FetchKeyResultForRemote
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
Expand Down Expand Up @@ -157,14 +158,22 @@ async def query_keys(
) -> JsonDict:
logger.info("Handling query for keys %r", query)

store_queries = []
server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {}
for server_name, key_ids in query.items():
if not key_ids:
key_ids = (None,)
for key_id in key_ids:
store_queries.append((server_name, key_id, None))
if key_ids:
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
results: Mapping[
str, Optional[FetchKeyResultForRemote]
] = await self.store.get_server_keys_json_for_remote(
server_name, key_ids
)
else:
results = await self.store.get_all_server_keys_json_for_remote(
server_name
)

cached = await self.store.get_server_keys_json_for_remote(store_queries)
server_keys.update(
((server_name, key_id), res) for key_id, res in results.items()
)

json_results: Set[bytes] = set()

Expand All @@ -173,23 +182,20 @@ async def query_keys(
# Map server_name->key_id->int. Note that the value of the int is unused.
# XXX: why don't we just use a set?
cache_misses: Dict[str, Dict[str, int]] = {}
for (server_name, key_id, _), key_results in cached.items():
results = [(result["ts_added_ms"], result) for result in key_results]

if key_id is None:
for (server_name, key_id), key_result in server_keys.items():
if not query[server_name]:
# all keys were requested. Just return what we have without worrying
# about validity
for _, result in results:
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(result["key_json"]))
if key_result:
json_results.add(key_result.key_json)
continue

miss = False
if not results:
if key_result is None:
miss = True
else:
ts_added_ms, most_recent_result = max(results)
ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
ts_added_ms = key_result.added_ts
ts_valid_until_ms = key_result.valid_until_ts
req_key = query.get(server_name, {}).get(key_id, {})
req_valid_until = req_key.get("minimum_valid_until_ts")
if req_valid_until is not None:
Expand Down Expand Up @@ -235,8 +241,8 @@ async def query_keys(
ts_valid_until_ms,
time_now_ms,
)
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(most_recent_result["key_json"]))

json_results.add(key_result.key_json)

if miss and query_remote_on_cache_miss:
# only bother attempting to fetch keys from servers on our whitelist
Expand Down
132 changes: 88 additions & 44 deletions synapse/storage/databases/main/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
import itertools
import json
import logging
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
from typing import Dict, Iterable, Mapping, Optional, Tuple

from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64

from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
from synapse.storage.keys import FetchKeyResult
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
Expand All @@ -34,7 +33,7 @@
db_binary_type = memoryview


class KeyStore(SQLBaseStore):
class KeyStore(CacheInvalidationWorkerStore):
"""Persistence for signature verification keys"""

@cached()
Expand Down Expand Up @@ -188,7 +187,12 @@ async def store_server_keys_json(
# invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id).
self._get_server_keys_json.invalidate(((server_name, key_id),))
await self.invalidate_cache_and_stream(
"_get_server_keys_json", ((server_name, key_id),)
)
await self.invalidate_cache_and_stream(
"get_server_key_json_for_remote", (server_name, key_id)
)

@cached()
def _get_server_keys_json(
Expand Down Expand Up @@ -253,47 +257,87 @@ def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:

return await self.db_pool.runInteraction("get_server_keys_json", _txn)

async def get_server_keys_json_for_remote(
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
"""Retrieve the key json for a list of server_keys and key ids.
If no keys are found for a given server, key_id and source then
that server, key_id, and source triplet entry will be an empty list.
The JSON is returned as a byte array so that it can be efficiently
used in an HTTP response.
@cached()
def get_server_key_json_for_remote(
self,
server_name: str,
key_id: str,
) -> Optional[FetchKeyResultForRemote]:
raise NotImplementedError()

Args:
server_keys: List of (server_name, key_id, source) triplets.
@cachedList(
cached_method_name="get_server_key_json_for_remote", list_name="key_ids"
)
async def get_server_keys_json_for_remote(
self, server_name: str, key_ids: Iterable[str]
) -> Dict[str, Optional[FetchKeyResultForRemote]]:
"""Fetch the cached keys for the given server/key IDs.

Returns:
A mapping from (server_name, key_id, source) triplets to a list of dicts
If we have multiple entries for a given key ID, returns the most recent.
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
"""
rows = await self.db_pool.simple_select_many_batch(
table="server_keys_json",
column="key_id",
iterable=key_ids,
keyvalues={"server_name": server_name},
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
desc="get_server_keys_json_for_remote",
)

def _get_server_keys_json_txn(
txn: LoggingTransaction,
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
results = {}
for server_name, key_id, from_server in server_keys:
keyvalues = {"server_name": server_name}
if key_id is not None:
keyvalues["key_id"] = key_id
if from_server is not None:
keyvalues["from_server"] = from_server
rows = self.db_pool.simple_select_list_txn(
txn,
"server_keys_json",
keyvalues=keyvalues,
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
)
results[(server_name, key_id, from_server)] = rows
return results
if not rows:
return {}

# We sort the rows so that the most recently added entry is picked up.
rows.sort(key=lambda r: r["ts_added_ms"])

return {
row["key_id"]: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview.
key_json=bytes(row["key_json"]),
valid_until_ts=row["ts_valid_until_ms"],
added_ts=row["ts_added_ms"],
)
for row in rows
}

return await self.db_pool.runInteraction(
"get_server_keys_json", _get_server_keys_json_txn
async def get_all_server_keys_json_for_remote(
self,
server_name: str,
) -> Dict[str, FetchKeyResultForRemote]:
"""Fetch the cached keys for the given server.

If we have multiple entries for a given key ID, returns the most recent.
"""
rows = await self.db_pool.simple_select_list(
table="server_keys_json",
keyvalues={"server_name": server_name},
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
desc="get_server_keys_json_for_remote",
)

if not rows:
return {}

rows.sort(key=lambda r: r["ts_added_ms"])

return {
row["key_id"]: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview.
key_json=bytes(row["key_json"]),
valid_until_ts=row["ts_valid_until_ms"],
added_ts=row["ts_added_ms"],
)
for row in rows
}
7 changes: 7 additions & 0 deletions synapse/storage/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,10 @@
class FetchKeyResult:
verify_key: VerifyKey # the key itself
valid_until_ts: int # how long we can use this key for


@attr.s(slots=True, frozen=True, auto_attribs=True)
class FetchKeyResultForRemote:
key_json: bytes # the full key JSON
valid_until_ts: int # how long we can use this key for, in milliseconds.
added_ts: int # When we added this key, in milliseconds.
61 changes: 23 additions & 38 deletions tests/crypto/test_keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,24 +456,19 @@ async def get_json(destination: str, path: str, **kwargs: Any) -> JsonDict:
self.assertEqual(k.verify_key.version, "ver1")

# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet]
SERVER_NAME, [testverifykey_id]
)
)
res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1)
res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], SERVER_NAME)
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
res = key_json[testverifykey_id]
self.assertIsNotNone(res)
assert res is not None
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)

# we expect it to be encoded as canonical json *before* it hits the db
self.assertEqual(
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))

# change the server name: the result should be ignored
response["server_name"] = "OTHER_SERVER"
Expand Down Expand Up @@ -576,23 +571,18 @@ def test_get_keys_from_perspectives(self) -> None:
self.assertEqual(k.verify_key.version, "ver1")

# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet]
SERVER_NAME, [testverifykey_id]
)
)
res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1)
res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)

self.assertEqual(
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
res = key_json[testverifykey_id]
self.assertIsNotNone(res)
assert res is not None
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)

self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))

def test_get_multiple_keys_from_perspectives(self) -> None:
"""Check that we can correctly request multiple keys for the same server"""
Expand Down Expand Up @@ -699,23 +689,18 @@ def test_get_perspectives_own_key(self) -> None:
self.assertEqual(k.verify_key.version, "ver1")

# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet]
SERVER_NAME, [testverifykey_id]
)
)
res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1)
res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)

self.assertEqual(
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
res = key_json[testverifykey_id]
self.assertIsNotNone(res)
assert res is not None
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)

self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))

def test_invalid_perspectives_responses(self) -> None:
"""Check that invalid responses from the perspectives server are rejected"""
Expand Down
Loading