Skip to content

Commit

Permalink
feat: support Pydantic v1 and v2 simultaneously (#492)
Browse files Browse the repository at this point in the history
  • Loading branch information
guacs authored Jan 21, 2024
1 parent b44c68b commit 46ecdc6
Show file tree
Hide file tree
Showing 27 changed files with 220 additions and 58 deletions.
152 changes: 103 additions & 49 deletions polyfactory/factories/pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,79 @@

try:
import pydantic
from pydantic import VERSION, BaseModel, Json
from pydantic import VERSION, Json
from pydantic.fields import FieldInfo
except ImportError as e:
msg = "pydantic is not installed"
raise MissingDependencyException(msg) from e

try:
from pydantic.fields import ModelField # type: ignore[attr-defined]
# pydantic v1
from pydantic import (
UUID1,
UUID3,
UUID4,
UUID5,
AmqpDsn,
AnyHttpUrl,
AnyUrl,
DirectoryPath,
FilePath,
HttpUrl,
KafkaDsn,
PostgresDsn,
PyObject,
RedisDsn,
)
from pydantic import BaseModel as BaseModelV1
from pydantic.color import Color
from pydantic.fields import ( # type: ignore[attr-defined]
DeferredType, # pyright: ignore[reportGeneralTypeIssues]
ModelField, # pyright: ignore[reportGeneralTypeIssues]
Undefined, # pyright: ignore[reportGeneralTypeIssues]
)

# prevent unbound variable warnings
BaseModelV2 = BaseModelV1
UndefinedV2 = Undefined
except ImportError:
ModelField = Any
from pydantic_core import PydanticUndefined as Undefined
# pydantic v2

with suppress(ImportError):
# v2 specific imports
from pydantic import BaseModel as BaseModelV2
from pydantic_core import PydanticUndefined as UndefinedV2
from pydantic_core import to_json

from pydantic.v1 import ( # v1 compat imports
UUID1,
UUID3,
UUID4,
UUID5,
AmqpDsn,
AnyHttpUrl,
AnyUrl,
DirectoryPath,
FilePath,
HttpUrl,
KafkaDsn,
PostgresDsn,
PyObject,
RedisDsn,
)
from pydantic.v1 import BaseModel as BaseModelV1 # type: ignore[assignment]
from pydantic.v1.color import Color # type: ignore[assignment]
from pydantic.v1.fields import DeferredType, ModelField, Undefined


if TYPE_CHECKING:
from random import Random
from typing import Callable, Sequence

from typing_extensions import NotRequired, TypeGuard

T = TypeVar("T", bound=BaseModel)
T = TypeVar("T", bound="BaseModelV1 | BaseModelV2")

_IS_PYDANTIC_V1 = VERSION.startswith("1")


class PydanticConstraints(Constraints):
Expand Down Expand Up @@ -109,11 +160,7 @@ def from_field_info(
if callable(field_info.default_factory):
default_value = field_info.default_factory()
else:
default_value = (
field_info.default
if field_info.default is not Undefined # pyright: ignore[reportUnboundVariable]
else Null
)
default_value = field_info.default if field_info.default is not UndefinedV2 else Null

annotation = unwrap_new_type(field_info.annotation)
children: list[FieldMeta,] | None = None
Expand Down Expand Up @@ -189,8 +236,6 @@ def from_model_field( # pragma: no cover
("max_collection_length", max_collection_length),
),
)
from pydantic import AmqpDsn, AnyHttpUrl, AnyUrl, HttpUrl, KafkaDsn, PostgresDsn, RedisDsn
from pydantic.fields import DeferredType, Undefined # type: ignore[attr-defined]

if model_field.default is not Undefined:
default_value = model_field.default
Expand Down Expand Up @@ -295,7 +340,7 @@ def from_model_field( # pragma: no cover
constraints=cast("PydanticConstraints", {k: v for k, v in constraints.items() if v is not None}) or None,
)

if VERSION.startswith("2"):
if not _IS_PYDANTIC_V1:

@classmethod
def get_constraints_metadata(cls, annotation: Any) -> Sequence[Any]:
Expand All @@ -319,12 +364,12 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
super().__init_subclass__(*args, **kwargs)

if (
VERSION.startswith("1")
and getattr(cls, "__model__", None)
getattr(cls, "__model__", None)
and _is_pydantic_v1_model(cls.__model__)
and hasattr(cls.__model__, "update_forward_refs")
):
with suppress(NameError): # pragma: no cover
cls.__model__.update_forward_refs(**cls.__forward_ref_resolution_type_mapping__)
cls.__model__.update_forward_refs(**cls.__forward_ref_resolution_type_mapping__) # type: ignore[attr-defined]

@classmethod
def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
Expand All @@ -333,7 +378,8 @@ def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
:param value: An arbitrary value.
:returns: A typeguard
"""
return is_safe_subclass(value, BaseModel)

return _is_pydantic_v1_model(value) or _is_pydantic_v2_model(value)

@classmethod
def get_model_fields(cls) -> list["FieldMeta"]:
Expand All @@ -344,24 +390,27 @@ def get_model_fields(cls) -> list["FieldMeta"]:
"""
if "_fields_metadata" not in cls.__dict__:
if VERSION.startswith("1"):
if _is_pydantic_v1_model(cls.__model__):
cls._fields_metadata = [
PydanticFieldMeta.from_model_field(
field,
use_alias=not cls.__model__.__config__.allow_population_by_field_name, # type: ignore[attr-defined]
random=cls.__random__,
)
for field in cls.__model__.__fields__.values() # type: ignore[attr-defined]
for field in cls.__model__.__fields__.values()
]
else:
cls._fields_metadata = [
PydanticFieldMeta.from_field_info(
field_info=field_info,
field_name=field_name,
random=cls.__random__,
use_alias=not cls.__model__.model_config.get("populate_by_name", False),
use_alias=not cls.__model__.model_config.get( # pyright: ignore[reportGeneralTypeIssues]
"populate_by_name",
False,
),
)
for field_name, field_info in cls.__model__.model_fields.items()
for field_name, field_info in cls.__model__.model_fields.items() # pyright: ignore[reportGeneralTypeIssues]
]
return cls._fields_metadata

Expand Down Expand Up @@ -392,13 +441,11 @@ def build(
processed_kwargs = cls.process_kwargs(**kwargs)

if factory_use_construct:
return (
cls.__model__.model_construct(**processed_kwargs)
if hasattr(cls.__model__, "model_construct")
else cls.__model__.construct(**processed_kwargs)
)
if _is_pydantic_v1_model(cls.__model__):
return cls.__model__.construct(**processed_kwargs) # type: ignore[return-value]
return cls.__model__.model_construct(**processed_kwargs) # type: ignore[return-value]

return cls.__model__(**processed_kwargs)
return cls.__model__(**processed_kwargs) # type: ignore[return-value]

@classmethod
def is_custom_root_field(cls, field_meta: FieldMeta) -> bool:
Expand Down Expand Up @@ -457,29 +504,28 @@ def get_provider_map(cls) -> dict[Any, Callable[[], Any]]:
pydantic.FutureDate: cls.__faker__.future_date,
}

if pydantic.VERSION.startswith("1"):
# v1 only values - these will raise an exception in v2
# in pydantic v2 these are all aliases for Annotated with a constraint.
# we therefore do not need them in v2
mapping.update(
{
pydantic.PyObject: lambda: "decimal.Decimal",
pydantic.AmqpDsn: lambda: "amqps://example.com",
pydantic.KafkaDsn: lambda: "kafka://localhost:9092",
pydantic.PostgresDsn: lambda: "postgresql://user:secret@localhost",
pydantic.RedisDsn: lambda: "redis://localhost:6379/0",
pydantic.FilePath: lambda: Path(realpath(__file__)),
pydantic.DirectoryPath: lambda: Path(realpath(__file__)).parent,
pydantic.UUID1: uuid1,
pydantic.UUID3: lambda: uuid3(NAMESPACE_DNS, cls.__faker__.pystr()),
pydantic.UUID4: cls.__faker__.uuid4,
pydantic.UUID5: lambda: uuid5(NAMESPACE_DNS, cls.__faker__.pystr()),
pydantic.color.Color: cls.__faker__.hex_color, # pyright: ignore[reportGeneralTypeIssues]
},
)
else:
# v1 only values
mapping.update(
{
PyObject: lambda: "decimal.Decimal",
AmqpDsn: lambda: "amqps://example.com",
KafkaDsn: lambda: "kafka://localhost:9092",
PostgresDsn: lambda: "postgresql://user:secret@localhost",
RedisDsn: lambda: "redis://localhost:6379/0",
FilePath: lambda: Path(realpath(__file__)),
DirectoryPath: lambda: Path(realpath(__file__)).parent,
UUID1: uuid1,
UUID3: lambda: uuid3(NAMESPACE_DNS, cls.__faker__.pystr()),
UUID4: cls.__faker__.uuid4,
UUID5: lambda: uuid5(NAMESPACE_DNS, cls.__faker__.pystr()),
Color: cls.__faker__.hex_color, # pyright: ignore[reportGeneralTypeIssues]
},
)

if not _IS_PYDANTIC_V1:
mapping.update(
{
# pydantic v2 specific types
pydantic.PastDatetime: cls.__faker__.past_datetime,
pydantic.FutureDatetime: cls.__faker__.future_datetime,
pydantic.AwareDatetime: partial(cls.__faker__.date_time, timezone.utc),
Expand All @@ -489,3 +535,11 @@ def get_provider_map(cls) -> dict[Any, Callable[[], Any]]:

mapping.update(super().get_provider_map())
return mapping


def _is_pydantic_v1_model(model: Any) -> TypeGuard[BaseModelV1]:
return is_safe_subclass(model, BaseModelV1)


def _is_pydantic_v2_model(model: Any) -> TypeGuard[BaseModelV2]:
return not _IS_PYDANTIC_V1 and is_safe_subclass(model, BaseModelV2)
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ classmethod-decorators = [

[tool.ruff.isort]
known-first-party = ["polyfactory", "tests"]
section-order = ["future", "standard-library", "third-party", "pydantic", "pydantic_v1", "first-party", "local-folder"]

[tool.ruff.isort.sections]
pydantic = ["pydantic", "pydantic_core"]
pydantic_v1 = ["pydantic.v1"]

[tool.ruff.per-file-ignores]
"**/*.*" = ["ANN101", "ANN401", "ANN102", "TD",]
Expand Down
1 change: 1 addition & 0 deletions tests/constraints/test_date_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from hypothesis import given
from hypothesis.strategies import dates

from pydantic import BaseModel, condate

from polyfactory.factories.pydantic_factory import ModelFactory
Expand Down
1 change: 1 addition & 0 deletions tests/constraints/test_decimal_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from hypothesis import given
from hypothesis.strategies import decimals, integers

from pydantic import BaseModel, condecimal

from polyfactory.exceptions import ParameterException
Expand Down
1 change: 1 addition & 0 deletions tests/constraints/test_list_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from hypothesis import given
from hypothesis.strategies import integers

from pydantic import VERSION

from polyfactory.exceptions import ParameterException
Expand Down
3 changes: 2 additions & 1 deletion tests/test_annotated_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from typing import Literal, Tuple, Union

from annotated_types import Ge, Le, LowerCase, UpperCase
from pydantic import BaseModel, Field
from typing_extensions import Annotated

from pydantic import BaseModel, Field

from polyfactory.factories import DataclassFactory
from polyfactory.factories.pydantic_factory import ModelFactory

Expand Down
1 change: 1 addition & 0 deletions tests/test_base_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict

import pytest

from pydantic.main import BaseModel

from polyfactory.factories import DataclassFactory
Expand Down
1 change: 1 addition & 0 deletions tests/test_build.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from uuid import uuid4

import pytest

from pydantic import VERSION, BaseModel, Field, ValidationError

from polyfactory.factories.pydantic_factory import ModelFactory
Expand Down
1 change: 1 addition & 0 deletions tests/test_collection_length.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, List, Optional, Set, Tuple

import pytest

from pydantic.dataclasses import dataclass

from polyfactory.factories import DataclassFactory
Expand Down
1 change: 1 addition & 0 deletions tests/test_complex_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)

import pytest

from pydantic import VERSION, BaseModel

from polyfactory.exceptions import ParameterException
Expand Down
1 change: 1 addition & 0 deletions tests/test_constrained_attribute_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict, List, Optional, Tuple

import pytest

from pydantic import (
VERSION,
BaseModel,
Expand Down
5 changes: 3 additions & 2 deletions tests/test_data_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from typing import Callable, Literal, Optional
from uuid import UUID

import pydantic
import pytest
from typing_extensions import TypeAlias

import pydantic
from pydantic import (
UUID1,
UUID3,
Expand Down Expand Up @@ -56,7 +58,6 @@
StrictInt,
StrictStr,
)
from typing_extensions import TypeAlias

from polyfactory.exceptions import ParameterException
from polyfactory.factories.pydantic_factory import ModelFactory
Expand Down
1 change: 1 addition & 0 deletions tests/test_dicts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, Union

import pytest

from pydantic import VERSION, BaseModel

from polyfactory.factories.pydantic_factory import ModelFactory
Expand Down
1 change: 1 addition & 0 deletions tests/test_factory_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Optional, Union

import pytest

from pydantic import BaseModel

from polyfactory.decorators import post_generated
Expand Down
1 change: 1 addition & 0 deletions tests/test_factory_subclassing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass

import pytest

from pydantic import BaseModel

from polyfactory import ConfigurationException
Expand Down
1 change: 1 addition & 0 deletions tests/test_new_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union

import pytest

from pydantic import (
VERSION,
BaseModel,
Expand Down
Loading

0 comments on commit 46ecdc6

Please sign in to comment.