Skip to content

Commit

Permalink
feat(data-exploration): countIf and other aggregations (#13854)
Browse files Browse the repository at this point in the history
* add new aggregations

* Revert "add new aggregations"

This reverts commit 8e6c7c9.

* Revert "Revert "add new aggregations""

This reverts commit 1385f6a.
  • Loading branch information
mariusandra authored Jan 20, 2023
1 parent 49c29c4 commit 31ac050
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 39 deletions.
76 changes: 41 additions & 35 deletions posthog/hogql/hogql.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,22 @@
"trunc": "trunc",
}
# Permitted HogQL aggregations
HOGQL_AGGREGATIONS = [
"count",
"min",
"max",
"sum",
"avg",
"any",
]
HOGQL_AGGREGATIONS = {
"count": 0,
"countIf": 1,
"countDistinct": 1,
"countDistinctIf": 2,
"min": 1,
"minIf": 2,
"max": 1,
"maxIf": 2,
"sum": 1,
"sumIf": 2,
"avg": 1,
"avgIf": 2,
"any": 1,
"anyIf": 2,
}
# Keywords passed to ClickHouse without transformation
KEYWORDS = ["true", "false", "null"]

Expand Down Expand Up @@ -259,36 +267,34 @@ def translate_ast(node: ast.AST, stack: List[ast.AST], context: HogQLContext) ->
call_name = node.func.id
if call_name in HOGQL_AGGREGATIONS:
context.found_aggregation = True
required_arg_count = HOGQL_AGGREGATIONS[call_name]

if call_name == "count" and len(node.args) == 0:
response = "count(*)"
else:
if call_name == "count" and len(node.args) != 1:
raise ValueError(f"Aggregation 'count' expects one or zero arguments.")
elif len(node.args) != 1:
raise ValueError(f"Aggregation '{call_name}' expects just one argument.")
if required_arg_count != len(node.args):
raise ValueError(
f"Aggregation '{call_name}' requires {required_arg_count} argument{'s' if required_arg_count != 1 else ''}, found {len(node.args)}"
)

# check that we're not running inside another aggregate
for stack_node in stack:
if (
stack_node != node
and isinstance(stack_node, ast.Call)
and isinstance(stack_node.func, ast.Name)
and stack_node.func.id in HOGQL_AGGREGATIONS
):
raise ValueError(
f"Aggregation '{call_name}' cannot be nested inside another aggregation '{stack_node.func.id}'."
)
# check that we're not running inside another aggregate
for stack_node in stack:
if (
stack_node != node
and isinstance(stack_node, ast.Call)
and isinstance(stack_node.func, ast.Name)
and stack_node.func.id in HOGQL_AGGREGATIONS
):
raise ValueError(
f"Aggregation '{call_name}' cannot be nested inside another aggregation '{stack_node.func.id}'."
)

# check that we're running an aggregate on a property
properties_before = len(context.attribute_list)
if call_name == "count":
response = f"{call_name}(distinct {translate_ast(node.args[0], stack, context)})"
else:
response = f"{call_name}({translate_ast(node.args[0], stack, context)})"
properties_after = len(context.attribute_list)
if properties_after == properties_before:
raise ValueError(f"{call_name}(...) must be called on fields or properties, not literals.")
translated_args = ", ".join([translate_ast(arg, stack, context) for arg in node.args])
if call_name == "count":
response = "count(*)"
elif call_name == "countDistinct":
response = f"count(distinct {translated_args})"
elif call_name == "countDistinctIf":
response = f"countIf(distinct {translated_args})"
else:
response = f"{call_name}({translated_args})"

elif node.func.id in CLICKHOUSE_FUNCTIONS:
response = f"{CLICKHOUSE_FUNCTIONS[node.func.id]}({', '.join([translate_ast(arg, stack, context) for arg in node.args])})"
Expand Down
11 changes: 7 additions & 4 deletions posthog/hogql/test/test_hogql.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def test_hogql_materialized_fields_and_properties(self):

def test_hogql_methods(self):
self.assertEqual(self._translate("count()"), "count(*)")
self.assertEqual(self._translate("count(event)"), "count(distinct event)")
self.assertEqual(self._translate("countDistinct(event)"), "count(distinct event)")
self.assertEqual(self._translate("countDistinctIf(event, 1 == 2)"), "countIf(distinct event, equals(1, 2))")
self.assertEqual(self._translate("sumIf(1, 1 == 2)"), "sumIf(1, equals(1, 2))")

def test_hogql_functions(self):
context = HogQLContext(values=None) # inline values
Expand All @@ -77,8 +79,10 @@ def test_hogql_expr_parse_errors(self):
self._assert_value_error("())", "SyntaxError: unmatched ')'")
self._assert_value_error("this makes little sense", "SyntaxError: invalid syntax")
self._assert_value_error("avg(bla)", "Unknown event field 'bla'")
self._assert_value_error("count(2,4)", "Aggregation 'count' expects one or zero arguments.")
self._assert_value_error("avg(2,1)", "Aggregation 'avg' expects just one argument.")
self._assert_value_error("count(2)", "Aggregation 'count' requires 0 arguments, found 1")
self._assert_value_error("count(2,4)", "Aggregation 'count' requires 0 arguments, found 2")
self._assert_value_error("countIf()", "Aggregation 'countIf' requires 1 argument, found 0")
self._assert_value_error("countIf(2,4)", "Aggregation 'countIf' requires 1 argument, found 2")
self._assert_value_error(
"bla.avg(bla)", "Can only call simple functions like 'avg(properties.bla)' or 'count()'"
)
Expand All @@ -89,7 +93,6 @@ def test_hogql_expr_parse_errors(self):
self._assert_value_error("['properties']['value']['bla']", "Unknown node in field access chain:")
self._assert_value_error("chipotle", "Unknown event field 'chipotle'")
self._assert_value_error("person.chipotle", "Unknown person field 'chipotle'")
self._assert_value_error("avg(2)", "avg(...) must be called on fields or properties, not literals.")
self._assert_value_error(
"avg(avg(properties.bla))", "Aggregation 'avg' cannot be nested inside another aggregation 'avg'."
)
Expand Down

0 comments on commit 31ac050

Please sign in to comment.