Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(Pydantic): added parameters in pydantic plugin to support strict validation and all the model_dump args #3608

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions litestar/contrib/pydantic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pydantic.v1 import BaseModel as BaseModelV1

from litestar.config.app import AppConfig
from litestar.types.serialization import PydanticV1FieldsListType, PydanticV2FieldsListType

__all__ = (
"PydanticDTO",
Expand Down Expand Up @@ -43,15 +44,43 @@ def _model_dump_json(model: BaseModel | BaseModelV1, by_alias: bool = False) ->
class PydanticPlugin(InitPluginProtocol):
"""A plugin that provides Pydantic integration."""

__slots__ = ("prefer_alias",)

def __init__(self, prefer_alias: bool = False) -> None:
"""Initialize ``PydanticPlugin``.
__slots__ = (
"exclude",
"exclude_defaults",
"exclude_none",
"exclude_unset",
"include",
"prefer_alias",
"validate_strict",
)

Args:
prefer_alias: OpenAPI and ``type_encoders`` will export by alias
def __init__(
self,
exclude: PydanticV1FieldsListType | PydanticV2FieldsListType | None = None,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_unset: bool = False,
include: PydanticV1FieldsListType | PydanticV2FieldsListType | None = None,
prefer_alias: bool = False,
validate_strict: bool = False,
) -> None:
"""Pydantic Plugin to support serialization / validation of Pydantic types / models

:param exclude: Fields to exclude during serialization
:param exclude_defaults: Fields to exclude during serialization when they are set to their default value
:param exclude_none: Fields to exclude during serialization when they are set to ``None``
:param exclude_unset: Fields to exclude during serialization when they arenot set
:param include: Fields to exclude during serialization
:param prefer_alias: Use the ``by_alias=True`` flag when dumping models
:param validate_strict: Use ``strict=True`` when calling ``.model_validate`` on Pydantic 2.x models
"""
self.exclude = exclude
self.exclude_defaults = exclude_defaults
self.exclude_none = exclude_none
self.exclude_unset = exclude_unset
self.include = include
self.prefer_alias = prefer_alias
self.validate_strict = validate_strict

def on_app_init(self, app_config: AppConfig) -> AppConfig:
"""Configure application for use with Pydantic.
Expand All @@ -61,7 +90,15 @@ def on_app_init(self, app_config: AppConfig) -> AppConfig:
"""
app_config.plugins.extend(
[
PydanticInitPlugin(prefer_alias=self.prefer_alias),
PydanticInitPlugin(
exclude=self.exclude,
exclude_defaults=self.exclude_defaults,
exclude_none=self.exclude_none,
exclude_unset=self.exclude_unset,
include=self.include,
prefer_alias=self.prefer_alias,
validate_strict=self.validate_strict,
),
PydanticSchemaPlugin(prefer_alias=self.prefer_alias),
PydanticDIPlugin(),
]
Expand Down
140 changes: 125 additions & 15 deletions litestar/contrib/pydantic/pydantic_init_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from contextlib import suppress
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
from uuid import UUID

Expand Down Expand Up @@ -34,6 +35,7 @@

if TYPE_CHECKING:
from litestar.config.app import AppConfig
from litestar.types.serialization import PydanticV1FieldsListType, PydanticV2FieldsListType


T = TypeVar("T")
Expand All @@ -46,9 +48,9 @@ def _dec_pydantic_v1(model_type: type[pydantic_v1.BaseModel], value: Any) -> pyd
raise ExtendedMsgSpecValidationError(errors=cast("list[dict[str, Any]]", e.errors())) from e


def _dec_pydantic_v2(model_type: type[pydantic_v2.BaseModel], value: Any) -> pydantic_v2.BaseModel:
def _dec_pydantic_v2(model_type: type[pydantic_v2.BaseModel], value: Any, strict: bool) -> pydantic_v2.BaseModel:
try:
return model_type.model_validate(value, strict=False)
return model_type.model_validate(value, strict=strict)
except pydantic_v2.ValidationError as e:
raise ExtendedMsgSpecValidationError(errors=cast("list[dict[str, Any]]", e.errors())) from e

Expand Down Expand Up @@ -123,36 +125,116 @@ def extract(annotation: Any, default: Any) -> Any:


class PydanticInitPlugin(InitPluginProtocol):
__slots__ = ("prefer_alias",)
__slots__ = (
"exclude",
"exclude_defaults",
"exclude_none",
"exclude_unset",
"include",
"prefer_alias",
"validate_strict",
)

def __init__(self, prefer_alias: bool = False) -> None:
def __init__(
self,
exclude: PydanticV1FieldsListType | PydanticV2FieldsListType | None = None,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_unset: bool = False,
include: PydanticV1FieldsListType | PydanticV2FieldsListType | None = None,
prefer_alias: bool = False,
validate_strict: bool = False,
) -> None:
"""Pydantic Plugin to support serialization / validation of Pydantic types / models

:param exclude: Fields to exclude during serialization
:param exclude_defaults: Fields to exclude during serialization when they are set to their default value
:param exclude_none: Fields to exclude during serialization when they are set to ``None``
:param exclude_unset: Fields to exclude during serialization when they arenot set
:param include: Fields to exclude during serialization
:param prefer_alias: Use the ``by_alias=True`` flag when dumping models
:param validate_strict: Use ``strict=True`` when calling ``.model_validate`` on Pydantic 2.x models
"""
self.exclude = exclude
self.exclude_defaults = exclude_defaults
self.exclude_none = exclude_none
self.exclude_unset = exclude_unset
self.include = include
self.prefer_alias = prefer_alias
self.validate_strict = validate_strict

@classmethod
def encoders(cls, prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]:
encoders = {**_base_encoders, **cls._create_pydantic_v1_encoders(prefer_alias)}
def encoders(
cls,
exclude: PydanticV1FieldsListType | PydanticV2FieldsListType | None = None,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_unset: bool = False,
include: PydanticV1FieldsListType | PydanticV2FieldsListType | None = None,
prefer_alias: bool = False,
) -> dict[Any, Callable[[Any], Any]]:
encoders = {
**_base_encoders,
**cls._create_pydantic_v1_encoders(
prefer_alias=prefer_alias,
exclude=exclude,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
exclude_unset=exclude_unset,
include=include,
),
}
if pydantic_v2 is not None: # pragma: no cover
encoders.update(cls._create_pydantic_v2_encoders(prefer_alias))
encoders.update(
cls._create_pydantic_v2_encoders(
prefer_alias=prefer_alias,
exclude=exclude,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
exclude_unset=exclude_unset,
include=include,
)
)
return encoders

@classmethod
def decoders(cls) -> list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]:
def decoders(cls, validate_strict: bool = False) -> list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]:
decoders: list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]] = [
(is_pydantic_v1_model_class, _dec_pydantic_v1)
]

if pydantic_v2 is not None: # pragma: no cover
decoders.append((is_pydantic_v2_model_class, _dec_pydantic_v2))
decoders.append(
(
is_pydantic_v2_model_class,
partial(_dec_pydantic_v2, strict=validate_strict),
)
)

decoders.append((_is_pydantic_v1_uuid, _dec_pydantic_uuid))

return decoders

@staticmethod
def _create_pydantic_v1_encoders(prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]: # pragma: no cover
def _create_pydantic_v1_encoders(
exclude: PydanticV1FieldsListType | None = None,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_unset: bool = False,
include: PydanticV1FieldsListType | None = None,
prefer_alias: bool = False,
) -> dict[Any, Callable[[Any], Any]]: # pragma: no cover
return {
pydantic_v1.BaseModel: lambda model: {
k: v.decode() if isinstance(v, bytes) else v for k, v in model.dict(by_alias=prefer_alias).items()
k: v.decode() if isinstance(v, bytes) else v
for k, v in model.dict(
by_alias=prefer_alias,
exclude=exclude,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
exclude_unset=exclude_unset,
include=include,
).items()
},
pydantic_v1.SecretField: str,
pydantic_v1.StrictBool: int,
Expand All @@ -163,9 +245,24 @@ def _create_pydantic_v1_encoders(prefer_alias: bool = False) -> dict[Any, Callab
}

@staticmethod
def _create_pydantic_v2_encoders(prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]:
def _create_pydantic_v2_encoders(
exclude: PydanticV2FieldsListType | None = None,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_unset: bool = False,
include: PydanticV2FieldsListType | None = None,
prefer_alias: bool = False,
) -> dict[Any, Callable[[Any], Any]]:
encoders: dict[Any, Callable[[Any], Any]] = {
pydantic_v2.BaseModel: lambda model: model.model_dump(mode="json", by_alias=prefer_alias),
pydantic_v2.BaseModel: lambda model: model.model_dump(
by_alias=prefer_alias,
exclude=exclude,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
exclude_unset=exclude_unset,
include=include,
mode="json",
),
pydantic_v2.types.SecretStr: lambda val: "**********" if val else "",
pydantic_v2.types.SecretBytes: lambda val: "**********" if val else "",
pydantic_v2.AnyUrl: str,
Expand All @@ -179,8 +276,21 @@ def _create_pydantic_v2_encoders(prefer_alias: bool = False) -> dict[Any, Callab
return encoders

def on_app_init(self, app_config: AppConfig) -> AppConfig:
app_config.type_encoders = {**self.encoders(self.prefer_alias), **(app_config.type_encoders or {})}
app_config.type_decoders = [*self.decoders(), *(app_config.type_decoders or [])]
app_config.type_encoders = {
**self.encoders(
prefer_alias=self.prefer_alias,
exclude=self.exclude,
exclude_defaults=self.exclude_defaults,
exclude_none=self.exclude_none,
exclude_unset=self.exclude_unset,
include=self.include,
),
**(app_config.type_encoders or {}),
}
app_config.type_decoders = [
*self.decoders(validate_strict=self.validate_strict),
*(app_config.type_decoders or []),
]

_KWARG_META_EXTRACTORS.add(ConstrainedFieldMetaExtractor)
return app_config
9 changes: 8 additions & 1 deletion litestar/types/serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Dict, Set

if TYPE_CHECKING:
from collections import deque
Expand Down Expand Up @@ -28,8 +28,13 @@

try:
from pydantic import BaseModel
from pydantic.main import IncEx
from pydantic.typing import AbstractSetIntStr, MappingIntStrAny
except ImportError:
BaseModel = Any # type: ignore[assignment, misc]
IncEx = Any # type: ignore[misc]
AbstractSetIntStr = Any
MappingIntStrAny = Any

try:
from attrs import AttrsInstance
Expand Down Expand Up @@ -57,3 +62,5 @@
EncodableMsgSpecType: TypeAlias = "Ext | Raw | Struct"
LitestarEncodableType: TypeAlias = "EncodableBuiltinType | EncodableBuiltinCollectionType | EncodableStdLibType | EncodableStdLibIPType | EncodableMsgSpecType | BaseModel | AttrsInstance" # pyright: ignore
DataContainerType: TypeAlias = "Struct | BaseModel | AttrsInstance | TypedDictClass | DataclassProtocol" # pyright: ignore
PydanticV2FieldsListType: TypeAlias = "Set[int] | Set[str] | Dict[int, Any] | Dict[str, Any]"
PydanticV1FieldsListType: TypeAlias = "IncEx | AbstractSetIntStr | MappingIntStrAny" # pyright: ignore
Loading
Loading