Skip to content

Commit

Permalink
Add backward compatibility with Pydantic V1 base models
Browse files Browse the repository at this point in the history
  • Loading branch information
Merinorus committed Jul 13, 2024
1 parent 8595fa8 commit 4bff748
Show file tree
Hide file tree
Showing 8 changed files with 1,172 additions and 32 deletions.
33 changes: 23 additions & 10 deletions flask_pydantic/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
from typing_extensions import get_args, get_origin

from pydantic import BaseModel
from pydantic.v1 import BaseModel as V1BaseModel
from werkzeug.datastructures import ImmutableMultiDict

V1OrV2BaseModel = Union[BaseModel, V1BaseModel]


def _is_list(type_: Type) -> bool:
origin = get_origin(type_)
Expand All @@ -19,7 +22,7 @@ def _is_list(type_: Type) -> bool:


def convert_query_params(
query_params: ImmutableMultiDict, model: Type[BaseModel]
query_params: ImmutableMultiDict, model: Type[V1OrV2BaseModel]
) -> dict:
"""
group query parameters into lists if model defines them
Expand All @@ -28,12 +31,22 @@ def convert_query_params(
:param model: query parameter's model
:return: resulting parameters
"""
return {
**query_params.to_dict(),
**{
key: value
for key, value in query_params.to_dict(flat=False).items()
if key in model.model_fields
and _is_list(model.model_fields[key].annotation)
},
}
if issubclass(model, BaseModel):
return {
**query_params.to_dict(),
**{
key: value
for key, value in query_params.to_dict(flat=False).items()
if key in model.model_fields
and _is_list(model.model_fields[key].annotation)
},
}
else:
return {
**query_params.to_dict(),
**{
key: value
for key, value in query_params.to_dict(flat=False).items()
if key in model.__fields__ and model.__fields__[key].is_complex()
},
}
78 changes: 56 additions & 22 deletions flask_pydantic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from flask import Response, current_app, jsonify, make_response, request
from pydantic import BaseModel, ValidationError, TypeAdapter, RootModel

from pydantic.v1 import BaseModel as V1BaseModel
from pydantic.v1.error_wrappers import ValidationError as V1ValidationError
from pydantic.v1.tools import parse_obj_as
from .converters import convert_query_params
from .exceptions import (
InvalidIterableOfModelsException,
Expand All @@ -17,19 +19,29 @@
except ImportError:
pass

V1OrV2BaseModel = Union[BaseModel, V1BaseModel]


def _model_dump_json(model: V1OrV2BaseModel, **kwargs):
"""Adapter to dump a model to json, whether it's a Pydantic V1 or V2 model."""
if isinstance(model, BaseModel):
return model.model_dump_json(**kwargs)
else:
return model.json(**kwargs)


def make_json_response(
content: Union[BaseModel, Iterable[BaseModel]],
content: Union[V1OrV2BaseModel, Iterable[V1OrV2BaseModel]],
status_code: int,
by_alias: bool,
exclude_none: bool = False,
many: bool = False,
) -> Response:
"""serializes model, creates JSON response with given status code"""
if many:
js = f"[{', '.join([model.model_dump_json(exclude_none=exclude_none, by_alias=by_alias) for model in content])}]"
js = f"[{', '.join([_model_dump_json(model, exclude_none=exclude_none, by_alias=by_alias) for model in content])}]"
else:
js = content.model_dump_json(exclude_none=exclude_none, by_alias=by_alias)
js = _model_dump_json(content, exclude_none=exclude_none, by_alias=by_alias)
response = make_response(js, status_code)
response.mimetype = "application/json"
return response
Expand All @@ -45,12 +57,14 @@ def unsupported_media_type_response(request_cont_type: str) -> Response:

def is_iterable_of_models(content: Any) -> bool:
try:
return all(isinstance(obj, BaseModel) for obj in content)
return all(isinstance(obj, V1OrV2BaseModel) for obj in content)
except TypeError:
return False


def validate_many_models(model: Type[BaseModel], content: Any) -> List[BaseModel]:
def validate_many_models(
model: Type[V1OrV2BaseModel], content: Any
) -> List[V1OrV2BaseModel]:
try:
return [model(**fields) for fields in content]
except TypeError:
Expand All @@ -63,7 +77,7 @@ def validate_many_models(model: Type[BaseModel], content: Any) -> List[BaseModel
}
]
raise ManyModelValidationError(err)
except ValidationError as ve:
except (ValidationError, V1ValidationError) as ve:
raise ManyModelValidationError(ve.errors())


Expand All @@ -74,9 +88,13 @@ def validate_path_params(func: Callable, kwargs: dict) -> Tuple[dict, list]:
if name in {"query", "body", "form", "return"}:
continue
try:
adapter = TypeAdapter(type_)
validated[name] = adapter.validate_python(kwargs.get(name))
except ValidationError as e:
if not isinstance(type_, V1BaseModel):
adapter = TypeAdapter(type_)
validated[name] = adapter.validate_python(kwargs.get(name))
else:
value = parse_obj_as(type_, kwargs.get(name))
validated[name] = value
except (ValidationError, V1ValidationError) as e:
err = e.errors()[0]
err["loc"] = [name]
errors.append(err)
Expand All @@ -92,15 +110,15 @@ def get_body_dict(**params):


def validate(
body: Optional[Type[BaseModel]] = None,
query: Optional[Type[BaseModel]] = None,
body: Optional[Type[V1OrV2BaseModel]] = None,
query: Optional[Type[V1OrV2BaseModel]] = None,
on_success_status: int = 200,
exclude_none: bool = False,
response_many: bool = False,
request_body_many: bool = False,
response_by_alias: bool = False,
get_json_params: Optional[dict] = None,
form: Optional[Type[BaseModel]] = None,
form: Optional[Type[V1OrV2BaseModel]] = None,
):
"""
Decorator for route methods which will validate query, body and form parameters
Expand Down Expand Up @@ -175,16 +193,24 @@ def wrapper(*args, **kwargs):
query_params = convert_query_params(request.args, query_model)
try:
q = query_model(**query_params)
except ValidationError as ve:
except (ValidationError, V1ValidationError) as ve:
err["query_params"] = ve.errors()
body_in_kwargs = func.__annotations__.get("body")
body_model = body_in_kwargs or body
if body_model:
body_params = get_body_dict(**(get_json_params or {}))
if issubclass(body_model, RootModel):
if (
issubclass(body_model, V1BaseModel)
and "__root__" in body_model.__fields__
):
try:
b = body_model(__root__=body_params).__root__
except (ValidationError, V1ValidationError) as ve:
err["body_params"] = ve.errors()
elif issubclass(body_model, RootModel):
try:
b = body_model(body_params)
except ValidationError as ve:
except (ValidationError, V1ValidationError) as ve:
err["body_params"] = ve.errors()
elif request_body_many:
try:
Expand All @@ -201,16 +227,24 @@ def wrapper(*args, **kwargs):
return unsupported_media_type_response(content_type)
else:
raise JsonBodyParsingError()
except ValidationError as ve:
except (ValidationError, V1ValidationError) as ve:
err["body_params"] = ve.errors()
form_in_kwargs = func.__annotations__.get("form")
form_model = form_in_kwargs or form
if form_model:
form_params = request.form
if issubclass(form_model, RootModel):
if (
isinstance(form, V1BaseModel)
and "__root__" in form_model.__fields__
):
try:
f = form_model(form_params)
except (ValidationError, V1ValidationError) as ve:
err["form_params"] = ve.errors()
elif issubclass(form_model, RootModel):
try:
f = form_model(form_params)
except ValidationError as ve:
except (ValidationError, V1ValidationError) as ve:
err["form_params"] = ve.errors()
else:
try:
Expand All @@ -222,7 +256,7 @@ def wrapper(*args, **kwargs):
return unsupported_media_type_response(content_type)
else:
raise JsonBodyParsingError
except ValidationError as ve:
except (ValidationError, V1ValidationError) as ve:
err["form_params"] = ve.errors()
request.query_params = q
request.body_params = b
Expand Down Expand Up @@ -260,7 +294,7 @@ def wrapper(*args, **kwargs):
else:
raise InvalidIterableOfModelsException(res)

if isinstance(res, BaseModel):
if isinstance(res, V1OrV2BaseModel):
return make_json_response(
res,
on_success_status,
Expand All @@ -271,7 +305,7 @@ def wrapper(*args, **kwargs):
if (
isinstance(res, tuple)
and len(res) in [2, 3]
and isinstance(res[0], BaseModel)
and isinstance(res[0], V1OrV2BaseModel)
):
headers = None
status = on_success_status
Expand Down
Empty file added tests/pydantic_v1/__init__.py
Empty file.
56 changes: 56 additions & 0 deletions tests/pydantic_v1/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
Specific confest.py file for testing behavior with Pydantic V1.
The fixtures below override the confest.py's fixtures for this module only.
"""
from typing import List, Optional, Type

import pytest

from pydantic.v1 import BaseModel


@pytest.fixture
def query_model() -> Type[BaseModel]:
class Query(BaseModel):
limit: int = 2
min_views: Optional[int] = None

return Query


@pytest.fixture
def body_model() -> Type[BaseModel]:
class Body(BaseModel):
search_term: str
exclude: Optional[str] = None

return Body


@pytest.fixture
def form_model() -> Type[BaseModel]:
class Form(BaseModel):
search_term: str
exclude: Optional[str] = None

return Form


@pytest.fixture
def post_model() -> Type[BaseModel]:
class Post(BaseModel):
title: str
text: str
views: int

return Post


@pytest.fixture
def response_model(post_model: BaseModel) -> Type[BaseModel]:
class Response(BaseModel):
results: List[post_model]
count: int

return Response
Empty file.
Loading

0 comments on commit 4bff748

Please sign in to comment.