Skip to content

Commit

Permalink
Use resolver in security middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
RobbeSneyders committed Jun 20, 2022
1 parent b561ecf commit 3f0a7d9
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 31 deletions.
16 changes: 8 additions & 8 deletions connexion/apis/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(
self,
specification: t.Union[pathlib.Path, str, dict],
base_path: t.Optional[str] = None,
resolver: t.Optional[Resolver] = None,
resolver_error_handler: t.Optional[t.Callable] = None,
arguments: t.Optional[dict] = None,
options: t.Optional[dict] = None,
*args,
Expand All @@ -48,6 +50,9 @@ def __init__(
:param specification: OpenAPI specification. Can be provided either as dict, or as path
to file.
:param base_path: Base path to host the API.
:param resolver: Callable that maps operationID to a function
:param resolver_error_handler: Callable that generates an Operation used for handling
ResolveErrors
:param arguments: Jinja arguments to resolve in specification.
:param options: New style options dictionary.
"""
Expand All @@ -70,6 +75,9 @@ def __init__(

self._set_base_path(base_path)

self.resolver_error_handler = resolver_error_handler
self.resolver = resolver or Resolver()

def _set_base_path(self, base_path: t.Optional[str] = None) -> None:
if base_path is not None:
# update spec to include user-provided base_path
Expand Down Expand Up @@ -121,26 +129,18 @@ class AbstractRoutingAPI(AbstractSpecAPI):
def __init__(
self,
*args,
resolver: t.Optional[Resolver] = None,
resolver_error_handler: t.Optional[t.Callable] = None,
debug: bool = False,
pass_context_arg_name: t.Optional[str] = None,
**kwargs
) -> None:
"""Minimal interface of an API, with only functionality related to routing.
:param resolver: Callable that maps operationID to a function
:param resolver_error_handler: Callable that generates an Operation used for handling
ResolveErrors
:param debug: Flag to run in debug mode
:param pass_context_arg_name: If not None URL request handling functions with an argument
matching this name will be passed the framework's request context.
"""
super().__init__(*args, **kwargs)
self.debug = debug
self.resolver_error_handler = resolver_error_handler

self.resolver = resolver or Resolver()

logger.debug('pass_context_arg_name: %s', pass_context_arg_name)
self.pass_context_arg_name = pass_context_arg_name
Expand Down
57 changes: 38 additions & 19 deletions connexion/middleware/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
from starlette.types import ASGIApp, Receive, Scope, Send

from connexion.apis.abstract import AbstractSpecAPI
from connexion.exceptions import MissingMiddleware
from connexion.exceptions import MissingMiddleware, ProblemException
from connexion.http_facts import METHODS
from connexion.lifecycle import MiddlewareRequest
from connexion.middleware import AppMiddleware
from connexion.middleware.routing import ROUTING_CONTEXT
from connexion.operations import AbstractOperation
from connexion.resolver import ResolverError
from connexion.security import SecurityHandlerFactory
from connexion.spec import Specification

logger = logging.getLogger("connexion.middleware.security")

Expand Down Expand Up @@ -69,8 +72,6 @@ def __init__(
):
super().__init__(specification, *args, **kwargs)
self.security_handler_factory = SecurityHandlerFactory('context')
self.app_security = self.specification.security
self.security_schemes = self.specification.security_definitions

if auth_all_paths:
self.add_auth_on_not_found()
Expand All @@ -81,30 +82,36 @@ def __init__(

def add_auth_on_not_found(self):
"""Register a default SecurityOperation for routes that are not found."""
default_operation = self.make_operation()
default_operation = self.make_operation(self.specification)
self.operations = defaultdict(lambda: default_operation)

def add_paths(self):
paths = self.specification.get('paths', {})
for path, methods in paths.items():
for method, operation in methods.items():
for method in methods:
if method not in METHODS:
continue
operation_id = operation.get('operationId')
if operation_id:
self.operations[operation_id] = self.make_operation(operation)

def make_operation(self, operation_spec: dict = None):
security = self.app_security
if operation_spec:
security = operation_spec.get('security', self.app_security)

return SecurityOperation(
self.security_handler_factory,
security=security,
security_schemes=self.specification.security_definitions
try:
self.add_operation(path, method)
except ResolverError:
# ResolverErrors are either raised or handled in routing middleware.
pass

def add_operation(self, path: str, method: str) -> None:
operation_cls = self.specification.operation_cls
operation = operation_cls.from_spec(self.specification, self, path, method, self.resolver)
security_operation = self.make_operation(operation)
self._add_operation_internal(operation.operation_id, security_operation)

def make_operation(self, operation: t.Union[AbstractOperation, Specification]):
return SecurityOperation.from_operation(
operation,
security_handler_factory=self.security_handler_factory,
)

def _add_operation_internal(self, operation_id: str, operation: 'SecurityOperation'):
self.operations[operation_id] = operation


class SecurityOperation:

Expand All @@ -119,6 +126,18 @@ def __init__(
self.security_schemes = security_schemes
self.verification_fn = self._get_verification_fn()

@classmethod
def from_operation(
cls,
operation: AbstractOperation,
security_handler_factory: SecurityHandlerFactory
):
return cls(
security_handler_factory,
security=operation.security,
security_schemes=operation.security_schemes
)

def _get_verification_fn(self):
logger.debug('... Security: %s', self.security, extra=vars(self))
if not self.security:
Expand Down Expand Up @@ -234,5 +253,5 @@ async def __call__(self, request: MiddlewareRequest):
await self.verification_fn(request)


class MissingSecurityOperation(Exception):
class MissingSecurityOperation(ProblemException):
pass
12 changes: 11 additions & 1 deletion connexion/operations/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def user_provided_handler_function(important, stuff):
serious_business(stuff)
"""
def __init__(self, api, method, path, operation, resolver,
app_security=None, security_schemes=None,
validate_responses=False, strict_validation=False,
randomize_endpoint=None, validator_map=None,
pythonic_params=False, uri_parser_class=None,
Expand All @@ -57,7 +58,6 @@ def __init__(self, api, method, path, operation, resolver,
:param operation: swagger operation object
:type operation: dict
:param resolver: Callable that maps operationID to a function
:param app_produces: list of content types the application can return by default
:param app_security: list of security rules the application uses by default
:type app_security: list
:param security_schemes: `Security Definitions Object
Expand Down Expand Up @@ -85,6 +85,8 @@ def __init__(self, api, method, path, operation, resolver,
self._path = path
self._operation = operation
self._resolver = resolver
self._security = operation.get('security', app_security)
self._security_schemes = security_schemes
self._validate_responses = validate_responses
self._strict_validation = strict_validation
self._pythonic_params = pythonic_params
Expand Down Expand Up @@ -119,6 +121,14 @@ def path(self):
"""
return self._path

@property
def security(self):
return self._security

@property
def security_schemes(self):
return self._security_schemes

@property
def responses(self):
"""
Expand Down
10 changes: 10 additions & 0 deletions connexion/operations/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class OpenAPIOperation(AbstractOperation):
"""

def __init__(self, api, method, path, operation, resolver, path_parameters=None,
app_security=None, security_schemes=None,
components=None, validate_responses=False, strict_validation=False,
randomize_endpoint=None, validator_map=None,
pythonic_params=False, uri_parser_class=None, pass_context_arg_name=None):
Expand All @@ -44,6 +45,11 @@ def __init__(self, api, method, path, operation, resolver, path_parameters=None,
:param resolver: Callable that maps operationID to a function
:param path_parameters: Parameters defined in the path level
:type path_parameters: list
:param app_security: list of security rules the application uses by default
:type app_security: list
:param security_schemes: `Security Definitions Object
<https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#security-definitions-object>`_
:type security_schemes: dict
:param components: `Components Object
<https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.1.md#componentsObject>`_
:type components: dict
Expand Down Expand Up @@ -76,6 +82,8 @@ def __init__(self, api, method, path, operation, resolver, path_parameters=None,
path=path,
operation=operation,
resolver=resolver,
app_security=app_security,
security_schemes=security_schemes,
validate_responses=validate_responses,
strict_validation=strict_validation,
randomize_endpoint=randomize_endpoint,
Expand Down Expand Up @@ -116,6 +124,8 @@ def from_spec(cls, spec, api, path, method, resolver, *args, **kwargs):
spec.get_operation(path, method),
resolver=resolver,
path_parameters=spec.get_path_params(path),
app_security=spec.security,
security_schemes=spec.security_schemes,
components=spec.components,
*args,
**kwargs
Expand Down
12 changes: 11 additions & 1 deletion connexion/operations/swagger2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class Swagger2Operation(AbstractOperation):
"""

def __init__(self, api, method, path, operation, resolver, app_produces, app_consumes,
path_parameters=None, definitions=None, validate_responses=False,
path_parameters=None, app_security=None, security_schemes=None,
definitions=None, validate_responses=False,
strict_validation=False, randomize_endpoint=None, validator_map=None,
pythonic_params=False, uri_parser_class=None, pass_context_arg_name=None):
"""
Expand All @@ -47,6 +48,11 @@ def __init__(self, api, method, path, operation, resolver, app_produces, app_con
:type app_consumes: list
:param path_parameters: Parameters defined in the path level
:type path_parameters: list
:param app_security: list of security rules the application uses by default
:type app_security: list
:param security_schemes: `Security Definitions Object
<https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#security-definitions-object>`_
:type security_schemes: dict
:param definitions: `Definitions Object
<https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#definitionsObject>`_
:type definitions: dict
Expand Down Expand Up @@ -77,6 +83,8 @@ def __init__(self, api, method, path, operation, resolver, app_produces, app_con
path=path,
operation=operation,
resolver=resolver,
app_security=app_security,
security_schemes=security_schemes,
validate_responses=validate_responses,
strict_validation=strict_validation,
randomize_endpoint=randomize_endpoint,
Expand Down Expand Up @@ -112,6 +120,8 @@ def from_spec(cls, spec, api, path, method, resolver, *args, **kwargs):
path_parameters=spec.get_path_params(path),
app_produces=spec.produces,
app_consumes=spec.consumes,
app_security=spec.security,
security_schemes=spec.security_schemes,
definitions=spec.definitions,
*args,
**kwargs
Expand Down
4 changes: 2 additions & 2 deletions connexion/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def response_definitions(self):
return self._spec['responses']

@property
def security_definitions(self):
def security_schemes(self):
return self._spec.get('securityDefinitions', {})

@property
Expand Down Expand Up @@ -268,7 +268,7 @@ def _set_defaults(cls, spec):
spec.setdefault('components', {})

@property
def security_definitions(self):
def security_schemes(self):
return self._spec['components'].get('securitySchemes', {})

@property
Expand Down

0 comments on commit 3f0a7d9

Please sign in to comment.