Skip to content

Commit

Permalink
Create AbstractRequestBodyValidator class
Browse files Browse the repository at this point in the history
  • Loading branch information
RobbeSneyders committed Feb 27, 2023
1 parent 969c146 commit 3a2b000
Show file tree
Hide file tree
Showing 11 changed files with 283 additions and 220 deletions.
7 changes: 7 additions & 0 deletions connexion/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ def _do_resolve(node):
return res


def format_error_with_path(exception: ValidationError) -> str:
"""Format a `ValidationError` with path to error."""
error_path = ".".join(str(item) for item in exception.path)
error_path_msg = f" - '{error_path}'" if error_path else ""
return error_path_msg


def allow_nullable(validation_fn: t.Callable) -> t.Callable:
"""Extend an existing validation function, so it allows nullable values to be null."""

Expand Down
2 changes: 1 addition & 1 deletion connexion/middleware/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,5 @@ def common_error_handler(_request: StarletteRequest, exc: Exception) -> Response

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
# Needs to be set so starlette router throws exceptions instead of returning error responses
scope["app"] = self
scope["app"] = "connexion"
await super().__call__(scope, receive, send)
9 changes: 4 additions & 5 deletions connexion/middleware/request_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def validate_mime_type(self, mime_type: str) -> None:
)

async def __call__(self, scope: Scope, receive: Receive, send: Send):
receive_fn = receive
next_app = self.next_app

# Validate parameters & headers
uri_parser_class = self._operation._uri_parser_class
Expand Down Expand Up @@ -100,8 +100,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
)
else:
validator = body_validator(
scope,
receive,
self.next_app,
schema=schema,
required=self._operation.request_body.get("required", False),
nullable=utils.is_nullable(
Expand All @@ -113,9 +112,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
self._operation.parameters, self._operation.body_definition()
),
)
receive_fn = await validator.wrapped_receive()
next_app = validator

await self.next_app(scope, receive_fn, send)
await next_app(scope, receive, send)


class RequestValidationAPI(RoutedAPI[RequestValidationOperation]):
Expand Down
5 changes: 3 additions & 2 deletions connexion/operations/swagger2.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def _transform_form(self, form_parameters: t.List[dict]) -> dict:

default = param.get("default")
if default is not None:
prop["default"] = default
defaults[param["name"]] = default

nullable = param.get("x-nullable")
Expand Down Expand Up @@ -320,11 +321,11 @@ def _transform_form(self, form_parameters: t.List[dict]) -> dict:
"schema": {
"type": "object",
"properties": properties,
"default": defaults,
"required": required,
}
}

if defaults:
definition["schema"]["default"] = defaults
if encoding:
definition["encoding"] = encoding

Expand Down
1 change: 1 addition & 0 deletions connexion/validators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from connexion.datastructures import MediaTypeDict

from .abstract import AbstractRequestBodyValidator # NOQA
from .form_data import FormDataValidator, MultiPartFormDataValidator
from .json import DefaultsJSONRequestBodyValidator # NOQA
from .json import (
Expand Down
165 changes: 165 additions & 0 deletions connexion/validators/abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""
This module defines a Validator interface with base functionality that can be subclassed
for custom validators provided to the RequestValidationMiddleware.
"""
import copy
import json
import typing as t

from starlette.datastructures import Headers, MutableHeaders
from starlette.types import ASGIApp, Receive, Scope, Send

from connexion.exceptions import BadRequestProblem
from connexion.utils import is_null


class AbstractRequestBodyValidator:
"""
Validator interface with base functionality that can be subclassed for custom validators.
.. note: Validators load the whole body into memory, which can be a problem for large payloads.
"""

MUTABLE_VALIDATION = False
"""
Whether mutations to the body during validation should be transmitted via the receive channel.
Note that this does not apply to the substitution of a missing body with the default body, which always
updates the receive channel.
"""
MAX_MESSAGE_LENGTH = 256000
"""Maximum message length that will be sent via the receive channel for mutated bodies."""

def __init__(
self,
next_app: ASGIApp,
*,
schema: dict,
required: bool = False,
nullable: bool = False,
encoding: str,
strict_validation: bool,
**kwargs,
):
"""
:param next_app: Next ASGI App to call
:param schema: Schema of operation to validate
:param required: Whether RequestBody is required
:param nullable: Whether RequestBody is nullable
:param encoding: Encoding of body (passed via Content-Type header)
:param kwargs: Additional arguments for subclasses
:param strict_validation: Whether to allow parameters not defined in the spec
"""
self._next_app = next_app
self._schema = schema
self._nullable = nullable
self._required = required
self._encoding = encoding
self._strict_validation = strict_validation

def _validate_no_body(
self, scope: Scope, receive: Receive
) -> t.Tuple[Scope, Receive]:
"""
Validate missing body. This happens separately since the `receive` channel is never
called for requests without a body.
"""
body = self._schema.get("default")

if body is None and self._required:
raise BadRequestProblem("RequestBody is required")

return self._update_for_body(scope, receive, body=body)

def _update_for_body(
self, scope: Scope, receive: Receive, *, body: t.Any
) -> t.Tuple[Scope, Receive]:
"""Update the scope and receive channel for the body to transmit."""
if body is None:
return scope, receive

bytes_body = json.dumps(body).encode(self._encoding)

# Update the content-length header
new_scope = copy.deepcopy(scope)
headers = MutableHeaders(scope=new_scope)
headers["content-length"] = str(len(bytes_body))

# Wrap in new receive channel
messages = (
{
"type": "http.request",
"body": bytes_body[i : i + self.MAX_MESSAGE_LENGTH],
"more_body": i + self.MAX_MESSAGE_LENGTH < len(bytes_body),
}
for i in range(0, len(bytes_body), self.MAX_MESSAGE_LENGTH)
)

wrapped_receive = self._wrap_receive(receive, messages=messages)

return new_scope, wrapped_receive

@staticmethod
def _wrap_receive(
receive: Receive, *, messages: t.Iterable[t.MutableMapping]
) -> Receive:
""" "Wrap the receive channel to play messages first."""

async def wrapped_receive() -> t.MutableMapping[str, t.Any]:
for message in messages:
return message
return await receive()

return wrapped_receive

async def _parse(
self, stream: t.AsyncGenerator[bytes, None], scope: Scope
) -> t.Any:
"""Parse the incoming stream."""

def _validate(self, body: t.Any) -> t.Optional[dict]:
"""
Validate the body.
:raises: :class:`connexion.exceptions.BadRequestProblem`
"""

async def _parse_and_validate(
self, scope: Scope, receive: Receive
) -> t.Tuple[Scope, Receive]:
"""
Parse the incoming receive channel and validate the contents. If `MUTABLE_VALIDATION` is `true`, returns the
updated `scope` and `receive` channel, otherwise returns the original `scope` and `receive` channel.
"""
messages = []

async def stream() -> t.AsyncGenerator[bytes, None]:
more_body = True
while more_body:
message = await receive()
messages.append(message)
more_body = message.get("more_body", False)
yield message.get("body", b"")
yield b""

body = await self._parse(stream(), scope=scope)

if not (self._nullable and is_null(body)):
self._validate(body)

if self.MUTABLE_VALIDATION:
scope, receive = self._update_for_body(scope, receive, body=body)
else:
receive = self._wrap_receive(receive, messages=messages)

return scope, receive

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
# Check for missing body first.
# If the content-length header is 0, the receive channel is never called, so validation needs to happen here.
headers = Headers(scope=scope)
if not int(headers.get("content-length", 0)):
scope, receive = self._validate_no_body(scope, receive)

scope, receive = await self._parse_and_validate(scope, receive)

await self._next_app(scope, receive, send)
Loading

0 comments on commit 3a2b000

Please sign in to comment.