Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add form data validator for validation middleware #1595

Merged
merged 5 commits into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions connexion/datastructures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from fnmatch import fnmatch


class MediaTypeDict(dict):
"""
A dictionary where keys can be either media types or media type ranges. When fetching a
value from the dictionary, the provided key is checked against the ranges. The most specific
key is chosen as prescribed by the OpenAPI spec, with `type/*` being preferred above
`*/subtype`.
"""

def __getitem__(self, item):
# Sort keys in order of specificity
for key in sorted(self, key=lambda k: ("*" not in k, k), reverse=True):
if fnmatch(item, key):
return super().__getitem__(key)
raise super().__getitem__(item)

def get(self, item, default=None):
try:
return self[item]
except KeyError:
return default

def __contains__(self, item):
try:
self[item]
except KeyError:
return False
else:
return True
151 changes: 6 additions & 145 deletions connexion/decorators/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,18 @@
import copy
import functools
import logging
import typing as t

from jsonschema import Draft4Validator, ValidationError
from jsonschema.validators import extend
from werkzeug.datastructures import FileStorage

from ..exceptions import BadRequestProblem, ExtraParameterProblem
from ..http_facts import FORM_CONTENT_TYPES
from ..json_schema import Draft4RequestValidator
from ..lifecycle import ConnexionResponse
from ..utils import boolean, is_null, is_nullable

logger = logging.getLogger("connexion.decorators.validation")

TYPE_MAP = {"integer": int, "number": float, "boolean": boolean, "object": dict}

try:
draft4_format_checker = Draft4Validator.FORMAT_CHECKER
draft4_format_checker = Draft4Validator.FORMAT_CHECKER # type: ignore
except AttributeError: # jsonschema < 4.5.0
from jsonschema import draft4_format_checker

Expand Down Expand Up @@ -101,106 +95,6 @@ def validate_parameter_list(request_params, spec_params):
return request_params.difference(spec_params)


class RequestBodyValidator:
def __init__(
self,
schema,
consumes,
api,
is_null_value_valid=False,
validator=None,
strict_validation=False,
):
"""
:param schema: The schema of the request body
:param consumes: The list of content types the operation consumes
:param is_null_value_valid: Flag to indicate if null is accepted as valid value.
:param validator: Validator class that should be used to validate passed data
against API schema. Default is jsonschema.Draft4Validator.
:type validator: jsonschema.IValidator
:param strict_validation: Flag indicating if parameters not in spec are allowed
"""
self.consumes = consumes
self.schema = schema
self.has_default = schema.get("default", False)
self.is_null_value_valid = is_null_value_valid
validatorClass = validator or Draft4RequestValidator
self.validator = validatorClass(schema, format_checker=draft4_format_checker)
self.api = api
self.strict_validation = strict_validation

def validate_formdata_parameter_list(self, request):
request_params = request.form.keys()
spec_params = self.schema.get("properties", {}).keys()
return validate_parameter_list(request_params, spec_params)

def __call__(self, function):
"""
:type function: types.FunctionType
:rtype: types.FunctionType
"""

@functools.wraps(function)
def wrapper(request):
if self.consumes[0] in FORM_CONTENT_TYPES:
data = dict(request.form.items()) or (
request.body if len(request.body) > 0 else {}
)
data.update(
dict.fromkeys(request.files, "")
) # validator expects string..
logger.debug("%s validating schema...", request.url)

if self.strict_validation:
formdata_errors = self.validate_formdata_parameter_list(request)
if formdata_errors:
raise ExtraParameterProblem(formdata_errors, [])

if data:
props = self.schema.get("properties", {})
errs = []
for k, param_defn in props.items():
if k in data:
try:
data[k] = coerce_type(
param_defn, data[k], "requestBody", k
)
except TypeValidationError as e:
errs += [str(e)]
print(errs)
if errs:
raise BadRequestProblem(detail=errs)

self.validate_schema(data, request.url)

response = function(request)
return response

return wrapper

@classmethod
def _error_path_message(cls, exception):
error_path = ".".join(str(item) for item in exception.path)
error_path_msg = f" - '{error_path}'" if error_path else ""
return error_path_msg

def validate_schema(self, data: dict, url: str) -> t.Optional[ConnexionResponse]:
if self.is_null_value_valid and is_null(data):
return None

try:
self.validator.validate(data)
except ValidationError as exception:
error_path_msg = self._error_path_message(exception=exception)
logger.error(
f"{str(url)} validation error: {exception.message}{error_path_msg}",
extra={"validator": "body"},
)
raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}")

return None


class ParameterValidator:
def __init__(self, parameters, api, strict_validation=False):
"""
Expand Down Expand Up @@ -231,20 +125,9 @@ def validate_parameter(parameter_type, value, param, param_name=None):
if "required" in param:
del param["required"]
try:
if parameter_type == "formdata" and param.get("type") == "file":
extend(
Draft4Validator,
type_checker=Draft4Validator.TYPE_CHECKER.redefine(
"file",
lambda checker, instance: isinstance(instance, FileStorage),
),
)(param, format_checker=draft4_format_checker).validate(
converted_value
)
else:
Draft4Validator(
param, format_checker=draft4_format_checker
).validate(converted_value)
Draft4Validator(param, format_checker=draft4_format_checker).validate(
converted_value
)
except ValidationError as exception:
debug_msg = (
"Error while converting value {converted_value} from param "
Expand All @@ -269,14 +152,6 @@ def validate_query_parameter_list(self, request):
spec_params = [x["name"] for x in self.parameters.get("query", [])]
return validate_parameter_list(request_params, spec_params)

def validate_formdata_parameter_list(self, request):
request_params = request.form.keys()
if "formData" in self.parameters: # Swagger 2:
spec_params = [x["name"] for x in self.parameters["formData"]]
else: # OAS 3
return set()
return validate_parameter_list(request_params, spec_params)

def validate_query_parameter(self, param, request):
"""
Validate a single query parameter (request.args in Flask)
Expand All @@ -299,14 +174,6 @@ def validate_cookie_parameter(self, param, request):
val = request.cookies.get(param["name"])
return self.validate_parameter("cookie", val, param)

def validate_formdata_parameter(self, param_name, param, request):
if param.get("type") == "file" or param.get("format") == "binary":
val = request.files.get(param_name)
else:
val = request.form.get(param_name)

return self.validate_parameter("formdata", val, param)

def __call__(self, function):
"""
:type function: types.FunctionType
Expand All @@ -319,10 +186,9 @@ def wrapper(request):

if self.strict_validation:
query_errors = self.validate_query_parameter_list(request)
formdata_errors = self.validate_formdata_parameter_list(request)

if formdata_errors or query_errors:
raise ExtraParameterProblem(formdata_errors, query_errors)
if query_errors:
raise ExtraParameterProblem([], query_errors)

for param in self.parameters.get("query", []):
error = self.validate_query_parameter(param, request)
Expand All @@ -344,11 +210,6 @@ def wrapper(request):
if error:
raise BadRequestProblem(detail=error)

for param in self.parameters.get("formData", []):
error = self.validate_formdata_parameter(param["name"], param, request)
if error:
raise BadRequestProblem(detail=error)

return function(request)

return wrapper
47 changes: 29 additions & 18 deletions connexion/middleware/request_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from starlette.types import ASGIApp, Receive, Scope, Send

from connexion import utils
from connexion.datastructures import MediaTypeDict
from connexion.decorators.uri_parsing import AbstractURIParser
from connexion.exceptions import UnsupportedMediaTypeProblem
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
Expand All @@ -29,7 +30,7 @@ def __init__(
self.next_app = next_app
self._operation = operation
self.strict_validation = strict_validation
self._validator_map = VALIDATOR_MAP
self._validator_map = VALIDATOR_MAP.copy()
self._validator_map.update(validator_map or {})
self.uri_parser_class = uri_parser_class

Expand Down Expand Up @@ -59,7 +60,11 @@ def validate_mime_type(self, mime_type: str) -> None:

:param mime_type: mime type from content type header
"""
if mime_type.lower() not in [c.lower() for c in self._operation.consumes]:
# Convert to MediaTypeDict to handle media-ranges
media_type_dict = MediaTypeDict(
[(c.lower(), None) for c in self._operation.consumes]
)
if mime_type.lower() not in media_type_dict:
raise UnsupportedMediaTypeProblem(
detail=f"Invalid Content-type ({mime_type}), "
f"expected {self._operation.consumes}"
Expand All @@ -75,22 +80,28 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
# TODO: Validate parameters

# Validate body
try:
body_validator = self._validator_map["body"][mime_type] # type: ignore
except KeyError:
logging.info(
f"Skipping validation. No validator registered for content type: "
f"{mime_type}."
)
else:
validator = body_validator(
scope,
receive,
schema=self._operation.body_schema,
nullable=utils.is_nullable(self._operation.body_definition),
encoding=encoding,
)
receive_fn = validator.receive
schema = self._operation.body_schema(mime_type)
if schema:
try:
body_validator = self._validator_map["body"][mime_type] # type: ignore
except KeyError:
logging.info(
f"Skipping validation. No validator registered for content type: "
f"{mime_type}."
)
else:
validator = body_validator(
scope,
receive,
schema=schema,
nullable=utils.is_nullable(
self._operation.body_definition(mime_type)
),
encoding=encoding,
strict_validation=self.strict_validation,
uri_parser=self._operation._uri_parsing_decorator,
)
receive_fn = await validator.wrapped_receive()

await self.next_app(scope, receive_fn, send)

Expand Down
2 changes: 1 addition & 1 deletion connexion/middleware/response_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
) -> None:
self.next_app = next_app
self._operation = operation
self._validator_map = VALIDATOR_MAP
self._validator_map = VALIDATOR_MAP.copy()
self._validator_map.update(validator_map or {})

def extract_content_type(
Expand Down
22 changes: 5 additions & 17 deletions connexion/operations/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,15 @@
from ..decorators.decorator import RequestResponseDecorator
from ..decorators.parameter import parameter_to_arg
from ..decorators.produces import BaseSerializer, Produces
from ..decorators.validation import ParameterValidator, RequestBodyValidator
from ..utils import all_json, is_nullable
from ..decorators.validation import ParameterValidator
from ..utils import all_json

logger = logging.getLogger("connexion.operations.abstract")

DEFAULT_MIMETYPE = "application/json"

VALIDATOR_MAP = {
"parameter": ParameterValidator,
"body": RequestBodyValidator,
}


Expand Down Expand Up @@ -281,16 +280,14 @@ def consumes(self):
Content-Types that the operation consumes
"""

@property
@abc.abstractmethod
def body_schema(self):
def body_schema(self, content_type: str = None) -> dict:
"""
The body schema definition for this operation.
"""

@property
@abc.abstractmethod
def body_definition(self):
def body_definition(self, content_type: str = None) -> dict:
"""
The body definition for this operation.
:rtype: dict
Expand Down Expand Up @@ -372,7 +369,7 @@ def _uri_parsing_decorator(self):
Returns a decorator that parses request data and handles things like
array types, and duplicate parameter definitions.
"""
return self._uri_parser_class(self.parameters, self.body_definition)
return self._uri_parser_class(self.parameters, self.body_definition())

@property
def function(self):
Expand Down Expand Up @@ -455,15 +452,6 @@ def __validation_decorators(self):
yield ParameterValidator(
self.parameters, self.api, strict_validation=self.strict_validation
)
if self.body_schema:
# TODO: temporarily hardcoded, remove RequestBodyValidator completely
yield RequestBodyValidator(
self.body_schema,
self.consumes,
self.api,
is_nullable(self.body_definition),
strict_validation=self.strict_validation,
)

def json_loads(self, data):
"""
Expand Down
Loading