Skip to content

Commit

Permalink
feat(DTO): deterministic transfer model names (#2389)
Browse files Browse the repository at this point in the history
* feat(DTO): make transfer model names deterministic

---------

Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>
  • Loading branch information
provinzkraut committed Sep 30, 2023
1 parent 2756f7e commit 4ef5fee
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 15 deletions.
28 changes: 15 additions & 13 deletions litestar/dto/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""
from __future__ import annotations

import secrets
from dataclasses import replace
from typing import TYPE_CHECKING, AbstractSet, Any, Callable, ClassVar, Collection, Final, Mapping, Union, cast

Expand All @@ -27,6 +26,7 @@
from litestar.serialization import decode_json, decode_msgpack
from litestar.types import Empty
from litestar.typing import FieldDefinition
from litestar.utils import unique_name_for_scope
from litestar.utils.typing import safe_generic_origin_map

if TYPE_CHECKING:
Expand Down Expand Up @@ -154,6 +154,18 @@ def parse_model(
defined_fields.append(transfer_field_definition)
return tuple(defined_fields)

def _create_transfer_model_name(self, model_name: str) -> str:
long_name_prefix = self.handler_id.split("::")[0]
short_name_prefix = _camelize(long_name_prefix.split(".")[-1], True)

name_suffix = "RequestBody" if self.is_data_field else "ResponseBody"

if (short_name := f"{short_name_prefix}{model_name}{name_suffix}") not in self._seen_model_names:
return short_name
if (long_name := f"{long_name_prefix}{model_name}{name_suffix}") not in self._seen_model_names:
return long_name
return unique_name_for_scope(long_name, self._seen_model_names)

def create_transfer_model_type(
self, model_name: str, field_definitions: tuple[TransferDTOFieldDefinition, ...]
) -> type[Struct]:
Expand All @@ -166,19 +178,9 @@ def create_transfer_model_type(
Returns:
A ``BackendT`` class.
"""
long_name_prefix = self.handler_id.split("::")[0]
short_name_prefix = _camelize(long_name_prefix.split(".")[-1], True)

name_suffix = "RequestBody" if self.is_data_field else "ResponseBody"

if f"{short_name_prefix}{model_name}{name_suffix}" not in self._seen_model_names:
struct_name = f"{short_name_prefix}{model_name}{name_suffix}"
elif f"{long_name_prefix}{model_name}{name_suffix}" not in self._seen_model_names:
struct_name = f"{long_name_prefix}{model_name}{name_suffix}"
else:
struct_name = f"{long_name_prefix}{model_name}{name_suffix}{secrets.token_hex(8)}"

struct_name = self._create_transfer_model_name(model_name)
self._seen_model_names.add(struct_name)

struct = _create_struct_for_field_definitions(struct_name, field_definitions)
setattr(struct, "__schema_name__", struct_name)
return struct
Expand Down
3 changes: 2 additions & 1 deletion litestar/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from litestar.utils.deprecation import deprecated, warn_deprecation

from .helpers import Ref, get_enum_string_value, get_name, url_quote
from .helpers import Ref, get_enum_string_value, get_name, unique_name_for_scope, url_quote
from .path import join_paths, normalize_path
from .predicates import (
is_annotated_type,
Expand Down Expand Up @@ -75,6 +75,7 @@
"normalize_path",
"set_litestar_scope_state",
"unique",
"unique_name_for_scope",
"url_quote",
"warn_deprecation",
)
11 changes: 11 additions & 0 deletions litestar/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from urllib.parse import quote

if TYPE_CHECKING:
from collections.abc import Container
from typing import Iterable

from litestar.datastructures import Cookie
Expand All @@ -18,6 +19,7 @@
"get_name",
"unwrap_partial",
"url_quote",
"unique_name_for_scope",
)

T = TypeVar("T")
Expand Down Expand Up @@ -99,3 +101,12 @@ def url_quote(value: str | bytes) -> str:
A quoted URL.
"""
return quote(value, safe="/#%[]=:;$&()+,!?*@'~")


def unique_name_for_scope(base_name: str, scope: Container[str]) -> str:
"""Create a name derived from ``base_name`` that's unique within ``scope``"""
i = 0
while True:
if (unique_name := f"{base_name}_{i}") not in scope:
return unique_name
i += 1
10 changes: 9 additions & 1 deletion tests/unit/test_utils/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import partial

from litestar.utils.helpers import unwrap_partial
from litestar.utils.helpers import unique_name_for_scope, unwrap_partial


def test_unwrap_partial() -> None:
Expand All @@ -11,3 +11,11 @@ def func(*args: int) -> int:

assert wrapped() == 3
assert unwrap_partial(wrapped) is func


def test_unique_name_for_scope() -> None:
assert unique_name_for_scope("a", []) == "a_0"

assert unique_name_for_scope("a", ["a", "a_0", "b"]) == "a_1"

assert unique_name_for_scope("b", ["a", "a_0", "b"]) == "b_0"

0 comments on commit 4ef5fee

Please sign in to comment.