diff --git a/.gitignore b/.gitignore index d5bbec0..0110069 100644 --- a/.gitignore +++ b/.gitignore @@ -145,3 +145,6 @@ cython_debug/ # Logs *.log + +# IntelliJ Idea based IDE +.idea diff --git a/asynction/__init__.py b/asynction/__init__.py index 3d55f2f..674a667 100644 --- a/asynction/__init__.py +++ b/asynction/__init__.py @@ -7,9 +7,11 @@ "PayloadValidationException", "BindingsValidationException", "MessageAckValidationException", + "SecurityInfo", ] from asynction.exceptions import * +from asynction.security import SecurityInfo from asynction.server import AsynctionSocketIO try: diff --git a/asynction/exceptions.py b/asynction/exceptions.py index f2361f1..66f8a2f 100644 --- a/asynction/exceptions.py +++ b/asynction/exceptions.py @@ -38,3 +38,12 @@ class MessageAckValidationException(ValidationException): """ pass + + +class SecurityException(AsynctionException, ConnectionRefusedError): + """ + Raised when an incoming connection fails to meet the requirements of + any of the specified security schemes. + """ + + pass diff --git a/asynction/mock_server.py b/asynction/mock_server.py index a057b13..84ef837 100644 --- a/asynction/mock_server.py +++ b/asynction/mock_server.py @@ -31,12 +31,15 @@ from hypothesis_jsonschema import from_schema from hypothesis_jsonschema._from_schema import STRING_FORMATS +from asynction.security import security_handler_factory from asynction.server import AsynctionSocketIO +from asynction.server import _noop_handler from asynction.types import AsyncApiSpec from asynction.types import ErrorHandler from asynction.types import JSONMapping from asynction.types import JSONSchema from asynction.types import Message +from asynction.types import SecurityRequirement from asynction.validation import bindings_validator_factory from asynction.validation import publish_message_validator_factory @@ -112,10 +115,6 @@ def task_scheduler( sleep() -def _noop_handler(*args, **kwargs) -> None: - return None - - class MockAsynctionSocketIO(AsynctionSocketIO): """Inherits the :class:`AsynctionSocketIO` class.""" @@ -210,7 +209,9 @@ def from_spec( ) def _register_handlers( - self, default_error_handler: Optional[ErrorHandler] = None + self, + server_security: Sequence[SecurityRequirement] = (), + default_error_handler: Optional[ErrorHandler] = None, ) -> None: for namespace, channel in self.spec.channels.items(): if channel.publish is not None: @@ -240,7 +241,16 @@ def _register_handlers( with_bindings_validation = bindings_validator_factory(channel.bindings) connect_handler = with_bindings_validation(connect_handler) - self.on_event("connect", connect_handler, namespace) + if server_security: + # create a security handler wrapper + with_security = security_handler_factory( + server_security, self.spec.components.security_schemes + ) + # apply security + connect_handler = with_security(connect_handler) + + if connect_handler is not _noop_handler: + self.on_event("connect", connect_handler, namespace) if default_error_handler is not None: self.on_error_default(default_error_handler) diff --git a/asynction/security.py b/asynction/security.py new file mode 100644 index 0000000..1b7903f --- /dev/null +++ b/asynction/security.py @@ -0,0 +1,386 @@ +import base64 +from functools import partial +from functools import wraps +from typing import Any +from typing import Callable +from typing import List +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Tuple + +from flask import Request +from flask import request as current_flask_request +from typing_extensions import TypedDict + +from asynction.exceptions import SecurityException +from asynction.types import ApiKeyLocation +from asynction.types import HTTPAuthenticationScheme +from asynction.types import SecurityRequirement +from asynction.types import SecurityScheme +from asynction.types import SecuritySchemesType +from asynction.utils import load_handler + + +class SecurityInfo(TypedDict, total=False): + """Security handler function response type. + + One of scopes, scope and one of sub, uid must be present + + Subclass this type to add extra fields to a security handler response + """ + + scopes: Sequence[str] + scope: str + sub: Any + uid: Any + + +TokenInfoFunc = Callable[[str], SecurityInfo] +BasicInfoFunc = Callable[[str, str, Optional[Sequence[str]]], SecurityInfo] +BearerInfoFunc = Callable[[str, Optional[Sequence[str]], Optional[str]], SecurityInfo] +APIKeyInfoFunc = Callable[[str, Optional[Sequence[str]]], SecurityInfo] +ScopeValidateFunc = Callable[[Sequence[str], Sequence[str]], bool] +InternalSecurityCheckResponse = Optional[SecurityInfo] +InternalSecurityCheck = Callable[[Request], InternalSecurityCheckResponse] +SecurityCheck = Callable[[Request], SecurityInfo] + +InternalSecurityRequirement = Tuple[str, Sequence[str]] +SecurityCheckFactory = Callable[ + [InternalSecurityRequirement, SecurityScheme], Optional[InternalSecurityCheck] +] + + +def unpack_security_requirement( + requirement: SecurityRequirement, +) -> InternalSecurityRequirement: + return next(iter(requirement.items())) + + +def unpack_security_requirements( + requirements: Sequence[SecurityRequirement], +) -> Sequence[InternalSecurityRequirement]: + return list(map(unpack_security_requirement, requirements)) + + +def extract_auth_header(request: Request) -> Optional[Tuple[str, str]]: + authorization = request.headers.get("Authorization") + + if not authorization: + return None + try: + lhs, rhs = authorization.split(None, 1) + if not lhs or not rhs: + raise SecurityException( + "invalid Authorization header" + " expected: " + f" found {authorization}" + ) + return lhs, rhs + except ValueError as err: + raise SecurityException from err + + +def validate_basic( + request: Request, basic_info_func: BasicInfoFunc, required_scopes: Sequence[str] +) -> Optional[SecurityInfo]: + auth = extract_auth_header(request) + if not auth: + return None + + auth_type, user_pass = auth + + if HTTPAuthenticationScheme(auth_type.lower()) != HTTPAuthenticationScheme.BASIC: + return None + + try: + username, password = base64.b64decode(user_pass).decode("latin1").split(":", 1) + except Exception as err: + raise SecurityException from err + + if not username or not password: + raise SecurityException + + token_info = basic_info_func(username, password, required_scopes) + if token_info is None: + raise SecurityException + + return token_info + + +def validate_oauth2_authorization_header( + request: Request, token_info_func: TokenInfoFunc +) -> Optional[SecurityInfo]: + """Check that the provided request contains a properly formatted Authorization + header and invokes the token_info_func on the token inside of the header. + """ + auth = extract_auth_header(request) + if not auth: + return None + + auth_type, token = auth + + if auth_type.lower() != "bearer": + return None + + token_info = token_info_func(token) + if token_info is None: + raise SecurityException + + return token_info + + +def validate_bearer( + request: Request, + bearer_info_func: BearerInfoFunc, + required_scopes: Sequence[str], + bearer_format: Optional[str] = None, +) -> Optional[SecurityInfo]: + """ + Adapted from: https://github.com/zalando/connexion/blob/main/connexion/security/security_handler_factory.py#L221 # noqa: 501 + """ + auth = extract_auth_header(request) + if not auth: + return None + + auth_type, token = auth + + if HTTPAuthenticationScheme(auth_type.lower()) != HTTPAuthenticationScheme.BEARER: + return None + + token_info = bearer_info_func(token, required_scopes, bearer_format) + if token_info is None: + raise SecurityException + + return token_info + + +def validate_scopes( + required_scopes: Sequence[str], token_scopes: Sequence[str] +) -> bool: + """Validates that all require scopes are present in the token scopes""" + missing_scopes = set(required_scopes) - set(token_scopes) + if missing_scopes: + raise SecurityException(f"Missing required scopes: {missing_scopes}") + + return not missing_scopes + + +def load_scope_validate_func(scheme: SecurityScheme) -> ScopeValidateFunc: + scope_validate_func = None + if scheme.x_scope_validate_func: + try: + scope_validate_func = load_handler(scheme.x_scope_validate_func) + except (AttributeError, ValueError) as err: + raise SecurityException from err + + if not scope_validate_func: + scope_validate_func = validate_scopes + + return scope_validate_func + + +def load_basic_info_func(scheme: SecurityScheme) -> BasicInfoFunc: + if not scheme.x_basic_info_func: + raise SecurityException("Missing basic info func") + try: + return load_handler(scheme.x_basic_info_func) + except (AttributeError, ValueError) as err: + raise SecurityException from err + + +def load_token_info_func(scheme: SecurityScheme) -> TokenInfoFunc: + if not scheme.x_token_info_func: + raise SecurityException("Missing token info function") + try: + return load_handler(scheme.x_token_info_func) + except (AttributeError, ValueError) as err: + raise SecurityException from err + + +def load_api_key_info_func(scheme: SecurityScheme) -> APIKeyInfoFunc: + if not scheme.x_api_key_info_func: + raise SecurityException("Missing API Key info function") + try: + return load_handler(scheme.x_api_key_info_func) + except (AttributeError, ValueError) as err: + raise SecurityException from err + + +def load_bearer_info_func(scheme: SecurityScheme) -> BearerInfoFunc: + if not scheme.x_bearer_info_func: + raise SecurityException("Missing Bearer info function") + try: + return load_handler(scheme.x_bearer_info_func) + except (AttributeError, ValueError) as err: + raise SecurityException from err + + +def validate_token_info( + token_info: InternalSecurityCheckResponse, + scope_validate_func: ScopeValidateFunc, + required_scopes: Sequence[str], +) -> InternalSecurityCheckResponse: + scopes: Optional[Sequence[str]] = None + if not token_info: + return None + + if "scopes" in token_info: + scopes = token_info.get("scopes") + elif "scope" in token_info: + scope = token_info.get("scope") + if isinstance(scope, str): + scopes = scope.split() + else: + raise ValueError("'scope' should be a string") + + if not scopes: + raise ValueError("missing scopes in token info") + + if not scope_validate_func(required_scopes, scopes): + raise SecurityException( + f"Invalid scopes: required: {required_scopes}, provided: {scopes}" + ) + + return token_info + + +def build_http_security_check( + requirement: InternalSecurityRequirement, scheme: SecurityScheme +) -> Optional[InternalSecurityCheck]: + _, required_scopes = requirement + + if scheme.scheme == HTTPAuthenticationScheme.BASIC: + basic_info_func = load_basic_info_func(scheme) + + return partial( + validate_basic, + basic_info_func=basic_info_func, + required_scopes=required_scopes, + ) + elif scheme.scheme == HTTPAuthenticationScheme.BEARER: + bearer_info_func = load_bearer_info_func(scheme) + bearer_format = scheme.bearer_format + + return partial( + validate_bearer, + bearer_info_func=bearer_info_func, + required_scopes=required_scopes, + bearer_format=bearer_format, + ) + else: + return None + + +def build_http_api_key_security_check( + requirement: InternalSecurityRequirement, scheme: SecurityScheme +) -> Optional[InternalSecurityCheck]: + api_key_info_func = load_api_key_info_func(scheme) + _, required_scopes = requirement + + def http_api_key_security_check(request: Request) -> InternalSecurityCheckResponse: + api_key = None + requests_dict = { + ApiKeyLocation.QUERY: request.args, + ApiKeyLocation.HEADER: request.headers, + ApiKeyLocation.COOKIE: request.cookies, + } + try: + # mypy insists this is checked + if scheme.in_ is not None and scheme.name is not None: + api_key = requests_dict[scheme.in_][scheme.name] + except KeyError: + return None + + if api_key is None: + return None + + return api_key_info_func(api_key, required_scopes) + + return http_api_key_security_check + + +def build_oauth2_security_check( + requirement: InternalSecurityRequirement, scheme: SecurityScheme +) -> Optional[InternalSecurityCheck]: + token_info_func = load_token_info_func(scheme) + scope_validate_func = load_scope_validate_func(scheme) + + _, required_scopes = requirement + + def oauth2_security_check(request: Request) -> InternalSecurityCheckResponse: + token_info = validate_oauth2_authorization_header(request, token_info_func) + + return validate_token_info(token_info, scope_validate_func, required_scopes) + + return oauth2_security_check + + +# Dispatch table mapping SecuritySchemesType to security check builder +_BUILDER_DISPATCH: Mapping[SecuritySchemesType, SecurityCheckFactory] = { + SecuritySchemesType.HTTP: build_http_security_check, + SecuritySchemesType.OAUTH2: build_oauth2_security_check, + SecuritySchemesType.HTTP_API_KEY: build_http_api_key_security_check, +} + + +def build_security_handler( + security: Sequence[InternalSecurityRequirement], + security_schemes: Mapping[str, SecurityScheme], +) -> SecurityCheck: + # build a list of security validators based on the provided security schemes + security_checks: List[InternalSecurityCheck] = [] + + for requirement in security: + requirement_name, _ = requirement + scheme = security_schemes[requirement_name] + builder = _BUILDER_DISPATCH.get(scheme.type) + if not builder: + continue + check = builder(requirement, scheme) + if not check: + continue + security_checks.append(check) + + def security_handler(request: Request) -> SecurityInfo: + + # apply the security schemes in the order listed in the API file + for check in security_checks: + + # if a security check fails if will raise the appropriate exception + # if the security check passes it will return a dict of kwargs to pass to the handler # noqa: 501 + # if the check is not applicable based on lack provided argument the check will return None indicating # noqa: 501 + # that the next (if any) check should be run. + security_args = check(request) + if security_args: + return security_args + + raise SecurityException("No checks passed") + + return security_handler + + +def security_handler_factory( + security: Sequence[SecurityRequirement], + security_schemes: Mapping[str, SecurityScheme], +) -> Callable: + """ + Build a security handler decorator based on security object and securitySchemes provided in the API file. # noqa: 501 + """ + unpacked_security = unpack_security_requirements(security) + security_handler = build_security_handler(unpacked_security, security_schemes) + + def decorator(handler: Callable): + if handler is None: + raise SecurityException("invalid or missing handler") + + @wraps(handler) + def handler_with_security(*args, **kwargs): + # match the args that connexion passes to handlers after a security check + token_info = security_handler(current_flask_request) + user = token_info.get("sub", token_info.get("uid")) + return handler(*args, user=user, token_info=token_info, **kwargs) + + return handler_with_security + + return decorator diff --git a/asynction/server.py b/asynction/server.py index 761d6c5..ffc4270 100644 --- a/asynction/server.py +++ b/asynction/server.py @@ -2,12 +2,9 @@ The :class:`AsynctionSocketIO` server is essentially a ``flask_socketio.SocketIO`` server with an additional factory classmethod. """ - from functools import singledispatch -from importlib import import_module from pathlib import Path from typing import Any -from typing import Callable from typing import Optional from typing import Sequence from urllib.parse import urlparse @@ -19,12 +16,15 @@ from asynction.exceptions import ValidationException from asynction.playground_docs import make_docs_blueprint +from asynction.security import security_handler_factory from asynction.types import GLOBAL_NAMESPACE from asynction.types import AsyncApiSpec from asynction.types import ChannelBindings from asynction.types import ChannelHandlers from asynction.types import ErrorHandler from asynction.types import JSONMapping +from asynction.types import SecurityRequirement +from asynction.utils import load_handler from asynction.validation import bindings_validator_factory from asynction.validation import callback_validator_factory from asynction.validation import publish_message_validator_factory @@ -77,11 +77,8 @@ def load_spec(spec_path: Path) -> AsyncApiSpec: return AsyncApiSpec.from_dict(raw_resolved) -def load_handler(handler_id: str) -> Callable: - *module_path_elements, object_name = handler_id.split(".") - module = import_module(".".join(module_path_elements)) - - return getattr(module, object_name) +def _noop_handler(*args, **kwargs) -> None: + return None class AsynctionSocketIO(SocketIO): @@ -159,7 +156,7 @@ def from_spec( """ spec = load_spec(spec_path=spec_path) - + server_security: Sequence[SecurityRequirement] = [] if ( server_name is not None and kwargs.get("path") is None @@ -176,35 +173,55 @@ def from_spec( if url_parse_result.path: kwargs["path"] = url_parse_result.path + server_security = server.security + asio = cls(spec, validation, docs, app, **kwargs) - asio._register_handlers(default_error_handler) + asio._register_handlers(server_security, default_error_handler) return asio def _register_namespace_handlers( self, namespace: str, - channel_handlers: ChannelHandlers, + channel_handlers: Optional[ChannelHandlers], channel_bindings: Optional[ChannelBindings], + server_security: Sequence[SecurityRequirement], ) -> None: - if channel_handlers.connect is not None: - handler = load_handler(channel_handlers.connect) + on_connect = _noop_handler + + # if a connection handler is defined then load it + if channel_handlers and channel_handlers.connect is not None: + on_connect = load_handler(channel_handlers.connect) if self.validation: with_bindings_validation = bindings_validator_factory(channel_bindings) - handler = with_bindings_validation(handler) + on_connect = with_bindings_validation(on_connect) - self.on_event("connect", handler, namespace) + if server_security: + # create a security handler wrapper + with_security = security_handler_factory( + server_security, self.spec.components.security_schemes + ) + # apply security + on_connect = with_security(on_connect) - if channel_handlers.disconnect is not None: - handler = load_handler(channel_handlers.disconnect) - self.on_event("disconnect", handler, namespace) + # if no user defined connection handler was specified + # or no security scheme was required then on_connect should still be None + if on_connect is not _noop_handler: + self.on_event("connect", on_connect, namespace) - if channel_handlers.error is not None: - handler = load_handler(channel_handlers.error) - self.on_error(namespace)(handler) + if channel_handlers: + if channel_handlers.disconnect is not None: + handler = load_handler(channel_handlers.disconnect) + self.on_event("disconnect", handler, namespace) + + if channel_handlers.error is not None: + handler = load_handler(channel_handlers.error) + self.on_error(namespace)(handler) def _register_handlers( - self, default_error_handler: Optional[ErrorHandler] = None + self, + server_security: Sequence[SecurityRequirement] = (), + default_error_handler: Optional[ErrorHandler] = None, ) -> None: for namespace, channel in self.spec.channels.items(): if channel.publish is not None: @@ -220,10 +237,12 @@ def _register_handlers( self.on_event(message.name, handler, namespace) - if channel.x_handlers is not None: - self._register_namespace_handlers( - namespace, channel.x_handlers, channel.bindings - ) + self._register_namespace_handlers( + namespace, + channel.x_handlers, + channel.bindings, + server_security=server_security, + ) if default_error_handler is not None: self.on_error_default(default_error_handler) diff --git a/asynction/types.py b/asynction/types.py index e658851..8f97209 100644 --- a/asynction/types.py +++ b/asynction/types.py @@ -1,9 +1,11 @@ from dataclasses import asdict from dataclasses import dataclass from dataclasses import field +from dataclasses import fields from enum import Enum from typing import Any from typing import Callable +from typing import Iterator from typing import Mapping from typing import Optional from typing import Sequence @@ -13,11 +15,243 @@ from svarog import register_forge from svarog.types import Forge +GLOBAL_NAMESPACE = "/" + JSONMappingValue = Any JSONMapping = Mapping[str, JSONMappingValue] JSONSchema = JSONMapping -GLOBAL_NAMESPACE = "/" + +class HTTPAuthenticationScheme(Enum): + """ + https://www.iana.org/assignments/http-authschemes/http-authschemes.xhtml + """ + + BASIC = "basic" + DIGEST = "digest" + BEARER = "bearer" + + +class OAuth2FlowType(Enum): + """ + https://www.asyncapi.com/docs/specifications/v2.2.0#oauthFlowsObject + """ + + IMPLICIT = "implicit" + PASSWORD = "password" + CLIENT_CREDENTIALS = "clientCredentials" + AUTHORIZATION_CODE = "authorizationCode" + + +@dataclass +class OAuth2Flow: + """ + https://www.asyncapi.com/docs/specifications/v2.2.0#oauthFlowObject + """ + + scopes: Mapping[str, str] + authorization_url: Optional[str] = None + token_url: Optional[str] = None + refresh_url: Optional[str] = None + + @staticmethod + def forge( + type_: Type["OAuth2Flow"], data: JSONMapping, forge: Forge + ) -> "OAuth2Flow": + return type_( + scopes=forge(type_.__annotations__["scopes"], data.get("scopes")), + authorization_url=forge( + type_.__annotations__["authorization_url"], data.get("authorizationUrl") + ), + token_url=forge(type_.__annotations__["token_url"], data.get("tokenUrl")), + refresh_url=forge( + type_.__annotations__["refresh_url"], data.get("refreshUrl") + ), + ) + + +register_forge(OAuth2Flow, OAuth2Flow.forge) + + +@dataclass +class OAuth2Flows: + implicit: Optional[OAuth2Flow] = None + password: Optional[OAuth2Flow] = None + client_credentials: Optional[OAuth2Flow] = None + authorization_code: Optional[OAuth2Flow] = None + + @staticmethod + def forge( + type_: Type["OAuth2Flows"], data: JSONMapping, forge: Forge + ) -> "OAuth2Flows": + return type_( + implicit=forge(type_.__annotations__["implicit"], data.get("implicit")), + password=forge(type_.__annotations__["password"], data.get("password")), + client_credentials=forge( + type_.__annotations__["client_credentials"], + data.get("clientCredentials"), + ), + authorization_code=forge( + type_.__annotations__["authorization_code"], + data.get("authorizationCode"), + ), + ) + + def __post_init__(self): + if self.implicit is not None and self.implicit.authorization_url is None: + raise ValueError("Implicit OAuth flow is missing Authorization URL") + elif self.password is not None and self.password.token_url is None: + raise ValueError("Password OAuth flow is missing Token URL") + elif ( + self.client_credentials is not None + and self.client_credentials.token_url is None + ): + raise ValueError("Client Credentials OAuth flow is missing Token URL") + elif ( + self.authorization_code is not None + and self.authorization_code.token_url is None + ): + raise ValueError("Authorization code OAuth flow is missing Token URL") + + def supported_scopes(self) -> Iterator[str]: + for f in fields(self): + flow = getattr(self, f.name) + if flow: + for scope in flow.scopes: + yield scope + + +register_forge(OAuth2Flows, OAuth2Flows.forge) + + +class SecuritySchemesType(Enum): + """ + https://www.asyncapi.com/docs/specifications/v2.2.0#securitySchemeObjectType + """ + + USER_PASSWORD = "userPassword" + API_KEY = "apiKey" + X509 = "X509" + SYMMETRIC_ENCRYPTION = "symmetricEncryption" + ASYMMETRIC_ENCRYPTION = "asymmetricEncryption" + HTTP_API_KEY = "httpApiKey" + HTTP = "http" + OAUTH2 = "oauth2" + OPENID_CONNECT = "openIdConnect" + PLAIN = "plain" + SCRAM_SHA256 = "scramSha256" + SCRAM_SHA512 = "scramSha512" + GSSAPI = "gssapi" + + +class ApiKeyLocation(Enum): + """ + https://www.asyncapi.com/docs/specifications/v2.2.0#securitySchemeObject + """ + + USER = "user" + PASSWORD = "password" + QUERY = "query" + HEADER = "header" + COOKIE = "cookie" + + +@dataclass +class SecurityScheme: + """ + https://www.asyncapi.com/docs/specifications/v2.2.0#securitySchemeObject + """ + + type: SecuritySchemesType + description: Optional[str] = None + name: Optional[str] = None # Required for httpApiKey + in_: Optional[ApiKeyLocation] = None # Required for httpApiKey | apiKey + scheme: Optional[HTTPAuthenticationScheme] = None # Required for http + bearer_format: Optional[str] = None # Optional for http ("bearer") + flows: Optional[OAuth2Flows] = None # Required for oauth2 + open_id_connect_url: Optional[str] = None # Required for openIdConnect + + x_basic_info_func: Optional[str] = None # Required for http(basic) + x_bearer_info_func: Optional[str] = None # Required for http(bearer) + x_token_info_func: Optional[str] = None # Required for oauth2 + x_api_key_info_func: Optional[str] = None # Required for apiKey + x_scope_validate_func: Optional[str] = None # Optional for oauth2 + + def __post_init__(self): + if not self.flows and self.type in [ + SecuritySchemesType.OAUTH2, + SecuritySchemesType.OPENID_CONNECT, + ]: + raise ValueError( + "flows field should be be defined " f"for {self.type} security schemes" + ) + + if self.type is SecuritySchemesType.HTTP: + # NOTE bearer_format is optional for HTTP + if not self.scheme: + raise ValueError(f"scheme is required for {self.type} security schemes") + + if self.type is SecuritySchemesType.HTTP_API_KEY: + options = [ + ApiKeyLocation.QUERY, + ApiKeyLocation.HEADER, + ApiKeyLocation.COOKIE, + ] + if not self.in_ or self.in_ not in options: + raise ValueError( + f'"in" field must be one of {options} ' + f"for {self.type} security schemes" + ) + if not self.name: + raise ValueError(f'"name" is required for {self.type} security schemes') + + # TODO include validation for other types + + @staticmethod + def forge( + type_: Type["SecurityScheme"], data: JSONMapping, forge: Forge + ) -> "SecurityScheme": + return type_( + type=forge(type_.__annotations__["type"], data.get("type")), + description=forge( + type_.__annotations__["description"], data.get("description") + ), + name=forge(type_.__annotations__["name"], data.get("name")), + in_=forge(type_.__annotations__["in_"], data.get("in")), + scheme=forge(type_.__annotations__["scheme"], data.get("scheme")), + bearer_format=forge( + type_.__annotations__["bearer_format"], data.get("bearerFormat") + ), + flows=forge(type_.__annotations__["flows"], data.get("flows")), + open_id_connect_url=forge( + type_.__annotations__["open_id_connect_url"], + data.get("openIdConnectUrl"), + ), + x_basic_info_func=forge( + type_.__annotations__["x_basic_info_func"], data.get("x-basicInfoFunc") + ), + x_bearer_info_func=forge( + type_.__annotations__["x_bearer_info_func"], + data.get("x-bearerInfoFunc"), + ), + x_token_info_func=forge( + type_.__annotations__["x_token_info_func"], data.get("x-tokenInfoFunc") + ), + x_api_key_info_func=forge( + type_.__annotations__["x_api_key_info_func"], + data.get("x-apiKeyInfoFunc"), + ), + x_scope_validate_func=forge( + type_.__annotations__["x_scope_validate_func"], + data.get("x-scopeValidateFunc"), + ), + ) + + +register_forge(SecurityScheme, SecurityScheme.forge) + + +SecurityRequirement = Mapping[str, Sequence[str]] @dataclass @@ -178,6 +412,7 @@ class Server: url: str protocol: ServerProtocol + security: Sequence[SecurityRequirement] = field(default_factory=list) @dataclass @@ -189,6 +424,27 @@ class Info: description: Optional[str] = None +@dataclass +class Components: + """https://www.asyncapi.com/docs/specifications/v2.2.0#componentsObject""" + + security_schemes: Mapping[str, SecurityScheme] = field(default_factory=dict) + + @staticmethod + def forge( + type_: Type["Components"], data: JSONMapping, forge: Forge + ) -> "Components": + return type_( + security_schemes=forge( + type_.__annotations__["security_schemes"], + data.get("securitySchemes", dict()), + ) + ) + + +register_forge(Components, Components.forge) + + @dataclass class AsyncApiSpec: """https://www.asyncapi.com/docs/specifications/2.2.0#A2SObject""" @@ -197,6 +453,47 @@ class AsyncApiSpec: channels: Mapping[str, Channel] info: Info servers: Mapping[str, Server] = field(default_factory=dict) + components: Components = field(default_factory=Components) + + def __post_init__(self): + for server_name, server in self.servers.items(): + for security_req in server.security: + (security_scheme_name, scopes), *other = security_req.items() + + if other: + raise ValueError( + f"{server_name} contains invalid " + f"security requirement: {security_req}" + ) + + security_scheme = self.components.security_schemes.get( + security_scheme_name + ) + if security_scheme is None: + raise ValueError( + f"{security_scheme_name} referenced within '{server_name}'" + " server does not exist in components/securitySchemes" + ) + + if scopes: + if security_scheme.type not in [ + SecuritySchemesType.OAUTH2, + SecuritySchemesType.OPENID_CONNECT, + ]: + raise ValueError( + "Scopes MUST be an empty array for " + f"{security_scheme.type} security requirements" + ) + + if security_scheme.type is SecuritySchemesType.OAUTH2: + supported_scopes = security_scheme.flows.supported_scopes() + + for scope in scopes: + if scope not in supported_scopes: + raise ValueError( + f"OAuth2 scope {scope} is not defined within " + f"the {security_scheme_name} security scheme" + ) @staticmethod def from_dict(data: JSONMapping) -> "AsyncApiSpec": diff --git a/asynction/utils.py b/asynction/utils.py new file mode 100644 index 0000000..d28850d --- /dev/null +++ b/asynction/utils.py @@ -0,0 +1,9 @@ +from importlib import import_module +from typing import Callable + + +def load_handler(handler_id: str) -> Callable: + *module_path_elements, object_name = handler_id.split(".") + module = import_module(".".join(module_path_elements)) + + return getattr(module, object_name) diff --git a/asynction/validation.py b/asynction/validation.py index 90d2316..ed0eb02 100644 --- a/asynction/validation.py +++ b/asynction/validation.py @@ -93,10 +93,10 @@ def publish_message_validator_factory(message: Message) -> Callable: def decorator(handler: Callable): @wraps(handler) - def handler_with_validation(*args): + def handler_with_validation(*args, **kwargs): validate_payload(args, message.payload) - ack = handler(*args) + ack = handler(*args, **kwargs) if ack is not None and message.x_ack is not None: jsonschema_validate_ack(ack, message.x_ack.args) @@ -147,9 +147,9 @@ def validate_request_bindings( def bindings_validator_factory(bindings: Optional[ChannelBindings]) -> Callable: def decorator(handler: Callable): @wraps(handler) - def handler_with_validation(*args): + def handler_with_validation(*args, **kwargs): validate_request_bindings(current_flask_request, bindings) - return handler(*args) + return handler(*args, **kwargs) return handler_with_validation diff --git a/docs/index.rst b/docs/index.rst index fa763bd..9fe070e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -39,6 +39,7 @@ Exceptions .. autoexception:: asynction.PayloadValidationException .. autoexception:: asynction.BindingsValidationException .. autoexception:: asynction.MessageAckValidationException +.. autoexception:: asynction.SecurityException Indices and tables ================== diff --git a/requirements.txt b/requirements.txt index 765c756..d9a3033 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ Flask>=0.9 jsonschema~=4.0 PyYAML~=6.0 svarog>=0.1.6,<2.0.0 +typing-extensions~=4.0.0 \ No newline at end of file diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 6047853..9cfcf49 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -6,10 +6,14 @@ class FixturePaths(NamedTuple): simple: Path echo: Path simple_with_servers: Path + security: Path + security_oauth2: Path paths = FixturePaths( simple=Path(__file__).parent.joinpath("simple.yml"), echo=Path(__file__).parent.joinpath("echo.yml"), simple_with_servers=Path(__file__).parent.joinpath("simple_with_servers.yml"), + security=Path(__file__).parent.joinpath("security.yaml"), + security_oauth2=Path(__file__).parent.joinpath("security_oauth2.yaml"), ) diff --git a/tests/fixtures/handlers.py b/tests/fixtures/handlers.py index 0d3e664..9d128e5 100644 --- a/tests/fixtures/handlers.py +++ b/tests/fixtures/handlers.py @@ -1,9 +1,14 @@ +import base64 from typing import Any +from typing import Mapping +from typing import Optional +from typing import Sequence from flask import request from flask_socketio import emit from typing_extensions import TypedDict +from asynction import SecurityInfo from asynction.exceptions import ValidationException @@ -52,3 +57,83 @@ def authenticated_connect() -> None: def echo_failed_validation(e: Exception) -> None: if isinstance(e, ValidationException): emit("echo errors", "Incoming message failed validation") + + +def basic_info( + username: str, password: str, required_scopes: Optional[Sequence[str]] = None +) -> SecurityInfo: + if username != "username" or password != "password": + raise ConnectionRefusedError("Invalid username or password") + + scopes = list(required_scopes) if required_scopes else [] + return SecurityInfo(sub=username, scopes=scopes) + + +def basic_info_bad(*args, **kwargs) -> Optional[SecurityInfo]: + return None + + +def bearer_info( + token: str, + required_scopes: Optional[Sequence[str]] = None, + bearer_format: Optional[str] = None, +) -> SecurityInfo: + username, password = base64.b64decode(token).decode().split(":") + if username != "username" or password != "password" or bearer_format != "test": + raise ConnectionRefusedError("Invalid username or password") + + scopes = list(required_scopes) if required_scopes else [] + return SecurityInfo(uid=username, scopes=scopes) + + +def bearer_info_bad(*args, **kwargs) -> Optional[SecurityInfo]: + return None + + +def api_key_info( + token: str, required_scopes: Optional[Sequence[str]] = None +) -> SecurityInfo: + username, password = base64.b64decode(token).decode().split(":") + if username != "username" or password != "password": + raise ConnectionRefusedError("Invalid username or password") + + scopes = list(required_scopes) if required_scopes else [] + return SecurityInfo(sub=username, scopes=scopes) + + +def api_key_info_bad(*args, **kwargs) -> Optional[SecurityInfo]: + return None + + +def token_info(token: str) -> SecurityInfo: + username, password = base64.b64decode(token).decode().split(":") + if username != "username" or password != "password": + raise ConnectionRefusedError("Invalid username or password") + + return SecurityInfo(sub=username, scopes=["a", "b"]) + + +def token_info_alternate(token: str) -> SecurityInfo: + username, password = base64.b64decode(token).decode().split(":") + if username != "username" or password != "password": + raise ConnectionRefusedError("Invalid username or password") + + return SecurityInfo(uid=username, scope="a b") + + +def token_info_alternate_invalid(token: str) -> Mapping: + username, password = base64.b64decode(token).decode().split(":") + if username != "username" or password != "password": + raise ConnectionRefusedError("Invalid username or password") + + # using a dict instead of the SecurityInfo to force bad values through for the tests + return dict(uid=username, scope=1) + + +def token_info_bad(*args, **kwargs) -> Optional[SecurityInfo]: + return None + + +def token_info_missing_required(*args, **kwargs) -> Optional[Mapping]: + # using a dict instead of the SecurityInfo to force bad values through for the tests + return dict(something="") diff --git a/tests/fixtures/security.yaml b/tests/fixtures/security.yaml new file mode 100644 index 0000000..4b9e45c --- /dev/null +++ b/tests/fixtures/security.yaml @@ -0,0 +1,39 @@ +asyncapi: 2.2.0 +info: + title: Test + version: 1.0.0 +servers: + test: + protocol: wss + url: 127.0.0.1/socket.io + security: + - basic: [] + - bearer: [] + - apiKey: [] +channels: + /: + subscribe: + message: + $ref: "#/components/messages/Test" +components: + messages: + Test: + name: test + payload: + type: string + + securitySchemes: + basic: + type: http + scheme: basic + x-basicInfoFunc: tests.fixtures.handlers.basic_info + bearer: + type: http + scheme: bearer + bearerFormat: test + x-bearerInfoFunc: tests.fixtures.handlers.bearer_info + apiKey: + type: httpApiKey + in: query + name: api_key + x-apiKeyInfoFunc: tests.fixtures.handlers.api_key_info diff --git a/tests/fixtures/security_oauth2.yaml b/tests/fixtures/security_oauth2.yaml new file mode 100644 index 0000000..9e7a142 --- /dev/null +++ b/tests/fixtures/security_oauth2.yaml @@ -0,0 +1,32 @@ +asyncapi: 2.2.0 +info: + title: Test + version: 1.0.0 +servers: + test: + protocol: wss + url: 127.0.0.1/socket.io + security: + - oauth2: ["a"] +channels: + /: + subscribe: + message: + $ref: "#/components/messages/Test" +components: + messages: + Test: + name: test + payload: + type: string + + securitySchemes: + oauth2: + type: oauth2 + flows: + implicit: + authorizationUrl: test + scopes: + a: "Test A" + b: "Test B" + x-tokenInfoFunc: tests.fixtures.handlers.token_info diff --git a/tests/integration/test_server.py b/tests/integration/test_server.py index 98d2b06..0763a95 100644 --- a/tests/integration/test_server.py +++ b/tests/integration/test_server.py @@ -1,3 +1,4 @@ +import base64 from enum import Enum from typing import Callable @@ -313,3 +314,152 @@ def test_docs_raw_specification_endpoint( with fixture_paths.simple.open() as f: assert resolve_references(yaml.safe_load(f.read())) == resp.json + + +@pytest.mark.parametrize( + argnames="factory_fixture", + argvalues=[ + FactoryFixture.ASYNCTION_SOCKET_IO, + FactoryFixture.MOCK_ASYNCTION_SOCKET_IO, + ], + ids=["server", "mock_server"], +) +def test_client_fails_to_connect_with_no_auth( + factory_fixture: FactoryFixture, + flask_app: Flask, + fixture_paths: FixturePaths, + request: pytest.FixtureRequest, +): + server_factory: AsynctionFactory = request.getfixturevalue(factory_fixture.value) + + socketio_server = server_factory( + spec_path=fixture_paths.security, server_name="test" + ) + flask_test_client = flask_app.test_client() + + with pytest.raises(ConnectionRefusedError): + socketio_test_client = socketio_server.test_client( + flask_app, flask_test_client=flask_test_client + ) + + assert socketio_test_client.is_connected() is False + + +@pytest.mark.parametrize( + argnames="factory_fixture", + argvalues=[ + FactoryFixture.ASYNCTION_SOCKET_IO, + FactoryFixture.MOCK_ASYNCTION_SOCKET_IO, + ], + ids=["server", "mock_server"], +) +def test_client_connects_with_http_basic_auth( + factory_fixture: FactoryFixture, + flask_app: Flask, + fixture_paths: FixturePaths, + request: pytest.FixtureRequest, +): + server_factory: AsynctionFactory = request.getfixturevalue(factory_fixture.value) + + socketio_server = server_factory( + spec_path=fixture_paths.security, server_name="test" + ) + flask_test_client = flask_app.test_client() + + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"Authorization": f"basic {basic_auth}"} + socketio_test_client = socketio_server.test_client( + flask_app, flask_test_client=flask_test_client, headers=headers + ) + + assert socketio_test_client.is_connected() is True + + +@pytest.mark.parametrize( + argnames="factory_fixture", + argvalues=[ + FactoryFixture.ASYNCTION_SOCKET_IO, + FactoryFixture.MOCK_ASYNCTION_SOCKET_IO, + ], + ids=["server", "mock_server"], +) +def test_client_connects_with_http_bearer_auth( + factory_fixture: FactoryFixture, + flask_app: Flask, + fixture_paths: FixturePaths, + request: pytest.FixtureRequest, +): + server_factory: AsynctionFactory = request.getfixturevalue(factory_fixture.value) + + socketio_server = server_factory( + spec_path=fixture_paths.security, server_name="test" + ) + flask_test_client = flask_app.test_client() + + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"Authorization": f"bearer {basic_auth}"} + socketio_test_client = socketio_server.test_client( + flask_app, flask_test_client=flask_test_client, headers=headers + ) + + assert socketio_test_client.is_connected() is True + + +@pytest.mark.parametrize( + argnames="factory_fixture", + argvalues=[ + FactoryFixture.ASYNCTION_SOCKET_IO, + FactoryFixture.MOCK_ASYNCTION_SOCKET_IO, + ], + ids=["server", "mock_server"], +) +def test_client_connects_with_http_api_key_auth( + factory_fixture: FactoryFixture, + flask_app: Flask, + fixture_paths: FixturePaths, + request: pytest.FixtureRequest, +): + server_factory: AsynctionFactory = request.getfixturevalue(factory_fixture.value) + + socketio_server = server_factory( + spec_path=fixture_paths.security, server_name="test" + ) + flask_test_client = flask_app.test_client() + + basic_auth = base64.b64encode("username:password".encode()).decode() + query = f"api_key={basic_auth}" + socketio_test_client = socketio_server.test_client( + flask_app, flask_test_client=flask_test_client, query_string=query + ) + + assert socketio_test_client.is_connected() is True + + +@pytest.mark.parametrize( + argnames="factory_fixture", + argvalues=[ + FactoryFixture.ASYNCTION_SOCKET_IO, + FactoryFixture.MOCK_ASYNCTION_SOCKET_IO, + ], + ids=["server", "mock_server"], +) +def test_client_connects_with_oauth2( + factory_fixture: FactoryFixture, + flask_app: Flask, + fixture_paths: FixturePaths, + request: pytest.FixtureRequest, +): + server_factory: AsynctionFactory = request.getfixturevalue(factory_fixture.value) + + socketio_server = server_factory( + spec_path=fixture_paths.security_oauth2, server_name="test" + ) + flask_test_client = flask_app.test_client() + + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"Authorization": f"bearer {basic_auth}"} + socketio_test_client = socketio_server.test_client( + flask_app, flask_test_client=flask_test_client, headers=headers + ) + + assert socketio_test_client.is_connected() is True diff --git a/tests/unit/test_mock_server.py b/tests/unit/test_mock_server.py index d378888..0cf3210 100644 --- a/tests/unit/test_mock_server.py +++ b/tests/unit/test_mock_server.py @@ -21,22 +21,31 @@ from asynction import PayloadValidationException from asynction.exceptions import BindingsValidationException +from asynction.exceptions import SecurityException from asynction.mock_server import MockAsynctionSocketIO -from asynction.mock_server import _noop_handler from asynction.mock_server import generate_fake_data_from_schema from asynction.mock_server import make_faker_formats from asynction.mock_server import task_runner from asynction.mock_server import task_scheduler from asynction.server import AsynctionSocketIO +from asynction.server import _noop_handler +from asynction.types import GLOBAL_NAMESPACE from asynction.types import AsyncApiSpec from asynction.types import Channel from asynction.types import ChannelBindings +from asynction.types import ChannelHandlers +from asynction.types import Components from asynction.types import ErrorHandler +from asynction.types import HTTPAuthenticationScheme from asynction.types import Info from asynction.types import Message from asynction.types import MessageAck from asynction.types import OneOfMessages from asynction.types import Operation +from asynction.types import SecurityScheme +from asynction.types import SecuritySchemesType +from asynction.types import Server +from asynction.types import ServerProtocol from asynction.types import WebSocketsChannelBindings from tests.fixtures import FixturePaths from tests.utils import deep_unwrap @@ -326,6 +335,41 @@ def test_register_handlers_registers_connection_handler_with_bindings_validation handler_with_validation() +def test_register_namespace_handlers_emits_security_validator_if_security_enabled(): + channel_handlers = ChannelHandlers(connect="tests.fixtures.handlers.connect") + spec = AsyncApiSpec( + asyncapi="2.2.0", + info=Info("test", "1.0.0"), + servers={ + "test": Server("https://localhost/", ServerProtocol.WSS, [{"basic": []}]) + }, + channels={GLOBAL_NAMESPACE: Channel(x_handlers=channel_handlers)}, + components=Components( + security_schemes={ + "basic": SecurityScheme( + type=SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BASIC, + x_basic_info_func="tests.fixtures.handlers.basic_info", + ) + } + ), + ) + + server = new_mock_asynction_socket_io(spec) + server._register_handlers(server_security=server.spec.servers.get("test").security) + event_name, registered_handler, _ = server.handlers[0] + assert event_name == "connect" + handler_with_security = deep_unwrap(registered_handler, depth=1) + actual_handler = deep_unwrap(handler_with_security) + + with Flask(__name__).test_client() as c: + c.post() # Inject invalid POST request + actual_handler() + with pytest.raises(SecurityException): + handler_with_security() # handler raises security exception + assert True + + @pytest.mark.parametrize( argnames=["optional_error_handler"], argvalues=[[lambda _: None], [None]], @@ -338,7 +382,7 @@ def test_register_handlers_registers_default_error_handler( AsyncApiSpec(asyncapi=faker.pystr(), info=server_info, channels={}) ) - server._register_handlers(optional_error_handler) + server._register_handlers(default_error_handler=optional_error_handler) assert server.default_exception_handler == optional_error_handler diff --git a/tests/unit/test_security.py b/tests/unit/test_security.py new file mode 100644 index 0000000..07a3374 --- /dev/null +++ b/tests/unit/test_security.py @@ -0,0 +1,1056 @@ +import base64 +from unittest.mock import Mock + +import pytest +from flask import Flask +from flask import request as current_flask_request + +from asynction.exceptions import SecurityException +from asynction.security import build_http_api_key_security_check +from asynction.security import build_http_security_check +from asynction.security import build_oauth2_security_check +from asynction.security import build_security_handler +from asynction.security import extract_auth_header +from asynction.security import load_api_key_info_func +from asynction.security import load_basic_info_func +from asynction.security import load_bearer_info_func +from asynction.security import load_scope_validate_func +from asynction.security import load_token_info_func +from asynction.security import security_handler_factory +from asynction.types import ApiKeyLocation +from asynction.types import HTTPAuthenticationScheme +from asynction.types import OAuth2Flow +from asynction.types import OAuth2Flows +from asynction.types import SecurityScheme +from asynction.types import SecuritySchemesType +from tests.fixtures import handlers + + +def test_extract_auth_header(): + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"Authorization": f"basic {basic_auth}"} + c.post(headers=headers) + extract_auth_header(current_flask_request) + + +def test_extract_auth_header_fails_missing_header(): + with Flask(__name__).test_client() as c: + c.post() + assert extract_auth_header(current_flask_request) is None + + +def test_extract_auth_header_fails_invalid_header(): + with Flask(__name__).test_client() as c: + headers = {"Authorization": "invalid"} + c.post(headers=headers) + with pytest.raises(SecurityException): + extract_auth_header(current_flask_request) + + +def test_load_basic_info_func(): + scheme = SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BASIC, + x_basic_info_func="tests.fixtures.handlers.basic_info", + ) + assert load_basic_info_func(scheme) == handlers.basic_info + + +def test_load_bearer_info_func(): + scheme = SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BEARER, + x_bearer_info_func="tests.fixtures.handlers.bearer_info", + ) + assert load_bearer_info_func(scheme) == handlers.bearer_info + + +def test_load_bearer_info_func_fails(): + scheme = SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BEARER, + x_basic_info_func="tests.fixtures.handlers.bearer_info", + ) + scheme.x_bearer_info_func = "" + with pytest.raises(SecurityException): + load_bearer_info_func(scheme) + + +def test_load_basic_info_func_fails(): + scheme = SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BASIC, + x_basic_info_func="tests.fixtures.handlers.basic_info", + ) + # set to empty string after validations + scheme.x_basic_info_func = "" + with pytest.raises(SecurityException): + load_basic_info_func(scheme) + + +def test_load_api_key_info_func(): + scheme = SecurityScheme( + SecuritySchemesType.HTTP_API_KEY, + name="api_key", + in_=ApiKeyLocation.QUERY, + x_api_key_info_func="tests.fixtures.handlers.api_key_info", + ) + assert load_api_key_info_func(scheme) == handlers.api_key_info + + +def test_load_api_key_info_func_fails(): + scheme = SecurityScheme( + SecuritySchemesType.HTTP_API_KEY, + name="api_key", + in_=ApiKeyLocation.QUERY, + x_api_key_info_func="tests.fixtures.handlers.api_key_info", + ) + scheme.x_api_key_info_func = "" + with pytest.raises(SecurityException): + load_api_key_info_func(scheme) + + +def test_load_token_info_func(): + scheme = SecurityScheme( + SecuritySchemesType.OAUTH2, + flows=OAuth2Flows(implicit=OAuth2Flow(authorization_url="", scopes={"a": "a"})), + x_token_info_func="tests.fixtures.handlers.token_info", + ) + assert load_token_info_func(scheme) == handlers.token_info + + +def test_load_token_info_func_fails(): + scheme = SecurityScheme( + SecuritySchemesType.OAUTH2, + flows=OAuth2Flows(implicit=OAuth2Flow(authorization_url="", scopes={"a": "a"})), + x_token_info_func="tests.fixtures.handlers.token_info", + ) + scheme.x_token_info_func = "" + with pytest.raises(SecurityException): + load_token_info_func(scheme) + + +def test_load_scope_validate_func_fails(): + scheme = SecurityScheme( + SecuritySchemesType.OAUTH2, + flows=OAuth2Flows(implicit=OAuth2Flow(authorization_url="", scopes={"a": "a"})), + x_token_info_func="tests.fixtures.handlers.token_info", + x_scope_validate_func="invalid", + ) + with pytest.raises(SecurityException): + load_scope_validate_func(scheme) + + +def test_build_basic_http_security_check(): + requirement = ("test", []) + scheme = SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BASIC, + x_basic_info_func="tests.fixtures.handlers.basic_info", + ) + check = build_http_security_check(requirement, scheme) + assert callable(check) + + +def test_build_bearer_http_security_check(): + requirement = ("test", []) + scheme = SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BEARER, + x_bearer_info_func="tests.fixtures.handlers.bearer_info", + ) + check = build_http_security_check(requirement, scheme) + assert callable(check) + + +def test_build_http_api_key_security_check(): + requirement = ("test", []) + scheme = SecurityScheme( + SecuritySchemesType.HTTP_API_KEY, + name="api_key", + in_=ApiKeyLocation.QUERY, + x_api_key_info_func="tests.fixtures.handlers.api_key_info", + ) + check = build_http_api_key_security_check(requirement, scheme) + assert callable(check) + + +def test_build_http_api_key_security_scheme_fails_without_name(): + with pytest.raises(ValueError): + SecurityScheme( + SecuritySchemesType.HTTP_API_KEY, + in_=ApiKeyLocation.QUERY, + x_api_key_info_func="tests.fixtures.handlers.api_key_info", + ) + + +def test_build_oauth2_security_check(): + requirement = ("test", []) + scheme = SecurityScheme( + SecuritySchemesType.OAUTH2, + flows=OAuth2Flows(implicit=OAuth2Flow(authorization_url="", scopes={"a": "a"})), + x_token_info_func="tests.fixtures.handlers.token_info", + ) + check = build_oauth2_security_check(requirement, scheme) + assert callable(check) + + +def test_build_security_check_list(): + requirements = [ + ("basic", []), + ("bearer", []), + ("api_key", []), + ("oauth2", ["a"]), + ] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BASIC, + x_basic_info_func="tests.fixtures.handlers.basic_info", + ), + bearer=SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BEARER, + x_bearer_info_func="tests.fixtures.handlers.bearer_info", + ), + api_key=SecurityScheme( + SecuritySchemesType.HTTP_API_KEY, + name="api_key", + in_=ApiKeyLocation.QUERY, + x_api_key_info_func="tests.fixtures.handlers.api_key_info", + ), + oauth2=SecurityScheme( + SecuritySchemesType.OAUTH2, + flows=OAuth2Flows( + implicit=OAuth2Flow(authorization_url="", scopes={"a": "a"}) + ), + x_token_info_func="tests.fixtures.handlers.token_info", + ), + ) + + check = build_security_handler(requirements, schemes) + assert check + assert callable(check) + + +def test_build_security_handler_with_invalid_handler(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BASIC, + x_basic_info_func="tests.fixtures.handlers.basic_info", + ) + ) + factory = security_handler_factory(requirements, schemes) + with pytest.raises(SecurityException): + factory(None) + + +def test_http_basic_works(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BASIC, + x_basic_info_func="tests.fixtures.handlers.basic_info", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"Authorization": f"basic {basic_auth}"} + c.post(headers=headers) + handler_with_security() + mock_ack.assert_called_once() + + +def test_http_basic_fails(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BASIC, + x_basic_info_func="tests.fixtures.handlers.basic_info", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:wrong".encode()).decode() + headers = {"Authorization": f"basic {basic_auth}"} + c.post(headers=headers) + with pytest.raises(ConnectionRefusedError): + handler_with_security() + + mock_ack.assert_not_called() + + +def test_http_basic_fails_missing_basic_info(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BASIC, + x_basic_info_func="tests.fixtures.handlers.basic_info_fake", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + with pytest.raises(SecurityException): + factory = security_handler_factory(requirements, schemes) + factory(on_connect) + mock_ack.assert_not_called() + + +def test_http_basic_fails_because_basic_info_returns_none(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BASIC, + x_basic_info_func="tests.fixtures.handlers.basic_info_bad", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"Authorization": f"basic {basic_auth}"} + c.post(headers=headers) + with pytest.raises(ConnectionRefusedError): + handler_with_security() + + mock_ack.assert_not_called() + + +def test_http_bearer_works(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BEARER, + bearer_format="test", + x_bearer_info_func="tests.fixtures.handlers.bearer_info", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"Authorization": f"bearer {basic_auth}"} + c.post(headers=headers) + handler_with_security() + + mock_ack.assert_called_once() + + +def test_http_bearer_fails_with_no_auth_header(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BEARER, + bearer_format="test", + x_bearer_info_func="tests.fixtures.handlers.bearer_info", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + c.post() + with pytest.raises(SecurityException): + handler_with_security() + + mock_ack.assert_not_called() + + +def test_http_bearer_fails_with_not_bearer(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BEARER, + bearer_format="test", + x_bearer_info_func="tests.fixtures.handlers.bearer_info", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"Authorization": f"not_bearer {basic_auth}"} + c.post(headers=headers) + with pytest.raises(ValueError): + handler_with_security() + + mock_ack.assert_not_called() + + +def test_http_bearer_fails_with_basic(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BEARER, + bearer_format="test", + x_bearer_info_func="tests.fixtures.handlers.bearer_info", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"Authorization": f"basic {basic_auth}"} + c.post(headers=headers) + with pytest.raises(SecurityException): + handler_with_security() + + mock_ack.assert_not_called() + + +def test_http_bearer_fails_with_invalid_header_format(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BEARER, + bearer_format="test", + x_bearer_info_func="tests.fixtures.handlers.bearer_info", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"Authorization": f"{basic_auth}"} + c.post(headers=headers) + with pytest.raises(SecurityException): + handler_with_security() + + mock_ack.assert_not_called() + + +def test_http_bearer_fails_bad_bearer_info_func(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BEARER, + bearer_format="test", + x_bearer_info_func="tests.fixtures.handlers.bearer_info_bad", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"Authorization": f"bearer {basic_auth}"} + c.post(headers=headers) + with pytest.raises(SecurityException): + handler_with_security() + + mock_ack.assert_not_called() + + +def test_http_bearer_fails_bearer_info_func_not_found(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BEARER, + bearer_format="test", + x_bearer_info_func="tests.fixtures.handlers.bearer_not_found", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + with pytest.raises(SecurityException): + factory = security_handler_factory(requirements, schemes) + factory(on_connect) + + mock_ack.assert_not_called() + + +def test_http_api_key_works_header(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP_API_KEY, + name="api_key", + in_=ApiKeyLocation.HEADER, + x_api_key_info_func="tests.fixtures.handlers.api_key_info", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"api_key": f"{basic_auth}"} + c.post(headers=headers) + handler_with_security() + + mock_ack.assert_called_once() + + +def test_http_api_key_works_query(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP_API_KEY, + name="api_key", + in_=ApiKeyLocation.QUERY, + x_api_key_info_func="tests.fixtures.handlers.api_key_info", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:password".encode()).decode() + c.post(f"/?api_key={basic_auth}") + handler_with_security() + + mock_ack.assert_called_once() + + +def test_http_api_key_works_cookie(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP_API_KEY, + name="api_key", + in_=ApiKeyLocation.COOKIE, + x_api_key_info_func="tests.fixtures.handlers.api_key_info", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:password".encode()).decode() + c.set_cookie("test", "api_key", basic_auth) + c.post() + handler_with_security() + + mock_ack.assert_called_once() + + +def test_http_api_key_fails_missing_api_key_info_func(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP_API_KEY, + name="api_key", + in_=ApiKeyLocation.COOKIE, + x_api_key_info_func="tests.fixtures.handlers.api_key_info_fake", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + with pytest.raises(SecurityException): + factory = security_handler_factory(requirements, schemes) + factory(on_connect) + + mock_ack.assert_not_called() + + +def test_http_api_key_fails_missing_cookie(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP_API_KEY, + name="api_key", + in_=ApiKeyLocation.COOKIE, + x_api_key_info_func="tests.fixtures.handlers.api_key_info", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + c.set_cookie("test", "wrong", "value") + + c.post() + with pytest.raises(SecurityException): + handler_with_security() + + mock_ack.assert_not_called() + + +def test_http_api_fails_bad_api_key_info_func(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP_API_KEY, + name="api_key", + in_=ApiKeyLocation.HEADER, + x_api_key_info_func="tests.fixtures.handlers.api_key_info_bad", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"api_key": f"{basic_auth}"} + c.post(headers=headers) + with pytest.raises(SecurityException): + handler_with_security() + + mock_ack.assert_not_called() + + +def test_oauth2_works(): + requirements = [{"basic": ["a"]}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.OAUTH2, + flows=OAuth2Flows( + implicit=OAuth2Flow(authorization_url="https://test", scopes={"a": "A"}) + ), + x_token_info_func="tests.fixtures.handlers.token_info", + x_scope_validate_func="asynction.security.validate_scopes", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"Authorization": f"bearer {basic_auth}"} + c.post(headers=headers) + handler_with_security() + + mock_ack.assert_called_once() + + +def test_oauth2_works_alternate(): + requirements = [{"basic": ["a"]}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.OAUTH2, + flows=OAuth2Flows( + implicit=OAuth2Flow(authorization_url="https://test", scopes={"a": "A"}) + ), + x_token_info_func="tests.fixtures.handlers.token_info_alternate", + x_scope_validate_func="asynction.security.validate_scopes", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"Authorization": f"bearer {basic_auth}"} + c.post(headers=headers) + handler_with_security() + + mock_ack.assert_called_once() + + +def test_oauth2_fails_missing_token_info_func(): + requirements = [{"basic": ["a"]}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.OAUTH2, + flows=OAuth2Flows( + implicit=OAuth2Flow(authorization_url="https://test", scopes={"a": "A"}) + ), + x_token_info_func="tests.fixtures.handlers.token_info_fake", + x_scope_validate_func="asynction.security.validate_scopes", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + with pytest.raises(SecurityException): + factory = security_handler_factory(requirements, schemes) + factory(on_connect) + + mock_ack.assert_not_called() + + +def test_oauth2_fails_missing_scopes(): + requirements = [{"basic": ["z"]}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.OAUTH2, + flows=OAuth2Flows( + implicit=OAuth2Flow(authorization_url="https://test", scopes={"z": "Z"}) + ), + x_token_info_func="tests.fixtures.handlers.token_info", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"Authorization": f"bearer {basic_auth}"} + c.post(headers=headers) + with pytest.raises(SecurityException): + handler_with_security() + + mock_ack.assert_not_called() + + +def test_oauth2_fails_bad_token_info_func(): + requirements = [{"basic": ["a"]}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.OAUTH2, + flows=OAuth2Flows( + implicit=OAuth2Flow(authorization_url="https://test", scopes={"a": "A"}) + ), + x_token_info_func="tests.fixtures.handlers.token_info_bad", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"Authorization": f"bearer {basic_auth}"} + c.post(headers=headers) + with pytest.raises(SecurityException): + handler_with_security() + + mock_ack.assert_not_called() + + +def test_oauth2_fails_bad_scope_type(): + requirements = [{"basic": ["a"]}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.OAUTH2, + flows=OAuth2Flows( + implicit=OAuth2Flow(authorization_url="https://test", scopes={"a": "A"}) + ), + x_token_info_func="tests.fixtures.handlers.token_info_alternate_invalid", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"Authorization": f"bearer {basic_auth}"} + c.post(headers=headers) + with pytest.raises(ValueError): + handler_with_security() + + mock_ack.assert_not_called() + + +def test_oauth2_fails_token_info_missing_required(): + requirements = [{"basic": ["a"]}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.OAUTH2, + flows=OAuth2Flows( + implicit=OAuth2Flow(authorization_url="https://test", scopes={"a": "A"}) + ), + x_token_info_func="tests.fixtures.handlers.token_info_missing_required", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"Authorization": f"bearer {basic_auth}"} + c.post(headers=headers) + with pytest.raises(ValueError): + handler_with_security() + + mock_ack.assert_not_called() + + +def test_oauth2_fails_missing_auth_header(): + requirements = [{"basic": ["a"]}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.OAUTH2, + flows=OAuth2Flows( + implicit=OAuth2Flow(authorization_url="https://test", scopes={"a": "A"}) + ), + x_token_info_func="tests.fixtures.handlers.token_info", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + c.post() + with pytest.raises(SecurityException): + handler_with_security() + + mock_ack.assert_not_called() + + +def test_oauth2_fails_invalid_header_format(): + requirements = [{"basic": ["a"]}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.OAUTH2, + flows=OAuth2Flows( + implicit=OAuth2Flow(authorization_url="https://test", scopes={"a": "A"}) + ), + x_token_info_func="tests.fixtures.handlers.token_info", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"Authorization": f"not_bearer {basic_auth}"} + c.post(headers=headers) + with pytest.raises(SecurityException): + handler_with_security() + + mock_ack.assert_not_called() + + +def test_http_basic_missing_auth_header(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BASIC, + x_basic_info_func="tests.fixtures.handlers.basic_info", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + c.post() + with pytest.raises(SecurityException): + handler_with_security() + + mock_ack.assert_not_called() + + +def test_http_basic_invalid_auth_header(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BASIC, + x_basic_info_func="tests.fixtures.handlers.basic_info", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"Authorization": f"basic{basic_auth}"} + c.post(headers=headers) + with pytest.raises(SecurityException): + handler_with_security() + + mock_ack.assert_not_called() + + +def test_http_basic_invalid_basic_auth_format(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BASIC, + x_basic_info_func="tests.fixtures.handlers.basic_info", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:".encode()).decode() + headers = {"Authorization": f"basic {basic_auth}"} + c.post(headers=headers) + with pytest.raises(SecurityException): + handler_with_security() + + mock_ack.assert_not_called() + + +def test_http_basic_invalid_basic_auth_scheme(): + requirements = [{"basic": []}] + schemes = dict( + basic=SecurityScheme( + SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BASIC, + x_basic_info_func="tests.fixtures.handlers.basic_info", + ) + ) + + mock_ack = Mock() + + def on_connect(*args, **kwargs): + mock_ack() + + factory = security_handler_factory(requirements, schemes) + handler_with_security = factory(on_connect) + with Flask(__name__).test_client() as c: + basic_auth = base64.b64encode("username:password".encode()).decode() + headers = {"Authorization": f"bearer {basic_auth}"} # expects basic + c.post(headers=headers) + with pytest.raises(SecurityException): + handler_with_security() + + mock_ack.assert_not_called() diff --git a/tests/unit/test_server.py b/tests/unit/test_server.py index 30f1009..1f08af4 100644 --- a/tests/unit/test_server.py +++ b/tests/unit/test_server.py @@ -7,6 +7,7 @@ from asynction.exceptions import MessageAckValidationException from asynction.exceptions import PayloadValidationException +from asynction.exceptions import SecurityException from asynction.exceptions import ValidationException from asynction.server import AsynctionSocketIO from asynction.server import SocketIO @@ -18,12 +19,18 @@ from asynction.types import Channel from asynction.types import ChannelBindings from asynction.types import ChannelHandlers +from asynction.types import Components from asynction.types import ErrorHandler +from asynction.types import HTTPAuthenticationScheme from asynction.types import Info from asynction.types import Message from asynction.types import MessageAck from asynction.types import OneOfMessages from asynction.types import Operation +from asynction.types import SecurityScheme +from asynction.types import SecuritySchemesType +from asynction.types import Server +from asynction.types import ServerProtocol from asynction.types import WebSocketsChannelBindings from tests.fixtures import FixturePaths from tests.fixtures.handlers import connect @@ -431,7 +438,7 @@ def test_register_handlers_registers_default_error_handler( None, ) - server._register_handlers(optional_error_handler) + server._register_handlers(default_error_handler=optional_error_handler) assert server.default_exception_handler == optional_error_handler @@ -445,7 +452,7 @@ def test_register_namespace_handlers_wraps_bindings_validator_if_validation_enab server = AsynctionSocketIO(mock.Mock(), True, True, None) server._register_namespace_handlers( - GLOBAL_NAMESPACE, channel_handlers, channel_bindings + GLOBAL_NAMESPACE, channel_handlers, channel_bindings, [] ) event_name, registered_handler, _ = server.handlers[0] assert event_name == "connect" @@ -469,7 +476,7 @@ def test_register_namespace_handlers_omits_bindings_validator_if_validation_disa server = AsynctionSocketIO(mock.Mock(), False, True, None) server._register_namespace_handlers( - GLOBAL_NAMESPACE, channel_handlers, channel_bindings + GLOBAL_NAMESPACE, channel_handlers, channel_bindings, [] ) event_name, registered_handler, _ = server.handlers[0] assert event_name == "connect" @@ -483,6 +490,46 @@ def test_register_namespace_handlers_omits_bindings_validator_if_validation_disa assert True +def test_register_namespace_handlers_emits_security_validator_if_security_enabled(): + channel_handlers = ChannelHandlers(connect="tests.fixtures.handlers.connect") + spec = AsyncApiSpec( + asyncapi="2.2.0", + info=Info("test", "1.0.0"), + servers={ + "test": Server("https://localhost/", ServerProtocol.WSS, [{"basic": []}]) + }, + channels={GLOBAL_NAMESPACE: Channel(x_handlers=channel_handlers)}, + components=Components( + security_schemes={ + "basic": SecurityScheme( + type=SecuritySchemesType.HTTP, + scheme=HTTPAuthenticationScheme.BASIC, + x_basic_info_func="tests.fixtures.handlers.basic_info", + ) + } + ), + ) + + server = AsynctionSocketIO(spec, False, True, None) + server._register_namespace_handlers( + GLOBAL_NAMESPACE, + channel_handlers, + None, + server.spec.servers.get("test").security, + ) + event_name, registered_handler, _ = server.handlers[0] + assert event_name == "connect" + handler_with_security = deep_unwrap(registered_handler, depth=1) + actual_handler = deep_unwrap(handler_with_security) + + with Flask(__name__).test_client() as c: + c.post() # Inject invalid POST request + actual_handler() + with pytest.raises(SecurityException): + handler_with_security() # handler raises security exception + assert True + + def test_emit_event_with_non_existent_namespace_raises_validation_exc( server_info: Info, faker: Faker ): diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index 37382df..030bfa0 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -3,13 +3,18 @@ from svarog import forge from asynction.types import GLOBAL_NAMESPACE +from asynction.types import ApiKeyLocation from asynction.types import AsyncApiSpec from asynction.types import Channel from asynction.types import ChannelBindings from asynction.types import ChannelHandlers from asynction.types import Message +from asynction.types import OAuth2Flow +from asynction.types import OAuth2Flows from asynction.types import OneOfMessages from asynction.types import Operation +from asynction.types import SecurityScheme +from asynction.types import SecuritySchemesType def test_message_deserialisation(faker: Faker): @@ -208,9 +213,230 @@ def test_async_api_spec_from_and_to_dict(faker: Faker): }, } }, - "servers": {"development": {"url": "localhost", "protocol": "ws"}}, + "servers": { + "development": { + "url": "localhost", + "protocol": "ws", + "security": [{"test": []}], + } + }, + "components": { + "securitySchemes": { + "test": {"type": "http", "scheme": "basic"}, + "test2": {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}, + "testApiKey": {"type": "httpApiKey", "name": "test", "in": "header"}, + "oauth2": { + "type": "oauth2", + "flows": { + "implicit": { + "authorizationUrl": "https://localhost:12345", + "refreshUrl": "https://localhost:12345/refresh", + "scopes": {"a": "A", "b": "B"}, + } + }, + }, + } + }, } spec = AsyncApiSpec.from_dict(data) assert isinstance(spec, AsyncApiSpec) assert spec.to_dict() == data + + +def test_oauth2_implicit_flow_validation(): + scopes = {"a": "A", "b": "B"} + # authorization_url is required for implicit flow + flow = OAuth2Flow(scopes=scopes, authorization_url=None) + + with pytest.raises(ValueError): + OAuth2Flows(implicit=flow) + + +def test_oauth2_password_flow_validation(): + scopes = {"a": "A", "b": "B"} + # token_url is required for password flow + flow = OAuth2Flow(scopes=scopes, token_url=None) + + with pytest.raises(ValueError): + OAuth2Flows(password=flow) + + +def test_oauth2_client_credentials_flow_validation(): + scopes = {"a": "A", "b": "B"} + # token_url is required for client_credentials flow + flow = OAuth2Flow(scopes=scopes, token_url=None) + + with pytest.raises(ValueError): + OAuth2Flows(client_credentials=flow) + + +def test_oauth2_authorization_code_flow_validation(): + scopes = {"a": "A", "b": "B"} + # token_url is required for authorization_code flow + flow = OAuth2Flow(scopes=scopes, token_url=None) + + with pytest.raises(ValueError): + OAuth2Flows(authorization_code=flow) + + +def test_security_scheme_validation(): + with pytest.raises(ValueError): + # missing flows + SecurityScheme(type=SecuritySchemesType.OAUTH2) + + with pytest.raises(ValueError): + # missing flows + SecurityScheme(type=SecuritySchemesType.OPENID_CONNECT) + + with pytest.raises(ValueError): + # missing scheme + SecurityScheme(type=SecuritySchemesType.HTTP) + + with pytest.raises(ValueError): + # missing in + SecurityScheme(type=SecuritySchemesType.HTTP_API_KEY) + with pytest.raises(ValueError): + # missing name + SecurityScheme(type=SecuritySchemesType.HTTP_API_KEY, in_=ApiKeyLocation.HEADER) + + +def test_asyncapi_spec_validation_invalid_security_requirement(faker: Faker): + data = { + "asyncapi": "2.2.0", + "info": { + "title": faker.sentence(), + "version": faker.pystr(), + "description": faker.sentence(), + }, + "channels": {}, + "servers": { + "development": { + "url": "localhost", + "protocol": "ws", + "security": [{"test": [], "invalid": "A"}], + } + }, + "components": { + "securitySchemes": { + "test": {"type": "http", "scheme": "basic"}, + "test2": {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}, + "testApiKey": {"type": "httpApiKey", "name": "test", "in": "header"}, + "oauth2": { + "type": "oauth2", + "flows": { + "implicit": { + "authorizationUrl": "https://localhost:12345", + "refreshUrl": "https://localhost:12345/refresh", + "scopes": {"a": "A", "b": "B"}, + } + }, + }, + } + }, + } + with pytest.raises(ValueError): + # missing security scheme + AsyncApiSpec.from_dict(data) + + +def test_asyncapi_spec_validation_invalid_security_requirement_scopes(faker: Faker): + data = { + "asyncapi": "2.2.0", + "info": { + "title": faker.sentence(), + "version": faker.pystr(), + "description": faker.sentence(), + }, + "channels": {}, + "servers": { + "development": { + "url": "localhost", + "protocol": "ws", + "security": [{"test": ["a"]}], + } + }, + "components": { + "securitySchemes": { + "test": {"type": "http", "scheme": "basic"}, + "test2": {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}, + "testApiKey": {"type": "httpApiKey", "name": "test", "in": "header"}, + "oauth2": { + "type": "oauth2", + "flows": { + "implicit": { + "authorizationUrl": "https://localhost:12345", + "refreshUrl": "https://localhost:12345/refresh", + "scopes": {"a": "A", "b": "B"}, + } + }, + }, + } + }, + } + with pytest.raises(ValueError): + # missing security scheme + AsyncApiSpec.from_dict(data) + + +def test_asyncapi_spec_validation_invalid_security_requirement_undefined_scopes( + faker: Faker, +): + data = { + "asyncapi": "2.2.0", + "info": { + "title": faker.sentence(), + "version": faker.pystr(), + "description": faker.sentence(), + }, + "channels": {}, + "servers": { + "development": { + "url": "localhost", + "protocol": "ws", + "security": [{"oauth2": ["undefined"]}], + } + }, + "components": { + "securitySchemes": { + "test": {"type": "http", "scheme": "basic"}, + "test2": {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}, + "testApiKey": {"type": "httpApiKey", "name": "test", "in": "header"}, + "oauth2": { + "type": "oauth2", + "flows": { + "implicit": { + "authorizationUrl": "https://localhost:12345", + "refreshUrl": "https://localhost:12345/refresh", + "scopes": {"a": "A", "b": "B"}, + } + }, + }, + } + }, + } + with pytest.raises(ValueError): + # missing security scheme + AsyncApiSpec.from_dict(data) + + +def test_asyncapi_spec_validation_missing_security_scheme(faker: Faker): + data = { + "asyncapi": "2.2.0", + "info": { + "title": faker.sentence(), + "version": faker.pystr(), + "description": faker.sentence(), + }, + "channels": {}, + "servers": { + "development": { + "url": "localhost", + "protocol": "ws", + "security": [{"test": []}], + } + }, + } + with pytest.raises(ValueError): + # missing security scheme + AsyncApiSpec.from_dict(data)