diff --git a/peewee.py b/peewee.py index 81a92d5da..c4bfafacb 100644 --- a/peewee.py +++ b/peewee.py @@ -1827,6 +1827,16 @@ def __sql__(self, ctx): return ctx.literal(self.window._alias or 'w') +class _InFunction(Node): + def __init__(self, node, in_function=True): + self.node = node + self.in_function = in_function + + def __sql__(self, ctx): + with ctx(in_function=self.in_function): + return ctx.sql(self.node) + + class Case(ColumnBase): def __init__(self, predicate, expression_tuples, default=None): self.predicate = predicate @@ -1838,9 +1848,10 @@ def __sql__(self, ctx): if self.predicate is not None: clauses.append(self.predicate) for expr, value in self.expression_tuples: - clauses.extend((SQL('WHEN'), expr, SQL('THEN'), value)) + clauses.extend((SQL('WHEN'), expr, + SQL('THEN'), _InFunction(value))) if self.default is not None: - clauses.extend((SQL('ELSE'), self.default)) + clauses.extend((SQL('ELSE'), _InFunction(self.default))) clauses.append(SQL('END')) with ctx(in_function=False): return ctx.sql(NodeList(clauses)) diff --git a/tests/sql.py b/tests/sql.py index 6b34bcf53..3251604da 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -1802,6 +1802,20 @@ def test_case_subquery(self): 'CASE WHEN ("t1"."id" IN (SELECT "t1"."id" FROM "n" AS "t1")) ' 'THEN ? ELSE ? END) FROM "n" AS "t1"'), [1, 0]) + case = Case(None, [ + (Name.id < 5, Name.select(fn.SUM(Name.id))), + (Name.id > 5, Name.select(fn.COUNT(Name.name)).distinct())], + Name.select(fn.MAX(Name.id))) + q = Name.select(Name.name, case.alias('magic')) + self.assertSQL(q, ( + 'SELECT "t1"."name", CASE ' + 'WHEN ("t1"."id" < ?) ' + 'THEN (SELECT SUM("t1"."id") FROM "n" AS "t1") ' + 'WHEN ("t1"."id" > ?) ' + 'THEN (SELECT DISTINCT COUNT("t1"."name") FROM "n" AS "t1") ' + 'ELSE (SELECT MAX("t1"."id") FROM "n" AS "t1") END AS "magic" ' + 'FROM "n" AS "t1"'), [5, 5]) + class TestSelectFeatures(BaseTestCase):