From b75537beaf841089f9f07c9dbed04a7a420a8b1f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 3 Apr 2019 18:10:24 +0100 Subject: [PATCH 1/3] Store key validity time in the storage layer This is a first step to checking that the key is valid at the required moment. The idea here is that, rather than passing VerifyKey objects in and out of the storage layer, we instead pass FetchKeyResult objects, which simply wrap the VerifyKey and add a valid_until_ts field. --- changelog.d/5237.misc | 1 + synapse/crypto/keyring.py | 47 +++++++++++++------ synapse/storage/keys.py | 31 ++++++++---- .../delta/54/add_validity_to_server_keys.sql | 23 +++++++++ tests/crypto/test_keyring.py | 22 +++++---- tests/storage/test_keys.py | 44 +++++++++++------ 6 files changed, 122 insertions(+), 46 deletions(-) create mode 100644 changelog.d/5237.misc create mode 100644 synapse/storage/schema/delta/54/add_validity_to_server_keys.sql diff --git a/changelog.d/5237.misc b/changelog.d/5237.misc new file mode 100644 index 000000000000..f4fe3b821bf6 --- /dev/null +++ b/changelog.d/5237.misc @@ -0,0 +1 @@ +Store key validity time in the storage layer. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 9d629b2238d4..14a27288fd4c 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -20,7 +20,6 @@ from six import raise_from from six.moves import urllib -import nacl.signing from signedjson.key import ( decode_verify_key_bytes, encode_verify_key_base64, @@ -43,6 +42,7 @@ RequestSendFailed, SynapseError, ) +from synapse.storage.keys import FetchKeyResult from synapse.util import logcontext, unwrapFirstError from synapse.util.logcontext import ( LoggingContext, @@ -307,11 +307,15 @@ def do_iterations(): # complete this VerifyKeyRequest. result_keys = results.get(server_name, {}) for key_id in verify_request.key_ids: - key = result_keys.get(key_id) - if key: + fetch_key_result = result_keys.get(key_id) + if fetch_key_result: with PreserveLoggingContext(): verify_request.deferred.callback( - (server_name, key_id, key) + ( + server_name, + key_id, + fetch_key_result.verify_key, + ) ) break else: @@ -348,12 +352,12 @@ def on_err(err): def get_keys_from_store(self, server_name_and_key_ids): """ Args: - server_name_and_key_ids (iterable(Tuple[str, iterable[str]]): + server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]): list of (server_name, iterable[key_id]) tuples to fetch keys for Returns: - Deferred: resolves to dict[str, dict[str, VerifyKey|None]]: map from - server_name -> key_id -> VerifyKey + Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: + map from server_name -> key_id -> FetchKeyResult """ keys_to_fetch = ( (server_name, key_id) @@ -430,6 +434,18 @@ def get_keys_from_server(self, server_name_and_key_ids): def get_server_verify_key_v2_indirect( self, server_names_and_key_ids, perspective_name, perspective_keys ): + """ + Args: + server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]): + list of (server_name, iterable[key_id]) tuples to fetch keys for + perspective_name (str): name of the notary server to query for the keys + perspective_keys (dict[str, VerifyKey]): map of key_id->key for the + notary server + + Returns: + Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map + from server_name -> key_id -> FetchKeyResult + """ # TODO(mark): Set the minimum_valid_until_ts to that needed by # the events being validated or the current time if validating # an incoming request. @@ -506,7 +522,7 @@ def get_server_verify_key_v2_indirect( @defer.inlineCallbacks def get_server_verify_key_v2_direct(self, server_name, key_ids): - keys = {} # type: dict[str, nacl.signing.VerifyKey] + keys = {} # type: dict[str, FetchKeyResult] for requested_key_id in key_ids: if requested_key_id in keys: @@ -583,9 +599,9 @@ def process_v2_response( actually in the response Returns: - Deferred[dict[str, nacl.signing.VerifyKey]]: - map from key_id to key object + Deferred[dict[str, FetchKeyResult]]: map from key_id to result object """ + ts_valid_until_ms = response_json[u"valid_until_ts"] # start by extracting the keys from the response, since they may be required # to validate the signature on the response. @@ -595,7 +611,9 @@ def process_v2_response( key_base64 = key_data["key"] key_bytes = decode_base64(key_base64) verify_key = decode_verify_key_bytes(key_id, key_bytes) - verify_keys[key_id] = verify_key + verify_keys[key_id] = FetchKeyResult( + verify_key=verify_key, valid_until_ts=ts_valid_until_ms + ) # TODO: improve this signature checking server_name = response_json["server_name"] @@ -606,7 +624,7 @@ def process_v2_response( ) verify_signed_json( - response_json, server_name, verify_keys[key_id] + response_json, server_name, verify_keys[key_id].verify_key ) for key_id, key_data in response_json["old_verify_keys"].items(): @@ -614,7 +632,9 @@ def process_v2_response( key_base64 = key_data["key"] key_bytes = decode_base64(key_base64) verify_key = decode_verify_key_bytes(key_id, key_bytes) - verify_keys[key_id] = verify_key + verify_keys[key_id] = FetchKeyResult( + verify_key=verify_key, valid_until_ts=key_data["expired_ts"] + ) # re-sign the json with our own key, so that it is ready if we are asked to # give it out as a notary server @@ -623,7 +643,6 @@ def process_v2_response( ) signed_key_json_bytes = encode_canonical_json(signed_key_json) - ts_valid_until_ms = signed_key_json[u"valid_until_ts"] # for reasons I don't quite understand, we store this json for the key ids we # requested, as well as those we got. diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 3c5f52009b3f..5300720dbb87 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -19,6 +19,7 @@ import six +import attr from signedjson.key import decode_verify_key_bytes from synapse.util import batch_iter @@ -36,6 +37,12 @@ db_binary_type = memoryview +@attr.s(slots=True, frozen=True) +class FetchKeyResult(object): + verify_key = attr.ib() # VerifyKey: the key itself + valid_until_ts = attr.ib() # int: how long we can use this key for + + class KeyStore(SQLBaseStore): """Persistence for signature verification keys """ @@ -54,8 +61,8 @@ def get_server_verify_keys(self, server_name_and_key_ids): iterable of (server_name, key-id) tuples to fetch keys for Returns: - Deferred: resolves to dict[Tuple[str, str], VerifyKey|None]: - map from (server_name, key_id) -> VerifyKey, or None if the key is + Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]: + map from (server_name, key_id) -> FetchKeyResult, or None if the key is unknown """ keys = {} @@ -65,17 +72,19 @@ def _get_keys(txn, batch): # batch_iter always returns tuples so it's safe to do len(batch) sql = ( - "SELECT server_name, key_id, verify_key FROM server_signature_keys " - "WHERE 1=0" + "SELECT server_name, key_id, verify_key, ts_valid_until_ms " + "FROM server_signature_keys WHERE 1=0" ) + " OR (server_name=? AND key_id=?)" * len(batch) txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) for row in txn: - server_name, key_id, key_bytes = row - keys[(server_name, key_id)] = decode_verify_key_bytes( - key_id, bytes(key_bytes) + server_name, key_id, key_bytes, ts_valid_until_ms = row + res = FetchKeyResult( + verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), + valid_until_ts=ts_valid_until_ms, ) + keys[(server_name, key_id)] = res def _txn(txn): for batch in batch_iter(server_name_and_key_ids, 50): @@ -89,20 +98,21 @@ def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): Args: from_server (str): Where the verification keys were looked up ts_added_ms (int): The time to record that the key was added - verify_keys (iterable[tuple[str, str, nacl.signing.VerifyKey]]): + verify_keys (iterable[tuple[str, str, FetchKeyResult]]): keys to be stored. Each entry is a triplet of (server_name, key_id, key). """ key_values = [] value_values = [] invalidations = [] - for server_name, key_id, verify_key in verify_keys: + for server_name, key_id, fetch_result in verify_keys: key_values.append((server_name, key_id)) value_values.append( ( from_server, ts_added_ms, - db_binary_type(verify_key.encode()), + fetch_result.valid_until_ts, + db_binary_type(fetch_result.verify_key.encode()), ) ) # invalidate takes a tuple corresponding to the params of @@ -125,6 +135,7 @@ def _invalidate(res): value_names=( "from_server", "ts_added_ms", + "ts_valid_until_ms", "verify_key", ), value_values=value_values, diff --git a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql new file mode 100644 index 000000000000..c01aa9d2d90b --- /dev/null +++ b/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql @@ -0,0 +1,23 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* When we can use this key until, before we have to refresh it. */ +ALTER TABLE server_signature_keys ADD COLUMN ts_valid_until_ms BIGINT; + +UPDATE server_signature_keys SET ts_valid_until_ms = ( + SELECT MAX(ts_valid_until_ms) FROM server_keys_json skj WHERE + skj.server_name = server_signature_keys.server_name AND + skj.key_id = server_signature_keys.key_id +); diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index bcffe53a9187..83de32b05d33 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError from synapse.crypto import keyring from synapse.crypto.keyring import KeyLookupError +from synapse.storage.keys import FetchKeyResult from synapse.util import logcontext from synapse.util.logcontext import LoggingContext @@ -201,7 +202,7 @@ def test_verify_json_for_server(self): ( "server9", key1_id, - signedjson.key.get_verify_key(key1), + FetchKeyResult(signedjson.key.get_verify_key(key1), 1000), ), ], ) @@ -251,9 +252,10 @@ def get_json(destination, path, **kwargs): server_name_and_key_ids = [(SERVER_NAME, ("key1",))] keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids)) k = keys[SERVER_NAME][testverifykey_id] - self.assertEqual(k, testverifykey) - self.assertEqual(k.alg, "ed25519") - self.assertEqual(k.version, "ver1") + self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) + self.assertEqual(k.verify_key, testverifykey) + self.assertEqual(k.verify_key.alg, "ed25519") + self.assertEqual(k.verify_key.version, "ver1") # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) @@ -321,9 +323,10 @@ def post_json(destination, path, data, **kwargs): keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids)) self.assertIn(SERVER_NAME, keys) k = keys[SERVER_NAME][testverifykey_id] - self.assertEqual(k, testverifykey) - self.assertEqual(k.alg, "ed25519") - self.assertEqual(k.version, "ver1") + self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) + self.assertEqual(k.verify_key, testverifykey) + self.assertEqual(k.verify_key.alg, "ed25519") + self.assertEqual(k.verify_key.version, "ver1") # check that the perspectives store is correctly updated lookup_triplet = (SERVER_NAME, testverifykey_id, None) @@ -346,7 +349,10 @@ def post_json(destination, path, data, **kwargs): @defer.inlineCallbacks def run_in_context(f, *args, **kwargs): - with LoggingContext("testctx"): + with LoggingContext("testctx") as ctx: + # we set the "request" prop to make it easier to follow what's going on in the + # logs. + ctx.request = "testctx" rv = yield f(*args, **kwargs) defer.returnValue(rv) diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index 71ad7aee32fc..e07ff0120173 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -17,6 +17,8 @@ from twisted.internet.defer import Deferred +from synapse.storage.keys import FetchKeyResult + import tests.unittest KEY_1 = signedjson.key.decode_verify_key_base64( @@ -37,8 +39,8 @@ def test_get_server_verify_keys(self): "from_server", 10, [ - ("server1", key_id_1, KEY_1), - ("server1", key_id_2, KEY_2), + ("server1", key_id_1, FetchKeyResult(KEY_1, 100)), + ("server1", key_id_2, FetchKeyResult(KEY_2, 200)), ], ) self.get_success(d) @@ -50,13 +52,15 @@ def test_get_server_verify_keys(self): self.assertEqual(len(res.keys()), 3) res1 = res[("server1", key_id_1)] - self.assertEqual(res1, KEY_1) - self.assertEqual(res1.version, "key1") + self.assertEqual(res1.verify_key, KEY_1) + self.assertEqual(res1.verify_key.version, "key1") + self.assertEqual(res1.valid_until_ts, 100) res2 = res[("server1", key_id_2)] - self.assertEqual(res2, KEY_2) + self.assertEqual(res2.verify_key, KEY_2) # version comes from the ID it was stored with - self.assertEqual(res2.version, "KEY_ID_2") + self.assertEqual(res2.verify_key.version, "KEY_ID_2") + self.assertEqual(res2.valid_until_ts, 200) # non-existent result gives None self.assertIsNone(res[("server1", "ed25519:key3")]) @@ -73,8 +77,8 @@ def test_cache(self): "from_server", 0, [ - ("srv1", key_id_1, KEY_1), - ("srv1", key_id_2, KEY_2), + ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)), + ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)), ], ) self.get_success(d) @@ -82,26 +86,38 @@ def test_cache(self): d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) res = self.get_success(d) self.assertEqual(len(res.keys()), 2) - self.assertEqual(res[("srv1", key_id_1)], KEY_1) - self.assertEqual(res[("srv1", key_id_2)], KEY_2) + + res1 = res[("srv1", key_id_1)] + self.assertEqual(res1.verify_key, KEY_1) + self.assertEqual(res1.valid_until_ts, 100) + + res2 = res[("srv1", key_id_2)] + self.assertEqual(res2.verify_key, KEY_2) + self.assertEqual(res2.valid_until_ts, 200) # we should be able to look up the same thing again without a db hit res = store.get_server_verify_keys([("srv1", key_id_1)]) if isinstance(res, Deferred): res = self.successResultOf(res) self.assertEqual(len(res.keys()), 1) - self.assertEqual(res[("srv1", key_id_1)], KEY_1) + self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1) new_key_2 = signedjson.key.get_verify_key( signedjson.key.generate_signing_key("key2") ) d = store.store_server_verify_keys( - "from_server", 10, [("srv1", key_id_2, new_key_2)] + "from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))] ) self.get_success(d) d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) res = self.get_success(d) self.assertEqual(len(res.keys()), 2) - self.assertEqual(res[("srv1", key_id_1)], KEY_1) - self.assertEqual(res[("srv1", key_id_2)], new_key_2) + + res1 = res[("srv1", key_id_1)] + self.assertEqual(res1.verify_key, KEY_1) + self.assertEqual(res1.valid_until_ts, 100) + + res2 = res[("srv1", key_id_2)] + self.assertEqual(res2.verify_key, new_key_2) + self.assertEqual(res2.valid_until_ts, 300) From 895b79ac2ece74500fb8a4ea158a6aec2adc0856 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 9 Apr 2019 18:28:17 +0100 Subject: [PATCH 2/3] Factor out KeyFetchers from KeyRing Rather than have three methods which have to have the same interface, factor out a separate interface which is provided by three implementations. I find it easier to grok the code this way. --- changelog.d/5244.misc | 1 + synapse/crypto/keyring.py | 315 +++++++++++++++++++---------------- tests/crypto/test_keyring.py | 34 +++- 3 files changed, 204 insertions(+), 146 deletions(-) create mode 100644 changelog.d/5244.misc diff --git a/changelog.d/5244.misc b/changelog.d/5244.misc new file mode 100644 index 000000000000..9cc1fb869de0 --- /dev/null +++ b/changelog.d/5244.misc @@ -0,0 +1 @@ +Refactor synapse.crypto.keyring to use a KeyFetcher interface. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 14a27288fd4c..eaf41b983c11 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -80,12 +80,13 @@ class KeyLookupError(ValueError): class Keyring(object): def __init__(self, hs): - self.store = hs.get_datastore() self.clock = hs.get_clock() - self.client = hs.get_http_client() - self.config = hs.get_config() - self.perspective_servers = self.config.perspectives - self.hs = hs + + self._key_fetchers = ( + StoreKeyFetcher(hs), + PerspectivesKeyFetcher(hs), + ServerKeyFetcher(hs), + ) # map from server name to Deferred. Has an entry for each server with # an ongoing key download; the Deferred completes once the download @@ -271,13 +272,6 @@ def _get_server_verify_keys(self, verify_requests): verify_requests (list[VerifyKeyRequest]): list of verify requests """ - # These are functions that produce keys given a list of key ids - key_fetch_fns = ( - self.get_keys_from_store, # First try the local store - self.get_keys_from_perspectives, # Then try via perspectives - self.get_keys_from_server, # Then try directly - ) - @defer.inlineCallbacks def do_iterations(): with Measure(self.clock, "get_server_verify_keys"): @@ -288,8 +282,8 @@ def do_iterations(): verify_request.key_ids ) - for fn in key_fetch_fns: - results = yield fn(missing_keys.items()) + for f in self._key_fetchers: + results = yield f.get_keys(missing_keys.items()) # We now need to figure out which verify requests we have keys # for and which we don't @@ -348,8 +342,9 @@ def on_err(err): run_in_background(do_iterations).addErrback(on_err) - @defer.inlineCallbacks - def get_keys_from_store(self, server_name_and_key_ids): + +class KeyFetcher(object): + def get_keys(self, server_name_and_key_ids): """ Args: server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]): @@ -359,6 +354,18 @@ def get_keys_from_store(self, server_name_and_key_ids): Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: map from server_name -> key_id -> FetchKeyResult """ + raise NotImplementedError + + +class StoreKeyFetcher(KeyFetcher): + """KeyFetcher impl which fetches keys from our data store""" + + def __init__(self, hs): + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def get_keys(self, server_name_and_key_ids): + """see KeyFetcher.get_keys""" keys_to_fetch = ( (server_name, key_id) for server_name, key_ids in server_name_and_key_ids @@ -370,8 +377,127 @@ def get_keys_from_store(self, server_name_and_key_ids): keys.setdefault(server_name, {})[key_id] = key defer.returnValue(keys) + +class BaseV2KeyFetcher(object): + def __init__(self, hs): + self.store = hs.get_datastore() + self.config = hs.get_config() + + @defer.inlineCallbacks + def process_v2_response( + self, from_server, response_json, time_added_ms, requested_ids=[] + ): + """Parse a 'Server Keys' structure from the result of a /key request + + This is used to parse either the entirety of the response from + GET /_matrix/key/v2/server, or a single entry from the list returned by + POST /_matrix/key/v2/query. + + Checks that each signature in the response that claims to come from the origin + server is valid. (Does not check that there actually is such a signature, for + some reason.) + + Stores the json in server_keys_json so that it can be used for future responses + to /_matrix/key/v2/query. + + Args: + from_server (str): the name of the server producing this result: either + the origin server for a /_matrix/key/v2/server request, or the notary + for a /_matrix/key/v2/query. + + response_json (dict): the json-decoded Server Keys response object + + time_added_ms (int): the timestamp to record in server_keys_json + + requested_ids (iterable[str]): a list of the key IDs that were requested. + We will store the json for these key ids as well as any that are + actually in the response + + Returns: + Deferred[dict[str, FetchKeyResult]]: map from key_id to result object + """ + ts_valid_until_ms = response_json[u"valid_until_ts"] + + # start by extracting the keys from the response, since they may be required + # to validate the signature on the response. + verify_keys = {} + for key_id, key_data in response_json["verify_keys"].items(): + if is_signing_algorithm_supported(key_id): + key_base64 = key_data["key"] + key_bytes = decode_base64(key_base64) + verify_key = decode_verify_key_bytes(key_id, key_bytes) + verify_keys[key_id] = FetchKeyResult( + verify_key=verify_key, valid_until_ts=ts_valid_until_ms + ) + + # TODO: improve this signature checking + server_name = response_json["server_name"] + for key_id in response_json["signatures"].get(server_name, {}): + if key_id not in verify_keys: + raise KeyLookupError( + "Key response must include verification keys for all signatures" + ) + + verify_signed_json( + response_json, server_name, verify_keys[key_id].verify_key + ) + + for key_id, key_data in response_json["old_verify_keys"].items(): + if is_signing_algorithm_supported(key_id): + key_base64 = key_data["key"] + key_bytes = decode_base64(key_base64) + verify_key = decode_verify_key_bytes(key_id, key_bytes) + verify_keys[key_id] = FetchKeyResult( + verify_key=verify_key, valid_until_ts=key_data["expired_ts"] + ) + + # re-sign the json with our own key, so that it is ready if we are asked to + # give it out as a notary server + signed_key_json = sign_json( + response_json, self.config.server_name, self.config.signing_key[0] + ) + + signed_key_json_bytes = encode_canonical_json(signed_key_json) + + # for reasons I don't quite understand, we store this json for the key ids we + # requested, as well as those we got. + updated_key_ids = set(requested_ids) + updated_key_ids.update(verify_keys) + + yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background( + self.store.store_server_keys_json, + server_name=server_name, + key_id=key_id, + from_server=from_server, + ts_now_ms=time_added_ms, + ts_expires_ms=ts_valid_until_ms, + key_json_bytes=signed_key_json_bytes, + ) + for key_id in updated_key_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + ) + + defer.returnValue(verify_keys) + + +class PerspectivesKeyFetcher(BaseV2KeyFetcher): + """KeyFetcher impl which fetches keys from the "perspectives" servers""" + + def __init__(self, hs): + super(PerspectivesKeyFetcher, self).__init__(hs) + self.clock = hs.get_clock() + self.client = hs.get_http_client() + self.perspective_servers = self.config.perspectives + @defer.inlineCallbacks - def get_keys_from_perspectives(self, server_name_and_key_ids): + def get_keys(self, server_name_and_key_ids): + """see KeyFetcher.get_keys""" + @defer.inlineCallbacks def get_key(perspective_name, perspective_keys): try: @@ -408,28 +534,6 @@ def get_key(perspective_name, perspective_keys): defer.returnValue(union_of_keys) - @defer.inlineCallbacks - def get_keys_from_server(self, server_name_and_key_ids): - results = yield logcontext.make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self.get_server_verify_key_v2_direct, server_name, key_ids - ) - for server_name, key_ids in server_name_and_key_ids - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) - ) - - merged = {} - for result in results: - merged.update(result) - - defer.returnValue( - {server_name: keys for server_name, keys in merged.items() if keys} - ) - @defer.inlineCallbacks def get_server_verify_key_v2_indirect( self, server_names_and_key_ids, perspective_name, perspective_keys @@ -520,6 +624,38 @@ def get_server_verify_key_v2_indirect( defer.returnValue(keys) + +class ServerKeyFetcher(BaseV2KeyFetcher): + """KeyFetcher impl which fetches keys from the origin servers""" + + def __init__(self, hs): + super(ServerKeyFetcher, self).__init__(hs) + self.clock = hs.get_clock() + self.client = hs.get_http_client() + + @defer.inlineCallbacks + def get_keys(self, server_name_and_key_ids): + """see KeyFetcher.get_keys""" + results = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background( + self.get_server_verify_key_v2_direct, server_name, key_ids + ) + for server_name, key_ids in server_name_and_key_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + ) + + merged = {} + for result in results: + merged.update(result) + + defer.returnValue( + {server_name: keys for server_name, keys in merged.items() if keys} + ) + @defer.inlineCallbacks def get_server_verify_key_v2_direct(self, server_name, key_ids): keys = {} # type: dict[str, FetchKeyResult] @@ -568,107 +704,6 @@ def get_server_verify_key_v2_direct(self, server_name, key_ids): defer.returnValue({server_name: keys}) - @defer.inlineCallbacks - def process_v2_response( - self, from_server, response_json, time_added_ms, requested_ids=[] - ): - """Parse a 'Server Keys' structure from the result of a /key request - - This is used to parse either the entirety of the response from - GET /_matrix/key/v2/server, or a single entry from the list returned by - POST /_matrix/key/v2/query. - - Checks that each signature in the response that claims to come from the origin - server is valid. (Does not check that there actually is such a signature, for - some reason.) - - Stores the json in server_keys_json so that it can be used for future responses - to /_matrix/key/v2/query. - - Args: - from_server (str): the name of the server producing this result: either - the origin server for a /_matrix/key/v2/server request, or the notary - for a /_matrix/key/v2/query. - - response_json (dict): the json-decoded Server Keys response object - - time_added_ms (int): the timestamp to record in server_keys_json - - requested_ids (iterable[str]): a list of the key IDs that were requested. - We will store the json for these key ids as well as any that are - actually in the response - - Returns: - Deferred[dict[str, FetchKeyResult]]: map from key_id to result object - """ - ts_valid_until_ms = response_json[u"valid_until_ts"] - - # start by extracting the keys from the response, since they may be required - # to validate the signature on the response. - verify_keys = {} - for key_id, key_data in response_json["verify_keys"].items(): - if is_signing_algorithm_supported(key_id): - key_base64 = key_data["key"] - key_bytes = decode_base64(key_base64) - verify_key = decode_verify_key_bytes(key_id, key_bytes) - verify_keys[key_id] = FetchKeyResult( - verify_key=verify_key, valid_until_ts=ts_valid_until_ms - ) - - # TODO: improve this signature checking - server_name = response_json["server_name"] - for key_id in response_json["signatures"].get(server_name, {}): - if key_id not in verify_keys: - raise KeyLookupError( - "Key response must include verification keys for all signatures" - ) - - verify_signed_json( - response_json, server_name, verify_keys[key_id].verify_key - ) - - for key_id, key_data in response_json["old_verify_keys"].items(): - if is_signing_algorithm_supported(key_id): - key_base64 = key_data["key"] - key_bytes = decode_base64(key_base64) - verify_key = decode_verify_key_bytes(key_id, key_bytes) - verify_keys[key_id] = FetchKeyResult( - verify_key=verify_key, valid_until_ts=key_data["expired_ts"] - ) - - # re-sign the json with our own key, so that it is ready if we are asked to - # give it out as a notary server - signed_key_json = sign_json( - response_json, self.config.server_name, self.config.signing_key[0] - ) - - signed_key_json_bytes = encode_canonical_json(signed_key_json) - - # for reasons I don't quite understand, we store this json for the key ids we - # requested, as well as those we got. - updated_key_ids = set(requested_ids) - updated_key_ids.update(verify_keys) - - yield logcontext.make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self.store.store_server_keys_json, - server_name=server_name, - key_id=key_id, - from_server=from_server, - ts_now_ms=time_added_ms, - ts_expires_ms=ts_valid_until_ms, - key_json_bytes=signed_key_json_bytes, - ) - for key_id in updated_key_ids - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) - ) - - defer.returnValue(verify_keys) - @defer.inlineCallbacks def _handle_key_deferred(verify_request): diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 83de32b05d33..de61bad15d27 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -24,7 +24,11 @@ from synapse.api.errors import SynapseError from synapse.crypto import keyring -from synapse.crypto.keyring import KeyLookupError +from synapse.crypto.keyring import ( + KeyLookupError, + PerspectivesKeyFetcher, + ServerKeyFetcher, +) from synapse.storage.keys import FetchKeyResult from synapse.util import logcontext from synapse.util.logcontext import LoggingContext @@ -218,12 +222,19 @@ def test_verify_json_for_server(self): self.assertFalse(d.called) self.get_success(d) + +class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + self.http_client = Mock() + hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client) + return hs + def test_get_keys_from_server(self): # arbitrarily advance the clock a bit self.reactor.advance(100) SERVER_NAME = "server2" - kr = keyring.Keyring(self.hs) + fetcher = ServerKeyFetcher(self.hs) testkey = signedjson.key.generate_signing_key("ver1") testverifykey = signedjson.key.get_verify_key(testkey) testverifykey_id = "ed25519:ver1" @@ -250,7 +261,7 @@ def get_json(destination, path, **kwargs): self.http_client.get_json.side_effect = get_json server_name_and_key_ids = [(SERVER_NAME, ("key1",))] - keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids)) + keys = self.get_success(fetcher.get_keys(server_name_and_key_ids)) k = keys[SERVER_NAME][testverifykey_id] self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) self.assertEqual(k.verify_key, testverifykey) @@ -278,15 +289,26 @@ def get_json(destination, path, **kwargs): # change the server name: it should cause a rejection response["server_name"] = "OTHER_SERVER" self.get_failure( - kr.get_keys_from_server(server_name_and_key_ids), KeyLookupError + fetcher.get_keys(server_name_and_key_ids), KeyLookupError ) + +class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + self.mock_perspective_server = MockPerspectiveServer() + self.http_client = Mock() + hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client) + keys = self.mock_perspective_server.get_verify_keys() + hs.config.perspectives = {self.mock_perspective_server.server_name: keys} + return hs + def test_get_keys_from_perspectives(self): # arbitrarily advance the clock a bit self.reactor.advance(100) + fetcher = PerspectivesKeyFetcher(self.hs) + SERVER_NAME = "server2" - kr = keyring.Keyring(self.hs) testkey = signedjson.key.generate_signing_key("ver1") testverifykey = signedjson.key.get_verify_key(testkey) testverifykey_id = "ed25519:ver1" @@ -320,7 +342,7 @@ def post_json(destination, path, data, **kwargs): self.http_client.post_json.side_effect = post_json server_name_and_key_ids = [(SERVER_NAME, ("key1",))] - keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids)) + keys = self.get_success(fetcher.get_keys(server_name_and_key_ids)) self.assertIn(SERVER_NAME, keys) k = keys[SERVER_NAME][testverifykey_id] self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) From ec24108cc2e937f49908df4c78f5cee9f81e0834 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 23 May 2019 14:52:13 +0100 Subject: [PATCH 3/3] Fix remote_key_resource --- synapse/rest/key/v2/remote_key_resource.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index eb8782aa6e1a..21c3c807b9d4 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -20,7 +20,7 @@ from twisted.web.server import NOT_DONE_YET from synapse.api.errors import Codes, SynapseError -from synapse.crypto.keyring import KeyLookupError +from synapse.crypto.keyring import KeyLookupError, ServerKeyFetcher from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler from synapse.http.servlet import parse_integer, parse_json_object_from_request @@ -89,7 +89,7 @@ class RemoteKey(Resource): isLeaf = True def __init__(self, hs): - self.keyring = hs.get_keyring() + self.fetcher = ServerKeyFetcher(hs) self.store = hs.get_datastore() self.clock = hs.get_clock() self.federation_domain_whitelist = hs.config.federation_domain_whitelist @@ -217,7 +217,7 @@ def query_keys(self, request, query, query_remote_on_cache_miss=False): if cache_misses and query_remote_on_cache_miss: for server_name, key_ids in cache_misses.items(): try: - yield self.keyring.get_server_verify_key_v2_direct( + yield self.fetcher.get_server_verify_key_v2_direct( server_name, key_ids ) except KeyLookupError as e: