Skip to content

Commit

Permalink
bpo-43224: Implement pickling of TypeVarTuples (#32119)
Browse files Browse the repository at this point in the history
Co-authored-by: Jelle Zijlstra <jelle.zijlstra@gmail.com>
  • Loading branch information
mrahtz and JelleZijlstra authored Apr 22, 2022
1 parent 2551a6c commit 5e130a8
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 7 deletions.
56 changes: 55 additions & 1 deletion Lib/test/test_typing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
import collections
from collections import defaultdict
from functools import lru_cache
from functools import lru_cache, wraps
import inspect
import pickle
import re
Expand Down Expand Up @@ -70,6 +70,18 @@ def clear_caches(self):
f()


def all_pickle_protocols(test_func):
"""Runs `test_func` with various values for `proto` argument."""

@wraps(test_func)
def wrapper(self):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(pickle_proto=proto):
test_func(self, proto=proto)

return wrapper


class Employee:
pass

Expand Down Expand Up @@ -911,6 +923,48 @@ class C(Generic[Unpack[Ts]]): pass
self.assertNotEqual(C[Unpack[Ts1]], C[Unpack[Ts2]])


class TypeVarTuplePicklingTests(BaseTestCase):
# These are slightly awkward tests to run, because TypeVarTuples are only
# picklable if defined in the global scope. We therefore need to push
# various things defined in these tests into the global scope with `global`
# statements at the start of each test.

@all_pickle_protocols
def test_pickling_then_unpickling_results_in_same_identity(self, proto):
global Ts1 # See explanation at start of class.
Ts1 = TypeVarTuple('Ts1')
Ts2 = pickle.loads(pickle.dumps(Ts1, proto))
self.assertIs(Ts1, Ts2)

@all_pickle_protocols
def test_pickling_then_unpickling_unpacked_results_in_same_identity(self, proto):
global Ts # See explanation at start of class.
Ts = TypeVarTuple('Ts')
unpacked1 = Unpack[Ts]
unpacked2 = pickle.loads(pickle.dumps(unpacked1, proto))
self.assertIs(unpacked1, unpacked2)

@all_pickle_protocols
def test_pickling_then_unpickling_tuple_with_typevartuple_equality(
self, proto
):
global T, Ts # See explanation at start of class.
T = TypeVar('T')
Ts = TypeVarTuple('Ts')

a1 = Tuple[Unpack[Ts]]
a2 = pickle.loads(pickle.dumps(a1, proto))
self.assertEqual(a1, a2)

a1 = Tuple[T, Unpack[Ts]]
a2 = pickle.loads(pickle.dumps(a1, proto))
self.assertEqual(a1, a2)

a1 = Tuple[int, Unpack[Ts]]
a2 = pickle.loads(pickle.dumps(a1, proto))
self.assertEqual(a1, a2)


class UnionTests(BaseTestCase):

def test_basics(self):
Expand Down
25 changes: 19 additions & 6 deletions Lib/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,13 @@ def _is_typevar_like(x: Any) -> bool:
return isinstance(x, (TypeVar, ParamSpec)) or _is_unpacked_typevartuple(x)


class _PickleUsingNameMixin:
"""Mixin enabling pickling based on self.__name__."""

def __reduce__(self):
return self.__name__


class _BoundVarianceMixin:
"""Mixin giving __init__ bound and variance arguments.
Expand Down Expand Up @@ -903,11 +910,9 @@ def __repr__(self):
prefix = '~'
return prefix + self.__name__

def __reduce__(self):
return self.__name__


class TypeVar(_Final, _Immutable, _BoundVarianceMixin, _root=True):
class TypeVar(_Final, _Immutable, _BoundVarianceMixin, _PickleUsingNameMixin,
_root=True):
"""Type variable.
Usage::
Expand Down Expand Up @@ -973,7 +978,7 @@ def __typing_subst__(self, arg):
return arg


class TypeVarTuple(_Final, _Immutable, _root=True):
class TypeVarTuple(_Final, _Immutable, _PickleUsingNameMixin, _root=True):
"""Type variable tuple.
Usage:
Expand All @@ -994,11 +999,18 @@ class C(Generic[*Ts]): ...
C[()] # Even this is fine
For more details, see PEP 646.
Note that only TypeVarTuples defined in global scope can be pickled.
"""

def __init__(self, name):
self.__name__ = name

# Used for pickling.
def_mod = _caller()
if def_mod != 'typing':
self.__module__ = def_mod

def __iter__(self):
yield Unpack[self]

Expand Down Expand Up @@ -1057,7 +1069,8 @@ def __eq__(self, other):
return self.__origin__ == other.__origin__


class ParamSpec(_Final, _Immutable, _BoundVarianceMixin, _root=True):
class ParamSpec(_Final, _Immutable, _BoundVarianceMixin, _PickleUsingNameMixin,
_root=True):
"""Parameter specification variable.
Usage::
Expand Down

0 comments on commit 5e130a8

Please sign in to comment.