Skip to content

Commit

Permalink
Add defaults enforcing validator
Browse files Browse the repository at this point in the history
  • Loading branch information
RobbeSneyders committed Jan 26, 2023
1 parent edb0381 commit b64d494
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 52 deletions.
9 changes: 6 additions & 3 deletions connexion/apps/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def fn(self) -> t.Callable:
async def __call__(
self, scope: Scope, receive: Receive, send: Send
) -> StarletteResponse:
return await self.fn(scope=scope, receive=receive, send=send)
response = await self.fn()
return await response(scope, receive, send)


class AsyncApi(RoutedAPI[AsyncOperation]):
Expand All @@ -80,15 +81,17 @@ def make_operation(self, operation: AbstractOperation) -> AsyncOperation:


class AsyncMiddlewareApp(RoutedMiddleware[AsyncApi]):

api_cls = AsyncApi

def __init__(self) -> None:
self.apis: t.Dict[str, AsyncApi] = {}
self.operations: t.Dict[str, AsyncOperation] = {}
self.router = Router()
super().__init__(self.router)

def add_api(self, *args, **kwargs):
api = AsyncApi(*args, **kwargs)
self.apis[api.base_path] = api
api = super().add_api(*args, **kwargs)
self.router.mount(api.base_path, api.router)
return api

Expand Down
5 changes: 1 addition & 4 deletions connexion/decorators/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,7 @@ async def wrapper(*args, **kwargs):
uri_parser=self.uri_parser, scope=scope, receive=receive
)
decorated_function = self.decorate(function)
response = decorated_function(request)
while asyncio.iscoroutine(response):
response = await response
return response
return await decorated_function(request)

return wrapper

Expand Down
7 changes: 3 additions & 4 deletions connexion/middleware/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class SpecMiddleware(abc.ABC):
@abc.abstractmethod
def add_api(
self, specification: t.Union[pathlib.Path, str, dict], **kwargs
) -> None:
) -> t.Any:
"""
Register an API represented by a single OpenAPI specification on this middleware.
Multiple APIs can be registered on a single middleware.
Expand Down Expand Up @@ -246,11 +246,10 @@ def __init__(self, app: ASGIApp) -> None:
self.app = app
self.apis: t.Dict[str, API] = {}

def add_api(
self, specification: t.Union[pathlib.Path, str, dict], **kwargs
) -> None:
def add_api(self, specification: t.Union[pathlib.Path, str, dict], **kwargs) -> API:
api = self.api_cls(specification, next_app=self.app, **kwargs)
self.apis[api.base_path] = api
return api

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Fetches the operation related to the request and calls it."""
Expand Down
1 change: 1 addition & 0 deletions connexion/validators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from connexion.datastructures import MediaTypeDict

from .form_data import FormDataValidator, MultiPartFormDataValidator
from .json import DefaultsJSONRequestBodyValidator # NOQA
from .json import (
JSONRequestBodyValidator,
JSONResponseBodyValidator,
Expand Down
97 changes: 86 additions & 11 deletions connexion/validators/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import typing as t

import jsonschema
from jsonschema import Draft4Validator, ValidationError, draft4_format_checker
from starlette.types import Receive, Scope, Send

Expand Down Expand Up @@ -33,7 +34,6 @@ def __init__(
self.nullable = nullable
self.validator = validator(schema, format_checker=draft4_format_checker)
self.encoding = encoding
self._messages: t.List[t.MutableMapping[str, t.Any]] = []

@classmethod
def _error_path_message(cls, exception):
Expand All @@ -52,30 +52,107 @@ def validate(self, body: dict):
)
raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}")

@staticmethod
def parse(body: str) -> dict:
def parse(self, body: str) -> dict:
try:
return json.loads(body)
except json.decoder.JSONDecodeError as e:
raise BadRequestProblem(str(e))

async def wrapped_receive(self) -> Receive:
more_body = True
messages = []
while more_body:
message = await self._receive()
self._messages.append(message)
messages.append(message)
more_body = message.get("more_body", False)

bytes_body = b"".join([message.get("body", b"") for message in self._messages])
bytes_body = b"".join([message.get("body", b"") for message in messages])
decoded_body = bytes_body.decode(self.encoding)

if decoded_body and not (self.nullable and is_null(decoded_body)):
body = self.parse(decoded_body)
self.validate(body)

async def receive() -> t.MutableMapping[str, t.Any]:
while self._messages:
return self._messages.pop(0)
while messages:
return messages.pop(0)
return await self._receive()

return receive


class DefaultsJSONRequestBodyValidator(JSONRequestBodyValidator):
"""Request body validator for json content types which fills in default values. This Validator
intercepts the body, makes changes to it, and replays it for the next ASGI application."""

def __init__(self, *args, **kwargs):
defaults_validator = self.extend_with_set_default(Draft4RequestValidator)
super().__init__(*args, validator=defaults_validator, **kwargs)

# via https://python-jsonschema.readthedocs.io/
@staticmethod
def extend_with_set_default(validator_class):
validate_properties = validator_class.VALIDATORS["properties"]

def set_defaults(validator, properties, instance, schema):
for property, subschema in properties.items():
if "default" in subschema:
instance.setdefault(property, subschema["default"])

yield from validate_properties(validator, properties, instance, schema)

return jsonschema.validators.extend(
validator_class, {"properties": set_defaults}
)

async def read_body(self) -> t.Tuple[str, int]:
"""Read the body from the receive channel.
:return: A tuple (body, max_length) where max_length is the length of the largest message.
"""
more_body = True
max_length = 256000
messages = []
while more_body:
message = await self._receive()
max_length = max(max_length, len(message.get("body", b"")))
messages.append(message)
more_body = message.get("more_body", False)

bytes_body = b"".join([message.get("body", b"") for message in messages])

return bytes_body.decode(self.encoding), max_length

async def wrapped_receive(self) -> Receive:
"""Receive channel to pass on to next ASGI application."""
decoded_body, max_length = await self.read_body()

# Validate the body if not null
if decoded_body and not (self.nullable and is_null(decoded_body)):
body = self.parse(decoded_body)
del decoded_body
self.validate(body)
str_body = json.dumps(body)
else:
str_body = decoded_body

bytes_body = str_body.encode(self.encoding)
del str_body

# Recreate ASGI messages from validated body so changes made by the validator are propagated
messages = [
{
"type": "http.request",
"body": bytes_body[i : i + max_length],
"more_body": i + max_length < len(bytes_body),
}
for i in range(0, len(bytes_body), max_length)
]
del bytes_body

async def receive() -> t.MutableMapping[str, t.Any]:
while messages:
return messages.pop(0)
return await self._receive()

return receive
Expand Down Expand Up @@ -122,8 +199,7 @@ def validate(self, body: dict):
message=f"{exception.message}{error_path_msg}"
)

@staticmethod
def parse(body: str) -> dict:
def parse(self, body: str) -> dict:
try:
return json.loads(body)
except json.decoder.JSONDecodeError as e:
Expand All @@ -147,8 +223,7 @@ async def send(self, message: t.MutableMapping[str, t.Any]) -> None:


class TextResponseBodyValidator(JSONResponseBodyValidator):
@staticmethod
def parse(body: str) -> str: # type: ignore
def parse(self, body: str) -> str: # type: ignore
try:
return json.loads(body)
except json.decoder.JSONDecodeError:
Expand Down
32 changes: 4 additions & 28 deletions examples/enforcedefaults/app.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,18 @@
from pathlib import Path

import connexion
import jsonschema
from connexion.json_schema import Draft4RequestValidator
from connexion.validators import JSONRequestBodyValidator
from connexion.validators import DefaultsJSONRequestBodyValidator


# TODO: should work as sync endpoint when parameter decorator is fixed
async def echo(data):
def echo(data):
return data


# via https://python-jsonschema.readthedocs.io/
def extend_with_set_default(validator_class):
validate_properties = validator_class.VALIDATORS["properties"]

def set_defaults(validator, properties, instance, schema):
for property, subschema in properties.items():
if "default" in subschema:
instance.setdefault(property, subschema["default"])

yield from validate_properties(validator, properties, instance, schema)

return jsonschema.validators.extend(validator_class, {"properties": set_defaults})


DefaultsEnforcingDraft4Validator = extend_with_set_default(Draft4RequestValidator)


class DefaultsEnforcingRequestBodyValidator(JSONRequestBodyValidator):
def __init__(self, *args, **kwargs):
super().__init__(*args, validator=DefaultsEnforcingDraft4Validator, **kwargs)


validator_map = {"body": {"application/json": DefaultsEnforcingRequestBodyValidator}}
validator_map = {"body": {"application/json": DefaultsJSONRequestBodyValidator}}


app = connexion.AsyncApp(__name__, specification_dir="spec")
app.add_api("openapi.yaml", validator_map=validator_map)
app.add_api("swagger.yaml", validator_map=validator_map)


Expand Down
46 changes: 46 additions & 0 deletions examples/enforcedefaults/spec/openapi.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
openapi: 3.0.0
info:
version: '1'
title: Custom Validator Example
servers:
- url: '/openapi'

paths:
/echo:
post:
description: Echo passed data
operationId: app.echo
requestBody:
x-body-name: data
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/Data'
responses:
'200':
description: Data with defaults filled in by validator
content:
application/json:
schema:
$ref: '#/components/schemas/Data'
default:
description: Unexpected error
content:
application/json:
schema:
$ref: '#/components/schemas/Error'
components:
schemas:
Data:
type: object
properties:
outer-object:
type: object
default: {}
properties:
inner-object:
type: string
default: foo
Error:
type: string
2 changes: 1 addition & 1 deletion examples/enforcedefaults/spec/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ swagger: '2.0'
info:
version: '1'
title: Custom Validator Example
basePath: '/v1'
basePath: '/swagger'
consumes:
- application/json
produces:
Expand Down
2 changes: 1 addition & 1 deletion examples/sqlalchemy/spec/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ paths:
'201':
description: New pet created
requestBody:
x-body-name: pet
content:
application/json:
schema:
x-body-name: pet
$ref: '#/components/schemas/Pet'
delete:
tags:
Expand Down

0 comments on commit b64d494

Please sign in to comment.