From aab03ed32c83542c8ff2685acce91e8e3000a6f1 Mon Sep 17 00:00:00 2001 From: Marius Andra Date: Fri, 17 Mar 2023 10:51:15 +0100 Subject: [PATCH] feat(hogql): automatic event and person property types (#14795) --- posthog/hogql/placeholders.py | 1 + posthog/hogql/printer.py | 2 + posthog/hogql/test/test_visitor.py | 1 + posthog/hogql/transforms/property_types.py | 88 +++++++++++++++++++ .../transforms/test/test_property_types.py | 87 ++++++++++++++++++ posthog/hogql/visitor.py | 38 ++++++-- 6 files changed, 208 insertions(+), 9 deletions(-) create mode 100644 posthog/hogql/transforms/property_types.py create mode 100644 posthog/hogql/transforms/test/test_property_types.py diff --git a/posthog/hogql/placeholders.py b/posthog/hogql/placeholders.py index 6d9eb5b017c4a..1b347b1af35c1 100644 --- a/posthog/hogql/placeholders.py +++ b/posthog/hogql/placeholders.py @@ -10,6 +10,7 @@ def replace_placeholders(node: ast.Expr, placeholders: Dict[str, ast.Expr]) -> a class ReplacePlaceholders(CloningVisitor): def __init__(self, placeholders: Dict[str, ast.Expr]): + super().__init__() self.placeholders = placeholders def visit_placeholder(self, node): diff --git a/posthog/hogql/printer.py b/posthog/hogql/printer.py index 6c2815d2e4993..fe37f834aaa23 100644 --- a/posthog/hogql/printer.py +++ b/posthog/hogql/printer.py @@ -9,6 +9,7 @@ from posthog.hogql.print_string import print_clickhouse_identifier, print_hogql_identifier from posthog.hogql.resolver import ResolverException, lookup_field_by_name, resolve_refs from posthog.hogql.transforms import expand_asterisks, resolve_lazy_tables +from posthog.hogql.transforms.property_types import resolve_property_types from posthog.hogql.visitor import Visitor from posthog.models.property import PropertyName, TableColumn @@ -40,6 +41,7 @@ def print_ast( # modify the cloned tree as needed if dialect == "clickhouse": + node = resolve_property_types(node, context) expand_asterisks(node) resolve_lazy_tables(node, stack, context) diff --git a/posthog/hogql/test/test_visitor.py b/posthog/hogql/test/test_visitor.py index 476aac693a8fb..87aebcc8b466b 100644 --- a/posthog/hogql/test/test_visitor.py +++ b/posthog/hogql/test/test_visitor.py @@ -8,6 +8,7 @@ class TestVisitor(BaseTest): def test_visitor_pattern(self): class ConstantVisitor(CloningVisitor): def __init__(self): + super().__init__() self.constants = [] self.fields = [] self.operations = [] diff --git a/posthog/hogql/transforms/property_types.py b/posthog/hogql/transforms/property_types.py new file mode 100644 index 0000000000000..9edcc9e3b3b2c --- /dev/null +++ b/posthog/hogql/transforms/property_types.py @@ -0,0 +1,88 @@ +from typing import Dict, Set + +from posthog.hogql import ast +from posthog.hogql.context import HogQLContext +from posthog.hogql.parser import parse_expr +from posthog.hogql.visitor import CloningVisitor, TraversingVisitor + + +def resolve_property_types(node: ast.Expr, context: HogQLContext = None) -> ast.Expr: + from posthog.models import PropertyDefinition + + # find all properties + property_finder = PropertyFinder() + property_finder.visit(node) + + # fetch them + event_property_values = ( + PropertyDefinition.objects.filter( + name__in=property_finder.event_properties, + team_id=context.team_id, + type__in=[None, PropertyDefinition.Type.EVENT], + ).values_list("name", "property_type") + if property_finder.event_properties + else [] + ) + event_properties = {name: property_type for name, property_type in event_property_values if property_type} + + person_property_values = ( + PropertyDefinition.objects.filter( + name__in=property_finder.person_properties, + team_id=context.team_id, + type=PropertyDefinition.Type.PERSON, + ).values_list("name", "property_type") + if property_finder.person_properties + else [] + ) + person_properties = {name: property_type for name, property_type in person_property_values if property_type} + + # swap them out + if len(event_properties) == 0 and len(person_properties) == 0: + return node + property_swapper = PropertySwapper(event_properties=event_properties, person_properties=person_properties) + return property_swapper.visit(node) + + +class PropertyFinder(TraversingVisitor): + def __init__(self): + super().__init__() + self.person_properties: Set[str] = set() + self.event_properties: Set[str] = set() + + def visit_property_ref(self, node: ast.PropertyRef): + if node.parent.name == "properties": + if isinstance(node.parent.table, ast.BaseTableRef): + table = node.parent.table.resolve_database_table().hogql_table() + if table == "persons": + self.person_properties.add(node.name) + if table == "events": + self.event_properties.add(node.name) + + +class PropertySwapper(CloningVisitor): + def __init__(self, event_properties: Dict[str, str], person_properties: Dict[str, str]): + super().__init__(clear_refs=False) + self.event_properties = event_properties + self.person_properties = person_properties + + def visit_field(self, node: ast.Field): + ref = node.ref + if isinstance(ref, ast.PropertyRef) and ref.parent.name == "properties": + if isinstance(ref.parent.table, ast.BaseTableRef): + table = ref.parent.table.resolve_database_table().hogql_table() + if table == "persons": + if ref.name in self.person_properties: + return self._add_type_to_string_field(node, self.person_properties[ref.name]) + if table == "events": + if ref.name in self.event_properties: + return self._add_type_to_string_field(node, self.event_properties[ref.name]) + return node + + def _add_type_to_string_field(self, node: ast.Field, type: str): + if type == "DateTime": + return ast.Call(name="toDateTime", args=[node]) + if type == "Numeric": + return ast.Call(name="toFloat", args=[node]) + if type == "Boolean": + return parse_expr("{node} = 'true'", {"node": node}) + return node diff --git a/posthog/hogql/transforms/test/test_property_types.py b/posthog/hogql/transforms/test/test_property_types.py new file mode 100644 index 0000000000000..d0597067bde1a --- /dev/null +++ b/posthog/hogql/transforms/test/test_property_types.py @@ -0,0 +1,87 @@ +from django.test import override_settings + +from posthog.hogql.context import HogQLContext +from posthog.hogql.parser import parse_select +from posthog.hogql.printer import print_ast +from posthog.models import PropertyDefinition +from posthog.test.base import BaseTest + + +class TestPropertyTypes(BaseTest): + def setUp(self): + super().setUp() + PropertyDefinition.objects.get_or_create( + team=self.team, + type=PropertyDefinition.Type.EVENT, + name="$screen_height", + defaults={"property_type": "Numeric"}, + ) + PropertyDefinition.objects.get_or_create( + team=self.team, + type=PropertyDefinition.Type.EVENT, + name="$screen_width", + defaults={"property_type": "Numeric"}, + ) + PropertyDefinition.objects.get_or_create( + team=self.team, type=PropertyDefinition.Type.EVENT, name="bool", defaults={"property_type": "Boolean"} + ) + PropertyDefinition.objects.get_or_create( + team=self.team, type=PropertyDefinition.Type.PERSON, name="tickets", defaults={"property_type": "Numeric"} + ) + PropertyDefinition.objects.get_or_create( + team=self.team, + type=PropertyDefinition.Type.PERSON, + name="provided_timestamp", + defaults={"property_type": "DateTime"}, + ) + PropertyDefinition.objects.get_or_create( + team=self.team, + type=PropertyDefinition.Type.PERSON, + name="$initial_browser", + defaults={"property_type": "String"}, + ) + + def test_resolve_property_types_event(self): + printed = self._print_select( + "select properties.$screen_width * properties.$screen_height, properties.bool from events" + ) + expected = ( + "SELECT multiply(" + "toFloat64OrNull(replaceRegexpAll(JSONExtractRaw(properties, %(hogql_val_0)s), '^\"|\"$', '')), " + "toFloat64OrNull(replaceRegexpAll(JSONExtractRaw(properties, %(hogql_val_1)s), '^\"|\"$', ''))), " + "equals(replaceRegexpAll(JSONExtractRaw(properties, %(hogql_val_2)s), '^\"|\"$', ''), %(hogql_val_3)s) " + f"FROM events WHERE equals(team_id, {self.team.pk}) LIMIT 65535" + ) + self.assertEqual(printed, expected) + + def test_resolve_property_types_perosn(self): + printed = self._print_select( + "select properties.tickets, properties.provided_timestamp, properties.$initial_browser from persons" + ) + expected = ( + "SELECT toFloat64OrNull(replaceRegexpAll(JSONExtractRaw(properties, %(hogql_val_0)s), '^\"|\"$', '')), " + "parseDateTimeBestEffort(replaceRegexpAll(JSONExtractRaw(properties, %(hogql_val_1)s), '^\"|\"$', '')), " + "replaceRegexpAll(JSONExtractRaw(properties, %(hogql_val_2)s), '^\"|\"$', '') " + f"FROM person WHERE equals(team_id, {self.team.pk}) LIMIT 65535" + ) + self.assertEqual(printed, expected) + + @override_settings(PERSON_ON_EVENTS_OVERRIDE=False) + def test_resolve_property_types_combined(self): + printed = self._print_select("select properties.$screen_width * person.properties.tickets from events") + expected = ( + "SELECT multiply(" + "toFloat64OrNull(replaceRegexpAll(JSONExtractRaw(properties, %(hogql_val_1)s), '^\"|\"$', '')), " + "toFloat64OrNull(events__pdi__person.properties___tickets)) FROM events INNER JOIN " + "(SELECT argMax(person_distinct_id2.person_id, version) AS person_id, distinct_id FROM person_distinct_id2 " + f"WHERE equals(team_id, {self.team.pk}) GROUP BY distinct_id HAVING equals(argMax(is_deleted, version), 0)) AS events__pdi " + "ON equals(events.distinct_id, events__pdi.distinct_id) INNER JOIN (SELECT " + "argMax(replaceRegexpAll(JSONExtractRaw(properties, %(hogql_val_0)s), '^\"|\"$', ''), version) AS properties___tickets, " + f"id FROM person WHERE equals(team_id, {self.team.pk}) GROUP BY id HAVING equals(argMax(is_deleted, version), 0)) AS events__pdi__person " + f"ON equals(events__pdi.person_id, events__pdi__person.id) WHERE equals(team_id, {self.team.pk}) LIMIT 65535" + ) + self.assertEqual(printed, expected) + + def _print_select(self, select: str): + expr = parse_select(select) + return print_ast(expr, HogQLContext(team_id=self.team.pk, enable_select_queries=True), "clickhouse") diff --git a/posthog/hogql/visitor.py b/posthog/hogql/visitor.py index 97c9ea7b05764..cfa93a31387b5 100644 --- a/posthog/hogql/visitor.py +++ b/posthog/hogql/visitor.py @@ -1,3 +1,5 @@ +from typing import Optional + from posthog.hogql import ast @@ -145,67 +147,82 @@ def visit_property_ref(self, node: ast.PropertyRef): class CloningVisitor(Visitor): """Visitor that traverses and clones the AST tree. Clears refs.""" + def __init__(self, clear_refs: Optional[bool] = True): + self.clear_refs = clear_refs + def visit_expr(self, node: ast.Expr): raise ValueError("Can not visit generic Expr node") def visit_alias(self, node: ast.Alias): return ast.Alias( + ref=None if self.clear_refs else node.ref, alias=node.alias, expr=self.visit(node.expr), ) def visit_binary_operation(self, node: ast.BinaryOperation): return ast.BinaryOperation( + ref=None if self.clear_refs else node.ref, left=self.visit(node.left), right=self.visit(node.right), op=node.op, ) def visit_and(self, node: ast.And): - return ast.And(exprs=[self.visit(expr) for expr in node.exprs]) + return ast.And(ref=None if self.clear_refs else node.ref, exprs=[self.visit(expr) for expr in node.exprs]) def visit_or(self, node: ast.Or): - return ast.Or(exprs=[self.visit(expr) for expr in node.exprs]) + return ast.Or(ref=None if self.clear_refs else node.ref, exprs=[self.visit(expr) for expr in node.exprs]) def visit_compare_operation(self, node: ast.CompareOperation): return ast.CompareOperation( + ref=None if self.clear_refs else node.ref, left=self.visit(node.left), right=self.visit(node.right), op=node.op, ) def visit_not(self, node: ast.Not): - return ast.Not(expr=self.visit(node.expr)) + return ast.Not(ref=None if self.clear_refs else node.ref, expr=self.visit(node.expr)) def visit_order_expr(self, node: ast.OrderExpr): return ast.OrderExpr( + ref=None if self.clear_refs else node.ref, expr=self.visit(node.expr), order=node.order, ) def visit_constant(self, node: ast.Constant): - return node + return ast.Constant(ref=None if self.clear_refs else node.ref, value=node.value) def visit_field(self, node: ast.Field): - return node + return ast.Field(ref=None if self.clear_refs else node.ref, chain=node.chain) def visit_placeholder(self, node: ast.Placeholder): - return node + return ast.Placeholder(ref=None if self.clear_refs else node.ref, field=node.field) def visit_call(self, node: ast.Call): return ast.Call( + ref=None if self.clear_refs else node.ref, name=node.name, args=[self.visit(arg) for arg in node.args], ) def visit_ratio_expr(self, node: ast.RatioExpr): - return ast.RatioExpr(left=self.visit(node.left), right=self.visit(node.right)) + return ast.RatioExpr( + ref=None if self.clear_refs else node.ref, left=self.visit(node.left), right=self.visit(node.right) + ) def visit_sample_expr(self, node: ast.SampleExpr): - return ast.SampleExpr(sample_value=self.visit(node.sample_value), offset_value=self.visit(node.offset_value)) + return ast.SampleExpr( + ref=None if self.clear_refs else node.ref, + sample_value=self.visit(node.sample_value), + offset_value=self.visit(node.offset_value), + ) def visit_join_expr(self, node: ast.JoinExpr): return ast.JoinExpr( + ref=None if self.clear_refs else node.ref, table=self.visit(node.table), next_join=self.visit(node.next_join), table_final=node.table_final, @@ -217,6 +234,7 @@ def visit_join_expr(self, node: ast.JoinExpr): def visit_select_query(self, node: ast.SelectQuery): return ast.SelectQuery( + ref=None if self.clear_refs else node.ref, select=[self.visit(expr) for expr in node.select] if node.select else None, select_from=self.visit(node.select_from), where=self.visit(node.where), @@ -232,4 +250,6 @@ def visit_select_query(self, node: ast.SelectQuery): ) def visit_select_union_query(self, node: ast.SelectUnionQuery): - return ast.SelectUnionQuery(select_queries=[self.visit(expr) for expr in node.select_queries]) + return ast.SelectUnionQuery( + ref=None if self.clear_refs else node.ref, select_queries=[self.visit(expr) for expr in node.select_queries] + )