Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Improve the synapse.api.auth.Auth mock used in unit tests. #13809

Merged
merged 10 commits into from
Sep 21, 2022
1 change: 1 addition & 0 deletions changelog.d/13809.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve the `synapse.api.auth.Auth` mock used in unit tests.
42 changes: 16 additions & 26 deletions tests/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
TypeVar,
Union,
)
from unittest.mock import Mock, patch
from unittest.mock import patch

import canonicaljson
import signedjson.key
Expand Down Expand Up @@ -300,46 +300,36 @@ def setUp(self) -> None:
if hasattr(self, "user_id"):
if self.hijack_auth:
assert self.helper.auth_user_id is not None
token = "some_fake_token"

# We need a valid token ID to satisfy foreign key constraints.
token_id = self.get_success(
self.hs.get_datastores().main.add_access_token_to_user(
self.helper.auth_user_id,
"some_fake_token",
token,
None,
None,
)
)

async def get_user_by_access_token(
token: Optional[str] = None, allow_guest: bool = False
) -> JsonDict:
assert self.helper.auth_user_id is not None
return {
"user": UserID.from_string(self.helper.auth_user_id),
"token_id": token_id,
"is_guest": False,
}

async def get_user_by_req(
request: SynapseRequest,
allow_guest: bool = False,
allow_expired: bool = False,
) -> Requester:
# This has to be a function and not just a Mock, because
# `self.helper.auth_user_id` is temporarily reassigned in some tests
async def get_requester(*args, **kwargs) -> Requester:
assert self.helper.auth_user_id is not None
return create_requester(
UserID.from_string(self.helper.auth_user_id),
token_id,
False,
False,
None,
user_id=UserID.from_string(self.helper.auth_user_id),
access_token_id=token_id,
)

# Type ignore: mypy doesn't like us assigning to methods.
self.hs.get_auth().get_user_by_req = get_user_by_req # type: ignore[assignment]
self.hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment]
self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[assignment]
return_value="1234"
self.hs.get_auth().get_user_by_req = get_requester # type: ignore[assignment]
self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[assignment]

patch.object(
self.hs.get_auth(),
"get_access_token_from_request",
return_value=token,
autospec=True,
)

if self.needs_threadpool:
Expand Down