Skip to content

Commit

Permalink
fix: fix json types
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewthetechie committed Jan 21, 2023
1 parent 779b1eb commit 38a2608
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 35 deletions.
14 changes: 6 additions & 8 deletions pydantic_aioredis/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,18 @@
import json
from datetime import date
from datetime import datetime
from ipaddress import IPv4Address
from ipaddress import IPv4Network
from ipaddress import IPv6Address
from ipaddress import IPv6Network
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from uuid import UUID

from pydantic import BaseModel
from pydantic_aioredis.config import RedisConfig
from redis import asyncio as aioredis

# STR_DUMP_SHAPES are object types that are serialized to strings using str(obj)
# They are stored in redis as strings and rely on pydantic to deserialize them
STR_DUMP_SHAPES = (IPv4Address, IPv4Network, IPv6Address, IPv6Network, UUID)
from .types import JSON_DUMP_SHAPES
from .types import STR_DUMP_SHAPES


class _AbstractStore(BaseModel):
Expand Down Expand Up @@ -89,6 +83,8 @@ def serialize_partially(cls, data: Dict[str, Any]):
continue
if cls.__fields__[field].type_ not in [str, float, int]:
data[field] = json.dumps(data[field], default=cls.json_default)
if getattr(cls.__fields__[field], "shape", None) in JSON_DUMP_SHAPES:
data[field] = json.dumps(data[field], default=cls.json_default)
if getattr(cls.__fields__[field], "allow_none", False):
if data[field] is None:
data[field] = "None"
Expand All @@ -107,6 +103,8 @@ def deserialize_partially(cls, data: Dict[bytes, Any]):
continue
if cls.__fields__[field].type_ not in [str, float, int]:
data[field] = json.loads(data[field], object_hook=cls.json_object_hook)
if getattr(cls.__fields__[field], "shape", None) in JSON_DUMP_SHAPES:
data[field] = json.loads(data[field], object_hook=cls.json_object_hook)
if getattr(cls.__fields__[field], "allow_none", False):
if data[field] == "None":
data[field] = None
Expand Down
34 changes: 34 additions & 0 deletions pydantic_aioredis/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from enum import Enum
from ipaddress import IPv4Address
from ipaddress import IPv4Network
from ipaddress import IPv6Address
from ipaddress import IPv6Network
from uuid import UUID

from pydantic.fields import SHAPE_DEFAULTDICT
from pydantic.fields import SHAPE_DICT
from pydantic.fields import SHAPE_FROZENSET
from pydantic.fields import SHAPE_LIST
from pydantic.fields import SHAPE_MAPPING
from pydantic.fields import SHAPE_SEQUENCE
from pydantic.fields import SHAPE_SET
from pydantic.fields import SHAPE_TUPLE
from pydantic.fields import SHAPE_TUPLE_ELLIPSIS

# JSON_DUMP_SHAPES are object types that are serialized to JSON using json.dumps
JSON_DUMP_SHAPES = (
SHAPE_LIST,
SHAPE_SET,
SHAPE_MAPPING,
SHAPE_TUPLE,
SHAPE_TUPLE_ELLIPSIS,
SHAPE_SEQUENCE,
SHAPE_FROZENSET,
SHAPE_DICT,
SHAPE_DEFAULTDICT,
Enum,
)

# STR_DUMP_SHAPES are object types that are serialized to strings using str(obj)
# They are stored in redis as strings and rely on pydantic to deserialize them
STR_DUMP_SHAPES = (IPv4Address, IPv4Network, IPv6Address, IPv6Network, UUID)
54 changes: 27 additions & 27 deletions test/test_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,48 +29,48 @@ class SimpleModel(Model):
test_tuple: Tuple[str]


def test_serialize_partially_skip_missing_field():
serialized = SimpleModel.serialize_partially({"unknown": "test"})
assert serialized["unknown"] == "test"


parameters = [
(st.text, [], {}, "test_str", None),
(st.integers, [], {}, "test_int", None),
(st.floats, [], {"allow_nan": False}, "test_float", None),
(st.dates, [], {}, "test_date", lambda x: json.dumps(x.isoformat())),
(st.datetimes, [], {}, "test_datetime", lambda x: json.dumps(x.isoformat())),
(st.ip_addresses, [], {"v": 4}, "test_ip_v4", lambda x: json.dumps(str(x))),
(st.ip_addresses, [], {"v": 6}, "test_ip_v4", lambda x: json.dumps(str(x))),
(
st.lists,
[st.tuples(st.integers(), st.floats())],
{},
"test_list",
lambda x: json.dumps(x),
),
(st.text, [], {}, "test_str", str, False),
(st.dates, [], {}, "test_date", str, False),
(st.datetimes, [], {}, "test_datetime", str, False),
(st.ip_addresses, [], {"v": 4}, "test_ip_v4", str, False),
(st.ip_addresses, [], {"v": 6}, "test_ip_v4", str, False),
(st.lists, [st.tuples(st.integers(), st.floats())], {}, "test_list", str, False),
(
st.dictionaries,
[st.text(), st.tuples(st.integers(), st.floats())],
{},
"test_dict",
lambda x: json.dumps(x),
str,
False,
),
(st.tuples, [st.text()], {}, "test_tuple", lambda x: json.dumps(x)),
(st.tuples, [st.text()], {}, "test_tuple", str, False),
(st.floats, [], {"allow_nan": False}, "test_float", float, True),
(st.integers, [], {}, "test_int", int, True),
]


@pytest.mark.parametrize(
"strategy, strategy_args, strategy_kwargs, model_field, serialize_callable",
"strategy, strategy_args, strategy_kwargs, model_field, expected_type, equality_expected",
parameters,
)
@given(st.data())
def test_serialize_partially(
strategy, strategy_args, strategy_kwargs, model_field, serialize_callable, data
strategy,
strategy_args,
strategy_kwargs,
model_field,
expected_type,
equality_expected,
data,
):
value = data.draw(strategy(*strategy_args, **strategy_kwargs))
serialized = SimpleModel.serialize_partially({model_field: value})
if serialize_callable is None:
assert serialized[model_field] == value
else:
assert serialized[model_field] == serialize_callable(value)


def test_serialize_partially_skip_missing_filed():
serialized = SimpleModel.serialize_partially({"unknown": "test"})
assert serialized["unknown"] == "test"
assert isinstance(serialized.get(model_field), expected_type)
if equality_expected:
assert serialized.get(model_field) == value
22 changes: 22 additions & 0 deletions test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ipaddress import IPv6Address
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

Expand Down Expand Up @@ -93,3 +94,24 @@ class UpdateModel(Model):

redis_model = await UpdateModel.select(ids=[test_str])
assert redis_model[0].test_int == update_int


async def test_storing_list(redis_store):
# https://github.com/andrewthetechie/pydantic-aioredis/issues/403
class DataTypeTest(Model):
_primary_key_field: str = "key"

key: str
value: List[int]

redis_store.register_model(DataTypeTest)
key = "test_list_storage"
instance = DataTypeTest(
key=key,
value=[1, 2, 3],
)
await instance.save()

instance_in_redis = await DataTypeTest.select()
assert instance_in_redis[0].key == instance.key
assert len(instance_in_redis[0].value) == len(instance.value)

0 comments on commit 38a2608

Please sign in to comment.