Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hogql): support UNION ALL #14593

Merged
merged 8 commits into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 32 additions & 11 deletions posthog/hogql/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,15 @@ class SelectQueryRef(Ref):
# all refs a select query exports
columns: Dict[str, Ref] = PydanticField(default_factory=dict)
# all from and join, tables and subqueries with aliases
tables: Dict[str, Union[BaseTableRef, "SelectQueryRef", "SelectQueryAliasRef"]] = PydanticField(
default_factory=dict
)
tables: Dict[
str, Union[BaseTableRef, "SelectUnionQueryRef", "SelectQueryRef", "SelectQueryAliasRef"]
] = PydanticField(default_factory=dict)
# all from and join subqueries without aliases
anonymous_tables: List["SelectQueryRef"] = PydanticField(default_factory=list)
anonymous_tables: List[Union["SelectQueryRef", "SelectUnionQueryRef"]] = PydanticField(default_factory=list)

def get_alias_for_table_ref(
self,
table_ref: Union[BaseTableRef, "SelectQueryRef", "SelectQueryAliasRef"],
table_ref: Union[BaseTableRef, "SelectUnionQueryRef", "SelectQueryRef", "SelectQueryAliasRef"],
) -> Optional[str]:
for key, value in self.tables.items():
if value == table_ref:
Expand All @@ -143,9 +143,25 @@ def has_child(self, name: str) -> bool:
return name in self.columns


class SelectUnionQueryRef(Ref):
refs: List[SelectQueryRef]

def get_alias_for_table_ref(
self,
table_ref: Union[BaseTableRef, SelectQueryRef, "SelectQueryAliasRef"],
) -> Optional[str]:
return self.refs[0].get_alias_for_table_ref(table_ref)

def get_child(self, name: str) -> Ref:
return self.refs[0].get_child(name)

def has_child(self, name: str) -> bool:
return self.refs[0].has_child(name)


class SelectQueryAliasRef(Ref):
name: str
ref: SelectQueryRef
ref: SelectQueryRef | SelectUnionQueryRef

def get_child(self, name: str) -> Ref:
if name == "*":
Expand All @@ -171,17 +187,17 @@ class ConstantRef(Ref):


class AsteriskRef(Ref):
table: Union[BaseTableRef, SelectQueryRef, SelectQueryAliasRef]
table: BaseTableRef | SelectQueryRef | SelectQueryAliasRef | SelectUnionQueryRef


class FieldTraverserRef(Ref):
chain: List[str]
table: Union[BaseTableRef, SelectQueryRef, SelectQueryAliasRef]
table: BaseTableRef | SelectQueryRef | SelectQueryAliasRef | SelectUnionQueryRef


class FieldRef(Ref):
name: str
table: Union[BaseTableRef, SelectQueryRef, SelectQueryAliasRef]
table: BaseTableRef | SelectQueryRef | SelectQueryAliasRef | SelectUnionQueryRef

def resolve_database_field(self) -> Optional[DatabaseField]:
if isinstance(self.table, BaseTableRef):
Expand Down Expand Up @@ -304,7 +320,7 @@ class Call(Expr):

class JoinExpr(Expr):
join_type: Optional[str] = None
table: Optional[Union["SelectQuery", Field]] = None
table: Optional[Union["SelectQuery", "SelectUnionQuery", Field]] = None
alias: Optional[str] = None
table_final: Optional[bool] = None
constraint: Optional[Expr] = None
Expand All @@ -328,5 +344,10 @@ class SelectQuery(Expr):
offset: Optional[Expr] = None


class SelectUnionQuery(Expr):
ref: Optional[SelectUnionQueryRef] = None
select_queries: List[SelectQuery]


JoinExpr.update_forward_refs(SelectUnionQuery=SelectUnionQuery)
JoinExpr.update_forward_refs(SelectQuery=SelectQuery)
JoinExpr.update_forward_refs(JoinExpr=JoinExpr)
22 changes: 16 additions & 6 deletions posthog/hogql/parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Literal, Optional, cast
from typing import Dict, List, Literal, Optional, cast

from antlr4 import CommonTokenStream, InputStream, ParseTreeVisitor
from antlr4.error.ErrorListener import ErrorListener
Expand Down Expand Up @@ -37,7 +37,7 @@ def parse_order_expr(

def parse_select(
statement: str, placeholders: Optional[Dict[str, ast.Expr]] = None, no_placeholders=False
) -> ast.SelectQuery:
) -> ast.SelectQuery | ast.SelectUnionQuery:
parse_tree = get_parser(statement).select()
node = HogQLParseTreeConverter().visit(parse_tree)
if placeholders:
Expand Down Expand Up @@ -67,10 +67,20 @@ def visitSelect(self, ctx: HogQLParser.SelectContext):
return self.visit(ctx.selectUnionStmt() or ctx.selectStmt())

def visitSelectUnionStmt(self, ctx: HogQLParser.SelectUnionStmtContext):
selects = ctx.selectStmtWithParens()
if len(selects) != 1:
raise NotImplementedError(f"Unsupported: UNION ALL")
return self.visit(selects[0])
select_queries: List[ast.SelectQuery | ast.SelectUnionQuery] = [
self.visit(select) for select in ctx.selectStmtWithParens()
]
flattened_queries: List[ast.SelectQuery] = []
for query in select_queries:
if isinstance(query, ast.SelectQuery):
flattened_queries.append(query)
elif isinstance(query, ast.SelectUnionQuery):
flattened_queries.extend(query.select_queries)
else:
raise Exception(f"Unexpected query node type {type(query).__name__}")
if len(flattened_queries) == 1:
return flattened_queries[0]
return ast.SelectUnionQuery(select_queries=flattened_queries)

def visitSelectStmtWithParens(self, ctx: HogQLParser.SelectStmtWithParensContext):
return self.visit(ctx.selectStmt() or ctx.selectUnionStmt())
Expand Down
17 changes: 15 additions & 2 deletions posthog/hogql/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,20 @@ def visit(self, node: ast.AST):
self.stack.pop()
return response

def visit_select_union_query(self, node: ast.SelectUnionQuery):
query = " UNION ALL ".join([self.visit(expr) for expr in node.select_queries])
if len(self.stack) > 1:
return f"({query})"
return query

def visit_select_query(self, node: ast.SelectQuery):
if self.dialect == "clickhouse" and not self.context.select_team_id:
raise ValueError("Full SELECT queries are disabled if context.select_team_id is not set")

# if we are the first parsed node in the tree, or a child of a SelectUnionQuery, mark us as a top level query
part_of_select_union = len(self.stack) >= 2 and isinstance(self.stack[-2], ast.SelectUnionQuery)
is_top_level_query = len(self.stack) <= 1 or (len(self.stack) == 2 and part_of_select_union)

# We will add extra clauses onto this from the joined tables
where = node.where

Expand Down Expand Up @@ -119,7 +129,7 @@ def visit_select_query(self, node: ast.SelectQuery):
]

limit = node.limit
if self.context.limit_top_select and len(self.stack) == 1:
if self.context.limit_top_select and is_top_level_query:
if limit is not None:
if isinstance(limit, ast.Constant) and isinstance(limit.value, int):
limit.value = min(limit.value, MAX_SELECT_RETURNED_ROWS)
Expand All @@ -140,7 +150,7 @@ def visit_select_query(self, node: ast.SelectQuery):
response = " ".join([clause for clause in clauses if clause])

# If we are printing a SELECT subquery (not the first AST node we are visiting), wrap it in parentheses.
if len(self.stack) > 1:
if not part_of_select_union and not is_top_level_query:
response = f"({response})"

return response
Expand Down Expand Up @@ -177,6 +187,9 @@ def visit_join_expr(self, node: ast.JoinExpr) -> JoinExprResponse:
elif isinstance(node.ref, ast.SelectQueryRef):
join_strings.append(self.visit(node.table))

elif isinstance(node.ref, ast.SelectUnionQueryRef):
join_strings.append(self.visit(node.table))

elif isinstance(node.ref, ast.SelectQueryAliasRef) and node.alias is not None:
join_strings.append(self.visit(node.table))
join_strings.append(f"AS {self._print_identifier(node.alias)}")
Expand Down
8 changes: 7 additions & 1 deletion posthog/hogql/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ def __init__(self, scope: Optional[ast.SelectQueryRef] = None):
# Each SELECT query creates a new scope. Store all of them in a list as we traverse the tree.
self.scopes: List[ast.SelectQueryRef] = [scope] if scope else []

def visit_select_union_query(self, node):
for expr in node.select_queries:
self.visit(expr)
node.ref = ast.SelectUnionQueryRef(refs=[expr.ref for expr in node.select_queries])
return node.ref

def visit_select_query(self, node):
"""Visit each SELECT query or subquery."""
if node.ref is not None:
Expand Down Expand Up @@ -93,7 +99,7 @@ def visit_join_expr(self, node):
else:
raise ResolverException(f'Unknown table "{table_name}".')

elif isinstance(node.table, ast.SelectQuery):
elif isinstance(node.table, ast.SelectQuery) or isinstance(node.table, ast.SelectUnionQuery):
node.table.ref = self.visit(node.table)
if node.alias is not None:
if node.alias in scope.tables:
Expand Down
12 changes: 12 additions & 0 deletions posthog/hogql/test/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,3 +711,15 @@ def test_select_placeholders(self):
),
),
)

def test_select_union_all(self):
self.assertEqual(
parse_select("select 1 union all select 2 union all select 3"),
ast.SelectUnionQuery(
select_queries=[
ast.SelectQuery(select=[ast.Constant(value=1)]),
ast.SelectQuery(select=[ast.Constant(value=2)]),
ast.SelectQuery(select=[ast.Constant(value=3)]),
]
),
)
24 changes: 24 additions & 0 deletions posthog/hogql/test/test_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,3 +382,27 @@ def test_select_subquery(self):
self._select("SELECT event from (select distinct event from events group by event, timestamp) e"),
"SELECT e.event FROM (SELECT DISTINCT event FROM events WHERE equals(team_id, 42) GROUP BY event, timestamp) AS e LIMIT 65535",
)

def test_select_union_all(self):
self.assertEqual(
self._select("SELECT event FROM events UNION ALL SELECT event FROM events WHERE 1 = 2"),
"SELECT event FROM events WHERE equals(team_id, 42) LIMIT 65535 UNION ALL SELECT event FROM events WHERE and(equals(team_id, 42), equals(1, 2)) LIMIT 65535",
)
self.assertEqual(
self._select(
"SELECT event FROM events UNION ALL SELECT event FROM events WHERE 1 = 2 UNION ALL SELECT event FROM events WHERE 1 = 2"
),
"SELECT event FROM events WHERE equals(team_id, 42) LIMIT 65535 UNION ALL SELECT event FROM events WHERE and(equals(team_id, 42), equals(1, 2)) LIMIT 65535 UNION ALL SELECT event FROM events WHERE and(equals(team_id, 42), equals(1, 2)) LIMIT 65535",
)
self.assertEqual(
self._select("SELECT 1 UNION ALL (SELECT 1 UNION ALL SELECT 1) UNION ALL SELECT 1"),
"SELECT 1 LIMIT 65535 UNION ALL SELECT 1 LIMIT 65535 UNION ALL SELECT 1 LIMIT 65535 UNION ALL SELECT 1 LIMIT 65535",
)
self.assertEqual(
self._select("SELECT 1 UNION ALL SELECT 1 UNION ALL SELECT 1 UNION ALL SELECT 1"),
"SELECT 1 LIMIT 65535 UNION ALL SELECT 1 LIMIT 65535 UNION ALL SELECT 1 LIMIT 65535 UNION ALL SELECT 1 LIMIT 65535",
)
self.assertEqual(
self._select("SELECT 1 FROM (SELECT 1 UNION ALL SELECT 1)"),
"SELECT 1 FROM (SELECT 1 UNION ALL SELECT 1) LIMIT 65535",
)
20 changes: 20 additions & 0 deletions posthog/hogql/test/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,23 @@ def test_resolve_virtual_events_poe(self):
self.assertEqual(expr.where, expected.where)
self.assertEqual(expr.ref, expected.ref)
self.assertEqual(expr, expected)

def test_resolve_union_all(self):
node = parse_select("select event, timestamp from events union all select event, timestamp from events")
resolve_refs(node)

events_table_ref = ast.TableRef(table=database.events)
self.assertEqual(
node.select_queries[0].select,
[
ast.Field(chain=["event"], ref=ast.FieldRef(name="event", table=events_table_ref)),
ast.Field(chain=["timestamp"], ref=ast.FieldRef(name="timestamp", table=events_table_ref)),
],
)
self.assertEqual(
node.select_queries[1].select,
[
ast.Field(chain=["event"], ref=ast.FieldRef(name="event", table=events_table_ref)),
ast.Field(chain=["timestamp"], ref=ast.FieldRef(name="timestamp", table=events_table_ref)),
],
)
8 changes: 6 additions & 2 deletions posthog/hogql/transforms/asterisk.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ def visit_select_query(self, node: ast.SelectQuery):
ref = ast.FieldRef(name=key, table=asterisk.table)
columns.append(ast.Field(chain=[key], ref=ref))
node.ref.columns[key] = ref
elif isinstance(asterisk.table, ast.SelectQueryRef) or isinstance(
asterisk.table, ast.SelectQueryAliasRef
elif (
isinstance(asterisk.table, ast.SelectUnionQueryRef)
or isinstance(asterisk.table, ast.SelectQueryRef)
or isinstance(asterisk.table, ast.SelectQueryAliasRef)
):
select = asterisk.table
while isinstance(select, ast.SelectQueryAliasRef):
select = select.ref
if isinstance(select, ast.SelectUnionQueryRef):
select = select.refs[0]
if isinstance(select, ast.SelectQueryRef):
for name in select.columns.keys():
ref = ast.FieldRef(name=name, table=asterisk.table)
Expand Down
42 changes: 42 additions & 0 deletions posthog/hogql/transforms/test/test_asterisk.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,45 @@ def test_asterisk_expander_multiple_table_error(self):
self.assertEqual(
str(e.exception), "Cannot use '*' without table name when there are multiple tables in the query"
)

def test_asterisk_expander_select_union(self):
node = parse_select("select * from (select * from events union all select * from events)")
resolve_refs(node)
expand_asterisks(node)

events_table_ref = ast.TableRef(table=database.events)
inner_select_ref = ast.SelectUnionQueryRef(
refs=[
ast.SelectQueryRef(
tables={"events": events_table_ref},
anonymous_tables=[],
aliases={},
columns={
"uuid": ast.FieldRef(name="uuid", table=events_table_ref),
"event": ast.FieldRef(name="event", table=events_table_ref),
"properties": ast.FieldRef(name="properties", table=events_table_ref),
"timestamp": ast.FieldRef(name="timestamp", table=events_table_ref),
"distinct_id": ast.FieldRef(name="distinct_id", table=events_table_ref),
"elements_chain": ast.FieldRef(name="elements_chain", table=events_table_ref),
"created_at": ast.FieldRef(name="created_at", table=events_table_ref),
},
)
]
* 2
)

self.assertEqual(
node.select,
[
ast.Field(chain=["uuid"], ref=ast.FieldRef(name="uuid", table=inner_select_ref)),
ast.Field(chain=["event"], ref=ast.FieldRef(name="event", table=inner_select_ref)),
ast.Field(chain=["properties"], ref=ast.FieldRef(name="properties", table=inner_select_ref)),
ast.Field(chain=["timestamp"], ref=ast.FieldRef(name="timestamp", table=inner_select_ref)),
ast.Field(chain=["distinct_id"], ref=ast.FieldRef(name="distinct_id", table=inner_select_ref)),
ast.Field(
chain=["elements_chain"],
ref=ast.FieldRef(name="elements_chain", table=inner_select_ref),
),
ast.Field(chain=["created_at"], ref=ast.FieldRef(name="created_at", table=inner_select_ref)),
],
)
11 changes: 11 additions & 0 deletions posthog/hogql/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def visit_select_query(self, node: ast.SelectQuery):
self.visit(node.limit),
self.visit(node.offset),

def visit_select_union_query(self, node: ast.SelectUnionQuery):
for expr in node.select_queries:
self.visit(expr)

def visit_field_alias_ref(self, node: ast.FieldAliasRef):
self.visit(node.ref)

Expand All @@ -94,6 +98,10 @@ def visit_select_query_ref(self, node: ast.SelectQueryRef):
for expr in node.columns.values():
self.visit(expr)

def visit_select_union_query_ref(self, node: ast.SelectUnionQueryRef):
for ref in node.refs:
self.visit(ref)

def visit_table_ref(self, node: ast.TableRef):
pass

Expand Down Expand Up @@ -207,3 +215,6 @@ def visit_select_query(self, node: ast.SelectQuery):
offset=self.visit(node.offset),
distinct=node.distinct,
)

def visit_select_union_query(self, node: ast.SelectUnionQuery):
return ast.SelectUnionQuery(select_queries=[self.visit(expr) for expr in node.select_queries])