Skip to content

Commit

Permalink
Extract JSON request body validation to middleware (#1588)
Browse files Browse the repository at this point in the history
* Set up code skeleton for validation middleware

* Add more boilerplate code

* WIP

* Add ASGI JSONBodyValidator

* Revert example changes

* Remove incorrect content type test

Co-authored-by: Ruwan <ruwanlambrichts@gmail.com>
  • Loading branch information
RobbeSneyders and Ruwann authored Sep 18, 2022
1 parent e4b7827 commit fb071ea
Show file tree
Hide file tree
Showing 14 changed files with 2,539 additions and 77 deletions.
36 changes: 3 additions & 33 deletions connexion/decorators/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,11 @@
from jsonschema.validators import extend
from werkzeug.datastructures import FileStorage

from ..exceptions import (
BadRequestProblem,
ExtraParameterProblem,
UnsupportedMediaTypeProblem,
)
from ..exceptions import BadRequestProblem, ExtraParameterProblem
from ..http_facts import FORM_CONTENT_TYPES
from ..json_schema import Draft4RequestValidator, Draft4ResponseValidator
from ..lifecycle import ConnexionResponse
from ..utils import all_json, boolean, is_json_mimetype, is_null, is_nullable
from ..utils import boolean, is_null, is_nullable

logger = logging.getLogger("connexion.decorators.validation")

Expand Down Expand Up @@ -141,33 +137,7 @@ def __call__(self, function):

@functools.wraps(function)
def wrapper(request):
if all_json(self.consumes):
data = request.json

empty_body = not (request.body or request.form or request.files)
if data is None and not empty_body and not self.is_null_value_valid:
try:
ctype_is_json = is_json_mimetype(
request.headers.get("Content-Type", "")
)
except ValueError:
ctype_is_json = False

if ctype_is_json:
# Content-Type is json but actual body was not parsed
raise BadRequestProblem(detail="Request body is not valid JSON")
else:
# the body has contents that were not parsed as JSON
raise UnsupportedMediaTypeProblem(
detail="Invalid Content-type ({content_type}), expected JSON data".format(
content_type=request.headers.get("Content-Type", "")
)
)

logger.debug("%s validating schema...", request.url)
if data is not None or not self.has_default:
self.validate_schema(data, request.url)
elif self.consumes[0] in FORM_CONTENT_TYPES:
if self.consumes[0] in FORM_CONTENT_TYPES:
data = dict(request.form.items()) or (
request.body if len(request.body) > 0 else {}
)
Expand Down
2 changes: 2 additions & 0 deletions connexion/middleware/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from connexion.middleware.routing import RoutingMiddleware
from connexion.middleware.security import SecurityMiddleware
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
from connexion.middleware.validation import ValidationMiddleware


class ConnexionMiddleware:
Expand All @@ -17,6 +18,7 @@ class ConnexionMiddleware:
SwaggerUIMiddleware,
RoutingMiddleware,
SecurityMiddleware,
ValidationMiddleware,
]

def __init__(
Expand Down
6 changes: 1 addition & 5 deletions connexion/middleware/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from starlette.types import ASGIApp, Receive, Scope, Send

from connexion.apis import AbstractRoutingAPI
from connexion.exceptions import NotFoundProblem
from connexion.middleware import AppMiddleware
from connexion.operations import AbstractOperation
from connexion.resolver import Resolver
Expand Down Expand Up @@ -61,10 +60,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

# Needs to be set so starlette router throws exceptions instead of returning error responses
scope["app"] = self
try:
await self.router(scope, receive, send)
except ValueError:
raise NotFoundProblem
await self.router(scope, receive, send)


class RoutingAPI(AbstractRoutingAPI):
Expand Down
1 change: 1 addition & 0 deletions connexion/middleware/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def from_operation(
operation: t.Union[AbstractOperation, Specification],
security_handler_factory: SecurityHandlerFactory,
):
# TODO: Turn Operation class into OperationSpec and use as init argument instead
return cls(
security_handler_factory,
security=operation.security,
Expand Down
232 changes: 232 additions & 0 deletions connexion/middleware/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
"""
Validation Middleware.
"""
import logging
import pathlib
import typing as t

from starlette.types import ASGIApp, Receive, Scope, Send

from connexion.apis.abstract import AbstractSpecAPI
from connexion.decorators.uri_parsing import AbstractURIParser
from connexion.exceptions import MissingMiddleware, UnsupportedMediaTypeProblem
from connexion.http_facts import METHODS
from connexion.middleware import AppMiddleware
from connexion.middleware.routing import ROUTING_CONTEXT
from connexion.operations import AbstractOperation
from connexion.resolver import ResolverError
from connexion.utils import is_nullable
from connexion.validators import JSONBodyValidator

from ..decorators.response import ResponseValidator
from ..decorators.validation import ParameterValidator

logger = logging.getLogger("connexion.middleware.validation")

VALIDATOR_MAP = {
"parameter": ParameterValidator,
"body": {"application/json": JSONBodyValidator},
"response": ResponseValidator,
}


class ValidationMiddleware(AppMiddleware):
"""Middleware for validating requests according to the API contract."""

def __init__(self, app: ASGIApp) -> None:
self.app = app
self.apis: t.Dict[str, ValidationAPI] = {}

def add_api(
self, specification: t.Union[pathlib.Path, str, dict], **kwargs
) -> None:
api = ValidationAPI(specification, next_app=self.app, **kwargs)
self.apis[api.base_path] = api

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return

try:
connexion_context = scope["extensions"][ROUTING_CONTEXT]
except KeyError:
raise MissingMiddleware(
"Could not find routing information in scope. Please make sure "
"you have a routing middleware registered upstream. "
)
api_base_path = connexion_context.get("api_base_path")
if api_base_path:
api = self.apis[api_base_path]
operation_id = connexion_context.get("operation_id")
try:
operation = api.operations[operation_id]
except KeyError as e:
if operation_id is None:
logger.debug("Skipping validation check for operation without id.")
await self.app(scope, receive, send)
return
else:
raise MissingValidationOperation(
"Encountered unknown operation_id."
) from e
else:
return await operation(scope, receive, send)

await self.app(scope, receive, send)


class ValidationAPI(AbstractSpecAPI):
"""Validation API."""

def __init__(
self,
specification: t.Union[pathlib.Path, str, dict],
*args,
next_app: ASGIApp,
validate_responses=False,
strict_validation=False,
validator_map=None,
uri_parser_class=None,
**kwargs,
):
super().__init__(specification, *args, **kwargs)
self.next_app = next_app

self.validator_map = validator_map

logger.debug("Validate Responses: %s", str(validate_responses))
self.validate_responses = validate_responses

logger.debug("Strict Request Validation: %s", str(strict_validation))
self.strict_validation = strict_validation

self.uri_parser_class = uri_parser_class

self.operations: t.Dict[str, ValidationOperation] = {}
self.add_paths()

def add_paths(self):
paths = self.specification.get("paths", {})
for path, methods in paths.items():
for method in methods:
if method not in METHODS:
continue
try:
self.add_operation(path, method)
except ResolverError:
# ResolverErrors are either raised or handled in routing middleware.
pass

def add_operation(self, path: str, method: str) -> None:
operation_cls = self.specification.operation_cls
operation = operation_cls.from_spec(
self.specification, self, path, method, self.resolver
)
validation_operation = self.make_operation(operation)
self._add_operation_internal(operation.operation_id, validation_operation)

def make_operation(self, operation: AbstractOperation):
return ValidationOperation(
operation,
self.next_app,
validate_responses=self.validate_responses,
strict_validation=self.strict_validation,
validator_map=self.validator_map,
uri_parser_class=self.uri_parser_class,
)

def _add_operation_internal(
self, operation_id: str, operation: "ValidationOperation"
):
self.operations[operation_id] = operation


class ValidationOperation:
def __init__(
self,
operation: AbstractOperation,
next_app: ASGIApp,
validate_responses: bool = False,
strict_validation: bool = False,
validator_map: t.Optional[dict] = None,
uri_parser_class: t.Optional[AbstractURIParser] = None,
) -> None:
self._operation = operation
self.next_app = next_app
self.validate_responses = validate_responses
self.strict_validation = strict_validation
self._validator_map = VALIDATOR_MAP
self._validator_map.update(validator_map or {})
self.uri_parser_class = uri_parser_class

def extract_content_type(self, headers: dict) -> t.Tuple[str, str]:
"""Extract the mime type and encoding from the content type headers.
:param headers: Header dict from ASGI scope
:return: A tuple of mime type, encoding
"""
encoding = "utf-8"
for key, value in headers:
# Headers can always be decoded using latin-1:
# https://stackoverflow.com/a/27357138/4098821
key = key.decode("latin-1")
if key.lower() == "content-type":
content_type = value.decode("latin-1")
if ";" in content_type:
mime_type, parameters = content_type.split(";", maxsplit=1)

prefix = "charset="
for parameter in parameters.split(";"):
if parameter.startswith(prefix):
encoding = parameter[len(prefix) :]
else:
mime_type = content_type
break
else:
# Content-type header is not required. Take a best guess.
mime_type = self._operation.consumes[0]

return mime_type, encoding

def validate_mime_type(self, mime_type: str) -> None:
"""Validate the mime type against the spec.
:param mime_type: mime type from content type header
"""
if mime_type.lower() not in [c.lower() for c in self._operation.consumes]:
raise UnsupportedMediaTypeProblem(
detail=f"Invalid Content-type ({mime_type}), "
f"expected {self._operation.consumes}"
)

async def __call__(self, scope: Scope, receive: Receive, send: Send):
headers = scope["headers"]
mime_type, encoding = self.extract_content_type(headers)
self.validate_mime_type(mime_type)

# TODO: Validate parameters

# Validate body
try:
body_validator = self._validator_map["body"][mime_type] # type: ignore
except KeyError:
logging.info(
f"Skipping validation. No validator registered for content type: "
f"{mime_type}."
)
else:
validator = body_validator(
self.next_app,
schema=self._operation.body_schema,
nullable=is_nullable(self._operation.body_definition),
encoding=encoding,
)
return await validator(scope, receive, send)

await self.next_app(scope, receive, send)


class MissingValidationOperation(Exception):
"""Missing validation operation"""
2 changes: 1 addition & 1 deletion connexion/operations/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,12 +465,12 @@ def __validation_decorators(self):
:rtype: types.FunctionType
"""
ParameterValidator = self.validator_map["parameter"]
RequestBodyValidator = self.validator_map["body"]
if self.parameters:
yield ParameterValidator(
self.parameters, self.api, strict_validation=self.strict_validation
)
if self.body_schema:
# TODO: temporarily hardcoded, remove RequestBodyValidator completely
yield RequestBodyValidator(
self.body_schema,
self.consumes,
Expand Down
Loading

0 comments on commit fb071ea

Please sign in to comment.