From 839d5059b60a3fb5efab513c44b84741d4f2675c Mon Sep 17 00:00:00 2001 From: Marius Andra Date: Wed, 8 Feb 2023 16:16:18 +0100 Subject: [PATCH] parse limit by --- posthog/hogql/printer.py | 20 +++++++++++++++----- posthog/hogql/test/test_printer.py | 19 +++++++++++++++++++ posthog/hogql/test/test_visitor.py | 6 ++++-- posthog/hogql/visitor.py | 6 ++++-- 4 files changed, 42 insertions(+), 9 deletions(-) diff --git a/posthog/hogql/printer.py b/posthog/hogql/printer.py index d07599d080e21..e2ebe83dc4d23 100644 --- a/posthog/hogql/printer.py +++ b/posthog/hogql/printer.py @@ -71,9 +71,12 @@ def print_ast( limit = node.limit if context.limit_top_select: if limit is not None: - limit = max(0, min(node.limit, MAX_SELECT_RETURNED_ROWS)) - if len(stack) == 1 and limit is None: - limit = MAX_SELECT_RETURNED_ROWS + if isinstance(limit, ast.Constant) and isinstance(limit.value, int): + limit.value = min(limit.value, MAX_SELECT_RETURNED_ROWS) + else: + limit = ast.Call(name="min2", args=[ast.Constant(value=MAX_SELECT_RETURNED_ROWS), limit]) + elif len(stack) == 1: + limit = ast.Constant(value=MAX_SELECT_RETURNED_ROWS) clauses = [ f"SELECT {'DISTINCT ' if node.distinct else ''}{', '.join(columns)}", @@ -83,9 +86,16 @@ def print_ast( "HAVING " + having if having else None, "PREWHERE " + prewhere if prewhere else None, f"ORDER BY {', '.join(order_by)}" if order_by and len(order_by) > 0 else None, - f"LIMIT {limit}" if limit is not None else None, - f"OFFSET {node.offset}" if node.offset is not None else None, ] + if limit is not None: + clauses.append(f"LIMIT {print_ast(limit, stack, context, dialect)}"), + if node.offset is not None: + clauses.append(f"OFFSET {print_ast(node.offset, stack, context, dialect)}") + if node.limit_by is not None: + clauses.append(f"BY {', '.join([print_ast(expr, stack, context, dialect) for expr in node.limit_by])}") + if node.limit_with_ties: + clauses.append("WITH TIES") + response = " ".join([clause for clause in clauses if clause]) if len(stack) > 1: response = f"({response})" diff --git a/posthog/hogql/test/test_printer.py b/posthog/hogql/test/test_printer.py index ded3198704c69..e1e67e8106497 100644 --- a/posthog/hogql/test/test_printer.py +++ b/posthog/hogql/test/test_printer.py @@ -405,6 +405,15 @@ def test_select_limit(self): self._select("select event from events limit 10000000"), "SELECT event FROM events WHERE equals(team_id, 42) LIMIT 65535", ) + self.assertEqual( + self._select("select event from events limit (select 1000000000)"), + "SELECT event FROM events WHERE equals(team_id, 42) LIMIT min2(65535, (SELECT 1000000000))", + ) + + self.assertEqual( + self._select("select event from events limit (select 1000000000) with ties"), + "SELECT event FROM events WHERE equals(team_id, 42) LIMIT min2(65535, (SELECT 1000000000)) WITH TIES", + ) def test_select_offset(self): self.assertEqual( @@ -415,6 +424,16 @@ def test_select_offset(self): self._select("select event from events limit 10 offset 0"), "SELECT event FROM events WHERE equals(team_id, 42) LIMIT 10 OFFSET 0", ) + self.assertEqual( + self._select("select event from events limit 10 offset 0 with ties"), + "SELECT event FROM events WHERE equals(team_id, 42) LIMIT 10 OFFSET 0 WITH TIES", + ) + + def test_select_limit_by(self): + self.assertEqual( + self._select("select event from events limit 10 offset 0 by 1,event"), + "SELECT event FROM events WHERE equals(team_id, 42) LIMIT 10 OFFSET 0 BY 1, event", + ) def test_select_group_by(self): self.assertEqual( diff --git a/posthog/hogql/test/test_visitor.py b/posthog/hogql/test/test_visitor.py index f60812f53aab8..738b4ac3aed12 100644 --- a/posthog/hogql/test/test_visitor.py +++ b/posthog/hogql/test/test_visitor.py @@ -85,8 +85,10 @@ def test_everything_visitor(self): having=ast.Constant(value=True), group_by=[ast.Constant(value=True)], order_by=[ast.OrderExpr(expr=ast.Constant(value=True), order="DESC")], - limit=1, - offset=0, + limit=ast.Constant(value=1), + limit_by=[ast.Constant(value=True)], + limit_with_ties=True, + offset=ast.Or(exprs=[ast.Constant(value=1)]), distinct=True, ), ] diff --git a/posthog/hogql/visitor.py b/posthog/hogql/visitor.py index 1bca4d05a78f4..66365f368bbc1 100644 --- a/posthog/hogql/visitor.py +++ b/posthog/hogql/visitor.py @@ -81,7 +81,9 @@ def visit_select_query(self, node: ast.SelectQuery): having=self.visit(node.having), group_by=[self.visit(expr) for expr in node.group_by] if node.group_by else None, order_by=[self.visit(expr) for expr in node.order_by] if node.order_by else None, - limit=node.limit, - offset=node.offset, + limit_by=[self.visit(expr) for expr in node.limit_by] if node.limit_by else None, + limit=self.visit(node.limit), + limit_with_ties=node.limit_with_ties, + offset=self.visit(node.offset), distinct=node.distinct, )