From ca2009d72a52a98bf43aafa9ad270a4fcfabfc89 Mon Sep 17 00:00:00 2001 From: Brandt Bucher Date: Fri, 25 Jun 2021 08:20:43 -0700 Subject: [PATCH] bpo-43977: Properly update the tp_flags of existing subclasses when their parents are registered (GH-26864) --- Doc/library/dis.rst | 13 ++- Lib/test/test_patma.py | 110 ++++++++++++++---- .../2021-06-22-16-45-48.bpo-43977.bamAGF.rst | 3 + Modules/_abc.c | 37 +++++- 4 files changed, 129 insertions(+), 34 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2021-06-22-16-45-48.bpo-43977.bamAGF.rst diff --git a/Doc/library/dis.rst b/Doc/library/dis.rst index 65fbb00bd6597b..645e94a669fd13 100644 --- a/Doc/library/dis.rst +++ b/Doc/library/dis.rst @@ -772,17 +772,20 @@ iterations of the loop. .. opcode:: MATCH_MAPPING - If TOS is an instance of :class:`collections.abc.Mapping`, push ``True`` onto - the stack. Otherwise, push ``False``. + If TOS is an instance of :class:`collections.abc.Mapping` (or, more technically: if + it has the :const:`Py_TPFLAGS_MAPPING` flag set in its + :c:member:`~PyTypeObject.tp_flags`), push ``True`` onto the stack. Otherwise, push + ``False``. .. versionadded:: 3.10 .. opcode:: MATCH_SEQUENCE - If TOS is an instance of :class:`collections.abc.Sequence` and is *not* an - instance of :class:`str`/:class:`bytes`/:class:`bytearray`, push ``True`` - onto the stack. Otherwise, push ``False``. + If TOS is an instance of :class:`collections.abc.Sequence` and is *not* an instance + of :class:`str`/:class:`bytes`/:class:`bytearray` (or, more technically: if it has + the :const:`Py_TPFLAGS_SEQUENCE` flag set in its :c:member:`~PyTypeObject.tp_flags`), + push ``True`` onto the stack. Otherwise, push ``False``. .. versionadded:: 3.10 diff --git a/Lib/test/test_patma.py b/Lib/test/test_patma.py index 91b1f7e2aa4c73..a8889cae1091fc 100644 --- a/Lib/test/test_patma.py +++ b/Lib/test/test_patma.py @@ -24,7 +24,43 @@ def test_refleaks(self): class TestInheritance(unittest.TestCase): - def test_multiple_inheritance(self): + @staticmethod + def check_sequence_then_mapping(x): + match x: + case [*_]: + return "seq" + case {}: + return "map" + + @staticmethod + def check_mapping_then_sequence(x): + match x: + case {}: + return "map" + case [*_]: + return "seq" + + def test_multiple_inheritance_mapping(self): + class C: + pass + class M1(collections.UserDict, collections.abc.Sequence): + pass + class M2(C, collections.UserDict, collections.abc.Sequence): + pass + class M3(collections.UserDict, C, list): + pass + class M4(dict, collections.abc.Sequence, C): + pass + self.assertEqual(self.check_sequence_then_mapping(M1()), "map") + self.assertEqual(self.check_sequence_then_mapping(M2()), "map") + self.assertEqual(self.check_sequence_then_mapping(M3()), "map") + self.assertEqual(self.check_sequence_then_mapping(M4()), "map") + self.assertEqual(self.check_mapping_then_sequence(M1()), "map") + self.assertEqual(self.check_mapping_then_sequence(M2()), "map") + self.assertEqual(self.check_mapping_then_sequence(M3()), "map") + self.assertEqual(self.check_mapping_then_sequence(M4()), "map") + + def test_multiple_inheritance_sequence(self): class C: pass class S1(collections.UserList, collections.abc.Mapping): @@ -35,32 +71,60 @@ class S3(list, C, collections.abc.Mapping): pass class S4(collections.UserList, dict, C): pass - class M1(collections.UserDict, collections.abc.Sequence): + self.assertEqual(self.check_sequence_then_mapping(S1()), "seq") + self.assertEqual(self.check_sequence_then_mapping(S2()), "seq") + self.assertEqual(self.check_sequence_then_mapping(S3()), "seq") + self.assertEqual(self.check_sequence_then_mapping(S4()), "seq") + self.assertEqual(self.check_mapping_then_sequence(S1()), "seq") + self.assertEqual(self.check_mapping_then_sequence(S2()), "seq") + self.assertEqual(self.check_mapping_then_sequence(S3()), "seq") + self.assertEqual(self.check_mapping_then_sequence(S4()), "seq") + + def test_late_registration_mapping(self): + class Parent: pass - class M2(C, collections.UserDict, collections.abc.Sequence): + class ChildPre(Parent): pass - class M3(collections.UserDict, C, list): + class GrandchildPre(ChildPre): pass - class M4(dict, collections.abc.Sequence, C): + collections.abc.Mapping.register(Parent) + class ChildPost(Parent): pass - def f(x): - match x: - case []: - return "seq" - case {}: - return "map" - def g(x): - match x: - case {}: - return "map" - case []: - return "seq" - for Seq in (S1, S2, S3, S4): - self.assertEqual(f(Seq()), "seq") - self.assertEqual(g(Seq()), "seq") - for Map in (M1, M2, M3, M4): - self.assertEqual(f(Map()), "map") - self.assertEqual(g(Map()), "map") + class GrandchildPost(ChildPost): + pass + self.assertEqual(self.check_sequence_then_mapping(Parent()), "map") + self.assertEqual(self.check_sequence_then_mapping(ChildPre()), "map") + self.assertEqual(self.check_sequence_then_mapping(GrandchildPre()), "map") + self.assertEqual(self.check_sequence_then_mapping(ChildPost()), "map") + self.assertEqual(self.check_sequence_then_mapping(GrandchildPost()), "map") + self.assertEqual(self.check_mapping_then_sequence(Parent()), "map") + self.assertEqual(self.check_mapping_then_sequence(ChildPre()), "map") + self.assertEqual(self.check_mapping_then_sequence(GrandchildPre()), "map") + self.assertEqual(self.check_mapping_then_sequence(ChildPost()), "map") + self.assertEqual(self.check_mapping_then_sequence(GrandchildPost()), "map") + + def test_late_registration_sequence(self): + class Parent: + pass + class ChildPre(Parent): + pass + class GrandchildPre(ChildPre): + pass + collections.abc.Sequence.register(Parent) + class ChildPost(Parent): + pass + class GrandchildPost(ChildPost): + pass + self.assertEqual(self.check_sequence_then_mapping(Parent()), "seq") + self.assertEqual(self.check_sequence_then_mapping(ChildPre()), "seq") + self.assertEqual(self.check_sequence_then_mapping(GrandchildPre()), "seq") + self.assertEqual(self.check_sequence_then_mapping(ChildPost()), "seq") + self.assertEqual(self.check_sequence_then_mapping(GrandchildPost()), "seq") + self.assertEqual(self.check_mapping_then_sequence(Parent()), "seq") + self.assertEqual(self.check_mapping_then_sequence(ChildPre()), "seq") + self.assertEqual(self.check_mapping_then_sequence(GrandchildPre()), "seq") + self.assertEqual(self.check_mapping_then_sequence(ChildPost()), "seq") + self.assertEqual(self.check_mapping_then_sequence(GrandchildPost()), "seq") class TestPatma(unittest.TestCase): diff --git a/Misc/NEWS.d/next/Library/2021-06-22-16-45-48.bpo-43977.bamAGF.rst b/Misc/NEWS.d/next/Library/2021-06-22-16-45-48.bpo-43977.bamAGF.rst new file mode 100644 index 00000000000000..5f8cb7b7ea7294 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2021-06-22-16-45-48.bpo-43977.bamAGF.rst @@ -0,0 +1,3 @@ +Set the proper :const:`Py_TPFLAGS_MAPPING` and :const:`Py_TPFLAGS_SEQUENCE` +flags for subclasses created before a parent has been registered as a +:class:`collections.abc.Mapping` or :class:`collections.abc.Sequence`. diff --git a/Modules/_abc.c b/Modules/_abc.c index 7720d4051fe9eb..8aa68359039e7d 100644 --- a/Modules/_abc.c +++ b/Modules/_abc.c @@ -481,6 +481,32 @@ _abc__abc_init(PyObject *module, PyObject *self) Py_RETURN_NONE; } +static void +set_collection_flag_recursive(PyTypeObject *child, unsigned long flag) +{ + assert(flag == Py_TPFLAGS_MAPPING || flag == Py_TPFLAGS_SEQUENCE); + if (PyType_HasFeature(child, Py_TPFLAGS_IMMUTABLETYPE) || + (child->tp_flags & COLLECTION_FLAGS) == flag) + { + return; + } + child->tp_flags &= ~COLLECTION_FLAGS; + child->tp_flags |= flag; + PyObject *grandchildren = child->tp_subclasses; + if (grandchildren == NULL) { + return; + } + assert(PyDict_CheckExact(grandchildren)); + Py_ssize_t i = 0; + while (PyDict_Next(grandchildren, &i, NULL, &grandchildren)) { + assert(PyWeakref_CheckRef(grandchildren)); + PyObject *grandchild = PyWeakref_GET_OBJECT(grandchildren); + if (PyType_Check(grandchild)) { + set_collection_flag_recursive((PyTypeObject *)grandchild, flag); + } + } +} + /*[clinic input] _abc._abc_register @@ -532,12 +558,11 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass) get_abc_state(module)->abc_invalidation_counter++; /* Set Py_TPFLAGS_SEQUENCE or Py_TPFLAGS_MAPPING flag */ - if (PyType_Check(self) && - !PyType_HasFeature((PyTypeObject *)subclass, Py_TPFLAGS_IMMUTABLETYPE) && - ((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS) - { - ((PyTypeObject *)subclass)->tp_flags &= ~COLLECTION_FLAGS; - ((PyTypeObject *)subclass)->tp_flags |= (((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS); + if (PyType_Check(self)) { + unsigned long collection_flag = ((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS; + if (collection_flag) { + set_collection_flag_recursive((PyTypeObject *)subclass, collection_flag); + } } Py_INCREF(subclass); return subclass;