Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
  • Loading branch information
sandhose and clokep committed Oct 25, 2022
1 parent 7c981b4 commit 3233581
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 80 deletions.
90 changes: 45 additions & 45 deletions tests/handlers/test_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ def default_config(self) -> Dict[str, Any]:
return config

def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.fake_provider = FakeOidcServer(clock=clock, issuer=ISSUER)
self.fake_server = FakeOidcServer(clock=clock, issuer=ISSUER)

hs = self.setup_test_homeserver()
self.hs_patcher = self.fake_provider.patch_homeserver(hs=hs)
self.hs_patcher = self.fake_server.patch_homeserver(hs=hs)
self.hs_patcher.start()

self.handler = hs.get_oidc_handler()
Expand All @@ -163,16 +163,16 @@ def tearDown(self) -> None:

def reset_mocks(self):
"""Reset all the Mocks."""
self.fake_provider.reset_mocks()
self.fake_server.reset_mocks()
self.render_error.reset_mock()
self.complete_sso_login.reset_mock()

def metadata_edit(self, values):
"""Modify the result that will be returned by the well-known query"""

metadata = self.fake_provider.get_metadata()
metadata = self.fake_server.get_metadata()
metadata.update(values)
return patch.object(self.fake_provider, "get_metadata", return_value=metadata)
return patch.object(self.fake_server, "get_metadata", return_value=metadata)

def start_authorization(
self,
Expand All @@ -185,7 +185,7 @@ def start_authorization(
nonce = random_string(10)
state = random_string(10)

code, grant = self.fake_provider.start_authorization(
code, grant = self.fake_server.start_authorization(
userinfo=userinfo,
scope=scope,
client_id=self.provider._client_auth.client_id,
Expand Down Expand Up @@ -218,48 +218,48 @@ def test_discovery(self) -> None:
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
metadata = self.get_success(self.provider.load_metadata())
self.fake_provider.get_metadata_handler.assert_called_once()
self.fake_server.get_metadata_handler.assert_called_once()

self.assertEqual(metadata.issuer, self.fake_provider.issuer)
self.assertEqual(metadata.issuer, self.fake_server.issuer)
self.assertEqual(
metadata.authorization_endpoint,
self.fake_provider.authorization_endpoint,
self.fake_server.authorization_endpoint,
)
self.assertEqual(metadata.token_endpoint, self.fake_provider.token_endpoint)
self.assertEqual(metadata.jwks_uri, self.fake_provider.jwks_uri)
self.assertEqual(metadata.token_endpoint, self.fake_server.token_endpoint)
self.assertEqual(metadata.jwks_uri, self.fake_server.jwks_uri)
# It seems like authlib does not have that defined in its metadata models
self.assertEqual(
metadata.get("userinfo_endpoint"),
self.fake_provider.userinfo_endpoint,
self.fake_server.userinfo_endpoint,
)

# subsequent calls should be cached
self.reset_mocks()
self.get_success(self.provider.load_metadata())
self.fake_provider.get_metadata_handler.assert_not_called()
self.fake_server.get_metadata_handler.assert_not_called()

@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_no_discovery(self) -> None:
"""When discovery is disabled, it should not try to load from discovery document."""
self.get_success(self.provider.load_metadata())
self.fake_provider.get_metadata_handler.assert_not_called()
self.fake_server.get_metadata_handler.assert_not_called()

@override_config({"oidc_config": DEFAULT_CONFIG})
def test_load_jwks(self) -> None:
"""JWKS loading is done once (then cached) if used."""
jwks = self.get_success(self.provider.load_jwks())
self.fake_provider.get_jwks_handler.assert_called_once()
self.assertEqual(jwks, self.fake_provider.get_jwks())
self.fake_server.get_jwks_handler.assert_called_once()
self.assertEqual(jwks, self.fake_server.get_jwks())

# subsequent calls should be cached…
self.reset_mocks()
self.get_success(self.provider.load_jwks())
self.fake_provider.get_jwks_handler.assert_not_called()
self.fake_server.get_jwks_handler.assert_not_called()

# …unless forced
self.reset_mocks()
self.get_success(self.provider.load_jwks(force=True))
self.fake_provider.get_jwks_handler.assert_called_once()
self.fake_server.get_jwks_handler.assert_called_once()

with self.metadata_edit({"jwks_uri": None}):
# If we don't do this, the load_metadata call will throw because of the
Expand Down Expand Up @@ -369,7 +369,7 @@ def test_redirect_request(self) -> None:
self.provider.handle_redirect_request(req, b"http://client/redirect")
)
)
auth_endpoint = urlparse(self.fake_provider.authorization_endpoint)
auth_endpoint = urlparse(self.fake_server.authorization_endpoint)

self.assertEqual(url.scheme, auth_endpoint.scheme)
self.assertEqual(url.netloc, auth_endpoint.netloc)
Expand Down Expand Up @@ -456,8 +456,8 @@ def test_callback(self) -> None:
new_user=True,
auth_provider_session_id=None,
)
self.fake_provider.post_token_handler.assert_called_once()
self.fake_provider.get_userinfo_handler.assert_not_called()
self.fake_server.post_token_handler.assert_called_once()
self.fake_server.get_userinfo_handler.assert_not_called()
self.render_error.assert_not_called()

# Handle mapping errors
Expand All @@ -472,7 +472,7 @@ def test_callback(self) -> None:

# Handle ID token errors
request, _ = self.start_authorization(userinfo)
with self.fake_provider.id_token_override({"iss": "https://bad.issuer/"}):
with self.fake_server.id_token_override({"iss": "https://bad.issuer/"}):
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")

Expand All @@ -493,8 +493,8 @@ def test_callback(self) -> None:
new_user=False,
auth_provider_session_id=None,
)
self.fake_provider.post_token_handler.assert_called_once()
self.fake_provider.get_userinfo_handler.assert_called_once()
self.fake_server.post_token_handler.assert_called_once()
self.fake_server.get_userinfo_handler.assert_called_once()
self.render_error.assert_not_called()

self.reset_mocks()
Expand All @@ -514,18 +514,18 @@ def test_callback(self) -> None:
new_user=False,
auth_provider_session_id=grant.sid,
)
self.fake_provider.post_token_handler.assert_called_once()
self.fake_provider.get_userinfo_handler.assert_called_once()
self.fake_server.post_token_handler.assert_called_once()
self.fake_server.get_userinfo_handler.assert_called_once()
self.render_error.assert_not_called()

# Handle userinfo fetching error
request, _ = self.start_authorization(userinfo)
with self.fake_provider.buggy_endpoint(userinfo=True):
with self.fake_server.buggy_endpoint(userinfo=True):
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")

request, _ = self.start_authorization(userinfo)
with self.fake_provider.buggy_endpoint(token=True):
with self.fake_server.buggy_endpoint(token=True):
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("server_error")

Expand Down Expand Up @@ -582,17 +582,17 @@ def test_exchange_code(self) -> None:
"access_token": "aabbcc",
}

self.fake_provider.post_token_handler.side_effect = None
self.fake_provider.post_token_handler.return_value = FakeResponse.json(
self.fake_server.post_token_handler.side_effect = None
self.fake_server.post_token_handler.return_value = FakeResponse.json(
payload=token
)
code = "code"
ret = self.get_success(self.provider._exchange_code(code))
kwargs = self.fake_provider.request.call_args[1]
kwargs = self.fake_server.request.call_args[1]

self.assertEqual(ret, token)
self.assertEqual(kwargs["method"], "POST")
self.assertEqual(kwargs["uri"], self.fake_provider.token_endpoint)
self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)

args = parse_qs(kwargs["data"].decode("utf-8"))
self.assertEqual(args["grant_type"], ["authorization_code"])
Expand All @@ -602,7 +602,7 @@ def test_exchange_code(self) -> None:
self.assertEqual(args["redirect_uri"], [CALLBACK_URL])

# Test error handling
self.fake_provider.post_token_handler.return_value = FakeResponse.json(
self.fake_server.post_token_handler.return_value = FakeResponse.json(
code=400, payload={"error": "foo", "error_description": "bar"}
)
from synapse.handlers.oidc import OidcError
Expand All @@ -612,29 +612,29 @@ def test_exchange_code(self) -> None:
self.assertEqual(exc.value.error_description, "bar")

# Internal server error with no JSON body
self.fake_provider.post_token_handler.return_value = FakeResponse(
self.fake_server.post_token_handler.return_value = FakeResponse(
code=500, body=b"Not JSON"
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")

# Internal server error with JSON body
self.fake_provider.post_token_handler.return_value = FakeResponse.json(
self.fake_server.post_token_handler.return_value = FakeResponse.json(
code=500, payload={"error": "internal_server_error"}
)

exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "internal_server_error")

# 4xx error without "error" field
self.fake_provider.post_token_handler.return_value = FakeResponse.json(
self.fake_server.post_token_handler.return_value = FakeResponse.json(
code=400, payload={}
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")

# 2xx error with "error" field
self.fake_provider.post_token_handler.return_value = FakeResponse.json(
self.fake_server.post_token_handler.return_value = FakeResponse.json(
code=200, payload={"error": "some_error"}
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
Expand Down Expand Up @@ -664,8 +664,8 @@ def test_exchange_code_jwt_key(self) -> None:
"access_token": "aabbcc",
}

self.fake_provider.post_token_handler.side_effect = None
self.fake_provider.post_token_handler.return_value = FakeResponse.json(
self.fake_server.post_token_handler.side_effect = None
self.fake_server.post_token_handler.return_value = FakeResponse.json(
payload=token
)
code = "code"
Expand All @@ -679,9 +679,9 @@ def test_exchange_code_jwt_key(self) -> None:
self.assertEqual(ret, token)

# the request should have hit the token endpoint
kwargs = self.fake_provider.request.call_args[1]
kwargs = self.fake_server.request.call_args[1]
self.assertEqual(kwargs["method"], "POST")
self.assertEqual(kwargs["uri"], self.fake_provider.token_endpoint)
self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)

# the client secret provided to the should be a jwt which can be checked with
# the public key
Expand Down Expand Up @@ -720,8 +720,8 @@ def test_exchange_code_no_auth(self) -> None:
"access_token": "aabbcc",
}

self.fake_provider.post_token_handler.side_effect = None
self.fake_provider.post_token_handler.return_value = FakeResponse.json(
self.fake_server.post_token_handler.side_effect = None
self.fake_server.post_token_handler.return_value = FakeResponse.json(
payload=token
)
code = "code"
Expand All @@ -730,9 +730,9 @@ def test_exchange_code_no_auth(self) -> None:
self.assertEqual(ret, token)

# the request should have hit the token endpoint
kwargs = self.fake_provider.request.call_args[1]
kwargs = self.fake_server.request.call_args[1]
self.assertEqual(kwargs["method"], "POST")
self.assertEqual(kwargs["uri"], self.fake_provider.token_endpoint)
self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)

# check the POSTed data
args = parse_qs(kwargs["data"].decode("utf-8"))
Expand Down
24 changes: 12 additions & 12 deletions tests/rest/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,11 +465,11 @@ def test_ui_auth_via_sso(self) -> None:
* checking that the original operation succeeds
"""

fake_oidc_provider = self.helper.fake_oidc_server()
fake_oidc_server = self.helper.fake_oidc_server()

# log the user in
remote_user_id = UserID.from_string(self.user).localpart
login_resp, _ = self.helper.login_via_oidc(fake_oidc_provider, remote_user_id)
login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, remote_user_id)
self.assertEqual(login_resp["user_id"], self.user)

# initiate a UI Auth process by attempting to delete the device
Expand All @@ -484,7 +484,7 @@ def test_ui_auth_via_sso(self) -> None:
# run the UIA-via-SSO flow
session_id = channel.json_body["session"]
channel, _ = self.helper.auth_via_oidc(
fake_oidc_provider, {"sub": remote_user_id}, ui_auth_session_id=session_id
fake_oidc_server, {"sub": remote_user_id}, ui_auth_session_id=session_id
)

# that should serve a confirmation page
Expand All @@ -501,8 +501,8 @@ def test_ui_auth_via_sso(self) -> None:
@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_does_not_offer_password_for_sso_user(self) -> None:
fake_oidc_provider = self.helper.fake_oidc_server()
login_resp, _ = self.helper.login_via_oidc(fake_oidc_provider, "username")
fake_oidc_server = self.helper.fake_oidc_server()
login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, "username")
user_tok = login_resp["access_token"]
device_id = login_resp["device_id"]

Expand All @@ -525,9 +525,9 @@ def test_does_not_offer_sso_for_password_user(self) -> None:
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_offers_both_flows_for_upgraded_user(self) -> None:
"""A user that had a password and then logged in with SSO should get both flows"""
fake_oidc_provider = self.helper.fake_oidc_server()
fake_oidc_server = self.helper.fake_oidc_server()
login_resp, _ = self.helper.login_via_oidc(
fake_oidc_provider, UserID.from_string(self.user).localpart
fake_oidc_server, UserID.from_string(self.user).localpart
)
self.assertEqual(login_resp["user_id"], self.user)

Expand All @@ -546,11 +546,11 @@ def test_offers_both_flows_for_upgraded_user(self) -> None:
def test_ui_auth_fails_for_incorrect_sso_user(self) -> None:
"""If the user tries to authenticate with the wrong SSO user, they get an error"""

fake_oidc_provider = self.helper.fake_oidc_server()
fake_oidc_server = self.helper.fake_oidc_server()

# log the user in
login_resp, _ = self.helper.login_via_oidc(
fake_oidc_provider, UserID.from_string(self.user).localpart
fake_oidc_server, UserID.from_string(self.user).localpart
)
self.assertEqual(login_resp["user_id"], self.user)

Expand All @@ -565,7 +565,7 @@ def test_ui_auth_fails_for_incorrect_sso_user(self) -> None:

# do the OIDC auth, but auth as the wrong user
channel, _ = self.helper.auth_via_oidc(
fake_oidc_provider, {"sub": "wrong_user"}, ui_auth_session_id=session_id
fake_oidc_server, {"sub": "wrong_user"}, ui_auth_session_id=session_id
)

# that should return a failure message
Expand Down Expand Up @@ -595,9 +595,9 @@ def test_sso_not_approved(self) -> None:
"""Tests that if we register a user via SSO while requiring approval for new
accounts, we still raise the correct error before logging the user in.
"""
fake_oidc_provider = self.helper.fake_oidc_server()
fake_oidc_server = self.helper.fake_oidc_server()
login_resp, _ = self.helper.login_via_oidc(
fake_oidc_provider, "username", expected_status=403
fake_oidc_server, "username", expected_status=403
)

self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL)
Expand Down
Loading

0 comments on commit 3233581

Please sign in to comment.