Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fix typing issues found with mypy.
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Mar 17, 2020
1 parent 7581d30 commit 05e331d
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 28 deletions.
1 change: 1 addition & 0 deletions changelog.d/7089.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix typing issue in federation client found with mypy.
24 changes: 9 additions & 15 deletions synapse/federation/federation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@

from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
RoomVersion,
)
from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.crypto.event_signing import check_event_content_hash
from synapse.crypto.keyring import Keyring
from synapse.events import EventBase, make_event_from_dict
Expand All @@ -55,13 +51,15 @@ def __init__(self, hs):
self.store = hs.get_datastore()
self._clock = hs.get_clock()

def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred:
def _check_sigs_and_hash(
self, room_version: RoomVersion, pdu: EventBase
) -> Deferred:
return make_deferred_yieldable(
self._check_sigs_and_hashes(room_version, [pdu])[0]
)

def _check_sigs_and_hashes(
self, room_version: str, pdus: List[EventBase]
self, room_version: RoomVersion, pdus: List[EventBase]
) -> List[Deferred]:
"""Checks that each of the received events is correctly signed by the
sending server.
Expand Down Expand Up @@ -146,7 +144,7 @@ class PduToCheckSig(


def _check_sigs_on_pdus(
keyring: Keyring, room_version: str, pdus: Iterable[EventBase]
keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase]
) -> List[Deferred]:
"""Check that the given events are correctly signed
Expand Down Expand Up @@ -191,10 +189,6 @@ def _check_sigs_on_pdus(
for p in pdus
]

v = KNOWN_ROOM_VERSIONS.get(room_version)
if not v:
raise RuntimeError("Unrecognized room version %s" % (room_version,))

# First we check that the sender event is signed by the sender's domain
# (except if its a 3pid invite, in which case it may be sent by any server)
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
Expand All @@ -204,7 +198,7 @@ def _check_sigs_on_pdus(
(
p.sender_domain,
p.redacted_pdu_json,
p.pdu.origin_server_ts if v.enforce_key_validity else 0,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_sender
Expand All @@ -227,7 +221,7 @@ def sender_err(e, pdu_to_check):
# event id's domain (normally only the case for joins/leaves), and add additional
# checks. Only do this if the room version has a concept of event ID domain
# (ie, the room version uses old-style non-hash event IDs).
if v.event_format == EventFormatVersions.V1:
if room_version.event_format == EventFormatVersions.V1:
pdus_to_check_event_id = [
p
for p in pdus_to_check
Expand All @@ -239,7 +233,7 @@ def sender_err(e, pdu_to_check):
(
get_domain_from_id(p.pdu.event_id),
p.redacted_pdu_json,
p.pdu.origin_server_ts if v.enforce_key_validity else 0,
p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_event_id
Expand Down
16 changes: 7 additions & 9 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,7 @@ async def get_pdu(
pdu = pdu_list[0]

# Check signatures are correct.
signed_pdu = await self._check_sigs_and_hash(
room_version.identifier, pdu
)
signed_pdu = await self._check_sigs_and_hash(room_version, pdu)

break

Expand Down Expand Up @@ -350,7 +348,7 @@ async def _check_sigs_and_hash_and_fetch(
self,
origin: str,
pdus: List[EventBase],
room_version: str,
room_version: RoomVersion,
outlier: bool = False,
include_none: bool = False,
) -> List[EventBase]:
Expand Down Expand Up @@ -396,7 +394,7 @@ def handle_check_result(pdu: EventBase, deferred: Deferred):
self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
room_version=room_version, # type: ignore
room_version=room_version,
outlier=outlier,
timeout=10000,
)
Expand Down Expand Up @@ -434,7 +432,7 @@ async def get_event_auth(self, destination, room_id, event_id):
]

signed_auth = await self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True, room_version=room_version.identifier
destination, auth_chain, outlier=True, room_version=room_version
)

signed_auth.sort(key=lambda e: e.depth)
Expand Down Expand Up @@ -661,7 +659,7 @@ async def send_request(destination) -> Dict[str, Any]:
destination,
list(pdus.values()),
outlier=True,
room_version=room_version.identifier,
room_version=room_version,
)

valid_pdus_map = {p.event_id: p for p in valid_pdus}
Expand Down Expand Up @@ -756,7 +754,7 @@ async def send_invite(
pdu = event_from_pdu_json(pdu_dict, room_version)

# Check signatures are correct.
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
pdu = await self._check_sigs_and_hash(room_version, pdu)

# FIXME: We should handle signature failures more gracefully.

Expand Down Expand Up @@ -948,7 +946,7 @@ async def get_missing_events(
]

signed_events = await self._check_sigs_and_hash_and_fetch(
destination, events, outlier=False, room_version=room_version.identifier
destination, events, outlier=False, room_version=room_version
)
except HttpResponseException as e:
if not e.code == 400:
Expand Down
8 changes: 4 additions & 4 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ async def on_invite_request(
pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, pdu.room_id)
pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
pdu = await self._check_sigs_and_hash(room_version, pdu)
ret_pdu = await self.handler.on_invite_request(origin, pdu, room_version)
time_now = self._clock.time_msec()
return {"event": ret_pdu.get_pdu_json(time_now)}
Expand All @@ -425,7 +425,7 @@ async def on_send_join_request(self, origin, content, room_id):

logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)

pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
pdu = await self._check_sigs_and_hash(room_version, pdu)

res_pdus = await self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
Expand Down Expand Up @@ -455,7 +455,7 @@ async def on_send_leave_request(self, origin, content, room_id):

logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)

pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
pdu = await self._check_sigs_and_hash(room_version, pdu)

await self.handler.on_send_leave_request(origin, pdu)
return {}
Expand Down Expand Up @@ -611,7 +611,7 @@ async def _handle_received_pdu(self, origin, pdu):
logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)

# We've already checked that we know the room version by this point
room_version = await self.store.get_room_version_id(pdu.room_id)
room_version = await self.store.get_room_version(pdu.room_id)

# Check signature.
try:
Expand Down

0 comments on commit 05e331d

Please sign in to comment.