Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better tests and small fixes #33

Merged
merged 6 commits into from
Jul 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion runtype/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ def validate_instance(self, obj, sampler: Optional[SamplerType]=None):
validate only a sample of the object. This approach may validate much faster,
but might miss anomalies in the data.
"""
...


def test_instance(self, obj, sampler=None):
Expand Down
24 changes: 11 additions & 13 deletions runtype/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def dataclass_transform(*a, **kw):
from .utils import ForwardRef
from .common import CHECK_TYPES
from .validation import TypeMismatchError, ensure_isa as default_ensure_isa
from .pytypes import TypeCaster, SumType, NoneType, ATypeCaster, PythonType
from .pytypes import TypeCaster, SumType, NoneType, ATypeCaster, PythonType, type_caster

Required = object()
MAX_SAMPLE_SIZE = 16
Expand Down Expand Up @@ -79,7 +79,6 @@ def make_type_caster(self, frame) -> ATypeCaster:
def ensure_isa(self, a, b, sampler=None):
"""Ensure that 'a' is an instance of type 'b'. If not, raise a TypeError.
"""
...

@abstractmethod
def cast(self, obj, t):
Expand All @@ -88,7 +87,6 @@ def cast(self, obj, t):
The result is expected to pass `self.ensure_isa(res, t)` without an error,
however this assertion is not validated, for performance reasons.
"""
...


class PythonConfiguration(Configuration):
Expand Down Expand Up @@ -203,21 +201,22 @@ def replace(inst, **kwargs):

def __iter__(inst):
"Yields a list of tuples [(name, value), ...]"
# TODO: deprecate this method
return ((name, getattr(inst, name)) for name in inst.__dataclass_fields__)

def asdict(inst):
"""Returns a dict of {name: value, ...}
"""
return {name: getattr(inst, name) for name in inst.__dataclass_fields__}

def aslist(inst):
"""Returns a list of values

Equivalent to: ``list(dict(inst).values())``
"""Returns a list of the values
"""
return [getattr(inst, name) for name in inst.__dataclass_fields__]


def astuple(inst):
"""Returns a tuple of values

Equivalent to: ``tuple(dict(inst).values())``
"""Returns a tuple of the values
"""
return tuple(getattr(inst, name) for name in inst.__dataclass_fields__)

Expand Down Expand Up @@ -317,6 +316,7 @@ def __post_init__(self):

_set_if_not_exists(c, {
'replace': replace,
'asdict': asdict,
'aslist': aslist,
'astuple': astuple,
'json': json,
Expand Down Expand Up @@ -391,13 +391,11 @@ def dataclass(
unsafe_hash: bool = False,
frozen: bool = True,
slots: bool = ...,
) -> Callable[[Type[_T]], Type[_T]]:
...
) -> Callable[[Type[_T]], Type[_T]]: ...

@dataclass_transform(field_specifiers=(dataclasses.field, dataclasses.Field), frozen_default=True)
@overload
def dataclass(_cls: Type[_T]) -> Type[_T]:
...
def dataclass(_cls: Type[_T]) -> Type[_T]: ...

@dataclass_transform(field_specifiers=(dataclasses.field, dataclasses.Field), frozen_default=True)
def dataclass(cls: Optional[Type[_T]]=None, *,
Expand Down
44 changes: 5 additions & 39 deletions runtype/datetime_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,13 @@ class DateTimeError(Exception):
MS_WATERSHED = int(2e10)
# slightly more than datetime.max in ns - (datetime.max - EPOCH).total_seconds() * 1e9
MAX_NUMBER = int(3e20)
StrBytesIntFloat = Union[str, bytes, int, float]


def get_numeric(value: StrBytesIntFloat, native_expected_type: str) -> Union[None, int, float]:
if isinstance(value, (int, float)):
return value
def get_numeric(value: str, native_expected_type: str) -> Union[None, int, float]:
try:
return float(value)
except ValueError:
return None
except TypeError:
raise TypeError(f'invalid type; expected {native_expected_type}, string, bytes, int or float')


def from_unix_seconds(seconds: Union[int, float]) -> datetime:
Expand Down Expand Up @@ -107,26 +102,17 @@ def _parse_timezone(value: Optional[str], error: Type[Exception]) -> Union[None,
return None


def parse_date(value: Union[date, StrBytesIntFloat]) -> date:
def parse_date(value: str) -> date:
"""
Parse a date/int/float/string and return a datetime.date.

Raise ValueError if the input is well formatted but not a valid date.
Raise ValueError if the input isn't well formatted.
"""
if isinstance(value, date):
if isinstance(value, datetime):
return value.date()
else:
return value

number = get_numeric(value, 'date')
if number is not None:
return from_unix_seconds(number).date()

if isinstance(value, bytes):
value = value.decode()

match = date_re.match(value) # type: ignore
if match is None:
raise DateError()
Expand All @@ -139,26 +125,20 @@ def parse_date(value: Union[date, StrBytesIntFloat]) -> date:
raise DateError()


def parse_time(value: Union[time, StrBytesIntFloat]) -> time:
def parse_time(value: str) -> time:
"""
Parse a time/string and return a datetime.time.

Raise ValueError if the input is well formatted but not a valid time.
Raise ValueError if the input isn't well formatted, in particular if it contains an offset.
"""
if isinstance(value, time):
return value

number = get_numeric(value, 'time')
if number is not None:
if number >= 86400:
# doesn't make sense since the time time loop back around to 0
raise TimeError()
return (datetime.min + timedelta(seconds=number)).time()

if isinstance(value, bytes):
value = value.decode()

match = time_re.match(value) # type: ignore
if match is None:
raise TimeError()
Expand All @@ -177,7 +157,7 @@ def parse_time(value: Union[time, StrBytesIntFloat]) -> time:
raise TimeError()


def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime:
def parse_datetime(value: str) -> datetime:
"""
Parse a datetime/int/float/string and return a datetime.datetime.

Expand All @@ -187,16 +167,11 @@ def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime:
Raise ValueError if the input is well formatted but not a valid datetime.
Raise ValueError if the input isn't well formatted.
"""
if isinstance(value, datetime):
return value

number = get_numeric(value, 'datetime')
if number is not None:
return from_unix_seconds(number)

if isinstance(value, bytes):
value = value.decode()

match = datetime_re.match(value) # type: ignore
if match is None:
raise DateTimeError()
Expand All @@ -215,23 +190,14 @@ def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime:
raise DateTimeError()


def parse_duration(value: StrBytesIntFloat) -> timedelta:
def parse_duration(value: str) -> timedelta:
"""
Parse a duration int/float/string and return a datetime.timedelta.

The preferred format for durations in Django is '%d %H:%M:%S.%f'.

Also supports ISO 8601 representation.
"""
if isinstance(value, timedelta):
return value

if isinstance(value, (int, float)):
# below code requires a string
value = str(value)
elif isinstance(value, bytes):
value = value.decode()

try:
match = standard_duration_re.match(value) or iso8601_duration_re.match(value)
except TypeError:
Expand Down
1 change: 1 addition & 0 deletions runtype/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class DispatchError(Exception):
pass


# TODO: Remove test_subtypes, replace with support for Type[], like isa(t, Type[t])
class MultiDispatch:
"""Creates a dispatch group for multiple dispatch

Expand Down
60 changes: 48 additions & 12 deletions runtype/pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections import abc
import sys
import typing
from datetime import datetime
from datetime import datetime, date, time, timedelta
from types import FrameType

from .utils import ForwardRef
Expand Down Expand Up @@ -112,12 +112,12 @@ def cast_from(self, obj):
raise TypeMismatchError(obj, self)


def _flatten_types(t):
if isinstance(t, SumType):
for t in t.types:
yield from _flatten_types(t)
else:
yield t
# def _flatten_types(t):
# if isinstance(t, SumType):
# for t in t.types:
# yield from _flatten_types(t)
# else:
# yield t



Expand Down Expand Up @@ -271,12 +271,16 @@ def __getitem__(self, item):
return type(self)(self.base, item)

def cast_from(self, obj):
# Must already be a dict
self.base.validate_instance(obj)

# Optimize for Dict[Any] and empty dicts
if self.item is Any or not obj:
return obj
# Already a dict?
if self.base.test_instance(obj):
return obj
# Make sure it's a dict
return dict(obj)

# Must already be a dict
self.base.validate_instance(obj)

# Recursively cast each item
kt, vt = self.item.types
Expand Down Expand Up @@ -347,6 +351,33 @@ def cast_from(self, obj):
raise TypeMismatchError(obj, self)
return super().cast_from(obj)

class _Date(PythonDataType):
def cast_from(self, obj):
if isinstance(obj, str):
try:
return datetime_parse.parse_date(obj)
except datetime_parse.DateTimeError:
raise TypeMismatchError(obj, self)
return super().cast_from(obj)

class _Time(PythonDataType):
def cast_from(self, obj):
if isinstance(obj, str):
try:
return datetime_parse.parse_time(obj)
except datetime_parse.DateTimeError:
raise TypeMismatchError(obj, self)
return super().cast_from(obj)

class _TimeDelta(PythonDataType):
def cast_from(self, obj):
if isinstance(obj, str):
try:
return datetime_parse.parse_duration(obj)
except datetime_parse.DateTimeError:
raise TypeMismatchError(obj, self)
return super().cast_from(obj)


class _NoneType(OneOf):
def __init__(self):
Expand All @@ -364,6 +395,9 @@ def cast_from(self, obj):
Float = _Float(float)
NoneType = _NoneType()
DateTime = _DateTime(datetime)
Date = _Date(date)
Time = _Time(time)
TimeDelta = _TimeDelta(timedelta)


_type_cast_mapping = {
Expand All @@ -381,7 +415,9 @@ def cast_from(self, obj):
object: Any,
typing.Any: Any,
datetime: DateTime,

date: Date,
time: Time,
timedelta: TimeDelta,
}


Expand Down
26 changes: 24 additions & 2 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
logging.basicConfig(level=logging.INFO)

from runtype import Dispatch, DispatchError, dataclass, isa, is_subtype, issubclass, assert_isa, String, Int, validate_func, cv_type_checking
from runtype.dispatch import MultiDispatch
from runtype.dataclass import Configuration

try:
Expand Down Expand Up @@ -102,6 +103,13 @@ def test_basic2(self):
assert is_subtype(int, a)
assert isa(1, a)

def test_issubclass(self):
# test class tuple
t = int, float
assert issubclass(int, t)
assert issubclass(float, t)
assert not issubclass(str, t)


def test_assert(self):
assert_isa(1, int)
Expand Down Expand Up @@ -247,6 +255,10 @@ def to_list(x:dict):
assert to_list([1]) == [1]
assert to_list({1: 2}) == [(1, 2)]

def test_with(self):
with Dispatch() as d:
assert isinstance(d, MultiDispatch)

def test_ambiguity(self):
dp = Dispatch()

Expand Down Expand Up @@ -588,6 +600,9 @@ def __post_init__(self):
assert p2.aslist() == [30, 3]
assert p2.astuple() == (30, 3)

assert p2.asdict() == {'x':30, 'y':3}
assert list(p2.asdict().keys()) == ['x', 'y'] # test order

self.assertRaises(AssertionError, Point, 0, 2)

self.assertRaises(TypeError, Point, 0, "a") # Before post_init
Expand Down Expand Up @@ -940,8 +955,15 @@ class Bar:
@dataclass
class Foo:
bars: List[Bar]

assert Foo([Bar(0)]).json() == {"bars": [{"baz": 0}]}
d: Dict[str, Bar]

assert Foo(
[Bar(0)],
{"a": Bar(2)}
).json() == {
"bars": [{"baz": 0}],
"d": {"a": {"baz": 2}}
}



Expand Down
Loading