diff --git a/connexion/apis/flask_api.py b/connexion/apis/flask_api.py index 69864d6dd..115daed3b 100644 --- a/connexion/apis/flask_api.py +++ b/connexion/apis/flask_api.py @@ -9,9 +9,8 @@ from flask import Response as FlaskResponse from connexion.apis.abstract import AbstractAPI -from connexion.decorators import SyncDecorator +from connexion.decorators import FlaskDecorator from connexion.frameworks import flask as flask_utils -from connexion.frameworks.flask import Flask as FlaskFramework from connexion.jsonifier import Jsonifier from connexion.operations import AbstractOperation from connexion.uri_parsing import AbstractURIParser @@ -91,12 +90,7 @@ def from_operation( @property def fn(self) -> t.Callable: - decorator = SyncDecorator( - self._operation, - uri_parser_cls=self.uri_parser_class, - framework=FlaskFramework, - parameter=True, - response=True, + decorator = FlaskDecorator( pythonic_params=self.pythonic_params, jsonifier=self.api.jsonifier, ) diff --git a/connexion/apps/async_app.py b/connexion/apps/async_app.py index cd1db9fa5..dd280e3e2 100644 --- a/connexion/apps/async_app.py +++ b/connexion/apps/async_app.py @@ -15,9 +15,8 @@ from connexion.apis.abstract import AbstractAPI from connexion.apps.abstract import AbstractApp -from connexion.decorators import AsyncDecorator +from connexion.decorators import StarletteDecorator from connexion.exceptions import MissingMiddleware, ProblemException -from connexion.frameworks.starlette import Starlette as StarletteFramework from connexion.middleware.main import ConnexionMiddleware from connexion.middleware.routing import ROUTING_CONTEXT from connexion.operations import AbstractOperation @@ -192,12 +191,7 @@ def from_operation( @property def fn(self) -> t.Callable: - decorator = AsyncDecorator( - self._operation, - uri_parser_cls=self._operation.uri_parser_class, - framework=StarletteFramework, - parameter=True, - response=True, + decorator = StarletteDecorator( pythonic_params=self.pythonic_params, jsonifier=self.api.jsonifier, ) diff --git a/connexion/context.py b/connexion/context.py index 519da947b..9f10aced0 100644 --- a/connexion/context.py +++ b/connexion/context.py @@ -1,12 +1,24 @@ from contextvars import ContextVar -from starlette.types import Scope +from starlette.types import Receive, Scope +from werkzeug.local import LocalProxy + +from connexion.operations import AbstractOperation + +UNBOUND_MESSAGE = ( + "Working outside of operation context. Make sure your app is wrapped in a " + "ContextMiddleware and you're processing a request while accessing the context." +) -_scope: ContextVar[Scope] = ContextVar("SCOPE") +_context: ContextVar[dict] = ContextVar("CONTEXT") +context = LocalProxy(_context, unbound_message=UNBOUND_MESSAGE) -def __getattr__(name): - if name == "scope": - return _scope.get() - if name == "context": - return _scope.get().get("extensions", {}).get("connexion_context", {}) +_operation: ContextVar[AbstractOperation] = ContextVar("OPERATION") +operation = LocalProxy(_operation, unbound_message=UNBOUND_MESSAGE) + +_receive: ContextVar[Receive] = ContextVar("RECEIVE") +receive = LocalProxy(_receive, unbound_message=UNBOUND_MESSAGE) + +_scope: ContextVar[Scope] = ContextVar("SCOPE") +scope = LocalProxy(_scope, unbound_message=UNBOUND_MESSAGE) diff --git a/connexion/decorators/__init__.py b/connexion/decorators/__init__.py index d6716ba55..187274535 100644 --- a/connexion/decorators/__init__.py +++ b/connexion/decorators/__init__.py @@ -1,4 +1,4 @@ """ This module defines decorators which Connexion uses to wrap user provided view functions. """ -from .main import AsyncDecorator, SyncDecorator # noqa +from .main import FlaskDecorator, StarletteDecorator # noqa diff --git a/connexion/decorators/main.py b/connexion/decorators/main.py index f2b02a917..3b565ff72 100644 --- a/connexion/decorators/main.py +++ b/connexion/decorators/main.py @@ -1,16 +1,17 @@ import abc import asyncio import functools +import json import typing as t from asgiref.sync import async_to_sync from starlette.concurrency import run_in_threadpool +from connexion.context import operation, receive, scope from connexion.decorators.parameter import ( AsyncParameterDecorator, BaseParameterDecorator, SyncParameterDecorator, - inspect_function_arguments, ) from connexion.decorators.response import ( AsyncResponseDecorator, @@ -18,39 +19,28 @@ SyncResponseDecorator, ) from connexion.frameworks.abstract import Framework -from connexion.operations import AbstractOperation +from connexion.frameworks.flask import Flask as FlaskFramework +from connexion.frameworks.starlette import Starlette as StarletteFramework from connexion.uri_parsing import AbstractURIParser class BaseDecorator: """Base class for connexion decorators.""" + framework: t.Type[Framework] + def __init__( self, - operation_spec: AbstractOperation, *, - uri_parser_cls: t.Type[AbstractURIParser], - framework: t.Type[Framework], - parameter: bool, - response: bool, pythonic_params: bool = False, - jsonifier, + uri_parser_class: AbstractURIParser = None, + jsonifier=json, ) -> None: - self.operation_spec = operation_spec - self.uri_parser = uri_parser_cls( - operation_spec.parameters, operation_spec.body_definition() - ) - self.framework = framework - self.produces = self.operation_spec.produces - self.parameter = parameter - self.response = response self.pythonic_params = pythonic_params + self.uri_parser_class = uri_parser_class self.jsonifier = jsonifier - if self.parameter: - self.arguments, self.has_kwargs = inspect_function_arguments( - operation_spec.function - ) + self.arguments, self.has_kwargs = None, None @property @abc.abstractmethod @@ -68,27 +58,25 @@ def _sync_async_decorator(self) -> t.Callable[[t.Callable], t.Callable]: """Decorator to translate between sync and async functions.""" raise NotImplementedError + @property + def uri_parser(self): + uri_parser_class = self.uri_parser_class or operation.uri_parser_class + return uri_parser_class(operation.parameters, operation.body_definition()) + def decorate(self, function: t.Callable) -> t.Callable: """Decorate a function with decorators based on the operation.""" function = self._sync_async_decorator(function) - if self.parameter: - parameter_decorator = self._parameter_decorator_cls( - self.operation_spec, - get_body_fn=self.framework.get_body, - arguments=self.arguments, - has_kwargs=self.has_kwargs, - pythonic_params=self.pythonic_params, - ) - function = parameter_decorator(function) + parameter_decorator = self._parameter_decorator_cls( + pythonic_params=self.pythonic_params, + ) + function = parameter_decorator(function) - if self.response: - response_decorator = self._response_decorator_cls( - self.operation_spec, - framework=self.framework, - jsonifier=self.jsonifier, - ) - function = response_decorator(function) + response_decorator = self._response_decorator_cls( + framework=self.framework, + jsonifier=self.jsonifier, + ) + function = response_decorator(function) return function @@ -97,7 +85,13 @@ def __call__(self, function: t.Callable) -> t.Callable: raise NotImplementedError -class SyncDecorator(BaseDecorator): +class FlaskDecorator(BaseDecorator): + """Decorator for usage with Flask. The parameter decorator works with a Flask request, + and provides Flask datastructures to the view function. The response decorator returns + a Flask response""" + + framework = FlaskFramework + @property def _parameter_decorator_cls(self) -> t.Type[SyncParameterDecorator]: return SyncParameterDecorator @@ -123,25 +117,33 @@ def wrapper(*args, **kwargs) -> t.Callable: def __call__(self, function: t.Callable) -> t.Callable: @functools.wraps(function) def wrapper(*args, **kwargs): - # TODO: move into parameter decorator? - connexion_request = self.framework.get_request( - *args, uri_parser=self.uri_parser, **kwargs - ) - + request = self.framework.get_request(uri_parser=self.uri_parser) decorated_function = self.decorate(function) - return decorated_function(connexion_request) + return decorated_function(request) return wrapper -class AsyncDecorator(BaseDecorator): +class ASGIDecorator(BaseDecorator): + """Decorator for usage with ASGI apps. The parameter decorator works with a Starlette request, + and provides Starlette datastructures to the view function. This works for any ASGI app, since + we get the request via the connexion context provided by ASGI middleware. + + This decorator does not parse responses, but passes them directly to the ASGI App.""" + + framework = StarletteFramework + @property def _parameter_decorator_cls(self) -> t.Type[AsyncParameterDecorator]: return AsyncParameterDecorator @property - def _response_decorator_cls(self) -> t.Type[AsyncResponseDecorator]: - return AsyncResponseDecorator + def _response_decorator_cls(self) -> t.Type[BaseResponseDecorator]: + class NoResponseDecorator(BaseResponseDecorator): + def __call__(self, function: t.Callable) -> t.Callable: + return lambda request: function(request) + + return NoResponseDecorator @property def _sync_async_decorator(self) -> t.Callable[[t.Callable], t.Callable]: @@ -160,15 +162,24 @@ async def wrapper(*args, **kwargs): def __call__(self, function: t.Callable) -> t.Callable: @functools.wraps(function) async def wrapper(*args, **kwargs): - # TODO: move into parameter decorator? - connexion_request = self.framework.get_request( - *args, uri_parser=self.uri_parser, **kwargs + request = self.framework.get_request( + uri_parser=self.uri_parser, scope=scope, receive=receive ) - decorated_function = self.decorate(function) - response = decorated_function(connexion_request) + response = decorated_function(request) while asyncio.iscoroutine(response): response = await response return response return wrapper + + +class StarletteDecorator(ASGIDecorator): + """Decorator for usage with Connexion or Starlette apps. The parameter decorator works with a + Starlette request, and provides Starlette datastructures to the view function. + + The response decorator returns Starlette responses.""" + + @property + def _response_decorator_cls(self) -> t.Type[AsyncResponseDecorator]: + return AsyncResponseDecorator diff --git a/connexion/decorators/parameter.py b/connexion/decorators/parameter.py index 468f5381b..84cc1fa4e 100644 --- a/connexion/decorators/parameter.py +++ b/connexion/decorators/parameter.py @@ -14,6 +14,9 @@ import inflection +from connexion.context import context, operation +from connexion.frameworks.flask import Flask as FlaskFramework +from connexion.frameworks.starlette import Starlette as StarletteFramework from connexion.http_facts import FORM_CONTENT_TYPES from connexion.lifecycle import ConnexionRequest, MiddlewareRequest from connexion.operations import AbstractOperation, Swagger2Operation @@ -27,30 +30,26 @@ class BaseParameterDecorator: def __init__( self, - operation: AbstractOperation, *, - get_body_fn: t.Callable, - arguments: t.List[str], - has_kwargs: bool, pythonic_params: bool = False, ) -> None: - self.operation = operation - self.get_body_fn = get_body_fn - self.arguments = arguments - self.has_kwargs = has_kwargs self.sanitize_fn = pythonic if pythonic_params else sanitized def _maybe_get_body( - self, request: t.Union[ConnexionRequest, MiddlewareRequest] + self, + request: t.Union[ConnexionRequest, MiddlewareRequest], + *, + arguments: t.List[str], + has_kwargs: bool, ) -> t.Any: - body_name = self.sanitize_fn(self.operation.body_name(request.content_type)) + body_name = self.sanitize_fn(operation.body_name(request.content_type)) # Pass form contents separately for Swagger2 for backward compatibility with # Connexion 2 Checking for body_name is not enough - if (body_name in self.arguments or self.has_kwargs) or ( + if (body_name in arguments or has_kwargs) or ( request.mimetype in FORM_CONTENT_TYPES - and isinstance(self.operation, Swagger2Operation) + and isinstance(operation, Swagger2Operation) ): - return self.get_body_fn(request) + return request.get_body() else: return None @@ -60,17 +59,24 @@ def __call__(self, function: t.Callable) -> t.Callable: class SyncParameterDecorator(BaseParameterDecorator): + + framework = FlaskFramework + def __call__(self, function: t.Callable) -> t.Callable: + unwrapped_function = unwrap_decorators(function) + arguments, has_kwargs = inspect_function_arguments(unwrapped_function) + @functools.wraps(function) - def wrapper(request: t.Union[ConnexionRequest, MiddlewareRequest]) -> t.Any: - request_body = self._maybe_get_body(request) + def wrapper(request: ConnexionRequest) -> t.Any: + request_body = self._maybe_get_body( + request, arguments=arguments, has_kwargs=has_kwargs + ) kwargs = prep_kwargs( request, - operation=self.operation, request_body=request_body, - arguments=self.arguments, - has_kwargs=self.has_kwargs, + arguments=arguments, + has_kwargs=has_kwargs, sanitize=self.sanitize_fn, ) @@ -80,22 +86,27 @@ def wrapper(request: t.Union[ConnexionRequest, MiddlewareRequest]) -> t.Any: class AsyncParameterDecorator(BaseParameterDecorator): + + framework = StarletteFramework + def __call__(self, function: t.Callable) -> t.Callable: + unwrapped_function = unwrap_decorators(function) + arguments, has_kwargs = inspect_function_arguments(unwrapped_function) + @functools.wraps(function) - async def wrapper( - request: t.Union[ConnexionRequest, MiddlewareRequest] - ) -> t.Any: - request_body = self._maybe_get_body(request) + async def wrapper(request: MiddlewareRequest) -> t.Any: + request_body = self._maybe_get_body( + request, arguments=arguments, has_kwargs=has_kwargs + ) while asyncio.iscoroutine(request_body): request_body = await request_body kwargs = prep_kwargs( request, - operation=self.operation, request_body=request_body, - arguments=self.arguments, - has_kwargs=self.has_kwargs, + arguments=arguments, + has_kwargs=has_kwargs, sanitize=self.sanitize_fn, ) @@ -107,7 +118,6 @@ async def wrapper( def prep_kwargs( request: t.Union[ConnexionRequest, MiddlewareRequest], *, - operation: AbstractOperation, request_body: t.Any, arguments: t.List[str], has_kwargs: bool, @@ -129,18 +139,25 @@ def prep_kwargs( kwargs = {sanitize(k): v for k, v in kwargs.items()} # add context info (e.g. from security decorator) - for key, value in request.context.items(): + for key, value in context.items(): if has_kwargs or key in arguments: kwargs[key] = value else: logger.debug("Context parameter '%s' not in function arguments", key) # attempt to provide the request context to the function if CONTEXT_NAME in arguments: - kwargs[CONTEXT_NAME] = request.context + kwargs[CONTEXT_NAME] = context return kwargs +def unwrap_decorators(function: t.Callable) -> t.Callable: + """Unwrap decorators to return the original function.""" + while hasattr(function, "__wrapped__"): + function = function.__wrapped__ # type: ignore + return function + + def inspect_function_arguments(function: t.Callable) -> t.Tuple[t.List[str], bool]: """ Returns the list of variables names of a function and if it diff --git a/connexion/decorators/response.py b/connexion/decorators/response.py index 082556fdf..1553cedfc 100644 --- a/connexion/decorators/response.py +++ b/connexion/decorators/response.py @@ -6,21 +6,18 @@ import typing as t from enum import Enum +from connexion.context import operation from connexion.datastructures import NoContent from connexion.exceptions import NonConformingResponseHeaders from connexion.frameworks.abstract import Framework from connexion.lifecycle import ConnexionResponse, MiddlewareResponse -from connexion.operations import AbstractOperation from connexion.utils import is_json_mimetype logger = logging.getLogger(__name__) class BaseResponseDecorator: - def __init__( - self, operation: AbstractOperation, *, framework: t.Type[Framework], jsonifier - ): - self.operation = operation + def __init__(self, *, framework: t.Type[Framework], jsonifier): self.framework = framework self.jsonifier = jsonifier @@ -39,7 +36,8 @@ def build_framework_response(self, handler_response): data, content_type=content_type, status_code=status_code, headers=headers ) - def _deduct_content_type(self, data: t.Any, headers: dict) -> str: + @staticmethod + def _deduct_content_type(data: t.Any, headers: dict) -> str: """Deduct the response content type from the returned data, headers and operation spec. :param data: Response data @@ -52,7 +50,7 @@ def _deduct_content_type(self, data: t.Any, headers: dict) -> str: content_type = headers.get("Content-Type") # TODO: don't default - produces = list(set(self.operation.produces)) + produces = list(set(operation.produces)) if data is not None and not produces: produces = ["application/json"] @@ -60,7 +58,7 @@ def _deduct_content_type(self, data: t.Any, headers: dict) -> str: if content_type not in produces: raise NonConformingResponseHeaders( f"Returned content type ({content_type}) is not defined in operation spec " - f"({self.operation.produces})." + f"({operation.produces})." ) else: if not produces: @@ -153,13 +151,13 @@ def _unpack_handler_response( class SyncResponseDecorator(BaseResponseDecorator): def __call__(self, function: t.Callable) -> t.Callable: @functools.wraps(function) - def wrapper(request): + def wrapper(*args, **kwargs): """ This method converts a handler response to a framework response. The handler response can be a ConnexionResponse, a framework response, a tuple or an object. """ - handler_response = function(request) + handler_response = function(*args, **kwargs) if self.framework.is_framework_response(handler_response): return handler_response elif isinstance(handler_response, (ConnexionResponse, MiddlewareResponse)): @@ -173,13 +171,13 @@ def wrapper(request): class AsyncResponseDecorator(BaseResponseDecorator): def __call__(self, function: t.Callable) -> t.Callable: @functools.wraps(function) - async def wrapper(request): + async def wrapper(*args, **kwargs): """ This method converts a handler response to a framework response. The handler response can be a ConnexionResponse, a framework response, a tuple or an object. """ - handler_response = await function(request) + handler_response = await function(*args, **kwargs) if self.framework.is_framework_response(handler_response): return handler_response elif isinstance(handler_response, (ConnexionResponse, MiddlewareResponse)): diff --git a/connexion/frameworks/flask.py b/connexion/frameworks/flask.py index 5cd5d1b77..d6cbc7b53 100644 --- a/connexion/frameworks/flask.py +++ b/connexion/frameworks/flask.py @@ -11,10 +11,8 @@ import werkzeug from connexion.frameworks.abstract import Framework -from connexion.http_facts import FORM_CONTENT_TYPES from connexion.lifecycle import ConnexionRequest from connexion.uri_parsing import AbstractURIParser -from connexion.utils import is_json_mimetype class Flask(Framework): @@ -58,16 +56,6 @@ def build_response( def get_request(*, uri_parser: AbstractURIParser, **kwargs) -> ConnexionRequest: # type: ignore return ConnexionRequest(flask.request, uri_parser=uri_parser) - @staticmethod - def get_body(request): - if is_json_mimetype(request.content_type): - return request.get_json(silent=True) - elif request.mimetype in FORM_CONTENT_TYPES: - return request.form - else: - # Return explicit None instead of empty bytestring so it is handled as null downstream - return request.get_data() or None - PATH_PARAMETER = re.compile(r"\{([^}]*)\}") diff --git a/connexion/frameworks/starlette.py b/connexion/frameworks/starlette.py index 89eeb9501..3121de6d1 100644 --- a/connexion/frameworks/starlette.py +++ b/connexion/frameworks/starlette.py @@ -8,9 +8,7 @@ from starlette.types import Receive, Scope from connexion.frameworks.abstract import Framework -from connexion.http_facts import FORM_CONTENT_TYPES from connexion.lifecycle import MiddlewareRequest, MiddlewareResponse -from connexion.utils import is_json_mimetype class Starlette(Framework): @@ -54,16 +52,6 @@ def build_response( def get_request(*, scope: Scope, receive: Receive, **kwargs) -> MiddlewareRequest: # type: ignore return MiddlewareRequest(scope, receive) - @staticmethod - async def get_body(request): - if is_json_mimetype(request.content_type): - return await request.json() - elif request.mimetype in FORM_CONTENT_TYPES: - return await request.form() - else: - # Return explicit None instead of empty bytestring so it is handled as null downstream - return await request.data() or None - PATH_PARAMETER = re.compile(r"\{([^}]*)\}") PATH_PARAMETER_CONVERTERS = {"integer": "int", "number": "float", "path": "path"} diff --git a/connexion/lifecycle.py b/connexion/lifecycle.py index f19b213f7..ebf120312 100644 --- a/connexion/lifecycle.py +++ b/connexion/lifecycle.py @@ -9,6 +9,9 @@ from starlette.requests import Request as StarletteRequest from starlette.responses import StreamingResponse as StarletteStreamingResponse +from connexion.http_facts import FORM_CONTENT_TYPES +from connexion.utils import is_json_mimetype + class ConnexionRequest: def __init__(self, flask_request: FlaskRequest, uri_parser=None): @@ -41,6 +44,16 @@ def form(self): form_data = self.uri_parser.resolve_form(form) return form_data + def get_body(self): + """Get body based on content type""" + if is_json_mimetype(self.content_type): + return self.get_json(silent=True) + elif self.mimetype in FORM_CONTENT_TYPES: + return self.form + else: + # Return explicit None instead of empty bytestring so it is handled as null downstream + return self.get_data() or None + def __getattr__(self, item): return getattr(self._flask_request, item) @@ -98,6 +111,15 @@ def files(self): # TODO: separate files? return {} + async def get_body(self): + if is_json_mimetype(self.content_type): + return await self.json() + elif self.mimetype in FORM_CONTENT_TYPES: + return await self.form() + else: + # Return explicit None instead of empty bytestring so it is handled as null downstream + return await self.data() or None + class MiddlewareResponse(StarletteStreamingResponse): """Wraps starlette StreamingResponse so it can easily be extended.""" diff --git a/connexion/middleware/abstract.py b/connexion/middleware/abstract.py index 4dc1ccddc..0247ba6a8 100644 --- a/connexion/middleware/abstract.py +++ b/connexion/middleware/abstract.py @@ -40,10 +40,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: class RoutedAPI(AbstractSpecAPI, t.Generic[OP]): - - operation_cls: t.Type[OP] - """The operation this middleware uses, which should implement the RoutingOperation protocol.""" - def __init__( self, specification: t.Union[pathlib.Path, str, dict], @@ -70,7 +66,12 @@ def add_paths(self) -> None: def add_operation(self, path: str, method: str) -> None: operation_spec_cls = self.specification.operation_cls operation = operation_spec_cls.from_spec( - self.specification, self, path, method, self.resolver + self.specification, + self, + path, + method, + self.resolver, + uri_parser_class=self.options.uri_parser_class, ) routed_operation = self.make_operation(operation) self.operations[operation.operation_id] = routed_operation diff --git a/connexion/middleware/context.py b/connexion/middleware/context.py index 978ac8cfe..71a1febf7 100644 --- a/connexion/middleware/context.py +++ b/connexion/middleware/context.py @@ -2,13 +2,39 @@ middleware stack, so it exposes the scope passed to the application""" from starlette.types import ASGIApp, Receive, Scope, Send -from connexion.context import _scope +from connexion.context import _context, _operation, _receive, _scope +from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware +from connexion.operations import AbstractOperation -class ContextMiddleware: - def __init__(self, app: ASGIApp) -> None: - self.app = app +class ContextOperation: + def __init__( + self, + next_app: ASGIApp, + *, + operation: AbstractOperation, + ) -> None: + self.next_app = next_app + self.operation = operation async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + _context.set(scope.get("extensions", {}).get("connexion_context", {})) + _operation.set(self.operation) + _receive.set(receive) _scope.set(scope) - await self.app(scope, receive, send) + await self.next_app(scope, receive, send) + + +class ContextAPI(RoutedAPI[ContextOperation]): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.add_paths() + + def make_operation(self, operation: AbstractOperation) -> ContextOperation: + return ContextOperation(self.next_app, operation=operation) + + +class ContextMiddleware(RoutedMiddleware[ContextAPI]): + """Middleware to expose operation specific context to application.""" + + api_cls = ContextAPI diff --git a/connexion/middleware/request_validation.py b/connexion/middleware/request_validation.py index be1741ac0..b15d44f88 100644 --- a/connexion/middleware/request_validation.py +++ b/connexion/middleware/request_validation.py @@ -120,8 +120,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): class RequestValidationAPI(RoutedAPI[RequestValidationOperation]): """Validation API.""" - operation_cls = RequestValidationOperation - def __init__( self, *args, diff --git a/connexion/middleware/response_validation.py b/connexion/middleware/response_validation.py index 9e5a57efb..761858c89 100644 --- a/connexion/middleware/response_validation.py +++ b/connexion/middleware/response_validation.py @@ -125,8 +125,6 @@ async def wrapped_send(message: t.MutableMapping[str, t.Any]) -> None: class ResponseValidationAPI(RoutedAPI[ResponseValidationOperation]): """Validation API.""" - operation_cls = ResponseValidationOperation - def __init__( self, *args, diff --git a/connexion/middleware/routing.py b/connexion/middleware/routing.py index beedb02d9..fe0643eb2 100644 --- a/connexion/middleware/routing.py +++ b/connexion/middleware/routing.py @@ -68,12 +68,18 @@ def __init__( resolver=resolver, resolver_error_handler=resolver_error_handler, debug=debug, + **kwargs, ) def add_operation(self, path: str, method: str) -> None: operation_cls = self.specification.operation_cls operation = operation_cls.from_spec( - self.specification, self, path, method, self.resolver + self.specification, + self, + path, + method, + self.resolver, + uri_parser_class=self.options.uri_parser_class, ) routing_operation = RoutingOperation.from_operation( operation, next_app=self.next_app diff --git a/connexion/middleware/security.py b/connexion/middleware/security.py index 376f5286b..043aa6b2d 100644 --- a/connexion/middleware/security.py +++ b/connexion/middleware/security.py @@ -205,9 +205,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: class SecurityAPI(RoutedAPI[SecurityOperation]): - - operation_cls = SecurityOperation - def __init__(self, *args, auth_all_paths: bool = False, **kwargs): super().__init__(*args, **kwargs) diff --git a/connexion/problem.py b/connexion/problem.py index e749facd1..8d0488b50 100644 --- a/connexion/problem.py +++ b/connexion/problem.py @@ -4,8 +4,6 @@ to communicate distinct "problem types" to non-human consumers. """ -from .lifecycle import ConnexionResponse - def problem(status, title, detail, type=None, instance=None, headers=None, ext=None): """ @@ -33,6 +31,8 @@ def problem(status, title, detail, type=None, instance=None, headers=None, ext=N :return: error response :rtype: ConnexionResponse """ + from .lifecycle import ConnexionResponse # prevent circular import + if not type: type = "about:blank" diff --git a/connexion/testing.py b/connexion/testing.py new file mode 100644 index 000000000..20a4335bf --- /dev/null +++ b/connexion/testing.py @@ -0,0 +1,72 @@ +import contextvars +import typing as t +from unittest.mock import MagicMock + +from starlette.types import Receive, Scope + +from connexion.context import _context, _operation, _receive, _scope +from connexion.operations import AbstractOperation + + +class TestContext: + __test__ = False # Pytest + + def __init__( + self, + *, + context: dict = None, + operation: AbstractOperation = None, + receive: Receive = None, + scope: Scope = None, + ) -> None: + self.context = context if context is not None else self.build_context() + self.operation = operation if operation is not None else self.build_operation() + self.receive = receive if receive is not None else self.build_receive() + self.scope = scope if scope is not None else self.build_scope() + + self.tokens: t.Dict[str, contextvars.Token] = {} + + def __enter__(self) -> None: + self.tokens["context"] = _context.set(self.context) + self.tokens["operation"] = _operation.set(self.operation) + self.tokens["receive"] = _receive.set(self.receive) + self.tokens["scope"] = _scope.set(self.scope) + return + + def __exit__(self, type, value, traceback): + _context.reset(self.tokens["context"]) + _operation.reset(self.tokens["operation"]) + _receive.reset(self.tokens["receive"]) + _scope.reset(self.tokens["scope"]) + return False + + @staticmethod + def build_context() -> dict: + return {} + + @staticmethod + def build_operation() -> AbstractOperation: + return MagicMock(name="operation") + + @staticmethod + def build_receive() -> Receive: + async def receive() -> t.MutableMapping[str, t.Any]: + return { + "type": "http.request", + "body": b"", + } + + return receive + + @staticmethod + def build_scope(**kwargs) -> Scope: + scope = { + "type": "http", + "query_string": b"", + "headers": [(b"Content-Type", b"application/octet-stream")], + } + + for key, value in kwargs.items(): + scope[key] = value + + return scope diff --git a/tests/decorators/test_parameter.py b/tests/decorators/test_parameter.py index 778804f36..03324632a 100644 --- a/tests/decorators/test_parameter.py +++ b/tests/decorators/test_parameter.py @@ -3,14 +3,13 @@ from connexion.decorators.parameter import ( AsyncParameterDecorator, SyncParameterDecorator, - inspect_function_arguments, pythonic, ) +from connexion.testing import TestContext def test_sync_injection(): request = MagicMock(name="request") - request.query_params = {} request.path_params = {"p1": "123"} func = MagicMock() @@ -18,25 +17,18 @@ def test_sync_injection(): def handler(**kwargs): func(**kwargs) - def get_body_fn(_request): - return {} - operation = MagicMock(name="operation") operation.body_name = lambda _: "body" - arguments, has_kwargs = inspect_function_arguments(handler) - - parameter_decorator = SyncParameterDecorator( - operation, get_body_fn=get_body_fn, arguments=arguments, has_kwargs=has_kwargs - ) - decorated_handler = parameter_decorator(handler) - decorated_handler(request) + with TestContext(operation=operation): + parameter_decorator = SyncParameterDecorator() + decorated_handler = parameter_decorator(handler) + decorated_handler(request) func.assert_called_with(p1="123") async def test_async_injection(): request = MagicMock(name="request") - request.query_params = {} request.path_params = {"p1": "123"} func = MagicMock() @@ -44,74 +36,56 @@ async def test_async_injection(): async def handler(**kwargs): func(**kwargs) - def get_body_fn(_request): - return {} - operation = MagicMock(name="operation") operation.body_name = lambda _: "body" - arguments, has_kwargs = inspect_function_arguments(handler) - - parameter_decorator = AsyncParameterDecorator( - operation, get_body_fn=get_body_fn, arguments=arguments, has_kwargs=has_kwargs - ) - decorated_handler = parameter_decorator(handler) - await decorated_handler(request) + with TestContext(operation=operation): + parameter_decorator = AsyncParameterDecorator() + decorated_handler = parameter_decorator(handler) + await decorated_handler(request) func.assert_called_with(p1="123") def test_sync_injection_with_context(): request = MagicMock(name="request") - request.query_params = {} request.path_params = {"p1": "123"} - request.context = {} func = MagicMock() def handler(context_, **kwargs): func(context_, **kwargs) - def get_body_fn(_request): - return {} + context = {"test": "success"} operation = MagicMock(name="operation") operation.body_name = lambda _: "body" - arguments, has_kwargs = inspect_function_arguments(handler) - - parameter_decorator = SyncParameterDecorator( - operation, get_body_fn=get_body_fn, arguments=arguments, has_kwargs=has_kwargs - ) - decorated_handler = parameter_decorator(handler) - decorated_handler(request) - func.assert_called_with(request.context, p1="123") + with TestContext(context=context, operation=operation): + parameter_decorator = SyncParameterDecorator() + decorated_handler = parameter_decorator(handler) + decorated_handler(request) + func.assert_called_with(context, p1="123", test="success") async def test_async_injection_with_context(): request = MagicMock(name="request") - request.query_params = {} request.path_params = {"p1": "123"} - request.context = {} func = MagicMock() async def handler(context_, **kwargs): func(context_, **kwargs) - def get_body_fn(_request): - return {} + context = {"test": "success"} operation = MagicMock(name="operation") operation.body_name = lambda _: "body" - arguments, has_kwargs = inspect_function_arguments(handler) - - parameter_decorator = AsyncParameterDecorator( - operation, get_body_fn=get_body_fn, arguments=arguments, has_kwargs=has_kwargs - ) - decorated_handler = parameter_decorator(handler) - await decorated_handler(request) - func.assert_called_with(request.context, p1="123") + with TestContext(context=context, operation=operation): + parameter_decorator = AsyncParameterDecorator() + decorated_handler = parameter_decorator(handler) + await decorated_handler(request) + func.assert_called_with(context, p1="123", test="success") def test_pythonic_params():