From b9fee28a3a3cff09604831bf17d87ac0ce35cbbd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 26 May 2021 15:22:15 -0400 Subject: [PATCH 1/7] Add types to TransportLayerServer. --- synapse/federation/transport/server.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index bed47f8abd5d..eaacee53e74e 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -15,7 +15,7 @@ import functools import logging import re -from typing import Container, Mapping, Optional, Sequence, Tuple, Type +from typing import Container, List, Mapping, Optional, Sequence, Tuple, Type import synapse from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH @@ -56,15 +56,15 @@ class TransportLayerServer(JsonResource): """Handles incoming federation HTTP requests""" - def __init__(self, hs, servlet_groups=None): + def __init__(self, hs: HomeServer, servlet_groups: Optional[List[str]] = None): """Initialize the TransportLayerServer Will by default register all servlets. For custom behaviour, pass in a list of servlet_groups to register. Args: - hs (synapse.server.HomeServer): homeserver - servlet_groups (list[str], optional): List of servlet groups to register. + hs: homeserver + servlet_groups: List of servlet groups to register. Defaults to ``DEFAULT_SERVLET_GROUPS``. """ self.hs = hs @@ -78,7 +78,7 @@ def __init__(self, hs, servlet_groups=None): self.register_servlets() - def register_servlets(self): + def register_servlets(self) -> None: register_servlets( self.hs, resource=self, @@ -91,14 +91,10 @@ def register_servlets(self): class AuthenticationError(SynapseError): """There was a problem authenticating the request""" - pass - class NoAuthenticationError(AuthenticationError): """The request had no authentication information""" - pass - class Authenticator: def __init__(self, hs: HomeServer): From de0073fea5aac04910adbfedb5759f6a476004ae Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 8 Jun 2021 14:41:14 -0400 Subject: [PATCH 2/7] Use parse_strings_from_args in more places. --- synapse/federation/transport/server.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index eaacee53e74e..80dbd4a1e7c5 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -525,10 +525,8 @@ async def on_GET(self, origin, _content, query, room_id, user_id): Returns: Tuple[int, object]: (response code, response object) """ - versions = query.get(b"ver") - if versions is not None: - supported_versions = [v.decode("utf-8") for v in versions] - else: + supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8") + if supported_versions is None: supported_versions = ["1"] content = await self.handler.on_make_join_request( @@ -759,14 +757,14 @@ class OpenIdUserInfo(BaseFederationServerServlet): REQUIRE_AUTH = False async def on_GET(self, origin, content, query): - token = query.get(b"access_token", [None])[0] + token = parse_string_from_args(query, "access_token") if token is None: return ( 401, {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"}, ) - user_id = await self.handler.on_openid_userinfo(token.decode("ascii")) + user_id = await self.handler.on_openid_userinfo(token) if user_id is None: return ( From d1f2b9f4c5a900f6f3b24cd9d80da45a51494ca7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 8 Jun 2021 14:23:35 -0400 Subject: [PATCH 3/7] Add type hints to on_FOO methods. --- synapse/federation/transport/server.py | 503 +++++++++++++++++++++---- 1 file changed, 430 insertions(+), 73 deletions(-) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 80dbd4a1e7c5..42d3473f40b6 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -15,7 +15,9 @@ import functools import logging import re -from typing import Container, List, Mapping, Optional, Sequence, Tuple, Type +from typing import Container, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union + +from typing_extensions import Literal import synapse from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH @@ -406,13 +408,18 @@ class FederationSendServlet(BaseFederationServerServlet): RATELIMIT = False # This is when someone is trying to send us a bunch of data. - async def on_PUT(self, origin, content, query, transaction_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + transaction_id: str, + ) -> Tuple[int, JsonDict]: """Called on PUT /send// Args: - request (twisted.web.http.Request): The HTTP request. - transaction_id (str): The transaction_id associated with this - request. This is *not* None. + transaction_id: The transaction_id associated with this request. This + is *not* None. Returns: Tuple of `(code, response)`, where @@ -457,7 +464,13 @@ class FederationEventServlet(BaseFederationServerServlet): PATH = "/event/(?P[^/]*)/?" # This is when someone asks for a data item for a given server data_id pair. - async def on_GET(self, origin, content, query, event_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + event_id: str, + ) -> Tuple[int, Union[JsonDict, str]]: return await self.handler.on_pdu_request(origin, event_id) @@ -465,7 +478,13 @@ class FederationStateV1Servlet(BaseFederationServerServlet): PATH = "/state/(?P[^/]*)/?" # This is when someone asks for all data for a given room. - async def on_GET(self, origin, content, query, room_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: return await self.handler.on_room_state_request( origin, room_id, @@ -476,7 +495,13 @@ async def on_GET(self, origin, content, query, room_id): class FederationStateIdsServlet(BaseFederationServerServlet): PATH = "/state_ids/(?P[^/]*)/?" - async def on_GET(self, origin, content, query, room_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: return await self.handler.on_state_ids_request( origin, room_id, @@ -487,7 +512,13 @@ async def on_GET(self, origin, content, query, room_id): class FederationBackfillServlet(BaseFederationServerServlet): PATH = "/backfill/(?P[^/]*)/?" - async def on_GET(self, origin, content, query, room_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: versions = [x.decode("ascii") for x in query[b"v"]] limit = parse_integer_from_args(query, "limit", None) @@ -501,7 +532,13 @@ class FederationQueryServlet(BaseFederationServerServlet): PATH = "/query/(?P[^/]*)" # This is when we receive a server-server Query - async def on_GET(self, origin, content, query, query_type): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + query_type: str, + ) -> Tuple[int, JsonDict]: args = {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()} args["origin"] = origin return await self.handler.on_query_request(query_type, args) @@ -510,20 +547,27 @@ async def on_GET(self, origin, content, query, query_type): class FederationMakeJoinServlet(BaseFederationServerServlet): PATH = "/make_join/(?P[^/]*)/(?P[^/]*)" - async def on_GET(self, origin, _content, query, room_id, user_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: """ Args: - origin (unicode): The authenticated server_name of the calling server + origin: The authenticated server_name of the calling server - _content (None): (GETs don't have bodies) + content: (GETs don't have bodies) - query (dict[bytes, list[bytes]]): Query params from the request. + query: Query params from the request. - **kwargs (dict[unicode, unicode]): the dict mapping keys to path - components as specified in the path match regexp. + **kwargs: the dict mapping keys to path components as specified in + the path match regexp. Returns: - Tuple[int, object]: (response code, response object) + Tuple of (response code, response object) """ supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8") if supported_versions is None: @@ -538,7 +582,14 @@ async def on_GET(self, origin, _content, query, room_id, user_id): class FederationMakeLeaveServlet(BaseFederationServerServlet): PATH = "/make_leave/(?P[^/]*)/(?P[^/]*)" - async def on_GET(self, origin, content, query, room_id, user_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: content = await self.handler.on_make_leave_request(origin, room_id, user_id) return 200, content @@ -546,7 +597,14 @@ async def on_GET(self, origin, content, query, room_id, user_id): class FederationV1SendLeaveServlet(BaseFederationServerServlet): PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" - async def on_PUT(self, origin, content, query, room_id, event_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, Tuple[int, JsonDict]]: content = await self.handler.on_send_leave_request(origin, content) return 200, (200, content) @@ -556,7 +614,14 @@ class FederationV2SendLeaveServlet(BaseFederationServerServlet): PREFIX = FEDERATION_V2_PREFIX - async def on_PUT(self, origin, content, query, room_id, event_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: content = await self.handler.on_send_leave_request(origin, content) return 200, content @@ -564,7 +629,14 @@ async def on_PUT(self, origin, content, query, room_id, event_id): class FederationMakeKnockServlet(BaseFederationServerServlet): PATH = "/make_knock/(?P[^/]*)/(?P[^/]*)" - async def on_GET(self, origin, content, query, room_id, user_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: try: # Retrieve the room versions the remote homeserver claims to support supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8") @@ -580,7 +652,14 @@ async def on_GET(self, origin, content, query, room_id, user_id): class FederationV1SendKnockServlet(BaseFederationServerServlet): PATH = "/send_knock/(?P[^/]*)/(?P[^/]*)" - async def on_PUT(self, origin, content, query, room_id, event_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: content = await self.handler.on_send_knock_request(origin, content, room_id) return 200, content @@ -588,14 +667,28 @@ async def on_PUT(self, origin, content, query, room_id, event_id): class FederationEventAuthServlet(BaseFederationServerServlet): PATH = "/event_auth/(?P[^/]*)/(?P[^/]*)" - async def on_GET(self, origin, content, query, room_id, event_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: return await self.handler.on_event_auth(origin, room_id, event_id) class FederationV1SendJoinServlet(BaseFederationServerServlet): PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" - async def on_PUT(self, origin, content, query, room_id, event_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, Tuple[int, JsonDict]]: # TODO(paul): assert that room_id/event_id parsed from path actually # match those given in content content = await self.handler.on_send_join_request(origin, content) @@ -607,7 +700,14 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet): PREFIX = FEDERATION_V2_PREFIX - async def on_PUT(self, origin, content, query, room_id, event_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: # TODO(paul): assert that room_id/event_id parsed from path actually # match those given in content content = await self.handler.on_send_join_request(origin, content) @@ -617,7 +717,14 @@ async def on_PUT(self, origin, content, query, room_id, event_id): class FederationV1InviteServlet(BaseFederationServerServlet): PATH = "/invite/(?P[^/]*)/(?P[^/]*)" - async def on_PUT(self, origin, content, query, room_id, event_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, Tuple[int, JsonDict]]: # We don't get a room version, so we have to assume its EITHER v1 or # v2. This is "fine" as the only difference between V1 and V2 is the # state resolution algorithm, and we don't use that for processing @@ -636,7 +743,14 @@ class FederationV2InviteServlet(BaseFederationServerServlet): PREFIX = FEDERATION_V2_PREFIX - async def on_PUT(self, origin, content, query, room_id, event_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: # TODO(paul): assert that room_id/event_id parsed from path actually # match those given in content @@ -658,7 +772,13 @@ async def on_PUT(self, origin, content, query, room_id, event_id): class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet): PATH = "/exchange_third_party_invite/(?P[^/]*)" - async def on_PUT(self, origin, content, query, room_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: await self.handler.on_exchange_third_party_invite_request(content) return 200, {} @@ -666,21 +786,31 @@ async def on_PUT(self, origin, content, query, room_id): class FederationClientKeysQueryServlet(BaseFederationServerServlet): PATH = "/user/keys/query" - async def on_POST(self, origin, content, query): + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: return await self.handler.on_query_client_keys(origin, content) class FederationUserDevicesQueryServlet(BaseFederationServerServlet): PATH = "/user/devices/(?P[^/]*)" - async def on_GET(self, origin, content, query, user_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + user_id: str, + ) -> Tuple[int, JsonDict]: return await self.handler.on_query_user_devices(origin, user_id) class FederationClientKeysClaimServlet(BaseFederationServerServlet): PATH = "/user/keys/claim" - async def on_POST(self, origin, content, query): + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: response = await self.handler.on_claim_client_keys(origin, content) return 200, response @@ -689,7 +819,13 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet): # TODO(paul): Why does this path alone end with "/?" optional? PATH = "/get_missing_events/(?P[^/]*)/?" - async def on_POST(self, origin, content, query, room_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: limit = int(content.get("limit", 10)) earliest_events = content.get("earliest_events", []) latest_events = content.get("latest_events", []) @@ -710,7 +846,9 @@ class On3pidBindServlet(BaseFederationServerServlet): REQUIRE_AUTH = False - async def on_POST(self, origin, content, query): + async def on_POST( + self, origin: Optional[str], content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: if "invites" in content: last_exception = None for invite in content["invites"]: @@ -756,7 +894,12 @@ class OpenIdUserInfo(BaseFederationServerServlet): REQUIRE_AUTH = False - async def on_GET(self, origin, content, query): + async def on_GET( + self, + origin: Optional[str], + content: Literal[None], + query: Dict[bytes, List[bytes]], + ) -> Tuple[int, JsonDict]: token = parse_string_from_args(query, "access_token") if token is None: return ( @@ -823,7 +966,9 @@ def __init__( self.handler = hs.get_room_list_handler() self.allow_access = allow_access - async def on_GET(self, origin, content, query): + async def on_GET( + self, origin: str, content: Literal[None], query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: if not self.allow_access: raise FederationDeniedError(origin) @@ -852,7 +997,9 @@ async def on_GET(self, origin, content, query): ) return 200, data - async def on_POST(self, origin, content, query): + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: # This implements MSC2197 (Search Filtering over Federation) if not self.allow_access: raise FederationDeniedError(origin) @@ -898,7 +1045,12 @@ class FederationVersionServlet(BaseFederationServlet): REQUIRE_AUTH = False - async def on_GET(self, origin, content, query): + async def on_GET( + self, + origin: Optional[str], + content: Literal[None], + query: Dict[bytes, List[bytes]], + ) -> Tuple[int, JsonDict]: return ( 200, {"server": {"name": "Synapse", "version": get_version_string(synapse)}}, @@ -927,7 +1079,13 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/profile" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -936,7 +1094,13 @@ async def on_GET(self, origin, content, query, group_id): return 200, new_content - async def on_POST(self, origin, content, query, group_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -951,7 +1115,13 @@ async def on_POST(self, origin, content, query, group_id): class FederationGroupsSummaryServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/summary" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -966,7 +1136,13 @@ class FederationGroupsRoomsServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/rooms" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -981,7 +1157,14 @@ class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/room/(?P[^/]*)" - async def on_POST(self, origin, content, query, group_id, room_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -992,7 +1175,14 @@ async def on_POST(self, origin, content, query, group_id, room_id): return 200, new_content - async def on_DELETE(self, origin, content, query, group_id, room_id): + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1012,7 +1202,15 @@ class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet): "/config/(?P[^/]*)" ) - async def on_POST(self, origin, content, query, group_id, room_id, config_key): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + room_id: str, + config_key: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1029,7 +1227,13 @@ class FederationGroupsUsersServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/users" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1044,7 +1248,13 @@ class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/invited_users" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1061,7 +1271,14 @@ class FederationGroupsInviteServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/invite" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1078,7 +1295,14 @@ class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/accept_invite" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: if get_domain_from_id(user_id) != origin: raise SynapseError(403, "user_id doesn't match origin") @@ -1092,7 +1316,14 @@ class FederationGroupsJoinServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/join" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: if get_domain_from_id(user_id) != origin: raise SynapseError(403, "user_id doesn't match origin") @@ -1106,7 +1337,14 @@ class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/remove" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1140,7 +1378,14 @@ class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet): PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/invite" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: if get_domain_from_id(group_id) != origin: raise SynapseError(403, "group_id doesn't match origin") @@ -1158,7 +1403,14 @@ class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet): PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/remove" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, None]: if get_domain_from_id(group_id) != origin: raise SynapseError(403, "user_id doesn't match origin") @@ -1166,11 +1418,9 @@ async def on_POST(self, origin, content, query, group_id, user_id): self.handler, GroupsLocalHandler ), "Workers cannot handle group removals." - new_content = await self.handler.user_removed_from_group( - group_id, user_id, content - ) + await self.handler.user_removed_from_group(group_id, user_id, content) - return 200, new_content + return 200, None class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): @@ -1188,7 +1438,14 @@ def __init__( super().__init__(hs, authenticator, ratelimiter, server_name) self.handler = hs.get_groups_attestation_renewer() - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: # We don't need to check auth here as we check the attestation signatures new_content = await self.handler.on_renew_attestation( @@ -1212,7 +1469,15 @@ class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet): "/rooms/(?P[^/]*)" ) - async def on_POST(self, origin, content, query, group_id, category_id, room_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1240,7 +1505,15 @@ async def on_POST(self, origin, content, query, group_id, category_id, room_id): return 200, resp - async def on_DELETE(self, origin, content, query, group_id, category_id, room_id): + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1260,7 +1533,13 @@ class FederationGroupsCategoriesServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/categories/?" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1275,7 +1554,14 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/categories/(?P[^/]+)" - async def on_GET(self, origin, content, query, group_id, category_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1286,7 +1572,14 @@ async def on_GET(self, origin, content, query, group_id, category_id): return 200, resp - async def on_POST(self, origin, content, query, group_id, category_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1308,7 +1601,14 @@ async def on_POST(self, origin, content, query, group_id, category_id): return 200, resp - async def on_DELETE(self, origin, content, query, group_id, category_id): + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1328,7 +1628,13 @@ class FederationGroupsRolesServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/roles/?" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1343,7 +1649,14 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/roles/(?P[^/]+)" - async def on_GET(self, origin, content, query, group_id, role_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1352,7 +1665,14 @@ async def on_GET(self, origin, content, query, group_id, role_id): return 200, resp - async def on_POST(self, origin, content, query, group_id, role_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1376,7 +1696,14 @@ async def on_POST(self, origin, content, query, group_id, role_id): return 200, resp - async def on_DELETE(self, origin, content, query, group_id, role_id): + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1405,7 +1732,15 @@ class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet): "/users/(?P[^/]*)" ) - async def on_POST(self, origin, content, query, group_id, role_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1431,7 +1766,15 @@ async def on_POST(self, origin, content, query, group_id, role_id, user_id): return 200, resp - async def on_DELETE(self, origin, content, query, group_id, role_id, user_id): + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1451,7 +1794,9 @@ class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet): PATH = "/get_groups_publicised" - async def on_POST(self, origin, content, query): + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: resp = await self.handler.bulk_get_publicised_groups( content["user_ids"], proxy=False ) @@ -1464,7 +1809,13 @@ class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/settings/m.join_policy" - async def on_PUT(self, origin, content, query, group_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1493,7 +1844,7 @@ def __init__( async def on_GET( self, origin: str, - content: JsonDict, + content: Literal[None], query: Mapping[bytes, Sequence[bytes]], room_id: str, ) -> Tuple[int, JsonDict]: @@ -1565,7 +1916,13 @@ def __init__( super().__init__(hs, authenticator, ratelimiter, server_name) self._store = self.hs.get_datastore() - async def on_GET(self, origin, content, query, room_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: is_public = await self._store.is_room_world_readable_or_publicly_joinable( room_id ) From 5403a7285ae9811f8058014378ce1af400295c80 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Jun 2021 14:01:34 -0400 Subject: [PATCH 4/7] Avoid changing variable types. --- synapse/federation/transport/server.py | 56 +++++++++++++++----------- 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 42d3473f40b6..dc9cd7cac10a 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -15,7 +15,17 @@ import functools import logging import re -from typing import Container, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import ( + Container, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, +) from typing_extensions import Literal @@ -573,10 +583,10 @@ async def on_GET( if supported_versions is None: supported_versions = ["1"] - content = await self.handler.on_make_join_request( + result = await self.handler.on_make_join_request( origin, room_id, user_id, supported_versions=supported_versions ) - return 200, content + return 200, result class FederationMakeLeaveServlet(BaseFederationServerServlet): @@ -590,8 +600,8 @@ async def on_GET( room_id: str, user_id: str, ) -> Tuple[int, JsonDict]: - content = await self.handler.on_make_leave_request(origin, room_id, user_id) - return 200, content + result = await self.handler.on_make_leave_request(origin, room_id, user_id) + return 200, result class FederationV1SendLeaveServlet(BaseFederationServerServlet): @@ -605,8 +615,8 @@ async def on_PUT( room_id: str, event_id: str, ) -> Tuple[int, Tuple[int, JsonDict]]: - content = await self.handler.on_send_leave_request(origin, content) - return 200, (200, content) + result = await self.handler.on_send_leave_request(origin, content) + return 200, (200, result) class FederationV2SendLeaveServlet(BaseFederationServerServlet): @@ -622,8 +632,8 @@ async def on_PUT( room_id: str, event_id: str, ) -> Tuple[int, JsonDict]: - content = await self.handler.on_send_leave_request(origin, content) - return 200, content + result = await self.handler.on_send_leave_request(origin, content) + return 200, result class FederationMakeKnockServlet(BaseFederationServerServlet): @@ -643,10 +653,10 @@ async def on_GET( except KeyError: raise SynapseError(400, "Missing required query parameter 'ver'") - content = await self.handler.on_make_knock_request( + result = await self.handler.on_make_knock_request( origin, room_id, user_id, supported_versions=supported_versions ) - return 200, content + return 200, result class FederationV1SendKnockServlet(BaseFederationServerServlet): @@ -660,8 +670,8 @@ async def on_PUT( room_id: str, event_id: str, ) -> Tuple[int, JsonDict]: - content = await self.handler.on_send_knock_request(origin, content, room_id) - return 200, content + result = await self.handler.on_send_knock_request(origin, content, room_id) + return 200, result class FederationEventAuthServlet(BaseFederationServerServlet): @@ -691,8 +701,8 @@ async def on_PUT( ) -> Tuple[int, Tuple[int, JsonDict]]: # TODO(paul): assert that room_id/event_id parsed from path actually # match those given in content - content = await self.handler.on_send_join_request(origin, content) - return 200, (200, content) + result = await self.handler.on_send_join_request(origin, content) + return 200, (200, result) class FederationV2SendJoinServlet(BaseFederationServerServlet): @@ -710,8 +720,8 @@ async def on_PUT( ) -> Tuple[int, JsonDict]: # TODO(paul): assert that room_id/event_id parsed from path actually # match those given in content - content = await self.handler.on_send_join_request(origin, content) - return 200, content + result = await self.handler.on_send_join_request(origin, content) + return 200, result class FederationV1InviteServlet(BaseFederationServerServlet): @@ -729,13 +739,13 @@ async def on_PUT( # v2. This is "fine" as the only difference between V1 and V2 is the # state resolution algorithm, and we don't use that for processing # invites - content = await self.handler.on_invite_request( + result = await self.handler.on_invite_request( origin, content, room_version_id=RoomVersions.V1.identifier ) # V1 federation API is defined to return a content of `[200, {...}]` # due to a historical bug. - return 200, (200, content) + return 200, (200, result) class FederationV2InviteServlet(BaseFederationServerServlet): @@ -763,10 +773,10 @@ async def on_PUT( event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state - content = await self.handler.on_invite_request( + result = await self.handler.on_invite_request( origin, event, room_version_id=room_version ) - return 200, content + return 200, result class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet): @@ -830,7 +840,7 @@ async def on_POST( earliest_events = content.get("earliest_events", []) latest_events = content.get("latest_events", []) - content = await self.handler.on_get_missing_events( + result = await self.handler.on_get_missing_events( origin, room_id=room_id, earliest_events=earliest_events, @@ -838,7 +848,7 @@ async def on_POST( limit=limit, ) - return 200, content + return 200, result class On3pidBindServlet(BaseFederationServerServlet): From 60336cd440da587419004c9212c166451a9aa7ca Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Jun 2021 14:04:46 -0400 Subject: [PATCH 5/7] Newsfragment --- changelog.d/10213.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/10213.misc diff --git a/changelog.d/10213.misc b/changelog.d/10213.misc new file mode 100644 index 000000000000..9adb0fbd02d3 --- /dev/null +++ b/changelog.d/10213.misc @@ -0,0 +1 @@ +Add type hints to the federation servlets. From ed8f2149356db7de82038cfedb6dbdf743eb8571 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 23 Jun 2021 11:09:57 +0100 Subject: [PATCH 6/7] Fix mypy --- synapse/federation/transport/server.py | 4 ++- synapse/http/servlet.py | 50 +++++++++++++++++++++++--- 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index dc9cd7cac10a..ba277cb832a8 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -649,7 +649,9 @@ async def on_GET( ) -> Tuple[int, JsonDict]: try: # Retrieve the room versions the remote homeserver claims to support - supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8") + supported_versions = parse_strings_from_args( + query, "ver", required=True, encoding="utf-8" + ) except KeyError: raise SynapseError(400, "Missing required query parameter 'ver'") diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index fda8da21b79c..6ba2ce1e53a2 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -109,12 +109,22 @@ def parse_boolean_from_args(args, name, default=None, required=False): return default +@overload +def parse_bytes_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Optional[bytes] = None, +) -> Optional[bytes]: + ... + + @overload def parse_bytes_from_args( args: Dict[bytes, List[bytes]], name: str, default: Literal[None] = None, - required: Literal[True] = True, + *, + required: Literal[True], ) -> bytes: ... @@ -197,7 +207,12 @@ def parse_string( """ args = request.args # type: Dict[bytes, List[bytes]] # type: ignore return parse_string_from_args( - args, name, default, required, allowed_values, encoding + args, + name, + default, + required=required, + allowed_values=allowed_values, + encoding=encoding, ) @@ -227,7 +242,20 @@ def parse_strings_from_args( args: Dict[bytes, List[bytes]], name: str, default: Optional[List[str]] = None, - required: Literal[True] = True, + *, + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> Optional[List[str]]: + ... + + +@overload +def parse_strings_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Optional[List[str]] = None, + *, + required: Literal[True], allowed_values: Optional[Iterable[str]] = None, encoding: str = "ascii", ) -> List[str]: @@ -239,6 +267,7 @@ def parse_strings_from_args( args: Dict[bytes, List[bytes]], name: str, default: Optional[List[str]] = None, + *, required: bool = False, allowed_values: Optional[Iterable[str]] = None, encoding: str = "ascii", @@ -299,7 +328,20 @@ def parse_string_from_args( args: Dict[bytes, List[bytes]], name: str, default: Optional[str] = None, - required: Literal[True] = True, + *, + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> Optional[str]: + ... + + +@overload +def parse_string_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Optional[str] = None, + *, + required: Literal[True], allowed_values: Optional[Iterable[str]] = None, encoding: str = "ascii", ) -> str: From a27ef9ec07a45022328ecc74da25214e23452af4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 25 Jun 2021 08:23:41 -0400 Subject: [PATCH 7/7] Remove unnecessary try-except. --- synapse/federation/transport/server.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index ba277cb832a8..559bec8ace9f 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -647,13 +647,10 @@ async def on_GET( room_id: str, user_id: str, ) -> Tuple[int, JsonDict]: - try: - # Retrieve the room versions the remote homeserver claims to support - supported_versions = parse_strings_from_args( - query, "ver", required=True, encoding="utf-8" - ) - except KeyError: - raise SynapseError(400, "Missing required query parameter 'ver'") + # Retrieve the room versions the remote homeserver claims to support + supported_versions = parse_strings_from_args( + query, "ver", required=True, encoding="utf-8" + ) result = await self.handler.on_make_knock_request( origin, room_id, user_id, supported_versions=supported_versions