Skip to content

Commit

Permalink
Better handling of generic aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
cdce8p committed Apr 6, 2021
1 parent 8e28720 commit a4fc868
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 49 deletions.
132 changes: 90 additions & 42 deletions astroid/brain/brain_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,51 @@ class {0}(metaclass=Meta):
"""
TYPING_MEMBERS = set(typing.__all__)

TYPING_ALIAS = frozenset(
(
"typing.Hashable",
"typing.Awaitable",
"typing.Coroutine",
"typing.AsyncIterable",
"typing.AsyncIterator",
"typing.Iterable",
"typing.Iterator",
"typing.Reversible",
"typing.Sized",
"typing.Container",
"typing.Collection",
"typing.Callable",
"typing.AbstractSet",
"typing.MutableSet",
"typing.Mapping",
"typing.MutableMapping",
"typing.Sequence",
"typing.MutableSequence",
"typing.ByteString",
"typing.Tuple",
"typing.List",
"typing.Deque",
"typing.Set",
"typing.FrozenSet",
"typing.MappingView",
"typing.KeysView",
"typing.ItemsView",
"typing.ValuesView",
"typing.ContextManager",
"typing.AsyncContextManager",
"typing.Dict",
"typing.DefaultDict",
"typing.OrderedDict",
"typing.Counter",
"typing.ChainMap",
"typing.Generator",
"typing.AsyncGenerator",
"typing.Type",
"typing.Pattern",
"typing.Match",
)
)


def looks_like_typing_typevar_or_newtype(node):
func = node.func
Expand Down Expand Up @@ -88,7 +133,13 @@ def infer_typing_attr(node, context=None):
except InferenceError as exc:
raise UseInferenceDefault from exc

if not value.qname().startswith("typing."):
if (
not value.qname().startswith("typing.")
or PY37
and value.qname() in TYPING_ALIAS
):
# If typing subscript belongs to an alias
# (PY37+) handle it separately later.
raise UseInferenceDefault

node = extract_node(TYPING_TYPE_TEMPLATE.format(value.qname().split(".")[-1]))
Expand Down Expand Up @@ -161,8 +212,6 @@ def full_raiser(origin_func, attr, *args, **kwargs):
else:
return origin_func(attr, *args, **kwargs)

if not isinstance(node, nodes.ClassDef):
raise TypeError("The parameter type should be ClassDef")
try:
node.getattr("__class_getitem__")
# If we are here, then we are sure to modify object that do have __class_getitem__ method (which origin is one the
Expand All @@ -179,52 +228,51 @@ def infer_typing_alias(
) -> typing.Optional[node_classes.NodeNG]:
"""
Infers the call to _alias function
Insert ClassDef with same name as aliased class
in mro to simulate _GenericAlias.
:param node: call node
:param context: inference context
"""
if (
not isinstance(node.parent, nodes.Assign)
or not len(node.parent.targets) == 1
or not isinstance(node.parent.targets[0], nodes.AssignName)
):
return None
res = next(node.args[0].infer(context=ctx))
assign_name = node.parent.targets[0]

class_def = nodes.ClassDef(
name=assign_name.name,
lineno=assign_name.lineno,
col_offset=assign_name.col_offset,
parent=node.parent,
)
if res != astroid.Uninferable and isinstance(res, nodes.ClassDef):
if not PY39:
# Here the node is a typing object which is an alias toward
# the corresponding object of collection.abc module.
# Before python3.9 there is no subscript allowed for any of the collections.abc objects.
# The subscript ability is given through the typing._GenericAlias class
# which is the metaclass of the typing object but not the metaclass of the inferred
# collections.abc object.
# Thus we fake subscript ability of the collections.abc object
# by mocking the existence of a __class_getitem__ method.
# We can not add `__getitem__` method in the metaclass of the object because
# the metaclass is shared by subscriptable and not subscriptable object
maybe_type_var = node.args[1]
if not (
isinstance(maybe_type_var, node_classes.Tuple)
and not maybe_type_var.elts
):
# The typing object is subscriptable if the second argument of the _alias function
# is a TypeVar or a tuple of TypeVar. We could check the type of the second argument but
# it appears that in the typing module the second argument is only TypeVar or a tuple of TypeVar or empty tuple.
# This last value means the type is not Generic and thus cannot be subscriptable
func_to_add = astroid.extract_node(CLASS_GETITEM_TEMPLATE)
res.locals["__class_getitem__"] = [func_to_add]
else:
# If we are here, then we are sure to modify object that do have __class_getitem__ method (which origin is one the
# protocol defined in collections module) whereas the typing module consider it should not
# We do not want __class_getitem__ to be found in the classdef
_forbid_class_getitem_access(res)
else:
# Within python3.9 discrepencies exist between some collections.abc containers that are subscriptable whereas
# corresponding containers in the typing module are not! This is the case at least for ByteString.
# It is far more to complex and dangerous to try to remove __class_getitem__ method from all the ancestors of the
# current class. Instead we raise an AttributeInferenceError if we try to access it.
maybe_type_var = node.args[1]
if isinstance(maybe_type_var, nodes.Const) and maybe_type_var.value == 0:
# Starting with Python39 the _alias function is in fact instantiation of _SpecialGenericAlias class.
# Thus the type is not Generic if the second argument of the call is equal to zero
_forbid_class_getitem_access(res)
return iter([res])
return iter([astroid.Uninferable])
# Only add `res` as base if it's a `ClassDef`
# This isn't the case for `typing.Pattern` and `typing.Match`
class_def.postinit(bases=[res], body=[], decorators=None)

maybe_type_var = node.args[1]
if (
not PY39
and not (
isinstance(maybe_type_var, node_classes.Tuple) and not maybe_type_var.elts
)
or PY39
and isinstance(maybe_type_var, nodes.Const)
and maybe_type_var.value > 0
):
# If typing alias is subscriptable, add `__class_getitem__` to ClassDef
func_to_add = astroid.extract_node(CLASS_GETITEM_TEMPLATE)
class_def.locals["__class_getitem__"] = [func_to_add]
else:
# If not, make sure that `__class_getitem__` access is forbidden.
# This is an issue in cases where the aliased class implements it,
# but the typing alias doesn't. E.g. `typing.ByteString` for PY39+
_forbid_class_getitem_access(class_def)
return iter([class_def])


MANAGER.register_transform(
Expand Down
35 changes: 28 additions & 7 deletions tests/unittest_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,7 +1167,7 @@ class Derived(collections.abc.Iterator[int]):
],
)

@test_utils.require_version(maxver="3.8")
@test_utils.require_version(maxver="3.9")
def test_collections_object_not_yet_subscriptable_2(self):
"""Before python39 Iterator in the collection.abc module is not subscriptable"""
node = builder.extract_node(
Expand All @@ -1194,6 +1194,28 @@ def test_collections_object_subscriptable_3(self):
inferred.getattr("__class_getitem__")[0], nodes.FunctionDef
)

@test_utils.require_version(minver="3.9")
def test_collections_object_subscriptable_4(self):
"""Multiple inheritance with subscriptable collection class"""
node = builder.extract_node(
"""
import collections.abc
class Derived(collections.abc.Hashable, collections.abc.Iterator[int]):
pass
"""
)
inferred = next(node.infer())
assertEqualMro(
inferred,
[
"Derived",
"Hashable",
"Iterator",
"Iterable",
"object",
],
)


@test_utils.require_version("3.6")
class TypingBrain(unittest.TestCase):
Expand Down Expand Up @@ -1398,12 +1420,12 @@ class Derived1(MutableSet[T]):
"""
)
inferred = next(node.infer())
check_metaclass_is_abc(inferred)
assertEqualMro(
inferred,
[
"Derived1",
"MutableSet",
"MutableSet",
"Set",
"Collection",
"Sized",
Expand All @@ -1429,19 +1451,18 @@ class Derived2(typing.OrderedDict[int, str]):
"""
)
inferred = next(node.infer())
# OrderedDict has no metaclass because it
# inherits from dict which is C coded
self.assertIsNone(inferred.metaclass())
assertEqualMro(
inferred,
[
"Derived2",
"OrderedDict",
"OrderedDict",
"dict",
"object",
],
)

@test_utils.require_version(minver="3.7")
def test_typing_object_not_subscriptable(self):
"""Hashable is not subscriptable"""
wrong_node = builder.extract_node(
Expand All @@ -1459,10 +1480,10 @@ def test_typing_object_not_subscriptable(self):
"""
)
inferred = next(right_node.infer())
check_metaclass_is_abc(inferred)
assertEqualMro(
inferred,
[
"Hashable",
"Hashable",
"object",
],
Expand All @@ -1480,10 +1501,10 @@ def test_typing_object_subscriptable(self):
"""
)
inferred = next(right_node.infer())
check_metaclass_is_abc(inferred)
assertEqualMro(
inferred,
[
"MutableSet",
"MutableSet",
"Set",
"Collection",
Expand Down

0 comments on commit a4fc868

Please sign in to comment.