From f84fb2458bc9f396eb08c8b91bc0119adece252c Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Mon, 13 May 2024 21:10:38 -0500 Subject: [PATCH] Ensure proper order when doing recursive delete_instance(). This could break if models had multiple FKs at different levels. Code is also somewhat simplified. Fixes #2893 --- CHANGELOG.md | 3 ++ peewee.py | 12 +++++--- tests/base.py | 6 ++++ tests/models.py | 21 +++++++------- tests/regressions.py | 69 ++++++++++++++++++++++++++++++++++++++++++-- 5 files changed, 95 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d472d89e..e2c4b0834 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ https://github.com/coleifer/peewee/releases ## master +* Fix bug in recursive `model.delete_instance()` when a table contains + foreign-keys at multiple depths of the graph, #2893. + [View commits](https://github.com/coleifer/peewee/compare/3.17.5...master) ## 3.17.5 diff --git a/peewee.py b/peewee.py index ffecdc098..81a92d5da 100644 --- a/peewee.py +++ b/peewee.py @@ -6956,9 +6956,10 @@ def is_dirty(self): def dirty_fields(self): return [f for f in self._meta.sorted_fields if f.name in self._dirty] - def dependencies(self, search_nullable=False): + def dependencies(self, search_nullable=True): model_class = type(self) stack = [(type(self), None)] + queries = {} seen = set() while stack: @@ -6974,13 +6975,16 @@ def dependencies(self, search_nullable=False): subquery = (rel_model.select(rel_model._meta.primary_key) .where(node)) if not fk.null or search_nullable: + queries.setdefault(rel_model, []).append((node, fk)) stack.append((rel_model, subquery)) - yield (node, fk) + + for m in reversed(sort_models(seen)): + for sq, q in queries.get(m, ()): + yield sq, q def delete_instance(self, recursive=False, delete_nullable=False): if recursive: - dependencies = self.dependencies(delete_nullable) - for query, fk in reversed(list(dependencies)): + for query, fk in self.dependencies(): model = fk.model if fk.null and not delete_nullable: model.update(**{fk.name: None}).where(query).execute() diff --git a/tests/base.py b/tests/base.py index 8259886ea..6b7fd86b3 100644 --- a/tests/base.py +++ b/tests/base.py @@ -179,6 +179,12 @@ def assertSQL(self, query, sql, params=None, **state): if params is not None: self.assertEqual(qparams, params) + def assertHistory(self, n, expected): + queries = [logrecord.msg for logrecord in self._qh.queries[-n:]] + queries = [(sql.replace('%s', '?').replace('`', '"'), params) + for sql, params in queries] + self.assertEqual(queries, expected) + @property def history(self): return self._qh.queries diff --git a/tests/models.py b/tests/models.py index 15ed41ebb..0e805c02e 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1595,22 +1595,24 @@ def setUp(self): def test_delete_instance_recursive(self): huey = User.get(User.username == 'huey') + a = [] + for d in huey.dependencies(): + a.append(d) with self.assertQueryCount(5): huey.delete_instance(recursive=True) - queries = [logrecord.msg for logrecord in self._qh.queries[-5:]] - self.assertEqual(sorted(queries), [ + self.assertHistory(5, [ + ('DELETE FROM "favorite" WHERE ("favorite"."user_id" = ?)', + [huey.id]), ('DELETE FROM "favorite" WHERE (' '"favorite"."tweet_id" IN (' 'SELECT "t1"."id" FROM "tweet" AS "t1" WHERE (' '"t1"."user_id" = ?)))', [huey.id]), - ('DELETE FROM "favorite" WHERE ("favorite"."user_id" = ?)', - [huey.id]), ('DELETE FROM "tweet" WHERE ("tweet"."user_id" = ?)', [huey.id]), - ('DELETE FROM "users" WHERE ("users"."id" = ?)', [huey.id]), ('UPDATE "account" SET "user_id" = ? ' 'WHERE ("account"."user_id" = ?)', [None, huey.id]), + ('DELETE FROM "users" WHERE ("users"."id" = ?)', [huey.id]), ]) # Only one user left. @@ -1638,17 +1640,16 @@ def test_delete_nullable(self): huey.delete_instance(recursive=True, delete_nullable=True) # Get the last 5 delete queries. - queries = [logrecord.msg for logrecord in self._qh.queries[-5:]] - self.assertEqual(sorted(queries), [ - ('DELETE FROM "account" WHERE ("account"."user_id" = ?)', + self.assertHistory(5, [ + ('DELETE FROM "favorite" WHERE ("favorite"."user_id" = ?)', [huey.id]), ('DELETE FROM "favorite" WHERE (' '"favorite"."tweet_id" IN (' 'SELECT "t1"."id" FROM "tweet" AS "t1" WHERE (' '"t1"."user_id" = ?)))', [huey.id]), - ('DELETE FROM "favorite" WHERE ("favorite"."user_id" = ?)', - [huey.id]), ('DELETE FROM "tweet" WHERE ("tweet"."user_id" = ?)', [huey.id]), + ('DELETE FROM "account" WHERE ("account"."user_id" = ?)', + [huey.id]), ('DELETE FROM "users" WHERE ("users"."id" = ?)', [huey.id]), ]) diff --git a/tests/regressions.py b/tests/regressions.py index 55f7197bd..da6ea7fc5 100644 --- a/tests/regressions.py +++ b/tests/regressions.py @@ -138,8 +138,7 @@ def test_delete_instance_regression(self): with self.assertQueryCount(5): a2.delete_instance(recursive=True) - queries = [logrecord.msg for logrecord in self._qh.queries[-5:]] - self.assertEqual(sorted(queries, reverse=True), [ + self.assertHistory(5, [ ('DELETE FROM "di_d" WHERE ("di_d"."c_id" IN (' 'SELECT "t1"."id" FROM "di_c" AS "t1" WHERE ("t1"."b_id" IN (' 'SELECT "t2"."id" FROM "di_b" AS "t2" WHERE ("t2"."a_id" = ?)' @@ -1823,3 +1822,69 @@ def test_sum_case_subquery(self): case = Case(None, [(Sample.id.in_(subq), Sample.value)], 0) q = Sample.select(fn.SUM(case)) self.assertEqual(q.scalar(), 4.0) + + +class I(TestModel): + name = TextField() +class S(TestModel): + i = ForeignKeyField(I) +class P(TestModel): + i = ForeignKeyField(I) +class PS(TestModel): + p = ForeignKeyField(P) + s = ForeignKeyField(S) +class PP(TestModel): + ps = ForeignKeyField(PS) +class O(TestModel): + ps = ForeignKeyField(PS) + s = ForeignKeyField(S) +class OX(TestModel): + o = ForeignKeyField(O, null=True) + +class TestDeleteInstanceDFS(ModelTestCase): + requires = [I, S, P, PS, PP, O, OX] + + def test_delete_instance_dfs(self): + i1, i2 = [I.create(name=n) for n in ('i1', 'i2')] + for i in (i1, i2): + s = S.create(i=i) + p = P.create(i=i) + ps = PS.create(p=p, s=s) + pp = PP.create(ps=ps) + o = O.create(ps=ps, s=s) + ox = OX.create(o=o) + + with self.assertQueryCount(9): + i1.delete_instance(recursive=True) + + self.assertHistory(9, [ + ('DELETE FROM "pp" WHERE (' + '"pp"."ps_id" IN (SELECT "t1"."id" FROM "ps" AS "t1" WHERE (' + '"t1"."p_id" IN (SELECT "t2"."id" FROM "p" AS "t2" WHERE (' + '"t2"."i_id" = ?)))))', [i1.id]), + ('UPDATE "ox" SET "o_id" = ? WHERE (' + '"ox"."o_id" IN (SELECT "t1"."id" FROM "o" AS "t1" WHERE (' + '"t1"."ps_id" IN (SELECT "t2"."id" FROM "ps" AS "t2" WHERE (' + '"t2"."p_id" IN (SELECT "t3"."id" FROM "p" AS "t3" WHERE (' + '"t3"."i_id" = ?)))))))', [None, i1.id]), + ('DELETE FROM "o" WHERE (' + '"o"."ps_id" IN (SELECT "t1"."id" FROM "ps" AS "t1" WHERE (' + '"t1"."p_id" IN (SELECT "t2"."id" FROM "p" AS "t2" WHERE (' + '"t2"."i_id" = ?)))))', [i1.id]), + ('DELETE FROM "o" WHERE (' + '"o"."s_id" IN (SELECT "t1"."id" FROM "s" AS "t1" WHERE (' + '"t1"."i_id" = ?)))', [i1.id]), + ('DELETE FROM "ps" WHERE (' + '"ps"."p_id" IN (SELECT "t1"."id" FROM "p" AS "t1" WHERE (' + '"t1"."i_id" = ?)))', [i1.id]), + ('DELETE FROM "ps" WHERE (' + '"ps"."s_id" IN (SELECT "t1"."id" FROM "s" AS "t1" WHERE (' + '"t1"."i_id" = ?)))', [i1.id]), + ('DELETE FROM "s" WHERE ("s"."i_id" = ?)', [i1.id]), + ('DELETE FROM "p" WHERE ("p"."i_id" = ?)', [i1.id]), + ('DELETE FROM "i" WHERE ("i"."id" = ?)', [i1.id]), + ]) + + counts = {OX: 2} + for m in self.requires: + self.assertEqual(m.select().count(), counts.get(m, 1))