Skip to content

Commit

Permalink
fix: #2133 - openapi schema generation when creating examples with Py…
Browse files Browse the repository at this point in the history
…dantic models (#2178)

* fix: pass serializer when encoding Open API schema to JSON

* fix: remove UnsetType from Unions when creating examples

* fix: properly encode OpenAPI schema as YAML when examples are included

* test: test OpenAPI schema generation with create_examples=True
  • Loading branch information
guacs authored Aug 20, 2023
1 parent 3e712e7 commit be75fbb
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 19 deletions.
16 changes: 16 additions & 0 deletions litestar/_openapi/schema_generation/examples.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from __future__ import annotations

import typing
from dataclasses import replace
from decimal import Decimal
from enum import Enum
from typing import TYPE_CHECKING, Any

import msgspec
from polyfactory.exceptions import ParameterException
from polyfactory.factories import DataclassFactory
from polyfactory.field_meta import FieldMeta, Null
from polyfactory.utils.helpers import unwrap_annotation
from polyfactory.utils.predicates import is_union
from typing_extensions import get_args

from litestar.openapi.spec import Example
from litestar.types import Empty
Expand All @@ -25,6 +29,18 @@ class ExampleFactory(DataclassFactory[Example]):

def _normalize_example_value(value: Any) -> Any:
"""Normalize the example value to make it look a bit prettier."""
# if UnsetType is part of the union, then it might get chosen as the value
# but that will not be properly serialized by msgspec unless it is for a field
# in a msgspec Struct
if is_union(value):
args = list(get_args(value))
try:
args.remove(msgspec.UnsetType)
value = typing.Union[tuple(args)] # pyright: ignore
except ValueError:
# UnsetType not part of the Union
pass

value = unwrap_annotation(annotation=value, random=ExampleFactory.__random__)
if isinstance(value, (Decimal, float)):
value = round(float(value), 2)
Expand Down
32 changes: 21 additions & 11 deletions litestar/openapi/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from litestar.handlers import get
from litestar.response.base import ASGIResponse
from litestar.serialization import encode_json
from litestar.serialization.msgspec_hooks import decode_json
from litestar.status_codes import HTTP_404_NOT_FOUND

__all__ = ("OpenAPIController",)
Expand Down Expand Up @@ -64,7 +65,8 @@ class OpenAPIController(Controller):
"""Download url for the Stoplight Elements JS bundle."""

# internal
_dumped_schema: str = ""
_dumped_json_schema: str = ""
_dumped_yaml_schema: bytes = b""
# until swagger-ui supports v3.1.* of OpenAPI officially, we need to modify the schema for it and keep it
# separate from the redoc version of the schema, which is unmodified.
dto = None
Expand Down Expand Up @@ -143,10 +145,11 @@ def retrieve_schema_yaml(self, request: Request[Any, Any, Any]) -> ASGIResponse:
A Response instance with the YAML object rendered into a string.
"""
if self.should_serve_endpoint(request):
content = dump_yaml(self.get_schema_from_request(request).to_schema(), default_flow_style=False).encode(
"utf-8"
)
return ASGIResponse(body=content, media_type=OpenAPIMediaType.OPENAPI_YAML)
if not self._dumped_json_schema:
schema_json = decode_json(self._get_schema_as_json(request))
schema_yaml = dump_yaml(schema_json, default_flow_style=False)
self._dumped_yaml_schema = schema_yaml.encode("utf-8")
return ASGIResponse(body=self._dumped_yaml_schema, media_type=OpenAPIMediaType.OPENAPI_YAML)
return ASGIResponse(body=b"", status_code=HTTP_404_NOT_FOUND, media_type=MediaType.HTML)

@get(path="/openapi.json", media_type=OpenAPIMediaType.OPENAPI_JSON, include_in_schema=False, sync_to_thread=False)
Expand All @@ -162,7 +165,7 @@ def retrieve_schema_json(self, request: Request[Any, Any, Any]) -> ASGIResponse:
"""
if self.should_serve_endpoint(request):
return ASGIResponse(
body=encode_json(self.get_schema_from_request(request).to_schema()),
body=self._get_schema_as_json(request),
media_type=OpenAPIMediaType.OPENAPI_JSON,
)
return ASGIResponse(body=b"", status_code=HTTP_404_NOT_FOUND, media_type=MediaType.HTML)
Expand Down Expand Up @@ -272,7 +275,7 @@ def render_swagger_ui(self, request: Request[Any, Any, Any]) -> bytes:
<div id='swagger-container'/>
<script type="text/javascript">
const ui = SwaggerUIBundle({{
spec: {encode_json(schema.to_schema()).decode("utf-8")},
spec: {self._get_schema_as_json(request)},
dom_id: '#swagger-container',
deepLinking: true,
showExtensions: true,
Expand Down Expand Up @@ -353,9 +356,6 @@ def render_redoc(self, request: Request[Any, Any, Any]) -> bytes: # pragma: no
"""
schema = self.get_schema_from_request(request)

if not self._dumped_schema:
self._dumped_schema = encode_json(schema.to_schema()).decode("utf-8")

head = f"""
<head>
<title>{schema.info.title}</title>
Expand All @@ -382,7 +382,7 @@ def render_redoc(self, request: Request[Any, Any, Any]) -> bytes: # pragma: no
<div id='redoc-container'/>
<script type="text/javascript">
Redoc.init(
{self._dumped_schema},
{self._get_schema_as_json(request)},
undefined,
document.getElementById('redoc-container')
)
Expand Down Expand Up @@ -422,3 +422,13 @@ def render_404_page(self) -> bytes:
</body>
</html>
""".encode()

def _get_schema_as_json(self, request: Request) -> str:
"""Get the schema encoded as a JSON string."""

if not self._dumped_json_schema:
schema = self.get_schema_from_request(request).to_schema()
json_encoded_schema = encode_json(schema, request.route_handler.default_serializer)
self._dumped_json_schema = json_encoded_schema.decode("utf-8")

return self._dumped_json_schema
35 changes: 27 additions & 8 deletions tests/unit/test_openapi/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Type

import msgspec
import pytest
import yaml
from pydantic import BaseModel, Field
from typing_extensions import Annotated
Expand All @@ -9,32 +10,46 @@
from litestar.app import DEFAULT_OPENAPI_CONFIG
from litestar.enums import OpenAPIMediaType
from litestar.openapi import OpenAPIConfig, OpenAPIController
from litestar.serialization.msgspec_hooks import decode_json, encode_json, get_serializer
from litestar.status_codes import HTTP_200_OK, HTTP_404_NOT_FOUND
from litestar.testing import create_test_client

CREATE_EXAMPLES_VALUES = (True, False)

def test_openapi_yaml(person_controller: Type[Controller], pet_controller: Type[Controller]) -> None:
with create_test_client([person_controller, pet_controller], openapi_config=DEFAULT_OPENAPI_CONFIG) as client:

@pytest.mark.parametrize("create_examples", CREATE_EXAMPLES_VALUES)
def test_openapi_yaml(
person_controller: Type[Controller], pet_controller: Type[Controller], create_examples: bool
) -> None:
openapi_config = OpenAPIConfig("Example API", "1.0.0", create_examples=create_examples)
with create_test_client([person_controller, pet_controller], openapi_config=openapi_config) as client:
assert client.app.openapi_schema
openapi_schema = client.app.openapi_schema
assert openapi_schema.paths
response = client.get("/schema/openapi.yaml")
assert response.status_code == HTTP_200_OK
assert response.headers["content-type"] == OpenAPIMediaType.OPENAPI_YAML.value
assert client.app.openapi_schema
assert yaml.unsafe_load(response.content) == client.app.openapi_schema.to_schema()
serializer = get_serializer(client.app.type_encoders)
schema_json = decode_json(encode_json(openapi_schema.to_schema(), serializer))
assert response.content.decode("utf-8") == yaml.dump(schema_json)


def test_openapi_json(person_controller: Type[Controller], pet_controller: Type[Controller]) -> None:
with create_test_client([person_controller, pet_controller], openapi_config=DEFAULT_OPENAPI_CONFIG) as client:
@pytest.mark.parametrize("create_examples", CREATE_EXAMPLES_VALUES)
def test_openapi_json(
person_controller: Type[Controller], pet_controller: Type[Controller], create_examples: bool
) -> None:
openapi_config = OpenAPIConfig("Example API", "1.0.0", create_examples=create_examples)
with create_test_client([person_controller, pet_controller], openapi_config=openapi_config) as client:
assert client.app.openapi_schema
openapi_schema = client.app.openapi_schema
assert openapi_schema.paths
response = client.get("/schema/openapi.json")
assert response.status_code == HTTP_200_OK
assert response.headers["content-type"] == OpenAPIMediaType.OPENAPI_JSON.value
assert client.app.openapi_schema
assert response.json() == client.app.openapi_schema.to_schema()
serializer = get_serializer(client.app.type_encoders)
assert response.content == encode_json(openapi_schema.to_schema(), serializer)


def test_openapi_yaml_not_allowed(person_controller: Type[Controller], pet_controller: Type[Controller]) -> None:
Expand Down Expand Up @@ -118,7 +133,8 @@ class CustomOpenAPIController(OpenAPIController):
assert response.status_code == HTTP_200_OK


def test_msgspec_schema_generation() -> None:
@pytest.mark.parametrize("create_examples", CREATE_EXAMPLES_VALUES)
def test_msgspec_schema_generation(create_examples: bool) -> None:
class Lookup(msgspec.Struct):
id: Annotated[
str,
Expand All @@ -139,6 +155,7 @@ async def example_route() -> Lookup:
openapi_config=OpenAPIConfig(
title="Example API",
version="1.0.0",
create_examples=create_examples,
),
) as client:
response = client.get("/schema/openapi.json")
Expand All @@ -152,7 +169,8 @@ async def example_route() -> Lookup:
}


def test_pydantic_schema_generation() -> None:
@pytest.mark.parametrize("create_examples", CREATE_EXAMPLES_VALUES)
def test_pydantic_schema_generation(create_examples: bool) -> None:
class Lookup(BaseModel):
id: Annotated[
str,
Expand All @@ -173,6 +191,7 @@ async def example_route() -> Lookup:
openapi_config=OpenAPIConfig(
title="Example API",
version="1.0.0",
create_examples=create_examples,
),
) as client:
response = client.get("/schema/openapi.json")
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_openapi/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from litestar import Controller, Litestar, Router, get
from litestar._openapi.parameters import create_parameter_for_handler
from litestar._openapi.schema_generation import SchemaCreator
from litestar._openapi.schema_generation.examples import ExampleFactory
from litestar._openapi.typescript_converter.schema_parsing import is_schema_value
from litestar._signature import SignatureModel
from litestar.di import Provide
Expand Down Expand Up @@ -41,6 +42,8 @@ def _create_parameters(app: Litestar, path: str) -> List["OpenAPIParameter"]:


def test_create_parameters(person_controller: Type[Controller]) -> None:
ExampleFactory.seed_random(10)

parameters = _create_parameters(app=Litestar(route_handlers=[person_controller]), path="/{service_id}/person")
assert len(parameters) == 9
page, name, page_size, service_id, from_date, to_date, gender, secret_header, cookie_value = tuple(parameters)
Expand Down

0 comments on commit be75fbb

Please sign in to comment.