Skip to content

Commit

Permalink
Create AbstractResponseBodyValidator class
Browse files Browse the repository at this point in the history
  • Loading branch information
RobbeSneyders committed Mar 1, 2023
1 parent 9bc5d59 commit 840517e
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 62 deletions.
10 changes: 3 additions & 7 deletions connexion/middleware/response_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,8 @@ def validate_required_headers(
raise NonConformingResponseHeaders(detail=msg)

async def __call__(self, scope: Scope, receive: Receive, send: Send):

send_fn = send

async def wrapped_send(message: t.MutableMapping[str, t.Any]) -> None:
nonlocal send_fn
nonlocal send

if message["type"] == "http.response.start":
status = str(message["status"])
Expand All @@ -107,16 +104,15 @@ async def wrapped_send(message: t.MutableMapping[str, t.Any]) -> None:
else:
validator = body_validator(
scope,
send,
schema=self._operation.response_schema(status, mime_type),
nullable=utils.is_nullable(
self._operation.response_definition(status, mime_type)
),
encoding=encoding,
)
send_fn = validator.send
send = validator.wrap_send(send)

return await send_fn(message)
return await send(message)

await self.next_app(scope, receive, wrapped_send)

Expand Down
5 changes: 4 additions & 1 deletion connexion/validators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from connexion.datastructures import MediaTypeDict

from .abstract import AbstractRequestBodyValidator # NOQA
from .abstract import ( # NOQA
AbstractRequestBodyValidator,
AbstractResponseBodyValidator,
)
from .form_data import FormDataValidator, MultiPartFormDataValidator
from .json import DefaultsJSONRequestBodyValidator # NOQA
from .json import (
Expand Down
57 changes: 56 additions & 1 deletion connexion/validators/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import typing as t

from starlette.datastructures import Headers, MutableHeaders
from starlette.types import Receive, Scope
from starlette.types import Receive, Scope, Send

from connexion.exceptions import BadRequestProblem
from connexion.utils import is_null
Expand Down Expand Up @@ -149,3 +149,58 @@ async def stream() -> t.AsyncGenerator[bytes, None]:
receive = self._insert_messages(receive, messages=messages)

return receive


class AbstractResponseBodyValidator:
"""
Validator interface with base functionality that can be subclassed for custom validators.
.. note: Validators load the whole body into memory, which can be a problem for large payloads.
"""

def __init__(
self,
scope: Scope,
*,
schema: dict,
nullable: bool = False,
encoding: str,
) -> None:
self._scope = scope
self._schema = schema
self._nullable = nullable
self._encoding = encoding

def _parse(self, body: t.Generator[bytes, None, None]) -> t.Any:
"""Parse the body."""

def _validate(self, body: t.Any) -> t.Optional[dict]:
"""
Validate the body.
:raises: :class:`connexion.exceptions.NonConformingResponse`
"""

def wrap_send(self, send: Send) -> Send:
"""Wrap the provided send channel with response body validation"""

messages = []

async def send_(message: t.MutableMapping[str, t.Any]) -> None:
messages.append(message)

if message["type"] == "http.response.start" or message.get(
"more_body", False
):
return

stream = (message.get("body", b"") for message in messages)
body = self._parse(stream)

if body is not None and not (self._nullable and is_null(body)):
self._validate(body)

while messages:
await send(messages.pop(0))

return send_
81 changes: 28 additions & 53 deletions connexion/validators/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import jsonschema
from jsonschema import Draft4Validator, ValidationError, draft4_format_checker
from starlette.types import Scope, Send
from starlette.types import Scope

from connexion.exceptions import BadRequestProblem, NonConformingResponseBody
from connexion.json_schema import (
Expand All @@ -13,7 +13,10 @@
format_error_with_path,
)
from connexion.utils import is_null
from connexion.validators import AbstractRequestBodyValidator
from connexion.validators import (
AbstractRequestBodyValidator,
AbstractResponseBodyValidator,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -100,39 +103,31 @@ def set_defaults(validator, properties, instance, schema):
)


class JSONResponseBodyValidator:
class JSONResponseBodyValidator(AbstractResponseBodyValidator):
"""Response body validator for json content types."""

def __init__(
self,
scope: Scope,
send: Send,
*,
schema: dict,
validator: t.Type[Draft4Validator] = Draft4ResponseValidator,
nullable=False,
encoding: str,
) -> None:
self._scope = scope
self._send = send
self.schema = schema
self.has_default = schema.get("default", False)
self.nullable = nullable
self.validator = validator(schema, format_checker=draft4_format_checker)
self.encoding = encoding
self._messages: t.List[t.MutableMapping[str, t.Any]] = []

@classmethod
def _error_path_message(cls, exception):
error_path = ".".join(str(item) for item in exception.path)
error_path_msg = f" - '{error_path}'" if error_path else ""
return error_path_msg

def validate(self, body: dict):
@property
def validator(self) -> Draft4Validator:
return Draft4ResponseValidator(
self._schema, format_checker=draft4_format_checker
)

def _parse(self, stream: t.Generator[bytes, None, None]) -> t.Any:
body = b"".join(stream).decode(self._encoding)

if not body:
return body

try:
return json.loads(body)
except json.decoder.JSONDecodeError as e:
raise NonConformingResponseBody(str(e))

def _validate(self, body: dict):
try:
self.validator.validate(body)
except ValidationError as exception:
error_path_msg = self._error_path_message(exception=exception)
error_path_msg = format_error_with_path(exception=exception)
logger.error(
f"Validation error: {exception.message}{error_path_msg}",
extra={"validator": "body"},
Expand All @@ -141,31 +136,11 @@ def validate(self, body: dict):
detail=f"Response body does not conform to specification. {exception.message}{error_path_msg}"
)

def parse(self, body: str) -> dict:
try:
return json.loads(body)
except json.decoder.JSONDecodeError as e:
raise NonConformingResponseBody(str(e))

async def send(self, message: t.MutableMapping[str, t.Any]) -> None:
self._messages.append(message)

if message["type"] == "http.response.start" or message.get("more_body", False):
return

bytes_body = b"".join([message.get("body", b"") for message in self._messages])
decoded_body = bytes_body.decode(self.encoding)

if decoded_body and not (self.nullable and is_null(decoded_body)):
body = self.parse(decoded_body)
self.validate(body)

while self._messages:
await self._send(self._messages.pop(0))


class TextResponseBodyValidator(JSONResponseBodyValidator):
def parse(self, body: str) -> str: # type: ignore
def _parse(self, stream: t.Generator[bytes, None, None]) -> str: # type: ignore
body = b"".join(stream).decode(self._encoding)

try:
return json.loads(body)
except json.decoder.JSONDecodeError:
Expand Down

0 comments on commit 840517e

Please sign in to comment.