-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add type hints to various handlers. #9223
Changes from all commits
8a72405
a4434a7
02b5509
3f2562b
706060a
23b3a58
50927fc
3f34357
f67b6e6
7115c66
47c7c16
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Add type hints to handlers code. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,8 +22,10 @@ | |
imported conditionally. | ||
""" | ||
import logging | ||
from typing import Dict, Iterable, List | ||
|
||
import attr | ||
import pem | ||
from cryptography.hazmat.backends import default_backend | ||
from cryptography.hazmat.primitives import serialization | ||
from josepy import JWKRSA | ||
|
@@ -36,20 +38,27 @@ | |
from zope.interface import implementer | ||
|
||
from twisted.internet import defer | ||
from twisted.internet.interfaces import IReactorTCP | ||
from twisted.python.filepath import FilePath | ||
from twisted.python.url import URL | ||
from twisted.web.resource import IResource | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource): | ||
def create_issuing_service( | ||
reactor: IReactorTCP, | ||
acme_url: str, | ||
account_key_file: str, | ||
well_known_resource: IResource, | ||
) -> AcmeIssuingService: | ||
"""Create an ACME issuing service, and attach it to a web Resource | ||
|
||
Args: | ||
reactor: twisted reactor | ||
acme_url (str): URL to use to request certificates | ||
account_key_file (str): where to store the account key | ||
well_known_resource (twisted.web.IResource): web resource for .well-known. | ||
acme_url: URL to use to request certificates | ||
account_key_file: where to store the account key | ||
well_known_resource: web resource for .well-known. | ||
we will attach a child resource for "acme-challenge". | ||
|
||
Returns: | ||
|
@@ -83,18 +92,20 @@ class ErsatzStore: | |
A store that only stores in memory. | ||
""" | ||
|
||
certs = attr.ib(default=attr.Factory(dict)) | ||
certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict)) | ||
|
||
def store(self, server_name, pem_objects): | ||
def store( | ||
self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject] | ||
) -> defer.Deferred: | ||
self.certs[server_name] = [o.as_bytes() for o in pem_objects] | ||
return defer.succeed(None) | ||
|
||
|
||
def load_or_create_client_key(key_file): | ||
def load_or_create_client_key(key_file: str) -> JWKRSA: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we should be writing stubs for libraries. Since we ignore the lib in mypy.ini it'll be treated as We do this in other places though, so happy to leave this as is and discuss the broader issue separately. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that would make sense to do as a separate step? 🤷 The ACME stuff in particular is deprecated I think, so maybe not worthwhile to write stubs there. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True, I just know that we sometimes get confused at why mypy isn't picking up obvious bugs because of this sort of thing, but that can be dealt with separately There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I filed #9228. |
||
"""Load the ACME account key from a file, creating it if it does not exist. | ||
|
||
Args: | ||
key_file (str): name of the file to use as the account key | ||
key_file: name of the file to use as the account key | ||
""" | ||
# this is based on txacme.endpoint.load_or_create_client_key, but doesn't | ||
# hardcode the 'client.key' filename | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,9 +15,13 @@ | |
# limitations under the License. | ||
|
||
import logging | ||
from typing import TYPE_CHECKING, Dict, Iterable, List, Set | ||
|
||
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError | ||
from synapse.types import GroupID, get_domain_from_id | ||
from synapse.types import GroupID, JsonDict, get_domain_from_id | ||
|
||
if TYPE_CHECKING: | ||
from synapse.app.homeserver import HomeServer | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -56,7 +60,7 @@ async def f(self, group_id, *args, **kwargs): | |
|
||
|
||
class GroupsLocalWorkerHandler: | ||
def __init__(self, hs): | ||
def __init__(self, hs: "HomeServer"): | ||
self.hs = hs | ||
self.store = hs.get_datastore() | ||
self.room_list_handler = hs.get_room_list_handler() | ||
|
@@ -84,7 +88,9 @@ def __init__(self, hs): | |
get_group_role = _create_rerouter("get_group_role") | ||
get_group_roles = _create_rerouter("get_group_roles") | ||
|
||
async def get_group_summary(self, group_id, requester_user_id): | ||
async def get_group_summary( | ||
self, group_id: str, requester_user_id: str | ||
) -> JsonDict: | ||
"""Get the group summary for a group. | ||
|
||
If the group is remote we check that the users have valid attestations. | ||
|
@@ -137,14 +143,15 @@ async def get_group_summary(self, group_id, requester_user_id): | |
|
||
return res | ||
|
||
async def get_users_in_group(self, group_id, requester_user_id): | ||
async def get_users_in_group( | ||
self, group_id: str, requester_user_id: str | ||
) -> JsonDict: | ||
"""Get users in a group | ||
""" | ||
if self.is_mine_id(group_id): | ||
res = await self.groups_server_handler.get_users_in_group( | ||
return await self.groups_server_handler.get_users_in_group( | ||
group_id, requester_user_id | ||
) | ||
return res | ||
|
||
group_server_name = get_domain_from_id(group_id) | ||
|
||
|
@@ -178,11 +185,11 @@ async def get_users_in_group(self, group_id, requester_user_id): | |
|
||
return res | ||
|
||
async def get_joined_groups(self, user_id): | ||
async def get_joined_groups(self, user_id: str) -> JsonDict: | ||
group_ids = await self.store.get_joined_groups(user_id) | ||
return {"groups": group_ids} | ||
|
||
async def get_publicised_groups_for_user(self, user_id): | ||
async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict: | ||
if self.hs.is_mine_id(user_id): | ||
result = await self.store.get_publicised_groups_for_user(user_id) | ||
|
||
|
@@ -206,8 +213,10 @@ async def get_publicised_groups_for_user(self, user_id): | |
# TODO: Verify attestations | ||
return {"groups": result} | ||
|
||
async def bulk_get_publicised_groups(self, user_ids, proxy=True): | ||
destinations = {} | ||
async def bulk_get_publicised_groups( | ||
self, user_ids: Iterable[str], proxy: bool = True | ||
) -> JsonDict: | ||
destinations = {} # type: Dict[str, Set[str]] | ||
local_users = set() | ||
|
||
for user_id in user_ids: | ||
|
@@ -220,7 +229,7 @@ async def bulk_get_publicised_groups(self, user_ids, proxy=True): | |
raise SynapseError(400, "Some user_ids are not local") | ||
|
||
results = {} | ||
failed_results = [] | ||
failed_results = [] # type: List[str] | ||
for destination, dest_user_ids in destinations.items(): | ||
try: | ||
r = await self.transport_client.bulk_get_publicised_groups( | ||
|
@@ -242,7 +251,7 @@ async def bulk_get_publicised_groups(self, user_ids, proxy=True): | |
|
||
|
||
class GroupsLocalHandler(GroupsLocalWorkerHandler): | ||
def __init__(self, hs): | ||
def __init__(self, hs: "HomeServer"): | ||
super().__init__(hs) | ||
|
||
# Ensure attestations get renewed | ||
|
@@ -271,7 +280,9 @@ def __init__(self, hs): | |
|
||
set_group_join_policy = _create_rerouter("set_group_join_policy") | ||
|
||
async def create_group(self, group_id, user_id, content): | ||
async def create_group( | ||
self, group_id: str, user_id: str, content: JsonDict | ||
) -> JsonDict: | ||
"""Create a group | ||
""" | ||
|
||
|
@@ -284,27 +295,7 @@ async def create_group(self, group_id, user_id, content): | |
local_attestation = None | ||
remote_attestation = None | ||
else: | ||
local_attestation = self.attestations.create_attestation(group_id, user_id) | ||
content["attestation"] = local_attestation | ||
|
||
content["user_profile"] = await self.profile_handler.get_profile(user_id) | ||
|
||
try: | ||
res = await self.transport_client.create_group( | ||
get_domain_from_id(group_id), group_id, user_id, content | ||
) | ||
Comment on lines
-293
to
-295
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
except HttpResponseException as e: | ||
raise e.to_synapse_error() | ||
except RequestSendFailed: | ||
raise SynapseError(502, "Failed to contact group server") | ||
|
||
remote_attestation = res["attestation"] | ||
await self.attestations.verify_attestation( | ||
remote_attestation, | ||
group_id=group_id, | ||
user_id=user_id, | ||
server_name=get_domain_from_id(group_id), | ||
) | ||
raise SynapseError(400, "Unable to create remote groups") | ||
|
||
is_publicised = content.get("publicise", False) | ||
token = await self.store.register_user_group_membership( | ||
|
@@ -320,7 +311,9 @@ async def create_group(self, group_id, user_id, content): | |
|
||
return res | ||
|
||
async def join_group(self, group_id, user_id, content): | ||
async def join_group( | ||
self, group_id: str, user_id: str, content: JsonDict | ||
) -> JsonDict: | ||
"""Request to join a group | ||
""" | ||
if self.is_mine_id(group_id): | ||
|
@@ -365,7 +358,9 @@ async def join_group(self, group_id, user_id, content): | |
|
||
return {} | ||
|
||
async def accept_invite(self, group_id, user_id, content): | ||
async def accept_invite( | ||
self, group_id: str, user_id: str, content: JsonDict | ||
) -> JsonDict: | ||
"""Accept an invite to a group | ||
""" | ||
if self.is_mine_id(group_id): | ||
|
@@ -410,7 +405,9 @@ async def accept_invite(self, group_id, user_id, content): | |
|
||
return {} | ||
|
||
async def invite(self, group_id, user_id, requester_user_id, config): | ||
async def invite( | ||
self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict | ||
) -> JsonDict: | ||
"""Invite a user to a group | ||
""" | ||
content = {"requester_user_id": requester_user_id, "config": config} | ||
|
@@ -434,7 +431,9 @@ async def invite(self, group_id, user_id, requester_user_id, config): | |
|
||
return res | ||
|
||
async def on_invite(self, group_id, user_id, content): | ||
async def on_invite( | ||
self, group_id: str, user_id: str, content: JsonDict | ||
) -> JsonDict: | ||
"""One of our users were invited to a group | ||
""" | ||
# TODO: Support auto join and rejection | ||
|
@@ -465,8 +464,8 @@ async def on_invite(self, group_id, user_id, content): | |
return {"state": "invite", "user_profile": user_profile} | ||
|
||
async def remove_user_from_group( | ||
self, group_id, user_id, requester_user_id, content | ||
): | ||
self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict | ||
) -> JsonDict: | ||
"""Remove a user from a group | ||
""" | ||
if user_id == requester_user_id: | ||
|
@@ -499,7 +498,9 @@ async def remove_user_from_group( | |
|
||
return res | ||
|
||
async def user_removed_from_group(self, group_id, user_id, content): | ||
async def user_removed_from_group( | ||
self, group_id: str, user_id: str, content: JsonDict | ||
) -> None: | ||
"""One of our users was removed/kicked from a group | ||
""" | ||
# TODO: Check if user in group | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method is just awaited and the return type is never used.