Skip to content

Commit

Permalink
Ensure proper order when doing recursive delete_instance().
Browse files Browse the repository at this point in the history
This could break if models had multiple FKs at different levels. Code is
also somewhat simplified.

Fixes #2893
  • Loading branch information
coleifer committed May 14, 2024
1 parent b7e40dd commit f84fb24
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 16 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ https://github.com/coleifer/peewee/releases

## master

* Fix bug in recursive `model.delete_instance()` when a table contains
foreign-keys at multiple depths of the graph, #2893.

[View commits](https://github.com/coleifer/peewee/compare/3.17.5...master)

## 3.17.5
Expand Down
12 changes: 8 additions & 4 deletions peewee.py
Original file line number Diff line number Diff line change
Expand Up @@ -6956,9 +6956,10 @@ def is_dirty(self):
def dirty_fields(self):
return [f for f in self._meta.sorted_fields if f.name in self._dirty]

def dependencies(self, search_nullable=False):
def dependencies(self, search_nullable=True):
model_class = type(self)
stack = [(type(self), None)]
queries = {}
seen = set()

while stack:
Expand All @@ -6974,13 +6975,16 @@ def dependencies(self, search_nullable=False):
subquery = (rel_model.select(rel_model._meta.primary_key)
.where(node))
if not fk.null or search_nullable:
queries.setdefault(rel_model, []).append((node, fk))
stack.append((rel_model, subquery))
yield (node, fk)

for m in reversed(sort_models(seen)):
for sq, q in queries.get(m, ()):
yield sq, q

def delete_instance(self, recursive=False, delete_nullable=False):
if recursive:
dependencies = self.dependencies(delete_nullable)
for query, fk in reversed(list(dependencies)):
for query, fk in self.dependencies():
model = fk.model
if fk.null and not delete_nullable:
model.update(**{fk.name: None}).where(query).execute()
Expand Down
6 changes: 6 additions & 0 deletions tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ def assertSQL(self, query, sql, params=None, **state):
if params is not None:
self.assertEqual(qparams, params)

def assertHistory(self, n, expected):
queries = [logrecord.msg for logrecord in self._qh.queries[-n:]]
queries = [(sql.replace('%s', '?').replace('`', '"'), params)
for sql, params in queries]
self.assertEqual(queries, expected)

@property
def history(self):
return self._qh.queries
Expand Down
21 changes: 11 additions & 10 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,22 +1595,24 @@ def setUp(self):

def test_delete_instance_recursive(self):
huey = User.get(User.username == 'huey')
a = []
for d in huey.dependencies():
a.append(d)
with self.assertQueryCount(5):
huey.delete_instance(recursive=True)

queries = [logrecord.msg for logrecord in self._qh.queries[-5:]]
self.assertEqual(sorted(queries), [
self.assertHistory(5, [
('DELETE FROM "favorite" WHERE ("favorite"."user_id" = ?)',
[huey.id]),
('DELETE FROM "favorite" WHERE ('
'"favorite"."tweet_id" IN ('
'SELECT "t1"."id" FROM "tweet" AS "t1" WHERE ('
'"t1"."user_id" = ?)))', [huey.id]),
('DELETE FROM "favorite" WHERE ("favorite"."user_id" = ?)',
[huey.id]),
('DELETE FROM "tweet" WHERE ("tweet"."user_id" = ?)', [huey.id]),
('DELETE FROM "users" WHERE ("users"."id" = ?)', [huey.id]),
('UPDATE "account" SET "user_id" = ? '
'WHERE ("account"."user_id" = ?)',
[None, huey.id]),
('DELETE FROM "users" WHERE ("users"."id" = ?)', [huey.id]),
])

# Only one user left.
Expand Down Expand Up @@ -1638,17 +1640,16 @@ def test_delete_nullable(self):
huey.delete_instance(recursive=True, delete_nullable=True)

# Get the last 5 delete queries.
queries = [logrecord.msg for logrecord in self._qh.queries[-5:]]
self.assertEqual(sorted(queries), [
('DELETE FROM "account" WHERE ("account"."user_id" = ?)',
self.assertHistory(5, [
('DELETE FROM "favorite" WHERE ("favorite"."user_id" = ?)',
[huey.id]),
('DELETE FROM "favorite" WHERE ('
'"favorite"."tweet_id" IN ('
'SELECT "t1"."id" FROM "tweet" AS "t1" WHERE ('
'"t1"."user_id" = ?)))', [huey.id]),
('DELETE FROM "favorite" WHERE ("favorite"."user_id" = ?)',
[huey.id]),
('DELETE FROM "tweet" WHERE ("tweet"."user_id" = ?)', [huey.id]),
('DELETE FROM "account" WHERE ("account"."user_id" = ?)',
[huey.id]),
('DELETE FROM "users" WHERE ("users"."id" = ?)', [huey.id]),
])

Expand Down
69 changes: 67 additions & 2 deletions tests/regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ def test_delete_instance_regression(self):
with self.assertQueryCount(5):
a2.delete_instance(recursive=True)

queries = [logrecord.msg for logrecord in self._qh.queries[-5:]]
self.assertEqual(sorted(queries, reverse=True), [
self.assertHistory(5, [
('DELETE FROM "di_d" WHERE ("di_d"."c_id" IN ('
'SELECT "t1"."id" FROM "di_c" AS "t1" WHERE ("t1"."b_id" IN ('
'SELECT "t2"."id" FROM "di_b" AS "t2" WHERE ("t2"."a_id" = ?)'
Expand Down Expand Up @@ -1823,3 +1822,69 @@ def test_sum_case_subquery(self):
case = Case(None, [(Sample.id.in_(subq), Sample.value)], 0)
q = Sample.select(fn.SUM(case))
self.assertEqual(q.scalar(), 4.0)


class I(TestModel):
name = TextField()
class S(TestModel):
i = ForeignKeyField(I)
class P(TestModel):
i = ForeignKeyField(I)
class PS(TestModel):
p = ForeignKeyField(P)
s = ForeignKeyField(S)
class PP(TestModel):
ps = ForeignKeyField(PS)
class O(TestModel):
ps = ForeignKeyField(PS)
s = ForeignKeyField(S)
class OX(TestModel):
o = ForeignKeyField(O, null=True)

class TestDeleteInstanceDFS(ModelTestCase):
requires = [I, S, P, PS, PP, O, OX]

def test_delete_instance_dfs(self):
i1, i2 = [I.create(name=n) for n in ('i1', 'i2')]
for i in (i1, i2):
s = S.create(i=i)
p = P.create(i=i)
ps = PS.create(p=p, s=s)
pp = PP.create(ps=ps)
o = O.create(ps=ps, s=s)
ox = OX.create(o=o)

with self.assertQueryCount(9):
i1.delete_instance(recursive=True)

self.assertHistory(9, [
('DELETE FROM "pp" WHERE ('
'"pp"."ps_id" IN (SELECT "t1"."id" FROM "ps" AS "t1" WHERE ('
'"t1"."p_id" IN (SELECT "t2"."id" FROM "p" AS "t2" WHERE ('
'"t2"."i_id" = ?)))))', [i1.id]),
('UPDATE "ox" SET "o_id" = ? WHERE ('
'"ox"."o_id" IN (SELECT "t1"."id" FROM "o" AS "t1" WHERE ('
'"t1"."ps_id" IN (SELECT "t2"."id" FROM "ps" AS "t2" WHERE ('
'"t2"."p_id" IN (SELECT "t3"."id" FROM "p" AS "t3" WHERE ('
'"t3"."i_id" = ?)))))))', [None, i1.id]),
('DELETE FROM "o" WHERE ('
'"o"."ps_id" IN (SELECT "t1"."id" FROM "ps" AS "t1" WHERE ('
'"t1"."p_id" IN (SELECT "t2"."id" FROM "p" AS "t2" WHERE ('
'"t2"."i_id" = ?)))))', [i1.id]),
('DELETE FROM "o" WHERE ('
'"o"."s_id" IN (SELECT "t1"."id" FROM "s" AS "t1" WHERE ('
'"t1"."i_id" = ?)))', [i1.id]),
('DELETE FROM "ps" WHERE ('
'"ps"."p_id" IN (SELECT "t1"."id" FROM "p" AS "t1" WHERE ('
'"t1"."i_id" = ?)))', [i1.id]),
('DELETE FROM "ps" WHERE ('
'"ps"."s_id" IN (SELECT "t1"."id" FROM "s" AS "t1" WHERE ('
'"t1"."i_id" = ?)))', [i1.id]),
('DELETE FROM "s" WHERE ("s"."i_id" = ?)', [i1.id]),
('DELETE FROM "p" WHERE ("p"."i_id" = ?)', [i1.id]),
('DELETE FROM "i" WHERE ("i"."id" = ?)', [i1.id]),
])

counts = {OX: 2}
for m in self.requires:
self.assertEqual(m.select().count(), counts.get(m, 1))

0 comments on commit f84fb24

Please sign in to comment.