Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Misc typing fixes for tests, part 2 of N #11330

Merged
merged 13 commits into from
Nov 16, 2021
1 change: 1 addition & 0 deletions changelog.d/11330.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type annotations in Synapse's test suite.
9 changes: 6 additions & 3 deletions tests/handlers/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,23 +193,26 @@ def test_mau_limits_when_disabled(self):

@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self):
self.store.count_monthly_users = Mock(
# Type ignore: mypy doesn't like us assigning to methods.
self.store.count_monthly_users = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
)
# Ensure does not throw exception
self.get_success(self.get_or_create_user(self.requester, "c", "User"))

@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_blocked(self):
self.store.get_monthly_active_count = Mock(
# Type ignore: mypy doesn't like us assigning to methods.
self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.lots_of_users)
)
self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"),
ResourceLimitError,
)

self.store.get_monthly_active_count = Mock(
# Type ignore: mypy doesn't like us assigning to methods.
self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.hs.config.server.max_mau_value)
)
self.get_failure(
Expand Down
51 changes: 42 additions & 9 deletions tests/rest/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@
MutableMapping,
Optional,
Tuple,
Union,
overload,
)
from unittest.mock import patch

import attr
from typing_extensions import Literal

from twisted.web.resource import Resource
from twisted.web.server import Site
Expand All @@ -55,6 +56,32 @@ class RestHelper:
site = attr.ib(type=Site)
auth_user_id = attr.ib()

@overload
def create_room_as(
self,
room_creator: Optional[str] = ...,
is_public: Optional[bool] = ...,
room_version: Optional[str] = ...,
tok: Optional[str] = ...,
expect_code: Literal[200] = ...,
extra_content: Optional[Dict] = ...,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
) -> str:
...

@overload
def create_room_as(
self,
room_creator: Optional[str] = ...,
is_public: Optional[bool] = ...,
room_version: Optional[str] = ...,
tok: Optional[str] = ...,
expect_code: int = ...,
extra_content: Optional[Dict] = ...,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
) -> Optional[str]:
...

def create_room_as(
self,
room_creator: Optional[str] = None,
Expand All @@ -64,7 +91,7 @@ def create_room_as(
expect_code: int = 200,
extra_content: Optional[Dict] = None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
) -> str:
) -> Optional[str]:
"""
Create a room.

Expand Down Expand Up @@ -107,6 +134,8 @@ def create_room_as(

if expect_code == 200:
return channel.json_body["room_id"]
else:
return None

def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
self.change_membership(
Expand Down Expand Up @@ -176,7 +205,7 @@ def change_membership(
extra_data: Optional[dict] = None,
tok: Optional[str] = None,
expect_code: int = 200,
expect_errcode: str = None,
expect_errcode: Optional[str] = None,
) -> None:
"""
Send a membership state event into a room.
Expand Down Expand Up @@ -260,9 +289,7 @@ def send_event(
txn_id=None,
tok=None,
expect_code=200,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
Expand Down Expand Up @@ -509,7 +536,7 @@ def auth_via_oidc(
went.
"""

cookies = {}
cookies: Dict[str, str] = {}

# if we're doing a ui auth, hit the ui auth redirect endpoint
if ui_auth_session_id:
Expand Down Expand Up @@ -631,7 +658,13 @@ def initiate_sso_login(

# hit the redirect url again with the right Host header, which should now issue
# a cookie and redirect to the SSO provider.
location = channel.headers.getRawHeaders("Location")[0]
def get_location(channel: FakeChannel) -> str:
location_values = channel.headers.getRawHeaders("Location")
clokep marked this conversation as resolved.
Show resolved Hide resolved
# Keep mypy happy by asserting that location_values is nonempty
assert location_values
return location_values[0]

location = get_location(channel)
parts = urllib.parse.urlsplit(location)
channel = make_request(
self.hs.get_reactor(),
Expand All @@ -645,7 +678,7 @@ def initiate_sso_login(

assert channel.code == 302
channel.extract_cookies(cookies)
return channel.headers.getRawHeaders("Location")[0]
return get_location(channel)

def initiate_sso_ui_auth(
self, ui_auth_session_id: str, cookies: MutableMapping[str, str]
Expand Down
3 changes: 2 additions & 1 deletion tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
MutableMapping,
Optional,
Tuple,
Type,
Union,
)

Expand Down Expand Up @@ -226,7 +227,7 @@ def make_request(
path: Union[bytes, str],
content: Union[bytes, str, JsonDict] = b"",
access_token: Optional[str] = None,
request: Request = SynapseRequest,
request: Type[Request] = SynapseRequest,
shorthand: bool = True,
federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
Expand Down
31 changes: 15 additions & 16 deletions tests/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest
from twisted.web.resource import Resource
from twisted.web.server import Request

from synapse import events
from synapse.api.constants import EventTypes, Membership
Expand Down Expand Up @@ -95,16 +96,13 @@ def new(*args, **kwargs):
return _around


T = TypeVar("T")


class TestCase(unittest.TestCase):
"""A subclass of twisted.trial's TestCase which looks for 'loglevel'
attributes on both itself and its individual test methods, to override the
root logger's logging level while that test (case|method) runs."""

def __init__(self, methodName, *args, **kwargs):
super().__init__(methodName, *args, **kwargs)
def __init__(self, methodName: str):
super().__init__(methodName)

method = getattr(self, methodName)

Expand Down Expand Up @@ -220,16 +218,16 @@ class HomeserverTestCase(TestCase):
Attributes:
servlets: List of servlet registration function.
user_id (str): The user ID to assume if auth is hijacked.
hijack_auth (bool): Whether to hijack auth to return the user specified
hijack_auth: Whether to hijack auth to return the user specified
in user_id.
"""

hijack_auth = True
needs_threadpool = False
hijack_auth: ClassVar[bool] = True
needs_threadpool: ClassVar[bool] = False
servlets: ClassVar[List[RegisterServletsFunc]] = []

def __init__(self, methodName, *args, **kwargs):
super().__init__(methodName, *args, **kwargs)
def __init__(self, methodName: str):
super().__init__(methodName)

# see if we have any additional config for this test
method = getattr(self, methodName)
Expand Down Expand Up @@ -301,9 +299,10 @@ async def get_user_by_req(request, allow_guest=False, rights="access"):
None,
)

self.hs.get_auth().get_user_by_req = get_user_by_req
self.hs.get_auth().get_user_by_access_token = get_user_by_access_token
self.hs.get_auth().get_access_token_from_request = Mock(
# Type ignore: mypy doesn't like us assigning to methods.
self.hs.get_auth().get_user_by_req = get_user_by_req # type: ignore[assignment]
self.hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment]
self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[assignment]
return_value="1234"
)

Expand Down Expand Up @@ -417,7 +416,7 @@ def make_request(
path: Union[bytes, str],
content: Union[bytes, str, JsonDict] = b"",
access_token: Optional[str] = None,
request: Type[T] = SynapseRequest,
request: Type[Request] = SynapseRequest,
shorthand: bool = True,
federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
Expand Down Expand Up @@ -596,7 +595,7 @@ def register_user(
nonce_str += b"\x00notadmin"

want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str)
want_mac = want_mac.hexdigest()
want_mac_digest = want_mac.hexdigest()

body = json.dumps(
{
Expand All @@ -605,7 +604,7 @@ def register_user(
"displayname": displayname,
"password": password,
"admin": admin,
"mac": want_mac,
"mac": want_mac_digest,
"inhibit_login": True,
}
)
Expand Down