From 7d581a74b38f6bc8bb337e0c07f8d0e2a9c01ce3 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Thu, 5 Oct 2023 18:43:04 -0500 Subject: [PATCH 1/6] feat: Adds a top-level pydantic plugin --- litestar/contrib/pydantic/__init__.py | 37 +++++++++++++++++-- litestar/contrib/pydantic/config.py | 0 .../contrib/pydantic/pydantic_init_plugin.py | 21 +++++++---- .../pydantic/pydantic_schema_plugin.py | 7 ++++ .../test_plugin_serialization.py | 28 ++++++++++++++ 5 files changed, 82 insertions(+), 11 deletions(-) create mode 100644 litestar/contrib/pydantic/config.py diff --git a/litestar/contrib/pydantic/__init__.py b/litestar/contrib/pydantic/__init__.py index de217e56f7..bf83416e43 100644 --- a/litestar/contrib/pydantic/__init__.py +++ b/litestar/contrib/pydantic/__init__.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, Any +from litestar.plugins import InitPluginProtocol + from .pydantic_dto_factory import PydanticDTO from .pydantic_init_plugin import PydanticInitPlugin from .pydantic_schema_plugin import PydanticSchemaPlugin @@ -9,7 +11,9 @@ if TYPE_CHECKING: import pydantic -__all__ = ("PydanticDTO", "PydanticInitPlugin", "PydanticSchemaPlugin") + from litestar.config.app import AppConfig + +__all__ = ("PydanticDTO", "PydanticInitPlugin", "PydanticSchemaPlugin", "PydanticPlugin") def _model_dump(model: pydantic.BaseModel, *, by_alias: bool = False) -> dict[str, Any]: @@ -20,5 +24,32 @@ def _model_dump(model: pydantic.BaseModel, *, by_alias: bool = False) -> dict[st ) -def _model_dump_json(model: pydantic.BaseModel) -> str: - return model.model_dump_json() if hasattr(model, "model_dump_json") else model.json() +def _model_dump_json(model: pydantic.BaseModel, by_alias: bool = False) -> str: + return ( + model.model_dump_json(by_alias=by_alias) if hasattr(model, "model_dump_json") else model.json(by_alias=by_alias) + ) + + +class PydanticPlugin(InitPluginProtocol): + """A plugin that provides Pydantic integration.""" + + __slots__ = ("prefer_alias",) + + def __init__(self, prefer_alias: bool = False) -> None: + """Initialize ``PydanticPlugin``. + + Args: + prefer_alias: OpenAPI and type_encoders will export by alias + """ + self.prefer_alias = prefer_alias + + def on_app_init(self, app_config: AppConfig) -> AppConfig: + """Configure application for use with Pydantic. + + Args: + app_config: The :class:`AppConfig <.config.app.AppConfig>` instance. + """ + app_config.plugins.extend( + [PydanticInitPlugin(prefer_alias=self.prefer_alias), PydanticSchemaPlugin(prefer_alias=self.prefer_alias)] + ) + return app_config diff --git a/litestar/contrib/pydantic/config.py b/litestar/contrib/pydantic/config.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/litestar/contrib/pydantic/pydantic_init_plugin.py b/litestar/contrib/pydantic/pydantic_init_plugin.py index 4a57bf5aba..eb7b851f93 100644 --- a/litestar/contrib/pydantic/pydantic_init_plugin.py +++ b/litestar/contrib/pydantic/pydantic_init_plugin.py @@ -71,11 +71,16 @@ def _is_pydantic_uuid(value: Any) -> bool: # pragma: no cover class PydanticInitPlugin(InitPluginProtocol): + __slots__ = ("prefer_alias",) + + def __init__(self, prefer_alias: bool = False) -> None: + self.prefer_alias = prefer_alias + @classmethod - def encoders(cls) -> dict[Any, Callable[[Any], Any]]: + def encoders(cls, prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]: if pydantic.VERSION.startswith("1"): # pragma: no cover - return {**_base_encoders, **cls._create_pydantic_v1_encoders()} - return {**_base_encoders, **cls._create_pydantic_v2_encoders()} + return {**_base_encoders, **cls._create_pydantic_v1_encoders(prefer_alias)} + return {**_base_encoders, **cls._create_pydantic_v2_encoders(prefer_alias)} @classmethod def decoders(cls) -> list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]: @@ -89,10 +94,10 @@ def decoders(cls) -> list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any] return decoders @staticmethod - def _create_pydantic_v1_encoders() -> dict[Any, Callable[[Any], Any]]: # pragma: no cover + def _create_pydantic_v1_encoders(prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]: # pragma: no cover return { pydantic.BaseModel: lambda model: { - k: v.decode() if isinstance(v, bytes) else v for k, v in model.dict().items() + k: v.decode() if isinstance(v, bytes) else v for k, v in model.dict(by_alias=prefer_alias).items() }, pydantic.SecretField: str, pydantic.StrictBool: int, @@ -102,9 +107,9 @@ def _create_pydantic_v1_encoders() -> dict[Any, Callable[[Any], Any]]: # pragma } @staticmethod - def _create_pydantic_v2_encoders() -> dict[Any, Callable[[Any], Any]]: + def _create_pydantic_v2_encoders(prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]: encoders: dict[Any, Callable[[Any], Any]] = { - pydantic.BaseModel: lambda model: model.model_dump(mode="json"), + pydantic.BaseModel: lambda model: model.model_dump(mode="json", by_alias=prefer_alias), pydantic.types.SecretStr: lambda val: "**********" if val else "", pydantic.types.SecretBytes: lambda val: "**********" if val else "", } @@ -117,6 +122,6 @@ def _create_pydantic_v2_encoders() -> dict[Any, Callable[[Any], Any]]: return encoders def on_app_init(self, app_config: AppConfig) -> AppConfig: - app_config.type_encoders = {**self.encoders(), **(app_config.type_encoders or {})} + app_config.type_encoders = {**self.encoders(self.prefer_alias), **(app_config.type_encoders or {})} app_config.type_decoders = [*self.decoders(), *(app_config.type_decoders or [])] return app_config diff --git a/litestar/contrib/pydantic/pydantic_schema_plugin.py b/litestar/contrib/pydantic/pydantic_schema_plugin.py index b644fcb570..eaf04131e2 100644 --- a/litestar/contrib/pydantic/pydantic_schema_plugin.py +++ b/litestar/contrib/pydantic/pydantic_schema_plugin.py @@ -132,6 +132,11 @@ class PydanticSchemaPlugin(OpenAPISchemaPluginProtocol): + __slots__ = ("prefer_alias",) + + def __init__(self, prefer_alias: bool = False) -> None: + self.prefer_alias = prefer_alias + @staticmethod def is_plugin_supported_type(value: Any) -> bool: return isinstance(value, _supported_types) or is_class_and_subclass(value, _supported_types) # type: ignore @@ -146,6 +151,8 @@ def to_openapi_schema(self, field_definition: FieldDefinition, schema_creator: S Returns: An :class:`OpenAPI ` instance. """ + if schema_creator.prefer_alias != self.prefer_alias: + schema_creator.prefer_alias = True if is_pydantic_model_class(field_definition.annotation): return self.for_pydantic_model(annotation=field_definition.annotation, schema_creator=schema_creator) return PYDANTIC_TYPE_MAP[field_definition.annotation] # pragma: no cover diff --git a/tests/unit/test_contrib/test_pydantic/test_plugin_serialization.py b/tests/unit/test_contrib/test_pydantic/test_plugin_serialization.py index d730d33f47..f1d3fb125d 100644 --- a/tests/unit/test_contrib/test_pydantic/test_plugin_serialization.py +++ b/tests/unit/test_contrib/test_pydantic/test_plugin_serialization.py @@ -191,6 +191,22 @@ def test_pydantic_json_compatibility(model: BaseModel) -> None: assert raw_result == encoded_result +def test_pydantic_json_by_alias_compatibility(model: BaseModel) -> None: + raw = _model_dump_json(model, by_alias=True) + encoded_json = encode_json(model, serializer=get_serializer(PydanticInitPlugin.encoders(prefer_alias=True))) + + raw_result = json.loads(raw) + encoded_result = json.loads(encoded_json) + + if VERSION.startswith("1"): + # pydantic v1 dumps decimals into floats as json, we therefore regard this as an error + assert raw_result.get("condecimal") == float(encoded_result.get("condecimal")) + del raw_result["condecimal"] + del encoded_result["condecimal"] + + assert raw_result == encoded_result + + @pytest.mark.parametrize("encoder", [encode_json, encode_msgpack]) def test_encoder_raises_serialization_exception(model: BaseModel, encoder: Any) -> None: with pytest.raises(SerializationException): @@ -219,3 +235,15 @@ def test_decode_msgpack_typed(model: BaseModel) -> None: ).json() == model_json ) + + +def test_decode_msgpack_typed_aliased(model: BaseModel) -> None: + model_json = _model_dump_json(model, by_alias=True) + assert ( + decode_msgpack( + encode_msgpack(model, serializer=get_serializer(PydanticInitPlugin.encoders(prefer_alias=True))), + Model, + type_decoders=PydanticInitPlugin.decoders(), + ).json() + == model_json + ) From 92e3d3751adc9da926499c60641bb8b48bb5fdc7 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Thu, 5 Oct 2023 19:38:10 -0500 Subject: [PATCH 2/6] fix: override default plugin setting --- litestar/app.py | 24 +++++--- .../test_plugin_serialization.py | 58 +++++++------------ tests/unit/test_openapi/test_config.py | 40 +++++++++++++ 3 files changed, 77 insertions(+), 45 deletions(-) diff --git a/litestar/app.py b/litestar/app.py index e50e808fa7..4a302ecbfc 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -341,7 +341,9 @@ def __init__( opt=dict(opt or {}), parameters=parameters or {}, pdb_on_exception=pdb_on_exception, - plugins=[*(plugins or []), *self._get_default_plugins()], + plugins=[ + *self._get_default_plugins(list(plugins or [])), + ], request_class=request_class, response_cache_config=response_cache_config or ResponseCacheConfig(), response_class=response_class, @@ -466,18 +468,24 @@ def serialization_plugins(self) -> list[SerializationPluginProtocol]: return list(self.plugins.serialization) @staticmethod - def _get_default_plugins() -> list[PluginProtocol]: - default_plugins: list[PluginProtocol] = [] + def _get_default_plugins(plugins: list[PluginProtocol] | None = None) -> list[PluginProtocol]: + if plugins is None: + plugins = [] with suppress(MissingDependencyException): - from litestar.contrib.pydantic import PydanticInitPlugin, PydanticSchemaPlugin - - default_plugins.extend((PydanticInitPlugin(), PydanticSchemaPlugin())) + from litestar.contrib.pydantic import PydanticInitPlugin, PydanticPlugin, PydanticSchemaPlugin + pre_configured = any( + isinstance(plugin, (PydanticPlugin, PydanticInitPlugin, PydanticSchemaPlugin)) for plugin in plugins + ) + if not pre_configured: + plugins.append(PydanticPlugin()) with suppress(MissingDependencyException): from litestar.contrib.attrs import AttrsSchemaPlugin - default_plugins.append(AttrsSchemaPlugin()) - return default_plugins + pre_configured = any(isinstance(plugin, AttrsSchemaPlugin) for plugin in plugins) + if not pre_configured: + plugins.append(AttrsSchemaPlugin()) + return plugins @property def debug(self) -> bool: diff --git a/tests/unit/test_contrib/test_pydantic/test_plugin_serialization.py b/tests/unit/test_contrib/test_pydantic/test_plugin_serialization.py index f1d3fb125d..ba2c52d4b2 100644 --- a/tests/unit/test_contrib/test_pydantic/test_plugin_serialization.py +++ b/tests/unit/test_contrib/test_pydantic/test_plugin_serialization.py @@ -175,25 +175,13 @@ def test_serialization_of_model_instance(model: BaseModel) -> None: assert serializer(model) == _model_dump(model) -def test_pydantic_json_compatibility(model: BaseModel) -> None: - raw = _model_dump_json(model) - encoded_json = encode_json(model, serializer=get_serializer(PydanticInitPlugin.encoders())) - - raw_result = json.loads(raw) - encoded_result = json.loads(encoded_json) - - if VERSION.startswith("1"): - # pydantic v1 dumps decimals into floats as json, we therefore regard this as an error - assert raw_result.get("condecimal") == float(encoded_result.get("condecimal")) - del raw_result["condecimal"] - del encoded_result["condecimal"] - - assert raw_result == encoded_result - - -def test_pydantic_json_by_alias_compatibility(model: BaseModel) -> None: - raw = _model_dump_json(model, by_alias=True) - encoded_json = encode_json(model, serializer=get_serializer(PydanticInitPlugin.encoders(prefer_alias=True))) +@pytest.mark.parametrize( + "prefer_alias", + [(False), (True)], +) +def test_pydantic_json_compatibility(model: BaseModel, prefer_alias: bool) -> None: + raw = _model_dump_json(model, by_alias=prefer_alias) + encoded_json = encode_json(model, serializer=get_serializer(PydanticInitPlugin.encoders(prefer_alias=prefer_alias))) raw_result = json.loads(raw) encoded_result = json.loads(encoded_json) @@ -219,29 +207,25 @@ def test_decode_json_raises_serialization_exception(model: BaseModel, decoder: A decoder(b"str") -def test_decode_json_typed(model: BaseModel) -> None: - dumped_model = _model_dump_json(model) +@pytest.mark.parametrize( + "prefer_alias", + [(False), (True)], +) +def test_decode_json_typed(model: BaseModel, prefer_alias: bool) -> None: + dumped_model = _model_dump_json(model, by_alias=prefer_alias) decoded_model = decode_json(value=dumped_model, target_type=Model, type_decoders=PydanticInitPlugin.decoders()) - assert _model_dump_json(decoded_model) == dumped_model - - -def test_decode_msgpack_typed(model: BaseModel) -> None: - model_json = _model_dump_json(model) - assert ( - decode_msgpack( - encode_msgpack(model, serializer=get_serializer(PydanticInitPlugin.encoders())), - Model, - type_decoders=PydanticInitPlugin.decoders(), - ).json() - == model_json - ) + assert _model_dump_json(decoded_model, by_alias=prefer_alias) == dumped_model -def test_decode_msgpack_typed_aliased(model: BaseModel) -> None: - model_json = _model_dump_json(model, by_alias=True) +@pytest.mark.parametrize( + "prefer_alias", + [(False), (True)], +) +def test_decode_msgpack_typed(model: BaseModel, prefer_alias: bool) -> None: + model_json = _model_dump_json(model, by_alias=prefer_alias) assert ( decode_msgpack( - encode_msgpack(model, serializer=get_serializer(PydanticInitPlugin.encoders(prefer_alias=True))), + encode_msgpack(model, serializer=get_serializer(PydanticInitPlugin.encoders(prefer_alias=prefer_alias))), Model, type_decoders=PydanticInitPlugin.decoders(), ).json() diff --git a/tests/unit/test_openapi/test_config.py b/tests/unit/test_openapi/test_config.py index 24c9e7e8a8..83fb90c594 100644 --- a/tests/unit/test_openapi/test_config.py +++ b/tests/unit/test_openapi/test_config.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field from litestar import Litestar, get, post +from litestar.contrib.pydantic import PydanticPlugin from litestar.exceptions import ImproperlyConfiguredException from litestar.openapi.config import OpenAPIConfig from litestar.openapi.spec import Components, Example, OpenAPIHeader, OpenAPIType, Schema @@ -83,6 +84,45 @@ def handler(data: RequestWithAlias) -> ResponseWithAlias: assert response.json() == {response_key: "foo"} +def test_pydantic_plugin_override_by_alias() -> None: + class RequestWithAlias(BaseModel): + first: str = Field(alias="second") + + class ResponseWithAlias(BaseModel): + first: str = Field(alias="second") + + @post("/") + def handler(data: RequestWithAlias) -> ResponseWithAlias: + return ResponseWithAlias(second=data.first) + + app = Litestar( + route_handlers=[handler], + openapi_config=OpenAPIConfig(title="my title", version="1.0.0"), + plugins=[PydanticPlugin(prefer_alias=True)], + ) + + assert app.openapi_schema + schemas = app.openapi_schema.to_schema()["components"]["schemas"] + request_key = "second" + assert schemas["RequestWithAlias"] == { + "properties": {request_key: {"type": "string"}}, + "type": "object", + "required": [request_key], + "title": "RequestWithAlias", + } + response_key = "second" + assert schemas["ResponseWithAlias"] == { + "properties": {response_key: {"type": "string"}}, + "type": "object", + "required": [response_key], + "title": "ResponseWithAlias", + } + + with TestClient(app) as client: + response = client.post("/", json={request_key: "foo"}) + assert response.json() == {response_key: "foo"} + + def test_allows_customization_of_operation_id_creator() -> None: def operation_id_creator(handler: "HTTPRouteHandler", _: Any, __: Any) -> str: return handler.name or "" From b8a2b1631d0c0bfea6c0085fe552ebe4805dd018 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Thu, 5 Oct 2023 19:43:02 -0500 Subject: [PATCH 3/6] chore: linting --- .github/PULL_REQUEST_TEMPLATE.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 68ff2cd56b..59184125e7 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -18,7 +18,7 @@ By submitting this pull request, you agree to: Please describe your pull request for new release changelog purposes --> -- +- ### Close Issue(s) -- +- From ff153fd60b743da57972f08d9712c2b411e3194d Mon Sep 17 00:00:00 2001 From: Jacob Coffee Date: Thu, 5 Oct 2023 20:14:37 -0500 Subject: [PATCH 4/6] Update litestar/contrib/pydantic/__init__.py --- litestar/contrib/pydantic/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litestar/contrib/pydantic/__init__.py b/litestar/contrib/pydantic/__init__.py index bf83416e43..3c3d0dcf9e 100644 --- a/litestar/contrib/pydantic/__init__.py +++ b/litestar/contrib/pydantic/__init__.py @@ -39,7 +39,7 @@ def __init__(self, prefer_alias: bool = False) -> None: """Initialize ``PydanticPlugin``. Args: - prefer_alias: OpenAPI and type_encoders will export by alias + prefer_alias: OpenAPI and ``type_encoders`` will export by alias """ self.prefer_alias = prefer_alias From 09c11a8f9d29e2853a6223a71811a7d51e78a00f Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Thu, 5 Oct 2023 21:44:13 -0500 Subject: [PATCH 5/6] fix: remove useless unpack --- litestar/app.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/litestar/app.py b/litestar/app.py index 4a302ecbfc..02ffb8cc98 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -341,9 +341,7 @@ def __init__( opt=dict(opt or {}), parameters=parameters or {}, pdb_on_exception=pdb_on_exception, - plugins=[ - *self._get_default_plugins(list(plugins or [])), - ], + plugins=self._get_default_plugins(list(plugins or [])), request_class=request_class, response_cache_config=response_cache_config or ResponseCacheConfig(), response_class=response_class, From 8aefa8450793d14acb99d2971e842cba4ed561e2 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 6 Oct 2023 10:09:46 -0500 Subject: [PATCH 6/6] fix: properly load plugins based on what is pre-configured --- litestar/app.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/litestar/app.py b/litestar/app.py index 02ffb8cc98..c16a9eb4d2 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -472,11 +472,15 @@ def _get_default_plugins(plugins: list[PluginProtocol] | None = None) -> list[Pl with suppress(MissingDependencyException): from litestar.contrib.pydantic import PydanticInitPlugin, PydanticPlugin, PydanticSchemaPlugin - pre_configured = any( - isinstance(plugin, (PydanticPlugin, PydanticInitPlugin, PydanticSchemaPlugin)) for plugin in plugins - ) - if not pre_configured: + pydantic_plugin_found = any(isinstance(plugin, PydanticPlugin) for plugin in plugins) + pydantic_init_plugin_found = any(isinstance(plugin, PydanticInitPlugin) for plugin in plugins) + pydantic_schema_plugin_found = any(isinstance(plugin, PydanticSchemaPlugin) for plugin in plugins) + if not pydantic_plugin_found and not pydantic_init_plugin_found and not pydantic_schema_plugin_found: plugins.append(PydanticPlugin()) + elif not pydantic_plugin_found and pydantic_init_plugin_found and not pydantic_schema_plugin_found: + plugins.append(PydanticSchemaPlugin()) + elif not pydantic_plugin_found and not pydantic_init_plugin_found and pydantic_schema_plugin_found: + plugins.append(PydanticInitPlugin()) with suppress(MissingDependencyException): from litestar.contrib.attrs import AttrsSchemaPlugin