Skip to content

Commit

Permalink
feat(Pydantic): added parameters in pydantic plugin to support strict…
Browse files Browse the repository at this point in the history
… validation and all the `model_dump` args (#3608)

feat: added parameters to pydantic plugin
  • Loading branch information
Anu-cool-007 committed Jul 11, 2024
1 parent ec77ce6 commit 5bcb256
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 23 deletions.
51 changes: 44 additions & 7 deletions litestar/contrib/pydantic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pydantic.v1 import BaseModel as BaseModelV1

from litestar.config.app import AppConfig
from litestar.types.serialization import PydanticV1FieldsListType, PydanticV2FieldsListType

__all__ = (
"PydanticDTO",
Expand Down Expand Up @@ -43,15 +44,43 @@ def _model_dump_json(model: BaseModel | BaseModelV1, by_alias: bool = False) ->
class PydanticPlugin(InitPluginProtocol):
"""A plugin that provides Pydantic integration."""

__slots__ = ("prefer_alias",)

def __init__(self, prefer_alias: bool = False) -> None:
"""Initialize ``PydanticPlugin``.
__slots__ = (
"exclude",
"exclude_defaults",
"exclude_none",
"exclude_unset",
"include",
"prefer_alias",
"validate_strict",
)

Args:
prefer_alias: OpenAPI and ``type_encoders`` will export by alias
def __init__(
self,
exclude: PydanticV1FieldsListType | PydanticV2FieldsListType | None = None,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_unset: bool = False,
include: PydanticV1FieldsListType | PydanticV2FieldsListType | None = None,
prefer_alias: bool = False,
validate_strict: bool = False,
) -> None:
"""Pydantic Plugin to support serialization / validation of Pydantic types / models
:param exclude: Fields to exclude during serialization
:param exclude_defaults: Fields to exclude during serialization when they are set to their default value
:param exclude_none: Fields to exclude during serialization when they are set to ``None``
:param exclude_unset: Fields to exclude during serialization when they arenot set
:param include: Fields to exclude during serialization
:param prefer_alias: Use the ``by_alias=True`` flag when dumping models
:param validate_strict: Use ``strict=True`` when calling ``.model_validate`` on Pydantic 2.x models
"""
self.exclude = exclude
self.exclude_defaults = exclude_defaults
self.exclude_none = exclude_none
self.exclude_unset = exclude_unset
self.include = include
self.prefer_alias = prefer_alias
self.validate_strict = validate_strict

def on_app_init(self, app_config: AppConfig) -> AppConfig:
"""Configure application for use with Pydantic.
Expand All @@ -61,7 +90,15 @@ def on_app_init(self, app_config: AppConfig) -> AppConfig:
"""
app_config.plugins.extend(
[
PydanticInitPlugin(prefer_alias=self.prefer_alias),
PydanticInitPlugin(
exclude=self.exclude,
exclude_defaults=self.exclude_defaults,
exclude_none=self.exclude_none,
exclude_unset=self.exclude_unset,
include=self.include,
prefer_alias=self.prefer_alias,
validate_strict=self.validate_strict,
),
PydanticSchemaPlugin(prefer_alias=self.prefer_alias),
PydanticDIPlugin(),
]
Expand Down
140 changes: 125 additions & 15 deletions litestar/contrib/pydantic/pydantic_init_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from contextlib import suppress
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
from uuid import UUID

Expand Down Expand Up @@ -34,6 +35,7 @@

if TYPE_CHECKING:
from litestar.config.app import AppConfig
from litestar.types.serialization import PydanticV1FieldsListType, PydanticV2FieldsListType


T = TypeVar("T")
Expand All @@ -46,9 +48,9 @@ def _dec_pydantic_v1(model_type: type[pydantic_v1.BaseModel], value: Any) -> pyd
raise ExtendedMsgSpecValidationError(errors=cast("list[dict[str, Any]]", e.errors())) from e


def _dec_pydantic_v2(model_type: type[pydantic_v2.BaseModel], value: Any) -> pydantic_v2.BaseModel:
def _dec_pydantic_v2(model_type: type[pydantic_v2.BaseModel], value: Any, strict: bool) -> pydantic_v2.BaseModel:
try:
return model_type.model_validate(value, strict=False)
return model_type.model_validate(value, strict=strict)
except pydantic_v2.ValidationError as e:
raise ExtendedMsgSpecValidationError(errors=cast("list[dict[str, Any]]", e.errors())) from e

Expand Down Expand Up @@ -123,36 +125,116 @@ def extract(annotation: Any, default: Any) -> Any:


class PydanticInitPlugin(InitPluginProtocol):
__slots__ = ("prefer_alias",)
__slots__ = (
"exclude",
"exclude_defaults",
"exclude_none",
"exclude_unset",
"include",
"prefer_alias",
"validate_strict",
)

def __init__(self, prefer_alias: bool = False) -> None:
def __init__(
self,
exclude: PydanticV1FieldsListType | PydanticV2FieldsListType | None = None,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_unset: bool = False,
include: PydanticV1FieldsListType | PydanticV2FieldsListType | None = None,
prefer_alias: bool = False,
validate_strict: bool = False,
) -> None:
"""Pydantic Plugin to support serialization / validation of Pydantic types / models
:param exclude: Fields to exclude during serialization
:param exclude_defaults: Fields to exclude during serialization when they are set to their default value
:param exclude_none: Fields to exclude during serialization when they are set to ``None``
:param exclude_unset: Fields to exclude during serialization when they arenot set
:param include: Fields to exclude during serialization
:param prefer_alias: Use the ``by_alias=True`` flag when dumping models
:param validate_strict: Use ``strict=True`` when calling ``.model_validate`` on Pydantic 2.x models
"""
self.exclude = exclude
self.exclude_defaults = exclude_defaults
self.exclude_none = exclude_none
self.exclude_unset = exclude_unset
self.include = include
self.prefer_alias = prefer_alias
self.validate_strict = validate_strict

@classmethod
def encoders(cls, prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]:
encoders = {**_base_encoders, **cls._create_pydantic_v1_encoders(prefer_alias)}
def encoders(
cls,
exclude: PydanticV1FieldsListType | PydanticV2FieldsListType | None = None,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_unset: bool = False,
include: PydanticV1FieldsListType | PydanticV2FieldsListType | None = None,
prefer_alias: bool = False,
) -> dict[Any, Callable[[Any], Any]]:
encoders = {
**_base_encoders,
**cls._create_pydantic_v1_encoders(
prefer_alias=prefer_alias,
exclude=exclude,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
exclude_unset=exclude_unset,
include=include,
),
}
if pydantic_v2 is not None: # pragma: no cover
encoders.update(cls._create_pydantic_v2_encoders(prefer_alias))
encoders.update(
cls._create_pydantic_v2_encoders(
prefer_alias=prefer_alias,
exclude=exclude,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
exclude_unset=exclude_unset,
include=include,
)
)
return encoders

@classmethod
def decoders(cls) -> list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]:
def decoders(cls, validate_strict: bool = False) -> list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]:
decoders: list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]] = [
(is_pydantic_v1_model_class, _dec_pydantic_v1)
]

if pydantic_v2 is not None: # pragma: no cover
decoders.append((is_pydantic_v2_model_class, _dec_pydantic_v2))
decoders.append(
(
is_pydantic_v2_model_class,
partial(_dec_pydantic_v2, strict=validate_strict),
)
)

decoders.append((_is_pydantic_v1_uuid, _dec_pydantic_uuid))

return decoders

@staticmethod
def _create_pydantic_v1_encoders(prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]: # pragma: no cover
def _create_pydantic_v1_encoders(
exclude: PydanticV1FieldsListType | None = None,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_unset: bool = False,
include: PydanticV1FieldsListType | None = None,
prefer_alias: bool = False,
) -> dict[Any, Callable[[Any], Any]]: # pragma: no cover
return {
pydantic_v1.BaseModel: lambda model: {
k: v.decode() if isinstance(v, bytes) else v for k, v in model.dict(by_alias=prefer_alias).items()
k: v.decode() if isinstance(v, bytes) else v
for k, v in model.dict(
by_alias=prefer_alias,
exclude=exclude,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
exclude_unset=exclude_unset,
include=include,
).items()
},
pydantic_v1.SecretField: str,
pydantic_v1.StrictBool: int,
Expand All @@ -163,9 +245,24 @@ def _create_pydantic_v1_encoders(prefer_alias: bool = False) -> dict[Any, Callab
}

@staticmethod
def _create_pydantic_v2_encoders(prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]:
def _create_pydantic_v2_encoders(
exclude: PydanticV2FieldsListType | None = None,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_unset: bool = False,
include: PydanticV2FieldsListType | None = None,
prefer_alias: bool = False,
) -> dict[Any, Callable[[Any], Any]]:
encoders: dict[Any, Callable[[Any], Any]] = {
pydantic_v2.BaseModel: lambda model: model.model_dump(mode="json", by_alias=prefer_alias),
pydantic_v2.BaseModel: lambda model: model.model_dump(
by_alias=prefer_alias,
exclude=exclude,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
exclude_unset=exclude_unset,
include=include,
mode="json",
),
pydantic_v2.types.SecretStr: lambda val: "**********" if val else "",
pydantic_v2.types.SecretBytes: lambda val: "**********" if val else "",
pydantic_v2.AnyUrl: str,
Expand All @@ -179,8 +276,21 @@ def _create_pydantic_v2_encoders(prefer_alias: bool = False) -> dict[Any, Callab
return encoders

def on_app_init(self, app_config: AppConfig) -> AppConfig:
app_config.type_encoders = {**self.encoders(self.prefer_alias), **(app_config.type_encoders or {})}
app_config.type_decoders = [*self.decoders(), *(app_config.type_decoders or [])]
app_config.type_encoders = {
**self.encoders(
prefer_alias=self.prefer_alias,
exclude=self.exclude,
exclude_defaults=self.exclude_defaults,
exclude_none=self.exclude_none,
exclude_unset=self.exclude_unset,
include=self.include,
),
**(app_config.type_encoders or {}),
}
app_config.type_decoders = [
*self.decoders(validate_strict=self.validate_strict),
*(app_config.type_decoders or []),
]

_KWARG_META_EXTRACTORS.add(ConstrainedFieldMetaExtractor)
return app_config
9 changes: 8 additions & 1 deletion litestar/types/serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Dict, Set

if TYPE_CHECKING:
from collections import deque
Expand Down Expand Up @@ -28,8 +28,13 @@

try:
from pydantic import BaseModel
from pydantic.main import IncEx
from pydantic.typing import AbstractSetIntStr, MappingIntStrAny
except ImportError:
BaseModel = Any # type: ignore[assignment, misc]
IncEx = Any # type: ignore[misc]
AbstractSetIntStr = Any
MappingIntStrAny = Any

try:
from attrs import AttrsInstance
Expand Down Expand Up @@ -57,3 +62,5 @@
EncodableMsgSpecType: TypeAlias = "Ext | Raw | Struct"
LitestarEncodableType: TypeAlias = "EncodableBuiltinType | EncodableBuiltinCollectionType | EncodableStdLibType | EncodableStdLibIPType | EncodableMsgSpecType | BaseModel | AttrsInstance" # pyright: ignore
DataContainerType: TypeAlias = "Struct | BaseModel | AttrsInstance | TypedDictClass | DataclassProtocol" # pyright: ignore
PydanticV2FieldsListType: TypeAlias = "Set[int] | Set[str] | Dict[int, Any] | Dict[str, Any]"
PydanticV1FieldsListType: TypeAlias = "IncEx | AbstractSetIntStr | MappingIntStrAny" # pyright: ignore
Loading

0 comments on commit 5bcb256

Please sign in to comment.