diff --git a/runtype/__init__.py b/runtype/__init__.py index 95ba493..76cdfbd 100644 --- a/runtype/__init__.py +++ b/runtype/__init__.py @@ -1,6 +1,6 @@ from .dataclass import dataclass from .dispatch import DispatchError, MultiDispatch -from .validation import PythonTyping, TypeSystem, TypeMismatchError, assert_isa, isa, issubclass, validate_func +from .validation import PythonTyping, TypeSystem, TypeMismatchError, assert_isa, isa, issubclass, validate_func, is_subtype from .pytypes import Constraint, String, Int __version__ = "0.3.1" diff --git a/runtype/pytypes.py b/runtype/pytypes.py index 488f10b..6607d03 100644 --- a/runtype/pytypes.py +++ b/runtype/pytypes.py @@ -166,6 +166,8 @@ def __init__(self, values): self.values = values def __le__(self, other): + if isinstance(other, OneOf): + return set(self.values) <= set(other.values) return NotImplemented def validate_instance(self, obj, sampler=None): diff --git a/tests/test_basic.py b/tests/test_basic.py index 15f1e28..fe599a5 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -11,7 +11,7 @@ import logging logging.basicConfig(level=logging.INFO) -from runtype import Dispatch, DispatchError, dataclass, isa, issubclass, assert_isa, String, Int, validate_func +from runtype import Dispatch, DispatchError, dataclass, isa, is_subtype, issubclass, assert_isa, String, Int, validate_func from runtype.dataclass import Configuration @@ -106,6 +106,8 @@ def test_py38(self): assert isa('a', typing.Literal['a', 'b']) assert not isa('c', typing.Literal['a', 'b']) + assert is_subtype(typing.Literal[1], typing.Literal[1,2]) + def test_validate_func(self): @validate_func def f(a: int, b: str, c: List[int] = []):