Skip to content

Commit

Permalink
Add backward compatibility with Pydantic V1 base models
Browse files Browse the repository at this point in the history
  • Loading branch information
Merinorus committed Jun 14, 2024
1 parent c0dfd89 commit 2b99f09
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 32 deletions.
32 changes: 22 additions & 10 deletions flask_pydantic/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from typing_extensions import get_args, get_origin

from pydantic import BaseModel
from pydantic.v1 import BaseModel as V1BaseModel
from werkzeug.datastructures import ImmutableMultiDict

V1OrV2BaseModel = Union[BaseModel, V1BaseModel]

def _is_list(type_: Type) -> bool:
origin = get_origin(type_)
Expand All @@ -19,7 +21,7 @@ def _is_list(type_: Type) -> bool:


def convert_query_params(
query_params: ImmutableMultiDict, model: Type[BaseModel]
query_params: ImmutableMultiDict, model: Type[V1OrV2BaseModel]
) -> dict:
"""
group query parameters into lists if model defines them
Expand All @@ -28,12 +30,22 @@ def convert_query_params(
:param model: query parameter's model
:return: resulting parameters
"""
return {
**query_params.to_dict(),
**{
key: value
for key, value in query_params.to_dict(flat=False).items()
if key in model.model_fields
and _is_list(model.model_fields[key].annotation)
},
}
if issubclass(model, BaseModel):
return {
**query_params.to_dict(),
**{
key: value
for key, value in query_params.to_dict(flat=False).items()
if key in model.model_fields
and _is_list(model.model_fields[key].annotation)
},
}
else:
return {
**query_params.to_dict(),
**{
key: value
for key, value in query_params.to_dict(flat=False).items()
if key in model.__fields__ and model.__fields__[key].is_complex()
},
}
75 changes: 53 additions & 22 deletions flask_pydantic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from flask import Response, current_app, jsonify, make_response, request
from pydantic import BaseModel, ValidationError, TypeAdapter, RootModel

from pydantic.v1 import BaseModel as V1BaseModel
from pydantic.v1.error_wrappers import ValidationError as V1ValidationError
from pydantic.v1.tools import parse_obj_as
from .converters import convert_query_params
from .exceptions import (
InvalidIterableOfModelsException,
Expand All @@ -17,19 +19,31 @@
except ImportError:
pass

V1OrV2BaseModel = Union[BaseModel, V1BaseModel]


def _model_dump_json(model: V1OrV2BaseModel, **kwargs):
"""Adapter to dump a model to json, whether it's a Pydantic V1 or V2 model."""
if isinstance(model, BaseModel):
# assert False, type(model)
return model.model_dump_json(**kwargs)
else:
# assert False, type(model)
return model.json(**kwargs)


def make_json_response(
content: Union[BaseModel, Iterable[BaseModel]],
content: Union[V1OrV2BaseModel, Iterable[V1OrV2BaseModel]],
status_code: int,
by_alias: bool,
exclude_none: bool = False,
many: bool = False,
) -> Response:
"""serializes model, creates JSON response with given status code"""
if many:
js = f"[{', '.join([model.model_dump_json(exclude_none=exclude_none, by_alias=by_alias) for model in content])}]"
js = f"[{', '.join([_model_dump_json(model, exclude_none=exclude_none, by_alias=by_alias) for model in content])}]"
else:
js = content.model_dump_json(exclude_none=exclude_none, by_alias=by_alias)
js = _model_dump_json(content, exclude_none=exclude_none, by_alias=by_alias)
response = make_response(js, status_code)
response.mimetype = "application/json"
return response
Expand All @@ -45,12 +59,12 @@ def unsupported_media_type_response(request_cont_type: str) -> Response:

def is_iterable_of_models(content: Any) -> bool:
try:
return all(isinstance(obj, BaseModel) for obj in content)
return all(isinstance(obj, V1OrV2BaseModel) for obj in content)
except TypeError:
return False


def validate_many_models(model: Type[BaseModel], content: Any) -> List[BaseModel]:
def validate_many_models(model: Type[V1OrV2BaseModel], content: Any) -> List[V1OrV2BaseModel]:
try:
return [model(**fields) for fields in content]
except TypeError:
Expand All @@ -63,7 +77,7 @@ def validate_many_models(model: Type[BaseModel], content: Any) -> List[BaseModel
}
]
raise ManyModelValidationError(err)
except ValidationError as ve:
except (ValidationError, V1ValidationError) as ve:
raise ManyModelValidationError(ve.errors())


Expand All @@ -74,9 +88,13 @@ def validate_path_params(func: Callable, kwargs: dict) -> Tuple[dict, list]:
if name in {"query", "body", "form", "return"}:
continue
try:
adapter = TypeAdapter(type_)
validated[name] = adapter.validate_python(kwargs.get(name))
except ValidationError as e:
if not isinstance(type_, V1BaseModel):
adapter = TypeAdapter(type_)
validated[name] = adapter.validate_python(kwargs.get(name))
else:
value = parse_obj_as(type_, kwargs.get(name))
validated[name] = value
except (ValidationError, V1ValidationError) as e:
err = e.errors()[0]
err["loc"] = [name]
errors.append(err)
Expand All @@ -92,15 +110,15 @@ def get_body_dict(**params):


def validate(
body: Optional[Type[BaseModel]] = None,
query: Optional[Type[BaseModel]] = None,
body: Optional[Type[V1OrV2BaseModel]] = None,
query: Optional[Type[V1OrV2BaseModel]] = None,
on_success_status: int = 200,
exclude_none: bool = False,
response_many: bool = False,
request_body_many: bool = False,
response_by_alias: bool = False,
get_json_params: Optional[dict] = None,
form: Optional[Type[BaseModel]] = None,
form: Optional[Type[V1OrV2BaseModel]] = None,
):
"""
Decorator for route methods which will validate query, body and form parameters
Expand Down Expand Up @@ -173,18 +191,26 @@ def wrapper(*args, **kwargs):
query_model = query_in_kwargs or query
if query_model:
query_params = convert_query_params(request.args, query_model)
# assert False, str(query_params)
try:
# assert False, 1
q = query_model(**query_params)
except ValidationError as ve:
# assert False, 2
except (ValidationError, V1ValidationError) as ve:
err["query_params"] = ve.errors()
body_in_kwargs = func.__annotations__.get("body")
body_model = body_in_kwargs or body
if body_model:
body_params = get_body_dict(**(get_json_params or {}))
if issubclass(body_model, RootModel):
if issubclass(body_model, V1BaseModel) and "__root__" in body_model.__fields__:
try:
b = body_model(__root__=body_params).__root__
except (ValidationError, V1ValidationError) as ve:
err["body_params"] = ve.errors()
elif issubclass(body_model, RootModel):
try:
b = body_model(body_params)
except ValidationError as ve:
except (ValidationError, V1ValidationError) as ve:
err["body_params"] = ve.errors()
elif request_body_many:
try:
Expand All @@ -201,16 +227,21 @@ def wrapper(*args, **kwargs):
return unsupported_media_type_response(content_type)
else:
raise JsonBodyParsingError()
except ValidationError as ve:
except (ValidationError, V1ValidationError) as ve:
err["body_params"] = ve.errors()
form_in_kwargs = func.__annotations__.get("form")
form_model = form_in_kwargs or form
if form_model:
form_params = request.form
if issubclass(form_model, RootModel):
if isinstance(form, V1BaseModel) and "__root__" in form_model.__fields__:
try:
f = form_model(form_params)
except (ValidationError, V1ValidationError) as ve:
err["form_params"] = ve.errors()
elif issubclass(form_model, RootModel):
try:
f = form_model(form_params)
except ValidationError as ve:
except (ValidationError, V1ValidationError) as ve:
err["form_params"] = ve.errors()
else:
try:
Expand All @@ -222,7 +253,7 @@ def wrapper(*args, **kwargs):
return unsupported_media_type_response(content_type)
else:
raise JsonBodyParsingError
except ValidationError as ve:
except (ValidationError, V1ValidationError) as ve:
err["form_params"] = ve.errors()
request.query_params = q
request.body_params = b
Expand Down Expand Up @@ -260,7 +291,7 @@ def wrapper(*args, **kwargs):
else:
raise InvalidIterableOfModelsException(res)

if isinstance(res, BaseModel):
if isinstance(res, V1OrV2BaseModel):
return make_json_response(
res,
on_success_status,
Expand All @@ -271,7 +302,7 @@ def wrapper(*args, **kwargs):
if (
isinstance(res, tuple)
and len(res) in [2, 3]
and isinstance(res[0], BaseModel)
and isinstance(res[0], V1OrV2BaseModel)
):
headers = None
status = on_success_status
Expand Down

0 comments on commit 2b99f09

Please sign in to comment.