Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to various handlers. #9223

Merged
merged 11 commits into from
Jan 26, 2021
1 change: 1 addition & 0 deletions changelog.d/9223.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to handlers code.
14 changes: 14 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ files =
synapse/handlers/_base.py,
synapse/handlers/account_data.py,
synapse/handlers/account_validity.py,
synapse/handlers/acme.py,
synapse/handlers/acme_issuing_service.py,
synapse/handlers/admin.py,
synapse/handlers/appservice.py,
synapse/handlers/auth.py,
Expand All @@ -36,6 +38,7 @@ files =
synapse/handlers/directory.py,
synapse/handlers/events.py,
synapse/handlers/federation.py,
synapse/handlers/groups_local.py,
synapse/handlers/identity.py,
synapse/handlers/initial_sync.py,
synapse/handlers/message.py,
Expand All @@ -52,8 +55,13 @@ files =
synapse/handlers/room_member.py,
synapse/handlers/room_member_worker.py,
synapse/handlers/saml_handler.py,
synapse/handlers/search.py,
synapse/handlers/set_password.py,
synapse/handlers/sso.py,
synapse/handlers/state_deltas.py,
synapse/handlers/stats.py,
synapse/handlers/sync.py,
synapse/handlers/typing.py,
synapse/handlers/user_directory.py,
synapse/handlers/ui_auth,
synapse/http/client.py,
Expand Down Expand Up @@ -194,3 +202,9 @@ ignore_missing_imports = True

[mypy-hiredis]
ignore_missing_imports = True

[mypy-josepy.*]
ignore_missing_imports = True

[mypy-txacme.*]
ignore_missing_imports = True
12 changes: 7 additions & 5 deletions synapse/handlers/acme.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING

import twisted
import twisted.internet.error
Expand All @@ -22,6 +23,9 @@

from synapse.app import check_bind_error

if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer

logger = logging.getLogger(__name__)

ACME_REGISTER_FAIL_ERROR = """
Expand All @@ -35,12 +39,12 @@


class AcmeHandler:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.reactor = hs.get_reactor()
self._acme_domain = hs.config.acme_domain

async def start_listening(self):
async def start_listening(self) -> None:
from synapse.handlers import acme_issuing_service

# Configure logging for txacme, if you need to debug
Expand Down Expand Up @@ -85,7 +89,7 @@ async def start_listening(self):
logger.error(ACME_REGISTER_FAIL_ERROR)
raise

async def provision_certificate(self):
async def provision_certificate(self) -> None:

logger.warning("Reprovisioning %s", self._acme_domain)

Expand All @@ -110,5 +114,3 @@ async def provision_certificate(self):
except Exception:
logger.exception("Failed saving!")
raise

return True
Copy link
Member Author

@clokep clokep Jan 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is just awaited and the return type is never used.

27 changes: 19 additions & 8 deletions synapse/handlers/acme_issuing_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
imported conditionally.
"""
import logging
from typing import Dict, Iterable, List

import attr
import pem
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from josepy import JWKRSA
Expand All @@ -36,20 +38,27 @@
from zope.interface import implementer

from twisted.internet import defer
from twisted.internet.interfaces import IReactorTCP
from twisted.python.filepath import FilePath
from twisted.python.url import URL
from twisted.web.resource import IResource

logger = logging.getLogger(__name__)


def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
def create_issuing_service(
reactor: IReactorTCP,
acme_url: str,
account_key_file: str,
well_known_resource: IResource,
) -> AcmeIssuingService:
"""Create an ACME issuing service, and attach it to a web Resource

Args:
reactor: twisted reactor
acme_url (str): URL to use to request certificates
account_key_file (str): where to store the account key
well_known_resource (twisted.web.IResource): web resource for .well-known.
acme_url: URL to use to request certificates
account_key_file: where to store the account key
well_known_resource: web resource for .well-known.
we will attach a child resource for "acme-challenge".

Returns:
Expand Down Expand Up @@ -83,18 +92,20 @@ class ErsatzStore:
A store that only stores in memory.
"""

certs = attr.ib(default=attr.Factory(dict))
certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict))

def store(self, server_name, pem_objects):
def store(
self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject]
) -> defer.Deferred:
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
return defer.succeed(None)


def load_or_create_client_key(key_file):
def load_or_create_client_key(key_file: str) -> JWKRSA:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should be writing stubs for libraries. Since we ignore the lib in mypy.ini it'll be treated as Any, which is distinctly unhelpful, and perhaps even worse than not having a type hint at all?

We do this in other places though, so happy to leave this as is and discuss the broader issue separately.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that would make sense to do as a separate step? 🤷 The ACME stuff in particular is deprecated I think, so maybe not worthwhile to write stubs there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, I just know that we sometimes get confused at why mypy isn't picking up obvious bugs because of this sort of thing, but that can be dealt with separately

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I filed #9228.

"""Load the ACME account key from a file, creating it if it does not exist.

Args:
key_file (str): name of the file to use as the account key
key_file: name of the file to use as the account key
"""
# this is based on txacme.endpoint.load_or_create_client_key, but doesn't
# hardcode the 'client.key' filename
Expand Down
83 changes: 42 additions & 41 deletions synapse/handlers/groups_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Set

from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.types import GroupID, get_domain_from_id
from synapse.types import GroupID, JsonDict, get_domain_from_id

if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,7 +60,7 @@ async def f(self, group_id, *args, **kwargs):


class GroupsLocalWorkerHandler:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.room_list_handler = hs.get_room_list_handler()
Expand Down Expand Up @@ -84,7 +88,9 @@ def __init__(self, hs):
get_group_role = _create_rerouter("get_group_role")
get_group_roles = _create_rerouter("get_group_roles")

async def get_group_summary(self, group_id, requester_user_id):
async def get_group_summary(
self, group_id: str, requester_user_id: str
) -> JsonDict:
"""Get the group summary for a group.

If the group is remote we check that the users have valid attestations.
Expand Down Expand Up @@ -137,14 +143,15 @@ async def get_group_summary(self, group_id, requester_user_id):

return res

async def get_users_in_group(self, group_id, requester_user_id):
async def get_users_in_group(
self, group_id: str, requester_user_id: str
) -> JsonDict:
"""Get users in a group
"""
if self.is_mine_id(group_id):
res = await self.groups_server_handler.get_users_in_group(
return await self.groups_server_handler.get_users_in_group(
group_id, requester_user_id
)
return res

group_server_name = get_domain_from_id(group_id)

Expand Down Expand Up @@ -178,11 +185,11 @@ async def get_users_in_group(self, group_id, requester_user_id):

return res

async def get_joined_groups(self, user_id):
async def get_joined_groups(self, user_id: str) -> JsonDict:
group_ids = await self.store.get_joined_groups(user_id)
return {"groups": group_ids}

async def get_publicised_groups_for_user(self, user_id):
async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict:
if self.hs.is_mine_id(user_id):
result = await self.store.get_publicised_groups_for_user(user_id)

Expand All @@ -206,8 +213,10 @@ async def get_publicised_groups_for_user(self, user_id):
# TODO: Verify attestations
return {"groups": result}

async def bulk_get_publicised_groups(self, user_ids, proxy=True):
destinations = {}
async def bulk_get_publicised_groups(
self, user_ids: Iterable[str], proxy: bool = True
) -> JsonDict:
destinations = {} # type: Dict[str, Set[str]]
local_users = set()

for user_id in user_ids:
Expand All @@ -220,7 +229,7 @@ async def bulk_get_publicised_groups(self, user_ids, proxy=True):
raise SynapseError(400, "Some user_ids are not local")

results = {}
failed_results = []
failed_results = [] # type: List[str]
for destination, dest_user_ids in destinations.items():
try:
r = await self.transport_client.bulk_get_publicised_groups(
Expand All @@ -242,7 +251,7 @@ async def bulk_get_publicised_groups(self, user_ids, proxy=True):


class GroupsLocalHandler(GroupsLocalWorkerHandler):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

# Ensure attestations get renewed
Expand Down Expand Up @@ -271,7 +280,9 @@ def __init__(self, hs):

set_group_join_policy = _create_rerouter("set_group_join_policy")

async def create_group(self, group_id, user_id, content):
async def create_group(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""Create a group
"""

Expand All @@ -284,27 +295,7 @@ async def create_group(self, group_id, user_id, content):
local_attestation = None
remote_attestation = None
else:
local_attestation = self.attestations.create_attestation(group_id, user_id)
content["attestation"] = local_attestation

content["user_profile"] = await self.profile_handler.get_profile(user_id)

try:
res = await self.transport_client.create_group(
get_domain_from_id(group_id), group_id, user_id, content
)
Comment on lines -293 to -295
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TransportLayerClient does not have a create_group, since the groups code is to be replaced it seemed prudent to just fail quickly here.

except HttpResponseException as e:
raise e.to_synapse_error()
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")

remote_attestation = res["attestation"]
await self.attestations.verify_attestation(
remote_attestation,
group_id=group_id,
user_id=user_id,
server_name=get_domain_from_id(group_id),
)
raise SynapseError(400, "Unable to create remote groups")

is_publicised = content.get("publicise", False)
token = await self.store.register_user_group_membership(
Expand All @@ -320,7 +311,9 @@ async def create_group(self, group_id, user_id, content):

return res

async def join_group(self, group_id, user_id, content):
async def join_group(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""Request to join a group
"""
if self.is_mine_id(group_id):
Expand Down Expand Up @@ -365,7 +358,9 @@ async def join_group(self, group_id, user_id, content):

return {}

async def accept_invite(self, group_id, user_id, content):
async def accept_invite(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""Accept an invite to a group
"""
if self.is_mine_id(group_id):
Expand Down Expand Up @@ -410,7 +405,9 @@ async def accept_invite(self, group_id, user_id, content):

return {}

async def invite(self, group_id, user_id, requester_user_id, config):
async def invite(
self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
) -> JsonDict:
"""Invite a user to a group
"""
content = {"requester_user_id": requester_user_id, "config": config}
Expand All @@ -434,7 +431,9 @@ async def invite(self, group_id, user_id, requester_user_id, config):

return res

async def on_invite(self, group_id, user_id, content):
async def on_invite(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""One of our users were invited to a group
"""
# TODO: Support auto join and rejection
Expand Down Expand Up @@ -465,8 +464,8 @@ async def on_invite(self, group_id, user_id, content):
return {"state": "invite", "user_profile": user_profile}

async def remove_user_from_group(
self, group_id, user_id, requester_user_id, content
):
self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
) -> JsonDict:
"""Remove a user from a group
"""
if user_id == requester_user_id:
Expand Down Expand Up @@ -499,7 +498,9 @@ async def remove_user_from_group(

return res

async def user_removed_from_group(self, group_id, user_id, content):
async def user_removed_from_group(
self, group_id: str, user_id: str, content: JsonDict
) -> None:
"""One of our users was removed/kicked from a group
"""
# TODO: Check if user in group
Expand Down
Loading