Skip to content

Commit

Permalink
feat: allow customization of Pydantic integration (#2404)
Browse files Browse the repository at this point in the history
* feat: Adds a top-level pydantic plugin

* fix: override default plugin setting

* chore: linting

* Update litestar/contrib/pydantic/__init__.py

* fix: remove useless unpack

* fix: properly load plugins based on what is pre-configured
  • Loading branch information
cofin committed Oct 6, 2023
1 parent 690b662 commit 2222d2d
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 31 deletions.
4 changes: 2 additions & 2 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ By submitting this pull request, you agree to:
Please describe your pull request for new release changelog purposes
-->

-
-

### Close Issue(s)
<!--
Please add in issue numbers this pull request will close, if applicable
Examples: Fixes #4321 or Closes #1234
-->

-
-
28 changes: 19 additions & 9 deletions litestar/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def __init__(
opt=dict(opt or {}),
parameters=parameters or {},
pdb_on_exception=pdb_on_exception,
plugins=[*(plugins or []), *self._get_default_plugins()],
plugins=self._get_default_plugins(list(plugins or [])),
request_class=request_class,
response_cache_config=response_cache_config or ResponseCacheConfig(),
response_class=response_class,
Expand Down Expand Up @@ -466,18 +466,28 @@ def serialization_plugins(self) -> list[SerializationPluginProtocol]:
return list(self.plugins.serialization)

@staticmethod
def _get_default_plugins() -> list[PluginProtocol]:
default_plugins: list[PluginProtocol] = []
def _get_default_plugins(plugins: list[PluginProtocol] | None = None) -> list[PluginProtocol]:
if plugins is None:
plugins = []
with suppress(MissingDependencyException):
from litestar.contrib.pydantic import PydanticInitPlugin, PydanticSchemaPlugin

default_plugins.extend((PydanticInitPlugin(), PydanticSchemaPlugin()))

from litestar.contrib.pydantic import PydanticInitPlugin, PydanticPlugin, PydanticSchemaPlugin

pydantic_plugin_found = any(isinstance(plugin, PydanticPlugin) for plugin in plugins)
pydantic_init_plugin_found = any(isinstance(plugin, PydanticInitPlugin) for plugin in plugins)
pydantic_schema_plugin_found = any(isinstance(plugin, PydanticSchemaPlugin) for plugin in plugins)
if not pydantic_plugin_found and not pydantic_init_plugin_found and not pydantic_schema_plugin_found:
plugins.append(PydanticPlugin())
elif not pydantic_plugin_found and pydantic_init_plugin_found and not pydantic_schema_plugin_found:
plugins.append(PydanticSchemaPlugin())
elif not pydantic_plugin_found and not pydantic_init_plugin_found and pydantic_schema_plugin_found:
plugins.append(PydanticInitPlugin())
with suppress(MissingDependencyException):
from litestar.contrib.attrs import AttrsSchemaPlugin

default_plugins.append(AttrsSchemaPlugin())
return default_plugins
pre_configured = any(isinstance(plugin, AttrsSchemaPlugin) for plugin in plugins)
if not pre_configured:
plugins.append(AttrsSchemaPlugin())
return plugins

@property
def debug(self) -> bool:
Expand Down
37 changes: 34 additions & 3 deletions litestar/contrib/pydantic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@

from typing import TYPE_CHECKING, Any

from litestar.plugins import InitPluginProtocol

from .pydantic_dto_factory import PydanticDTO
from .pydantic_init_plugin import PydanticInitPlugin
from .pydantic_schema_plugin import PydanticSchemaPlugin

if TYPE_CHECKING:
import pydantic

__all__ = ("PydanticDTO", "PydanticInitPlugin", "PydanticSchemaPlugin")
from litestar.config.app import AppConfig

__all__ = ("PydanticDTO", "PydanticInitPlugin", "PydanticSchemaPlugin", "PydanticPlugin")


def _model_dump(model: pydantic.BaseModel, *, by_alias: bool = False) -> dict[str, Any]:
Expand All @@ -20,5 +24,32 @@ def _model_dump(model: pydantic.BaseModel, *, by_alias: bool = False) -> dict[st
)


def _model_dump_json(model: pydantic.BaseModel) -> str:
return model.model_dump_json() if hasattr(model, "model_dump_json") else model.json()
def _model_dump_json(model: pydantic.BaseModel, by_alias: bool = False) -> str:
return (
model.model_dump_json(by_alias=by_alias) if hasattr(model, "model_dump_json") else model.json(by_alias=by_alias)
)


class PydanticPlugin(InitPluginProtocol):
"""A plugin that provides Pydantic integration."""

__slots__ = ("prefer_alias",)

def __init__(self, prefer_alias: bool = False) -> None:
"""Initialize ``PydanticPlugin``.
Args:
prefer_alias: OpenAPI and ``type_encoders`` will export by alias
"""
self.prefer_alias = prefer_alias

def on_app_init(self, app_config: AppConfig) -> AppConfig:
"""Configure application for use with Pydantic.
Args:
app_config: The :class:`AppConfig <.config.app.AppConfig>` instance.
"""
app_config.plugins.extend(
[PydanticInitPlugin(prefer_alias=self.prefer_alias), PydanticSchemaPlugin(prefer_alias=self.prefer_alias)]
)
return app_config
Empty file.
21 changes: 13 additions & 8 deletions litestar/contrib/pydantic/pydantic_init_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,16 @@ def _is_pydantic_uuid(value: Any) -> bool: # pragma: no cover


class PydanticInitPlugin(InitPluginProtocol):
__slots__ = ("prefer_alias",)

def __init__(self, prefer_alias: bool = False) -> None:
self.prefer_alias = prefer_alias

@classmethod
def encoders(cls) -> dict[Any, Callable[[Any], Any]]:
def encoders(cls, prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]:
if pydantic.VERSION.startswith("1"): # pragma: no cover
return {**_base_encoders, **cls._create_pydantic_v1_encoders()}
return {**_base_encoders, **cls._create_pydantic_v2_encoders()}
return {**_base_encoders, **cls._create_pydantic_v1_encoders(prefer_alias)}
return {**_base_encoders, **cls._create_pydantic_v2_encoders(prefer_alias)}

@classmethod
def decoders(cls) -> list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]]:
Expand All @@ -89,10 +94,10 @@ def decoders(cls) -> list[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]
return decoders

@staticmethod
def _create_pydantic_v1_encoders() -> dict[Any, Callable[[Any], Any]]: # pragma: no cover
def _create_pydantic_v1_encoders(prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]: # pragma: no cover
return {
pydantic.BaseModel: lambda model: {
k: v.decode() if isinstance(v, bytes) else v for k, v in model.dict().items()
k: v.decode() if isinstance(v, bytes) else v for k, v in model.dict(by_alias=prefer_alias).items()
},
pydantic.SecretField: str,
pydantic.StrictBool: int,
Expand All @@ -102,9 +107,9 @@ def _create_pydantic_v1_encoders() -> dict[Any, Callable[[Any], Any]]: # pragma
}

@staticmethod
def _create_pydantic_v2_encoders() -> dict[Any, Callable[[Any], Any]]:
def _create_pydantic_v2_encoders(prefer_alias: bool = False) -> dict[Any, Callable[[Any], Any]]:
encoders: dict[Any, Callable[[Any], Any]] = {
pydantic.BaseModel: lambda model: model.model_dump(mode="json"),
pydantic.BaseModel: lambda model: model.model_dump(mode="json", by_alias=prefer_alias),
pydantic.types.SecretStr: lambda val: "**********" if val else "",
pydantic.types.SecretBytes: lambda val: "**********" if val else "",
}
Expand All @@ -117,6 +122,6 @@ def _create_pydantic_v2_encoders() -> dict[Any, Callable[[Any], Any]]:
return encoders

def on_app_init(self, app_config: AppConfig) -> AppConfig:
app_config.type_encoders = {**self.encoders(), **(app_config.type_encoders or {})}
app_config.type_encoders = {**self.encoders(self.prefer_alias), **(app_config.type_encoders or {})}
app_config.type_decoders = [*self.decoders(), *(app_config.type_decoders or [])]
return app_config
7 changes: 7 additions & 0 deletions litestar/contrib/pydantic/pydantic_schema_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@


class PydanticSchemaPlugin(OpenAPISchemaPluginProtocol):
__slots__ = ("prefer_alias",)

def __init__(self, prefer_alias: bool = False) -> None:
self.prefer_alias = prefer_alias

@staticmethod
def is_plugin_supported_type(value: Any) -> bool:
return isinstance(value, _supported_types) or is_class_and_subclass(value, _supported_types) # type: ignore
Expand All @@ -146,6 +151,8 @@ def to_openapi_schema(self, field_definition: FieldDefinition, schema_creator: S
Returns:
An :class:`OpenAPI <litestar.openapi.spec.schema.Schema>` instance.
"""
if schema_creator.prefer_alias != self.prefer_alias:
schema_creator.prefer_alias = True
if is_pydantic_model_class(field_definition.annotation):
return self.for_pydantic_model(annotation=field_definition.annotation, schema_creator=schema_creator)
return PYDANTIC_TYPE_MAP[field_definition.annotation] # pragma: no cover
Expand Down
30 changes: 21 additions & 9 deletions tests/unit/test_contrib/test_pydantic/test_plugin_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,13 @@ def test_serialization_of_model_instance(model: BaseModel) -> None:
assert serializer(model) == _model_dump(model)


def test_pydantic_json_compatibility(model: BaseModel) -> None:
raw = _model_dump_json(model)
encoded_json = encode_json(model, serializer=get_serializer(PydanticInitPlugin.encoders()))
@pytest.mark.parametrize(
"prefer_alias",
[(False), (True)],
)
def test_pydantic_json_compatibility(model: BaseModel, prefer_alias: bool) -> None:
raw = _model_dump_json(model, by_alias=prefer_alias)
encoded_json = encode_json(model, serializer=get_serializer(PydanticInitPlugin.encoders(prefer_alias=prefer_alias)))

raw_result = json.loads(raw)
encoded_result = json.loads(encoded_json)
Expand All @@ -203,17 +207,25 @@ def test_decode_json_raises_serialization_exception(model: BaseModel, decoder: A
decoder(b"str")


def test_decode_json_typed(model: BaseModel) -> None:
dumped_model = _model_dump_json(model)
@pytest.mark.parametrize(
"prefer_alias",
[(False), (True)],
)
def test_decode_json_typed(model: BaseModel, prefer_alias: bool) -> None:
dumped_model = _model_dump_json(model, by_alias=prefer_alias)
decoded_model = decode_json(value=dumped_model, target_type=Model, type_decoders=PydanticInitPlugin.decoders())
assert _model_dump_json(decoded_model) == dumped_model
assert _model_dump_json(decoded_model, by_alias=prefer_alias) == dumped_model


def test_decode_msgpack_typed(model: BaseModel) -> None:
model_json = _model_dump_json(model)
@pytest.mark.parametrize(
"prefer_alias",
[(False), (True)],
)
def test_decode_msgpack_typed(model: BaseModel, prefer_alias: bool) -> None:
model_json = _model_dump_json(model, by_alias=prefer_alias)
assert (
decode_msgpack(
encode_msgpack(model, serializer=get_serializer(PydanticInitPlugin.encoders())),
encode_msgpack(model, serializer=get_serializer(PydanticInitPlugin.encoders(prefer_alias=prefer_alias))),
Model,
type_decoders=PydanticInitPlugin.decoders(),
).json()
Expand Down
40 changes: 40 additions & 0 deletions tests/unit/test_openapi/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic import BaseModel, Field

from litestar import Litestar, get, post
from litestar.contrib.pydantic import PydanticPlugin
from litestar.exceptions import ImproperlyConfiguredException
from litestar.openapi.config import OpenAPIConfig
from litestar.openapi.spec import Components, Example, OpenAPIHeader, OpenAPIType, Schema
Expand Down Expand Up @@ -83,6 +84,45 @@ def handler(data: RequestWithAlias) -> ResponseWithAlias:
assert response.json() == {response_key: "foo"}


def test_pydantic_plugin_override_by_alias() -> None:
class RequestWithAlias(BaseModel):
first: str = Field(alias="second")

class ResponseWithAlias(BaseModel):
first: str = Field(alias="second")

@post("/")
def handler(data: RequestWithAlias) -> ResponseWithAlias:
return ResponseWithAlias(second=data.first)

app = Litestar(
route_handlers=[handler],
openapi_config=OpenAPIConfig(title="my title", version="1.0.0"),
plugins=[PydanticPlugin(prefer_alias=True)],
)

assert app.openapi_schema
schemas = app.openapi_schema.to_schema()["components"]["schemas"]
request_key = "second"
assert schemas["RequestWithAlias"] == {
"properties": {request_key: {"type": "string"}},
"type": "object",
"required": [request_key],
"title": "RequestWithAlias",
}
response_key = "second"
assert schemas["ResponseWithAlias"] == {
"properties": {response_key: {"type": "string"}},
"type": "object",
"required": [response_key],
"title": "ResponseWithAlias",
}

with TestClient(app) as client:
response = client.post("/", json={request_key: "foo"})
assert response.json() == {response_key: "foo"}


def test_allows_customization_of_operation_id_creator() -> None:
def operation_id_creator(handler: "HTTPRouteHandler", _: Any, __: Any) -> str:
return handler.name or ""
Expand Down

0 comments on commit 2222d2d

Please sign in to comment.