diff --git a/connexion/json_schema.py b/connexion/json_schema.py index 1cf26aa96..6ab214fa7 100644 --- a/connexion/json_schema.py +++ b/connexion/json_schema.py @@ -107,6 +107,13 @@ def _do_resolve(node): return res +def format_error_with_path(exception: ValidationError) -> str: + """Format a `ValidationError` with path to error.""" + error_path = ".".join(str(item) for item in exception.path) + error_path_msg = f" - '{error_path}'" if error_path else "" + return error_path_msg + + def allow_nullable(validation_fn: t.Callable) -> t.Callable: """Extend an existing validation function, so it allows nullable values to be null.""" diff --git a/connexion/middleware/exceptions.py b/connexion/middleware/exceptions.py index 9ae5c44c9..c8f0c5114 100644 --- a/connexion/middleware/exceptions.py +++ b/connexion/middleware/exceptions.py @@ -65,5 +65,5 @@ def common_error_handler(_request: StarletteRequest, exc: Exception) -> Response async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # Needs to be set so starlette router throws exceptions instead of returning error responses - scope["app"] = self + scope["app"] = "connexion" await super().__call__(scope, receive, send) diff --git a/connexion/middleware/request_validation.py b/connexion/middleware/request_validation.py index 36cd1a582..65aafd9af 100644 --- a/connexion/middleware/request_validation.py +++ b/connexion/middleware/request_validation.py @@ -68,7 +68,7 @@ def validate_mime_type(self, mime_type: str) -> None: ) async def __call__(self, scope: Scope, receive: Receive, send: Send): - receive_fn = receive + next_app = self.next_app # Validate parameters & headers uri_parser_class = self._operation._uri_parser_class @@ -100,8 +100,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): ) else: validator = body_validator( - scope, - receive, + self.next_app, schema=schema, required=self._operation.request_body.get("required", False), nullable=utils.is_nullable( @@ -113,9 +112,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): self._operation.parameters, self._operation.body_definition() ), ) - receive_fn = await validator.wrapped_receive() + next_app = validator - await self.next_app(scope, receive_fn, send) + await next_app(scope, receive, send) class RequestValidationAPI(RoutedAPI[RequestValidationOperation]): diff --git a/connexion/operations/swagger2.py b/connexion/operations/swagger2.py index 10d986d95..722b0435d 100644 --- a/connexion/operations/swagger2.py +++ b/connexion/operations/swagger2.py @@ -291,6 +291,7 @@ def _transform_form(self, form_parameters: t.List[dict]) -> dict: default = param.get("default") if default is not None: + prop["default"] = default defaults[param["name"]] = default nullable = param.get("x-nullable") @@ -320,11 +321,11 @@ def _transform_form(self, form_parameters: t.List[dict]) -> dict: "schema": { "type": "object", "properties": properties, - "default": defaults, "required": required, } } - + if defaults: + definition["schema"]["default"] = defaults if encoding: definition["encoding"] = encoding diff --git a/connexion/validators/__init__.py b/connexion/validators/__init__.py index f527409e0..87646036e 100644 --- a/connexion/validators/__init__.py +++ b/connexion/validators/__init__.py @@ -1,5 +1,6 @@ from connexion.datastructures import MediaTypeDict +from .abstract import AbstractRequestBodyValidator # NOQA from .form_data import FormDataValidator, MultiPartFormDataValidator from .json import DefaultsJSONRequestBodyValidator # NOQA from .json import ( diff --git a/connexion/validators/abstract.py b/connexion/validators/abstract.py new file mode 100644 index 000000000..650b6e5b3 --- /dev/null +++ b/connexion/validators/abstract.py @@ -0,0 +1,165 @@ +""" +This module defines a Validator interface with base functionality that can be subclassed +for custom validators provided to the RequestValidationMiddleware. +""" +import copy +import json +import typing as t + +from starlette.datastructures import Headers, MutableHeaders +from starlette.types import ASGIApp, Receive, Scope, Send + +from connexion.exceptions import BadRequestProblem +from connexion.utils import is_null + + +class AbstractRequestBodyValidator: + """ + Validator interface with base functionality that can be subclassed for custom validators. + + .. note: Validators load the whole body into memory, which can be a problem for large payloads. + """ + + MUTABLE_VALIDATION = False + """ + Whether mutations to the body during validation should be transmitted via the receive channel. + Note that this does not apply to the substitution of a missing body with the default body, which always + updates the receive channel. + """ + MAX_MESSAGE_LENGTH = 256000 + """Maximum message length that will be sent via the receive channel for mutated bodies.""" + + def __init__( + self, + next_app: ASGIApp, + *, + schema: dict, + required: bool = False, + nullable: bool = False, + encoding: str, + strict_validation: bool, + **kwargs, + ): + """ + :param next_app: Next ASGI App to call + :param schema: Schema of operation to validate + :param required: Whether RequestBody is required + :param nullable: Whether RequestBody is nullable + :param encoding: Encoding of body (passed via Content-Type header) + :param kwargs: Additional arguments for subclasses + :param strict_validation: Whether to allow parameters not defined in the spec + """ + self._next_app = next_app + self._schema = schema + self._nullable = nullable + self._required = required + self._encoding = encoding + self._strict_validation = strict_validation + + def _validate_no_body( + self, scope: Scope, receive: Receive + ) -> t.Tuple[Scope, Receive]: + """ + Validate missing body. This happens separately since the `receive` channel is never + called for requests without a body. + """ + body = self._schema.get("default") + + if body is None and self._required: + raise BadRequestProblem("RequestBody is required") + + return self._update_for_body(scope, receive, body=body) + + def _update_for_body( + self, scope: Scope, receive: Receive, *, body: t.Any + ) -> t.Tuple[Scope, Receive]: + """Update the scope and receive channel for the body to transmit.""" + if body is None: + return scope, receive + + bytes_body = json.dumps(body).encode(self._encoding) + + # Update the content-length header + new_scope = copy.deepcopy(scope) + headers = MutableHeaders(scope=new_scope) + headers["content-length"] = str(len(bytes_body)) + + # Wrap in new receive channel + messages = ( + { + "type": "http.request", + "body": bytes_body[i : i + self.MAX_MESSAGE_LENGTH], + "more_body": i + self.MAX_MESSAGE_LENGTH < len(bytes_body), + } + for i in range(0, len(bytes_body), self.MAX_MESSAGE_LENGTH) + ) + + wrapped_receive = self._wrap_receive(receive, messages=messages) + + return new_scope, wrapped_receive + + @staticmethod + def _wrap_receive( + receive: Receive, *, messages: t.Iterable[t.MutableMapping] + ) -> Receive: + """ "Wrap the receive channel to play messages first.""" + + async def wrapped_receive() -> t.MutableMapping[str, t.Any]: + for message in messages: + return message + return await receive() + + return wrapped_receive + + async def _parse( + self, stream: t.AsyncGenerator[bytes, None], scope: Scope + ) -> t.Any: + """Parse the incoming stream.""" + + def _validate(self, body: t.Any) -> t.Optional[dict]: + """ + Validate the body. + + :raises: :class:`connexion.exceptions.BadRequestProblem` + """ + + async def _parse_and_validate( + self, scope: Scope, receive: Receive + ) -> t.Tuple[Scope, Receive]: + """ + Parse the incoming receive channel and validate the contents. If `MUTABLE_VALIDATION` is `true`, returns the + updated `scope` and `receive` channel, otherwise returns the original `scope` and `receive` channel. + """ + messages = [] + + async def stream() -> t.AsyncGenerator[bytes, None]: + more_body = True + while more_body: + message = await receive() + messages.append(message) + more_body = message.get("more_body", False) + yield message.get("body", b"") + yield b"" + + body = await self._parse(stream(), scope=scope) + + if not (self._nullable and is_null(body)): + self._validate(body) + + if self.MUTABLE_VALIDATION: + scope, receive = self._update_for_body(scope, receive, body=body) + else: + receive = self._wrap_receive(receive, messages=messages) + + return scope, receive + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + # Check for missing body first. + # If the content-length header is 0, the receive channel is never called, so validation needs to happen here. + headers = Headers(scope=scope) + if not int(headers.get("content-length", 0)): + scope, receive = self._validate_no_body(scope, receive) + + scope, receive = await self._parse_and_validate(scope, receive) + + await self._next_app(scope, receive, send) diff --git a/connexion/validators/form_data.py b/connexion/validators/form_data.py index 65d209846..dbb8f5477 100644 --- a/connexion/validators/form_data.py +++ b/connexion/validators/form_data.py @@ -1,85 +1,59 @@ import logging import typing as t -from jsonschema import Draft4Validator, ValidationError, draft4_format_checker -from starlette.datastructures import FormData, Headers, UploadFile +from jsonschema import ValidationError, draft4_format_checker +from starlette.datastructures import Headers, UploadFile from starlette.formparsers import FormParser, MultiPartParser -from starlette.types import Receive, Scope +from starlette.types import ASGIApp, Scope from connexion.exceptions import BadRequestProblem, ExtraParameterProblem -from connexion.json_schema import Draft4RequestValidator +from connexion.json_schema import Draft4RequestValidator, format_error_with_path from connexion.uri_parsing import AbstractURIParser -from connexion.utils import is_null +from connexion.validators import AbstractRequestBodyValidator logger = logging.getLogger("connexion.validators.form_data") -class FormDataValidator: +class FormDataValidator(AbstractRequestBodyValidator): """Request body validator for form content types.""" def __init__( self, - scope: Scope, - receive: Receive, + next_app: ASGIApp, *, schema: dict, - validator: t.Type[Draft4Validator] = None, required=False, nullable=False, encoding: str, - uri_parser: t.Optional[AbstractURIParser] = None, strict_validation: bool, + uri_parser: t.Optional[AbstractURIParser] = None, ) -> None: - self._scope = scope - self._receive = receive - self.schema = schema - self.has_default = schema.get("default", False) - self.nullable = nullable - self.required = required - validator_cls = validator or Draft4RequestValidator - self.validator = validator_cls(schema, format_checker=draft4_format_checker) - self.uri_parser = uri_parser - self.encoding = encoding - self._messages: t.List[t.MutableMapping[str, t.Any]] = [] - self.headers = Headers(scope=scope) - self.strict_validation = strict_validation - self.check_empty() + super().__init__( + next_app, + schema=schema, + required=required, + nullable=nullable, + encoding=encoding, + strict_validation=strict_validation, + ) + self._uri_parser = uri_parser + + @property + def _validator(self): + return Draft4RequestValidator( + self._schema, format_checker=draft4_format_checker + ) @property - def form_parser_cls(self): + def _form_parser_cls(self): return FormParser - def check_empty(self): - """`receive` is never called if body is empty, so we need to check this case at - initialization.""" - if not int(self.headers.get("content-length", 0)): - # TODO: default should be passed along and content-length updated - if self.schema.get("default"): - self.validate(self.schema.get("default")) - elif self.required: # RequestBody itself is required - raise BadRequestProblem("RequestBody is required") - elif self.schema.get("required", []): # Required top level properties - self._validate({}) - - @classmethod - def _error_path_message(cls, exception): - error_path = ".".join(str(item) for item in exception.path) - error_path_msg = f" - '{error_path}'" if error_path else "" - return error_path_msg + async def _parse(self, stream: t.AsyncGenerator[bytes, None], scope: Scope) -> dict: + headers = Headers(scope=scope) + form_parser = self._form_parser_cls(headers, stream) + data = await form_parser.parse() - def _validate(self, data: dict) -> None: - try: - self.validator.validate(data) - except ValidationError as exception: - error_path_msg = self._error_path_message(exception=exception) - logger.error( - f"Validation error: {exception.message}{error_path_msg}", - extra={"validator": "body"}, - ) - raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}") - - def _parse(self, data: FormData) -> dict: - if self.uri_parser is not None: + if self._uri_parser is not None: # Don't parse file_data form_data = {} file_data = {} @@ -90,7 +64,7 @@ def _parse(self, data: FormData) -> dict: # Replace files with empty strings for validation file_data[k] = "" - data = self.uri_parser.resolve_form(form_data) + data = self._uri_parser.resolve_form(form_data) # Add the files again data.update(file_data) else: @@ -98,45 +72,29 @@ def _parse(self, data: FormData) -> dict: return data - def _validate_strictly(self, data: FormData) -> None: + def _validate(self, data: dict) -> None: + if self._strict_validation: + self._validate_params_strictly(data) + + try: + self._validator.validate(data) + except ValidationError as exception: + error_path_msg = format_error_with_path(exception=exception) + logger.error( + f"Validation error: {exception.message}{error_path_msg}", + extra={"validator": "body"}, + ) + raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}") + + def _validate_params_strictly(self, data: dict) -> None: form_params = data.keys() - spec_params = self.schema.get("properties", {}).keys() + spec_params = self._schema.get("properties", {}).keys() errors = set(form_params).difference(set(spec_params)) if errors: raise ExtraParameterProblem(param_type="formData", extra_params=errors) - def validate(self, data: FormData) -> None: - if self.strict_validation: - self._validate_strictly(data) - - data = self._parse(data) - self._validate(data) - - async def wrapped_receive(self) -> Receive: - async def stream() -> t.AsyncGenerator[bytes, None]: - more_body = True - while more_body: - message = await self._receive() - self._messages.append(message) - more_body = message.get("more_body", False) - yield message.get("body", b"") - yield b"" - - form_parser = self.form_parser_cls(self.headers, stream()) - form = await form_parser.parse() - - if form and not (self.nullable and is_null(form)): - self.validate(form) - - async def receive() -> t.MutableMapping[str, t.Any]: - while self._messages: - return self._messages.pop(0) - return await self._receive() - - return receive - class MultiPartFormDataValidator(FormDataValidator): @property - def form_parser_cls(self): + def _form_parser_cls(self): return MultiPartParser diff --git a/connexion/validators/json.py b/connexion/validators/json.py index b23594f35..98e0563f6 100644 --- a/connexion/validators/json.py +++ b/connexion/validators/json.py @@ -4,107 +4,86 @@ import jsonschema from jsonschema import Draft4Validator, ValidationError, draft4_format_checker -from starlette.datastructures import Headers -from starlette.types import Receive, Scope, Send +from starlette.types import ASGIApp, Scope, Send from connexion.exceptions import BadRequestProblem, NonConformingResponseBody -from connexion.json_schema import Draft4RequestValidator, Draft4ResponseValidator +from connexion.json_schema import ( + Draft4RequestValidator, + Draft4ResponseValidator, + format_error_with_path, +) from connexion.utils import is_null +from connexion.validators import AbstractRequestBodyValidator logger = logging.getLogger("connexion.validators.json") -class JSONRequestBodyValidator: +class JSONRequestBodyValidator(AbstractRequestBodyValidator): """Request body validator for json content types.""" def __init__( self, - scope: Scope, - receive: Receive, + next_app: ASGIApp, *, schema: dict, - validator: t.Type[Draft4Validator] = Draft4RequestValidator, required=False, nullable=False, encoding: str, + strict_validation: bool, **kwargs, ) -> None: - self._scope = scope - self._receive = receive - self.schema = schema - self.has_default = schema.get("default", False) - self.nullable = nullable - self.required = required - self.validator = validator(schema, format_checker=draft4_format_checker) - self.encoding = encoding - self.headers = Headers(scope=scope) - self.check_empty() - - def check_empty(self): - """receive` is never called if body is empty, so we need to check this case at - initialization.""" - if not int(self.headers.get("content-length", 0)): - # TODO: default should be passed along and content-length updated - if self.schema.get("default"): - self.validate(self.schema.get("default")) - elif self.required: # RequestBody itself is required - raise BadRequestProblem("RequestBody is required") - elif self.schema.get("required", []): # Required top level properties - self.validate({}) + super().__init__( + next_app, + schema=schema, + required=required, + nullable=nullable, + encoding=encoding, + strict_validation=strict_validation, + ) - @classmethod - def _error_path_message(cls, exception): - error_path = ".".join(str(item) for item in exception.path) - error_path_msg = f" - '{error_path}'" if error_path else "" - return error_path_msg + @property + def _validator(self): + return Draft4RequestValidator( + self._schema, format_checker=draft4_format_checker + ) + + async def _parse( + self, stream: t.AsyncGenerator[bytes, None], scope: Scope + ) -> t.Union[dict, str]: + bytes_body = b"".join([message async for message in stream]) + body = bytes_body.decode(self._encoding) + + if self._nullable and is_null(body): + return body - def validate(self, body: dict): try: - self.validator.validate(body) + return json.loads(body) + except json.decoder.JSONDecodeError as e: + raise BadRequestProblem(detail=str(e)) + + def _validate(self, body: dict) -> None: + try: + return self._validator.validate(body) except ValidationError as exception: - error_path_msg = self._error_path_message(exception=exception) + error_path_msg = format_error_with_path(exception=exception) logger.error( f"Validation error: {exception.message}{error_path_msg}", extra={"validator": "body"}, ) raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}") - def parse(self, body: str) -> dict: - try: - return json.loads(body) - except json.decoder.JSONDecodeError as e: - raise BadRequestProblem(str(e)) - - async def wrapped_receive(self) -> Receive: - more_body = True - messages = [] - while more_body: - message = await self._receive() - messages.append(message) - more_body = message.get("more_body", False) - - bytes_body = b"".join([message.get("body", b"") for message in messages]) - decoded_body = bytes_body.decode(self.encoding) - - if decoded_body and not (self.nullable and is_null(decoded_body)): - body = self.parse(decoded_body) - self.validate(body) - - async def receive() -> t.MutableMapping[str, t.Any]: - while messages: - return messages.pop(0) - return await self._receive() - - return receive - class DefaultsJSONRequestBodyValidator(JSONRequestBodyValidator): """Request body validator for json content types which fills in default values. This Validator intercepts the body, makes changes to it, and replays it for the next ASGI application.""" - def __init__(self, *args, **kwargs): - defaults_validator = self.extend_with_set_default(Draft4RequestValidator) - super().__init__(*args, validator=defaults_validator, **kwargs) + MUTABLE_VALIDATION = True + """This validator might mutate to the body.""" + + @property + def _validator(self): + validator_cls = self.extend_with_set_default(Draft4RequestValidator) + return validator_cls(self._schema, format_checker=draft4_format_checker) # via https://python-jsonschema.readthedocs.io/ @staticmethod @@ -122,58 +101,6 @@ def set_defaults(validator, properties, instance, schema): validator_class, {"properties": set_defaults} ) - async def read_body(self) -> t.Tuple[str, int]: - """Read the body from the receive channel. - - :return: A tuple (body, max_length) where max_length is the length of the largest message. - """ - more_body = True - max_length = 256000 - messages = [] - while more_body: - message = await self._receive() - max_length = max(max_length, len(message.get("body", b""))) - messages.append(message) - more_body = message.get("more_body", False) - - bytes_body = b"".join([message.get("body", b"") for message in messages]) - - return bytes_body.decode(self.encoding), max_length - - async def wrapped_receive(self) -> Receive: - """Receive channel to pass on to next ASGI application.""" - decoded_body, max_length = await self.read_body() - - # Validate the body if not null - if decoded_body and not (self.nullable and is_null(decoded_body)): - body = self.parse(decoded_body) - del decoded_body - self.validate(body) - str_body = json.dumps(body) - else: - str_body = decoded_body - - bytes_body = str_body.encode(self.encoding) - del str_body - - # Recreate ASGI messages from validated body so changes made by the validator are propagated - messages = [ - { - "type": "http.request", - "body": bytes_body[i : i + max_length], - "more_body": i + max_length < len(bytes_body), - } - for i in range(0, len(bytes_body), max_length) - ] - del bytes_body - - async def receive() -> t.MutableMapping[str, t.Any]: - while messages: - return messages.pop(0) - return await self._receive() - - return receive - class JSONResponseBodyValidator: """Response body validator for json content types.""" diff --git a/tests/api/test_parameters.py b/tests/api/test_parameters.py index 16e699a4f..6fcb5e277 100644 --- a/tests/api/test_parameters.py +++ b/tests/api/test_parameters.py @@ -314,7 +314,10 @@ def test_mixed_formdata(simple_app): def test_formdata_file_upload_bad_request(simple_app): app_client = simple_app.test_client() - resp = app_client.post("/v1.0/test-formData-file-upload") + resp = app_client.post( + "/v1.0/test-formData-file-upload", + headers={"Content-Type": b"multipart/form-data; boundary=-"}, + ) assert resp.status_code == 400 assert resp.json()["detail"] in [ "Missing formdata parameter 'fileData'", diff --git a/tests/test_json_validation.py b/tests/test_json_validation.py index f2d2c58f6..d915ee116 100644 --- a/tests/test_json_validation.py +++ b/tests/test_json_validation.py @@ -28,8 +28,9 @@ def validate_type(validator, types, instance, schema): MinLengthRequestValidator = extend(Draft4RequestValidator, {"type": validate_type}) class MyJSONBodyValidator(JSONRequestBodyValidator): - def __init__(self, *args, **kwargs): - super().__init__(*args, validator=MinLengthRequestValidator, **kwargs) + @property + def _validator(self): + return MinLengthRequestValidator(self._schema) validator_map = {"body": {"application/json": MyJSONBodyValidator}} diff --git a/tests/test_operation2.py b/tests/test_operation2.py index 59779357e..5b1fc99c8 100644 --- a/tests/test_operation2.py +++ b/tests/test_operation2.py @@ -749,6 +749,7 @@ def test_form_transformation(api): "param": { "type": "string", "format": "email", + "default": "foo@bar.com", }, "array_param": { "type": "array",