From 77e60bb8f547e325763bdc47aba6f9156b6e3609 Mon Sep 17 00:00:00 2001 From: Matej Aleksandrov Date: Mon, 28 Oct 2024 19:52:53 +0000 Subject: [PATCH] Fix #2628 by adding ignore_duplicate parameter --- astroid/brain/brain_dataclasses.py | 4 +-- astroid/nodes/scoped_nodes/scoped_nodes.py | 21 ++++++++++----- tests/brain/test_dataclasses.py | 30 ++++++++++++++++++++++ tests/test_scoped_nodes.py | 3 +++ 4 files changed, 50 insertions(+), 8 deletions(-) diff --git a/astroid/brain/brain_dataclasses.py b/astroid/brain/brain_dataclasses.py index 92d983e2b0..fca95f76bd 100644 --- a/astroid/brain/brain_dataclasses.py +++ b/astroid/brain/brain_dataclasses.py @@ -171,7 +171,7 @@ def _find_arguments_from_base_classes( # See TODO down below # all_have_defaults = True - for base in reversed(node.mro()): + for base in reversed(node.mro(ignore_duplicates=True)): if not base.is_dataclass: continue try: @@ -221,7 +221,7 @@ def _parse_arguments_into_strings( def _get_previous_field_default(node: nodes.ClassDef, name: str) -> nodes.NodeNG | None: """Get the default value of a previously defined field.""" - for base in reversed(node.mro()): + for base in reversed(node.mro(ignore_duplicates=True)): if not base.is_dataclass: continue if name in base.locals: diff --git a/astroid/nodes/scoped_nodes/scoped_nodes.py b/astroid/nodes/scoped_nodes/scoped_nodes.py index 99ed79675b..726becfabd 100644 --- a/astroid/nodes/scoped_nodes/scoped_nodes.py +++ b/astroid/nodes/scoped_nodes/scoped_nodes.py @@ -144,12 +144,13 @@ def clean_duplicates_mro( sequences: list[list[ClassDef]], cls: ClassDef, context: InferenceContext | None, + ignore_duplicates: bool, ) -> list[list[ClassDef]]: for sequence in sequences: seen = set() for node in sequence: lineno_and_qname = (node.lineno, node.qname()) - if lineno_and_qname in seen: + if lineno_and_qname in seen and not ignore_duplicates: raise DuplicateBasesError( message="Duplicates found in MROs {mros} for {cls!r}.", mros=sequences, @@ -2834,7 +2835,9 @@ def _inferred_bases(self, context: InferenceContext | None = None): else: yield from baseobj.bases - def _compute_mro(self, context: InferenceContext | None = None): + def _compute_mro( + self, context: InferenceContext | None = None, ignore_duplicates: bool = False + ): if self.qname() == "builtins.object": return [self] @@ -2844,15 +2847,21 @@ def _compute_mro(self, context: InferenceContext | None = None): if base is self: continue - mro = base._compute_mro(context=context) + mro = base._compute_mro( + context=context, ignore_duplicates=ignore_duplicates + ) bases_mro.append(mro) unmerged_mro: list[list[ClassDef]] = [[self], *bases_mro, inferred_bases] - unmerged_mro = clean_duplicates_mro(unmerged_mro, self, context) + unmerged_mro = clean_duplicates_mro( + unmerged_mro, self, context, ignore_duplicates=ignore_duplicates + ) clean_typing_generic_mro(unmerged_mro) return _c3_merge(unmerged_mro, self, context) - def mro(self, context: InferenceContext | None = None) -> list[ClassDef]: + def mro( + self, context: InferenceContext | None = None, ignore_duplicates: bool = False + ) -> list[ClassDef]: """Get the method resolution order, using C3 linearization. :returns: The list of ancestors, sorted by the mro. @@ -2860,7 +2869,7 @@ def mro(self, context: InferenceContext | None = None) -> list[ClassDef]: :raises DuplicateBasesError: Duplicate bases in the same class base :raises InconsistentMroError: A class' MRO is inconsistent """ - return self._compute_mro(context=context) + return self._compute_mro(context=context, ignore_duplicates=ignore_duplicates) def bool_value(self, context: InferenceContext | None = None) -> Literal[True]: """Determine the boolean value of this node. diff --git a/tests/brain/test_dataclasses.py b/tests/brain/test_dataclasses.py index cd3fcb4cfb..a8b1c23523 100644 --- a/tests/brain/test_dataclasses.py +++ b/tests/brain/test_dataclasses.py @@ -1350,3 +1350,33 @@ def attr(self, value: int) -> None: fourth_init: bases.UnboundMethod = next(fourth.infer()) assert [a.name for a in fourth_init.args.args] == ["self", "other_attr", "attr"] assert [a.name for a in fourth_init.args.defaults] == ["Uninferable"] + + +@parametrize_module +def test_dataclass_inherited_from_multiple_protocol_bases(module: str): + code = astroid.extract_node( + f""" + from {module} import dataclass + from typing import TypeVar, Protocol + + BaseT = TypeVar("BaseT") + T = TypeVar("T", bound=BaseT) + + + class A(Protocol[BaseT]): + pass + + + class B(A[T], Protocol[T]): + pass + + + @dataclass + class Dataclass(B[T]): + pass + """ + ) + inferred = code.inferred() + assert len(inferred) == 1 + assert isinstance(inferred[0], nodes.ClassDef) + assert inferred[0].is_dataclass diff --git a/tests/test_scoped_nodes.py b/tests/test_scoped_nodes.py index 5e5bb581d4..028b2d5b2c 100644 --- a/tests/test_scoped_nodes.py +++ b/tests/test_scoped_nodes.py @@ -1938,6 +1938,7 @@ class A(Generic[T1], Generic[T2]): ... assert isinstance(cls, nodes.ClassDef) with self.assertRaises(DuplicateBasesError): cls.mro() + assert len(cls.mro(ignore_duplicates=True)) == 3 def test_mro_generic_error_2(self): cls = builder.extract_node( @@ -1951,6 +1952,8 @@ class B(A[T], A[T]): ... assert isinstance(cls, nodes.ClassDef) with self.assertRaises(DuplicateBasesError): cls.mro() + with self.assertRaises(InconsistentMroError): + cls.mro(ignore_duplicates=True) def test_mro_typing_extensions(self): """Regression test for mro() inference on typing_extensions.