Skip to content

Commit

Permalink
refactor: Clean up authentication code; add tests for OIDC
Browse files Browse the repository at this point in the history
  • Loading branch information
MoritzWeber0 committed Jul 30, 2024
1 parent 4ca31f3 commit fd0c755
Show file tree
Hide file tree
Showing 47 changed files with 1,691 additions and 441 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ helm-deploy:
--set docker.registry.sessions=$(CAPELLACOLLAB_SESSIONS_REGISTRY) \
--set docker.tag=$(DOCKER_TAG) \
--set mocks.oauth=True \
--set authentication.claimMapping.username=sub \
--set authentication.endpoints.authorization=https://localhost/default/authorize \
--set backend.authentication.claimMapping.username=sub \
--set backend.authentication.endpoints.authorization=https://localhost/default/authorize \
--set development=$(DEVELOPMENT_MODE) \
--set cluster.ingressClassName=traefik \
--set cluster.ingressNamespace=kube-system \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright DB InfraGO AG and contributors
# SPDX-License-Identifier: Apache-2.0

"""Commit message
"""Add column for Jupyter token
Revision ID: f3d2dedd7906
Revises: 4df9c82766e2
Expand Down
10 changes: 5 additions & 5 deletions backend/capellacollab/core/authentication/api_key_cookie.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from capellacollab.config import config

from . import exceptions, oidc_provider
from . import exceptions, oidc

log = logging.getLogger(__name__)

Expand All @@ -21,7 +21,7 @@
class JWTConfig:
_jwks_client = None

def __init__(self, oidc_config: oidc_provider.AbstractOIDCProviderConfig):
def __init__(self, oidc_config: oidc.OIDCProviderConfig):
self.oidc_config = oidc_config

if JWTConfig._jwks_client is None:
Expand All @@ -32,10 +32,10 @@ def __init__(self, oidc_config: oidc_provider.AbstractOIDCProviderConfig):


class JWTAPIKeyCookie(security.APIKeyCookie):
def __init__(self, oidc_config: oidc_provider.AbstractOIDCProviderConfig):
def __init__(self):
super().__init__(name="id_token", auto_error=True)
self.oidc_config = oidc_config
self.jwt_config = JWTConfig(oidc_config)
self.oidc_config = oidc.get_cached_oidc_config()
self.jwt_config = JWTConfig(self.oidc_config)

async def __call__(self, request: fastapi.Request) -> str:
token: str | None = await super().__call__(request)
Expand Down
25 changes: 2 additions & 23 deletions backend/capellacollab/core/authentication/injectables.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from __future__ import annotations

import dataclasses
import functools
import logging

import fastapi
Expand All @@ -24,28 +23,11 @@
from capellacollab.users import exceptions as users_exceptions
from capellacollab.users import models as users_models

from . import exceptions, oidc_provider
from . import exceptions

logger = logging.getLogger(__name__)


@functools.lru_cache
def get_cached_oidc_config() -> oidc_provider.AbstractOIDCProviderConfig:
return oidc_provider.WellKnownOIDCProviderConfig()


async def get_oidc_config() -> oidc_provider.AbstractOIDCProviderConfig:
return get_cached_oidc_config()


async def get_oidc_provider(
oidc_config: oidc_provider.AbstractOIDCProviderConfig = fastapi.Depends(
get_oidc_config
),
) -> oidc_provider.AbstractOIDCProvider:
return oidc_provider.OIDCProvider(oidc_config)


class OpenAPIFakeBase(security_base.SecurityBase):
"""Fake class to display the authentication methods in the OpenAPI docs
Expand Down Expand Up @@ -79,13 +61,10 @@ class OpenAPIPersonalAccessToken(OpenAPIFakeBase):

async def get_username(
request: fastapi.Request,
oidc_config: oidc_provider.AbstractOIDCProviderConfig = fastapi.Depends(
get_oidc_config
),
_unused1=fastapi.Depends(OpenAPIPersonalAccessToken()),
) -> str:
if request.cookies.get("id_token"):
username = await api_key_cookie.JWTAPIKeyCookie(oidc_config)(request)
username = await api_key_cookie.JWTAPIKeyCookie()(request)
return username

authorization = request.headers.get("Authorization")
Expand Down
7 changes: 7 additions & 0 deletions backend/capellacollab/core/authentication/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,10 @@ class TokenRequest(core_pydantic.BaseModel):
code: str
nonce: str
code_verifier: str


class AuthorizationResponse(core_pydantic.BaseModel):
auth_url: str
state: str
nonce: str
code_verifier: str
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright DB InfraGO AG and contributors
# SPDX-License-Identifier: Apache-2.0

import abc
import functools
import logging
import typing as t

Expand All @@ -10,46 +10,12 @@

from capellacollab.config import config

from . import exceptions
from . import exceptions, models

logger = logging.getLogger(__name__)


class AbstractOIDCProviderConfig(abc.ABC):
@abc.abstractmethod
def get_authorization_endpoint(self) -> str:
pass

@abc.abstractmethod
def get_token_endpoint(self) -> str:
pass

@abc.abstractmethod
def get_jwks_uri(self) -> str:
pass

@abc.abstractmethod
def get_supported_signing_algorithms(self) -> list[str]:
pass

@abc.abstractmethod
def get_issuer(self) -> str:
pass

@abc.abstractmethod
def get_scopes(self) -> list[str]:
pass

@abc.abstractmethod
def get_client_id(self) -> str:
pass

@abc.abstractmethod
def get_client_secret(self) -> str:
pass


class WellKnownOIDCProviderConfig(AbstractOIDCProviderConfig):
class OIDCProviderConfig:
def __init__(self):
self.well_known_uri = config.authentication.endpoints.well_known
self.well_known = self._fetch_well_known_configuration()
Expand Down Expand Up @@ -92,32 +58,16 @@ def get_client_id(self) -> str:
return config.authentication.client.id


class AbstractOIDCProvider(abc.ABC):
def __init__(self, oidc_config: AbstractOIDCProviderConfig):
self.oidc_config = oidc_config
@functools.lru_cache
def get_cached_oidc_config() -> OIDCProviderConfig:
return OIDCProviderConfig()

@abc.abstractmethod
def get_authorization_url_with_parameters(
self,
) -> t.Tuple[str, str, str, str]:
pass

@abc.abstractmethod
def exchange_code_for_tokens(
self, authorization_code: str, code_verifier: str
) -> dict[str, t.Any]:
pass

@abc.abstractmethod
def refresh_token(self, _refresh_token: str) -> dict[str, t.Any]:
pass


class OIDCProvider(AbstractOIDCProvider):
class OIDCProvider:
CODE_CHALLENGE_METHOD = "S256"

def __init__(self, oidc_config: AbstractOIDCProviderConfig):
super().__init__(oidc_config)
def __init__(self):
self.oidc_config = get_cached_oidc_config()
self.web_client: oauth2.WebApplicationClient = (
oauth2.WebApplicationClient(
client_id=self.oidc_config.get_client_id()
Expand All @@ -126,10 +76,10 @@ def __init__(self, oidc_config: AbstractOIDCProviderConfig):

def get_authorization_url_with_parameters(
self,
) -> t.Tuple[str, str, str, str]:
) -> models.AuthorizationResponse:
state = common.generate_token()

nonce = common.generate_nonce()

code_verifier = self.web_client.create_code_verifier(length=43)
code_challenge = self.web_client.create_code_challenge(
code_verifier, OIDCProvider.CODE_CHALLENGE_METHOD
Expand All @@ -145,7 +95,12 @@ def get_authorization_url_with_parameters(
code_challenge_method=OIDCProvider.CODE_CHALLENGE_METHOD,
)

return (auth_url, state, nonce, code_verifier)
return models.AuthorizationResponse(
auth_url=auth_url,
state=state,
nonce=nonce,
code_verifier=code_verifier,
)

def exchange_code_for_tokens(
self, authorization_code: str, code_verifier: str
Expand Down
87 changes: 35 additions & 52 deletions backend/capellacollab/core/authentication/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,100 +13,73 @@
from capellacollab.users import crud as users_crud
from capellacollab.users import models as users_models

from . import api_key_cookie, exceptions, injectables, models, oidc_provider
from . import api_key_cookie, exceptions, models, oidc

router = fastapi.APIRouter()
router = fastapi.APIRouter(tags=["Authentication"])


@router.get("", name="Get the authorization URL for the OAuth Server")
async def get_redirect_url(
@router.get("")
async def get_authorization_url(
response: fastapi.Response,
provider: oidc_provider.AbstractOIDCProvider = fastapi.Depends(
injectables.get_oidc_provider
),
) -> dict[str, str]:
auth_url, state, nonce, code_verifier = (
provider.get_authorization_url_with_parameters()
) -> models.AuthorizationResponse:
authorization_response = (
oidc.OIDCProvider().get_authorization_url_with_parameters()
)
delete_token_cookies(response)

return {
"auth_url": auth_url,
"state": state,
"nonce": nonce,
"code_verifier": code_verifier,
}
return authorization_response


@router.post("/tokens", name="Create the identity token")
async def api_get_token(
@router.post("/tokens")
async def get_identity_token(
token_request: models.TokenRequest,
response: fastapi.Response,
db: orm.Session = fastapi.Depends(database.get_db),
provider: oidc_provider.AbstractOIDCProvider = fastapi.Depends(
injectables.get_oidc_provider
),
oidc_config: oidc_provider.AbstractOIDCProviderConfig = fastapi.Depends(
injectables.get_oidc_config
),
):
tokens = provider.exchange_code_for_tokens(
tokens = oidc.OIDCProvider().exchange_code_for_tokens(
token_request.code, token_request.code_verifier
)

validated_id_token = validate_id_token(
tokens["id_token"], oidc_config, None
)
validated_id_token = validate_id_token(tokens["id_token"], None)
user = create_or_update_user(db, validated_id_token)

update_token_cookies(
response, tokens["id_token"], tokens.get("refresh_token", None), user
)


@router.put("/tokens", name="Refresh the identity token")
async def api_refresh_token(
@router.put("/tokens")
async def refresh_identity_token(
response: fastapi.Response,
refresh_token: t.Annotated[str | None, fastapi.Cookie()] = None,
db: orm.Session = fastapi.Depends(database.get_db),
provider: oidc_provider.AbstractOIDCProvider = fastapi.Depends(
injectables.get_oidc_provider
),
oidc_config: oidc_provider.AbstractOIDCProviderConfig = fastapi.Depends(
injectables.get_oidc_config
),
):
) -> None:
if refresh_token is None or refresh_token == "":
raise exceptions.RefreshTokenCookieMissingError()

tokens = provider.refresh_token(refresh_token)
tokens = oidc.OIDCProvider().refresh_token(refresh_token)

validated_id_token = validate_id_token(
tokens["id_token"], oidc_config, None
)
validated_id_token = validate_id_token(tokens["id_token"], None)
user = create_or_update_user(db, validated_id_token)

update_token_cookies(
response, tokens["id_token"], tokens.get("refresh_token", None), user
)


@router.delete("/tokens", name="Remove the token (log out)")
@router.delete("/tokens")
async def logout(response: fastapi.Response):
delete_token_cookies(response)
return None


@router.get("/tokens", name="Validate the token")
async def validate_token(
@router.get("/tokens")
async def validate_jwt_token(
request: fastapi.Request,
scope: users_models.Role | None = None,
db: orm.Session = fastapi.Depends(database.get_db),
oidc_config: oidc_provider.AbstractOIDCProviderConfig = fastapi.Depends(
injectables.get_oidc_config
),
):
username = await api_key_cookie.JWTAPIKeyCookie(oidc_config)(request)
username = await api_key_cookie.JWTAPIKeyCookie()(request)
if scope and scope.ADMIN:
auth_injectables.RoleVerification(
required_role=users_models.Role.ADMIN
Expand All @@ -116,12 +89,12 @@ async def validate_token(

def validate_id_token(
id_token: str,
oidc_config: oidc_provider.AbstractOIDCProviderConfig,
nonce: str | None,
) -> dict[str, str]:
validated_id_token = api_key_cookie.JWTAPIKeyCookie(
oidc_config
).validate_token(id_token)
validated_id_token = api_key_cookie.JWTAPIKeyCookie().validate_token(
id_token
)
oidc_config = oidc.get_cached_oidc_config()

if nonce and not hmac.compare_digest(validated_id_token["nonce"], nonce):
raise exceptions.NonceMismatchError()
Expand All @@ -146,6 +119,16 @@ def create_or_update_user(
user = users_crud.create_user(db, username, idp_identifier, email)
events_crud.create_user_creation_event(db, user)

if user.email != email:
user = users_crud.update_user(
db, user, users_models.PatchUser(email=email)
)

if user.name != username:
user = users_crud.update_user(
db, user, users_models.PatchUser(name=username)
)

users_crud.update_last_login(db, user)

return user
Expand Down
Loading

0 comments on commit fd0c755

Please sign in to comment.