diff --git a/ChangeLog b/ChangeLog index 4b16576793..a321d9965c 100644 --- a/ChangeLog +++ b/ChangeLog @@ -19,6 +19,12 @@ Release Date: TBA * Use ``inference_tip`` for ``typing.TypedDict`` brain. +* Fix mro for classes that inherit from typing.Generic + +* Add inference tip for typing.Generic and typing.Annotated with ``__class_getitem__`` + + Closes PyCQA/pylint#2822 + What's New in astroid 2.5.2? ============================ diff --git a/astroid/brain/brain_typing.py b/astroid/brain/brain_typing.py index 2e76446561..9d94ca3ee8 100644 --- a/astroid/brain/brain_typing.py +++ b/astroid/brain/brain_typing.py @@ -88,6 +88,12 @@ class {0}(metaclass=Meta): ) ) +CLASS_GETITEM_TEMPLATE = """ +@classmethod +def __class_getitem__(cls, item): + return cls +""" + def looks_like_typing_typevar_or_newtype(node): func = node.func @@ -126,7 +132,9 @@ def _looks_like_typing_subscript(node): return False -def infer_typing_attr(node, context=None): +def infer_typing_attr( + node: nodes.Subscript, ctx: context.InferenceContext = None +) -> typing.Iterator[nodes.ClassDef]: """Infer a typing.X[...] subscript""" try: value = next(node.value.infer()) @@ -142,8 +150,31 @@ def infer_typing_attr(node, context=None): # (PY37+) handle it separately. raise UseInferenceDefault + if ( + PY37 + and isinstance(value, nodes.ClassDef) + and value.qname() + in ("typing.Generic", "typing.Annotated", "typing_extensions.Annotated") + ): + # With PY37+ typing.Generic and typing.Annotated (PY39) are subscriptable + # through __class_getitem__. Since astroid can't easily + # infer the native methods, replace them for an easy inference tip + func_to_add = astroid.extract_node(CLASS_GETITEM_TEMPLATE) + value.locals["__class_getitem__"] = [func_to_add] + if ( + isinstance(node.parent, nodes.ClassDef) + and node in node.parent.bases + and getattr(node.parent, "__cache", None) + ): + # node.parent.slots is evaluated and cached before the inference tip + # is first applied. Remove the last result to allow a recalculation of slots + cache = getattr(node.parent, "__cache") + if cache.get(node.parent.slots) is not None: + del cache[node.parent.slots] + return iter([value]) + node = extract_node(TYPING_TYPE_TEMPLATE.format(value.qname().split(".")[-1])) - return node.infer(context=context) + return node.infer(context=ctx) def _looks_like_typedDict( # pylint: disable=invalid-name @@ -166,13 +197,6 @@ def infer_typedDict( # pylint: disable=invalid-name return iter([class_def]) -CLASS_GETITEM_TEMPLATE = """ -@classmethod -def __class_getitem__(cls, item): - return cls -""" - - def _looks_like_typing_alias(node: nodes.Call) -> bool: """ Returns True if the node corresponds to a call to _alias function. diff --git a/astroid/scoped_nodes.py b/astroid/scoped_nodes.py index dd5aa1257a..c3558e9a50 100644 --- a/astroid/scoped_nodes.py +++ b/astroid/scoped_nodes.py @@ -103,6 +103,42 @@ def _c3_merge(sequences, cls, context): return None +def clean_typing_generic_mro(sequences: List[List["ClassDef"]]) -> None: + """A class can inherit from typing.Generic directly, as base, + and as base of bases. The merged MRO must however only contain the last entry. + To prepare for _c3_merge, remove some typing.Generic entries from + sequences if multiple are present. + + This method will check if Generic is in inferred_bases and also + part of bases_mro. If true, remove it from inferred_bases + as well as its entry the bases_mro. + + Format sequences: [[self]] + bases_mro + [inferred_bases] + """ + bases_mro = sequences[1:-1] + inferred_bases = sequences[-1] + # Check if Generic is part of inferred_bases + for i, base in enumerate(inferred_bases): + if base.qname() == "typing.Generic": + position_in_inferred_bases = i + break + else: + return + # Check if also part of bases_mro + # Ignore entry for typing.Generic + for i, seq in enumerate(bases_mro): + if i == position_in_inferred_bases: + continue + if any(base.qname() == "typing.Generic" for base in seq): + break + else: + return + # Found multiple Generics in mro, remove entry from inferred_bases + # and the corresponding one from bases_mro + inferred_bases.pop(position_in_inferred_bases) + bases_mro.pop(position_in_inferred_bases) + + def clean_duplicates_mro(sequences, cls, context): for sequence in sequences: names = [ @@ -2924,6 +2960,7 @@ def _compute_mro(self, context=None): unmerged_mro = [[self]] + bases_mro + [inferred_bases] unmerged_mro = list(clean_duplicates_mro(unmerged_mro, self, context)) + clean_typing_generic_mro(unmerged_mro) return _c3_merge(unmerged_mro, self, context) def mro(self, context=None) -> List["ClassDef"]: diff --git a/tests/unittest_brain.py b/tests/unittest_brain.py index f15c4a1c64..a217f223f7 100644 --- a/tests/unittest_brain.py +++ b/tests/unittest_brain.py @@ -1361,6 +1361,56 @@ def test_typing_types(self): inferred = next(node.infer()) self.assertIsInstance(inferred, nodes.ClassDef, node.as_string()) + @test_utils.require_version(minver="3.7") + def test_typing_generic_subscriptable(self): + """Test typing.Generic is subscriptable with __class_getitem__ (added in PY37)""" + node = builder.extract_node( + """ + from typing import Generic, TypeVar + T = TypeVar('T') + Generic[T] + """ + ) + inferred = next(node.infer()) + assert isinstance(inferred, nodes.ClassDef) + assert isinstance(inferred.getattr("__class_getitem__")[0], nodes.FunctionDef) + + @test_utils.require_version(minver="3.9") + def test_typing_annotated_subscriptable(self): + """Test typing.Annotated is subscriptable with __class_getitem__""" + node = builder.extract_node( + """ + import typing + typing.Annotated[str, "data"] + """ + ) + inferred = next(node.infer()) + assert isinstance(inferred, nodes.ClassDef) + assert isinstance(inferred.getattr("__class_getitem__")[0], nodes.FunctionDef) + + @test_utils.require_version(minver="3.7") + def test_typing_generic_slots(self): + """Test cache reset for slots if Generic subscript is inferred.""" + node = builder.extract_node( + """ + from typing import Generic, TypeVar + T = TypeVar('T') + class A(Generic[T]): + __slots__ = ['value'] + def __init__(self, value): + self.value = value + """ + ) + inferred = next(node.infer()) + assert len(inferred.slots()) == 0 + # Only after the subscript base is inferred and the inference tip applied, + # will slots contain the correct value + next(node.bases[0].infer()) + slots = inferred.slots() + assert len(slots) == 1 + assert isinstance(slots[0], nodes.Const) + assert slots[0].value == "value" + def test_has_dunder_args(self): ast_node = builder.extract_node( """ diff --git a/tests/unittest_scoped_nodes.py b/tests/unittest_scoped_nodes.py index a0f882aee6..b98cd8f836 100644 --- a/tests/unittest_scoped_nodes.py +++ b/tests/unittest_scoped_nodes.py @@ -1275,6 +1275,9 @@ class NodeBase(object): def assertEqualMro(self, klass, expected_mro): self.assertEqual([member.name for member in klass.mro()], expected_mro) + def assertEqualMroQName(self, klass, expected_mro): + self.assertEqual([member.qname() for member in klass.mro()], expected_mro) + @unittest.skipUnless(HAS_SIX, "These tests require the six library") def test_with_metaclass_mro(self): astroid = builder.parse( @@ -1438,6 +1441,142 @@ class C(scope.A, scope.B): ) self.assertEqualMro(cls, ["C", "A", "B", "object"]) + @test_utils.require_version(minver="3.7") + def test_mro_generic_1(self): + cls = builder.extract_node( + """ + import typing + T = typing.TypeVar('T') + class A(typing.Generic[T]): ... + class B: ... + class C(A[T], B): ... + """ + ) + self.assertEqualMroQName( + cls, [".C", ".A", "typing.Generic", ".B", "builtins.object"] + ) + + @test_utils.require_version(minver="3.7") + def test_mro_generic_2(self): + cls = builder.extract_node( + """ + from typing import Generic, TypeVar + T = TypeVar('T') + class A: ... + class B(Generic[T]): ... + class C(Generic[T], A, B[T]): ... + """ + ) + self.assertEqualMroQName( + cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"] + ) + + @test_utils.require_version(minver="3.7") + def test_mro_generic_3(self): + cls = builder.extract_node( + """ + from typing import Generic, TypeVar + T = TypeVar('T') + class A: ... + class B(A, Generic[T]): ... + class C(Generic[T]): ... + class D(B[T], C[T], Generic[T]): ... + """ + ) + self.assertEqualMroQName( + cls, [".D", ".B", ".A", ".C", "typing.Generic", "builtins.object"] + ) + + @test_utils.require_version(minver="3.7") + def test_mro_generic_4(self): + cls = builder.extract_node( + """ + from typing import Generic, TypeVar + T = TypeVar('T') + class A: ... + class B(Generic[T]): ... + class C(A, Generic[T], B[T]): ... + """ + ) + self.assertEqualMroQName( + cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"] + ) + + @test_utils.require_version(minver="3.7") + def test_mro_generic_5(self): + cls = builder.extract_node( + """ + from typing import Generic, TypeVar + T1 = TypeVar('T1') + T2 = TypeVar('T2') + class A(Generic[T1]): ... + class B(Generic[T2]): ... + class C(A[T1], B[T2]): ... + """ + ) + self.assertEqualMroQName( + cls, [".C", ".A", ".B", "typing.Generic", "builtins.object"] + ) + + @test_utils.require_version(minver="3.7") + def test_mro_generic_6(self): + cls = builder.extract_node( + """ + from typing import Generic as TGeneric, TypeVar + T = TypeVar('T') + class Generic: ... + class A(Generic): ... + class B(TGeneric[T]): ... + class C(A, B[T]): ... + """ + ) + self.assertEqualMroQName( + cls, [".C", ".A", ".Generic", ".B", "typing.Generic", "builtins.object"] + ) + + @test_utils.require_version(minver="3.7") + def test_mro_generic_7(self): + cls = builder.extract_node( + """ + from typing import Generic, TypeVar + T = TypeVar('T') + class A(): ... + class B(Generic[T]): ... + class C(A, B[T]): ... + class D: ... + class E(C[str], D): ... + """ + ) + self.assertEqualMroQName( + cls, [".E", ".C", ".A", ".B", "typing.Generic", ".D", "builtins.object"] + ) + + @test_utils.require_version(minver="3.7") + def test_mro_generic_error_1(self): + cls = builder.extract_node( + """ + from typing import Generic, TypeVar + T1 = TypeVar('T1') + T2 = TypeVar('T2') + class A(Generic[T1], Generic[T2]): ... + """ + ) + with self.assertRaises(DuplicateBasesError) as ex: + cls.mro() + + @test_utils.require_version(minver="3.7") + def test_mro_generic_error_2(self): + cls = builder.extract_node( + """ + from typing import Generic, TypeVar + T = TypeVar('T') + class A(Generic[T]): ... + class B(A[T], A[T]): ... + """ + ) + with self.assertRaises(DuplicateBasesError) as ex: + cls.mro() + def test_generator_from_infer_call_result_parent(self): func = builder.extract_node( """