diff --git a/polyfactory/factories/pydantic_factory.py b/polyfactory/factories/pydantic_factory.py index 9a823189..7efa2e99 100644 --- a/polyfactory/factories/pydantic_factory.py +++ b/polyfactory/factories/pydantic_factory.py @@ -23,28 +23,79 @@ try: import pydantic - from pydantic import VERSION, BaseModel, Json + from pydantic import VERSION, Json from pydantic.fields import FieldInfo except ImportError as e: msg = "pydantic is not installed" raise MissingDependencyException(msg) from e try: - from pydantic.fields import ModelField # type: ignore[attr-defined] + # pydantic v1 + from pydantic import ( + UUID1, + UUID3, + UUID4, + UUID5, + AmqpDsn, + AnyHttpUrl, + AnyUrl, + DirectoryPath, + FilePath, + HttpUrl, + KafkaDsn, + PostgresDsn, + PyObject, + RedisDsn, + ) + from pydantic import BaseModel as BaseModelV1 + from pydantic.color import Color + from pydantic.fields import ( # type: ignore[attr-defined] + DeferredType, # pyright: ignore[reportGeneralTypeIssues] + ModelField, # pyright: ignore[reportGeneralTypeIssues] + Undefined, # pyright: ignore[reportGeneralTypeIssues] + ) + + # prevent unbound variable warnings + BaseModelV2 = BaseModelV1 + UndefinedV2 = Undefined except ImportError: - ModelField = Any - from pydantic_core import PydanticUndefined as Undefined + # pydantic v2 -with suppress(ImportError): + # v2 specific imports + from pydantic import BaseModel as BaseModelV2 + from pydantic_core import PydanticUndefined as UndefinedV2 from pydantic_core import to_json + from pydantic.v1 import ( # v1 compat imports + UUID1, + UUID3, + UUID4, + UUID5, + AmqpDsn, + AnyHttpUrl, + AnyUrl, + DirectoryPath, + FilePath, + HttpUrl, + KafkaDsn, + PostgresDsn, + PyObject, + RedisDsn, + ) + from pydantic.v1 import BaseModel as BaseModelV1 # type: ignore[assignment] + from pydantic.v1.color import Color # type: ignore[assignment] + from pydantic.v1.fields import DeferredType, ModelField, Undefined + + if TYPE_CHECKING: from random import Random from typing import Callable, Sequence from typing_extensions import NotRequired, TypeGuard -T = TypeVar("T", bound=BaseModel) +T = TypeVar("T", bound="BaseModelV1 | BaseModelV2") + +_IS_PYDANTIC_V1 = VERSION.startswith("1") class PydanticConstraints(Constraints): @@ -109,11 +160,7 @@ def from_field_info( if callable(field_info.default_factory): default_value = field_info.default_factory() else: - default_value = ( - field_info.default - if field_info.default is not Undefined # pyright: ignore[reportUnboundVariable] - else Null - ) + default_value = field_info.default if field_info.default is not UndefinedV2 else Null annotation = unwrap_new_type(field_info.annotation) children: list[FieldMeta,] | None = None @@ -189,8 +236,6 @@ def from_model_field( # pragma: no cover ("max_collection_length", max_collection_length), ), ) - from pydantic import AmqpDsn, AnyHttpUrl, AnyUrl, HttpUrl, KafkaDsn, PostgresDsn, RedisDsn - from pydantic.fields import DeferredType, Undefined # type: ignore[attr-defined] if model_field.default is not Undefined: default_value = model_field.default @@ -295,7 +340,7 @@ def from_model_field( # pragma: no cover constraints=cast("PydanticConstraints", {k: v for k, v in constraints.items() if v is not None}) or None, ) - if VERSION.startswith("2"): + if not _IS_PYDANTIC_V1: @classmethod def get_constraints_metadata(cls, annotation: Any) -> Sequence[Any]: @@ -319,12 +364,12 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: super().__init_subclass__(*args, **kwargs) if ( - VERSION.startswith("1") - and getattr(cls, "__model__", None) + getattr(cls, "__model__", None) + and _is_pydantic_v1_model(cls.__model__) and hasattr(cls.__model__, "update_forward_refs") ): with suppress(NameError): # pragma: no cover - cls.__model__.update_forward_refs(**cls.__forward_ref_resolution_type_mapping__) + cls.__model__.update_forward_refs(**cls.__forward_ref_resolution_type_mapping__) # type: ignore[attr-defined] @classmethod def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: @@ -333,7 +378,8 @@ def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]: :param value: An arbitrary value. :returns: A typeguard """ - return is_safe_subclass(value, BaseModel) + + return _is_pydantic_v1_model(value) or _is_pydantic_v2_model(value) @classmethod def get_model_fields(cls) -> list["FieldMeta"]: @@ -344,14 +390,14 @@ def get_model_fields(cls) -> list["FieldMeta"]: """ if "_fields_metadata" not in cls.__dict__: - if VERSION.startswith("1"): + if _is_pydantic_v1_model(cls.__model__): cls._fields_metadata = [ PydanticFieldMeta.from_model_field( field, use_alias=not cls.__model__.__config__.allow_population_by_field_name, # type: ignore[attr-defined] random=cls.__random__, ) - for field in cls.__model__.__fields__.values() # type: ignore[attr-defined] + for field in cls.__model__.__fields__.values() ] else: cls._fields_metadata = [ @@ -359,9 +405,12 @@ def get_model_fields(cls) -> list["FieldMeta"]: field_info=field_info, field_name=field_name, random=cls.__random__, - use_alias=not cls.__model__.model_config.get("populate_by_name", False), + use_alias=not cls.__model__.model_config.get( # pyright: ignore[reportGeneralTypeIssues] + "populate_by_name", + False, + ), ) - for field_name, field_info in cls.__model__.model_fields.items() + for field_name, field_info in cls.__model__.model_fields.items() # pyright: ignore[reportGeneralTypeIssues] ] return cls._fields_metadata @@ -392,13 +441,11 @@ def build( processed_kwargs = cls.process_kwargs(**kwargs) if factory_use_construct: - return ( - cls.__model__.model_construct(**processed_kwargs) - if hasattr(cls.__model__, "model_construct") - else cls.__model__.construct(**processed_kwargs) - ) + if _is_pydantic_v1_model(cls.__model__): + return cls.__model__.construct(**processed_kwargs) # type: ignore[return-value] + return cls.__model__.model_construct(**processed_kwargs) # type: ignore[return-value] - return cls.__model__(**processed_kwargs) + return cls.__model__(**processed_kwargs) # type: ignore[return-value] @classmethod def is_custom_root_field(cls, field_meta: FieldMeta) -> bool: @@ -457,29 +504,28 @@ def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: pydantic.FutureDate: cls.__faker__.future_date, } - if pydantic.VERSION.startswith("1"): - # v1 only values - these will raise an exception in v2 - # in pydantic v2 these are all aliases for Annotated with a constraint. - # we therefore do not need them in v2 - mapping.update( - { - pydantic.PyObject: lambda: "decimal.Decimal", - pydantic.AmqpDsn: lambda: "amqps://example.com", - pydantic.KafkaDsn: lambda: "kafka://localhost:9092", - pydantic.PostgresDsn: lambda: "postgresql://user:secret@localhost", - pydantic.RedisDsn: lambda: "redis://localhost:6379/0", - pydantic.FilePath: lambda: Path(realpath(__file__)), - pydantic.DirectoryPath: lambda: Path(realpath(__file__)).parent, - pydantic.UUID1: uuid1, - pydantic.UUID3: lambda: uuid3(NAMESPACE_DNS, cls.__faker__.pystr()), - pydantic.UUID4: cls.__faker__.uuid4, - pydantic.UUID5: lambda: uuid5(NAMESPACE_DNS, cls.__faker__.pystr()), - pydantic.color.Color: cls.__faker__.hex_color, # pyright: ignore[reportGeneralTypeIssues] - }, - ) - else: + # v1 only values + mapping.update( + { + PyObject: lambda: "decimal.Decimal", + AmqpDsn: lambda: "amqps://example.com", + KafkaDsn: lambda: "kafka://localhost:9092", + PostgresDsn: lambda: "postgresql://user:secret@localhost", + RedisDsn: lambda: "redis://localhost:6379/0", + FilePath: lambda: Path(realpath(__file__)), + DirectoryPath: lambda: Path(realpath(__file__)).parent, + UUID1: uuid1, + UUID3: lambda: uuid3(NAMESPACE_DNS, cls.__faker__.pystr()), + UUID4: cls.__faker__.uuid4, + UUID5: lambda: uuid5(NAMESPACE_DNS, cls.__faker__.pystr()), + Color: cls.__faker__.hex_color, # pyright: ignore[reportGeneralTypeIssues] + }, + ) + + if not _IS_PYDANTIC_V1: mapping.update( { + # pydantic v2 specific types pydantic.PastDatetime: cls.__faker__.past_datetime, pydantic.FutureDatetime: cls.__faker__.future_datetime, pydantic.AwareDatetime: partial(cls.__faker__.date_time, timezone.utc), @@ -489,3 +535,11 @@ def get_provider_map(cls) -> dict[Any, Callable[[], Any]]: mapping.update(super().get_provider_map()) return mapping + + +def _is_pydantic_v1_model(model: Any) -> TypeGuard[BaseModelV1]: + return is_safe_subclass(model, BaseModelV1) + + +def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]: + return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2) diff --git a/pyproject.toml b/pyproject.toml index 7ef75cf1..a8a28baf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -203,6 +203,11 @@ classmethod-decorators = [ [tool.ruff.isort] known-first-party = ["polyfactory", "tests"] +section-order = ["future", "standard-library", "third-party", "pydantic", "pydantic_v1", "first-party", "local-folder"] + +[tool.ruff.isort.sections] +pydantic = ["pydantic", "pydantic_core"] +pydantic_v1 = ["pydantic.v1"] [tool.ruff.per-file-ignores] "**/*.*" = ["ANN101", "ANN401", "ANN102", "TD",] diff --git a/tests/constraints/test_date_constraints.py b/tests/constraints/test_date_constraints.py index 8df7ae06..6ccee453 100644 --- a/tests/constraints/test_date_constraints.py +++ b/tests/constraints/test_date_constraints.py @@ -4,6 +4,7 @@ import pytest from hypothesis import given from hypothesis.strategies import dates + from pydantic import BaseModel, condate from polyfactory.factories.pydantic_factory import ModelFactory diff --git a/tests/constraints/test_decimal_constraints.py b/tests/constraints/test_decimal_constraints.py index 5791d596..b6a76a3c 100644 --- a/tests/constraints/test_decimal_constraints.py +++ b/tests/constraints/test_decimal_constraints.py @@ -5,6 +5,7 @@ import pytest from hypothesis import given from hypothesis.strategies import decimals, integers + from pydantic import BaseModel, condecimal from polyfactory.exceptions import ParameterException diff --git a/tests/constraints/test_list_constraints.py b/tests/constraints/test_list_constraints.py index e92cf495..03edf91a 100644 --- a/tests/constraints/test_list_constraints.py +++ b/tests/constraints/test_list_constraints.py @@ -5,6 +5,7 @@ import pytest from hypothesis import given from hypothesis.strategies import integers + from pydantic import VERSION from polyfactory.exceptions import ParameterException diff --git a/tests/test_annotated_fields.py b/tests/test_annotated_fields.py index 65cc2086..510a7b28 100644 --- a/tests/test_annotated_fields.py +++ b/tests/test_annotated_fields.py @@ -2,9 +2,10 @@ from typing import Literal, Tuple, Union from annotated_types import Ge, Le, LowerCase, UpperCase -from pydantic import BaseModel, Field from typing_extensions import Annotated +from pydantic import BaseModel, Field + from polyfactory.factories import DataclassFactory from polyfactory.factories.pydantic_factory import ModelFactory diff --git a/tests/test_base_factories.py b/tests/test_base_factories.py index 751b6b1c..922b4b02 100644 --- a/tests/test_base_factories.py +++ b/tests/test_base_factories.py @@ -2,6 +2,7 @@ from typing import Any, Dict import pytest + from pydantic.main import BaseModel from polyfactory.factories import DataclassFactory diff --git a/tests/test_build.py b/tests/test_build.py index 7ea29dbd..d64deecc 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -1,6 +1,7 @@ from uuid import uuid4 import pytest + from pydantic import VERSION, BaseModel, Field, ValidationError from polyfactory.factories.pydantic_factory import ModelFactory diff --git a/tests/test_collection_length.py b/tests/test_collection_length.py index b8563165..be85307c 100644 --- a/tests/test_collection_length.py +++ b/tests/test_collection_length.py @@ -1,6 +1,7 @@ from typing import Dict, List, Optional, Set, Tuple import pytest + from pydantic.dataclasses import dataclass from polyfactory.factories import DataclassFactory diff --git a/tests/test_complex_types.py b/tests/test_complex_types.py index 04d10bbf..41566c8c 100644 --- a/tests/test_complex_types.py +++ b/tests/test_complex_types.py @@ -20,6 +20,7 @@ ) import pytest + from pydantic import VERSION, BaseModel from polyfactory.exceptions import ParameterException diff --git a/tests/test_constrained_attribute_parsing.py b/tests/test_constrained_attribute_parsing.py index 5df51def..dc810bc8 100644 --- a/tests/test_constrained_attribute_parsing.py +++ b/tests/test_constrained_attribute_parsing.py @@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple import pytest + from pydantic import ( VERSION, BaseModel, diff --git a/tests/test_data_parsing.py b/tests/test_data_parsing.py index 30e46413..00e61b70 100644 --- a/tests/test_data_parsing.py +++ b/tests/test_data_parsing.py @@ -14,8 +14,10 @@ from typing import Callable, Literal, Optional from uuid import UUID -import pydantic import pytest +from typing_extensions import TypeAlias + +import pydantic from pydantic import ( UUID1, UUID3, @@ -56,7 +58,6 @@ StrictInt, StrictStr, ) -from typing_extensions import TypeAlias from polyfactory.exceptions import ParameterException from polyfactory.factories.pydantic_factory import ModelFactory diff --git a/tests/test_dicts.py b/tests/test_dicts.py index 8fa61950..a60a05a6 100644 --- a/tests/test_dicts.py +++ b/tests/test_dicts.py @@ -1,6 +1,7 @@ from typing import Dict, Union import pytest + from pydantic import VERSION, BaseModel from polyfactory.factories.pydantic_factory import ModelFactory diff --git a/tests/test_factory_fields.py b/tests/test_factory_fields.py index b6d01511..7bcd6808 100644 --- a/tests/test_factory_fields.py +++ b/tests/test_factory_fields.py @@ -3,6 +3,7 @@ from typing import Any, Optional, Union import pytest + from pydantic import BaseModel from polyfactory.decorators import post_generated diff --git a/tests/test_factory_subclassing.py b/tests/test_factory_subclassing.py index bef892f1..f42631cc 100644 --- a/tests/test_factory_subclassing.py +++ b/tests/test_factory_subclassing.py @@ -1,6 +1,7 @@ from dataclasses import dataclass import pytest + from pydantic import BaseModel from polyfactory import ConfigurationException diff --git a/tests/test_new_types.py b/tests/test_new_types.py index 3788d6b7..10e513bd 100644 --- a/tests/test_new_types.py +++ b/tests/test_new_types.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union import pytest + from pydantic import ( VERSION, BaseModel, diff --git a/tests/test_optional_model_field_inference.py b/tests/test_optional_model_field_inference.py index 3aeefda7..d8aac75b 100644 --- a/tests/test_optional_model_field_inference.py +++ b/tests/test_optional_model_field_inference.py @@ -3,11 +3,12 @@ import pytest from attrs import define from msgspec import Struct -from pydantic import BaseModel -from pydantic.generics import GenericModel from sqlalchemy import Column, Integer from sqlalchemy.orm.decl_api import DeclarativeMeta, registry +from pydantic import BaseModel +from pydantic.generics import GenericModel + from polyfactory import ConfigurationException from polyfactory.factories import TypedDictFactory from polyfactory.factories.attrs_factory import AttrsFactory diff --git a/tests/test_options_validation.py b/tests/test_options_validation.py index 654691cc..54bb53a8 100644 --- a/tests/test_options_validation.py +++ b/tests/test_options_validation.py @@ -1,6 +1,7 @@ from typing import List, Optional import pytest + from pydantic import BaseModel, Field from polyfactory import ConfigurationException diff --git a/tests/test_persistence_handlers.py b/tests/test_persistence_handlers.py index d797870c..cdd62e0e 100644 --- a/tests/test_persistence_handlers.py +++ b/tests/test_persistence_handlers.py @@ -1,6 +1,7 @@ from typing import Any import pytest + from pydantic import BaseModel from polyfactory import AsyncPersistenceProtocol, SyncPersistenceProtocol diff --git a/tests/test_pydantic_factory.py b/tests/test_pydantic_factory.py index f04fb1fc..eab588ad 100644 --- a/tests/test_pydantic_factory.py +++ b/tests/test_pydantic_factory.py @@ -2,9 +2,10 @@ from typing import Dict, List, Optional, Set, Tuple import pytest -from pydantic import VERSION, BaseModel, Field, Json from typing_extensions import Annotated +from pydantic import VERSION, BaseModel, Field, Json + from polyfactory.factories.pydantic_factory import ModelFactory diff --git a/tests/test_pydantic_v1_v2.py b/tests/test_pydantic_v1_v2.py new file mode 100644 index 00000000..7fad1450 --- /dev/null +++ b/tests/test_pydantic_v1_v2.py @@ -0,0 +1,79 @@ +"""Tests to check that usage of pydantic v1 and v2 at the same time works.""" + +from typing import Dict, List, Optional, Type, Union + +import pytest +from typing_extensions import Annotated + +import pydantic + +from polyfactory.factories.pydantic_factory import ModelFactory + +if pydantic.VERSION.startswith("1"): + pytest.skip("only for pydantic v2", allow_module_level=True) + +from pydantic import BaseModel as BaseModelV2 + +try: + from pydantic.v1 import BaseModel as BaseModelV1 +except ImportError: + from pydantic import BaseModel as BaseModelV1 # type: ignore[assignment] + + +@pytest.mark.parametrize("base_model", [BaseModelV1, BaseModelV2]) +def test_is_supported_type(base_model: Type[Union[BaseModelV1, BaseModelV2]]) -> None: + class Foo(base_model): # type: ignore[valid-type, misc] + ... + + assert ModelFactory.is_supported_type(Foo) is True + + +@pytest.mark.parametrize("base_model", [BaseModelV1, BaseModelV2]) +def test_build(base_model: Type[Union[BaseModelV1, BaseModelV2]]) -> None: + class Foo(base_model): # type: ignore[valid-type, misc] + a: int + b: str + c: bool + + FooFactory = ModelFactory.create_factory(Foo) + foo = FooFactory.build() + + assert isinstance(foo.a, int) + assert isinstance(foo.b, str) + assert isinstance(foo.c, bool) + + +def test_build_v1_with_contrained_fields() -> None: + from pydantic.v1.fields import Field + + ConstrainedInt = Annotated[int, Field(ge=100, le=200)] + ConstrainedStr = Annotated[str, Field(min_length=1, max_length=3)] + + class Foo(pydantic.v1.BaseModel): # pyright: ignore[reportGeneralTypeIssues] + a: ConstrainedInt + b: ConstrainedStr + c: Union[ConstrainedInt, ConstrainedStr] + d: Optional[ConstrainedInt] + e: Optional[Union[ConstrainedInt, ConstrainedStr]] + f: List[ConstrainedInt] + g: Dict[ConstrainedInt, ConstrainedStr] + + ModelFactory.create_factory(Foo).build() # type: ignore[type-var] + + +def test_build_v2_with_contrained_fields() -> None: + from pydantic.fields import Field + + ConstrainedInt = Annotated[int, Field(ge=100, le=200)] + ConstrainedStr = Annotated[str, Field(min_length=1, max_length=3)] + + class Foo(pydantic.BaseModel): # pyright: ignore[reportGeneralTypeIssues] + a: ConstrainedInt + b: ConstrainedStr + c: Union[ConstrainedInt, ConstrainedStr] + d: Optional[ConstrainedInt] + e: Optional[Union[ConstrainedInt, ConstrainedStr]] + f: List[ConstrainedInt] + g: Dict[ConstrainedInt, ConstrainedStr] + + ModelFactory.create_factory(Foo).build() diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest_plugin.py index 42a53de4..4708a5c1 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest_plugin.py @@ -1,6 +1,7 @@ from typing import List import pytest + from pydantic import BaseModel from polyfactory.exceptions import ParameterException diff --git a/tests/test_recursive_models.py b/tests/test_recursive_models.py index 9f8c4ad1..36002ef9 100644 --- a/tests/test_recursive_models.py +++ b/tests/test_recursive_models.py @@ -4,6 +4,7 @@ from typing import List, Optional, Union import pytest + from pydantic import BaseModel, Field from polyfactory.factories.dataclass_factory import DataclassFactory diff --git a/tests/test_setting_default_factories.py b/tests/test_setting_default_factories.py index 9e60aed6..3c2ad5ec 100644 --- a/tests/test_setting_default_factories.py +++ b/tests/test_setting_default_factories.py @@ -1,9 +1,10 @@ from dataclasses import dataclass as vanilla_dataclass from typing import List -from pydantic import BaseModel from typing_extensions import TypedDict +from pydantic import BaseModel + from polyfactory.factories import DataclassFactory, TypedDictFactory from polyfactory.factories.pydantic_factory import ModelFactory diff --git a/tests/test_typeddict_factory.py b/tests/test_typeddict_factory.py index c9a32674..e68d8fb6 100644 --- a/tests/test_typeddict_factory.py +++ b/tests/test_typeddict_factory.py @@ -1,9 +1,10 @@ from typing import Dict, List, Optional from annotated_types import Ge -from pydantic import BaseModel from typing_extensions import Annotated, NotRequired, Required, TypedDict +from pydantic import BaseModel + from polyfactory.factories import TypedDictFactory from polyfactory.factories.pydantic_factory import ModelFactory diff --git a/tests/test_union_handling.py b/tests/test_union_handling.py index 240a6e37..31818d2d 100644 --- a/tests/test_union_handling.py +++ b/tests/test_union_handling.py @@ -2,9 +2,10 @@ import pytest from annotated_types import Ge, MinLen -from pydantic import BaseModel from typing_extensions import Annotated +from pydantic import BaseModel + from polyfactory.factories.pydantic_factory import ModelFactory diff --git a/tests/test_utils.py b/tests/test_utils.py index 9dcb00b5..1ef25144 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,6 +6,7 @@ import pytest from hypothesis import given from hypothesis.strategies import decimals, floats, integers + from pydantic import BaseModel from polyfactory.factories.pydantic_factory import ModelFactory