Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract JSON request body validation to middleware #1588

Merged
merged 6 commits into from
Sep 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 3 additions & 33 deletions connexion/decorators/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,11 @@
from jsonschema.validators import extend
from werkzeug.datastructures import FileStorage

from ..exceptions import (
BadRequestProblem,
ExtraParameterProblem,
UnsupportedMediaTypeProblem,
)
from ..exceptions import BadRequestProblem, ExtraParameterProblem
from ..http_facts import FORM_CONTENT_TYPES
from ..json_schema import Draft4RequestValidator, Draft4ResponseValidator
from ..lifecycle import ConnexionResponse
from ..utils import all_json, boolean, is_json_mimetype, is_null, is_nullable
from ..utils import boolean, is_null, is_nullable

logger = logging.getLogger("connexion.decorators.validation")

Expand Down Expand Up @@ -141,33 +137,7 @@ def __call__(self, function):

@functools.wraps(function)
def wrapper(request):
if all_json(self.consumes):
data = request.json

empty_body = not (request.body or request.form or request.files)
if data is None and not empty_body and not self.is_null_value_valid:
try:
ctype_is_json = is_json_mimetype(
request.headers.get("Content-Type", "")
)
except ValueError:
ctype_is_json = False

if ctype_is_json:
# Content-Type is json but actual body was not parsed
raise BadRequestProblem(detail="Request body is not valid JSON")
else:
# the body has contents that were not parsed as JSON
raise UnsupportedMediaTypeProblem(
detail="Invalid Content-type ({content_type}), expected JSON data".format(
content_type=request.headers.get("Content-Type", "")
)
)

logger.debug("%s validating schema...", request.url)
if data is not None or not self.has_default:
self.validate_schema(data, request.url)
elif self.consumes[0] in FORM_CONTENT_TYPES:
if self.consumes[0] in FORM_CONTENT_TYPES:
data = dict(request.form.items()) or (
request.body if len(request.body) > 0 else {}
)
Expand Down
2 changes: 2 additions & 0 deletions connexion/middleware/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from connexion.middleware.routing import RoutingMiddleware
from connexion.middleware.security import SecurityMiddleware
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
from connexion.middleware.validation import ValidationMiddleware


class ConnexionMiddleware:
Expand All @@ -17,6 +18,7 @@ class ConnexionMiddleware:
SwaggerUIMiddleware,
RoutingMiddleware,
SecurityMiddleware,
ValidationMiddleware,
]

def __init__(
Expand Down
6 changes: 1 addition & 5 deletions connexion/middleware/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from starlette.types import ASGIApp, Receive, Scope, Send

from connexion.apis import AbstractRoutingAPI
from connexion.exceptions import NotFoundProblem
from connexion.middleware import AppMiddleware
from connexion.operations import AbstractOperation
from connexion.resolver import Resolver
Expand Down Expand Up @@ -61,10 +60,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

# Needs to be set so starlette router throws exceptions instead of returning error responses
scope["app"] = self
try:
await self.router(scope, receive, send)
except ValueError:
raise NotFoundProblem
await self.router(scope, receive, send)


class RoutingAPI(AbstractRoutingAPI):
Expand Down
1 change: 1 addition & 0 deletions connexion/middleware/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def from_operation(
operation: t.Union[AbstractOperation, Specification],
security_handler_factory: SecurityHandlerFactory,
):
# TODO: Turn Operation class into OperationSpec and use as init argument instead
return cls(
security_handler_factory,
security=operation.security,
Expand Down
232 changes: 232 additions & 0 deletions connexion/middleware/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
"""
Validation Middleware.
"""
import logging
import pathlib
import typing as t

from starlette.types import ASGIApp, Receive, Scope, Send

from connexion.apis.abstract import AbstractSpecAPI
from connexion.decorators.uri_parsing import AbstractURIParser
from connexion.exceptions import MissingMiddleware, UnsupportedMediaTypeProblem
from connexion.http_facts import METHODS
from connexion.middleware import AppMiddleware
from connexion.middleware.routing import ROUTING_CONTEXT
from connexion.operations import AbstractOperation
from connexion.resolver import ResolverError
from connexion.utils import is_nullable
from connexion.validators import JSONBodyValidator

from ..decorators.response import ResponseValidator
from ..decorators.validation import ParameterValidator

logger = logging.getLogger("connexion.middleware.validation")

VALIDATOR_MAP = {
"parameter": ParameterValidator,
"body": {"application/json": JSONBodyValidator},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have been struggling here with the distinction of parameter and body as there currently isn't a clear split: form parameters are validated by the ParameterValidator but also by the RequestBodyValidator. The request body also contains parameters, at least in Swagger 2 (link). In OpenAPI there is a clearer distinction between "parameters" and the "body" (link).

Seeing this, I think a way forward could be to move everything of formdata to a separate body validator as well. This also simplifies things wrt accessing the ASGI scope only vs also needing the send and receive callables. Wdyt?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that makes sense. Just let me rephrase to make sure I understand:

  • Move form validation into a separate body validator so it has the full ASGI interface (scope, send, receive) available and can decide if / how to consume the stream
  • Parameter validators will only have the scope available, possibly via a Request object which does not provide access to the stream.

Correct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, indeed. This to limit the amount of places in which the body stream could be accessed.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👋 does this code mean that the JSONBodyValidator is string-literally matched? I ask because I was thinking about a case where an American business is insisting that a partner use their mm/dd/yyyy format for dates. They already use Connexion, and I'd sooner specify a separate accept & content-type for their weird expected output, and have everything else use, nice readable, standard-compliant ISO-8601 as per https://datatracker.ietf.org/doc/html/rfc3339#section-5.6 and https://swagger.io/docs/specification/data-models/data-types/#string.

"response": ResponseValidator,
}


class ValidationMiddleware(AppMiddleware):
"""Middleware for validating requests according to the API contract."""

def __init__(self, app: ASGIApp) -> None:
self.app = app
self.apis: t.Dict[str, ValidationAPI] = {}

def add_api(
self, specification: t.Union[pathlib.Path, str, dict], **kwargs
) -> None:
api = ValidationAPI(specification, next_app=self.app, **kwargs)
self.apis[api.base_path] = api

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return

try:
connexion_context = scope["extensions"][ROUTING_CONTEXT]
except KeyError:
raise MissingMiddleware(
"Could not find routing information in scope. Please make sure "
"you have a routing middleware registered upstream. "
)
api_base_path = connexion_context.get("api_base_path")
if api_base_path:
api = self.apis[api_base_path]
operation_id = connexion_context.get("operation_id")
try:
operation = api.operations[operation_id]
except KeyError as e:
if operation_id is None:
logger.debug("Skipping validation check for operation without id.")
await self.app(scope, receive, send)
return
else:
raise MissingValidationOperation(
"Encountered unknown operation_id."
) from e
else:
return await operation(scope, receive, send)

await self.app(scope, receive, send)


class ValidationAPI(AbstractSpecAPI):
"""Validation API."""

def __init__(
self,
specification: t.Union[pathlib.Path, str, dict],
*args,
next_app: ASGIApp,
validate_responses=False,
strict_validation=False,
validator_map=None,
uri_parser_class=None,
**kwargs,
):
super().__init__(specification, *args, **kwargs)
self.next_app = next_app

self.validator_map = validator_map

logger.debug("Validate Responses: %s", str(validate_responses))
self.validate_responses = validate_responses

logger.debug("Strict Request Validation: %s", str(strict_validation))
self.strict_validation = strict_validation

self.uri_parser_class = uri_parser_class

self.operations: t.Dict[str, ValidationOperation] = {}
self.add_paths()

def add_paths(self):
paths = self.specification.get("paths", {})
for path, methods in paths.items():
for method in methods:
if method not in METHODS:
continue
try:
self.add_operation(path, method)
except ResolverError:
# ResolverErrors are either raised or handled in routing middleware.
pass

def add_operation(self, path: str, method: str) -> None:
operation_cls = self.specification.operation_cls
operation = operation_cls.from_spec(
self.specification, self, path, method, self.resolver
)
validation_operation = self.make_operation(operation)
self._add_operation_internal(operation.operation_id, validation_operation)

def make_operation(self, operation: AbstractOperation):
return ValidationOperation(
operation,
self.next_app,
validate_responses=self.validate_responses,
strict_validation=self.strict_validation,
validator_map=self.validator_map,
uri_parser_class=self.uri_parser_class,
)

def _add_operation_internal(
self, operation_id: str, operation: "ValidationOperation"
):
self.operations[operation_id] = operation


class ValidationOperation:
def __init__(
self,
operation: AbstractOperation,
next_app: ASGIApp,
validate_responses: bool = False,
strict_validation: bool = False,
validator_map: t.Optional[dict] = None,
uri_parser_class: t.Optional[AbstractURIParser] = None,
) -> None:
self._operation = operation
self.next_app = next_app
self.validate_responses = validate_responses
self.strict_validation = strict_validation
self._validator_map = VALIDATOR_MAP
self._validator_map.update(validator_map or {})
self.uri_parser_class = uri_parser_class

def extract_content_type(self, headers: dict) -> t.Tuple[str, str]:
"""Extract the mime type and encoding from the content type headers.

:param headers: Header dict from ASGI scope

:return: A tuple of mime type, encoding
"""
encoding = "utf-8"
for key, value in headers:
# Headers can always be decoded using latin-1:
# https://stackoverflow.com/a/27357138/4098821
key = key.decode("latin-1")
if key.lower() == "content-type":
content_type = value.decode("latin-1")
if ";" in content_type:
mime_type, parameters = content_type.split(";", maxsplit=1)

prefix = "charset="
for parameter in parameters.split(";"):
if parameter.startswith(prefix):
encoding = parameter[len(prefix) :]
else:
mime_type = content_type
break
else:
# Content-type header is not required. Take a best guess.
mime_type = self._operation.consumes[0]

return mime_type, encoding

def validate_mime_type(self, mime_type: str) -> None:
"""Validate the mime type against the spec.

:param mime_type: mime type from content type header
"""
if mime_type.lower() not in [c.lower() for c in self._operation.consumes]:
raise UnsupportedMediaTypeProblem(
detail=f"Invalid Content-type ({mime_type}), "
f"expected {self._operation.consumes}"
)

async def __call__(self, scope: Scope, receive: Receive, send: Send):
headers = scope["headers"]
mime_type, encoding = self.extract_content_type(headers)
self.validate_mime_type(mime_type)

# TODO: Validate parameters

# Validate body
try:
body_validator = self._validator_map["body"][mime_type] # type: ignore
except KeyError:
logging.info(
f"Skipping validation. No validator registered for content type: "
f"{mime_type}."
)
else:
validator = body_validator(
self.next_app,
schema=self._operation.body_schema,
nullable=is_nullable(self._operation.body_definition),
encoding=encoding,
)
return await validator(scope, receive, send)

await self.next_app(scope, receive, send)


class MissingValidationOperation(Exception):
"""Missing validation operation"""
2 changes: 1 addition & 1 deletion connexion/operations/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,12 +465,12 @@ def __validation_decorators(self):
:rtype: types.FunctionType
"""
ParameterValidator = self.validator_map["parameter"]
RequestBodyValidator = self.validator_map["body"]
if self.parameters:
yield ParameterValidator(
self.parameters, self.api, strict_validation=self.strict_validation
)
if self.body_schema:
# TODO: temporarily hardcoded, remove RequestBodyValidator completely
yield RequestBodyValidator(
self.body_schema,
self.consumes,
Expand Down
Loading