From 0c0fc61548c29911bca58726d72fac47b1c6a2b1 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Sun, 8 May 2022 23:42:46 +0200 Subject: [PATCH] Types: Converted all __eq__, __le__, and __ge__ methods into dispatched funcs --- runtype/base_types.py | 242 +++++++++++++++++------------------------- runtype/pytypes.py | 81 ++++++-------- tests/test_types.py | 45 ++++++-- 3 files changed, 167 insertions(+), 201 deletions(-) diff --git a/runtype/base_types.py b/runtype/base_types.py index 9520be9..51328e4 100644 --- a/runtype/base_types.py +++ b/runtype/base_types.py @@ -9,6 +9,9 @@ from typing import Callable, Sequence, Optional, Union from abc import ABC, abstractmethod +from .dispatch import MultiDispatch +from .typesystem import PythonBasic +dp = MultiDispatch(PythonBasic()) class RuntypeError(TypeError): pass @@ -30,9 +33,6 @@ def __add__(self, other: _Type): def __mul__(self, other: _Type): return ProductType.create((self, other)) - @abstractmethod - def __le__(self, other: _Type): - return NotImplemented class AnyType(Type): """Represents the Any type. @@ -42,21 +42,6 @@ class AnyType(Type): def __add__(self, other): return self - def __ge__(self, other): - if isinstance(other, (type, Type)): - return True - - return NotImplemented - - def __le__(self, other): - if other is self: # Optimization - return True - elif isinstance(other, (type, Type)): - if not isinstance(other, SumType): - return False - - return NotImplemented - def __repr__(self): return 'Any' @@ -69,15 +54,6 @@ class DataType(Type): Example of possible data-types: int, float, text """ - def __le__(self, other): - if isinstance(other, DataType): - return self == other - - return super().__le__(other) - - def __ge__(self, other): - return NotImplemented - class SumType(Type): """Implements a sum type, i.e. a disjoint union of a set of types. @@ -106,23 +82,10 @@ def create(cls, types): def __repr__(self): return '(%s)' % '+'.join(map(repr, self.types)) - def __le__(self, other): - return all(t <= other for t in self.types) - - def __ge__(self, other): - return any(other <= t for t in self.types) - - def __eq__(self, other): - if isinstance(other, SumType): - return self.types == other.types - - return NotImplemented - def __hash__(self): return hash(frozenset(self.types)) - class ProductType(Type): """Implements a product type, i.e. a tuple of types. """ @@ -148,29 +111,6 @@ def __repr__(self): def __hash__(self): return hash(self.types) - def __eq__(self, other): - if isinstance(other, ProductType): - return self.types == other.types - - return NotImplemented - - def __le__(self, other): - if isinstance(other, ProductType): - if len(self.types) != len(other.types): - return False - - return all(t1<=t2 for t1, t2 in zip(self.types, other.types)) - elif isinstance(other, DataType): - return False - - return NotImplemented - - def __ge__(self, other): - if isinstance(other, DataType): - return False - - return NotImplemented - class ContainerType(DataType): """Base class for containers, such as generics. @@ -204,63 +144,15 @@ def __repr__(self): def __getitem__(self, item): return type(self)(self, item) - def __eq__(self, other): - if isinstance(other, GenericType): - return self.base == other.base and self.item == other.item - - elif isinstance(other, Type): - return Any <= self.item and self.base == other - - return NotImplemented - - - def __le__(self, other): - if isinstance(other, GenericType): - return self.base <= other.base and self.item <= other.item - - elif isinstance(other, DataType): - return self.base <= other - - return NotImplemented - - def __ge__(self, other): - if isinstance(other, type(self)): - return self.base >= other.base and self.item >= other.item - - elif isinstance(other, DataType): - return self.base >= other - - return NotImplemented - def __hash__(self): return hash((self.base, self.item)) - class PhantomType(Type): """Implements a base for phantom types. """ def __getitem__(self, other): return PhantomGenericType(self, other) - def __le__(self, other): - if isinstance(other, PhantomType): - return self == other - - elif not isinstance(other, PhantomGenericType): - return False - - return NotImplemented - - - def __ge__(self, other): - if isinstance(other, PhantomType): - return self == other - - elif not isinstance(other, PhantomGenericType): - return False - - return NotImplemented - class PhantomGenericType(Type): """Implements a generic phantom type, for carrying metadata within the type signature. @@ -271,32 +163,6 @@ def __init__(self, base, item=Any): self.base = base self.item = item - def __le__(self, other): - if isinstance(other, PhantomType): - return self.base <= other or self.item <= other - - elif isinstance(other, PhantomGenericType): - return (self.base <= other.base and self.item <= other.item) or self.item <= other - - elif isinstance(other, DataType): - return self.item <= other - - return NotImplemented - - def __eq__(self, other): - if isinstance(other, PhantomGenericType): - return self.base == other.base and self.item == other.base - - return NotImplemented - - def __ge__(self, other): - if isinstance(other, PhantomType): - return False - - elif isinstance(other, DataType): - return other <= self.item - - return NotImplemented SamplerType = Callable[[Sequence], Sequence] @@ -326,7 +192,7 @@ def test_instance(self, obj, sampler=None): return False -class Constraint(Validator, PhantomType): +class Constraint(Validator, Type): """Defines a constraint, which activates during validation. """ @@ -342,9 +208,101 @@ def validate_instance(self, inst, sampler=None): if not p(inst): raise TypeMismatchError(inst, self) - def __ge__(self, other): - # Arbitrary predicates prevent it from being a superclass + +@dp +def le(self, other): + return NotImplemented +@dp(priority=-1) +def le(self: Type, other: Type): + return self == other +@dp +def ge(self, other): + return le(other, self) +@dp +def eq(self, other): + return NotImplemented + +@dp +def eq(self: SumType, other: SumType): + return self.types == other.types +@dp +def eq(self: ProductType, other: ProductType): + return self.types == other.types +@dp +def eq(self: GenericType, other: GenericType): + return self.base == other.base and self.item == other.item +@dp +def eq(self: GenericType, other: Type): + return Any <= self.item and self.base == other +@dp +def eq(self: PhantomGenericType, other: PhantomGenericType): + return self.base == other.base and self.item == other.base + + +# Special cases (AnyType, SumType) +@dp(priority=100) +def le(self: Type, other: AnyType): + return True +@dp +def le(self: type, other: AnyType): + return True + + +@dp(priority=51) +def le(self: SumType, other: Type): + return all(t <= other for t in self.types) + +@dp(priority=50) +def le(self: Type, other: SumType): + return any(self <= t for t in other.types) + +@dp +def le(self: ProductType, other: ProductType): + if len(self.types) != len(other.types): return False - def __le__(self, other): - return self.type <= other + return all(t1<=t2 for t1, t2 in zip(self.types, other.types)) + +# Generics + +@dp +def le(self: GenericType, other: GenericType): + return self.base <= other.base and self.item <= other.item + +@dp +def le(self: GenericType, other: Type): + return self.base <= other +@dp +def le(self: Type, other: GenericType): + return other.item is Any and self <= other.base + + +# Phantom Types + +@dp(priority=1) +def le(self: Type, other: PhantomGenericType): + return self <= other.item + +@dp +def le(self: PhantomGenericType, other: Type): + return self.item <= other + +@dp +def le(self: PhantomGenericType, other: PhantomType): + return self.base <= other or self.item <= other + +# Constraints + +@dp +def le(self: Constraint, other: Constraint): + # Arbitrary predicates prevent it from being a superclass + return self == other + +@dp(priority=1) +def le(self: Constraint, other: Type): + return self.type <= other + + +Type.__eq__ = eq +Type.__le__ = le +Type.__ge__ = ge diff --git a/runtype/pytypes.py b/runtype/pytypes.py index f2ef657..63ec386 100644 --- a/runtype/pytypes.py +++ b/runtype/pytypes.py @@ -14,7 +14,7 @@ from types import FrameType from .utils import ForwardRef -from .base_types import DataType, Validator, TypeMismatchError +from .base_types import DataType, Validator, TypeMismatchError, dp from . import base_types from . import datetime_parse @@ -129,12 +129,6 @@ class PythonDataType(DataType, PythonType): def __init__(self, kernel, supertypes={Any}): self.kernel = kernel - def __le__(self, other): - if isinstance(other, PythonDataType): - return issubclass(self.kernel, other.kernel) - - return NotImplemented - def validate_instance(self, obj, sampler=None): if not isinstance(obj, self.kernel): raise TypeMismatchError(obj, self) @@ -160,28 +154,7 @@ def cast_from(self, obj): return obj - class TupleType(PythonType): - def __le__(self, other): - # No superclasses or subclasses - if other is Any: - return True - - return isinstance(other, TupleType) - - def __ge__(self, other): - if isinstance(other, TupleEllipsisType): - return True - elif isinstance(other, TupleType): - return True - elif isinstance(other, DataType): - return False - elif isinstance(other, ProductType): - # Products are a tuple, but with length and types - return True - - return NotImplemented - def validate_instance(self, obj, sampler=None): if not isinstance(obj, tuple): raise TypeMismatchError(obj, self) @@ -199,25 +172,6 @@ class OneOf(PythonType): def __init__(self, values): self.values = values - def __le__(self, other): - if isinstance(other, OneOf): - return set(self.values) <= set(other.values) - elif isinstance(other, PythonType): - try: - for v in self.values: - other.validate_instance(v) - except TypeMismatchError: - return False - return True - return NotImplemented - - def __ge__(self, other): - if isinstance(other, OneOf): - return set(self.values) >= set(other.values) - elif isinstance(other, PythonType): - return False - return NotImplemented - def validate_instance(self, obj, sampler=None): tok = cv_type_checking.set(True) try: @@ -374,10 +328,6 @@ def __call__(self, min_length=None, max_length=None): return Constraint(self, predicates) - def __le__(self, other): - if isinstance(other, SequenceType): - return self <= other.base and self <= other.item - return super().__le__(other) class _DateTime(PythonDataType): @@ -613,3 +563,32 @@ def to_canon(self, t) -> PythonType: type_caster = TypeCaster() + + +@dp +def le(self: PythonDataType, other: PythonDataType): + return issubclass(self.kernel, other.kernel) +@dp +def le(self: TupleEllipsisType, other: TupleType): + return True +@dp +def le(self: ProductType, other: TupleType): + # Products are a tuple, but with length and specific types + return True + +@dp +def le(self: PythonDataType, other: SequenceType): + return self <= other.base and self <= other.item + +@dp +def le(self: OneOf, other: OneOf): + return set(self.values) <= set(other.values) + +@dp +def le(self: OneOf, other: PythonType): + try: + for v in self.values: + other.validate_instance(v) + except TypeMismatchError: + return False + return True \ No newline at end of file diff --git a/tests/test_types.py b/tests/test_types.py index 1b09e62..ebd18c5 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -48,16 +48,25 @@ def test_phantom(self): assert Int <= P[Int] assert P[Int] <= Int assert P[Int] <= P[Int] - assert P[Int] <= P assert not P <= P[Int] - assert Q[P] <= Q + assert P[Q] <= Q + assert P[Q] <= P + assert not P <= Q[Int] + + assert P[Q[Int]] <= P + assert P[Q[Int]] <= Q + assert P[Q[Int]] <= Int assert P[Q[Int]] <= P[Q] assert P[Q[Int]] <= P[Int] assert P[Q[Int]] <= Q[Int] - assert P[Q[Int]] <= Int + assert P[Q[Int]] <= P[Q[Int]] + assert Int <= P[Q[Int]] + assert P <= P + Int + assert not P <= Dict + assert not P <= Int + Dict def test_pytypes1(self): @@ -114,6 +123,23 @@ def test_constraint(self): assert not i.test_instance(9) assert not i.test_instance(13) + assert int_pair == int_pair + assert int_pair <= int_pair + assert int_pair >= int_pair + assert int_pair <= Any + assert Any >= int_pair + assert not int_pair <= Dict + assert not int_pair <= Int + assert not int_pair <= Int + Dict + assert not int_pair <= Tuple + + assert int_pair <= List + assert List >= int_pair + assert int_pair <= List[Int] + assert List[Int] >= int_pair + assert not int_pair <= List[String] + assert not List[String] >= int_pair + def test_typesystem(self): @@ -142,12 +168,15 @@ def test_pytypes2(self): assert not Tuple <= Int assert not Int <= Tuple - assert Literal([1]) <= Literal([1, 2]) - assert not Literal([1, 3]) <= Literal([1, 2]) - assert Literal([1, 3]) >= Literal([1]) + one = Literal([1]) + one_two = Literal([1, 2]) + one_three = Literal([1, 3]) + assert one <= one_two + assert not one_three <= one_two + assert one_three >= one assert not Literal([1]) <= Tuple - assert not Literal(1) >= Tuple - assert not Tuple <= Literal(1) + assert not Literal([1]) >= Tuple + assert not Tuple <= Literal([1]) Tuple.validate_instance((1, 2)) self.assertRaises(TypeError, Tuple.validate_instance, 1)