Skip to content

Commit

Permalink
Unmarshalling processor enhancement
Browse files Browse the repository at this point in the history
  • Loading branch information
p1c2u committed Jul 22, 2023
1 parent d60be8c commit c1f90b6
Show file tree
Hide file tree
Showing 12 changed files with 280 additions and 180 deletions.
32 changes: 26 additions & 6 deletions openapi_core/contrib/django/handlers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""OpenAPI core contrib django handlers module"""
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Optional
Expand All @@ -14,6 +15,7 @@
from openapi_core.templating.paths.exceptions import PathNotFound
from openapi_core.templating.paths.exceptions import ServerNotFound
from openapi_core.templating.security.exceptions import SecurityNotFound
from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult


class DjangoOpenAPIErrorsHandler:
Expand All @@ -25,19 +27,26 @@ class DjangoOpenAPIErrorsHandler:
MediaTypeNotFound: 415,
}

def __call__(
self,
errors: Iterable[Exception],
) -> JsonResponse:
data_errors = [self.format_openapi_error(err) for err in errors]
data = {
"errors": data_errors,
}
data_error_max = max(data_errors, key=self.get_error_status)
return JsonResponse(data, status=data_error_max["status"])

@classmethod
def handle(
cls,
errors: Iterable[Exception],
req: HttpRequest,
resp: Optional[HttpResponse] = None,
) -> JsonResponse:
data_errors = [cls.format_openapi_error(err) for err in errors]
data = {
"errors": data_errors,
}
data_error_max = max(data_errors, key=cls.get_error_status)
return JsonResponse(data, status=data_error_max["status"])
instance = cls()
return instance(errors)

@classmethod
def format_openapi_error(cls, error: BaseException) -> Dict[str, Any]:
Expand All @@ -52,3 +61,14 @@ def format_openapi_error(cls, error: BaseException) -> Dict[str, Any]:
@classmethod
def get_error_status(cls, error: Dict[str, Any]) -> str:
return str(error["status"])


class DjangoOpenAPIValidRequestHandler:

def __init__(self, req: HttpRequest, view: Callable[[Any], HttpResponse]):
self.req = req
self.view = view

def __call__(self, request_unmarshal_result: RequestUnmarshalResult) -> HttpResponse:
self.req.openapi = request_unmarshal_result
return self.view(self.req)
42 changes: 8 additions & 34 deletions openapi_core/contrib/django/middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,20 @@

from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.http import JsonResponse
from django.http.request import HttpRequest
from django.http.response import HttpResponse

from openapi_core.contrib.django.handlers import DjangoOpenAPIErrorsHandler
from openapi_core.contrib.django.handlers import DjangoOpenAPIValidRequestHandler
from openapi_core.contrib.django.requests import DjangoOpenAPIRequest
from openapi_core.contrib.django.responses import DjangoOpenAPIResponse
from openapi_core.unmarshalling.processors import UnmarshallingProcessor
from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult
from openapi_core.unmarshalling.response.datatypes import (
ResponseUnmarshalResult,
)


class DjangoOpenAPIMiddleware:
class DjangoOpenAPIMiddleware(UnmarshallingProcessor[HttpRequest, HttpResponse]):
request_class = DjangoOpenAPIRequest
response_class = DjangoOpenAPIResponse
valid_request_handler_cls = DjangoOpenAPIValidRequestHandler
errors_handler = DjangoOpenAPIErrorsHandler()

def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
Expand All @@ -28,38 +25,15 @@ def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
if not hasattr(settings, "OPENAPI_SPEC"):
raise ImproperlyConfigured("OPENAPI_SPEC not defined in settings")

self.processor = UnmarshallingProcessor(settings.OPENAPI_SPEC)
super().__init__(settings.OPENAPI_SPEC)

def __call__(self, request: HttpRequest) -> HttpResponse:
openapi_request = self._get_openapi_request(request)
req_result = self.processor.process_request(openapi_request)
if req_result.errors:
response = self._handle_request_errors(req_result, request)
else:
request.openapi = req_result
response = self.get_response(request)
valid_request_handler = self.valid_request_handler_cls(request, self.get_response)
response = self.handle_request(request, valid_request_handler, self.errors_handler)

openapi_response = self._get_openapi_response(response)
resp_result = self.processor.process_response(
openapi_request, openapi_response
return self.handle_response(
request, response, self.errors_handler
)
if resp_result.errors:
return self._handle_response_errors(resp_result, request, response)

return response

def _handle_request_errors(
self, request_result: RequestUnmarshalResult, req: HttpRequest
) -> JsonResponse:
return self.errors_handler.handle(request_result.errors, req, None)

def _handle_response_errors(
self,
response_result: ResponseUnmarshalResult,
req: HttpRequest,
resp: HttpResponse,
) -> JsonResponse:
return self.errors_handler.handle(response_result.errors, req, resp)

def _get_openapi_request(
self, request: HttpRequest
Expand Down
39 changes: 29 additions & 10 deletions openapi_core/contrib/falcon/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from openapi_core.templating.paths.exceptions import PathNotFound
from openapi_core.templating.paths.exceptions import ServerNotFound
from openapi_core.templating.security.exceptions import SecurityNotFound
from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult


class FalconOpenAPIErrorsHandler:
Expand All @@ -26,24 +27,31 @@ class FalconOpenAPIErrorsHandler:
MediaTypeNotFound: 415,
}

@classmethod
def handle(
cls, req: Request, resp: Response, errors: Iterable[Exception]
) -> None:
data_errors = [cls.format_openapi_error(err) for err in errors]
def __init__(self, req: Request, resp: Response):
self.req = req
self.resp = resp

def __call__(self, errors: Iterable[Exception]) -> Response:
data_errors = [self.format_openapi_error(err) for err in errors]
data = {
"errors": data_errors,
}
data_str = dumps(data)
data_error_max = max(data_errors, key=cls.get_error_status)
resp.content_type = MEDIA_JSON
resp.status = getattr(
data_error_max = max(data_errors, key=self.get_error_status)
self.resp.content_type = MEDIA_JSON
self.resp.status = getattr(
status_codes,
f"HTTP_{data_error_max['status']}",
status_codes.HTTP_400,
)
resp.text = data_str
resp.complete = True
self.resp.text = data_str
self.resp.complete = True
return self.resp

@classmethod
def handle(cls, req: Request, resp: Response, errors: Iterable[Exception]) -> Response:
instance = cls(req, resp)
return instance(errors)

@classmethod
def format_openapi_error(cls, error: BaseException) -> Dict[str, Any]:
Expand All @@ -58,3 +66,14 @@ def format_openapi_error(cls, error: BaseException) -> Dict[str, Any]:
@classmethod
def get_error_status(cls, error: Dict[str, Any]) -> int:
return int(error["status"])


class FalconOpenAPIValidRequestHandler:

def __init__(self, req: Request, resp: Response):
self.req = req
self.resp = resp

def __call__(self, request_unmarshal_result: RequestUnmarshalResult) -> Response:
self.req.context.openapi = request_unmarshal_result
return self.resp
85 changes: 31 additions & 54 deletions openapi_core/contrib/falcon/middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,33 @@
from falcon.request import Request
from falcon.response import Response

from openapi_core.contrib.falcon.handlers import FalconOpenAPIErrorsHandler
from openapi_core.contrib.falcon.handlers import (
FalconOpenAPIErrorsHandler,
FalconOpenAPIValidRequestHandler,
)
from openapi_core.contrib.falcon.requests import FalconOpenAPIRequest
from openapi_core.contrib.falcon.responses import FalconOpenAPIResponse
from openapi_core.protocols import ErrorsHandler
from openapi_core.spec import Spec
from openapi_core.unmarshalling.processors import UnmarshallingProcessor
from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult
from openapi_core.unmarshalling.request.types import RequestUnmarshallerType
from openapi_core.unmarshalling.response.datatypes import (
ResponseUnmarshalResult,
)
from openapi_core.unmarshalling.response.types import ResponseUnmarshallerType


class FalconOpenAPIMiddleware(UnmarshallingProcessor):
request_class = FalconOpenAPIRequest
response_class = FalconOpenAPIResponse
errors_handler = FalconOpenAPIErrorsHandler()
class FalconOpenAPIMiddleware(UnmarshallingProcessor[Request, Response]):
request_cls = FalconOpenAPIRequest
response_cls = FalconOpenAPIResponse
valid_request_handler_cls = FalconOpenAPIValidRequestHandler
errors_handler_cls: Type[FalconOpenAPIErrorsHandler] = FalconOpenAPIErrorsHandler

def __init__(
self,
spec: Spec,
request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None,
response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None,
request_class: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest,
response_class: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse,
errors_handler: Optional[FalconOpenAPIErrorsHandler] = None,
request_cls: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest,
response_cls: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse,
errors_handler_cls: Type[FalconOpenAPIErrorsHandler] = FalconOpenAPIErrorsHandler,
**unmarshaller_kwargs: Any,
):
super().__init__(
Expand All @@ -40,70 +41,46 @@ def __init__(
response_unmarshaller_cls=response_unmarshaller_cls,
**unmarshaller_kwargs,
)
self.request_class = request_class or self.request_class
self.response_class = response_class or self.response_class
self.errors_handler = errors_handler or self.errors_handler
self.request_cls = request_cls or self.request_cls
self.response_cls = response_cls or self.response_cls
self.errors_handler_cls = errors_handler_cls or self.errors_handler_cls

@classmethod
def from_spec(
cls,
spec: Spec,
request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None,
response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None,
request_class: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest,
response_class: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse,
errors_handler: Optional[FalconOpenAPIErrorsHandler] = None,
request_cls: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest,
response_cls: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse,
errors_handler_cls: Type[FalconOpenAPIErrorsHandler] = FalconOpenAPIErrorsHandler,
**unmarshaller_kwargs: Any,
) -> "FalconOpenAPIMiddleware":
return cls(
spec,
request_unmarshaller_cls=request_unmarshaller_cls,
response_unmarshaller_cls=response_unmarshaller_cls,
request_class=request_class,
response_class=response_class,
errors_handler=errors_handler,
request_cls=request_cls,
response_cls=response_cls,
errors_handler_cls=errors_handler_cls,
**unmarshaller_kwargs,
)

def process_request(self, req: Request, resp: Response) -> None: # type: ignore
openapi_req = self._get_openapi_request(req)
req.context.openapi = super().process_request(openapi_req)
if req.context.openapi.errors:
return self._handle_request_errors(req, resp, req.context.openapi)
def process_request(self, req: Request, resp: Response) -> None:
valid_handler = self.valid_request_handler_cls(req, resp)
errors_handler = self.errors_handler_cls(req, resp)
self.handle_request(req, valid_handler, errors_handler)

def process_response( # type: ignore
def process_response(
self, req: Request, resp: Response, resource: Any, req_succeeded: bool
) -> None:
openapi_req = self._get_openapi_request(req)
openapi_resp = self._get_openapi_response(resp)
resp.context.openapi = super().process_response(
openapi_req, openapi_resp
)
if resp.context.openapi.errors:
return self._handle_response_errors(
req, resp, resp.context.openapi
)

def _handle_request_errors(
self,
req: Request,
resp: Response,
request_result: RequestUnmarshalResult,
) -> None:
return self.errors_handler.handle(req, resp, request_result.errors)

def _handle_response_errors(
self,
req: Request,
resp: Response,
response_result: ResponseUnmarshalResult,
) -> None:
return self.errors_handler.handle(req, resp, response_result.errors)
errors_handler = self.errors_handler_cls(req, resp)
self.handle_response(req, resp, errors_handler)

def _get_openapi_request(self, request: Request) -> FalconOpenAPIRequest:
return self.request_class(request)
return self.request_cls(request)

def _get_openapi_response(
self, response: Response
) -> FalconOpenAPIResponse:
return self.response_class(response)
return self.response_cls(response)
Loading

0 comments on commit c1f90b6

Please sign in to comment.