Skip to content

Commit

Permalink
Keep retrocompatibility with Pydantic V1 base models
Browse files Browse the repository at this point in the history
  • Loading branch information
Merinorus committed Jun 14, 2024
1 parent c0dfd89 commit 5b5089d
Showing 1 changed file with 35 additions and 12 deletions.
47 changes: 35 additions & 12 deletions flask_pydantic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

from flask import Response, current_app, jsonify, make_response, request
from pydantic import BaseModel, ValidationError, TypeAdapter, RootModel
from pydantic.v1 import BaseModel as V1BaseModel
from pydantic.v1.tools import parse_obj_as
V1OrV2BaseModel = Union[BaseModel, V1BaseModel]

from .converters import convert_query_params
from .exceptions import (
Expand All @@ -17,19 +20,25 @@
except ImportError:
pass

def _model_dump_json(model: V1OrV2BaseModel, **kwargs):
"""Adapter to dump a model to json, whether it's a Pydantic V1 or V2 model."""
if isinstance(model, BaseModel):
return model.model_dump_json(**kwargs)
elif isinstance(model, V1BaseModel):
return model.json(**kwargs)

def make_json_response(
content: Union[BaseModel, Iterable[BaseModel]],
content: Union[V1OrV2BaseModel, Iterable[V1OrV2BaseModel]],
status_code: int,
by_alias: bool,
exclude_none: bool = False,
many: bool = False,
) -> Response:
"""serializes model, creates JSON response with given status code"""
if many:
js = f"[{', '.join([model.model_dump_json(exclude_none=exclude_none, by_alias=by_alias) for model in content])}]"
js = f"[{', '.join([_model_dump_json(model, exclude_none=exclude_none, by_alias=by_alias) for model in content])}]"
else:
js = content.model_dump_json(exclude_none=exclude_none, by_alias=by_alias)
js = _model_dump_json(content, exclude_none=exclude_none, by_alias=by_alias)
response = make_response(js, status_code)
response.mimetype = "application/json"
return response
Expand All @@ -45,12 +54,12 @@ def unsupported_media_type_response(request_cont_type: str) -> Response:

def is_iterable_of_models(content: Any) -> bool:
try:
return all(isinstance(obj, BaseModel) for obj in content)
return all(isinstance(obj, V1OrV2BaseModel) for obj in content)
except TypeError:
return False


def validate_many_models(model: Type[BaseModel], content: Any) -> List[BaseModel]:
def validate_many_models(model: Type[V1OrV2BaseModel], content: Any) -> List[V1OrV2BaseModel]:
try:
return [model(**fields) for fields in content]
except TypeError:
Expand All @@ -74,8 +83,12 @@ def validate_path_params(func: Callable, kwargs: dict) -> Tuple[dict, list]:
if name in {"query", "body", "form", "return"}:
continue
try:
adapter = TypeAdapter(type_)
validated[name] = adapter.validate_python(kwargs.get(name))
if not isinstance(type_, V1BaseModel):
adapter = TypeAdapter(type_)
validated[name] = adapter.validate_python(kwargs.get(name))
else:
value = parse_obj_as(type_, kwargs.get(name))
validated[name] = value
except ValidationError as e:
err = e.errors()[0]
err["loc"] = [name]
Expand All @@ -92,15 +105,15 @@ def get_body_dict(**params):


def validate(
body: Optional[Type[BaseModel]] = None,
query: Optional[Type[BaseModel]] = None,
body: Optional[Type[V1OrV2BaseModel]] = None,
query: Optional[Type[V1OrV2BaseModel]] = None,
on_success_status: int = 200,
exclude_none: bool = False,
response_many: bool = False,
request_body_many: bool = False,
response_by_alias: bool = False,
get_json_params: Optional[dict] = None,
form: Optional[Type[BaseModel]] = None,
form: Optional[Type[V1OrV2BaseModel]] = None,
):
"""
Decorator for route methods which will validate query, body and form parameters
Expand Down Expand Up @@ -186,6 +199,11 @@ def wrapper(*args, **kwargs):
b = body_model(body_params)
except ValidationError as ve:
err["body_params"] = ve.errors()
elif isinstance(body, V1BaseModel) and "__root__" in body_model.__fields__:
try:
b = body_model(__root__=body_params).__root__
except ValidationError as ve:
err["body_params"] = ve.errors()
elif request_body_many:
try:
b = validate_many_models(body_model, body_params)
Expand All @@ -212,6 +230,11 @@ def wrapper(*args, **kwargs):
f = form_model(form_params)
except ValidationError as ve:
err["form_params"] = ve.errors()
elif issubclass(form, V1BaseModel) and "__root__" in form_model.__fields__:
try:
f = form_model(form_params)
except ValidationError as ve:
err["form_params"] = ve.errors()
else:
try:
f = form_model(**form_params)
Expand Down Expand Up @@ -260,7 +283,7 @@ def wrapper(*args, **kwargs):
else:
raise InvalidIterableOfModelsException(res)

if isinstance(res, BaseModel):
if isinstance(res, V1OrV2BaseModel):
return make_json_response(
res,
on_success_status,
Expand All @@ -271,7 +294,7 @@ def wrapper(*args, **kwargs):
if (
isinstance(res, tuple)
and len(res) in [2, 3]
and isinstance(res[0], BaseModel)
and isinstance(res[0], V1OrV2BaseModel)
):
headers = None
status = on_success_status
Expand Down

0 comments on commit 5b5089d

Please sign in to comment.