Skip to content

Commit

Permalink
feat(hogql): support UNION ALL (#14593)
Browse files Browse the repository at this point in the history
* feat(hogql): union all

* test and stack pop for print visit

* make union its own thing

* python 3.10

* better type and test for parser

* fixes
  • Loading branch information
mariusandra authored Mar 10, 2023
1 parent 2e4c4d4 commit 33ad337
Show file tree
Hide file tree
Showing 10 changed files with 185 additions and 22 deletions.
43 changes: 32 additions & 11 deletions posthog/hogql/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,15 @@ class SelectQueryRef(Ref):
# all refs a select query exports
columns: Dict[str, Ref] = PydanticField(default_factory=dict)
# all from and join, tables and subqueries with aliases
tables: Dict[str, Union[BaseTableRef, "SelectQueryRef", "SelectQueryAliasRef"]] = PydanticField(
default_factory=dict
)
tables: Dict[
str, Union[BaseTableRef, "SelectUnionQueryRef", "SelectQueryRef", "SelectQueryAliasRef"]
] = PydanticField(default_factory=dict)
# all from and join subqueries without aliases
anonymous_tables: List["SelectQueryRef"] = PydanticField(default_factory=list)
anonymous_tables: List[Union["SelectQueryRef", "SelectUnionQueryRef"]] = PydanticField(default_factory=list)

def get_alias_for_table_ref(
self,
table_ref: Union[BaseTableRef, "SelectQueryRef", "SelectQueryAliasRef"],
table_ref: Union[BaseTableRef, "SelectUnionQueryRef", "SelectQueryRef", "SelectQueryAliasRef"],
) -> Optional[str]:
for key, value in self.tables.items():
if value == table_ref:
Expand All @@ -143,9 +143,25 @@ def has_child(self, name: str) -> bool:
return name in self.columns


class SelectUnionQueryRef(Ref):
refs: List[SelectQueryRef]

def get_alias_for_table_ref(
self,
table_ref: Union[BaseTableRef, SelectQueryRef, "SelectQueryAliasRef"],
) -> Optional[str]:
return self.refs[0].get_alias_for_table_ref(table_ref)

def get_child(self, name: str) -> Ref:
return self.refs[0].get_child(name)

def has_child(self, name: str) -> bool:
return self.refs[0].has_child(name)


class SelectQueryAliasRef(Ref):
name: str
ref: SelectQueryRef
ref: SelectQueryRef | SelectUnionQueryRef

def get_child(self, name: str) -> Ref:
if name == "*":
Expand All @@ -171,17 +187,17 @@ class ConstantRef(Ref):


class AsteriskRef(Ref):
table: Union[BaseTableRef, SelectQueryRef, SelectQueryAliasRef]
table: BaseTableRef | SelectQueryRef | SelectQueryAliasRef | SelectUnionQueryRef


class FieldTraverserRef(Ref):
chain: List[str]
table: Union[BaseTableRef, SelectQueryRef, SelectQueryAliasRef]
table: BaseTableRef | SelectQueryRef | SelectQueryAliasRef | SelectUnionQueryRef


class FieldRef(Ref):
name: str
table: Union[BaseTableRef, SelectQueryRef, SelectQueryAliasRef]
table: BaseTableRef | SelectQueryRef | SelectQueryAliasRef | SelectUnionQueryRef

def resolve_database_field(self) -> Optional[DatabaseField]:
if isinstance(self.table, BaseTableRef):
Expand Down Expand Up @@ -304,7 +320,7 @@ class Call(Expr):

class JoinExpr(Expr):
join_type: Optional[str] = None
table: Optional[Union["SelectQuery", Field]] = None
table: Optional[Union["SelectQuery", "SelectUnionQuery", Field]] = None
alias: Optional[str] = None
table_final: Optional[bool] = None
constraint: Optional[Expr] = None
Expand All @@ -328,5 +344,10 @@ class SelectQuery(Expr):
offset: Optional[Expr] = None


class SelectUnionQuery(Expr):
ref: Optional[SelectUnionQueryRef] = None
select_queries: List[SelectQuery]


JoinExpr.update_forward_refs(SelectUnionQuery=SelectUnionQuery)
JoinExpr.update_forward_refs(SelectQuery=SelectQuery)
JoinExpr.update_forward_refs(JoinExpr=JoinExpr)
22 changes: 16 additions & 6 deletions posthog/hogql/parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Literal, Optional, cast
from typing import Dict, List, Literal, Optional, cast

from antlr4 import CommonTokenStream, InputStream, ParseTreeVisitor
from antlr4.error.ErrorListener import ErrorListener
Expand Down Expand Up @@ -37,7 +37,7 @@ def parse_order_expr(

def parse_select(
statement: str, placeholders: Optional[Dict[str, ast.Expr]] = None, no_placeholders=False
) -> ast.SelectQuery:
) -> ast.SelectQuery | ast.SelectUnionQuery:
parse_tree = get_parser(statement).select()
node = HogQLParseTreeConverter().visit(parse_tree)
if placeholders:
Expand Down Expand Up @@ -67,10 +67,20 @@ def visitSelect(self, ctx: HogQLParser.SelectContext):
return self.visit(ctx.selectUnionStmt() or ctx.selectStmt())

def visitSelectUnionStmt(self, ctx: HogQLParser.SelectUnionStmtContext):
selects = ctx.selectStmtWithParens()
if len(selects) != 1:
raise NotImplementedError(f"Unsupported: UNION ALL")
return self.visit(selects[0])
select_queries: List[ast.SelectQuery | ast.SelectUnionQuery] = [
self.visit(select) for select in ctx.selectStmtWithParens()
]
flattened_queries: List[ast.SelectQuery] = []
for query in select_queries:
if isinstance(query, ast.SelectQuery):
flattened_queries.append(query)
elif isinstance(query, ast.SelectUnionQuery):
flattened_queries.extend(query.select_queries)
else:
raise Exception(f"Unexpected query node type {type(query).__name__}")
if len(flattened_queries) == 1:
return flattened_queries[0]
return ast.SelectUnionQuery(select_queries=flattened_queries)

def visitSelectStmtWithParens(self, ctx: HogQLParser.SelectStmtWithParensContext):
return self.visit(ctx.selectStmt() or ctx.selectUnionStmt())
Expand Down
17 changes: 15 additions & 2 deletions posthog/hogql/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,20 @@ def visit(self, node: ast.AST):
self.stack.pop()
return response

def visit_select_union_query(self, node: ast.SelectUnionQuery):
query = " UNION ALL ".join([self.visit(expr) for expr in node.select_queries])
if len(self.stack) > 1:
return f"({query})"
return query

def visit_select_query(self, node: ast.SelectQuery):
if self.dialect == "clickhouse" and not self.context.select_team_id:
raise ValueError("Full SELECT queries are disabled if context.select_team_id is not set")

# if we are the first parsed node in the tree, or a child of a SelectUnionQuery, mark us as a top level query
part_of_select_union = len(self.stack) >= 2 and isinstance(self.stack[-2], ast.SelectUnionQuery)
is_top_level_query = len(self.stack) <= 1 or (len(self.stack) == 2 and part_of_select_union)

# We will add extra clauses onto this from the joined tables
where = node.where

Expand Down Expand Up @@ -119,7 +129,7 @@ def visit_select_query(self, node: ast.SelectQuery):
]

limit = node.limit
if self.context.limit_top_select and len(self.stack) == 1:
if self.context.limit_top_select and is_top_level_query:
if limit is not None:
if isinstance(limit, ast.Constant) and isinstance(limit.value, int):
limit.value = min(limit.value, MAX_SELECT_RETURNED_ROWS)
Expand All @@ -140,7 +150,7 @@ def visit_select_query(self, node: ast.SelectQuery):
response = " ".join([clause for clause in clauses if clause])

# If we are printing a SELECT subquery (not the first AST node we are visiting), wrap it in parentheses.
if len(self.stack) > 1:
if not part_of_select_union and not is_top_level_query:
response = f"({response})"

return response
Expand Down Expand Up @@ -177,6 +187,9 @@ def visit_join_expr(self, node: ast.JoinExpr) -> JoinExprResponse:
elif isinstance(node.ref, ast.SelectQueryRef):
join_strings.append(self.visit(node.table))

elif isinstance(node.ref, ast.SelectUnionQueryRef):
join_strings.append(self.visit(node.table))

elif isinstance(node.ref, ast.SelectQueryAliasRef) and node.alias is not None:
join_strings.append(self.visit(node.table))
join_strings.append(f"AS {self._print_identifier(node.alias)}")
Expand Down
8 changes: 7 additions & 1 deletion posthog/hogql/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ def __init__(self, scope: Optional[ast.SelectQueryRef] = None):
# Each SELECT query creates a new scope. Store all of them in a list as we traverse the tree.
self.scopes: List[ast.SelectQueryRef] = [scope] if scope else []

def visit_select_union_query(self, node):
for expr in node.select_queries:
self.visit(expr)
node.ref = ast.SelectUnionQueryRef(refs=[expr.ref for expr in node.select_queries])
return node.ref

def visit_select_query(self, node):
"""Visit each SELECT query or subquery."""
if node.ref is not None:
Expand Down Expand Up @@ -93,7 +99,7 @@ def visit_join_expr(self, node):
else:
raise ResolverException(f'Unknown table "{table_name}".')

elif isinstance(node.table, ast.SelectQuery):
elif isinstance(node.table, ast.SelectQuery) or isinstance(node.table, ast.SelectUnionQuery):
node.table.ref = self.visit(node.table)
if node.alias is not None:
if node.alias in scope.tables:
Expand Down
12 changes: 12 additions & 0 deletions posthog/hogql/test/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,3 +711,15 @@ def test_select_placeholders(self):
),
),
)

def test_select_union_all(self):
self.assertEqual(
parse_select("select 1 union all select 2 union all select 3"),
ast.SelectUnionQuery(
select_queries=[
ast.SelectQuery(select=[ast.Constant(value=1)]),
ast.SelectQuery(select=[ast.Constant(value=2)]),
ast.SelectQuery(select=[ast.Constant(value=3)]),
]
),
)
24 changes: 24 additions & 0 deletions posthog/hogql/test/test_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,3 +382,27 @@ def test_select_subquery(self):
self._select("SELECT event from (select distinct event from events group by event, timestamp) e"),
"SELECT e.event FROM (SELECT DISTINCT event FROM events WHERE equals(team_id, 42) GROUP BY event, timestamp) AS e LIMIT 65535",
)

def test_select_union_all(self):
self.assertEqual(
self._select("SELECT event FROM events UNION ALL SELECT event FROM events WHERE 1 = 2"),
"SELECT event FROM events WHERE equals(team_id, 42) LIMIT 65535 UNION ALL SELECT event FROM events WHERE and(equals(team_id, 42), equals(1, 2)) LIMIT 65535",
)
self.assertEqual(
self._select(
"SELECT event FROM events UNION ALL SELECT event FROM events WHERE 1 = 2 UNION ALL SELECT event FROM events WHERE 1 = 2"
),
"SELECT event FROM events WHERE equals(team_id, 42) LIMIT 65535 UNION ALL SELECT event FROM events WHERE and(equals(team_id, 42), equals(1, 2)) LIMIT 65535 UNION ALL SELECT event FROM events WHERE and(equals(team_id, 42), equals(1, 2)) LIMIT 65535",
)
self.assertEqual(
self._select("SELECT 1 UNION ALL (SELECT 1 UNION ALL SELECT 1) UNION ALL SELECT 1"),
"SELECT 1 LIMIT 65535 UNION ALL SELECT 1 LIMIT 65535 UNION ALL SELECT 1 LIMIT 65535 UNION ALL SELECT 1 LIMIT 65535",
)
self.assertEqual(
self._select("SELECT 1 UNION ALL SELECT 1 UNION ALL SELECT 1 UNION ALL SELECT 1"),
"SELECT 1 LIMIT 65535 UNION ALL SELECT 1 LIMIT 65535 UNION ALL SELECT 1 LIMIT 65535 UNION ALL SELECT 1 LIMIT 65535",
)
self.assertEqual(
self._select("SELECT 1 FROM (SELECT 1 UNION ALL SELECT 1)"),
"SELECT 1 FROM (SELECT 1 UNION ALL SELECT 1) LIMIT 65535",
)
20 changes: 20 additions & 0 deletions posthog/hogql/test/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,23 @@ def test_resolve_virtual_events_poe(self):
self.assertEqual(expr.where, expected.where)
self.assertEqual(expr.ref, expected.ref)
self.assertEqual(expr, expected)

def test_resolve_union_all(self):
node = parse_select("select event, timestamp from events union all select event, timestamp from events")
resolve_refs(node)

events_table_ref = ast.TableRef(table=database.events)
self.assertEqual(
node.select_queries[0].select,
[
ast.Field(chain=["event"], ref=ast.FieldRef(name="event", table=events_table_ref)),
ast.Field(chain=["timestamp"], ref=ast.FieldRef(name="timestamp", table=events_table_ref)),
],
)
self.assertEqual(
node.select_queries[1].select,
[
ast.Field(chain=["event"], ref=ast.FieldRef(name="event", table=events_table_ref)),
ast.Field(chain=["timestamp"], ref=ast.FieldRef(name="timestamp", table=events_table_ref)),
],
)
8 changes: 6 additions & 2 deletions posthog/hogql/transforms/asterisk.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ def visit_select_query(self, node: ast.SelectQuery):
ref = ast.FieldRef(name=key, table=asterisk.table)
columns.append(ast.Field(chain=[key], ref=ref))
node.ref.columns[key] = ref
elif isinstance(asterisk.table, ast.SelectQueryRef) or isinstance(
asterisk.table, ast.SelectQueryAliasRef
elif (
isinstance(asterisk.table, ast.SelectUnionQueryRef)
or isinstance(asterisk.table, ast.SelectQueryRef)
or isinstance(asterisk.table, ast.SelectQueryAliasRef)
):
select = asterisk.table
while isinstance(select, ast.SelectQueryAliasRef):
select = select.ref
if isinstance(select, ast.SelectUnionQueryRef):
select = select.refs[0]
if isinstance(select, ast.SelectQueryRef):
for name in select.columns.keys():
ref = ast.FieldRef(name=name, table=asterisk.table)
Expand Down
42 changes: 42 additions & 0 deletions posthog/hogql/transforms/test/test_asterisk.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,45 @@ def test_asterisk_expander_multiple_table_error(self):
self.assertEqual(
str(e.exception), "Cannot use '*' without table name when there are multiple tables in the query"
)

def test_asterisk_expander_select_union(self):
node = parse_select("select * from (select * from events union all select * from events)")
resolve_refs(node)
expand_asterisks(node)

events_table_ref = ast.TableRef(table=database.events)
inner_select_ref = ast.SelectUnionQueryRef(
refs=[
ast.SelectQueryRef(
tables={"events": events_table_ref},
anonymous_tables=[],
aliases={},
columns={
"uuid": ast.FieldRef(name="uuid", table=events_table_ref),
"event": ast.FieldRef(name="event", table=events_table_ref),
"properties": ast.FieldRef(name="properties", table=events_table_ref),
"timestamp": ast.FieldRef(name="timestamp", table=events_table_ref),
"distinct_id": ast.FieldRef(name="distinct_id", table=events_table_ref),
"elements_chain": ast.FieldRef(name="elements_chain", table=events_table_ref),
"created_at": ast.FieldRef(name="created_at", table=events_table_ref),
},
)
]
* 2
)

self.assertEqual(
node.select,
[
ast.Field(chain=["uuid"], ref=ast.FieldRef(name="uuid", table=inner_select_ref)),
ast.Field(chain=["event"], ref=ast.FieldRef(name="event", table=inner_select_ref)),
ast.Field(chain=["properties"], ref=ast.FieldRef(name="properties", table=inner_select_ref)),
ast.Field(chain=["timestamp"], ref=ast.FieldRef(name="timestamp", table=inner_select_ref)),
ast.Field(chain=["distinct_id"], ref=ast.FieldRef(name="distinct_id", table=inner_select_ref)),
ast.Field(
chain=["elements_chain"],
ref=ast.FieldRef(name="elements_chain", table=inner_select_ref),
),
ast.Field(chain=["created_at"], ref=ast.FieldRef(name="created_at", table=inner_select_ref)),
],
)
11 changes: 11 additions & 0 deletions posthog/hogql/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def visit_select_query(self, node: ast.SelectQuery):
self.visit(node.limit),
self.visit(node.offset),

def visit_select_union_query(self, node: ast.SelectUnionQuery):
for expr in node.select_queries:
self.visit(expr)

def visit_field_alias_ref(self, node: ast.FieldAliasRef):
self.visit(node.ref)

Expand All @@ -94,6 +98,10 @@ def visit_select_query_ref(self, node: ast.SelectQueryRef):
for expr in node.columns.values():
self.visit(expr)

def visit_select_union_query_ref(self, node: ast.SelectUnionQueryRef):
for ref in node.refs:
self.visit(ref)

def visit_table_ref(self, node: ast.TableRef):
pass

Expand Down Expand Up @@ -207,3 +215,6 @@ def visit_select_query(self, node: ast.SelectQuery):
offset=self.visit(node.offset),
distinct=node.distinct,
)

def visit_select_union_query(self, node: ast.SelectUnionQuery):
return ast.SelectUnionQuery(select_queries=[self.visit(expr) for expr in node.select_queries])

0 comments on commit 33ad337

Please sign in to comment.