Skip to content

Commit

Permalink
feat(hogql): automatic event and person property types (#14795)
Browse files Browse the repository at this point in the history
  • Loading branch information
mariusandra authored Mar 17, 2023
1 parent 4f53565 commit aab03ed
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 9 deletions.
1 change: 1 addition & 0 deletions posthog/hogql/placeholders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def replace_placeholders(node: ast.Expr, placeholders: Dict[str, ast.Expr]) -> a

class ReplacePlaceholders(CloningVisitor):
def __init__(self, placeholders: Dict[str, ast.Expr]):
super().__init__()
self.placeholders = placeholders

def visit_placeholder(self, node):
Expand Down
2 changes: 2 additions & 0 deletions posthog/hogql/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from posthog.hogql.print_string import print_clickhouse_identifier, print_hogql_identifier
from posthog.hogql.resolver import ResolverException, lookup_field_by_name, resolve_refs
from posthog.hogql.transforms import expand_asterisks, resolve_lazy_tables
from posthog.hogql.transforms.property_types import resolve_property_types
from posthog.hogql.visitor import Visitor
from posthog.models.property import PropertyName, TableColumn

Expand Down Expand Up @@ -40,6 +41,7 @@ def print_ast(

# modify the cloned tree as needed
if dialect == "clickhouse":
node = resolve_property_types(node, context)
expand_asterisks(node)
resolve_lazy_tables(node, stack, context)

Expand Down
1 change: 1 addition & 0 deletions posthog/hogql/test/test_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class TestVisitor(BaseTest):
def test_visitor_pattern(self):
class ConstantVisitor(CloningVisitor):
def __init__(self):
super().__init__()
self.constants = []
self.fields = []
self.operations = []
Expand Down
88 changes: 88 additions & 0 deletions posthog/hogql/transforms/property_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import Dict, Set

from posthog.hogql import ast
from posthog.hogql.context import HogQLContext
from posthog.hogql.parser import parse_expr
from posthog.hogql.visitor import CloningVisitor, TraversingVisitor


def resolve_property_types(node: ast.Expr, context: HogQLContext = None) -> ast.Expr:
from posthog.models import PropertyDefinition

# find all properties
property_finder = PropertyFinder()
property_finder.visit(node)

# fetch them
event_property_values = (
PropertyDefinition.objects.filter(
name__in=property_finder.event_properties,
team_id=context.team_id,
type__in=[None, PropertyDefinition.Type.EVENT],
).values_list("name", "property_type")
if property_finder.event_properties
else []
)
event_properties = {name: property_type for name, property_type in event_property_values if property_type}

person_property_values = (
PropertyDefinition.objects.filter(
name__in=property_finder.person_properties,
team_id=context.team_id,
type=PropertyDefinition.Type.PERSON,
).values_list("name", "property_type")
if property_finder.person_properties
else []
)
person_properties = {name: property_type for name, property_type in person_property_values if property_type}

# swap them out
if len(event_properties) == 0 and len(person_properties) == 0:
return node
property_swapper = PropertySwapper(event_properties=event_properties, person_properties=person_properties)
return property_swapper.visit(node)


class PropertyFinder(TraversingVisitor):
def __init__(self):
super().__init__()
self.person_properties: Set[str] = set()
self.event_properties: Set[str] = set()

def visit_property_ref(self, node: ast.PropertyRef):
if node.parent.name == "properties":
if isinstance(node.parent.table, ast.BaseTableRef):
table = node.parent.table.resolve_database_table().hogql_table()
if table == "persons":
self.person_properties.add(node.name)
if table == "events":
self.event_properties.add(node.name)


class PropertySwapper(CloningVisitor):
def __init__(self, event_properties: Dict[str, str], person_properties: Dict[str, str]):
super().__init__(clear_refs=False)
self.event_properties = event_properties
self.person_properties = person_properties

def visit_field(self, node: ast.Field):
ref = node.ref
if isinstance(ref, ast.PropertyRef) and ref.parent.name == "properties":
if isinstance(ref.parent.table, ast.BaseTableRef):
table = ref.parent.table.resolve_database_table().hogql_table()
if table == "persons":
if ref.name in self.person_properties:
return self._add_type_to_string_field(node, self.person_properties[ref.name])
if table == "events":
if ref.name in self.event_properties:
return self._add_type_to_string_field(node, self.event_properties[ref.name])
return node

def _add_type_to_string_field(self, node: ast.Field, type: str):
if type == "DateTime":
return ast.Call(name="toDateTime", args=[node])
if type == "Numeric":
return ast.Call(name="toFloat", args=[node])
if type == "Boolean":
return parse_expr("{node} = 'true'", {"node": node})
return node
87 changes: 87 additions & 0 deletions posthog/hogql/transforms/test/test_property_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from django.test import override_settings

from posthog.hogql.context import HogQLContext
from posthog.hogql.parser import parse_select
from posthog.hogql.printer import print_ast
from posthog.models import PropertyDefinition
from posthog.test.base import BaseTest


class TestPropertyTypes(BaseTest):
def setUp(self):
super().setUp()
PropertyDefinition.objects.get_or_create(
team=self.team,
type=PropertyDefinition.Type.EVENT,
name="$screen_height",
defaults={"property_type": "Numeric"},
)
PropertyDefinition.objects.get_or_create(
team=self.team,
type=PropertyDefinition.Type.EVENT,
name="$screen_width",
defaults={"property_type": "Numeric"},
)
PropertyDefinition.objects.get_or_create(
team=self.team, type=PropertyDefinition.Type.EVENT, name="bool", defaults={"property_type": "Boolean"}
)
PropertyDefinition.objects.get_or_create(
team=self.team, type=PropertyDefinition.Type.PERSON, name="tickets", defaults={"property_type": "Numeric"}
)
PropertyDefinition.objects.get_or_create(
team=self.team,
type=PropertyDefinition.Type.PERSON,
name="provided_timestamp",
defaults={"property_type": "DateTime"},
)
PropertyDefinition.objects.get_or_create(
team=self.team,
type=PropertyDefinition.Type.PERSON,
name="$initial_browser",
defaults={"property_type": "String"},
)

def test_resolve_property_types_event(self):
printed = self._print_select(
"select properties.$screen_width * properties.$screen_height, properties.bool from events"
)
expected = (
"SELECT multiply("
"toFloat64OrNull(replaceRegexpAll(JSONExtractRaw(properties, %(hogql_val_0)s), '^\"|\"$', '')), "
"toFloat64OrNull(replaceRegexpAll(JSONExtractRaw(properties, %(hogql_val_1)s), '^\"|\"$', ''))), "
"equals(replaceRegexpAll(JSONExtractRaw(properties, %(hogql_val_2)s), '^\"|\"$', ''), %(hogql_val_3)s) "
f"FROM events WHERE equals(team_id, {self.team.pk}) LIMIT 65535"
)
self.assertEqual(printed, expected)

def test_resolve_property_types_perosn(self):
printed = self._print_select(
"select properties.tickets, properties.provided_timestamp, properties.$initial_browser from persons"
)
expected = (
"SELECT toFloat64OrNull(replaceRegexpAll(JSONExtractRaw(properties, %(hogql_val_0)s), '^\"|\"$', '')), "
"parseDateTimeBestEffort(replaceRegexpAll(JSONExtractRaw(properties, %(hogql_val_1)s), '^\"|\"$', '')), "
"replaceRegexpAll(JSONExtractRaw(properties, %(hogql_val_2)s), '^\"|\"$', '') "
f"FROM person WHERE equals(team_id, {self.team.pk}) LIMIT 65535"
)
self.assertEqual(printed, expected)

@override_settings(PERSON_ON_EVENTS_OVERRIDE=False)
def test_resolve_property_types_combined(self):
printed = self._print_select("select properties.$screen_width * person.properties.tickets from events")
expected = (
"SELECT multiply("
"toFloat64OrNull(replaceRegexpAll(JSONExtractRaw(properties, %(hogql_val_1)s), '^\"|\"$', '')), "
"toFloat64OrNull(events__pdi__person.properties___tickets)) FROM events INNER JOIN "
"(SELECT argMax(person_distinct_id2.person_id, version) AS person_id, distinct_id FROM person_distinct_id2 "
f"WHERE equals(team_id, {self.team.pk}) GROUP BY distinct_id HAVING equals(argMax(is_deleted, version), 0)) AS events__pdi "
"ON equals(events.distinct_id, events__pdi.distinct_id) INNER JOIN (SELECT "
"argMax(replaceRegexpAll(JSONExtractRaw(properties, %(hogql_val_0)s), '^\"|\"$', ''), version) AS properties___tickets, "
f"id FROM person WHERE equals(team_id, {self.team.pk}) GROUP BY id HAVING equals(argMax(is_deleted, version), 0)) AS events__pdi__person "
f"ON equals(events__pdi.person_id, events__pdi__person.id) WHERE equals(team_id, {self.team.pk}) LIMIT 65535"
)
self.assertEqual(printed, expected)

def _print_select(self, select: str):
expr = parse_select(select)
return print_ast(expr, HogQLContext(team_id=self.team.pk, enable_select_queries=True), "clickhouse")
38 changes: 29 additions & 9 deletions posthog/hogql/visitor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from posthog.hogql import ast


Expand Down Expand Up @@ -145,67 +147,82 @@ def visit_property_ref(self, node: ast.PropertyRef):
class CloningVisitor(Visitor):
"""Visitor that traverses and clones the AST tree. Clears refs."""

def __init__(self, clear_refs: Optional[bool] = True):
self.clear_refs = clear_refs

def visit_expr(self, node: ast.Expr):
raise ValueError("Can not visit generic Expr node")

def visit_alias(self, node: ast.Alias):
return ast.Alias(
ref=None if self.clear_refs else node.ref,
alias=node.alias,
expr=self.visit(node.expr),
)

def visit_binary_operation(self, node: ast.BinaryOperation):
return ast.BinaryOperation(
ref=None if self.clear_refs else node.ref,
left=self.visit(node.left),
right=self.visit(node.right),
op=node.op,
)

def visit_and(self, node: ast.And):
return ast.And(exprs=[self.visit(expr) for expr in node.exprs])
return ast.And(ref=None if self.clear_refs else node.ref, exprs=[self.visit(expr) for expr in node.exprs])

def visit_or(self, node: ast.Or):
return ast.Or(exprs=[self.visit(expr) for expr in node.exprs])
return ast.Or(ref=None if self.clear_refs else node.ref, exprs=[self.visit(expr) for expr in node.exprs])

def visit_compare_operation(self, node: ast.CompareOperation):
return ast.CompareOperation(
ref=None if self.clear_refs else node.ref,
left=self.visit(node.left),
right=self.visit(node.right),
op=node.op,
)

def visit_not(self, node: ast.Not):
return ast.Not(expr=self.visit(node.expr))
return ast.Not(ref=None if self.clear_refs else node.ref, expr=self.visit(node.expr))

def visit_order_expr(self, node: ast.OrderExpr):
return ast.OrderExpr(
ref=None if self.clear_refs else node.ref,
expr=self.visit(node.expr),
order=node.order,
)

def visit_constant(self, node: ast.Constant):
return node
return ast.Constant(ref=None if self.clear_refs else node.ref, value=node.value)

def visit_field(self, node: ast.Field):
return node
return ast.Field(ref=None if self.clear_refs else node.ref, chain=node.chain)

def visit_placeholder(self, node: ast.Placeholder):
return node
return ast.Placeholder(ref=None if self.clear_refs else node.ref, field=node.field)

def visit_call(self, node: ast.Call):
return ast.Call(
ref=None if self.clear_refs else node.ref,
name=node.name,
args=[self.visit(arg) for arg in node.args],
)

def visit_ratio_expr(self, node: ast.RatioExpr):
return ast.RatioExpr(left=self.visit(node.left), right=self.visit(node.right))
return ast.RatioExpr(
ref=None if self.clear_refs else node.ref, left=self.visit(node.left), right=self.visit(node.right)
)

def visit_sample_expr(self, node: ast.SampleExpr):
return ast.SampleExpr(sample_value=self.visit(node.sample_value), offset_value=self.visit(node.offset_value))
return ast.SampleExpr(
ref=None if self.clear_refs else node.ref,
sample_value=self.visit(node.sample_value),
offset_value=self.visit(node.offset_value),
)

def visit_join_expr(self, node: ast.JoinExpr):
return ast.JoinExpr(
ref=None if self.clear_refs else node.ref,
table=self.visit(node.table),
next_join=self.visit(node.next_join),
table_final=node.table_final,
Expand All @@ -217,6 +234,7 @@ def visit_join_expr(self, node: ast.JoinExpr):

def visit_select_query(self, node: ast.SelectQuery):
return ast.SelectQuery(
ref=None if self.clear_refs else node.ref,
select=[self.visit(expr) for expr in node.select] if node.select else None,
select_from=self.visit(node.select_from),
where=self.visit(node.where),
Expand All @@ -232,4 +250,6 @@ def visit_select_query(self, node: ast.SelectQuery):
)

def visit_select_union_query(self, node: ast.SelectUnionQuery):
return ast.SelectUnionQuery(select_queries=[self.visit(expr) for expr in node.select_queries])
return ast.SelectUnionQuery(
ref=None if self.clear_refs else node.ref, select_queries=[self.visit(expr) for expr in node.select_queries]
)

0 comments on commit aab03ed

Please sign in to comment.