Skip to content

Commit

Permalink
Merge branch 'main' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
guacs authored May 12, 2024
2 parents 7d0bbc8 + 2f781ee commit c0b4db5
Show file tree
Hide file tree
Showing 13 changed files with 165 additions and 16 deletions.
9 changes: 9 additions & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,15 @@
"contributions": [
"doc"
]
},
{
"login": "wangxin688",
"name": "jeffry",
"avatar_url": "https://avatars.githubusercontent.com/u/36665036?v=4",
"profile": "https://github.com/wangxin688",
"contributions": [
"code"
]
}
],
"contributorsPerLine": 7,
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ repos:
- id: prettier
exclude: ".all-contributorsrc"
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.9.0"
rev: "v1.10.0"
hooks:
- id: mypy
exclude: "test_decimal_constraints|examples/fields/test_example_2|examples/configuration|tools/"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
<td align="center" valign="top" width="14.28%"><a href="https://github.com/hsorsky"><img src="https://avatars.githubusercontent.com/u/36887638?v=4?s=100" width="100px;" alt="Henry Sorsky"/><br /><sub><b>Henry Sorsky</b></sub></a><br /><a href="#infra-hsorsky" title="Infrastructure (Hosting, Build-Tools, etc)">🚇</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/wer153"><img src="https://avatars.githubusercontent.com/u/23370765?v=4?s=100" width="100px;" alt="Kim Minki"/><br /><sub><b>Kim Minki</b></sub></a><br /><a href="https://github.com/litestar-org/polyfactory/commits?author=wer153" title="Documentation">📖</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://www.timdumol.com/"><img src="https://avatars.githubusercontent.com/u/49169?v=4?s=100" width="100px;" alt="Tim Joseph Dumol"/><br /><sub><b>Tim Joseph Dumol</b></sub></a><br /><a href="https://github.com/litestar-org/polyfactory/commits?author=TimDumol" title="Documentation">📖</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/wangxin688"><img src="https://avatars.githubusercontent.com/u/36665036?v=4?s=100" width="100px;" alt="jeffry"/><br /><sub><b>jeffry</b></sub></a><br /><a href="https://github.com/litestar-org/polyfactory/commits?author=wangxin688" title="Code">💻</a></td>
</tr>
</tbody>
</table>
Expand Down
3 changes: 1 addition & 2 deletions docs/examples/decorators/test_example_1.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import cast

from polyfactory.decorators import post_generated
from polyfactory.factories import DataclassFactory
Expand All @@ -16,7 +15,7 @@ class DatetimeRangeFactory(DataclassFactory[DatetimeRange]):
@post_generated
@classmethod
def to_dt(cls, from_dt: datetime) -> datetime:
return from_dt + cast(timedelta, cls.__faker__.time_delta("+3d"))
return from_dt + cls.__faker__.time_delta("+3d")


def test_post_generated() -> None:
Expand Down
4 changes: 2 additions & 2 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def _create_generic_fn() -> Callable:
# standard library objects
Path: lambda: Path(realpath(__file__)),
Decimal: cls.__faker__.pydecimal,
UUID: lambda: UUID(cls.__faker__.uuid4()),
UUID: lambda: UUID(str(cls.__faker__.uuid4())),
# datetime
datetime: cls.__faker__.date_time_between,
date: cls.__faker__.date_this_decade,
Expand Down Expand Up @@ -777,7 +777,7 @@ def get_field_value_coverage( # noqa: C901
"""
if cls.is_ignored_type(field_meta.annotation):
return [None]
return

for unwrapped_annotation in flatten_annotation(field_meta.annotation):
if unwrapped_annotation in (None, NoneType):
Expand Down
2 changes: 1 addition & 1 deletion polyfactory/factories/beanie_odm_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class BeaniePersistenceHandler(Generic[T], AsyncPersistenceProtocol[T]):

async def save(self, data: T) -> T:
"""Persist a single instance in mongoDB."""
return await data.insert() # type: ignore[no-any-return]
return await data.insert() # pyright: ignore[reportGeneralTypeIssues]

async def save_many(self, data: list[T]) -> list[T]:
"""Persist multiple instances in mongoDB.
Expand Down
2 changes: 1 addition & 1 deletion polyfactory/factories/pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def should_set_field_value(cls, field_meta: FieldMeta, **kwargs: Any) -> bool:

@classmethod
def get_provider_map(cls) -> dict[Any, Callable[[], Any]]:
mapping = {
mapping: dict[Any, Callable[[], Any]] = {
pydantic.ByteSize: cls.__faker__.pyint,
pydantic.PositiveInt: cls.__faker__.pyint,
pydantic.NegativeFloat: lambda: cls.__random__.uniform(-100, -1),
Expand Down
14 changes: 13 additions & 1 deletion polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,14 @@ def __init__(self, session: AsyncSession) -> None:
async def save(self, data: T) -> T:
self.session.add(data)
await self.session.commit()
await self.session.refresh(data)
return data

async def save_many(self, data: list[T]) -> list[T]:
self.session.add_all(data)
await self.session.commit()
for batch_item in data:
await self.session.refresh(batch_item)
return data


Expand Down Expand Up @@ -95,6 +98,9 @@ def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]:
postgresql.NUMRANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])),
postgresql.TSRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005
postgresql.TSTZRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005
postgresql.HSTORE: lambda: cls.__faker__.pydict(),
# `types.JSON` is compatible for sqlachemy extend dialects. Such as `pg.JSON` and `JSONB`
types.JSON: lambda: cls.__faker__.pydict(),
}

@classmethod
Expand Down Expand Up @@ -124,8 +130,14 @@ def should_column_be_set(cls, column: Any) -> bool:
@classmethod
def get_type_from_column(cls, column: Column) -> type:
column_type = type(column.type)
if column_type in cls.get_sqlalchemy_types():
sqla_types = cls.get_sqlalchemy_types()
if column_type in sqla_types:
annotation = column_type
elif issubclass(column_type, postgresql.ARRAY):
if type(column.type.item_type) in sqla_types: # type: ignore[attr-defined]
annotation = List[type(column.type.item_type)] # type: ignore[attr-defined,misc]
else:
annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined]
elif issubclass(column_type, types.ARRAY):
annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined]
else:
Expand Down
4 changes: 2 additions & 2 deletions polyfactory/value_generators/constrained_dates.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from datetime import date, datetime, timedelta, timezone, tzinfo
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from faker import Faker
Expand Down Expand Up @@ -38,4 +38,4 @@ def handle_constrained_date(
elif lt:
end_date = lt - timedelta(days=1)

return cast("date", faker.date_between(start_date=start_date, end_date=end_date))
return faker.date_between(start_date=start_date, end_date=end_date)
66 changes: 65 additions & 1 deletion tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Callable, Type, Union
from uuid import UUID

import pytest
from sqlalchemy import Column, ForeignKey, Integer, String, create_engine, inspect, orm, types
from sqlalchemy import (
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
String,
create_engine,
func,
inspect,
orm,
text,
types,
)
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.ext.hybrid import hybrid_property
Expand All @@ -14,6 +28,7 @@
from polyfactory.exceptions import ConfigurationException
from polyfactory.factories.base import BaseFactory
from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory
from polyfactory.fields import Ignore


@pytest.fixture()
Expand Down Expand Up @@ -336,6 +351,55 @@ class Factory(SQLAlchemyFactory[AsyncModel]):
assert inspect(batch_item).persistent # type: ignore[union-attr]


@pytest.mark.parametrize(
"session_config",
(
lambda session: session,
lambda session: (lambda: session),
),
)
async def test_async_server_default_refresh(
async_engine: AsyncEngine,
session_config: Callable[[AsyncSession], Any],
) -> None:
_registry = registry()

class Base(metaclass=DeclarativeMeta):
__abstract__ = True

registry = _registry
metadata = _registry.metadata

class AsyncRefreshModel(Base):
__tablename__ = "server_default_test"

id: Any = Column(Integer(), primary_key=True)
test_datetime: Any = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
test_str: Any = Column(String, nullable=False, server_default=text("test_str"))
test_int: Any = Column(Integer, nullable=False, server_default=text("123"))
test_bool: Any = Column(Boolean, nullable=False, server_default=text("False"))

await create_tables(async_engine, Base)

async with AsyncSession(async_engine) as session:

class Factory(SQLAlchemyFactory[AsyncRefreshModel]):
__async_session__ = session_config(session)
__model__ = AsyncRefreshModel
test_datetime = Ignore()
test_str = Ignore()
test_int = Ignore()
test_bool = Ignore()

result = await Factory.create_async()
assert inspect(result).persistent # type: ignore[union-attr]
assert result.test_datetime is not None
assert isinstance(result.test_datetime, datetime)
assert result.test_str == "test_str"
assert result.test_int == 123
assert result.test_bool is False


def test_alias() -> None:
class ModelWithAlias(Base):
__tablename__ = "table"
Expand Down
68 changes: 66 additions & 2 deletions tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from enum import Enum
from typing import Any, List
from ipaddress import ip_network
from typing import Any, Dict, List
from uuid import UUID

import pytest
from sqlalchemy import ForeignKey, __version__, orm, types
from sqlalchemy import ForeignKey, Text, __version__, orm, types
from sqlalchemy.dialects.mssql import JSON as MSSQL_JSON
from sqlalchemy.dialects.mysql import JSON as MYSQL_JSON
from sqlalchemy.dialects.postgresql import ARRAY, CIDR, HSTORE, INET, JSON, JSONB
from sqlalchemy.dialects.sqlite import JSON as SQLITE_JSON
from sqlalchemy.ext.mutable import MutableDict, MutableList

from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory

Expand Down Expand Up @@ -64,6 +71,63 @@ class ModelFactory(SQLAlchemyFactory[Model]):
assert isinstance(instance.str_array_type[0], str)


def test_pg_dialect_types() -> None:
class Base(orm.DeclarativeBase): ...

class SqlaModel(Base):
__tablename__ = "sql_models"
id: orm.Mapped[int] = orm.mapped_column(primary_key=True)
uuid_type: orm.Mapped[UUID] = orm.mapped_column(type_=types.UUID)
nested_array_inet: orm.Mapped[List[str]] = orm.mapped_column(type_=ARRAY(INET, dimensions=1))
nested_array_cidr: orm.Mapped[List[str]] = orm.mapped_column(type_=ARRAY(CIDR, dimensions=1))
hstore_type: orm.Mapped[Dict] = orm.mapped_column(type_=HSTORE)
mut_nested_arry_inet: orm.Mapped[List[str]] = orm.mapped_column(
type_=MutableList.as_mutable(ARRAY(INET, dimensions=1))
)
pg_json_type: orm.Mapped[Dict] = orm.mapped_column(type_=JSON)
pg_jsonb_type: orm.Mapped[Dict] = orm.mapped_column(type_=JSONB)
common_json_type: orm.Mapped[Dict] = orm.mapped_column(type_=types.JSON)
mysql_json: orm.Mapped[Dict] = orm.mapped_column(type_=MYSQL_JSON)
sqlite_json: orm.Mapped[Dict] = orm.mapped_column(type_=SQLITE_JSON)
mssql_json: orm.Mapped[Dict] = orm.mapped_column(type_=MSSQL_JSON)

multible_pg_json_type: orm.Mapped[Dict] = orm.mapped_column(
type_=MutableDict.as_mutable(JSON(astext_type=Text())) # type: ignore[no-untyped-call]
)
multible_pg_jsonb_type: orm.Mapped[Dict] = orm.mapped_column(
type_=MutableDict.as_mutable(JSONB(astext_type=Text())) # type: ignore[no-untyped-call]
)
multible_common_json_type: orm.Mapped[Dict] = orm.mapped_column(type_=MutableDict.as_mutable(types.JSON()))
multible_mysql_json: orm.Mapped[Dict] = orm.mapped_column(type_=MutableDict.as_mutable(MYSQL_JSON()))
multible_sqlite_json: orm.Mapped[Dict] = orm.mapped_column(type_=MutableDict.as_mutable(SQLITE_JSON()))
multible_mssql_json: orm.Mapped[Dict] = orm.mapped_column(type_=MutableDict.as_mutable(MSSQL_JSON()))

class ModelFactory(SQLAlchemyFactory[SqlaModel]):
__model__ = SqlaModel

instance = ModelFactory.build()
assert isinstance(instance.nested_array_inet[0], str)
assert ip_network(instance.nested_array_inet[0])
assert isinstance(instance.nested_array_cidr[0], str)
assert ip_network(instance.nested_array_cidr[0])
assert isinstance(instance.hstore_type, dict)
assert isinstance(instance.uuid_type, UUID)
assert isinstance(instance.mut_nested_arry_inet[0], str)
assert ip_network(instance.mut_nested_arry_inet[0])
assert isinstance(instance.pg_json_type, dict)
assert isinstance(instance.pg_jsonb_type, dict)
assert isinstance(instance.common_json_type, dict)
assert isinstance(instance.mysql_json, dict)
assert isinstance(instance.sqlite_json, dict)
assert isinstance(instance.mssql_json, dict)
assert isinstance(instance.multible_pg_json_type, dict)
assert isinstance(instance.multible_pg_jsonb_type, dict)
assert isinstance(instance.multible_common_json_type, dict)
assert isinstance(instance.multible_mysql_json, dict)
assert isinstance(instance.multible_sqlite_json, dict)
assert isinstance(instance.multible_mssql_json, dict)


@pytest.mark.parametrize(
"type_",
tuple(SQLAlchemyFactory.get_sqlalchemy_types().keys()),
Expand Down
4 changes: 2 additions & 2 deletions tests/test_random_configuration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from random import Random
from typing import List, Union, cast
from typing import List, Union

import pytest
from faker import Faker
Expand Down Expand Up @@ -65,7 +65,7 @@ class FooFactory(DataclassFactory[Foo]):

@classmethod
def foo(cls) -> int:
return cast(int, cls.__faker__.random_digit())
return cls.__faker__.random_digit()

assert FooFactory.build().foo == RANDINT_MAP[seed]

Expand Down
2 changes: 1 addition & 1 deletion tests/test_recursive_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_recursive_pydantic_models(factory_use_construct: bool) -> None:
factory = ModelFactory.create_factory(PydanticNode)

result = factory.build(factory_use_construct)
assert result.child is _Sentinel, "Default is not used"
assert result.child is _Sentinel, "Default is not used" # type: ignore[comparison-overlap]
assert isinstance(result.union_child, int)
assert result.optional_child is None
assert result.list_child == []
Expand Down

0 comments on commit c0b4db5

Please sign in to comment.