Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move JSON response body validation to middleware #1591

Merged
merged 10 commits into from
Oct 3, 2022
135 changes: 0 additions & 135 deletions connexion/decorators/response.py

This file was deleted.

25 changes: 1 addition & 24 deletions connexion/decorators/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from ..exceptions import BadRequestProblem, ExtraParameterProblem
from ..http_facts import FORM_CONTENT_TYPES
from ..json_schema import Draft4RequestValidator, Draft4ResponseValidator
from ..json_schema import Draft4RequestValidator
from ..lifecycle import ConnexionResponse
from ..utils import boolean, is_null, is_nullable

Expand Down Expand Up @@ -196,29 +196,6 @@ def validate_schema(self, data: dict, url: str) -> t.Optional[ConnexionResponse]
return None


class ResponseBodyValidator:
def __init__(self, schema, validator=None):
"""
:param schema: The schema of the response body
:param validator: Validator class that should be used to validate passed data
against API schema. Default is Draft4ResponseValidator.
:type validator: jsonschema.IValidator
"""
ValidatorClass = validator or Draft4ResponseValidator
self.validator = ValidatorClass(schema, format_checker=draft4_format_checker)

def validate_schema(self, data: dict, url: str) -> t.Optional[ConnexionResponse]:
try:
self.validator.validate(data)
except ValidationError as exception:
logger.error(
f"{url} validation error: {exception}", extra={"validator": "response"}
)
raise exception

return None


class ParameterValidator:
def __init__(self, parameters, api, strict_validation=False):
"""
Expand Down
6 changes: 4 additions & 2 deletions connexion/middleware/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

from connexion.middleware.abstract import AppMiddleware
from connexion.middleware.exceptions import ExceptionMiddleware
from connexion.middleware.request_validation import RequestValidationMiddleware
from connexion.middleware.response_validation import ResponseValidationMiddleware
from connexion.middleware.routing import RoutingMiddleware
from connexion.middleware.security import SecurityMiddleware
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
from connexion.middleware.validation import ValidationMiddleware


class ConnexionMiddleware:
Expand All @@ -18,7 +19,8 @@ class ConnexionMiddleware:
SwaggerUIMiddleware,
RoutingMiddleware,
SecurityMiddleware,
ValidationMiddleware,
RequestValidationMiddleware,
ResponseValidationMiddleware,
]

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,71 +6,51 @@

from starlette.types import ASGIApp, Receive, Scope, Send

from connexion import utils
from connexion.decorators.uri_parsing import AbstractURIParser
from connexion.exceptions import UnsupportedMediaTypeProblem
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
from connexion.operations import AbstractOperation
from connexion.utils import is_nullable
from connexion.validators import JSONBodyValidator

from ..decorators.response import ResponseValidator
from ..decorators.validation import ParameterValidator
from connexion.validators import VALIDATOR_MAP

logger = logging.getLogger("connexion.middleware.validation")

VALIDATOR_MAP = {
"parameter": ParameterValidator,
"body": {"application/json": JSONBodyValidator},
"response": ResponseValidator,
}


class ValidationOperation:
class RequestValidationOperation:
def __init__(
self,
next_app: ASGIApp,
*,
operation: AbstractOperation,
validate_responses: bool = False,
strict_validation: bool = False,
validator_map: t.Optional[dict] = None,
uri_parser_class: t.Optional[AbstractURIParser] = None,
) -> None:
self.next_app = next_app
self._operation = operation
self.validate_responses = validate_responses
self.strict_validation = strict_validation
self._validator_map = VALIDATOR_MAP
self._validator_map.update(validator_map or {})
self.uri_parser_class = uri_parser_class

def extract_content_type(self, headers: dict) -> t.Tuple[str, str]:
def extract_content_type(
self, headers: t.List[t.Tuple[bytes, bytes]]
) -> t.Tuple[str, str]:
"""Extract the mime type and encoding from the content type headers.

:param headers: Header dict from ASGI scope
:param headers: Headers from ASGI scope

:return: A tuple of mime type, encoding
"""
encoding = "utf-8"
for key, value in headers:
# Headers can always be decoded using latin-1:
# https://stackoverflow.com/a/27357138/4098821
key = key.decode("latin-1")
if key.lower() == "content-type":
content_type = value.decode("latin-1")
if ";" in content_type:
mime_type, parameters = content_type.split(";", maxsplit=1)

prefix = "charset="
for parameter in parameters.split(";"):
if parameter.startswith(prefix):
encoding = parameter[len(prefix) :]
else:
mime_type = content_type
break
else:
mime_type, encoding = utils.extract_content_type(headers)
if mime_type is None:
# Content-type header is not required. Take a best guess.
mime_type = self._operation.consumes[0]
try:
mime_type = self._operation.consumes[0]
except IndexError:
mime_type = "application/octet-stream"
if encoding is None:
encoding = "utf-8"

return mime_type, encoding

Expand All @@ -86,6 +66,8 @@ def validate_mime_type(self, mime_type: str) -> None:
)

async def __call__(self, scope: Scope, receive: Receive, send: Send):
receive_fn = receive

headers = scope["headers"]
mime_type, encoding = self.extract_content_type(headers)
self.validate_mime_type(mime_type)
Expand All @@ -102,25 +84,25 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
)
else:
validator = body_validator(
self.next_app,
scope,
receive,
schema=self._operation.body_schema,
nullable=is_nullable(self._operation.body_definition),
nullable=utils.is_nullable(self._operation.body_definition),
encoding=encoding,
)
return await validator(scope, receive, send)
receive_fn = validator.receive

await self.next_app(scope, receive, send)
await self.next_app(scope, receive_fn, send)


class ValidationAPI(RoutedAPI[ValidationOperation]):
class RequestValidationAPI(RoutedAPI[RequestValidationOperation]):
"""Validation API."""

operation_cls = ValidationOperation
operation_cls = RequestValidationOperation

def __init__(
self,
*args,
validate_responses=False,
strict_validation=False,
validator_map=None,
uri_parser_class=None,
Expand All @@ -129,31 +111,29 @@ def __init__(
super().__init__(*args, **kwargs)
self.validator_map = validator_map

logger.debug("Validate Responses: %s", str(validate_responses))
self.validate_responses = validate_responses

logger.debug("Strict Request Validation: %s", str(strict_validation))
self.strict_validation = strict_validation

self.uri_parser_class = uri_parser_class

self.add_paths()

def make_operation(self, operation: AbstractOperation) -> ValidationOperation:
return ValidationOperation(
def make_operation(
self, operation: AbstractOperation
) -> RequestValidationOperation:
return RequestValidationOperation(
self.next_app,
operation=operation,
validate_responses=self.validate_responses,
strict_validation=self.strict_validation,
validator_map=self.validator_map,
uri_parser_class=self.uri_parser_class,
)


class ValidationMiddleware(RoutedMiddleware[ValidationAPI]):
class RequestValidationMiddleware(RoutedMiddleware[RequestValidationAPI]):
"""Middleware for validating requests according to the API contract."""

api_cls = ValidationAPI
api_cls = RequestValidationAPI


class MissingValidationOperation(Exception):
Expand Down
Loading