From 0dc676f5ac08c8eecad79a77e1cd8c68973e7f0f Mon Sep 17 00:00:00 2001 From: Antoine Merino Date: Fri, 14 Jun 2024 13:17:48 +0200 Subject: [PATCH] Add backward compatibility with Pydantic V1 base models --- flask_pydantic/converters.py | 33 +- flask_pydantic/core.py | 78 ++-- tests/pydantic_v1/__init__.py | 0 tests/pydantic_v1/conftest.py | 56 +++ tests/pydantic_v1/func/__init__.py | 0 tests/pydantic_v1/func/test_app.py | 450 +++++++++++++++++++++ tests/pydantic_v1/unit/__init__.py | 0 tests/pydantic_v1/unit/test_core.py | 587 ++++++++++++++++++++++++++++ 8 files changed, 1172 insertions(+), 32 deletions(-) create mode 100644 tests/pydantic_v1/__init__.py create mode 100644 tests/pydantic_v1/conftest.py create mode 100644 tests/pydantic_v1/func/__init__.py create mode 100644 tests/pydantic_v1/func/test_app.py create mode 100644 tests/pydantic_v1/unit/__init__.py create mode 100644 tests/pydantic_v1/unit/test_core.py diff --git a/flask_pydantic/converters.py b/flask_pydantic/converters.py index fa7e24a..5335f12 100644 --- a/flask_pydantic/converters.py +++ b/flask_pydantic/converters.py @@ -6,8 +6,11 @@ from typing_extensions import get_args, get_origin from pydantic import BaseModel +from pydantic.v1 import BaseModel as V1BaseModel from werkzeug.datastructures import ImmutableMultiDict +V1OrV2BaseModel = Union[BaseModel, V1BaseModel] + def _is_list(type_: Type) -> bool: origin = get_origin(type_) @@ -19,7 +22,7 @@ def _is_list(type_: Type) -> bool: def convert_query_params( - query_params: ImmutableMultiDict, model: Type[BaseModel] + query_params: ImmutableMultiDict, model: Type[V1OrV2BaseModel] ) -> dict: """ group query parameters into lists if model defines them @@ -28,12 +31,22 @@ def convert_query_params( :param model: query parameter's model :return: resulting parameters """ - return { - **query_params.to_dict(), - **{ - key: value - for key, value in query_params.to_dict(flat=False).items() - if key in model.model_fields - and _is_list(model.model_fields[key].annotation) - }, - } + if issubclass(model, BaseModel): + return { + **query_params.to_dict(), + **{ + key: value + for key, value in query_params.to_dict(flat=False).items() + if key in model.model_fields + and _is_list(model.model_fields[key].annotation) + }, + } + else: + return { + **query_params.to_dict(), + **{ + key: value + for key, value in query_params.to_dict(flat=False).items() + if key in model.__fields__ and model.__fields__[key].is_complex() + }, + } diff --git a/flask_pydantic/core.py b/flask_pydantic/core.py index 68bbfe0..f3e6fe8 100644 --- a/flask_pydantic/core.py +++ b/flask_pydantic/core.py @@ -3,7 +3,9 @@ from flask import Response, current_app, jsonify, make_response, request from pydantic import BaseModel, ValidationError, TypeAdapter, RootModel - +from pydantic.v1 import BaseModel as V1BaseModel +from pydantic.v1.error_wrappers import ValidationError as V1ValidationError +from pydantic.v1.tools import parse_obj_as from .converters import convert_query_params from .exceptions import ( InvalidIterableOfModelsException, @@ -17,9 +19,19 @@ except ImportError: pass +V1OrV2BaseModel = Union[BaseModel, V1BaseModel] + + +def _model_dump_json(model: V1OrV2BaseModel, **kwargs): + """Adapter to dump a model to json, whether it's a Pydantic V1 or V2 model.""" + if isinstance(model, BaseModel): + return model.model_dump_json(**kwargs) + else: + return model.json(**kwargs) + def make_json_response( - content: Union[BaseModel, Iterable[BaseModel]], + content: Union[V1OrV2BaseModel, Iterable[V1OrV2BaseModel]], status_code: int, by_alias: bool, exclude_none: bool = False, @@ -27,9 +39,9 @@ def make_json_response( ) -> Response: """serializes model, creates JSON response with given status code""" if many: - js = f"[{', '.join([model.model_dump_json(exclude_none=exclude_none, by_alias=by_alias) for model in content])}]" + js = f"[{', '.join([_model_dump_json(model, exclude_none=exclude_none, by_alias=by_alias) for model in content])}]" else: - js = content.model_dump_json(exclude_none=exclude_none, by_alias=by_alias) + js = _model_dump_json(content, exclude_none=exclude_none, by_alias=by_alias) response = make_response(js, status_code) response.mimetype = "application/json" return response @@ -45,12 +57,14 @@ def unsupported_media_type_response(request_cont_type: str) -> Response: def is_iterable_of_models(content: Any) -> bool: try: - return all(isinstance(obj, BaseModel) for obj in content) + return all(isinstance(obj, V1OrV2BaseModel) for obj in content) except TypeError: return False -def validate_many_models(model: Type[BaseModel], content: Any) -> List[BaseModel]: +def validate_many_models( + model: Type[V1OrV2BaseModel], content: Any +) -> List[V1OrV2BaseModel]: try: return [model(**fields) for fields in content] except TypeError: @@ -63,7 +77,7 @@ def validate_many_models(model: Type[BaseModel], content: Any) -> List[BaseModel } ] raise ManyModelValidationError(err) - except ValidationError as ve: + except (ValidationError, V1ValidationError) as ve: raise ManyModelValidationError(ve.errors()) @@ -74,9 +88,13 @@ def validate_path_params(func: Callable, kwargs: dict) -> Tuple[dict, list]: if name in {"query", "body", "form", "return"}: continue try: - adapter = TypeAdapter(type_) - validated[name] = adapter.validate_python(kwargs.get(name)) - except ValidationError as e: + if not isinstance(type_, V1BaseModel): + adapter = TypeAdapter(type_) + validated[name] = adapter.validate_python(kwargs.get(name)) + else: + value = parse_obj_as(type_, kwargs.get(name)) + validated[name] = value + except (ValidationError, V1ValidationError) as e: err = e.errors()[0] err["loc"] = [name] errors.append(err) @@ -92,15 +110,15 @@ def get_body_dict(**params): def validate( - body: Optional[Type[BaseModel]] = None, - query: Optional[Type[BaseModel]] = None, + body: Optional[Type[V1OrV2BaseModel]] = None, + query: Optional[Type[V1OrV2BaseModel]] = None, on_success_status: int = 200, exclude_none: bool = False, response_many: bool = False, request_body_many: bool = False, response_by_alias: bool = False, get_json_params: Optional[dict] = None, - form: Optional[Type[BaseModel]] = None, + form: Optional[Type[V1OrV2BaseModel]] = None, ): """ Decorator for route methods which will validate query, body and form parameters @@ -175,16 +193,24 @@ def wrapper(*args, **kwargs): query_params = convert_query_params(request.args, query_model) try: q = query_model(**query_params) - except ValidationError as ve: + except (ValidationError, V1ValidationError) as ve: err["query_params"] = ve.errors() body_in_kwargs = func.__annotations__.get("body") body_model = body_in_kwargs or body if body_model: body_params = get_body_dict(**(get_json_params or {})) - if issubclass(body_model, RootModel): + if ( + issubclass(body_model, V1BaseModel) + and "__root__" in body_model.__fields__ + ): + try: + b = body_model(__root__=body_params).__root__ + except (ValidationError, V1ValidationError) as ve: + err["body_params"] = ve.errors() + elif issubclass(body_model, RootModel): try: b = body_model(body_params) - except ValidationError as ve: + except (ValidationError, V1ValidationError) as ve: err["body_params"] = ve.errors() elif request_body_many: try: @@ -201,16 +227,24 @@ def wrapper(*args, **kwargs): return unsupported_media_type_response(content_type) else: raise JsonBodyParsingError() - except ValidationError as ve: + except (ValidationError, V1ValidationError) as ve: err["body_params"] = ve.errors() form_in_kwargs = func.__annotations__.get("form") form_model = form_in_kwargs or form if form_model: form_params = request.form - if issubclass(form_model, RootModel): + if ( + isinstance(form, V1BaseModel) + and "__root__" in form_model.__fields__ + ): + try: + f = form_model(form_params) + except (ValidationError, V1ValidationError) as ve: + err["form_params"] = ve.errors() + elif issubclass(form_model, RootModel): try: f = form_model(form_params) - except ValidationError as ve: + except (ValidationError, V1ValidationError) as ve: err["form_params"] = ve.errors() else: try: @@ -222,7 +256,7 @@ def wrapper(*args, **kwargs): return unsupported_media_type_response(content_type) else: raise JsonBodyParsingError - except ValidationError as ve: + except (ValidationError, V1ValidationError) as ve: err["form_params"] = ve.errors() request.query_params = q request.body_params = b @@ -260,7 +294,7 @@ def wrapper(*args, **kwargs): else: raise InvalidIterableOfModelsException(res) - if isinstance(res, BaseModel): + if isinstance(res, V1OrV2BaseModel): return make_json_response( res, on_success_status, @@ -271,7 +305,7 @@ def wrapper(*args, **kwargs): if ( isinstance(res, tuple) and len(res) in [2, 3] - and isinstance(res[0], BaseModel) + and isinstance(res[0], V1OrV2BaseModel) ): headers = None status = on_success_status diff --git a/tests/pydantic_v1/__init__.py b/tests/pydantic_v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/pydantic_v1/conftest.py b/tests/pydantic_v1/conftest.py new file mode 100644 index 0000000..af4a502 --- /dev/null +++ b/tests/pydantic_v1/conftest.py @@ -0,0 +1,56 @@ +""" +Specific confest.py file for testing behavior with Pydantic V1. + +The fixtures below override the confest.py's fixtures for this module only. +""" +from typing import List, Optional, Type + +import pytest + +from pydantic.v1 import BaseModel + + +@pytest.fixture +def query_model() -> Type[BaseModel]: + class Query(BaseModel): + limit: int = 2 + min_views: Optional[int] = None + + return Query + + +@pytest.fixture +def body_model() -> Type[BaseModel]: + class Body(BaseModel): + search_term: str + exclude: Optional[str] = None + + return Body + + +@pytest.fixture +def form_model() -> Type[BaseModel]: + class Form(BaseModel): + search_term: str + exclude: Optional[str] = None + + return Form + + +@pytest.fixture +def post_model() -> Type[BaseModel]: + class Post(BaseModel): + title: str + text: str + views: int + + return Post + + +@pytest.fixture +def response_model(post_model: BaseModel) -> Type[BaseModel]: + class Response(BaseModel): + results: List[post_model] + count: int + + return Response diff --git a/tests/pydantic_v1/func/__init__.py b/tests/pydantic_v1/func/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/pydantic_v1/func/test_app.py b/tests/pydantic_v1/func/test_app.py new file mode 100644 index 0000000..a8ec201 --- /dev/null +++ b/tests/pydantic_v1/func/test_app.py @@ -0,0 +1,450 @@ +from ...util import assert_matches +import re +from typing import List, Optional + +import pytest +from flask import jsonify, request +from flask_pydantic import validate, ValidationError +from pydantic.v1 import BaseModel + + +class ArrayModel(BaseModel): + arr1: List[str] + arr2: Optional[List[int]] = None + + +@pytest.fixture +def app_with_array_route(app): + @app.route("/arr", methods=["GET"]) + @validate(query=ArrayModel, exclude_none=True) + def pass_array(): + return ArrayModel( + arr1=request.query_params.arr1, arr2=request.query_params.arr2 + ) + + +@pytest.fixture +def app_with_optional_body(app): + class Body(BaseModel): + param: str + + @app.route("/no_params", methods=["POST"]) + @validate() + def no_params(body: Body): + return body + + @app.route("/silent", methods=["POST"]) + @validate(get_json_params={"silent": True}) + def silent(body: Body): + return body + + +@pytest.fixture +def app_raise_on_validation_error(app): + app.config["FLASK_PYDANTIC_VALIDATION_ERROR_RAISE"] = True + + def validation_error(error: ValidationError): + return ( + jsonify( + { + "title": "validation error", + "body": error.body_params, + } + ), + 422, + ) + + app.register_error_handler(ValidationError, validation_error) + + class Body(BaseModel): + param: str + + @app.route("/silent", methods=["POST"]) + @validate(get_json_params={"silent": True}) + def silent(body: Body): + return body + + +@pytest.fixture +def app_with_int_path_param_route(app): + class IdObj(BaseModel): + id: int + + @app.route("/path_param//", methods=["GET"]) + @validate() + def int_path_param(obj_id: int): + return IdObj(id=obj_id) + + +@pytest.fixture +def app_with_untyped_path_param_route(app): + class IdObj(BaseModel): + id: str + + @app.route("/path_param//", methods=["GET"]) + @validate() + def int_path_param(obj_id): + return IdObj(id=obj_id) + + +@pytest.fixture +def app_with_custom_root_type(app): + class Person(BaseModel): + name: str + age: Optional[int] = None + + class PersonBulk(BaseModel): + __root__: List[Person] + + def __len__(self): + return len(self.root) + + @app.route("/root_type", methods=["POST"]) + @validate() + def root_type(body: PersonBulk): + return {"number": len(body)} + + +@pytest.fixture +def app_with_custom_headers(app): + @app.route("/custom_headers", methods=["GET"]) + @validate() + def custom_headers(): + return {"test": 1}, {"CUSTOM_HEADER": "UNIQUE"} + + +@pytest.fixture +def app_with_custom_headers_status(app): + @app.route("/custom_headers_status", methods=["GET"]) + @validate() + def custom_headers(): + return {"test": 1}, 201, {"CUSTOM_HEADER": "UNIQUE"} + + +@pytest.fixture +def app_with_camel_route(app): + def to_camel(x: str) -> str: + first, *rest = x.split("_") + return "".join([first] + [x.capitalize() for x in rest]) + + class RequestModel(BaseModel): + x: int + y: int + + class ResultModel(BaseModel): + result_of_addition: int + result_of_multiplication: int + + class Config: + alias_generator = to_camel + allow_population_by_field_name = True + + @app.route("/compute", methods=["GET"]) + @validate(response_by_alias=True) + def compute(query: RequestModel): + return ResultModel( + result_of_addition=query.x + query.y, + result_of_multiplication=query.x * query.y, + ) + + +test_cases = [ + pytest.param( + "?limit=limit", + {"search_term": "text"}, + 400, + { + "validation_error": { + "query_params": [ + { + "loc": ["limit"], + "msg": "value is not a valid integer", + "type": "type_error.integer", + } + ] + } + }, + id="invalid limit", + ), + pytest.param( + "?limit=2", + {}, + 400, + { + "validation_error": { + "body_params": [ + { + "loc": ["search_term"], + "msg": "field required", + "type": "value_error.missing", + } + ] + } + }, + id="missing required body parameter", + ), + pytest.param( + "?limit=1&min_views=2", + {"search_term": "text"}, + 200, + {"count": 2, "results": [{"title": "2", "text": "another text", "views": 2}]}, + id="valid parameters", + ), + pytest.param( + "", + {"search_term": "text"}, + 200, + { + "count": 3, + "results": [ + {"title": "title 1", "text": "random text", "views": 1}, + {"title": "2", "text": "another text", "views": 2}, + ], + }, + id="valid params, no query", + ), +] + +form_test_cases = [ + pytest.param( + "?limit=2", + {}, + 400, + { + "validation_error": { + "form_params": [ + { + "loc": ["search_term"], + "msg": "field required", + "type": "value_error.missing", + } + ] + } + }, + id="missing required form parameter", + ), + pytest.param( + "?limit=1&min_views=2", + {"search_term": "text"}, + 200, + {"count": 2, "results": [{"title": "2", "text": "another text", "views": 2}]}, + id="valid parameters", + ), + pytest.param( + "", + {"search_term": "text"}, + 200, + { + "count": 3, + "results": [ + {"title": "title 1", "text": "random text", "views": 1}, + {"title": "2", "text": "another text", "views": 2}, + ], + }, + id="valid params, no query", + ), +] + + +class TestSimple: + @pytest.mark.parametrize("query,body,expected_status,expected_response", test_cases) + def test_post(self, client, query, body, expected_status, expected_response): + response = client.post(f"/search{query}", json=body) + assert_matches(expected_response, response.json) + assert response.status_code == expected_status + + @pytest.mark.parametrize("query,body,expected_status,expected_response", test_cases) + def test_post_kwargs(self, client, query, body, expected_status, expected_response): + response = client.post(f"/search/kwargs{query}", json=body) + assert_matches(expected_response, response.json) + assert response.status_code == expected_status + + @pytest.mark.parametrize( + "query,form,expected_status,expected_response", form_test_cases + ) + def test_post_kwargs_form( + self, client, query, form, expected_status, expected_response + ): + response = client.post( + f"/search/form/kwargs{query}", + data=form, + ) + assert_matches(expected_response, response.json) + assert response.status_code == expected_status + + def test_error_status_code(self, app, mocker, client): + mocker.patch.dict( + app.config, {"FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE": 422} + ) + response = client.post("/search?limit=2", json={}) + assert response.status_code == 422 + + +@pytest.mark.usefixtures("app_with_custom_root_type") +def test_custom_root_types(client): + response = client.post( + "/root_type", + json=[{"name": "Joshua Bardwell", "age": 46}, {"name": "Andrew Cambden"}], + ) + assert response.json == {"number": 2} + + +@pytest.mark.usefixtures("app_with_custom_headers") +def test_custom_headers(client): + response = client.get("/custom_headers") + assert response.json == {"test": 1} + assert response.status_code == 200 + assert response.headers.get("CUSTOM_HEADER") == "UNIQUE" + + +@pytest.mark.usefixtures("app_with_custom_headers_status") +def test_custom_headers(client): + response = client.get("/custom_headers_status") + assert response.json == {"test": 1} + assert response.status_code == 201 + assert response.headers.get("CUSTOM_HEADER") == "UNIQUE" + + +@pytest.mark.usefixtures("app_with_array_route") +class TestArrayQueryParam: + def test_no_param_raises(self, client): + response = client.get("/arr") + assert_matches( + { + "validation_error": { + "query_params": [ + { + "loc": ["arr1"], + "msg": "field required", + "type": "value_error.missing", + } + ] + } + }, + response.json, + ) + + def test_correctly_returns_first_arr(self, client): + response = client.get("/arr?arr1=first&arr1=second") + assert response.json == {"arr1": ["first", "second"]} + + def test_correctly_returns_first_arr_one_element(self, client): + response = client.get("/arr?arr1=first") + assert response.json == {"arr1": ["first"]} + + def test_correctly_returns_both_arrays(self, client): + response = client.get("/arr?arr1=first&arr1=second&arr2=1&arr2=10") + assert response.json == {"arr1": ["first", "second"], "arr2": [1, 10]} + + +aliases_test_cases = [ + pytest.param(1, 2, {"resultOfMultiplication": 2, "resultOfAddition": 3}), + pytest.param(10, 20, {"resultOfMultiplication": 200, "resultOfAddition": 30}), + pytest.param(999, 0, {"resultOfMultiplication": 0, "resultOfAddition": 999}), +] + + +@pytest.mark.usefixtures("app_with_camel_route") +@pytest.mark.parametrize("x,y,expected_result", aliases_test_cases) +def test_aliases(x, y, expected_result, client): + response = client.get(f"/compute?x={x}&y={y}") + assert_matches(expected_result, response.json) + + +@pytest.mark.usefixtures("app_with_int_path_param_route") +class TestPathIntParameter: + def test_correct_param_passes(self, client): + id_ = 12 + expected_response = {"id": id_} + response = client.get(f"/path_param/{id_}/") + assert_matches(expected_response, response.json) + + def test_string_parameter(self, client): + expected_response = { + "validation_error": { + "path_params": [ + { + "input": "not_an_int", + "loc": ["obj_id"], + "msg": "Input should be a valid integer, unable to parse string as an integer", + "type": "int_parsing", + "url": re.compile( + r"https://errors\.pydantic\.dev/.*/v/int_parsing" + ), + } + ] + } + } + response = client.get("/path_param/not_an_int/") + + assert_matches(expected_response, response.json) + assert response.status_code == 400 + + +@pytest.mark.usefixtures("app_with_untyped_path_param_route") +class TestPathUnannotatedParameter: + def test_int_str_param_passes(self, client): + id_ = 12 + expected_response = {"id": str(id_)} + response = client.get(f"/path_param/{id_}/") + + assert_matches(expected_response, response.json) + + def test_str_param_passes(self, client): + id_ = "twelve" + expected_response = {"id": id_} + response = client.get(f"/path_param/{id_}/") + + assert_matches(expected_response, response.json) + + +@pytest.mark.usefixtures("app_with_optional_body") +class TestGetJsonParams: + def test_empty_body_fails(self, client): + response = client.post( + "/no_params", headers={"Content-Type": "application/json"} + ) + + assert response.status_code == 400 + assert ( + "failed to decode json object: expecting value: line 1 column 1 (char 0)" + in response.text.lower() + ) + + def test_silent(self, client): + response = client.post("/silent", headers={"Content-Type": "application/json"}) + + assert_matches( + { + "validation_error": { + "body_params": [ + { + "loc": ["param"], + "msg": "field required", + "type": "value_error.missing", + } + ] + } + }, + response.json, + ) + assert response.status_code == 400 + + +@pytest.mark.usefixtures("app_raise_on_validation_error") +class TestCustomResponse: + def test_silent(self, client): + response = client.post("/silent", headers={"Content-Type": "application/json"}) + + assert response.json["title"] == "validation error" + assert_matches( + [ + { + "loc": ["param"], + "msg": "field required", + "type": "value_error.missing", + } + ], + response.json["body"], + ) + assert response.status_code == 422 diff --git a/tests/pydantic_v1/unit/__init__.py b/tests/pydantic_v1/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/pydantic_v1/unit/test_core.py b/tests/pydantic_v1/unit/test_core.py new file mode 100644 index 0000000..074ad7c --- /dev/null +++ b/tests/pydantic_v1/unit/test_core.py @@ -0,0 +1,587 @@ +from typing import Any, List, NamedTuple, Optional, Type, Union +from ...util import assert_matches + +import pytest +from flask import jsonify +from flask_pydantic import validate, ValidationError +from flask_pydantic.core import convert_query_params, is_iterable_of_models +from flask_pydantic.exceptions import ( + InvalidIterableOfModelsException, + JsonBodyParsingError, +) +from pydantic.v1 import BaseModel +from werkzeug.datastructures import ImmutableMultiDict + + +class ValidateParams(NamedTuple): + body_model: Optional[Type[BaseModel]] = None + query_model: Optional[Type[BaseModel]] = None + form_model: Optional[Type[BaseModel]] = None + response_model: Type[BaseModel] = None + on_success_status: int = 200 + request_query: ImmutableMultiDict = ImmutableMultiDict({}) + request_body: Union[dict, List[dict]] = {} + request_form: ImmutableMultiDict = ImmutableMultiDict({}) + expected_response_body: Optional[dict] = None + expected_status_code: int = 200 + exclude_none: bool = False + response_many: bool = False + request_body_many: bool = False + + +class ResponseModel(BaseModel): + q1: int + q2: str + b1: float + b2: Optional[str] = None + + +class QueryModel(BaseModel): + q1: int + q2: str = "default" + + +class RequestBodyModel(BaseModel): + b1: float + b2: Optional[str] = None + + +class FormModel(BaseModel): + f1: int + f2: str = None + + +class RequestBodyModelRoot(BaseModel): + __root__: Union[str, RequestBodyModel] + + +validate_test_cases = [ + pytest.param( + ValidateParams( + request_body={"b1": 1.4}, + request_query=ImmutableMultiDict({"q1": 1}), + request_form=ImmutableMultiDict({"f1": 1}), + form_model=FormModel, + expected_response_body={"q1": 1, "q2": "default", "b1": 1.4, "b2": None}, + response_model=ResponseModel, + query_model=QueryModel, + body_model=RequestBodyModel, + ), + id="simple valid example with default values", + ), + pytest.param( + ValidateParams( + request_body={"b1": 1.4}, + request_query=ImmutableMultiDict({"q1": 1}), + request_form=ImmutableMultiDict({"f1": 1}), + form_model=FormModel, + expected_response_body={"q1": 1, "q2": "default", "b1": 1.4}, + response_model=ResponseModel, + query_model=QueryModel, + body_model=RequestBodyModel, + exclude_none=True, + ), + id="simple valid example with default values, exclude none", + ), + pytest.param( + ValidateParams( + query_model=QueryModel, + expected_response_body={ + "validation_error": { + "query_params": [ + { + "loc": ["q1"], + "msg": "field required", + "type": "value_error.missing", + } + ] + } + }, + expected_status_code=400, + ), + id="invalid query param", + ), + pytest.param( + ValidateParams( + body_model=RequestBodyModel, + expected_response_body={ + "validation_error": { + "body_params": [ + { + "loc": ["root"], + "msg": "is not an array of objects", + "type": "type_error.array", + } + ] + } + }, + request_body={"b1": 3.14, "b2": "str"}, + expected_status_code=400, + request_body_many=True, + ), + id="`request_body_many=True` but in request body is a single object", + ), + pytest.param( + ValidateParams( + expected_response_body={ + "validation_error": { + "body_params": [ + { + "loc": ["b1"], + "msg": "field required", + "type": "value_error.missing", + } + ] + } + }, + body_model=RequestBodyModel, + expected_status_code=400, + ), + id="invalid body param", + ), + pytest.param( + ValidateParams( + expected_response_body={ + "validation_error": { + "body_params": [ + { + "loc": ["b1"], + "msg": "field required", + "type": "value_error.missing", + } + ] + } + }, + body_model=RequestBodyModel, + expected_status_code=400, + request_body=[{}], + request_body_many=True, + ), + id="invalid body param in many-object request body", + ), + pytest.param( + ValidateParams( + form_model=FormModel, + expected_response_body={ + "validation_error": { + "form_params": [ + { + "loc": ["f1"], + "msg": "field required", + "type": "value_error.missing", + } + ] + } + }, + expected_status_code=400, + ), + id="invalid form param", + ), +] + + +class TestValidate: + @pytest.mark.parametrize("parameters", validate_test_cases) + def test_validate(self, mocker, request_ctx, parameters: ValidateParams): + mock_request = mocker.patch.object(request_ctx, "request") + mock_request.args = parameters.request_query + mock_request.get_json = lambda: parameters.request_body + mock_request.form = parameters.request_form + + def f(): + body = {} + query = {} + if mock_request.form_params: + body = mock_request.form_params.dict() + if mock_request.body_params: + body = mock_request.body_params.dict() + if mock_request.query_params: + query = mock_request.query_params.dict() + return parameters.response_model(**body, **query) + + response = validate( + query=parameters.query_model, + body=parameters.body_model, + on_success_status=parameters.on_success_status, + exclude_none=parameters.exclude_none, + response_many=parameters.response_many, + request_body_many=parameters.request_body_many, + form=parameters.form_model, + )(f)() + + assert response.status_code == parameters.expected_status_code + assert_matches(parameters.expected_response_body, response.json) + if 200 <= response.status_code < 300: + assert ( + mock_request.body_params.dict(exclude_none=True, exclude_defaults=True) + == parameters.request_body + ) + assert ( + mock_request.query_params.dict(exclude_none=True, exclude_defaults=True) + == parameters.request_query.to_dict() + ) + + @pytest.mark.parametrize("parameters", validate_test_cases) + def test_validate_kwargs(self, mocker, request_ctx, parameters: ValidateParams): + mock_request = mocker.patch.object(request_ctx, "request") + mock_request.args = parameters.request_query + mock_request.get_json = lambda: parameters.request_body + mock_request.form = parameters.request_form + + def f( + body: parameters.body_model, + query: parameters.query_model, + form: parameters.form_model, + ): + return parameters.response_model( + **body.dict(), **query.dict(), **form.dict() + ) + + response = validate( + on_success_status=parameters.on_success_status, + exclude_none=parameters.exclude_none, + response_many=parameters.response_many, + request_body_many=parameters.request_body_many, + )(f)() + + assert_matches(parameters.expected_response_body, response.json) + assert response.status_code == parameters.expected_status_code + if 200 <= response.status_code < 300: + assert ( + mock_request.body_params.dict(exclude_none=True, exclude_defaults=True) + == parameters.request_body + ) + assert ( + mock_request.query_params.dict(exclude_none=True, exclude_defaults=True) + == parameters.request_query.to_dict() + ) + + @pytest.mark.usefixtures("request_ctx") + def test_response_with_status(self): + expected_status_code = 201 + expected_response_body = dict(q1=1, q2="2", b1=3.14, b2="b2") + + def f(): + return ResponseModel(q1=1, q2="2", b1=3.14, b2="b2"), expected_status_code + + response = validate()(f)() + assert response.status_code == expected_status_code + assert_matches(expected_response_body, response.json) + + @pytest.mark.usefixtures("request_ctx") + def test_response_already_response(self): + expected_response_body = {"a": 1, "b": 2} + + def f(): + return jsonify(expected_response_body) + + response = validate()(f)() + assert_matches(expected_response_body, response.json) + + @pytest.mark.usefixtures("request_ctx") + def test_response_many_response_objs(self): + response_content = [ + ResponseModel(q1=1, q2="2", b1=3.14, b2="b2"), + ResponseModel(q1=2, q2="3", b1=3.14), + ResponseModel(q1=3, q2="4", b1=6.9, b2="b4"), + ] + expected_response_body = [ + {"q1": 1, "q2": "2", "b1": 3.14, "b2": "b2"}, + {"q1": 2, "q2": "3", "b1": 3.14}, + {"q1": 3, "q2": "4", "b1": 6.9, "b2": "b4"}, + ] + + def f(): + return response_content + + response = validate(exclude_none=True, response_many=True)(f)() + assert_matches(expected_response_body, response.json) + + @pytest.mark.usefixtures("request_ctx") + def test_invalid_many_raises(self): + def f(): + return ResponseModel(q1=1, q2="2", b1=3.14, b2="b2") + + with pytest.raises(InvalidIterableOfModelsException): + validate(response_many=True)(f)() + + def test_valid_array_object_request_body(self, mocker, request_ctx): + mock_request = mocker.patch.object(request_ctx, "request") + mock_request.args = ImmutableMultiDict({"q1": 1}) + mock_request.get_json = lambda: [ + {"b1": 1.0, "b2": "str1"}, + {"b1": 2.0, "b2": "str2"}, + ] + expected_response_body = [ + {"q1": 1, "q2": "default", "b1": 1.0, "b2": "str1"}, + {"q1": 1, "q2": "default", "b1": 2.0, "b2": "str2"}, + ] + + def f(): + query_params = mock_request.query_params + body_params = mock_request.body_params + return [ + ResponseModel( + q1=query_params.q1, + q2=query_params.q2, + b1=obj.b1, + b2=obj.b2, + ) + for obj in body_params + ] + + response = validate( + query=QueryModel, + body=RequestBodyModel, + request_body_many=True, + response_many=True, + )(f)() + + assert response.status_code == 200 + assert_matches(expected_response_body, response.json) + + def test_unsupported_media_type(self, request_ctx, mocker): + mock_request = mocker.patch.object(request_ctx, "request") + content_type = "text/plain" + mock_request.headers = {"Content-Type": content_type} + mock_request.get_json = lambda: None + body_model = RequestBodyModel + response = validate(body_model)(lambda x: x)() + assert response.status_code == 415 + assert response.json == { + "detail": f"Unsupported media type '{content_type}' in request. " + "'application/json' is required." + } + + def test_invalid_body_model_root(self, request_ctx, mocker): + mock_request = mocker.patch.object(request_ctx, "request") + content_type = "application/json" + mock_request.headers = {"Content-Type": content_type} + mock_request.get_json = lambda: None + body_model = RequestBodyModelRoot + response = validate(body_model)(lambda x: x)() + assert response.status_code == 400 + + assert_matches( + { + "validation_error": { + "body_params": [ + { + "loc": ["__root__"], + "msg": "none is not an allowed value", + "type": "type_error.none.not_allowed", + } + ] + } + }, + response.json, + ) + + def test_damaged_request_body_json_with_charset(self, request_ctx, mocker): + mock_request = mocker.patch.object(request_ctx, "request") + content_type = "application/json;charset=utf-8" + mock_request.headers = {"Content-Type": content_type} + mock_request.get_json = lambda: None + body_model = RequestBodyModel + with pytest.raises(JsonBodyParsingError): + validate(body_model)(lambda x: x)() + + def test_damaged_request_body(self, request_ctx, mocker): + mock_request = mocker.patch.object(request_ctx, "request") + content_type = "application/json" + mock_request.headers = {"Content-Type": content_type} + mock_request.get_json = lambda: None + body_model = RequestBodyModel + with pytest.raises(JsonBodyParsingError): + validate(body_model)(lambda x: x)() + + @pytest.mark.parametrize("parameters", validate_test_cases) + def test_validate_func_having_return_type_annotation( + self, mocker, request_ctx, parameters: ValidateParams + ): + mock_request = mocker.patch.object(request_ctx, "request") + mock_request.args = parameters.request_query + mock_request.get_json = lambda: parameters.request_body + mock_request.form = parameters.request_form + + def f() -> Any: + body = {} + query = {} + if mock_request.form_params: + body = mock_request.form_params.dict() + if mock_request.body_params: + body = mock_request.body_params.dict() + if mock_request.query_params: + query = mock_request.query_params.dict() + return parameters.response_model(**body, **query) + + response = validate( + query=parameters.query_model, + body=parameters.body_model, + form=parameters.form_model, + on_success_status=parameters.on_success_status, + exclude_none=parameters.exclude_none, + response_many=parameters.response_many, + request_body_many=parameters.request_body_many, + )(f)() + + assert response.status_code == parameters.expected_status_code + assert_matches(parameters.expected_response_body, response.json) + if 200 <= response.status_code < 300: + assert ( + mock_request.body_params.dict(exclude_none=True, exclude_defaults=True) + == parameters.request_body + ) + assert ( + mock_request.query_params.dict(exclude_none=True, exclude_defaults=True) + == parameters.request_query.to_dict() + ) + + def test_fail_validation_custom_status_code(self, app, request_ctx, mocker): + app.config["FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE"] = 422 + mock_request = mocker.patch.object(request_ctx, "request") + content_type = "application/json" + mock_request.headers = {"Content-Type": content_type} + mock_request.get_json = lambda: None + body_model = RequestBodyModelRoot + response = validate(body_model)(lambda x: x)() + assert response.status_code == 422 + + assert_matches( + { + "validation_error": { + "body_params": [ + { + "loc": ["__root__"], + "msg": "none is not an allowed value", + "type": "type_error.none.not_allowed", + } + ] + } + }, + response.json, + ) + + def test_body_fail_validation_raise_exception(self, app, request_ctx, mocker): + app.config["FLASK_PYDANTIC_VALIDATION_ERROR_RAISE"] = True + mock_request = mocker.patch.object(request_ctx, "request") + content_type = "application/json" + mock_request.headers = {"Content-Type": content_type} + mock_request.get_json = lambda: None + body_model = RequestBodyModelRoot + with pytest.raises(ValidationError) as excinfo: + validate(body_model)(lambda x: x)() + assert_matches( + [ + { + "loc": ("__root__",), + "msg": "none is not an allowed value", + "type": "type_error.none.not_allowed", + } + ], + excinfo.value.body_params, + ) + + def test_query_fail_validation_raise_exception(self, app, request_ctx, mocker): + app.config["FLASK_PYDANTIC_VALIDATION_ERROR_RAISE"] = True + mock_request = mocker.patch.object(request_ctx, "request") + content_type = "application/json" + mock_request.headers = {"Content-Type": content_type} + mock_request.get_json = lambda: None + query_model = QueryModel + with pytest.raises(ValidationError) as excinfo: + validate(query=query_model)(lambda x: x)() + assert_matches( + [ + { + "loc": ("q1",), + "msg": "field required", + "type": "value_error.missing", + } + ], + excinfo.value.query_params, + ) + + def test_form_fail_validation_raise_exception(self, app, request_ctx, mocker): + app.config["FLASK_PYDANTIC_VALIDATION_ERROR_RAISE"] = True + mock_request = mocker.patch.object(request_ctx, "request") + content_type = "application/json" + mock_request.headers = {"Content-Type": content_type} + mock_request.get_json = lambda: None + form_model = FormModel + with pytest.raises(ValidationError) as excinfo: + validate(form=form_model)(lambda x: x)() + assert_matches( + [ + { + "loc": ("f1",), + "msg": "field required", + "type": "value_error.missing", + } + ], + excinfo.value.form_params, + ) + + +class TestIsIterableOfModels: + def test_simple_true_case(self): + models = [ + QueryModel(q1=1, q2="w"), + QueryModel(q1=2, q2="wsdf"), + RequestBodyModel(b1=3.1), + RequestBodyModel(b1=0.1), + ] + assert is_iterable_of_models(models) + + def test_false_for_non_iterable(self): + assert not is_iterable_of_models(1) + + def test_false_for_single_model(self): + assert not is_iterable_of_models(RequestBodyModel(b1=12)) + + +convert_query_params_test_cases = [ + pytest.param( + ImmutableMultiDict({"a": 1, "b": "b"}), {"a": 1, "b": "b"}, id="primitive types" + ), + pytest.param( + ImmutableMultiDict({"a": 1, "b": "b", "c": ["one"]}), + {"a": 1, "b": "b", "c": ["one"]}, + id="one element in array", + ), + pytest.param( + ImmutableMultiDict({"a": 1, "b": "b", "c": ["one"], "d": [1]}), + {"a": 1, "b": "b", "c": ["one"], "d": [1]}, + id="one element in arrays", + ), + pytest.param( + ImmutableMultiDict({"a": 1, "b": "b", "c": ["one"], "d": [1, 2, 3]}), + {"a": 1, "b": "b", "c": ["one"], "d": [1, 2, 3]}, + id="one element in array, multiple in the other", + ), + pytest.param( + ImmutableMultiDict({"a": 1, "b": "b", "c": ["one", "two", "three"]}), + {"a": 1, "b": "b", "c": ["one", "two", "three"]}, + id="multiple elements in array", + ), + pytest.param( + ImmutableMultiDict( + {"a": 1, "b": "b", "c": ["one", "two", "three"], "d": [1, 2, 3]} + ), + {"a": 1, "b": "b", "c": ["one", "two", "three"], "d": [1, 2, 3]}, + id="multiple in both arrays", + ), +] + + +@pytest.mark.parametrize( + "query_params,expected_result", convert_query_params_test_cases +) +def test_convert_query_params(query_params: ImmutableMultiDict, expected_result: dict): + class Model(BaseModel): + a: int + b: str + c: Optional[List[str]] + d: Optional[List[int]] + + assert convert_query_params(query_params, Model) == expected_result