diff --git a/runtype/base_types.py b/runtype/base_types.py index 9520be9..685d4a3 100644 --- a/runtype/base_types.py +++ b/runtype/base_types.py @@ -1,6 +1,14 @@ """ Base Type Classes - contains the basic building blocks of a generic type system +There are five kinds of types: (to date) +- Any - Contains every type +- Sum - marks a union between types +- Product - marks a record / tuple / struct +- Data - marks any type that contains non-type information +- Phantom - A "meta"-type that can wrap existing types, + but is transparent, and has no effect otherwise. + We use comparison operators to indicate whether a type is a subtype of another: - t1 <= t2 means "t1 is a subtype of t2" - t1 >= t2 means "t2 is a subtype of t1" @@ -9,6 +17,11 @@ from typing import Callable, Sequence, Optional, Union from abc import ABC, abstractmethod +from .dispatch import MultiDispatch +from .typesystem import PythonBasic + +dp = MultiDispatch(PythonBasic()) + class RuntypeError(TypeError): pass @@ -19,46 +32,33 @@ def __str__(self) -> str: v, t = self.args return f"Expected type '{t}', but got value: {v}." -_Type = Union['Type', type] + +_Type = Union["Type", type] + class Type(ABC): - """Abstract Type class. All types inherit from it. - """ + """Abstract Type class. All types inherit from it.""" + def __add__(self, other: _Type): return SumType.create((self, other)) def __mul__(self, other: _Type): return ProductType.create((self, other)) - @abstractmethod - def __le__(self, other: _Type): - return NotImplemented class AnyType(Type): """Represents the Any type. + Any contains every other type. + For any type 't' within the typesystem, t is a subtype of Any (or: t <= Any) """ + def __add__(self, other): return self - def __ge__(self, other): - if isinstance(other, (type, Type)): - return True - - return NotImplemented - - def __le__(self, other): - if other is self: # Optimization - return True - elif isinstance(other, (type, Type)): - if not isinstance(other, SumType): - return False - - return NotImplemented - def __repr__(self): - return 'Any' + return "Any" Any = AnyType() @@ -67,16 +67,10 @@ def __repr__(self): class DataType(Type): """Abstract class for a data type. - Example of possible data-types: int, float, text - """ - def __le__(self, other): - if isinstance(other, DataType): - return self == other - - return super().__le__(other) + A data-type is any type that contains non-type information. - def __ge__(self, other): - return NotImplemented + Example of possible data-types: int, float, text, list + """ class SumType(Type): @@ -84,6 +78,7 @@ class SumType(Type): Similar to Python's `typing.Union`. """ + def __init__(self, types): self.types = frozenset(types) @@ -99,33 +94,19 @@ def create(cls, types): else: x.add(t) - if len(x) == 1: # SumType([x]) is x + if len(x) == 1: # SumType([x]) is x return list(x)[0] return cls(x) def __repr__(self): - return '(%s)' % '+'.join(map(repr, self.types)) - - def __le__(self, other): - return all(t <= other for t in self.types) - - def __ge__(self, other): - return any(other <= t for t in self.types) - - def __eq__(self, other): - if isinstance(other, SumType): - return self.types == other.types - - return NotImplemented + return "(%s)" % "+".join(map(repr, self.types)) def __hash__(self): return hash(frozenset(self.types)) - class ProductType(Type): - """Implements a product type, i.e. a tuple of types. - """ + """Implements a product type, i.e. a record / tuple / struct""" def __init__(self, types): self.types = tuple(types) @@ -143,54 +124,35 @@ def create(cls, types): return cls(x) def __repr__(self): - return '(%s)' % '*'.join(map(repr, self.types)) + return "(%s)" % "*".join(map(repr, self.types)) def __hash__(self): return hash(self.types) - def __eq__(self, other): - if isinstance(other, ProductType): - return self.types == other.types - - return NotImplemented - - def __le__(self, other): - if isinstance(other, ProductType): - if len(self.types) != len(other.types): - return False - - return all(t1<=t2 for t1, t2 in zip(self.types, other.types)) - elif isinstance(other, DataType): - return False - - return NotImplemented - - def __ge__(self, other): - if isinstance(other, DataType): - return False - - return NotImplemented - class ContainerType(DataType): - """Base class for containers, such as generics. - """ + """Base class for containers, such as generics.""" + def __getitem__(self, other): return GenericType(self, other) + class GenericType(ContainerType): """Implements a generic type. i.e. a container for items of a specific type. For any two generic types a[i] and b[j], it's true that a[i] <= b[j] iff a <= b and i <= j. """ + base: Type item: Union[type, Type] - def __init__(self, base: Type, item: Union[type, Type]=Any): + def __init__(self, base: Type, item: Union[type, Type] = Any): assert isinstance(item, (Type, type)), item if isinstance(base, GenericType): if not item <= base.item: - raise TypeError(f"Expecting new generic to be a subtype of base, but {item} = other.base and self.item >= other.item - - elif isinstance(other, DataType): - return self.base >= other - - return NotImplemented - def __hash__(self): return hash((self.base, self.item)) class PhantomType(Type): """Implements a base for phantom types. + + A phantom type is a "meta" type that can wrap existing types, + but it is transparent (subtype checks may skip over it), and has no effect otherwise. """ + def __getitem__(self, other): return PhantomGenericType(self, other) - def __le__(self, other): - if isinstance(other, PhantomType): - return self == other - - elif not isinstance(other, PhantomGenericType): - return False - - return NotImplemented - - - def __ge__(self, other): - if isinstance(other, PhantomType): - return self == other - - elif not isinstance(other, PhantomGenericType): - return False - - return NotImplemented - class PhantomGenericType(Type): """Implements a generic phantom type, for carrying metadata within the type signature. For any phantom type p[i], it's true that p[i] <= p but also p[i] <= i and i <= p[i]. """ + def __init__(self, base, item=Any): self.base = base self.item = item - def __le__(self, other): - if isinstance(other, PhantomType): - return self.base <= other or self.item <= other - - elif isinstance(other, PhantomGenericType): - return (self.base <= other.base and self.item <= other.item) or self.item <= other - - elif isinstance(other, DataType): - return self.item <= other - - return NotImplemented - - def __eq__(self, other): - if isinstance(other, PhantomGenericType): - return self.base == other.base and self.item == other.base - - return NotImplemented - - def __ge__(self, other): - if isinstance(other, PhantomType): - return False - - elif isinstance(other, DataType): - return other <= self.item - - return NotImplemented SamplerType = Callable[[Sequence], Sequence] + class Validator(ABC): - """Defines the validator interface. - """ + """Defines the validator interface.""" + @abstractmethod - def validate_instance(self, obj, sampler: Optional[SamplerType]=None): + def validate_instance(self, obj, sampler: Optional[SamplerType] = None): """Validates obj, raising a TypeMismatchError if it does not conform. - If sampler is provided, it will be applied to the instance in order to + If sampler is provided, it will be applied to the instance in order to validate only a sample of the object. This approach may validate much faster, but might miss anomalies in the data. """ - def test_instance(self, obj, sampler=None): """Tests obj, returning a True/False for whether it conforms or not. - If sampler is provided, it will be applied to the instance in order to + If sampler is provided, it will be applied to the instance in order to validate only a sample of the object. """ try: @@ -325,10 +219,9 @@ def test_instance(self, obj, sampler=None): except TypeMismatchError: return False - -class Constraint(Validator, PhantomType): - """Defines a constraint, which activates during validation. - """ + +class Constraint(Validator, Type): + """Defines a constraint, which activates during validation.""" def __init__(self, for_type, predicates): self.type = for_type @@ -342,9 +235,128 @@ def validate_instance(self, inst, sampler=None): if not p(inst): raise TypeMismatchError(inst, self) - def __ge__(self, other): - # Arbitrary predicates prevent it from being a superclass + +# fmt: off +@dp +def le(self, other): + return NotImplemented + +@dp(priority=-1) +def le(self: Type, other: Type): + return self == other + +@dp +def ge(self, other): + return le(other, self) + +@dp +def eq(self, other): + return NotImplemented + +@dp +def eq(self: SumType, other: SumType): + return self.types == other.types + +@dp +def eq(self: ProductType, other: ProductType): + return self.types == other.types + +@dp +def eq(self: GenericType, other: GenericType): + return self.base == other.base and self.item == other.item + +@dp +def eq(self: GenericType, other: Type): + return Any <= self.item and self.base == other + +@dp +def eq(self: PhantomGenericType, other: PhantomGenericType): + return self.base == other.base and self.item == other.base + + +# le() for AnyType + + +@dp(priority=100) +def le(self: Type, other: AnyType): + # Any contains all types + return True + +@dp +def le(self: type, other: AnyType): + # Any contains all types + return True + + +# le() for SumType + + +@dp(priority=51) +def le(self: SumType, other: Type): + return all(t <= other for t in self.types) + +@dp(priority=50) +def le(self: Type, other: SumType): + return any(self <= t for t in other.types) + + +# le() for ProductType + + +@dp +def le(self: ProductType, other: ProductType): + if len(self.types) != len(other.types): return False - def __le__(self, other): - return self.type <= other + return all(t1 <= t2 for t1, t2 in zip(self.types, other.types)) + + +# le() for GenericType + + +@dp +def le(self: GenericType, other: GenericType): + return self.base <= other.base and self.item <= other.item + +@dp +def le(self: GenericType, other: Type): + return self.base <= other + +@dp +def le(self: Type, other: GenericType): + return other.item is Any and self <= other.base + + +# le() for PhantomType and PhantomGenericType + + +@dp(priority=1) +def le(self: Type, other: PhantomGenericType): + return self <= other.item + +@dp +def le(self: PhantomGenericType, other: Type): + return self.item <= other + +@dp +def le(self: PhantomGenericType, other: PhantomType): + # Only phantom types can match the base of a phantom-generic + return self.base <= other or self.item <= other + +# le() for Constraint + +@dp +def le(self: Constraint, other: Constraint): + # Arbitrary predicates prevent it from being a superclass + return self == other + +@dp(priority=1) +def le(self: Constraint, other: Type): + return self.type <= other + + +Type.__eq__ = eq +Type.__le__ = le +Type.__ge__ = ge + +# fmt: on diff --git a/runtype/dispatch.py b/runtype/dispatch.py index d09a14c..13e948d 100644 --- a/runtype/dispatch.py +++ b/runtype/dispatch.py @@ -1,7 +1,9 @@ from collections import defaultdict from functools import wraps -from typing import Dict, Callable, Sequence +from typing import Any, Dict, Callable, Sequence +from operator import itemgetter +from dataclasses import dataclass from .utils import get_func_signatures from .typesystem import TypeSystem @@ -21,21 +23,39 @@ class MultiDispatch: """ def __init__(self, typesystem: TypeSystem, test_subtypes: Sequence[int] = ()): - self.fname_to_tree : Dict[str, TypeTree] = {} + self.fname_to_tree: Dict[str, TypeTree] = {} self.typesystem: TypeSystem = typesystem self.test_subtypes = test_subtypes - def __call__(self, f): - fname = f.__name__ + def __call__(self, func=None, *, priority=None): + """Decorate the function + + Warning: Priority is still an experimental feature + """ + if func is None: + if priority is None: + raise ValueError( + "Must either provide a function to decorate, or set a priority" + ) + return MultiDispatchWithOptions(self, priority=priority) + + if priority is not None: + raise ValueError( + "Must either provide a function to decorate, or set a priority" + ) + + fname = func.__name__ try: tree = self.fname_to_tree[fname] except KeyError: - tree = self.fname_to_tree[fname] = TypeTree(fname, self.typesystem, self.test_subtypes) + tree = self.fname_to_tree[fname] = TypeTree( + fname, self.typesystem, self.test_subtypes + ) - tree.define_function(f) + tree.define_function(func) find_function_cached = tree.find_function_cached - @wraps(f) + @wraps(func) def dispatched_f(*args, **kw): return find_function_cached(args)(*args, **kw) @@ -45,10 +65,20 @@ def dispatched_f(*args, **kw): def __enter__(self): return self - def __exit__(self, exc_type, exc_value, exc_traceback): + def __exit__(self, exc_type, exc_value, exc_traceback): pass +@dataclass +class MultiDispatchWithOptions: + dispatch: MultiDispatch + priority: int + + def __call__(self, f): + f.__dispatch_priority__ = self.priority + return self.dispatch(f) + + class TypeNode: def __init__(self): self.follow_type = defaultdict(TypeNode) @@ -81,28 +111,36 @@ def get_arg_types(self, args): get_type = self.typesystem.get_type if self.test_subtypes: # TODO can be made more efficient - return tuple((a if i in self.test_subtypes else get_type(a)) - for i, a in enumerate(args)) + return tuple( + (a if i in self.test_subtypes else get_type(a)) + for i, a in enumerate(args) + ) return tuple(map(get_type, args)) def find_function(self, args): nodes = [self.root] for i, a in enumerate(args): - nodes = [n for node in nodes - for n in node.follow_arg(a, self.typesystem, test_subtype=i in self.test_subtypes)] + nodes = [ + n + for node in nodes + for n in node.follow_arg( + a, self.typesystem, test_subtype=i in self.test_subtypes + ) + ] funcs = [node.func for node in nodes if node.func] if len(funcs) == 0: - raise DispatchError(f"Function '{self.name}' not found for signature {self.get_arg_types(args)}") + raise DispatchError( + f"Function '{self.name}' not found for signature {self.get_arg_types(args)}" + ) elif len(funcs) > 1: - f, _sig = self.choose_most_specific_function(*funcs) + f, _sig = self.choose_most_specific_function(args, *funcs) else: - (f, _sig) ,= funcs + ((f, _sig),) = funcs return f - def find_function_cached(self, args): "Memoized version of find_function" sig = self.get_arg_types(args) @@ -113,7 +151,6 @@ def find_function_cached(self, args): self._cache[sig] = f return f - def define_function(self, f): for signature in get_func_signatures(self.typesystem, f): node = self.root @@ -122,11 +159,12 @@ def define_function(self, f): if node.func is not None: code_obj = node.func[0].__code__ - raise ValueError(f"Function {f.__name__} at {code_obj.co_filename}:{code_obj.co_firstlineno} matches existing signature: {signature}!") + raise ValueError( + f"Function {f.__name__} at {code_obj.co_filename}:{code_obj.co_firstlineno} matches existing signature: {signature}!" + ) node.func = f, signature - - def choose_most_specific_function(self, *funcs): + def choose_most_specific_function(self, args, *funcs): issubclass = self.typesystem.issubclass class IsSubclass: @@ -142,35 +180,47 @@ def __lt__(self, other): if all_eq(zipped_params): continue x = sorted(enumerate(zipped_params), key=IsSubclass) - ms_i, ms_t = x[0] # Most significant index and type - ms_set = {ms_i} # Init set of indexes of most significant params + ms_i, ms_t = x[0] # Most significant index and type + ms_set = {ms_i} # Init set of indexes of most significant params for i, t in x[1:]: if ms_t == t: - ms_set.add(i) # Add more indexes with the same type + # Add more indexes with the same type + ms_set.add(i) elif issubclass(t, ms_t) or not issubclass(ms_t, t): - # Cannot resolve ordering of these two types - n = funcs[0][0].__name__ - msg = f"Ambiguous dispatch in '{n}', argument #{arg_idx+1}: Unable to resolve the specificity of the types: \n\t- {t}\n\t- {ms_t}\n" - msg += '\nThis error occured because neither is a subclass of the other.' - msg += '\nRelevant functions:\n' - for f, sig in funcs: - c = f.__code__ - msg += f'\t- {c.co_filename}:{c.co_firstlineno} :: {sig}\n' - - raise DispatchError(msg) + # Possibly ambiguous. We might throw an error below + # TODO secondary candidates should still obscure less specific candidates + # by only considering the top match, we are ignoring them + ms_set.add(i) most_specific_per_param.append(ms_set) # Is there only one function that matches each and every parameter? most_specific = set.intersection(*most_specific_per_param) - if len(most_specific) != 1: - n = funcs[0][0].__name__ - msg = f"Ambiguous dispatch in '{n}': Unable to resolve the specificity of the functions" - msg += ''.join(f'\n\t- {n}{tuple(f[1])}' for f in funcs) - raise DispatchError(msg) + if len(most_specific) == 1: + (ms,) = most_specific + return funcs[ms] + + ambig_funcs = [funcs[i] for i in set.union(*most_specific_per_param)] + assert len(ambig_funcs) > 1 + p_ambig_funcs = [ + (getattr(f, "__dispatch_priority__", 0), f, params) + for f, params in ambig_funcs + ] + p_ambig_funcs.sort(key=itemgetter(0), reverse=True) + if p_ambig_funcs[0][0] > p_ambig_funcs[1][0]: + # If one item has a higher priority than all others, choose it + p, f, params = p_ambig_funcs[0] + return f, params + + # Could not resolve ambiguity. Throw error + n = funcs[0][0].__name__ + msg = f"Ambiguous dispatch in '{n}': Unable to resolve the specificity of the functions" + msg += "".join( + f"\n\t- {n}{tuple(params)} [priority={p}]" for p, f, params in p_ambig_funcs + ) + msg += f"\nFor arguments: {args}" + raise DispatchError(msg) - ms ,= most_specific - return funcs[ms] def all_eq(xs): a = xs[0] @@ -178,4 +228,3 @@ def all_eq(xs): if a != b: return False return True - diff --git a/runtype/pytypes.py b/runtype/pytypes.py index f2ef657..63ec386 100644 --- a/runtype/pytypes.py +++ b/runtype/pytypes.py @@ -14,7 +14,7 @@ from types import FrameType from .utils import ForwardRef -from .base_types import DataType, Validator, TypeMismatchError +from .base_types import DataType, Validator, TypeMismatchError, dp from . import base_types from . import datetime_parse @@ -129,12 +129,6 @@ class PythonDataType(DataType, PythonType): def __init__(self, kernel, supertypes={Any}): self.kernel = kernel - def __le__(self, other): - if isinstance(other, PythonDataType): - return issubclass(self.kernel, other.kernel) - - return NotImplemented - def validate_instance(self, obj, sampler=None): if not isinstance(obj, self.kernel): raise TypeMismatchError(obj, self) @@ -160,28 +154,7 @@ def cast_from(self, obj): return obj - class TupleType(PythonType): - def __le__(self, other): - # No superclasses or subclasses - if other is Any: - return True - - return isinstance(other, TupleType) - - def __ge__(self, other): - if isinstance(other, TupleEllipsisType): - return True - elif isinstance(other, TupleType): - return True - elif isinstance(other, DataType): - return False - elif isinstance(other, ProductType): - # Products are a tuple, but with length and types - return True - - return NotImplemented - def validate_instance(self, obj, sampler=None): if not isinstance(obj, tuple): raise TypeMismatchError(obj, self) @@ -199,25 +172,6 @@ class OneOf(PythonType): def __init__(self, values): self.values = values - def __le__(self, other): - if isinstance(other, OneOf): - return set(self.values) <= set(other.values) - elif isinstance(other, PythonType): - try: - for v in self.values: - other.validate_instance(v) - except TypeMismatchError: - return False - return True - return NotImplemented - - def __ge__(self, other): - if isinstance(other, OneOf): - return set(self.values) >= set(other.values) - elif isinstance(other, PythonType): - return False - return NotImplemented - def validate_instance(self, obj, sampler=None): tok = cv_type_checking.set(True) try: @@ -374,10 +328,6 @@ def __call__(self, min_length=None, max_length=None): return Constraint(self, predicates) - def __le__(self, other): - if isinstance(other, SequenceType): - return self <= other.base and self <= other.item - return super().__le__(other) class _DateTime(PythonDataType): @@ -613,3 +563,32 @@ def to_canon(self, t) -> PythonType: type_caster = TypeCaster() + + +@dp +def le(self: PythonDataType, other: PythonDataType): + return issubclass(self.kernel, other.kernel) +@dp +def le(self: TupleEllipsisType, other: TupleType): + return True +@dp +def le(self: ProductType, other: TupleType): + # Products are a tuple, but with length and specific types + return True + +@dp +def le(self: PythonDataType, other: SequenceType): + return self <= other.base and self <= other.item + +@dp +def le(self: OneOf, other: OneOf): + return set(self.values) <= set(other.values) + +@dp +def le(self: OneOf, other: PythonType): + try: + for v in self.values: + other.validate_instance(v) + except TypeMismatchError: + return False + return True \ No newline at end of file diff --git a/tests/test_types.py b/tests/test_types.py index 1b09e62..ebd18c5 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -48,16 +48,25 @@ def test_phantom(self): assert Int <= P[Int] assert P[Int] <= Int assert P[Int] <= P[Int] - assert P[Int] <= P assert not P <= P[Int] - assert Q[P] <= Q + assert P[Q] <= Q + assert P[Q] <= P + assert not P <= Q[Int] + + assert P[Q[Int]] <= P + assert P[Q[Int]] <= Q + assert P[Q[Int]] <= Int assert P[Q[Int]] <= P[Q] assert P[Q[Int]] <= P[Int] assert P[Q[Int]] <= Q[Int] - assert P[Q[Int]] <= Int + assert P[Q[Int]] <= P[Q[Int]] + assert Int <= P[Q[Int]] + assert P <= P + Int + assert not P <= Dict + assert not P <= Int + Dict def test_pytypes1(self): @@ -114,6 +123,23 @@ def test_constraint(self): assert not i.test_instance(9) assert not i.test_instance(13) + assert int_pair == int_pair + assert int_pair <= int_pair + assert int_pair >= int_pair + assert int_pair <= Any + assert Any >= int_pair + assert not int_pair <= Dict + assert not int_pair <= Int + assert not int_pair <= Int + Dict + assert not int_pair <= Tuple + + assert int_pair <= List + assert List >= int_pair + assert int_pair <= List[Int] + assert List[Int] >= int_pair + assert not int_pair <= List[String] + assert not List[String] >= int_pair + def test_typesystem(self): @@ -142,12 +168,15 @@ def test_pytypes2(self): assert not Tuple <= Int assert not Int <= Tuple - assert Literal([1]) <= Literal([1, 2]) - assert not Literal([1, 3]) <= Literal([1, 2]) - assert Literal([1, 3]) >= Literal([1]) + one = Literal([1]) + one_two = Literal([1, 2]) + one_three = Literal([1, 3]) + assert one <= one_two + assert not one_three <= one_two + assert one_three >= one assert not Literal([1]) <= Tuple - assert not Literal(1) >= Tuple - assert not Tuple <= Literal(1) + assert not Literal([1]) >= Tuple + assert not Tuple <= Literal([1]) Tuple.validate_instance((1, 2)) self.assertRaises(TypeError, Tuple.validate_instance, 1)