From 7a88082ad0a52ba1fe389d1456b0e5b817bfb067 Mon Sep 17 00:00:00 2001 From: Anuranjan Srivastava Date: Thu, 4 Jul 2024 02:06:44 +0530 Subject: [PATCH] feat: added parameters to pydantic plugin --- litestar/contrib/pydantic/__init__.py | 51 ++++++- .../contrib/pydantic/pydantic_init_plugin.py | 140 ++++++++++++++++-- litestar/types/serialization.py | 9 +- .../test_pydantic/test_integration.py | 83 +++++++++++ 4 files changed, 260 insertions(+), 23 deletions(-) diff --git a/litestar/contrib/pydantic/__init__.py b/litestar/contrib/pydantic/__init__.py index 9bab707c31..094560fe94 100644 --- a/litestar/contrib/pydantic/__init__.py +++ b/litestar/contrib/pydantic/__init__.py @@ -14,6 +14,7 @@ from pydantic.v1 import BaseModel as BaseModelV1 from litestar.config.app import AppConfig + from litestar.types.serialization import PydanticV1FieldsListType, PydanticV2FieldsListType __all__ = ( "PydanticDTO", @@ -43,15 +44,43 @@ def _model_dump_json(model: BaseModel | BaseModelV1, by_alias: bool = False) -> class PydanticPlugin(InitPluginProtocol): """A plugin that provides Pydantic integration.""" - __slots__ = ("prefer_alias",) - - def __init__(self, prefer_alias: bool = False) -> None: - """Initialize ``PydanticPlugin``. + __slots__ = ( + "exclude", + "exclude_defaults", + "exclude_none", + "exclude_unset", + "include", + "prefer_alias", + "validate_strict", + ) - Args: - prefer_alias: OpenAPI and ``type_encoders`` will export by alias + def __init__( + self, + exclude: PydanticV1FieldsListType | PydanticV2FieldsListType | None = None, + exclude_defaults: bool = False, + exclude_none: bool = False, + exclude_unset: bool = False, + include: PydanticV1FieldsListType | PydanticV2FieldsListType | None = None, + prefer_alias: bool = False, + validate_strict: bool = False, + ) -> None: + """Pydantic Plugin to support serialization / validation of Pydantic types / models + + :param exclude: Fields to exclude during serialization + :param exclude_defaults: Fields to exclude during serialization when they are set to their default value + :param exclude_none: Fields to exclude during serialization when they are set to ``None`` + :param exclude_unset: Fields to exclude during serialization when they arenot set + :param include: Fields to exclude during serialization + :param prefer_alias: Use the ``by_alias=True`` flag when dumping models + :param validate_strict: Use ``strict=True`` when calling ``.model_validate`` on Pydantic 2.x models """ + self.exclude = exclude + self.exclude_defaults = exclude_defaults + self.exclude_none = exclude_none + self.exclude_unset = exclude_unset + self.include = include self.prefer_alias = prefer_alias + self.validate_strict = validate_strict def on_app_init(self, app_config: AppConfig) -> AppConfig: """Configure application for use with Pydantic. @@ -61,7 +90,15 @@ def on_app_init(self, app_config: AppConfig) -> AppConfig: """ app_config.plugins.extend( [ - PydanticInitPlugin(prefer_alias=self.prefer_alias), + PydanticInitPlugin( + exclude=self.exclude, + exclude_defaults=self.exclude_defaults, + exclude_none=self.exclude_none, + exclude_unset=self.exclude_unset, + include=self.include, + prefer_alias=self.prefer_alias, + validate_strict=self.validate_strict, + ), PydanticSchemaPlugin(prefer_alias=self.prefer_alias), PydanticDIPlugin(), ] diff --git a/litestar/contrib/pydantic/pydantic_init_plugin.py b/litestar/contrib/pydantic/pydantic_init_plugin.py index 95bf53f171..1d425f3420 100644 --- a/litestar/contrib/pydantic/pydantic_init_plugin.py +++ b/litestar/contrib/pydantic/pydantic_init_plugin.py @@ -1,6 +1,7 @@ from __future__ import annotations from contextlib import suppress +from functools import partial from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast from uuid import UUID @@ -34,6 +35,7 @@ if TYPE_CHECKING: from litestar.config.app import AppConfig + from litestar.types.serialization import PydanticV1FieldsListType, PydanticV2FieldsListType T = TypeVar("T") @@ -46,9 +48,9 @@ def _dec_pydantic_v1(model_type: type[pydantic_v1.BaseModel], value: Any) -> pyd raise ExtendedMsgSpecValidationError(errors=cast("list[dict[str, Any]]", e.errors())) from e -def _dec_pydantic_v2(model_type: type[pydantic_v2.BaseModel], value: Any) -> pydantic_v2.BaseModel: +def _dec_pydantic_v2(model_type: type[pydantic_v2.BaseModel], value: Any, strict: bool) -> pydantic_v2.BaseModel: try: - return model_type.model_validate(value, strict=False) + return model_type.model_validate(value, strict=strict) except pydantic_v2.ValidationError as e: raise ExtendedMsgSpecValidationError(errors=cast("list[dict[str, Any]]", e.errors())) from e @@ -123,36 +125,116 @@ def extract(annotation: Any, default: Any) -> Any: class PydanticInitPlugin(InitPluginProtocol): - __slots__ = ("prefer_alias",) + __slots__ = ( + "exclude", + "exclude_defaults", + "exclude_none", + "exclude_unset", + "include", + "prefer_alias", + "validate_strict", + ) - def __init__(self, prefer_alias: bool = False) -> None: + def __init__( + self, + exclude: PydanticV1FieldsListType | PydanticV2FieldsListType | None = None, + exclude_defaults: bool = False, + exclude_none: bool = False, + exclude_unset: bool = False, + include: PydanticV1FieldsListType | PydanticV2FieldsListType | None = None, + prefer_alias: bool = False, + validate_strict: bool = False, + ) -> None: + """Pydantic Plugin to support serialization / validation of Pydantic types / models + + :param exclude: Fields to exclude during serialization + :param exclude_defaults: Fields to exclude during serialization when they are set to their default value + :param exclude_none: Fields to exclude during serialization when they are set to ``None`` + :param exclude_unset: Fields to exclude during serialization when they arenot set + :param include: Fields to exclude during serialization + :param prefer_alias: Use the ``by_alias=True`` flag when dumping models + :param validate_strict: Use ``strict=True`` when calling ``.model_validate`` on Pydantic 2.x models + """ + self.exclude = exclude + self.exclude_defaults = exclude_defaults + self.exclude_none = exclude_none + self.exclude_unset = exclude_unset + self.include = include self.prefer_alias = prefer_alias + self.validate_strict = validate_strict @classmethod - def encoders(cls, prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]: - encoders = {**_base_encoders, **cls._create_pydantic_v1_encoders(prefer_alias)} + def encoders( + cls, + exclude: PydanticV1FieldsListType | PydanticV2FieldsListType | None = None, + exclude_defaults: bool = False, + exclude_none: bool = False, + exclude_unset: bool = False, + include: PydanticV1FieldsListType | PydanticV2FieldsListType | None = None, + prefer_alias: bool = False, + ) -> dict[Any, Callable[[Any], Any]]: + encoders = { + **_base_encoders, + **cls._create_pydantic_v1_encoders( + prefer_alias=prefer_alias, + exclude=exclude, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + exclude_unset=exclude_unset, + include=include, + ), + } if pydantic_v2 is not None: # pragma: no cover - encoders.update(cls._create_pydantic_v2_encoders(prefer_alias)) + encoders.update( + cls._create_pydantic_v2_encoders( + prefer_alias=prefer_alias, + exclude=exclude, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + exclude_unset=exclude_unset, + include=include, + ) + ) return encoders @classmethod - def decoders(cls) -> list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]: + def decoders(cls, validate_strict: bool = False) -> list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]: decoders: list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]] = [ (is_pydantic_v1_model_class, _dec_pydantic_v1) ] if pydantic_v2 is not None: # pragma: no cover - decoders.append((is_pydantic_v2_model_class, _dec_pydantic_v2)) + decoders.append( + ( + is_pydantic_v2_model_class, + partial(_dec_pydantic_v2, strict=validate_strict), + ) + ) decoders.append((_is_pydantic_v1_uuid, _dec_pydantic_uuid)) return decoders @staticmethod - def _create_pydantic_v1_encoders(prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]: # pragma: no cover + def _create_pydantic_v1_encoders( + exclude: PydanticV1FieldsListType | None = None, + exclude_defaults: bool = False, + exclude_none: bool = False, + exclude_unset: bool = False, + include: PydanticV1FieldsListType | None = None, + prefer_alias: bool = False, + ) -> dict[Any, Callable[[Any], Any]]: # pragma: no cover return { pydantic_v1.BaseModel: lambda model: { - k: v.decode() if isinstance(v, bytes) else v for k, v in model.dict(by_alias=prefer_alias).items() + k: v.decode() if isinstance(v, bytes) else v + for k, v in model.dict( + by_alias=prefer_alias, + exclude=exclude, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + exclude_unset=exclude_unset, + include=include, + ).items() }, pydantic_v1.SecretField: str, pydantic_v1.StrictBool: int, @@ -163,9 +245,24 @@ def _create_pydantic_v1_encoders(prefer_alias: bool = False) -> dict[Any, Callab } @staticmethod - def _create_pydantic_v2_encoders(prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]: + def _create_pydantic_v2_encoders( + exclude: PydanticV2FieldsListType | None = None, + exclude_defaults: bool = False, + exclude_none: bool = False, + exclude_unset: bool = False, + include: PydanticV2FieldsListType | None = None, + prefer_alias: bool = False, + ) -> dict[Any, Callable[[Any], Any]]: encoders: dict[Any, Callable[[Any], Any]] = { - pydantic_v2.BaseModel: lambda model: model.model_dump(mode="json", by_alias=prefer_alias), + pydantic_v2.BaseModel: lambda model: model.model_dump( + by_alias=prefer_alias, + exclude=exclude, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + exclude_unset=exclude_unset, + include=include, + mode="json", + ), pydantic_v2.types.SecretStr: lambda val: "**********" if val else "", pydantic_v2.types.SecretBytes: lambda val: "**********" if val else "", pydantic_v2.AnyUrl: str, @@ -179,8 +276,21 @@ def _create_pydantic_v2_encoders(prefer_alias: bool = False) -> dict[Any, Callab return encoders def on_app_init(self, app_config: AppConfig) -> AppConfig: - app_config.type_encoders = {**self.encoders(self.prefer_alias), **(app_config.type_encoders or {})} - app_config.type_decoders = [*self.decoders(), *(app_config.type_decoders or [])] + app_config.type_encoders = { + **self.encoders( + prefer_alias=self.prefer_alias, + exclude=self.exclude, + exclude_defaults=self.exclude_defaults, + exclude_none=self.exclude_none, + exclude_unset=self.exclude_unset, + include=self.include, + ), + **(app_config.type_encoders or {}), + } + app_config.type_decoders = [ + *self.decoders(validate_strict=self.validate_strict), + *(app_config.type_decoders or []), + ] _KWARG_META_EXTRACTORS.add(ConstrainedFieldMetaExtractor) return app_config diff --git a/litestar/types/serialization.py b/litestar/types/serialization.py index 0f61e10533..15ec4308ec 100644 --- a/litestar/types/serialization.py +++ b/litestar/types/serialization.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Dict, Set if TYPE_CHECKING: from collections import deque @@ -28,8 +28,13 @@ try: from pydantic import BaseModel + from pydantic.main import IncEx + from pydantic.typing import AbstractSetIntStr, MappingIntStrAny except ImportError: BaseModel = Any # type: ignore[assignment, misc] + IncEx = Any # type: ignore[misc] + AbstractSetIntStr = Any + MappingIntStrAny = Any try: from attrs import AttrsInstance @@ -57,3 +62,5 @@ EncodableMsgSpecType: TypeAlias = "Ext | Raw | Struct" LitestarEncodableType: TypeAlias = "EncodableBuiltinType | EncodableBuiltinCollectionType | EncodableStdLibType | EncodableStdLibIPType | EncodableMsgSpecType | BaseModel | AttrsInstance" # pyright: ignore DataContainerType: TypeAlias = "Struct | BaseModel | AttrsInstance | TypedDictClass | DataclassProtocol" # pyright: ignore +PydanticV2FieldsListType: TypeAlias = "Set[int] | Set[str] | Dict[int, Any] | Dict[str, Any]" +PydanticV1FieldsListType: TypeAlias = "IncEx | AbstractSetIntStr | MappingIntStrAny" # pyright: ignore diff --git a/tests/unit/test_contrib/test_pydantic/test_integration.py b/tests/unit/test_contrib/test_pydantic/test_integration.py index 9cc9d285c3..2c054cec48 100644 --- a/tests/unit/test_contrib/test_pydantic/test_integration.py +++ b/tests/unit/test_contrib/test_pydantic/test_integration.py @@ -7,6 +7,7 @@ from typing_extensions import Annotated from litestar import post +from litestar.contrib.pydantic import PydanticInitPlugin, PydanticPlugin from litestar.contrib.pydantic.pydantic_dto_factory import PydanticDTO from litestar.enums import RequestEncodingType from litestar.params import Body, Parameter @@ -305,3 +306,85 @@ async def handler(data: Model) -> Model: res = client.post("/", json={"foo": in_}) assert res.status_code == 201 assert res.json() == {"foo": in_} + + +@pytest.mark.parametrize( + "plugin_params, response", + ( + ( + {"exclude": {"alias"}}, + { + "none": None, + "default": "default", + }, + ), + ({"exclude_defaults": True}, {"alias": "prefer_alias"}), + ({"exclude_none": True}, {"alias": "prefer_alias", "default": "default"}), + ({"exclude_unset": True}, {"alias": "prefer_alias"}), + ({"include": {"alias"}}, {"alias": "prefer_alias"}), + ({"prefer_alias": True}, {"prefer_alias": "prefer_alias", "default": "default", "none": None}), + ), + ids=( + "Exclude alias field", + "Exclude default fields", + "Exclude None field", + "Exclude unset fields", + "Include alias field", + "Use alias in response", + ), +) +def test_params_with_v1_and_v2_models(plugin_params: dict, response: dict) -> None: + class ModelV1(pydantic_v1.BaseModel): # pyright: ignore + alias: str = pydantic_v1.fields.Field(alias="prefer_alias") + default: str = "default" + none: None = None + + class Config: + allow_population_by_field_name = True + + class ModelV2(pydantic_v2.BaseModel): + alias: str = pydantic_v2.fields.Field(serialization_alias="prefer_alias") + default: str = "default" + none: None = None + + @post("/v1") + async def handler_v1() -> ModelV1: + return ModelV1(alias="prefer_alias") # type: ignore[call-arg] + + @post("/v2") + async def handler_v2() -> ModelV2: + return ModelV2(alias="prefer_alias") + + with create_test_client([handler_v1, handler_v2], plugins=[PydanticPlugin(**plugin_params)]) as client: + assert client.post("/v1").json() == response + assert client.post("/v2").json() == response + + +@pytest.mark.parametrize( + "validate_strict,expect_error", + [ + (False, False), + (None, False), + (True, True), + ], +) +def test_v2_strict_validate( + validate_strict: bool, + expect_error: bool, +) -> None: + # https://github.com/litestar-org/litestar/issues/3572 + + class Model(pydantic_v2.BaseModel): + test_bool: pydantic_v2.StrictBool + + @post("/") + async def handler(data: Model) -> None: + return None + + plugins = [] + if validate_strict is not None: + plugins.append(PydanticInitPlugin(validate_strict=validate_strict)) + + with create_test_client([handler], plugins=plugins) as client: + res = client.post("/", json={"test_bool": "YES"}) + assert res.status_code == 400 if expect_error else 201