Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unmarshalling processor enhancement #625

Merged
merged 1 commit into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions openapi_core/contrib/django/handlers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""OpenAPI core contrib django handlers module"""
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Optional
Expand All @@ -14,6 +15,7 @@
from openapi_core.templating.paths.exceptions import PathNotFound
from openapi_core.templating.paths.exceptions import ServerNotFound
from openapi_core.templating.security.exceptions import SecurityNotFound
from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult


class DjangoOpenAPIErrorsHandler:
Expand All @@ -25,18 +27,15 @@ class DjangoOpenAPIErrorsHandler:
MediaTypeNotFound: 415,
}

@classmethod
def handle(
cls,
def __call__(
self,
errors: Iterable[Exception],
req: HttpRequest,
resp: Optional[HttpResponse] = None,
) -> JsonResponse:
data_errors = [cls.format_openapi_error(err) for err in errors]
data_errors = [self.format_openapi_error(err) for err in errors]
data = {
"errors": data_errors,
}
data_error_max = max(data_errors, key=cls.get_error_status)
data_error_max = max(data_errors, key=self.get_error_status)
return JsonResponse(data, status=data_error_max["status"])

@classmethod
Expand All @@ -52,3 +51,15 @@ def format_openapi_error(cls, error: BaseException) -> Dict[str, Any]:
@classmethod
def get_error_status(cls, error: Dict[str, Any]) -> str:
return str(error["status"])


class DjangoOpenAPIValidRequestHandler:
def __init__(self, req: HttpRequest, view: Callable[[Any], HttpResponse]):
self.req = req
self.view = view

def __call__(
self, request_unmarshal_result: RequestUnmarshalResult
) -> HttpResponse:
self.req.openapi = request_unmarshal_result
return self.view(self.req)
53 changes: 17 additions & 36 deletions openapi_core/contrib/django/middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,24 @@

from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.http import JsonResponse
from django.http.request import HttpRequest
from django.http.response import HttpResponse

from openapi_core.contrib.django.handlers import DjangoOpenAPIErrorsHandler
from openapi_core.contrib.django.handlers import (
DjangoOpenAPIValidRequestHandler,
)
from openapi_core.contrib.django.requests import DjangoOpenAPIRequest
from openapi_core.contrib.django.responses import DjangoOpenAPIResponse
from openapi_core.unmarshalling.processors import UnmarshallingProcessor
from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult
from openapi_core.unmarshalling.response.datatypes import (
ResponseUnmarshalResult,
)


class DjangoOpenAPIMiddleware:
class DjangoOpenAPIMiddleware(
UnmarshallingProcessor[HttpRequest, HttpResponse]
):
request_cls = DjangoOpenAPIRequest
response_cls = DjangoOpenAPIResponse
valid_request_handler_cls = DjangoOpenAPIValidRequestHandler
errors_handler = DjangoOpenAPIErrorsHandler()

def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
Expand All @@ -31,40 +32,17 @@ def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
if hasattr(settings, "OPENAPI_RESPONSE_CLS"):
self.response_cls = settings.OPENAPI_RESPONSE_CLS

self.processor = UnmarshallingProcessor(settings.OPENAPI_SPEC)
super().__init__(settings.OPENAPI_SPEC)

def __call__(self, request: HttpRequest) -> HttpResponse:
openapi_request = self._get_openapi_request(request)
req_result = self.processor.process_request(openapi_request)
if req_result.errors:
response = self._handle_request_errors(req_result, request)
else:
request.openapi = req_result
response = self.get_response(request)

if self.response_cls is None:
return response
openapi_response = self._get_openapi_response(response)
resp_result = self.processor.process_response(
openapi_request, openapi_response
valid_request_handler = self.valid_request_handler_cls(
request, self.get_response
)
response = self.handle_request(
request, valid_request_handler, self.errors_handler
)
if resp_result.errors:
return self._handle_response_errors(resp_result, request, response)

return response

def _handle_request_errors(
self, request_result: RequestUnmarshalResult, req: HttpRequest
) -> JsonResponse:
return self.errors_handler.handle(request_result.errors, req, None)

def _handle_response_errors(
self,
response_result: ResponseUnmarshalResult,
req: HttpRequest,
resp: HttpResponse,
) -> JsonResponse:
return self.errors_handler.handle(response_result.errors, req, resp)
return self.handle_response(request, response, self.errors_handler)

def _get_openapi_request(
self, request: HttpRequest
Expand All @@ -76,3 +54,6 @@ def _get_openapi_response(
) -> DjangoOpenAPIResponse:
assert self.response_cls is not None
return self.response_cls(response)

def _validate_response(self) -> bool:
return self.response_cls is not None
35 changes: 25 additions & 10 deletions openapi_core/contrib/falcon/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from openapi_core.templating.paths.exceptions import PathNotFound
from openapi_core.templating.paths.exceptions import ServerNotFound
from openapi_core.templating.security.exceptions import SecurityNotFound
from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult


class FalconOpenAPIErrorsHandler:
Expand All @@ -26,24 +27,26 @@ class FalconOpenAPIErrorsHandler:
MediaTypeNotFound: 415,
}

@classmethod
def handle(
cls, req: Request, resp: Response, errors: Iterable[Exception]
) -> None:
data_errors = [cls.format_openapi_error(err) for err in errors]
def __init__(self, req: Request, resp: Response):
self.req = req
self.resp = resp

def __call__(self, errors: Iterable[Exception]) -> Response:
data_errors = [self.format_openapi_error(err) for err in errors]
data = {
"errors": data_errors,
}
data_str = dumps(data)
data_error_max = max(data_errors, key=cls.get_error_status)
resp.content_type = MEDIA_JSON
resp.status = getattr(
data_error_max = max(data_errors, key=self.get_error_status)
self.resp.content_type = MEDIA_JSON
self.resp.status = getattr(
status_codes,
f"HTTP_{data_error_max['status']}",
status_codes.HTTP_400,
)
resp.text = data_str
resp.complete = True
self.resp.text = data_str
self.resp.complete = True
return self.resp

@classmethod
def format_openapi_error(cls, error: BaseException) -> Dict[str, Any]:
Expand All @@ -58,3 +61,15 @@ def format_openapi_error(cls, error: BaseException) -> Dict[str, Any]:
@classmethod
def get_error_status(cls, error: Dict[str, Any]) -> int:
return int(error["status"])


class FalconOpenAPIValidRequestHandler:
def __init__(self, req: Request, resp: Response):
self.req = req
self.resp = resp

def __call__(
self, request_unmarshal_result: RequestUnmarshalResult
) -> Response:
self.req.context.openapi = request_unmarshal_result
return self.resp
69 changes: 26 additions & 43 deletions openapi_core/contrib/falcon/middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,24 @@
from falcon.response import Response

from openapi_core.contrib.falcon.handlers import FalconOpenAPIErrorsHandler
from openapi_core.contrib.falcon.handlers import (
FalconOpenAPIValidRequestHandler,
)
from openapi_core.contrib.falcon.requests import FalconOpenAPIRequest
from openapi_core.contrib.falcon.responses import FalconOpenAPIResponse
from openapi_core.spec import Spec
from openapi_core.unmarshalling.processors import UnmarshallingProcessor
from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult
from openapi_core.unmarshalling.request.types import RequestUnmarshallerType
from openapi_core.unmarshalling.response.datatypes import (
ResponseUnmarshalResult,
)
from openapi_core.unmarshalling.response.types import ResponseUnmarshallerType


class FalconOpenAPIMiddleware(UnmarshallingProcessor):
class FalconOpenAPIMiddleware(UnmarshallingProcessor[Request, Response]):
request_cls = FalconOpenAPIRequest
response_cls = FalconOpenAPIResponse
errors_handler = FalconOpenAPIErrorsHandler()
valid_request_handler_cls = FalconOpenAPIValidRequestHandler
errors_handler_cls: Type[
FalconOpenAPIErrorsHandler
] = FalconOpenAPIErrorsHandler

def __init__(
self,
Expand All @@ -31,7 +33,9 @@ def __init__(
response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None,
request_cls: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest,
response_cls: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse,
errors_handler: Optional[FalconOpenAPIErrorsHandler] = None,
errors_handler_cls: Type[
FalconOpenAPIErrorsHandler
] = FalconOpenAPIErrorsHandler,
**unmarshaller_kwargs: Any,
):
super().__init__(
Expand All @@ -42,7 +46,7 @@ def __init__(
)
self.request_cls = request_cls or self.request_cls
self.response_cls = response_cls or self.response_cls
self.errors_handler = errors_handler or self.errors_handler
self.errors_handler_cls = errors_handler_cls or self.errors_handler_cls

@classmethod
def from_spec(
Expand All @@ -52,7 +56,9 @@ def from_spec(
response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None,
request_cls: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest,
response_cls: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse,
errors_handler: Optional[FalconOpenAPIErrorsHandler] = None,
errors_handler_cls: Type[
FalconOpenAPIErrorsHandler
] = FalconOpenAPIErrorsHandler,
**unmarshaller_kwargs: Any,
) -> "FalconOpenAPIMiddleware":
return cls(
Expand All @@ -61,46 +67,20 @@ def from_spec(
response_unmarshaller_cls=response_unmarshaller_cls,
request_cls=request_cls,
response_cls=response_cls,
errors_handler=errors_handler,
errors_handler_cls=errors_handler_cls,
**unmarshaller_kwargs,
)

def process_request(self, req: Request, resp: Response) -> None: # type: ignore
openapi_req = self._get_openapi_request(req)
req.context.openapi = super().process_request(openapi_req)
if req.context.openapi.errors:
return self._handle_request_errors(req, resp, req.context.openapi)
def process_request(self, req: Request, resp: Response) -> None:
valid_handler = self.valid_request_handler_cls(req, resp)
errors_handler = self.errors_handler_cls(req, resp)
self.handle_request(req, valid_handler, errors_handler)

def process_response( # type: ignore
def process_response(
self, req: Request, resp: Response, resource: Any, req_succeeded: bool
) -> None:
if self.response_cls is None:
return resp
openapi_req = self._get_openapi_request(req)
openapi_resp = self._get_openapi_response(resp)
resp.context.openapi = super().process_response(
openapi_req, openapi_resp
)
if resp.context.openapi.errors:
return self._handle_response_errors(
req, resp, resp.context.openapi
)

def _handle_request_errors(
self,
req: Request,
resp: Response,
request_result: RequestUnmarshalResult,
) -> None:
return self.errors_handler.handle(req, resp, request_result.errors)

def _handle_response_errors(
self,
req: Request,
resp: Response,
response_result: ResponseUnmarshalResult,
) -> None:
return self.errors_handler.handle(req, resp, response_result.errors)
errors_handler = self.errors_handler_cls(req, resp)
self.handle_response(req, resp, errors_handler)

def _get_openapi_request(self, request: Request) -> FalconOpenAPIRequest:
return self.request_cls(request)
Expand All @@ -110,3 +90,6 @@ def _get_openapi_response(
) -> FalconOpenAPIResponse:
assert self.response_cls is not None
return self.response_cls(response)

def _validate_response(self) -> bool:
return self.response_cls is not None
Loading