Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Variance. List and Dict are now invariant! #54

Merged
merged 3 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions runtype/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""
from typing import Callable, Sequence, Optional, Union
from abc import ABC, abstractmethod
from enum import Enum, auto

from .dispatch import MultiDispatch
from .typesystem import PythonBasic
Expand Down Expand Up @@ -131,11 +132,19 @@ def __hash__(self):


class ContainerType(DataType):
"""Base class for containers, such as generics."""
"""Base class for containers, such as generics.

"""

@abstractmethod
def __getitem__(self, other):
return GenericType(self, other)
...


class Variance(Enum):
Covariant = auto()
Contravariant = auto()
Invariant = auto()

class GenericType(ContainerType):
"""Implements a generic type. i.e. a container for items of a specific type.
Expand All @@ -145,8 +154,9 @@ class GenericType(ContainerType):

base: Type
item: Union[type, Type]
variance: Variance

def __init__(self, base: Type, item: Union[type, Type] = Any):
def __init__(self, base: Type, item: Union[type, Type], variance):
assert isinstance(item, (Type, type)), item
if isinstance(base, GenericType):
if not item <= base.item:
Expand All @@ -157,14 +167,15 @@ def __init__(self, base: Type, item: Union[type, Type] = Any):

self.base = base
self.item = item
self.variance = variance

def __repr__(self):
if self.item is Any:
return str(self.base)
return "%r[%r]" % (self.base, self.item)

def __getitem__(self, item):
return type(self)(self, item)
return type(self)(self, item, self.variance)

def __hash__(self):
return hash((self.base, self.item))
Expand Down Expand Up @@ -328,7 +339,13 @@ def le(self: ProductType, other: ProductType):

@dp
def le(self: GenericType, other: GenericType):
return self.base <= other.base and self.item <= other.item
if self.variance == Variance.Covariant:
return self.base <= other.base and self.item <= other.item
elif self.variance == Variance.Contravariant:
return self.base <= other.base and self.item >= other.item
elif self.variance == Variance.Invariant:
return self.base <= other.base and self.item == other.item
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of raising a generic RuntimeError, it's better to throw a more specific error that details what the issue is (for instance, 'Unexpected value for variance')

raise RuntimeError()

@dp
def le(self: GenericType, other: Type):
Expand Down
149 changes: 78 additions & 71 deletions runtype/pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Python Types - contains an implementation of a Runtype type system that is parallel to the Python type system.
"""

import typing as t
import contextvars
import types
from abc import abstractmethod, ABC
Expand All @@ -14,7 +15,7 @@
from types import FrameType

from .utils import ForwardRef
from .base_types import DataType, Validator, TypeMismatchError, dp
from .base_types import DataType, Validator, TypeMismatchError, dp, Variance
from . import base_types
from . import datetime_parse

Expand Down Expand Up @@ -215,96 +216,102 @@ def cast_from(self, obj):


class GenericType(base_types.GenericType, PythonType):
base: PythonType
base: PythonDataType
item: PythonType

def __init__(self, base: PythonType, item=Any):
return super().__init__(base, item)
def __init__(self, base: PythonType, item: PythonType=Any, variance: Variance = Variance.Covariant):
return super().__init__(base, item, variance)


class SequenceType(GenericType):
class GenericContainerType(GenericType):
def validate_instance(self, obj, sampler=None):
self.base.validate_instance(obj)
if not self.accepts_any:
self.validate_instance_items(obj, sampler)

def test_instance(self, obj, sampler=None):
if not self.base.test_instance(obj):
return False
if self.item is not Any:
if sampler:
obj = sampler(obj)
for item in obj:
if not self.item.test_instance(item, sampler):
return False
return True

def validate_instance(self, obj, sampler=None):
self.base.validate_instance(obj)
if self.item is not Any:
if sampler:
obj = sampler(obj)
for item in obj:
self.item.validate_instance(item, sampler)
return self.accepts_any or self.test_instance_items(obj, sampler)

def cast_from(self, obj):
# Optimize for List[Any] and empty sequences
if self.item is Any or not obj:
# Already a list?
if self.base.test_instance(obj):
return obj
# Make sure it's a list
return list(obj)
# Optimize for X[Any] and empty containers
if not obj or self.accepts_any:
# Make sure it's the right type
return obj if self.base.test_instance(obj) else self.base.kernel(obj)

return self.cast_from_items(obj)

@property
@abstractmethod
def accepts_any(self) -> bool:
...

@abstractmethod
def validate_instance_items(self, items: t.Iterable, sampler):
...

@abstractmethod
def test_instance_items(self, items: t.Iterable, sampler) -> bool:
...

@abstractmethod
def cast_from_items(self, obj):
...

class SequenceType(GenericContainerType):
@property
def accepts_any(self):
return self.item is Any

def validate_instance_items(self, obj: t.Sequence, sampler):
for item in sampler(obj) if sampler else obj:
self.item.validate_instance(item, sampler)

def test_instance_items(self, obj: t.Sequence, sampler) -> bool:
return all(
self.item.test_instance(item, sampler)
for item in (sampler(obj) if sampler else obj)
)

def cast_from_items(self, obj: t.Sequence):
# Recursively cast each item
return [self.item.cast_from(item) for item in obj]
return self.base.kernel(self.item.cast_from(item) for item in obj)


class DictType(GenericType):
class DictType(GenericContainerType):
item: ProductType

def __init__(self, base: PythonType, item=Any*Any):
super().__init__(base)
def __init__(self, base: PythonType, item=Any*Any, variance: Variance = Variance.Covariant):
super().__init__(base, variance=variance)
if isinstance(item, tuple):
assert len(item) == 2
item = ProductType([type_caster.to_canon(x) for x in item])
self.item = item

def validate_instance(self, obj, sampler=None):
self.base.validate_instance(obj)
if self.item is not Any:
kt, vt = self.item.types
items = obj.items()
if sampler:
items = sampler(items)
for k, v in items:
kt.validate_instance(k, sampler)
vt.validate_instance(v, sampler)
@property
def accepts_any(self):
return self.item is Any or self.item == Any*Any

def test_instance(self, obj, sampler=None):
if not self.base.test_instance(obj):
return False
if self.item is not Any:
kt, vt = self.item.types
items = obj.items()
if sampler:
items = sampler(items)
for k, v in items:
if not kt.test_instance(k, sampler):
return False
if not vt.test_instance(v, sampler):
return False
return True
def validate_instance_items(self, obj: t.Mapping, sampler):
assert isinstance(self.item, base_types.ProductType)
kt, vt = self.item.types
for k, v in sampler(obj.items()) if sampler else obj.items():
kt.validate_instance(k, sampler)
vt.validate_instance(v, sampler)

def test_instance_items(self, obj: t.Mapping, sampler) -> bool:
assert isinstance(self.item, base_types.ProductType)
kt, vt = self.item.types
return all(
kt.test_instance(k, sampler) and vt.test_instance(v, sampler)
for k, v in (sampler(obj.items()) if sampler else obj.items())
)

def __getitem__(self, item):
assert self.item == Any*Any
return type(self)(self.base, item)

def cast_from(self, obj):
# Optimize for Dict[Any] and empty dicts
if self.item is Any or not obj:
# Already a dict?
if self.base.test_instance(obj):
return obj
# Make sure it's a dict
return dict(obj)
return type(self)(self.base, item, self.variance)

def cast_from_items(self, obj: t.Mapping):
# Must already be a dict
self.base.validate_instance(obj)

Expand All @@ -324,15 +331,15 @@ def test_instance(self, obj, sampler=None):

Object = PythonDataType(object)
Iter = SequenceType(PythonDataType(collections.abc.Iterable))
List = SequenceType(PythonDataType(list))
Sequence = SequenceType(PythonDataType(abc.Sequence))
MutableSequence = SequenceType(PythonDataType(abc.MutableSequence))
Set = SequenceType(PythonDataType(set))
List = SequenceType(PythonDataType(list), variance=Variance.Invariant)
MutableSequence = SequenceType(PythonDataType(abc.MutableSequence), variance=Variance.Invariant)
Set = SequenceType(PythonDataType(set), variance=Variance.Invariant)
FrozenSet = SequenceType(PythonDataType(frozenset))
AbstractSet = SequenceType(PythonDataType(abc.Set))
Dict = DictType(PythonDataType(dict))
Mapping = DictType(PythonDataType(abc.Mapping))
MutableMapping = DictType(PythonDataType(abc.MutableMapping))
Dict = DictType(PythonDataType(dict), variance=Variance.Invariant)
MutableMapping = DictType(PythonDataType(abc.MutableMapping), variance=Variance.Invariant)
Tuple = TupleType()
TupleEllipsis = TupleEllipsisType(PythonDataType(tuple))
# Float = PythonDataType(float)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_basic(self):
assert isa(lambda:0, Callable)
assert not isa(1, Callable)

assert issubclass(List[int], list)
assert not issubclass(List[int], list) # invariant
assert issubclass(Tuple[int], tuple)
assert issubclass(Tuple[int, int], tuple)
assert not issubclass(tuple, Tuple[int])
Expand Down Expand Up @@ -96,7 +96,8 @@ def test_basic(self):


def test_issubclass(self):
assert issubclass(List[Tuple], list)
assert not issubclass(List[Tuple], list) # invariant
assert issubclass(Sequence[Tuple], Sequence)

if hasattr(typing, 'Annotated'):
a = typing.Annotated[int, range(1, 10)]
Expand Down
31 changes: 21 additions & 10 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
import typing
import collections.abc as cabc

from runtype.base_types import DataType, ContainerType, PhantomType
from runtype.pytypes import type_caster, List, Dict, Int, Any, Constraint, String, Tuple, Iter, Literal, NoneType
from runtype.base_types import DataType, GenericType, PhantomType, Variance
from runtype.pytypes import type_caster, List, Dict, Int, Any, Constraint, String, Tuple, Iter, Literal, NoneType, Sequence, Mapping
from runtype.typesystem import TypeSystem


class TestTypes(TestCase):
def test_basic_types(self):
Int = DataType()
Str = DataType()
Array = ContainerType()
Array = GenericType(DataType(), Any, Variance.Covariant)

assert Int == Int
assert Int != Str
Expand Down Expand Up @@ -105,7 +105,7 @@ def test_pytypes1(self):
assert List[Any] == List

def test_constraint(self):
int_pair = Constraint(typing.List[int], [lambda a: len(a) == 2])
int_pair = Constraint(typing.Sequence[int], [lambda a: len(a) == 2])
assert int_pair.test_instance([1,2])
assert not int_pair.test_instance([1,2,3])
assert not int_pair.test_instance([1,'a'])
Expand Down Expand Up @@ -133,12 +133,12 @@ def test_constraint(self):
assert not int_pair <= Int + Dict
assert not int_pair <= Tuple

assert int_pair <= List
assert List >= int_pair
assert int_pair <= List[Int]
assert List[Int] >= int_pair
assert not int_pair <= List[String]
assert not List[String] >= int_pair
assert int_pair <= Sequence
assert Sequence >= int_pair
assert int_pair <= Sequence[Int]
assert Sequence[Int] >= int_pair
assert not int_pair <= Sequence[String]
assert not Sequence[String] >= int_pair



Expand Down Expand Up @@ -271,6 +271,17 @@ def test_any(self):
assert Any + Int <= Any
assert Any + NoneType <= Any

def test_invariance(self):
assert List <= Sequence
assert not List[List] <= List[Sequence]
assert not List[Sequence] <= List[List]

assert Dict <= Mapping
assert Dict[Int, Int] <= Mapping[Int, Int]
assert Mapping[Int, List] <= Mapping[Int, Sequence]
assert not Dict[Int, List] <= Dict[Int, Sequence]
assert not Mapping[Int, Sequence] <= Mapping[Int, List]
assert not Dict[Int, Sequence] <= Dict[Int, List]


if __name__ == '__main__':
Expand Down
Loading