Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Adjust min/max items to valid lengths for Set[Enum] fields #567

Merged
merged 8 commits into from
Sep 13, 2024
13 changes: 12 additions & 1 deletion polyfactory/value_generators/constrained_collections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, List, Mapping, TypeVar, cast
from enum import EnumMeta
from typing import TYPE_CHECKING, Any, Callable, List, Literal, Mapping, TypeVar, cast

from polyfactory.exceptions import ParameterException
from polyfactory.field_meta import FieldMeta
Expand Down Expand Up @@ -43,6 +44,16 @@ def handle_constrained_collection(
msg = "max_items must be larger or equal to min_items"
raise ParameterException(msg)

if collection_type in (frozenset, set) or unique_items:
max_field_values = max_items
if hasattr(field_meta.annotation, "__origin__") and field_meta.annotation.__origin__ is Literal:
if field_meta.children is not None:
max_field_values = len(field_meta.children)
elif isinstance(field_meta.annotation, EnumMeta):
max_field_values = len(field_meta.annotation)
min_items = min(min_items, max_field_values)
max_items = min(max_items, max_field_values)

collection: set[T] | list[T] = set() if (collection_type in (frozenset, set) or unique_items) else []

try:
Expand Down
52 changes: 51 additions & 1 deletion tests/test_collection_length.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Any, Dict, List, Optional, Set, Tuple
from enum import Enum
from typing import Any, Dict, FrozenSet, List, Literal, Optional, Set, Tuple, get_args

import pytest

from pydantic import BaseModel
from pydantic.dataclasses import dataclass

from polyfactory.factories import DataclassFactory
from polyfactory.factories.pydantic_factory import ModelFactory

MIN_MAX_PARAMETERS = ((10, 15), (20, 25), (30, 40), (40, 50))

Expand Down Expand Up @@ -132,3 +135,50 @@ class FooFactory(DataclassFactory[Foo]):

assert len(foo.foo) >= min_val, len(foo.foo)
assert len(foo.foo) <= max_val, len(foo.foo)


@pytest.mark.parametrize("type_", (List, FrozenSet, Set))
@pytest.mark.parametrize("min_items", (0, 2, 4))
@pytest.mark.parametrize("max_inc", (0, 1, 4))
def test_collection_length_with_literal(type_: type, min_items: int, max_inc: int) -> None:
max_items = min_items + max_inc
literal_type = Literal["Dog", "Cat", "Monkey"]

@dataclass
class MyModel:
animal_collection: type_[literal_type] # type: ignore

class MyFactory(DataclassFactory):
__model__ = MyModel
__randomize_collection_length__ = True
__min_collection_length__ = min_items
__max_collection_length__ = max_items

result = MyFactory.build()
assert len(result.animal_collection) >= min(min_items, len(get_args(literal_type)))
assert len(result.animal_collection) <= max_items


@pytest.mark.parametrize("type_", (List, FrozenSet, Set))
@pytest.mark.parametrize("min_items", (0, 2, 4))
@pytest.mark.parametrize("max_inc", (0, 1, 4))
def test_collection_length_with_enum(type_: type, min_items: int, max_inc: int) -> None:
max_items = min_items + max_inc

class Animal(str, Enum):
DOG = "Dog"
CAT = "Cat"
MONKEY = "Monkey"

class MyModel(BaseModel):
animal_collection: type_[Animal] # type: ignore

class MyFactory(ModelFactory):
__model__ = MyModel
__randomize_collection_length__ = True
__min_collection_length__ = min_items
__max_collection_length__ = max_items

result = MyFactory.build()
assert len(result.animal_collection) >= min(min_items, len(Animal))
assert len(result.animal_collection) <= max_items
Loading