Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use resolver in security middleware #1553

Merged
merged 2 commits into from
Jun 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions connexion/apis/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
self,
specification: t.Union[pathlib.Path, str, dict],
base_path: t.Optional[str] = None,
resolver: t.Optional[Resolver] = None,
arguments: t.Optional[dict] = None,
options: t.Optional[dict] = None,
*args,
Expand All @@ -48,6 +49,9 @@ def __init__(
:param specification: OpenAPI specification. Can be provided either as dict, or as path
to file.
:param base_path: Base path to host the API.
:param resolver: Callable that maps operationID to a function
:param resolver_error_handler: Callable that generates an Operation used for handling
ResolveErrors
RobbeSneyders marked this conversation as resolved.
Show resolved Hide resolved
:param arguments: Jinja arguments to resolve in specification.
:param options: New style options dictionary.
"""
Expand All @@ -70,6 +74,8 @@ def __init__(

self._set_base_path(base_path)

self.resolver = resolver or Resolver()

def _set_base_path(self, base_path: t.Optional[str] = None) -> None:
if base_path is not None:
# update spec to include user-provided base_path
Expand Down Expand Up @@ -121,17 +127,13 @@ class AbstractRoutingAPI(AbstractSpecAPI):
def __init__(
self,
*args,
resolver: t.Optional[Resolver] = None,
resolver_error_handler: t.Optional[t.Callable] = None,
debug: bool = False,
pass_context_arg_name: t.Optional[str] = None,
**kwargs
) -> None:
"""Minimal interface of an API, with only functionality related to routing.

:param resolver: Callable that maps operationID to a function
:param resolver_error_handler: Callable that generates an Operation used for handling
ResolveErrors
:param debug: Flag to run in debug mode
:param pass_context_arg_name: If not None URL request handling functions with an argument
matching this name will be passed the framework's request context.
Expand All @@ -140,8 +142,6 @@ def __init__(
self.debug = debug
self.resolver_error_handler = resolver_error_handler

self.resolver = resolver or Resolver()

logger.debug('pass_context_arg_name: %s', pass_context_arg_name)
self.pass_context_arg_name = pass_context_arg_name

Expand Down
7 changes: 6 additions & 1 deletion connexion/middleware/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from connexion.apis import AbstractRoutingAPI
from connexion.exceptions import NotFoundProblem
from connexion.middleware import AppMiddleware
from connexion.operations import AbstractOperation
from connexion.resolver import Resolver

ROUTING_CONTEXT = 'connexion_routing'
Expand Down Expand Up @@ -91,7 +92,7 @@ def __init__(
def add_operation(self, path: str, method: str) -> None:
operation_cls = self.specification.operation_cls
operation = operation_cls.from_spec(self.specification, self, path, method, self.resolver)
routing_operation = RoutingOperation(operation.operation_id, next_app=self.next_app)
routing_operation = RoutingOperation.from_operation(operation, next_app=self.next_app)
self._add_operation_internal(method, path, routing_operation)

def _add_operation_internal(self, method: str, path: str, operation: 'RoutingOperation') -> None:
Expand All @@ -104,6 +105,10 @@ def __init__(self, operation_id: t.Optional[str], next_app: ASGIApp) -> None:
self.operation_id = operation_id
self.next_app = next_app

@classmethod
def from_operation(cls, operation: AbstractOperation, next_app: ASGIApp):
return cls(operation.operation_id, next_app)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Attach operation to scope and pass it to the next app"""
original_scope = _scope.get()
Expand Down
57 changes: 38 additions & 19 deletions connexion/middleware/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
from starlette.types import ASGIApp, Receive, Scope, Send

from connexion.apis.abstract import AbstractSpecAPI
from connexion.exceptions import MissingMiddleware
from connexion.exceptions import MissingMiddleware, ProblemException
from connexion.http_facts import METHODS
from connexion.lifecycle import MiddlewareRequest
from connexion.middleware import AppMiddleware
from connexion.middleware.routing import ROUTING_CONTEXT
from connexion.operations import AbstractOperation
from connexion.resolver import ResolverError
from connexion.security import SecurityHandlerFactory
from connexion.spec import Specification

logger = logging.getLogger("connexion.middleware.security")

Expand Down Expand Up @@ -69,8 +72,6 @@ def __init__(
):
super().__init__(specification, *args, **kwargs)
self.security_handler_factory = SecurityHandlerFactory('context')
self.app_security = self.specification.security
self.security_schemes = self.specification.security_definitions

if auth_all_paths:
self.add_auth_on_not_found()
Expand All @@ -81,30 +82,36 @@ def __init__(

def add_auth_on_not_found(self):
"""Register a default SecurityOperation for routes that are not found."""
default_operation = self.make_operation()
default_operation = self.make_operation(self.specification)
self.operations = defaultdict(lambda: default_operation)

def add_paths(self):
paths = self.specification.get('paths', {})
for path, methods in paths.items():
for method, operation in methods.items():
for method in methods:
if method not in METHODS:
continue
operation_id = operation.get('operationId')
if operation_id:
self.operations[operation_id] = self.make_operation(operation)

def make_operation(self, operation_spec: dict = None):
security = self.app_security
if operation_spec:
security = operation_spec.get('security', self.app_security)

return SecurityOperation(
self.security_handler_factory,
security=security,
security_schemes=self.specification.security_definitions
try:
self.add_operation(path, method)
except ResolverError:
# ResolverErrors are either raised or handled in routing middleware.
pass

def add_operation(self, path: str, method: str) -> None:
operation_cls = self.specification.operation_cls
operation = operation_cls.from_spec(self.specification, self, path, method, self.resolver)
security_operation = self.make_operation(operation)
self._add_operation_internal(operation.operation_id, security_operation)

def make_operation(self, operation: t.Union[AbstractOperation, Specification]):
return SecurityOperation.from_operation(
operation,
security_handler_factory=self.security_handler_factory,
)

def _add_operation_internal(self, operation_id: str, operation: 'SecurityOperation'):
self.operations[operation_id] = operation


class SecurityOperation:

Expand All @@ -119,6 +126,18 @@ def __init__(
self.security_schemes = security_schemes
self.verification_fn = self._get_verification_fn()

@classmethod
def from_operation(
cls,
operation: AbstractOperation,
security_handler_factory: SecurityHandlerFactory
):
return cls(
security_handler_factory,
security=operation.security,
security_schemes=operation.security_schemes
)

def _get_verification_fn(self):
logger.debug('... Security: %s', self.security, extra=vars(self))
if not self.security:
Expand Down Expand Up @@ -234,5 +253,5 @@ async def __call__(self, request: MiddlewareRequest):
await self.verification_fn(request)


class MissingSecurityOperation(Exception):
class MissingSecurityOperation(ProblemException):
pass
12 changes: 11 additions & 1 deletion connexion/operations/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def user_provided_handler_function(important, stuff):
serious_business(stuff)
"""
def __init__(self, api, method, path, operation, resolver,
app_security=None, security_schemes=None,
validate_responses=False, strict_validation=False,
randomize_endpoint=None, validator_map=None,
pythonic_params=False, uri_parser_class=None,
Expand All @@ -57,7 +58,6 @@ def __init__(self, api, method, path, operation, resolver,
:param operation: swagger operation object
:type operation: dict
:param resolver: Callable that maps operationID to a function
:param app_produces: list of content types the application can return by default
:param app_security: list of security rules the application uses by default
:type app_security: list
:param security_schemes: `Security Definitions Object
Expand Down Expand Up @@ -85,6 +85,8 @@ def __init__(self, api, method, path, operation, resolver,
self._path = path
self._operation = operation
self._resolver = resolver
self._security = operation.get('security', app_security)
self._security_schemes = security_schemes
self._validate_responses = validate_responses
self._strict_validation = strict_validation
self._pythonic_params = pythonic_params
Expand Down Expand Up @@ -119,6 +121,14 @@ def path(self):
"""
return self._path

@property
def security(self):
return self._security

@property
def security_schemes(self):
return self._security_schemes

@property
def responses(self):
"""
Expand Down
10 changes: 10 additions & 0 deletions connexion/operations/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class OpenAPIOperation(AbstractOperation):
"""

def __init__(self, api, method, path, operation, resolver, path_parameters=None,
app_security=None, security_schemes=None,
components=None, validate_responses=False, strict_validation=False,
randomize_endpoint=None, validator_map=None,
pythonic_params=False, uri_parser_class=None, pass_context_arg_name=None):
Expand All @@ -44,6 +45,11 @@ def __init__(self, api, method, path, operation, resolver, path_parameters=None,
:param resolver: Callable that maps operationID to a function
:param path_parameters: Parameters defined in the path level
:type path_parameters: list
:param app_security: list of security rules the application uses by default
:type app_security: list
:param security_schemes: `Security Definitions Object
<https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#security-definitions-object>`_
:type security_schemes: dict
:param components: `Components Object
<https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.1.md#componentsObject>`_
:type components: dict
Expand Down Expand Up @@ -76,6 +82,8 @@ def __init__(self, api, method, path, operation, resolver, path_parameters=None,
path=path,
operation=operation,
resolver=resolver,
app_security=app_security,
security_schemes=security_schemes,
validate_responses=validate_responses,
strict_validation=strict_validation,
randomize_endpoint=randomize_endpoint,
Expand Down Expand Up @@ -116,6 +124,8 @@ def from_spec(cls, spec, api, path, method, resolver, *args, **kwargs):
spec.get_operation(path, method),
resolver=resolver,
path_parameters=spec.get_path_params(path),
app_security=spec.security,
security_schemes=spec.security_schemes,
components=spec.components,
*args,
**kwargs
Expand Down
12 changes: 11 additions & 1 deletion connexion/operations/swagger2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class Swagger2Operation(AbstractOperation):
"""

def __init__(self, api, method, path, operation, resolver, app_produces, app_consumes,
path_parameters=None, definitions=None, validate_responses=False,
path_parameters=None, app_security=None, security_schemes=None,
definitions=None, validate_responses=False,
strict_validation=False, randomize_endpoint=None, validator_map=None,
pythonic_params=False, uri_parser_class=None, pass_context_arg_name=None):
"""
Expand All @@ -47,6 +48,11 @@ def __init__(self, api, method, path, operation, resolver, app_produces, app_con
:type app_consumes: list
:param path_parameters: Parameters defined in the path level
:type path_parameters: list
:param app_security: list of security rules the application uses by default
:type app_security: list
:param security_schemes: `Security Definitions Object
<https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#security-definitions-object>`_
:type security_schemes: dict
:param definitions: `Definitions Object
<https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#definitionsObject>`_
:type definitions: dict
Expand Down Expand Up @@ -77,6 +83,8 @@ def __init__(self, api, method, path, operation, resolver, app_produces, app_con
path=path,
operation=operation,
resolver=resolver,
app_security=app_security,
security_schemes=security_schemes,
validate_responses=validate_responses,
strict_validation=strict_validation,
randomize_endpoint=randomize_endpoint,
Expand Down Expand Up @@ -112,6 +120,8 @@ def from_spec(cls, spec, api, path, method, resolver, *args, **kwargs):
path_parameters=spec.get_path_params(path),
app_produces=spec.produces,
app_consumes=spec.consumes,
app_security=spec.security,
security_schemes=spec.security_schemes,
definitions=spec.definitions,
*args,
**kwargs
Expand Down
4 changes: 2 additions & 2 deletions connexion/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def response_definitions(self):
return self._spec['responses']

@property
def security_definitions(self):
def security_schemes(self):
return self._spec.get('securityDefinitions', {})

@property
Expand Down Expand Up @@ -268,7 +268,7 @@ def _set_defaults(cls, spec):
spec.setdefault('components', {})

@property
def security_definitions(self):
def security_schemes(self):
return self._spec['components'].get('securitySchemes', {})

@property
Expand Down