Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support Pydantic v1 and v2 simultaneously #492

Merged
merged 5 commits into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 103 additions & 49 deletions polyfactory/factories/pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,79 @@

try:
import pydantic
from pydantic import VERSION, BaseModel, Json
from pydantic import VERSION, Json
from pydantic.fields import FieldInfo
except ImportError as e:
msg = "pydantic is not installed"
raise MissingDependencyException(msg) from e

try:
from pydantic.fields import ModelField # type: ignore[attr-defined]
# pydantic v1
from pydantic import (
UUID1,
UUID3,
UUID4,
UUID5,
AmqpDsn,
AnyHttpUrl,
AnyUrl,
DirectoryPath,
FilePath,
HttpUrl,
KafkaDsn,
PostgresDsn,
PyObject,
RedisDsn,
)
from pydantic import BaseModel as BaseModelV1
from pydantic.color import Color
from pydantic.fields import ( # type: ignore[attr-defined]
DeferredType, # pyright: ignore[reportGeneralTypeIssues]
ModelField, # pyright: ignore[reportGeneralTypeIssues]
Undefined, # pyright: ignore[reportGeneralTypeIssues]
)

# prevent unbound variable warnings
BaseModelV2 = BaseModelV1
UndefinedV2 = Undefined
except ImportError:
ModelField = Any
from pydantic_core import PydanticUndefined as Undefined
# pydantic v2

with suppress(ImportError):
# v2 specific imports
from pydantic import BaseModel as BaseModelV2
from pydantic_core import PydanticUndefined as UndefinedV2
from pydantic_core import to_json

from pydantic.v1 import ( # v1 compat imports
UUID1,
UUID3,
UUID4,
UUID5,
AmqpDsn,
AnyHttpUrl,
AnyUrl,
DirectoryPath,
FilePath,
HttpUrl,
KafkaDsn,
PostgresDsn,
PyObject,
RedisDsn,
)
from pydantic.v1 import BaseModel as BaseModelV1 # type: ignore[assignment]
from pydantic.v1.color import Color # type: ignore[assignment]
from pydantic.v1.fields import DeferredType, ModelField, Undefined


if TYPE_CHECKING:
from random import Random
from typing import Callable, Sequence

from typing_extensions import NotRequired, TypeGuard

T = TypeVar("T", bound=BaseModel)
T = TypeVar("T", bound="BaseModelV1 | BaseModelV2")

_IS_PYDANTIC_V1 = VERSION.startswith("1")


class PydanticConstraints(Constraints):
Expand Down Expand Up @@ -109,11 +160,7 @@
if callable(field_info.default_factory):
default_value = field_info.default_factory()
else:
default_value = (
field_info.default
if field_info.default is not Undefined # pyright: ignore[reportUnboundVariable]
else Null
)
default_value = field_info.default if field_info.default is not UndefinedV2 else Null

annotation = unwrap_new_type(field_info.annotation)
children: list[FieldMeta,] | None = None
Expand Down Expand Up @@ -189,8 +236,6 @@
("max_collection_length", max_collection_length),
),
)
from pydantic import AmqpDsn, AnyHttpUrl, AnyUrl, HttpUrl, KafkaDsn, PostgresDsn, RedisDsn
from pydantic.fields import DeferredType, Undefined # type: ignore[attr-defined]

if model_field.default is not Undefined:
default_value = model_field.default
Expand Down Expand Up @@ -295,7 +340,7 @@
constraints=cast("PydanticConstraints", {k: v for k, v in constraints.items() if v is not None}) or None,
)

if VERSION.startswith("2"):
if not _IS_PYDANTIC_V1:

@classmethod
def get_constraints_metadata(cls, annotation: Any) -> Sequence[Any]:
Expand All @@ -319,12 +364,12 @@
super().__init_subclass__(*args, **kwargs)

if (
VERSION.startswith("1")
and getattr(cls, "__model__", None)
getattr(cls, "__model__", None)
and _is_pydantic_v1_model(cls.__model__)
and hasattr(cls.__model__, "update_forward_refs")
):
with suppress(NameError): # pragma: no cover
cls.__model__.update_forward_refs(**cls.__forward_ref_resolution_type_mapping__)
cls.__model__.update_forward_refs(**cls.__forward_ref_resolution_type_mapping__) # type: ignore[attr-defined]

@classmethod
def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
Expand All @@ -333,7 +378,8 @@
:param value: An arbitrary value.
:returns: A typeguard
"""
return is_safe_subclass(value, BaseModel)

return _is_pydantic_v1_model(value) or _is_pydantic_v2_model(value)

@classmethod
def get_model_fields(cls) -> list["FieldMeta"]:
Expand All @@ -344,24 +390,27 @@

"""
if "_fields_metadata" not in cls.__dict__:
if VERSION.startswith("1"):
if _is_pydantic_v1_model(cls.__model__):
cls._fields_metadata = [
PydanticFieldMeta.from_model_field(
field,
use_alias=not cls.__model__.__config__.allow_population_by_field_name, # type: ignore[attr-defined]
random=cls.__random__,
)
for field in cls.__model__.__fields__.values() # type: ignore[attr-defined]
for field in cls.__model__.__fields__.values()
]
else:
cls._fields_metadata = [
PydanticFieldMeta.from_field_info(
field_info=field_info,
field_name=field_name,
random=cls.__random__,
use_alias=not cls.__model__.model_config.get("populate_by_name", False),
use_alias=not cls.__model__.model_config.get( # pyright: ignore[reportGeneralTypeIssues]
"populate_by_name",
False,
),
)
for field_name, field_info in cls.__model__.model_fields.items()
for field_name, field_info in cls.__model__.model_fields.items() # pyright: ignore[reportGeneralTypeIssues]
]
return cls._fields_metadata

Expand Down Expand Up @@ -392,13 +441,11 @@
processed_kwargs = cls.process_kwargs(**kwargs)

if factory_use_construct:
return (
cls.__model__.model_construct(**processed_kwargs)
if hasattr(cls.__model__, "model_construct")
else cls.__model__.construct(**processed_kwargs)
)
if _is_pydantic_v1_model(cls.__model__):
return cls.__model__.construct(**processed_kwargs) # type: ignore[return-value]
return cls.__model__.model_construct(**processed_kwargs) # type: ignore[return-value]

return cls.__model__(**processed_kwargs)
return cls.__model__(**processed_kwargs) # type: ignore[return-value]

@classmethod
def is_custom_root_field(cls, field_meta: FieldMeta) -> bool:
Expand Down Expand Up @@ -457,29 +504,28 @@
pydantic.FutureDate: cls.__faker__.future_date,
}

if pydantic.VERSION.startswith("1"):
# v1 only values - these will raise an exception in v2
# in pydantic v2 these are all aliases for Annotated with a constraint.
# we therefore do not need them in v2
mapping.update(
{
pydantic.PyObject: lambda: "decimal.Decimal",
pydantic.AmqpDsn: lambda: "amqps://example.com",
pydantic.KafkaDsn: lambda: "kafka://localhost:9092",
pydantic.PostgresDsn: lambda: "postgresql://user:secret@localhost",
pydantic.RedisDsn: lambda: "redis://localhost:6379/0",
pydantic.FilePath: lambda: Path(realpath(__file__)),
pydantic.DirectoryPath: lambda: Path(realpath(__file__)).parent,
pydantic.UUID1: uuid1,
pydantic.UUID3: lambda: uuid3(NAMESPACE_DNS, cls.__faker__.pystr()),
pydantic.UUID4: cls.__faker__.uuid4,
pydantic.UUID5: lambda: uuid5(NAMESPACE_DNS, cls.__faker__.pystr()),
pydantic.color.Color: cls.__faker__.hex_color, # pyright: ignore[reportGeneralTypeIssues]
},
)
else:
# v1 only values
mapping.update(
{
PyObject: lambda: "decimal.Decimal",
AmqpDsn: lambda: "amqps://example.com",
KafkaDsn: lambda: "kafka://localhost:9092",
PostgresDsn: lambda: "postgresql://user:secret@localhost",
Dismissed Show dismissed Hide dismissed
RedisDsn: lambda: "redis://localhost:6379/0",
FilePath: lambda: Path(realpath(__file__)),
DirectoryPath: lambda: Path(realpath(__file__)).parent,
UUID1: uuid1,
UUID3: lambda: uuid3(NAMESPACE_DNS, cls.__faker__.pystr()),
UUID4: cls.__faker__.uuid4,
UUID5: lambda: uuid5(NAMESPACE_DNS, cls.__faker__.pystr()),
Color: cls.__faker__.hex_color, # pyright: ignore[reportGeneralTypeIssues]
},
)

if not _IS_PYDANTIC_V1:
mapping.update(
{
# pydantic v2 specific types
pydantic.PastDatetime: cls.__faker__.past_datetime,
pydantic.FutureDatetime: cls.__faker__.future_datetime,
pydantic.AwareDatetime: partial(cls.__faker__.date_time, timezone.utc),
Expand All @@ -489,3 +535,11 @@

mapping.update(super().get_provider_map())
return mapping


def _is_pydantic_v1_model(model: Any) -> TypeGuard[BaseModelV1]:
return is_safe_subclass(model, BaseModelV1)


def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]:
return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2)
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ classmethod-decorators = [

[tool.ruff.isort]
known-first-party = ["polyfactory", "tests"]
section-order = ["future", "standard-library", "third-party", "pydantic", "pydantic_v1", "first-party", "local-folder"]

[tool.ruff.isort.sections]
pydantic = ["pydantic", "pydantic_core"]
pydantic_v1 = ["pydantic.v1"]

[tool.ruff.per-file-ignores]
"**/*.*" = ["ANN101", "ANN401", "ANN102", "TD",]
Expand Down
1 change: 1 addition & 0 deletions tests/constraints/test_date_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from hypothesis import given
from hypothesis.strategies import dates

from pydantic import BaseModel, condate

from polyfactory.factories.pydantic_factory import ModelFactory
Expand Down
1 change: 1 addition & 0 deletions tests/constraints/test_decimal_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from hypothesis import given
from hypothesis.strategies import decimals, integers

from pydantic import BaseModel, condecimal

from polyfactory.exceptions import ParameterException
Expand Down
1 change: 1 addition & 0 deletions tests/constraints/test_list_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from hypothesis import given
from hypothesis.strategies import integers

from pydantic import VERSION

from polyfactory.exceptions import ParameterException
Expand Down
3 changes: 2 additions & 1 deletion tests/test_annotated_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from typing import Literal, Tuple, Union

from annotated_types import Ge, Le, LowerCase, UpperCase
from pydantic import BaseModel, Field
from typing_extensions import Annotated

from pydantic import BaseModel, Field

from polyfactory.factories import DataclassFactory
from polyfactory.factories.pydantic_factory import ModelFactory

Expand Down
1 change: 1 addition & 0 deletions tests/test_base_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict

import pytest

from pydantic.main import BaseModel

from polyfactory.factories import DataclassFactory
Expand Down
1 change: 1 addition & 0 deletions tests/test_build.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from uuid import uuid4

import pytest

from pydantic import VERSION, BaseModel, Field, ValidationError

from polyfactory.factories.pydantic_factory import ModelFactory
Expand Down
1 change: 1 addition & 0 deletions tests/test_collection_length.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, List, Optional, Set, Tuple

import pytest

from pydantic.dataclasses import dataclass

from polyfactory.factories import DataclassFactory
Expand Down
1 change: 1 addition & 0 deletions tests/test_complex_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)

import pytest

from pydantic import VERSION, BaseModel

from polyfactory.exceptions import ParameterException
Expand Down
1 change: 1 addition & 0 deletions tests/test_constrained_attribute_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict, List, Optional, Tuple

import pytest

from pydantic import (
VERSION,
BaseModel,
Expand Down
5 changes: 3 additions & 2 deletions tests/test_data_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from typing import Callable, Literal, Optional
from uuid import UUID

import pydantic
import pytest
from typing_extensions import TypeAlias

import pydantic
from pydantic import (
UUID1,
UUID3,
Expand Down Expand Up @@ -56,7 +58,6 @@
StrictInt,
StrictStr,
)
from typing_extensions import TypeAlias

from polyfactory.exceptions import ParameterException
from polyfactory.factories.pydantic_factory import ModelFactory
Expand Down
1 change: 1 addition & 0 deletions tests/test_dicts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, Union

import pytest

from pydantic import VERSION, BaseModel

from polyfactory.factories.pydantic_factory import ModelFactory
Expand Down
1 change: 1 addition & 0 deletions tests/test_factory_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Optional, Union

import pytest

from pydantic import BaseModel

from polyfactory.decorators import post_generated
Expand Down
1 change: 1 addition & 0 deletions tests/test_factory_subclassing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass

import pytest

from pydantic import BaseModel

from polyfactory import ConfigurationException
Expand Down
1 change: 1 addition & 0 deletions tests/test_new_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union

import pytest

from pydantic import (
VERSION,
BaseModel,
Expand Down
Loading
Loading