From 070dae418b8275200e9260d4963a7e857f28eea9 Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Fri, 24 Apr 2020 18:34:05 +0200 Subject: [PATCH] If a generic class is subscripted, infer that class itself Closes PyCQA/pylint#3131 Closes PyCQA/pylint#3505 --- ChangeLog | 5 +++++ astroid/inference.py | 4 ++++ tests/unittest_inference.py | 24 ++++++++++++++++++++++++ 3 files changed, 33 insertions(+) diff --git a/ChangeLog b/ChangeLog index 272fd8f11c..00ea6cf09b 100644 --- a/ChangeLog +++ b/ChangeLog @@ -2,6 +2,11 @@ astroid's ChangeLog =================== +* If a generic class is subscripted, infer that class itself + + Closes PyCQA/pylint#3131 + Closes PyCQA/pylint#3505 + What's New in astroid 2.4.0? ============================ Release Date: 2020-04-27 diff --git a/astroid/inference.py b/astroid/inference.py index bc3e1f9701..4f30a8de82 100644 --- a/astroid/inference.py +++ b/astroid/inference.py @@ -368,6 +368,10 @@ def infer_subscript(self, context=None): if value is util.Uninferable: yield util.Uninferable return None + if isinstance(value, nodes.ClassDef): + if value.is_subtype_of("typing.Generic"): + yield value + return None for index in self.slice.infer(context): if index is util.Uninferable: yield util.Uninferable diff --git a/tests/unittest_inference.py b/tests/unittest_inference.py index e267f97022..8541a04855 100644 --- a/tests/unittest_inference.py +++ b/tests/unittest_inference.py @@ -374,6 +374,30 @@ class A(B): #@ self.assertIs(a2_ancestors[0], b) self.assertIs(a2_ancestors[1], a1) + @pytest.mark.skipif(sys.version_info < (3, 5), reason="Needs 'typing' module") + def test_ancestors_generic(self): + code = """ + from typing import Generic, TypeVar + + T = TypeVar('T') + + class A(Generic[T]): #@ + pass + + class B(A[T]): #@ + pass + + class C(B[int]): #@ + pass + """ + a, b, c = extract_node(code, __name__) + ancestors = list(c.ancestors()) + self.assertEqual(len(ancestors), 4) + self.assertIs(ancestors[0], b) + self.assertIs(ancestors[1], a) + self.assertEqual(ancestors[2].name, "Generic") + self.assertEqual(ancestors[3].name, "object") + def test_f_arg_f(self): code = """ def f(f=1):