From 9e6edabd4ee5242c54f3ddb66e539a415da86901 Mon Sep 17 00:00:00 2001 From: jeffry <36665036+wangxin688@users.noreply.github.com> Date: Sun, 12 May 2024 20:00:02 +0800 Subject: [PATCH] =?UTF-8?q?fix(sqla=5Ffactory):=20fix=20json=20type=20erro?= =?UTF-8?q?r=20and=20pg=20dialect=20default=20value=20e=E2=80=A6=20(#542)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: guacs <126393040+guacs@users.noreply.github.com> --- polyfactory/factories/sqlalchemy_factory.py | 21 ++++++++++++------- .../test_sqlalchemy_factory_v2.py | 14 ++++++++++++- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/polyfactory/factories/sqlalchemy_factory.py b/polyfactory/factories/sqlalchemy_factory.py index f9897dd8..d9b0ac66 100644 --- a/polyfactory/factories/sqlalchemy_factory.py +++ b/polyfactory/factories/sqlalchemy_factory.py @@ -10,7 +10,7 @@ try: from sqlalchemy import Column, inspect, types - from sqlalchemy.dialects import mysql, postgresql + from sqlalchemy.dialects import mssql, mysql, postgresql, sqlite from sqlalchemy.exc import NoInspectionAvailable from sqlalchemy.orm import InstanceState, Mapper except ImportError as e: @@ -85,22 +85,28 @@ class SQLAlchemyFactory(Generic[T], BaseFactory[T]): @classmethod def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]: - """Get mapping of types where column type.""" + """Get mapping of types where column type. + for sqlalchemy dialect `JSON` type, accepted only basic types in pydict in case sqlalchemy process `JSON` raise serialize error. + """ return { types.TupleType: cls.__faker__.pytuple, + mssql.JSON: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)), mysql.YEAR: lambda: cls.__random__.randint(1901, 2155), - postgresql.CIDR: lambda: cls.__faker__.ipv4(network=False), + mysql.JSON: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)), + postgresql.CIDR: lambda: cls.__faker__.ipv4(network=True), postgresql.DATERANGE: lambda: (cls.__faker__.past_date(), date.today()), # noqa: DTZ011 - postgresql.INET: lambda: cls.__faker__.ipv4(network=True), + postgresql.INET: lambda: cls.__faker__.ipv4(network=False), postgresql.INT4RANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])), postgresql.INT8RANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])), postgresql.MACADDR: lambda: cls.__faker__.hexify(text="^^:^^:^^:^^:^^:^^", upper=True), postgresql.NUMRANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])), postgresql.TSRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005 postgresql.TSTZRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005 - postgresql.HSTORE: lambda: cls.__faker__.pydict(), - # `types.JSON` is compatible for sqlachemy extend dialects. Such as `pg.JSON` and `JSONB` - types.JSON: lambda: cls.__faker__.pydict(), + postgresql.HSTORE: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)), + postgresql.JSON: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)), + postgresql.JSONB: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)), + sqlite.JSON: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)), + types.JSON: lambda: cls.__faker__.pydict(value_types=(str, int, bool, float)), } @classmethod @@ -148,7 +154,6 @@ def get_type_from_column(cls, column: Column) -> type: if column.nullable: annotation = Union[annotation, None] # type: ignore[assignment] - return annotation @classmethod diff --git a/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py b/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py index de2ab7e4..b7d7d7d7 100644 --- a/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py +++ b/tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py @@ -71,7 +71,7 @@ class ModelFactory(SQLAlchemyFactory[Model]): assert isinstance(instance.str_array_type[0], str) -def test_pg_dialect_types() -> None: +def test_sqla_dialect_types() -> None: class Base(orm.DeclarativeBase): ... class SqlaModel(Base): @@ -115,11 +115,23 @@ class ModelFactory(SQLAlchemyFactory[SqlaModel]): assert isinstance(instance.mut_nested_arry_inet[0], str) assert ip_network(instance.mut_nested_arry_inet[0]) assert isinstance(instance.pg_json_type, dict) + for value in instance.pg_json_type.values(): + assert isinstance(value, (str, int, bool, float)) assert isinstance(instance.pg_jsonb_type, dict) + for value in instance.pg_jsonb_type.values(): + assert isinstance(value, (str, int, bool, float)) assert isinstance(instance.common_json_type, dict) + for value in instance.common_json_type.values(): + assert isinstance(value, (str, int, bool, float)) assert isinstance(instance.mysql_json, dict) + for value in instance.mysql_json.values(): + assert isinstance(value, (str, int, bool, float)) assert isinstance(instance.sqlite_json, dict) + for value in instance.sqlite_json.values(): + assert isinstance(value, (str, int, bool, float)) assert isinstance(instance.mssql_json, dict) + for value in instance.mssql_json.values(): + assert isinstance(value, (str, int, bool, float)) assert isinstance(instance.multible_pg_json_type, dict) assert isinstance(instance.multible_pg_jsonb_type, dict) assert isinstance(instance.multible_common_json_type, dict)