diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index e8394809efb4..5955410524c9 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -134,10 +134,10 @@ def default_config(self) -> Dict[str, Any]: return config def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.fake_provider = FakeOidcServer(clock=clock, issuer=ISSUER) + self.fake_server = FakeOidcServer(clock=clock, issuer=ISSUER) hs = self.setup_test_homeserver() - self.hs_patcher = self.fake_provider.patch_homeserver(hs=hs) + self.hs_patcher = self.fake_server.patch_homeserver(hs=hs) self.hs_patcher.start() self.handler = hs.get_oidc_handler() @@ -163,16 +163,16 @@ def tearDown(self) -> None: def reset_mocks(self): """Reset all the Mocks.""" - self.fake_provider.reset_mocks() + self.fake_server.reset_mocks() self.render_error.reset_mock() self.complete_sso_login.reset_mock() def metadata_edit(self, values): """Modify the result that will be returned by the well-known query""" - metadata = self.fake_provider.get_metadata() + metadata = self.fake_server.get_metadata() metadata.update(values) - return patch.object(self.fake_provider, "get_metadata", return_value=metadata) + return patch.object(self.fake_server, "get_metadata", return_value=metadata) def start_authorization( self, @@ -185,7 +185,7 @@ def start_authorization( nonce = random_string(10) state = random_string(10) - code, grant = self.fake_provider.start_authorization( + code, grant = self.fake_server.start_authorization( userinfo=userinfo, scope=scope, client_id=self.provider._client_auth.client_id, @@ -218,48 +218,48 @@ def test_discovery(self) -> None: """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid metadata = self.get_success(self.provider.load_metadata()) - self.fake_provider.get_metadata_handler.assert_called_once() + self.fake_server.get_metadata_handler.assert_called_once() - self.assertEqual(metadata.issuer, self.fake_provider.issuer) + self.assertEqual(metadata.issuer, self.fake_server.issuer) self.assertEqual( metadata.authorization_endpoint, - self.fake_provider.authorization_endpoint, + self.fake_server.authorization_endpoint, ) - self.assertEqual(metadata.token_endpoint, self.fake_provider.token_endpoint) - self.assertEqual(metadata.jwks_uri, self.fake_provider.jwks_uri) + self.assertEqual(metadata.token_endpoint, self.fake_server.token_endpoint) + self.assertEqual(metadata.jwks_uri, self.fake_server.jwks_uri) # It seems like authlib does not have that defined in its metadata models self.assertEqual( metadata.get("userinfo_endpoint"), - self.fake_provider.userinfo_endpoint, + self.fake_server.userinfo_endpoint, ) # subsequent calls should be cached self.reset_mocks() self.get_success(self.provider.load_metadata()) - self.fake_provider.get_metadata_handler.assert_not_called() + self.fake_server.get_metadata_handler.assert_not_called() @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) def test_no_discovery(self) -> None: """When discovery is disabled, it should not try to load from discovery document.""" self.get_success(self.provider.load_metadata()) - self.fake_provider.get_metadata_handler.assert_not_called() + self.fake_server.get_metadata_handler.assert_not_called() @override_config({"oidc_config": DEFAULT_CONFIG}) def test_load_jwks(self) -> None: """JWKS loading is done once (then cached) if used.""" jwks = self.get_success(self.provider.load_jwks()) - self.fake_provider.get_jwks_handler.assert_called_once() - self.assertEqual(jwks, self.fake_provider.get_jwks()) + self.fake_server.get_jwks_handler.assert_called_once() + self.assertEqual(jwks, self.fake_server.get_jwks()) # subsequent calls should be cached… self.reset_mocks() self.get_success(self.provider.load_jwks()) - self.fake_provider.get_jwks_handler.assert_not_called() + self.fake_server.get_jwks_handler.assert_not_called() # …unless forced self.reset_mocks() self.get_success(self.provider.load_jwks(force=True)) - self.fake_provider.get_jwks_handler.assert_called_once() + self.fake_server.get_jwks_handler.assert_called_once() with self.metadata_edit({"jwks_uri": None}): # If we don't do this, the load_metadata call will throw because of the @@ -369,7 +369,7 @@ def test_redirect_request(self) -> None: self.provider.handle_redirect_request(req, b"http://client/redirect") ) ) - auth_endpoint = urlparse(self.fake_provider.authorization_endpoint) + auth_endpoint = urlparse(self.fake_server.authorization_endpoint) self.assertEqual(url.scheme, auth_endpoint.scheme) self.assertEqual(url.netloc, auth_endpoint.netloc) @@ -456,8 +456,8 @@ def test_callback(self) -> None: new_user=True, auth_provider_session_id=None, ) - self.fake_provider.post_token_handler.assert_called_once() - self.fake_provider.get_userinfo_handler.assert_not_called() + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_not_called() self.render_error.assert_not_called() # Handle mapping errors @@ -472,7 +472,7 @@ def test_callback(self) -> None: # Handle ID token errors request, _ = self.start_authorization(userinfo) - with self.fake_provider.id_token_override({"iss": "https://bad.issuer/"}): + with self.fake_server.id_token_override({"iss": "https://bad.issuer/"}): self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") @@ -493,8 +493,8 @@ def test_callback(self) -> None: new_user=False, auth_provider_session_id=None, ) - self.fake_provider.post_token_handler.assert_called_once() - self.fake_provider.get_userinfo_handler.assert_called_once() + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_called_once() self.render_error.assert_not_called() self.reset_mocks() @@ -514,18 +514,18 @@ def test_callback(self) -> None: new_user=False, auth_provider_session_id=grant.sid, ) - self.fake_provider.post_token_handler.assert_called_once() - self.fake_provider.get_userinfo_handler.assert_called_once() + self.fake_server.post_token_handler.assert_called_once() + self.fake_server.get_userinfo_handler.assert_called_once() self.render_error.assert_not_called() # Handle userinfo fetching error request, _ = self.start_authorization(userinfo) - with self.fake_provider.buggy_endpoint(userinfo=True): + with self.fake_server.buggy_endpoint(userinfo=True): self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("fetch_error") request, _ = self.start_authorization(userinfo) - with self.fake_provider.buggy_endpoint(token=True): + with self.fake_server.buggy_endpoint(token=True): self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("server_error") @@ -582,17 +582,17 @@ def test_exchange_code(self) -> None: "access_token": "aabbcc", } - self.fake_provider.post_token_handler.side_effect = None - self.fake_provider.post_token_handler.return_value = FakeResponse.json( + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( payload=token ) code = "code" ret = self.get_success(self.provider._exchange_code(code)) - kwargs = self.fake_provider.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(ret, token) self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], self.fake_provider.token_endpoint) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) args = parse_qs(kwargs["data"].decode("utf-8")) self.assertEqual(args["grant_type"], ["authorization_code"]) @@ -602,7 +602,7 @@ def test_exchange_code(self) -> None: self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) # Test error handling - self.fake_provider.post_token_handler.return_value = FakeResponse.json( + self.fake_server.post_token_handler.return_value = FakeResponse.json( code=400, payload={"error": "foo", "error_description": "bar"} ) from synapse.handlers.oidc import OidcError @@ -612,14 +612,14 @@ def test_exchange_code(self) -> None: self.assertEqual(exc.value.error_description, "bar") # Internal server error with no JSON body - self.fake_provider.post_token_handler.return_value = FakeResponse( + self.fake_server.post_token_handler.return_value = FakeResponse( code=500, body=b"Not JSON" ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # Internal server error with JSON body - self.fake_provider.post_token_handler.return_value = FakeResponse.json( + self.fake_server.post_token_handler.return_value = FakeResponse.json( code=500, payload={"error": "internal_server_error"} ) @@ -627,14 +627,14 @@ def test_exchange_code(self) -> None: self.assertEqual(exc.value.error, "internal_server_error") # 4xx error without "error" field - self.fake_provider.post_token_handler.return_value = FakeResponse.json( + self.fake_server.post_token_handler.return_value = FakeResponse.json( code=400, payload={} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # 2xx error with "error" field - self.fake_provider.post_token_handler.return_value = FakeResponse.json( + self.fake_server.post_token_handler.return_value = FakeResponse.json( code=200, payload={"error": "some_error"} ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) @@ -664,8 +664,8 @@ def test_exchange_code_jwt_key(self) -> None: "access_token": "aabbcc", } - self.fake_provider.post_token_handler.side_effect = None - self.fake_provider.post_token_handler.return_value = FakeResponse.json( + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( payload=token ) code = "code" @@ -679,9 +679,9 @@ def test_exchange_code_jwt_key(self) -> None: self.assertEqual(ret, token) # the request should have hit the token endpoint - kwargs = self.fake_provider.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], self.fake_provider.token_endpoint) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) # the client secret provided to the should be a jwt which can be checked with # the public key @@ -720,8 +720,8 @@ def test_exchange_code_no_auth(self) -> None: "access_token": "aabbcc", } - self.fake_provider.post_token_handler.side_effect = None - self.fake_provider.post_token_handler.return_value = FakeResponse.json( + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( payload=token ) code = "code" @@ -730,9 +730,9 @@ def test_exchange_code_no_auth(self) -> None: self.assertEqual(ret, token) # the request should have hit the token endpoint - kwargs = self.fake_provider.request.call_args[1] + kwargs = self.fake_server.request.call_args[1] self.assertEqual(kwargs["method"], "POST") - self.assertEqual(kwargs["uri"], self.fake_provider.token_endpoint) + self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint) # check the POSTed data args = parse_qs(kwargs["data"].decode("utf-8")) diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 74a01104f672..ebf653d018f6 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -465,11 +465,11 @@ def test_ui_auth_via_sso(self) -> None: * checking that the original operation succeeds """ - fake_oidc_provider = self.helper.fake_oidc_server() + fake_oidc_server = self.helper.fake_oidc_server() # log the user in remote_user_id = UserID.from_string(self.user).localpart - login_resp, _ = self.helper.login_via_oidc(fake_oidc_provider, remote_user_id) + login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, remote_user_id) self.assertEqual(login_resp["user_id"], self.user) # initiate a UI Auth process by attempting to delete the device @@ -484,7 +484,7 @@ def test_ui_auth_via_sso(self) -> None: # run the UIA-via-SSO flow session_id = channel.json_body["session"] channel, _ = self.helper.auth_via_oidc( - fake_oidc_provider, {"sub": remote_user_id}, ui_auth_session_id=session_id + fake_oidc_server, {"sub": remote_user_id}, ui_auth_session_id=session_id ) # that should serve a confirmation page @@ -501,8 +501,8 @@ def test_ui_auth_via_sso(self) -> None: @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_does_not_offer_password_for_sso_user(self) -> None: - fake_oidc_provider = self.helper.fake_oidc_server() - login_resp, _ = self.helper.login_via_oidc(fake_oidc_provider, "username") + fake_oidc_server = self.helper.fake_oidc_server() + login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, "username") user_tok = login_resp["access_token"] device_id = login_resp["device_id"] @@ -525,9 +525,9 @@ def test_does_not_offer_sso_for_password_user(self) -> None: @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_offers_both_flows_for_upgraded_user(self) -> None: """A user that had a password and then logged in with SSO should get both flows""" - fake_oidc_provider = self.helper.fake_oidc_server() + fake_oidc_server = self.helper.fake_oidc_server() login_resp, _ = self.helper.login_via_oidc( - fake_oidc_provider, UserID.from_string(self.user).localpart + fake_oidc_server, UserID.from_string(self.user).localpart ) self.assertEqual(login_resp["user_id"], self.user) @@ -546,11 +546,11 @@ def test_offers_both_flows_for_upgraded_user(self) -> None: def test_ui_auth_fails_for_incorrect_sso_user(self) -> None: """If the user tries to authenticate with the wrong SSO user, they get an error""" - fake_oidc_provider = self.helper.fake_oidc_server() + fake_oidc_server = self.helper.fake_oidc_server() # log the user in login_resp, _ = self.helper.login_via_oidc( - fake_oidc_provider, UserID.from_string(self.user).localpart + fake_oidc_server, UserID.from_string(self.user).localpart ) self.assertEqual(login_resp["user_id"], self.user) @@ -565,7 +565,7 @@ def test_ui_auth_fails_for_incorrect_sso_user(self) -> None: # do the OIDC auth, but auth as the wrong user channel, _ = self.helper.auth_via_oidc( - fake_oidc_provider, {"sub": "wrong_user"}, ui_auth_session_id=session_id + fake_oidc_server, {"sub": "wrong_user"}, ui_auth_session_id=session_id ) # that should return a failure message @@ -595,9 +595,9 @@ def test_sso_not_approved(self) -> None: """Tests that if we register a user via SSO while requiring approval for new accounts, we still raise the correct error before logging the user in. """ - fake_oidc_provider = self.helper.fake_oidc_server() + fake_oidc_server = self.helper.fake_oidc_server() login_resp, _ = self.helper.login_via_oidc( - fake_oidc_provider, "username", expected_status=403 + fake_oidc_server, "username", expected_status=403 ) self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 51e62b650bac..ff5baa9f0a78 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -612,9 +612,9 @@ def test_multi_sso_redirect_to_saml(self) -> None: def test_login_via_oidc(self) -> None: """If OIDC is chosen, should redirect to the OIDC auth endpoint""" - fake_oidc_provider = self.helper.fake_oidc_server() + fake_oidc_server = self.helper.fake_oidc_server() - with fake_oidc_provider.patch_homeserver(hs=self.hs): + with fake_oidc_server.patch_homeserver(hs=self.hs): # pick the default OIDC provider channel = self.make_request( "GET", @@ -629,7 +629,7 @@ def test_login_via_oidc(self) -> None: oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, fake_oidc_provider.authorization_endpoint) + self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint) # ... and should have set a cookie including the redirect url cookie_headers = channel.headers.getRawHeaders("Set-Cookie") @@ -647,7 +647,7 @@ def test_login_via_oidc(self) -> None: ) channel, _ = self.helper.complete_oidc_auth( - fake_oidc_provider, oidc_uri, cookies, {"sub": "user1"} + fake_oidc_server, oidc_uri, cookies, {"sub": "user1"} ) # that should serve a confirmation page @@ -698,9 +698,9 @@ def test_client_idp_redirect_to_unknown(self) -> None: def test_client_idp_redirect_to_oidc(self) -> None: """If the client pick a known IdP, redirect to it""" - fake_oidc_provider = self.helper.fake_oidc_server() + fake_oidc_server = self.helper.fake_oidc_server() - with fake_oidc_provider.patch_homeserver(hs=self.hs): + with fake_oidc_server.patch_homeserver(hs=self.hs): channel = self._make_sso_redirect_request("oidc") self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") @@ -709,7 +709,7 @@ def test_client_idp_redirect_to_oidc(self) -> None: oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, fake_oidc_provider.authorization_endpoint) + self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint) def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel: """Send a request to /_matrix/client/r0/login/sso/redirect @@ -1288,11 +1288,11 @@ def create_resource_dict(self) -> Dict[str, Resource]: def test_username_picker(self) -> None: """Test the happy path of a username picker flow.""" - fake_oidc_provider = self.helper.fake_oidc_server() + fake_oidc_server = self.helper.fake_oidc_server() # do the start of the login flow channel, _ = self.helper.auth_via_oidc( - fake_oidc_provider, + fake_oidc_server, {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL, ) diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index c47b63fff45d..e62ebcc6a5a3 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -114,7 +114,7 @@ class FakeResponse: # type: ignore[misc] attribute, and didn't support deliverBody until recently. """ - verison: Tuple[bytes, int, int] = (b"HTTP", 1, 1) + version: Tuple[bytes, int, int] = (b"HTTP", 1, 1) # HTTP response code code: int = 200 @@ -122,7 +122,7 @@ class FakeResponse: # type: ignore[misc] # body of the response body: bytes = b"" - headers: Headers = Headers() + headers: Headers = attr.Factory(Headers) @property def phrase(self): diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py index 914446101717..de134bbc893b 100644 --- a/tests/test_utils/oidc.py +++ b/tests/test_utils/oidc.py @@ -54,7 +54,7 @@ class FakeOidcServer: def __init__(self, clock: Clock, issuer: str): from authlib.jose import ECKey, KeySet - self.clock = clock + self._clock = clock self.issuer = issuer self.request = Mock(side_effect=self._request) @@ -64,14 +64,14 @@ def __init__(self, clock: Clock, issuer: str): self.post_token_handler = Mock(side_effect=self._post_token_handler) # A code -> grant mapping - self.authorization_grants: Dict[str, FakeAuthorizationGrant] = {} + self._authorization_grants: Dict[str, FakeAuthorizationGrant] = {} # An access token -> grant mapping - self.sessions: Dict[str, FakeAuthorizationGrant] = {} + self._sessions: Dict[str, FakeAuthorizationGrant] = {} # We generate here an ECDSA key with the P-256 curve (ES256 algorithm) used for # signing JWTs. ECDSA keys are really quick to generate compared to RSA. - self.key = ECKey.generate_key(crv="P-256", is_private=True) - self.jwks = KeySet([ECKey.import_key(self.key.as_pem(is_private=False))]) + self._key = ECKey.generate_key(crv="P-256", is_private=True) + self._jwks = KeySet([ECKey.import_key(self._key.as_pem(is_private=False))]) self._id_token_overrides: Dict[str, Any] = {} @@ -127,11 +127,11 @@ def get_metadata(self) -> dict: } def get_jwks(self) -> dict: - return self.jwks.as_dict() + return self._jwks.as_dict() def get_userinfo(self, access_token: str) -> Optional[dict]: """Given an access token, get the userinfo of the associated session.""" - session = self.sessions.get(access_token, None) + session = self._sessions.get(access_token, None) if session is None: return None return session.userinfo @@ -143,10 +143,10 @@ def _sign(self, payload: dict) -> str: kid = self.get_jwks()["keys"][0]["kid"] protected = {"alg": "ES256", "kid": kid} json_payload = json.dumps(payload) - return jws.serialize_compact(protected, json_payload, self.key).decode("utf-8") + return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8") def generate_id_token(self, grant: FakeAuthorizationGrant) -> str: - now = self.clock.time() + now = self._clock.time() id_token = { **grant.userinfo, "iss": self.issuer, @@ -193,17 +193,17 @@ def start_authorization( client_id=client_id, sid=sid, ) - self.authorization_grants[code] = grant + self._authorization_grants[code] = grant return code, grant def exchange_code(self, code: str) -> Optional[Dict[str, Any]]: - grant = self.authorization_grants.pop(code, None) + grant = self._authorization_grants.pop(code, None) if grant is None: return None access_token = random_string(10) - self.sessions[access_token] = grant + self._sessions[access_token] = grant token = { "token_type": "Bearer",