Skip to content

Commit

Permalink
Ensure parentheses used for CASE with subqueries.
Browse files Browse the repository at this point in the history
Fixes #2873
  • Loading branch information
coleifer committed Apr 19, 2024
1 parent 9dae730 commit 5c41d56
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 12 deletions.
31 changes: 19 additions & 12 deletions peewee.py
Original file line number Diff line number Diff line change
Expand Up @@ -1827,6 +1827,25 @@ def __sql__(self, ctx):
return ctx.literal(self.window._alias or 'w')


class Case(ColumnBase):
def __init__(self, predicate, expression_tuples, default=None):
self.predicate = predicate
self.expression_tuples = expression_tuples
self.default = default

def __sql__(self, ctx):
clauses = [SQL('CASE')]
if self.predicate is not None:
clauses.append(self.predicate)
for expr, value in self.expression_tuples:
clauses.extend((SQL('WHEN'), expr, SQL('THEN'), value))
if self.default is not None:
clauses.extend((SQL('ELSE'), self.default))
clauses.append(SQL('END'))
with ctx(in_function=False):
return ctx.sql(NodeList(clauses))


class ForUpdate(Node):
def __init__(self, expr, of=None, nowait=None):
expr = 'FOR UPDATE' if expr is True else expr
Expand All @@ -1849,18 +1868,6 @@ def __sql__(self, ctx):
return ctx


def Case(predicate, expression_tuples, default=None):
clauses = [SQL('CASE')]
if predicate is not None:
clauses.append(predicate)
for expr, value in expression_tuples:
clauses.extend((SQL('WHEN'), expr, SQL('THEN'), value))
if default is not None:
clauses.extend((SQL('ELSE'), default))
clauses.append(SQL('END'))
return NodeList(clauses)


class NodeList(ColumnBase):
def __init__(self, nodes, glue=' ', parens=False):
self.nodes = nodes
Expand Down
12 changes: 12 additions & 0 deletions tests/regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1811,3 +1811,15 @@ def test_iteration_single_query(self):
list(User.select())
with self.assertQueryCount(1):
self.assertEqual(User.select().count(), 0)


class TestSumCaseSubquery(ModelTestCase):
requires = [Sample]

def test_sum_case_subquery(self):
Sample.insert_many([(i, i) for i in range(5)]).execute()

subq = Sample.select().where(Sample.counter.in_([1, 3, 5]))
case = Case(None, [(Sample.id.in_(subq), Sample.value)], 0)
q = Sample.select(fn.SUM(case))
self.assertEqual(q.scalar(), 4.0)
10 changes: 10 additions & 0 deletions tests/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1793,6 +1793,16 @@ def test_case_function(self):
'ELSE ? END '
'FROM "nn" AS "t1"'), [1, 'one', 2, 'two', '?'])

def test_case_subquery(self):
Name = Table('n', ('id', 'name',))
case = Case(None, [(Name.id.in_(Name.select(Name.id)), 1)], 0)
q = Name.select(fn.SUM(case))
self.assertSQL(q, (
'SELECT SUM('
'CASE WHEN ("t1"."id" IN (SELECT "t1"."id" FROM "n" AS "t1")) '
'THEN ? ELSE ? END) FROM "n" AS "t1"'), [1, 0])



class TestSelectFeatures(BaseTestCase):
def test_reselect(self):
Expand Down

0 comments on commit 5c41d56

Please sign in to comment.