From 1fed58a66f62ab92cb2967ade74e5286cdca3604 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Tue, 23 May 2023 20:47:12 +0100 Subject: [PATCH 1/6] Fix the bug --- src/test_typing_extensions.py | 108 ++++++++++--- src/typing_extensions.py | 279 ++++++++++++++++++++-------------- 2 files changed, 258 insertions(+), 129 deletions(-) diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index 882b4500..6f5a42db 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -2613,6 +2613,52 @@ class CustomProtocolWithoutInitB(Protocol): self.assertEqual(CustomProtocolWithoutInitA.__init__, CustomProtocolWithoutInitB.__init__) + def test_protocol_generic_over_paramspec(self): + P = ParamSpec("P") + T = TypeVar("T") + T2 = TypeVar("T2") + + class MemoizedFunc(Protocol[P, T, T2]): + cache: typing.Dict[T2, T] + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ... + + self.assertEqual(MemoizedFunc.__parameters__, (P, T, T2)) + self.assertTrue(MemoizedFunc._is_protocol) + + with self.assertRaisesRegex(TypeError, "Too few arguments"): + MemoizedFunc[[int, str, str]] + + X = MemoizedFunc[[int, str, str], T, T2] + self.assertEqual(X.__parameters__, (T, T2)) + self.assertEqual(X.__args__, ((int, str, str), T, T2)) + + Y = X[bytes, memoryview] + self.assertEqual(Y.__parameters__, ()) + self.assertEqual(Y.__args__, ((int, str, str), bytes, memoryview)) + + def test_protocol_generic_over_typevartuple(self): + Ts = TypeVarTuple("Ts") + T = TypeVar("T") + T2 = TypeVar("T2") + + class MemoizedFunc(Protocol[Unpack[Ts], T, T2]): + cache: typing.Dict[T2, T] + def __call__(self, *args: Unpack[Ts]) -> T: ... + + self.assertEqual(MemoizedFunc.__parameters__, (Ts, T, T2)) + self.assertTrue(MemoizedFunc._is_protocol) + + with self.assertRaisesRegex(TypeError, "Too few arguments"): + MemoizedFunc[int] + + X = MemoizedFunc[int, T, T2] + self.assertEqual(X.__parameters__, (T, T2)) + self.assertEqual(X.__args__, (int, T, T2)) + + Y = X[bytes, memoryview] + self.assertEqual(Y.__parameters__, ()) + self.assertEqual(Y.__args__, (int, bytes, memoryview)) + class Point2DGeneric(Generic[T], TypedDict): a: T @@ -3402,13 +3448,18 @@ def test_user_generics(self): class X(Generic[T, P]): pass - G1 = X[int, P_2] - self.assertEqual(G1.__args__, (int, P_2)) - self.assertEqual(G1.__parameters__, (P_2,)) + class Y(Protocol[T, P]): + pass + + for klass in X, Y: + with self.subTest(klass=klass.__name__): + G1 = klass[int, P_2] + self.assertEqual(G1.__args__, (int, P_2)) + self.assertEqual(G1.__parameters__, (P_2,)) - G2 = X[int, Concatenate[int, P_2]] - self.assertEqual(G2.__args__, (int, Concatenate[int, P_2])) - self.assertEqual(G2.__parameters__, (P_2,)) + G2 = klass[int, Concatenate[int, P_2]] + self.assertEqual(G2.__args__, (int, Concatenate[int, P_2])) + self.assertEqual(G2.__parameters__, (P_2,)) # The following are some valid uses cases in PEP 612 that don't work: # These do not work in 3.9, _type_check blocks the list and ellipsis. @@ -3421,6 +3472,9 @@ class X(Generic[T, P]): class Z(Generic[P]): pass + class ProtoZ(Protocol[P]): + pass + def test_pickle(self): global P, P_co, P_contra, P_default P = ParamSpec('P') @@ -3727,31 +3781,47 @@ def test_concatenation(self): self.assertEqual(Tuple[int, Unpack[Xs], str].__args__, (int, Unpack[Xs], str)) class C(Generic[Unpack[Xs]]): pass - self.assertEqual(C[int, Unpack[Xs]].__args__, (int, Unpack[Xs])) - self.assertEqual(C[Unpack[Xs], int].__args__, (Unpack[Xs], int)) - self.assertEqual(C[int, Unpack[Xs], str].__args__, - (int, Unpack[Xs], str)) + class D(Protocol[Unpack[Xs]]): pass + for klass in C, D: + with self.subTest(klass=klass.__name__): + self.assertEqual(klass[int, Unpack[Xs]].__args__, (int, Unpack[Xs])) + self.assertEqual(klass[Unpack[Xs], int].__args__, (Unpack[Xs], int)) + self.assertEqual(klass[int, Unpack[Xs], str].__args__, + (int, Unpack[Xs], str)) def test_class(self): Ts = TypeVarTuple('Ts') class C(Generic[Unpack[Ts]]): pass - self.assertEqual(C[int].__args__, (int,)) - self.assertEqual(C[int, str].__args__, (int, str)) + class D(Protocol[Unpack[Ts]]): pass + + for klass in C, D: + with self.subTest(klass=klass.__name__): + self.assertEqual(klass[int].__args__, (int,)) + self.assertEqual(klass[int, str].__args__, (int, str)) with self.assertRaises(TypeError): class C(Generic[Unpack[Ts], int]): pass + with self.assertRaises(TypeError): + class D(Protocol[Unpack[Ts], int]): pass + T1 = TypeVar('T') T2 = TypeVar('T') class C(Generic[T1, T2, Unpack[Ts]]): pass - self.assertEqual(C[int, str].__args__, (int, str)) - self.assertEqual(C[int, str, float].__args__, (int, str, float)) - self.assertEqual(C[int, str, float, bool].__args__, (int, str, float, bool)) - # TODO This should probably also fail on 3.11, pending changes to CPython. - if not TYPING_3_11_0: - with self.assertRaises(TypeError): - C[int] + class D(Protocol[T1, T2, Unpack[Ts]]): pass + for klass in C, D: + with self.subTest(klass=klass.__name__): + self.assertEqual(klass[int, str].__args__, (int, str)) + self.assertEqual(klass[int, str, float].__args__, (int, str, float)) + self.assertEqual( + klass[int, str, float, bool].__args__, (int, str, float, bool) + ) + # TODO This should probably also fail on 3.11, + # pending changes to CPython. + if not TYPING_3_11_0: + with self.assertRaises(TypeError): + klass[int] class TypeVarTupleTests(BaseTestCase): diff --git a/src/typing_extensions.py b/src/typing_extensions.py index e2b2dc6a..17eb8606 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -612,131 +612,190 @@ def __instancecheck__(cls, instance): return False - class Protocol(metaclass=_ProtocolMeta): - # There is quite a lot of overlapping code with typing.Generic. - # Unfortunately it is hard to avoid this while these live in two different - # modules. The duplicated code will be removed when Protocol is moved to typing. - """Base class for protocol classes. Protocol classes are defined as:: - - class Proto(Protocol): - def meth(self) -> int: - ... + def __eq__(cls, other): + # Hack so that typing.Generic.__class_getitem__ + # treats typing_extensions.Protocol + # as equivalent to typing.Protocol on Python 3.8+ + if super().__eq__(other) is True: + return True + return ( + cls is Protocol and other is getattr(typing, "Protocol", object()) + ) - Such classes are primarily used with static type checkers that recognize - structural subtyping (static duck-typing), for example:: + # This has to be defined, or the abc-module cache + # complains about classes with this metaclass being unhashable, + # if we define only __eq__! + def __hash__(cls) -> int: + return type.__hash__(cls) - class C: - def meth(self) -> int: - return 0 + @classmethod + def _proto_hook(cls, other): + if not cls.__dict__.get('_is_protocol', False): + return NotImplemented - def func(x: Proto) -> int: - return x.meth() + # First, perform various sanity checks. + if not getattr(cls, '_is_runtime_protocol', False): + if _allow_reckless_class_checks(): + return NotImplemented + raise TypeError("Instance and class checks can only be used with" + " @runtime_checkable protocols") + + if not isinstance(other, type): + # Same error message as for issubclass(1, int). + raise TypeError('issubclass() arg 1 must be a class') + + # Second, perform the actual structural compatibility check. + for attr in cls.__protocol_attrs__: + for base in other.__mro__: + # Check if the members appears in the class dictionary... + if attr in base.__dict__: + if base.__dict__[attr] is None: + return NotImplemented + break - func(C()) # Passes static type check + # ...or in annotations, if it is a sub-protocol. + annotations = getattr(base, '__annotations__', {}) + if ( + isinstance(annotations, collections.abc.Mapping) + and attr in annotations + and issubclass(other, (typing.Generic, _ProtocolMeta)) + and other._is_protocol + ): + break + else: + return NotImplemented + return True - See PEP 544 for details. Protocol classes decorated with - @typing_extensions.runtime act as simple-minded runtime protocol that checks - only the presence of given attributes, ignoring their type signatures. + def _check_proto_bases(cls): + for base in cls.__bases__: + if not (base in (object, typing.Generic) or + base.__module__ in _PROTO_ALLOWLIST and + base.__name__ in _PROTO_ALLOWLIST[base.__module__] or + isinstance(base, _ProtocolMeta) and base._is_protocol): + raise TypeError('Protocols can only inherit from other' + f' protocols, got {repr(base)}') - Protocol classes can be generic, they are defined as:: + if sys.version_info >= (3, 8): + class Protocol(typing.Generic, metaclass=_ProtocolMeta): + __doc__ = typing.Protocol.__doc__ + __slots__ = () + _is_protocol = True + _is_runtime_protocol = False - class GenProto(Protocol[T]): - def meth(self) -> T: - ... - """ - __slots__ = () - _is_protocol = True - _is_runtime_protocol = False + def __init_subclass__(cls, *args, **kwargs): + super().__init_subclass__(*args, **kwargs) - def __new__(cls, *args, **kwds): - if cls is Protocol: - raise TypeError("Type Protocol cannot be instantiated; " - "it can only be used as a base class") - return super().__new__(cls) + # Determine if this is a protocol or a concrete subclass. + if not cls.__dict__.get('_is_protocol', False): + cls._is_protocol = any(b is Protocol for b in cls.__bases__) - @typing._tp_cache - def __class_getitem__(cls, params): - if not isinstance(params, tuple): - params = (params,) - if not params and cls is not typing.Tuple: - raise TypeError( - f"Parameter list to {cls.__qualname__}[...] cannot be empty") - msg = "Parameters to generic types must be types." - params = tuple(typing._type_check(p, msg) for p in params) - if cls is Protocol: - # Generic can only be subscripted with unique type variables. - if not all(isinstance(p, typing.TypeVar) for p in params): - i = 0 - while isinstance(params[i], typing.TypeVar): - i += 1 - raise TypeError( - "Parameters to Protocol[...] must all be type variables." - f" Parameter {i + 1} is {params[i]}") - if len(set(params)) != len(params): - raise TypeError( - "Parameters to Protocol[...] must all be unique") - else: - # Subscripting a regular Generic subclass. - _check_generic(cls, params, len(cls.__parameters__)) - return typing._GenericAlias(cls, params) + # Set (or override) the protocol subclass hook. + if '__subclasshook__' not in cls.__dict__: + cls.__subclasshook__ = _proto_hook - def __init_subclass__(cls, *args, **kwargs): - if '__orig_bases__' in cls.__dict__: - error = typing.Generic in cls.__orig_bases__ - else: - error = typing.Generic in cls.__bases__ - if error: - raise TypeError("Cannot inherit from plain Generic") - _maybe_adjust_parameters(cls) + # We have nothing more to do for non-protocols... + if not cls._is_protocol: + return - # Determine if this is a protocol or a concrete subclass. - if not cls.__dict__.get('_is_protocol', None): - cls._is_protocol = any(b is Protocol for b in cls.__bases__) + # ... otherwise check consistency of bases, and prohibit instantiation. + _check_proto_bases(cls) + if cls.__init__ is Protocol.__init__: + cls.__init__ = _no_init - # Set (or override) the protocol subclass hook. - def _proto_hook(other): + else: + class Protocol(metaclass=_ProtocolMeta): + # There is quite a lot of overlapping code with typing.Generic. + # Unfortunately it is hard to avoid this on Python <3.8, + # as the typing module on Python 3.7 doesn't let us subclass typing.Generic! + """Base class for protocol classes. Protocol classes are defined as:: + + class Proto(Protocol): + def meth(self) -> int: + ... + + Such classes are primarily used with static type checkers that recognize + structural subtyping (static duck-typing), for example:: + + class C: + def meth(self) -> int: + return 0 + + def func(x: Proto) -> int: + return x.meth() + + func(C()) # Passes static type check + + See PEP 544 for details. Protocol classes decorated with + @typing_extensions.runtime act as simple-minded runtime protocol that checks + only the presence of given attributes, ignoring their type signatures. + + Protocol classes can be generic, they are defined as:: + + class GenProto(Protocol[T]): + def meth(self) -> T: + ... + """ + __slots__ = () + _is_protocol = True + _is_runtime_protocol = False + + def __new__(cls, *args, **kwds): + if cls is Protocol: + raise TypeError("Type Protocol cannot be instantiated; " + "it can only be used as a base class") + return super().__new__(cls) + + @typing._tp_cache + def __class_getitem__(cls, params): + if not isinstance(params, tuple): + params = (params,) + if not params and cls is not typing.Tuple: + raise TypeError( + f"Parameter list to {cls.__qualname__}[...] cannot be empty") + msg = "Parameters to generic types must be types." + params = tuple(typing._type_check(p, msg) for p in params) + if cls is Protocol: + # Generic can only be subscripted with unique type variables. + if not all(isinstance(p, typing.TypeVar) for p in params): + i = 0 + while isinstance(params[i], typing.TypeVar): + i += 1 + raise TypeError( + "Parameters to Protocol[...] must all be type variables." + f" Parameter {i + 1} is {params[i]}") + if len(set(params)) != len(params): + raise TypeError( + "Parameters to Protocol[...] must all be unique") + else: + # Subscripting a regular Generic subclass. + _check_generic(cls, params, len(cls.__parameters__)) + return typing._GenericAlias(cls, params) + + def __init_subclass__(cls, *args, **kwargs): + if '__orig_bases__' in cls.__dict__: + error = typing.Generic in cls.__orig_bases__ + else: + error = typing.Generic in cls.__bases__ + if error: + raise TypeError("Cannot inherit from plain Generic") + _maybe_adjust_parameters(cls) + + # Determine if this is a protocol or a concrete subclass. if not cls.__dict__.get('_is_protocol', None): - return NotImplemented - if not getattr(cls, '_is_runtime_protocol', False): - if _allow_reckless_class_checks(): - return NotImplemented - raise TypeError("Instance and class checks can only be used with" - " @runtime_checkable protocols") - if not isinstance(other, type): - # Same error as for issubclass(1, int) - raise TypeError('issubclass() arg 1 must be a class') - for attr in cls.__protocol_attrs__: - for base in other.__mro__: - if attr in base.__dict__: - if base.__dict__[attr] is None: - return NotImplemented - break - annotations = getattr(base, '__annotations__', {}) - if (isinstance(annotations, typing.Mapping) and - attr in annotations and - isinstance(other, _ProtocolMeta) and - other._is_protocol): - break - else: - return NotImplemented - return True - if '__subclasshook__' not in cls.__dict__: - cls.__subclasshook__ = _proto_hook + cls._is_protocol = any(b is Protocol for b in cls.__bases__) - # We have nothing more to do for non-protocols. - if not cls._is_protocol: - return + # Set (or override) the protocol subclass hook. + if '__subclasshook__' not in cls.__dict__: + cls.__subclasshook__ = _proto_hook + + # We have nothing more to do for non-protocols. + if not cls._is_protocol: + return - # Check consistency of bases. - for base in cls.__bases__: - if not (base in (object, typing.Generic) or - base.__module__ in _PROTO_ALLOWLIST and - base.__name__ in _PROTO_ALLOWLIST[base.__module__] or - isinstance(base, _ProtocolMeta) and base._is_protocol): - raise TypeError('Protocols can only inherit from other' - f' protocols, got {repr(base)}') - if cls.__init__ is Protocol.__init__: - cls.__init__ = _no_init + # Check consistency of bases. + _check_proto_bases(cls) + if cls.__init__ is Protocol.__init__: + cls.__init__ = _no_init def runtime_checkable(cls): """Mark a protocol class as a runtime protocol, so that it From 1c74866953b811be363197deef1d7c7d44250f9e Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Tue, 23 May 2023 20:51:46 +0100 Subject: [PATCH 2/6] Docs --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a8bddeb..3a9cba7d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ - Change deprecated `@runtime` to formal API `@runtime_checkable` in the error message. Patch by Xuehai Pan. +- Fix regression in 4.6.0 where attempting to define a `Protocol` that was + generic over a `ParamSpec` or a `TypeVarTuple` would cause `TypeError` to be + raised. Patch by Alex Waygood. # Release 4.6.0 (May 22, 2023) From e233aebf7ea57bac54a87ed3b0ea8c4cef8a70a3 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Tue, 23 May 2023 20:57:27 +0100 Subject: [PATCH 3/6] Fix tests on older Pythons --- src/test_typing_extensions.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index 6f5a42db..d806c2b0 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -2625,16 +2625,19 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ... self.assertEqual(MemoizedFunc.__parameters__, (P, T, T2)) self.assertTrue(MemoizedFunc._is_protocol) - with self.assertRaisesRegex(TypeError, "Too few arguments"): + with self.assertRaises(TypeError): MemoizedFunc[[int, str, str]] - X = MemoizedFunc[[int, str, str], T, T2] - self.assertEqual(X.__parameters__, (T, T2)) - self.assertEqual(X.__args__, ((int, str, str), T, T2)) + if sys.version_info >= (3, 10): + # These unfortunately don't pass on <=3.9, + # due to typing._type_check on older Python versions + X = MemoizedFunc[[int, str, str], T, T2] + self.assertEqual(X.__parameters__, (T, T2)) + self.assertEqual(X.__args__, ((int, str, str), T, T2)) - Y = X[bytes, memoryview] - self.assertEqual(Y.__parameters__, ()) - self.assertEqual(Y.__args__, ((int, str, str), bytes, memoryview)) + Y = X[bytes, memoryview] + self.assertEqual(Y.__parameters__, ()) + self.assertEqual(Y.__args__, ((int, str, str), bytes, memoryview)) def test_protocol_generic_over_typevartuple(self): Ts = TypeVarTuple("Ts") @@ -2648,7 +2651,9 @@ def __call__(self, *args: Unpack[Ts]) -> T: ... self.assertEqual(MemoizedFunc.__parameters__, (Ts, T, T2)) self.assertTrue(MemoizedFunc._is_protocol) - with self.assertRaisesRegex(TypeError, "Too few arguments"): + things = "arguments" if sys.version_info >= (3, 11) else "parameters" + + with self.assertRaisesRegex(TypeError, f"Too few {things}"): MemoizedFunc[int] X = MemoizedFunc[int, T, T2] From f1bf575f79c23349e07557cfcf0ba9cafd50c70b Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Tue, 23 May 2023 21:14:53 +0100 Subject: [PATCH 4/6] Update src/test_typing_extensions.py --- src/test_typing_extensions.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index d806c2b0..a00abc2b 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -2653,8 +2653,13 @@ def __call__(self, *args: Unpack[Ts]) -> T: ... things = "arguments" if sys.version_info >= (3, 11) else "parameters" - with self.assertRaisesRegex(TypeError, f"Too few {things}"): - MemoizedFunc[int] + # A bug was fixed in 3.11.1 + # (https://github.com/python/cpython/commit/74920aa27d0c57443dd7f704d6272cca9c507ab3) + # That means this assertion doesn't pass on 3.11.0, + # but it passes on all other Python versions + if sys.version_info[:3] != (3, 11, 0): + with self.assertRaisesRegex(TypeError, f"Too few {things}"): + MemoizedFunc[int] X = MemoizedFunc[int, T, T2] self.assertEqual(X.__parameters__, (T, T2)) From 22e9ec022ef8a605cc69893193c659de7afc17b3 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Tue, 23 May 2023 21:25:08 +0100 Subject: [PATCH 5/6] Fixup docstring while we're here --- src/typing_extensions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/typing_extensions.py b/src/typing_extensions.py index 17eb8606..f13859f0 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -726,7 +726,8 @@ def func(x: Proto) -> int: func(C()) # Passes static type check See PEP 544 for details. Protocol classes decorated with - @typing_extensions.runtime act as simple-minded runtime protocol that checks + @typing_extensions.runtime_checkable act + as simple-minded runtime-checkable protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as:: From dfc84e7b4982d0a4fc1719569af09a6cff24bef9 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Tue, 23 May 2023 21:28:07 +0100 Subject: [PATCH 6/6] Unskip a test that now passes --- src/test_typing_extensions.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index a00abc2b..9c605fa4 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -3827,9 +3827,11 @@ class D(Protocol[T1, T2, Unpack[Ts]]): pass self.assertEqual( klass[int, str, float, bool].__args__, (int, str, float, bool) ) - # TODO This should probably also fail on 3.11, - # pending changes to CPython. - if not TYPING_3_11_0: + # A bug was fixed in 3.11.1 + # (https://github.com/python/cpython/commit/74920aa27d0c57443dd7f704d6272cca9c507ab3) + # That means this assertion doesn't pass on 3.11.0, + # but it passes on all other Python versions + if sys.version_info[:3] != (3, 11, 0): with self.assertRaises(TypeError): klass[int]