Skip to content

Commit

Permalink
Types: Converted all __eq__, __le__, and __ge__ methods into dispatch…
Browse files Browse the repository at this point in the history
…ed funcs
  • Loading branch information
erezsh committed Mar 2, 2024
1 parent 0322821 commit 0c0fc61
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 201 deletions.
242 changes: 100 additions & 142 deletions runtype/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from typing import Callable, Sequence, Optional, Union
from abc import ABC, abstractmethod

from .dispatch import MultiDispatch
from .typesystem import PythonBasic
dp = MultiDispatch(PythonBasic())

class RuntypeError(TypeError):
pass
Expand All @@ -30,9 +33,6 @@ def __add__(self, other: _Type):
def __mul__(self, other: _Type):
return ProductType.create((self, other))

@abstractmethod
def __le__(self, other: _Type):
return NotImplemented

class AnyType(Type):
"""Represents the Any type.
Expand All @@ -42,21 +42,6 @@ class AnyType(Type):
def __add__(self, other):
return self

def __ge__(self, other):
if isinstance(other, (type, Type)):
return True

return NotImplemented

def __le__(self, other):
if other is self: # Optimization
return True
elif isinstance(other, (type, Type)):
if not isinstance(other, SumType):
return False

return NotImplemented

def __repr__(self):
return 'Any'

Expand All @@ -69,15 +54,6 @@ class DataType(Type):
Example of possible data-types: int, float, text
"""
def __le__(self, other):
if isinstance(other, DataType):
return self == other

return super().__le__(other)

def __ge__(self, other):
return NotImplemented


class SumType(Type):
"""Implements a sum type, i.e. a disjoint union of a set of types.
Expand Down Expand Up @@ -106,23 +82,10 @@ def create(cls, types):
def __repr__(self):
return '(%s)' % '+'.join(map(repr, self.types))

def __le__(self, other):
return all(t <= other for t in self.types)

def __ge__(self, other):
return any(other <= t for t in self.types)

def __eq__(self, other):
if isinstance(other, SumType):
return self.types == other.types

return NotImplemented

def __hash__(self):
return hash(frozenset(self.types))



class ProductType(Type):
"""Implements a product type, i.e. a tuple of types.
"""
Expand All @@ -148,29 +111,6 @@ def __repr__(self):
def __hash__(self):
return hash(self.types)

def __eq__(self, other):
if isinstance(other, ProductType):
return self.types == other.types

return NotImplemented

def __le__(self, other):
if isinstance(other, ProductType):
if len(self.types) != len(other.types):
return False

return all(t1<=t2 for t1, t2 in zip(self.types, other.types))
elif isinstance(other, DataType):
return False

return NotImplemented

def __ge__(self, other):
if isinstance(other, DataType):
return False

return NotImplemented


class ContainerType(DataType):
"""Base class for containers, such as generics.
Expand Down Expand Up @@ -204,63 +144,15 @@ def __repr__(self):
def __getitem__(self, item):
return type(self)(self, item)

def __eq__(self, other):
if isinstance(other, GenericType):
return self.base == other.base and self.item == other.item

elif isinstance(other, Type):
return Any <= self.item and self.base == other

return NotImplemented


def __le__(self, other):
if isinstance(other, GenericType):
return self.base <= other.base and self.item <= other.item

elif isinstance(other, DataType):
return self.base <= other

return NotImplemented

def __ge__(self, other):
if isinstance(other, type(self)):
return self.base >= other.base and self.item >= other.item

elif isinstance(other, DataType):
return self.base >= other

return NotImplemented

def __hash__(self):
return hash((self.base, self.item))


class PhantomType(Type):
"""Implements a base for phantom types.
"""
def __getitem__(self, other):
return PhantomGenericType(self, other)

def __le__(self, other):
if isinstance(other, PhantomType):
return self == other

elif not isinstance(other, PhantomGenericType):
return False

return NotImplemented


def __ge__(self, other):
if isinstance(other, PhantomType):
return self == other

elif not isinstance(other, PhantomGenericType):
return False

return NotImplemented


class PhantomGenericType(Type):
"""Implements a generic phantom type, for carrying metadata within the type signature.
Expand All @@ -271,32 +163,6 @@ def __init__(self, base, item=Any):
self.base = base
self.item = item

def __le__(self, other):
if isinstance(other, PhantomType):
return self.base <= other or self.item <= other

elif isinstance(other, PhantomGenericType):
return (self.base <= other.base and self.item <= other.item) or self.item <= other

elif isinstance(other, DataType):
return self.item <= other

return NotImplemented

def __eq__(self, other):
if isinstance(other, PhantomGenericType):
return self.base == other.base and self.item == other.base

return NotImplemented

def __ge__(self, other):
if isinstance(other, PhantomType):
return False

elif isinstance(other, DataType):
return other <= self.item

return NotImplemented

SamplerType = Callable[[Sequence], Sequence]

Expand Down Expand Up @@ -326,7 +192,7 @@ def test_instance(self, obj, sampler=None):
return False


class Constraint(Validator, PhantomType):
class Constraint(Validator, Type):
"""Defines a constraint, which activates during validation.
"""

Expand All @@ -342,9 +208,101 @@ def validate_instance(self, inst, sampler=None):
if not p(inst):
raise TypeMismatchError(inst, self)

def __ge__(self, other):
# Arbitrary predicates prevent it from being a superclass

@dp
def le(self, other):
return NotImplemented
@dp(priority=-1)
def le(self: Type, other: Type):
return self == other
@dp
def ge(self, other):
return le(other, self)
@dp
def eq(self, other):
return NotImplemented

@dp
def eq(self: SumType, other: SumType):
return self.types == other.types
@dp
def eq(self: ProductType, other: ProductType):
return self.types == other.types
@dp
def eq(self: GenericType, other: GenericType):
return self.base == other.base and self.item == other.item
@dp
def eq(self: GenericType, other: Type):
return Any <= self.item and self.base == other
@dp
def eq(self: PhantomGenericType, other: PhantomGenericType):
return self.base == other.base and self.item == other.base


# Special cases (AnyType, SumType)
@dp(priority=100)
def le(self: Type, other: AnyType):
return True
@dp
def le(self: type, other: AnyType):
return True


@dp(priority=51)
def le(self: SumType, other: Type):
return all(t <= other for t in self.types)

@dp(priority=50)
def le(self: Type, other: SumType):
return any(self <= t for t in other.types)

@dp
def le(self: ProductType, other: ProductType):
if len(self.types) != len(other.types):
return False

def __le__(self, other):
return self.type <= other
return all(t1<=t2 for t1, t2 in zip(self.types, other.types))

# Generics

@dp
def le(self: GenericType, other: GenericType):
return self.base <= other.base and self.item <= other.item

@dp
def le(self: GenericType, other: Type):
return self.base <= other
@dp
def le(self: Type, other: GenericType):
return other.item is Any and self <= other.base


# Phantom Types

@dp(priority=1)
def le(self: Type, other: PhantomGenericType):
return self <= other.item

@dp
def le(self: PhantomGenericType, other: Type):
return self.item <= other

@dp
def le(self: PhantomGenericType, other: PhantomType):
return self.base <= other or self.item <= other

# Constraints

@dp
def le(self: Constraint, other: Constraint):
# Arbitrary predicates prevent it from being a superclass
return self == other

@dp(priority=1)
def le(self: Constraint, other: Type):
return self.type <= other


Type.__eq__ = eq
Type.__le__ = le
Type.__ge__ = ge
Loading

0 comments on commit 0c0fc61

Please sign in to comment.