From 399769db79bf97842ae77b2013bce36068486bf1 Mon Sep 17 00:00:00 2001 From: Marius Andra Date: Thu, 6 Apr 2023 15:47:57 +0200 Subject: [PATCH] feat(hogql): really lazy tables (#14927) --- frontend/public/blank-dashboard-hog.png | Bin 5703 -> 1610 bytes posthog/hogql/ast.py | 22 +- posthog/hogql/database.py | 225 +++++++++++------- posthog/hogql/printer.py | 21 +- posthog/hogql/resolver.py | 7 +- .../test/__snapshots__/test_database.ambr | 136 +++++++---- posthog/hogql/test/test_printer.py | 22 +- posthog/hogql/test/test_query.py | 8 +- posthog/hogql/test/test_resolver.py | 56 +++-- posthog/hogql/transforms/lazy_tables.py | 179 +++++++++----- posthog/hogql/transforms/property_types.py | 4 +- .../hogql/transforms/test/test_lazy_tables.py | 75 ++++-- .../transforms/test/test_property_types.py | 14 +- posthog/hogql/visitor.py | 7 +- posthog/queries/funnels/test/test_funnel.py | 1 - 15 files changed, 513 insertions(+), 264 deletions(-) diff --git a/frontend/public/blank-dashboard-hog.png b/frontend/public/blank-dashboard-hog.png index ddc56da808dde2a19222db17f5fa4fc6ebf6a0b8..6d90e9bf9ad26b4759a7c690df6ab1effd2988d6 100644 GIT binary patch literal 1610 zcmeAS@N?(olHy`uVBq!ia0y~yVA;UHz*NS;3>2xD*PREX5(0ceTo2b}9d9c>-d1v| zul7`b?YU`f=ccuvo8EqDUjL?^bJ` zpLQq0f!A5jV21RDJ+*q93@sB1gCjTisj=(i&fYA{UBQ!_xna*yrfJi>nx~%Oi-^!m zD?6}h*QbQFKKueY3mzNihUwFOqg- zr~TWqt2wuFL)rHqwb3o=8>EBy6nm|2x7^ri_0(ZmbFWjx+X>B*Qqiq5c1<}v)2h04 z!uP9Fn7#6JIx@~E{@oJU@!Xbu>&9q~^e(A2C-W3%$y-`)WXRv3WqnJpLa9b1{B~=^ zESVYgW)3>9PaLe;8qbj`Z5eTIS91cVm_kGoD+h#}mcR&PgGkkg1}F(q&;V5i7KZ2` zp81U%s`i68)b6F7GHZ_2DNf726BnB!-X!^$EqhBW$FhFuHGAJWnTY$hGzhQW)O^>; zX1-zd*Tl^N0h5)!y}ENt$K?vE&b6JJ!#S#wc^}SWzZ))aJ$glc>+*YUid*VBJ|$n7wf}UF-7?G?vsqCwX3?nCgwS?`W|dQRGgjXz@-PxZeN( literal 5703 zcmeHK`%_bQ7C#{%uT@kYaVj94PG#7o27JJZyrk6VNMP%BOUK08el1!ogP! zn~)M$bcmDytwa7gbmL6hOWq4`o9DTl=>p)ZI*_>c69&$p7FVTHSQRrD!@|Dty>@#m ziHb`5^iv#s~BUrvKKDEp^k-;&J_`2|@8*?X=yhX_9T^FW5X z`{8|zfc}Yv>V&`K#J~IWN^fBSb+LAzt;@`cEvPS9M{K_5HCALa*@wqg7TD`hEO24T zj>BZ8lOHB>9Kr#Cg9VO4_`mjo-&{?SI3W)~WQX`)2=cN3gtWPvhTDajevZt0oNh3R zncD+ZtH%&r894ByPbn&A=BsGZ#r{e2%RluShfe63%yK4cVV0U^%>W3!RN(@l&wN^w zyJpq5J3C_j#a}gSCML`3MRmnoxs>G21H3=v~epz3{k1Y+E`ySF7174r-$Uo}!p8Qo&zMhz$w~ zjGTUjWmxE_rhK}zz1Mf!WmjPMV=gvSHWi%y<|!;b0i>O`TmC_-sK)AIG%!HI&RPr zWPk5e#uk?X4Ax^7pKlq|w44(U?mX`TVKg+z!g5^+>xDJKeUR}ZX8MknX&(7^UTjB0 zhZ79Ehin7YNhf>G%V zk3V@jKCamzTUYbiDlRRbUR8@0`YmRTj235|ES!{f;6MK?oV#lm3j){de1DE!z5ayk+6!sXX3z8#ABynL2uP zh*7XQ3Lrm+XqtPh)Vd<4cbjemE!XfNIzS|18o8pIp>S13al>V-O}a>F{nv$v zAgHTQQ;CClcPa?iREy?+8ug5vo=Vzl>8xEr4WS&7yq8>?lHxTz6=N^R4~Qy7+Pt5& zcks5!-X=EvXT{(0nxYI-LHhO_UVV>KMFcup&$=dNCfH26xWXg?{D3ONO&L|n1Y5ps zLWhIDBB$#+3N`M1t}nvt7nSn>WMnz9k}}Q6J7s?g@D=*x!gEH?PPgdftwOkZ@Gb7< z6*pYX#EF65;)Vl{riB~aAREhHTA8g6uZW8z zcI|*S*?P$j5h$xS_HwqKfPlkA)aT}G zeBF$ufKhcXqOnXz?9<9$;FeGDZ4;*&*)$Uogz4xs)DK%)YR!$PFY^q4BvOR$ z3~~@8J6Z^MHbEdcX+RV_q+em~&UJwr(wm`fyp=XEs=C#h8i=+7Y;Ltjkb@tePROB8&T(Ep zhyc)HtL1T%)=HPA+UG@g*MLNYnbMQJsZHt8aa!%LS;iX~v2OX&8P*G-;-fmgP&?nz z{onF=v!I0boU@swr2Li7OhY${BOkS(QhM>-uJUlS6vZPKepz(du5EVv7*`qB`Fr_@ z;UJ+8)syh9A5>gdpr3xv?c`%=Y}*#8SbZ&FPztl2Rs)$$gZEDMX%Zn~0c-gDUKtRS_=NboFnB3{%VNGd=fd z Ref: return AsteriskRef(table=self) if self.has_child(name): field = self.resolve_database_table().get_field(name) + if isinstance(field, LazyJoin): + return LazyJoinRef(table=self, field=name, lazy_join=field) if isinstance(field, LazyTable): - return LazyTableRef(table=self, field=name, lazy_table=field) + return LazyTableRef(table=field) if isinstance(field, FieldTraverser): return FieldTraverserRef(table=self, chain=field.chain) if isinstance(field, VirtualTable): @@ -101,13 +104,20 @@ def resolve_database_table(self) -> Table: return self.table_ref.table -class LazyTableRef(BaseTableRef): +class LazyJoinRef(BaseTableRef): table: BaseTableRef field: str - lazy_table: LazyTable + lazy_join: LazyJoin + + def resolve_database_table(self) -> Table: + return self.lazy_join.join_table + + +class LazyTableRef(BaseTableRef): + table: LazyTable def resolve_database_table(self) -> Table: - return self.lazy_table.table + return self.table class VirtualTableRef(BaseTableRef): @@ -328,6 +338,8 @@ class Call(Expr): class JoinExpr(Expr): + ref: Optional[BaseTableRef | SelectQueryRef | SelectQueryAliasRef | SelectUnionQueryRef] + join_type: Optional[str] = None table: Optional[Union["SelectQuery", "SelectUnionQuery", Field]] = None alias: Optional[str] = None diff --git a/posthog/hogql/database.py b/posthog/hogql/database.py index 2bfbc8af1017b..d6470ca61b961 100644 --- a/posthog/hogql/database.py +++ b/posthog/hogql/database.py @@ -65,8 +65,7 @@ def get_asterisk(self) -> Dict[str, DatabaseField]: asterisk[key] = database_field elif ( isinstance(database_field, Table) - or isinstance(database_field, LazyTable) - or isinstance(database_field, VirtualTable) + or isinstance(database_field, LazyJoin) or isinstance(database_field, FieldTraverser) ): pass # ignore virtual tables for now @@ -75,15 +74,23 @@ def get_asterisk(self) -> Dict[str, DatabaseField]: return asterisk -class LazyTable(BaseModel): +class LazyJoin(BaseModel): class Config: extra = Extra.forbid join_function: Callable[[str, str, Dict[str, Any]], Any] - table: Table + join_table: Table from_field: str +class LazyTable(Table): + class Config: + extra = Extra.forbid + + def lazy_select(self, requested_fields: Dict[str, Any]) -> Any: + raise NotImplementedError("LazyTable.lazy_select not overridden") + + class VirtualTable(Table): class Config: extra = Extra.forbid @@ -96,19 +103,51 @@ class Config: chain: List[str] -class EventsPersonSubTable(VirtualTable): - id: StringDatabaseField = StringDatabaseField(name="person_id") - created_at: DateTimeDatabaseField = DateTimeDatabaseField(name="person_created_at") - properties: StringJSONDatabaseField = StringJSONDatabaseField(name="person_properties") +def select_from_persons_table(requested_fields: Dict[str, Any]): + from posthog.hogql import ast - def clickhouse_table(self): - return "events" + if not requested_fields: + raise ValueError("No fields requested from persons table.") - def hogql_table(self): - return "events" + fields_to_select: List[ast.Expr] = [] + argmax_version: Callable[[ast.Expr], ast.Expr] = lambda field: ast.Call( + name="argMax", args=[field, ast.Field(chain=["version"])] + ) + for field, expr in requested_fields.items(): + if field != "id": + fields_to_select.append(ast.Alias(alias=field, expr=argmax_version(expr))) + + id = ast.Field(chain=["id"]) + return ast.SelectQuery( + select=fields_to_select + [id], + select_from=ast.JoinExpr(table=ast.Field(chain=["raw_persons"])), + group_by=[id], + having=ast.CompareOperation( + op=ast.CompareOperationType.Eq, + left=argmax_version(ast.Field(chain=["is_deleted"])), + right=ast.Constant(value=0), + ), + ) -class PersonsTable(Table): + +def join_with_persons_table(from_table: str, to_table: str, requested_fields: Dict[str, Any]): + from posthog.hogql import ast + + if not requested_fields: + raise ValueError("No fields requested from persons table.") + join_expr = ast.JoinExpr(table=select_from_persons_table(requested_fields)) + join_expr.join_type = "INNER JOIN" + join_expr.alias = to_table + join_expr.constraint = ast.CompareOperation( + op=ast.CompareOperationType.Eq, + left=ast.Field(chain=[from_table, "person_id"]), + right=ast.Field(chain=[to_table, "id"]), + ) + return join_expr + + +class RawPersonsTable(Table): id: StringDatabaseField = StringDatabaseField(name="id") created_at: DateTimeDatabaseField = DateTimeDatabaseField(name="created_at") team_id: IntegerDatabaseField = IntegerDatabaseField(name="team_id") @@ -117,6 +156,23 @@ class PersonsTable(Table): is_deleted: BooleanDatabaseField = BooleanDatabaseField(name="is_deleted") version: IntegerDatabaseField = IntegerDatabaseField(name="version") + def clickhouse_table(self): + return "person" + + def hogql_table(self): + return "raw_persons" + + +class PersonsTable(LazyTable): + id: StringDatabaseField = StringDatabaseField(name="id") + created_at: DateTimeDatabaseField = DateTimeDatabaseField(name="created_at") + team_id: IntegerDatabaseField = IntegerDatabaseField(name="team_id") + properties: StringJSONDatabaseField = StringJSONDatabaseField(name="properties") + is_identified: BooleanDatabaseField = BooleanDatabaseField(name="is_identified") + + def lazy_select(self, requested_fields: Dict[str, Any]): + return select_from_persons_table(requested_fields) + def clickhouse_table(self): return "person" @@ -124,101 +180,92 @@ def hogql_table(self): return "persons" -def join_with_persons_table(from_table: str, to_table: str, requested_fields: Dict[str, Any]): +def select_from_person_distinct_ids_table(requested_fields: Dict[str, Any]): from posthog.hogql import ast if not requested_fields: - raise ValueError("No fields requested from persons table. Why are we joining it?") + requested_fields = {"person_id": ast.Field(chain=["person_id"])} - # contains the list of fields we will select from this table fields_to_select: List[ast.Expr] = [] - argmax_version: Callable[[ast.Expr], ast.Expr] = lambda field: ast.Call( name="argMax", args=[field, ast.Field(chain=["version"])] ) for field, expr in requested_fields.items(): - if field != "id": + if field != "distinct_id": fields_to_select.append(ast.Alias(alias=field, expr=argmax_version(expr))) - id = ast.Field(chain=["id"]) + distinct_id = ast.Field(chain=["distinct_id"]) - return ast.JoinExpr( - join_type="INNER JOIN", - table=ast.SelectQuery( - select=fields_to_select + [id], - select_from=ast.JoinExpr(table=ast.Field(chain=["persons"])), - group_by=[id], - having=ast.CompareOperation( - op=ast.CompareOperationType.Eq, - left=argmax_version(ast.Field(chain=["is_deleted"])), - right=ast.Constant(value=0), - ), - ), - alias=to_table, - constraint=ast.CompareOperation( + return ast.SelectQuery( + select=fields_to_select + [distinct_id], + select_from=ast.JoinExpr(table=ast.Field(chain=["raw_person_distinct_ids"])), + group_by=[distinct_id], + having=ast.CompareOperation( op=ast.CompareOperationType.Eq, - left=ast.Field(chain=[from_table, "person_id"]), - right=ast.Field(chain=[to_table, "id"]), + left=argmax_version(ast.Field(chain=["is_deleted"])), + right=ast.Constant(value=0), ), ) -class PersonDistinctIdTable(Table): +def join_with_person_distinct_ids_table(from_table: str, to_table: str, requested_fields: Dict[str, Any]): + from posthog.hogql import ast + + if not requested_fields: + raise ValueError("No fields requested from person_distinct_ids.") + join_expr = ast.JoinExpr(table=select_from_person_distinct_ids_table(requested_fields)) + join_expr.join_type = "INNER JOIN" + join_expr.alias = to_table + join_expr.constraint = ast.CompareOperation( + op=ast.CompareOperationType.Eq, + left=ast.Field(chain=[from_table, "distinct_id"]), + right=ast.Field(chain=[to_table, "distinct_id"]), + ) + return join_expr + + +class RawPersonDistinctIdTable(Table): team_id: IntegerDatabaseField = IntegerDatabaseField(name="team_id") distinct_id: StringDatabaseField = StringDatabaseField(name="distinct_id") person_id: StringDatabaseField = StringDatabaseField(name="person_id") is_deleted: BooleanDatabaseField = BooleanDatabaseField(name="is_deleted") version: IntegerDatabaseField = IntegerDatabaseField(name="version") - person: LazyTable = LazyTable(from_field="person_id", table=PersonsTable(), join_function=join_with_persons_table) - - def avoid_asterisk_fields(self): - return ["is_deleted", "version"] - def clickhouse_table(self): return "person_distinct_id2" def hogql_table(self): - return "person_distinct_ids" + return "raw_person_distinct_ids" -def join_with_max_person_distinct_id_table(from_table: str, to_table: str, requested_fields: Dict[str, Any]): - from posthog.hogql import ast +class PersonDistinctIdTable(LazyTable): + team_id: IntegerDatabaseField = IntegerDatabaseField(name="team_id") + distinct_id: StringDatabaseField = StringDatabaseField(name="distinct_id") + person_id: StringDatabaseField = StringDatabaseField(name="person_id") + person: LazyJoin = LazyJoin( + from_field="person_id", join_table=PersonsTable(), join_function=join_with_persons_table + ) - if not requested_fields: - requested_fields = {"person_id": ast.Field(chain=["person_id"])} + def lazy_select(self, requested_fields: Dict[str, Any]): + return select_from_person_distinct_ids_table(requested_fields) - # contains the list of fields we will select from this table - fields_to_select: List[ast.Expr] = [] + def clickhouse_table(self): + return "person_distinct_id2" - argmax_version: Callable[[ast.Expr], ast.Expr] = lambda field: ast.Call( - name="argMax", args=[field, ast.Field(chain=["version"])] - ) - for field, expr in requested_fields.items(): - if field != "distinct_id": - fields_to_select.append(ast.Alias(alias=field, expr=argmax_version(expr))) + def hogql_table(self): + return "person_distinct_ids" - distinct_id = ast.Field(chain=["distinct_id"]) - return ast.JoinExpr( - join_type="INNER JOIN", - table=ast.SelectQuery( - select=fields_to_select + [distinct_id], - select_from=ast.JoinExpr(table=ast.Field(chain=["person_distinct_ids"])), - group_by=[distinct_id], - having=ast.CompareOperation( - op=ast.CompareOperationType.Eq, - left=argmax_version(ast.Field(chain=["is_deleted"])), - right=ast.Constant(value=0), - ), - ), - alias=to_table, - constraint=ast.CompareOperation( - op=ast.CompareOperationType.Eq, - left=ast.Field(chain=[from_table, "distinct_id"]), - right=ast.Field(chain=[to_table, "distinct_id"]), - ), - ) +class EventsPersonSubTable(VirtualTable): + id: StringDatabaseField = StringDatabaseField(name="person_id") + created_at: DateTimeDatabaseField = DateTimeDatabaseField(name="person_created_at") + properties: StringJSONDatabaseField = StringJSONDatabaseField(name="person_properties") + + def clickhouse_table(self): + return "events" + + def hogql_table(self): + return "events" class EventsTable(Table): @@ -232,8 +279,10 @@ class EventsTable(Table): created_at: DateTimeDatabaseField = DateTimeDatabaseField(name="created_at") # lazy table that adds a join to the persons table - pdi: LazyTable = LazyTable( - from_field="distinct_id", table=PersonDistinctIdTable(), join_function=join_with_max_person_distinct_id_table + pdi: LazyJoin = LazyJoin( + from_field="distinct_id", + join_table=PersonDistinctIdTable(), + join_function=join_with_person_distinct_ids_table, ) # person fields on the event itself poe: EventsPersonSubTable = EventsPersonSubTable() @@ -267,8 +316,10 @@ class SessionRecordingEvents(Table): last_event_timestamp: DateTimeDatabaseField = DateTimeDatabaseField(name="last_event_timestamp") urls: StringDatabaseField = StringDatabaseField(name="urls", array=True) - pdi: LazyTable = LazyTable( - from_field="distinct_id", table=PersonDistinctIdTable(), join_function=join_with_max_person_distinct_id_table + pdi: LazyJoin = LazyJoin( + from_field="distinct_id", + join_table=PersonDistinctIdTable(), + join_function=join_with_person_distinct_ids_table, ) person: FieldTraverser = FieldTraverser(chain=["pdi", "person"]) @@ -290,7 +341,9 @@ class CohortPeople(Table): # TODO: automatically add "HAVING SUM(sign) > 0" to fields selected from this table? - person: LazyTable = LazyTable(from_field="person_id", table=PersonsTable(), join_function=join_with_persons_table) + person: LazyJoin = LazyJoin( + from_field="person_id", join_table=PersonsTable(), join_function=join_with_persons_table + ) def clickhouse_table(self): return "cohortpeople" @@ -304,7 +357,9 @@ class StaticCohortPeople(Table): cohort_id: IntegerDatabaseField = IntegerDatabaseField(name="cohort_id") team_id: IntegerDatabaseField = IntegerDatabaseField(name="team_id") - person: LazyTable = LazyTable(from_field="person_id", table=PersonsTable(), join_function=join_with_persons_table) + person: LazyJoin = LazyJoin( + from_field="person_id", join_table=PersonsTable(), join_function=join_with_persons_table + ) def avoid_asterisk_fields(self): return ["_timestamp", "_offset"] @@ -336,12 +391,16 @@ class Config: # Users can query from the tables below events: EventsTable = EventsTable() + groups: Groups = Groups() persons: PersonsTable = PersonsTable() person_distinct_ids: PersonDistinctIdTable = PersonDistinctIdTable() + session_recording_events: SessionRecordingEvents = SessionRecordingEvents() cohort_people: CohortPeople = CohortPeople() static_cohort_people: StaticCohortPeople = StaticCohortPeople() - groups: Groups = Groups() + + raw_person_distinct_ids: RawPersonDistinctIdTable = RawPersonDistinctIdTable() + raw_persons: RawPersonsTable = RawPersonsTable() def __init__(self, timezone: Optional[str]): super().__init__() @@ -394,8 +453,8 @@ def serialize_database(database: Database) -> dict: fields.append({"key": field_key, "type": "boolean"}) elif isinstance(field, StringJSONDatabaseField): fields.append({"key": field_key, "type": "json"}) - elif isinstance(field, LazyTable): - fields.append({"key": field_key, "type": "lazy_table", "table": field.table.hogql_table()}) + elif isinstance(field, LazyJoin): + fields.append({"key": field_key, "type": "lazy_table", "table": field.join_table.hogql_table()}) elif isinstance(field, VirtualTable): fields.append( { diff --git a/posthog/hogql/printer.py b/posthog/hogql/printer.py index 639b3f477d1e8..c1c021f10cf8d 100644 --- a/posthog/hogql/printer.py +++ b/posthog/hogql/printer.py @@ -252,11 +252,6 @@ def visit_join_expr(self, node: ast.JoinExpr) -> JoinExprResponse: else: join_strings.append(self._print_identifier(node.ref.table.hogql_table())) - if node.sample is not None: - sample_clause = self.visit_sample_expr(node.sample) - if sample_clause is not None: - join_strings.append(sample_clause) - if self.dialect == "clickhouse": # TODO: do this in a separate pass before printing, along with person joins and other transforms extra_where = team_id_guard_for_table(node.ref, self.context) @@ -270,12 +265,21 @@ def visit_join_expr(self, node: ast.JoinExpr) -> JoinExprResponse: elif isinstance(node.ref, ast.SelectQueryAliasRef) and node.alias is not None: join_strings.append(self.visit(node.table)) join_strings.append(f"AS {self._print_identifier(node.alias)}") + + elif isinstance(node.ref, ast.LazyTableRef) and self.dialect == "hogql": + join_strings.append(self._print_identifier(node.ref.table.hogql_table())) + else: raise ValueError("Only selecting from a table or a subquery is supported") if node.table_final: join_strings.append("FINAL") + if node.sample is not None: + sample_clause = self.visit_sample_expr(node.sample) + if sample_clause is not None: + join_strings.append(sample_clause) + if node.constraint is not None: join_strings.append(f"ON {self.visit(node.constraint)}") @@ -590,8 +594,11 @@ def visit_virtual_table_ref(self, ref: ast.VirtualTableRef): def visit_asterisk_ref(self, ref: ast.AsteriskRef): return "*" - def visit_lazy_table_ref(self, ref: ast.LazyTableRef): - raise ValueError("Unexpected ast.LazyTableRef. Make sure LazyTableResolver has run on the AST.") + def visit_lazy_join_ref(self, ref: ast.LazyJoinRef): + raise ValueError("Unexpected ast.LazyJoinRef. Make sure LazyJoinResolver has run on the AST.") + + def visit_lazy_table_ref(self, ref: ast.LazyJoinRef): + raise ValueError("Unexpected ast.LazyTableRef. Make sure LazyJoinResolver has run on the AST.") def visit_field_traverser_ref(self, ref: ast.FieldTraverserRef): raise ValueError("Unexpected ast.FieldTraverserRef. This should have been resolved.") diff --git a/posthog/hogql/resolver.py b/posthog/hogql/resolver.py index 5089307e9aa85..ea041b2169d16 100644 --- a/posthog/hogql/resolver.py +++ b/posthog/hogql/resolver.py @@ -91,7 +91,12 @@ def visit_join_expr(self, node): raise ResolverException(f'Already have joined a table called "{table_alias}". Can\'t redefine.') if self.database.has_table(table_name): - node.table.ref = ast.TableRef(table=self.database.get_table(table_name)) + database_table = self.database.get_table(table_name) + if isinstance(database_table, ast.LazyTable): + node.table.ref = ast.LazyTableRef(table=database_table) + else: + node.table.ref = ast.TableRef(table=database_table) + if table_alias == table_name: node.ref = node.table.ref else: diff --git a/posthog/hogql/test/__snapshots__/test_database.ambr b/posthog/hogql/test/__snapshots__/test_database.ambr index f5f15ea0e90e7..8d12e5bee43e7 100644 --- a/posthog/hogql/test/__snapshots__/test_database.ambr +++ b/posthog/hogql/test/__snapshots__/test_database.ambr @@ -62,9 +62,13 @@ ] } ], - "persons": [ + "groups": [ { - "key": "id", + "key": "index", + "type": "integer" + }, + { + "key": "key", "type": "string" }, { @@ -74,18 +78,24 @@ { "key": "properties", "type": "json" + } + ], + "persons": [ + { + "key": "id", + "type": "string" }, { - "key": "is_identified", - "type": "boolean" + "key": "created_at", + "type": "datetime" }, { - "key": "is_deleted", - "type": "boolean" + "key": "properties", + "type": "json" }, { - "key": "version", - "type": "integer" + "key": "is_identified", + "type": "boolean" } ], "person_distinct_ids": [ @@ -97,14 +107,6 @@ "key": "person_id", "type": "string" }, - { - "key": "is_deleted", - "type": "boolean" - }, - { - "key": "version", - "type": "integer" - }, { "key": "person", "type": "lazy_table", @@ -232,13 +234,27 @@ "table": "persons" } ], - "groups": [ + "raw_person_distinct_ids": [ { - "key": "index", - "type": "integer" + "key": "distinct_id", + "type": "string" }, { - "key": "key", + "key": "person_id", + "type": "string" + }, + { + "key": "is_deleted", + "type": "boolean" + }, + { + "key": "version", + "type": "integer" + } + ], + "raw_persons": [ + { + "key": "id", "type": "string" }, { @@ -248,6 +264,18 @@ { "key": "properties", "type": "json" + }, + { + "key": "is_identified", + "type": "boolean" + }, + { + "key": "is_deleted", + "type": "boolean" + }, + { + "key": "version", + "type": "integer" } ] } @@ -312,9 +340,13 @@ "type": "string" } ], - "persons": [ + "groups": [ { - "key": "id", + "key": "index", + "type": "integer" + }, + { + "key": "key", "type": "string" }, { @@ -324,18 +356,24 @@ { "key": "properties", "type": "json" + } + ], + "persons": [ + { + "key": "id", + "type": "string" }, { - "key": "is_identified", - "type": "boolean" + "key": "created_at", + "type": "datetime" }, { - "key": "is_deleted", - "type": "boolean" + "key": "properties", + "type": "json" }, { - "key": "version", - "type": "integer" + "key": "is_identified", + "type": "boolean" } ], "person_distinct_ids": [ @@ -347,14 +385,6 @@ "key": "person_id", "type": "string" }, - { - "key": "is_deleted", - "type": "boolean" - }, - { - "key": "version", - "type": "integer" - }, { "key": "person", "type": "lazy_table", @@ -482,13 +512,27 @@ "table": "persons" } ], - "groups": [ + "raw_person_distinct_ids": [ { - "key": "index", - "type": "integer" + "key": "distinct_id", + "type": "string" }, { - "key": "key", + "key": "person_id", + "type": "string" + }, + { + "key": "is_deleted", + "type": "boolean" + }, + { + "key": "version", + "type": "integer" + } + ], + "raw_persons": [ + { + "key": "id", "type": "string" }, { @@ -498,6 +542,18 @@ { "key": "properties", "type": "json" + }, + { + "key": "is_identified", + "type": "boolean" + }, + { + "key": "is_deleted", + "type": "boolean" + }, + { + "key": "version", + "type": "integer" } ] } diff --git a/posthog/hogql/test/test_printer.py b/posthog/hogql/test/test_printer.py index fa2fab68426c1..0d16ba8595f7c 100644 --- a/posthog/hogql/test/test_printer.py +++ b/posthog/hogql/test/test_printer.py @@ -441,29 +441,31 @@ def test_select_sample(self): self._select( "SELECT events.event FROM events SAMPLE 2/78 OFFSET 999 JOIN persons ON persons.id=events.person_id" ), - f"SELECT events.event FROM events SAMPLE 2/78 OFFSET 999 JOIN person ON equals(person.id, events__pdi.person_id) INNER JOIN (SELECT argMax(person_distinct_id2.person_id, person_distinct_id2.version) AS person_id, person_distinct_id2.distinct_id FROM person_distinct_id2 WHERE equals(person_distinct_id2.team_id, {self.team.pk}) GROUP BY person_distinct_id2.distinct_id HAVING equals(argMax(person_distinct_id2.is_deleted, person_distinct_id2.version), 0)) AS events__pdi ON equals(events.distinct_id, events__pdi.distinct_id) WHERE and(equals(person.team_id, {self.team.pk}), equals(events.team_id, {self.team.pk})) LIMIT 65535", + f"SELECT events.event FROM events SAMPLE 2/78 OFFSET 999 INNER JOIN (SELECT argMax(person_distinct_id2.person_id, person_distinct_id2.version) AS person_id, person_distinct_id2.distinct_id FROM person_distinct_id2 WHERE equals(person_distinct_id2.team_id, {self.team.pk}) GROUP BY person_distinct_id2.distinct_id HAVING equals(argMax(person_distinct_id2.is_deleted, person_distinct_id2.version), 0)) AS events__pdi ON equals(events.distinct_id, events__pdi.distinct_id) JOIN (SELECT person.id FROM person WHERE equals(person.team_id, {self.team.pk}) GROUP BY person.id HAVING equals(argMax(person.is_deleted, person.version), 0)) AS persons ON equals(persons.id, events__pdi.person_id) WHERE equals(events.team_id, {self.team.pk}) LIMIT 65535", ) self.assertEqual( self._select( "SELECT events.event FROM events SAMPLE 2/78 OFFSET 999 JOIN persons SAMPLE 0.1 ON persons.id=events.person_id" ), - f"SELECT events.event FROM events SAMPLE 2/78 OFFSET 999 JOIN person SAMPLE 0.1 ON equals(person.id, events__pdi.person_id) INNER JOIN (SELECT argMax(person_distinct_id2.person_id, person_distinct_id2.version) AS person_id, person_distinct_id2.distinct_id FROM person_distinct_id2 WHERE equals(person_distinct_id2.team_id, {self.team.pk}) GROUP BY person_distinct_id2.distinct_id HAVING equals(argMax(person_distinct_id2.is_deleted, person_distinct_id2.version), 0)) AS events__pdi ON equals(events.distinct_id, events__pdi.distinct_id) WHERE and(equals(person.team_id, {self.team.pk}), equals(events.team_id, {self.team.pk})) LIMIT 65535", + f"SELECT events.event FROM events SAMPLE 2/78 OFFSET 999 INNER JOIN (SELECT argMax(person_distinct_id2.person_id, person_distinct_id2.version) AS person_id, person_distinct_id2.distinct_id FROM person_distinct_id2 WHERE equals(person_distinct_id2.team_id, {self.team.pk}) GROUP BY person_distinct_id2.distinct_id HAVING equals(argMax(person_distinct_id2.is_deleted, person_distinct_id2.version), 0)) AS events__pdi ON equals(events.distinct_id, events__pdi.distinct_id) JOIN (SELECT person.id FROM person WHERE equals(person.team_id, {self.team.pk}) GROUP BY person.id HAVING equals(argMax(person.is_deleted, person.version), 0)) AS persons SAMPLE 0.1 ON equals(persons.id, events__pdi.person_id) WHERE equals(events.team_id, {self.team.pk}) LIMIT 65535", ) with override_settings(PERSON_ON_EVENTS_OVERRIDE=True): + expected = self._select( + "SELECT events.event FROM events SAMPLE 2/78 OFFSET 999 JOIN persons ON persons.id=events.person_id" + ) self.assertEqual( - self._select( - "SELECT events.event FROM events SAMPLE 2/78 OFFSET 999 JOIN persons ON persons.id=events.person_id" - ), - f"SELECT events.event FROM events SAMPLE 2/78 OFFSET 999 JOIN person ON equals(person.id, events.person_id) WHERE and(equals(person.team_id, {self.team.pk}), equals(events.team_id, {self.team.pk})) LIMIT 65535", + expected, + f"SELECT events.event FROM events SAMPLE 2/78 OFFSET 999 JOIN (SELECT person.id FROM person WHERE equals(person.team_id, {self.team.pk}) GROUP BY person.id HAVING equals(argMax(person.is_deleted, person.version), 0)) AS persons ON equals(persons.id, events.person_id) WHERE equals(events.team_id, {self.team.pk}) LIMIT 65535", ) + expected = self._select( + "SELECT events.event FROM events SAMPLE 2/78 OFFSET 999 JOIN persons SAMPLE 0.1 ON persons.id=events.person_id" + ) self.assertEqual( - self._select( - "SELECT events.event FROM events SAMPLE 2/78 OFFSET 999 JOIN persons SAMPLE 0.1 ON persons.id=events.person_id" - ), - f"SELECT events.event FROM events SAMPLE 2/78 OFFSET 999 JOIN person SAMPLE 0.1 ON equals(person.id, events.person_id) WHERE and(equals(person.team_id, {self.team.pk}), equals(events.team_id, {self.team.pk})) LIMIT 65535", + expected, + f"SELECT events.event FROM events SAMPLE 2/78 OFFSET 999 JOIN (SELECT person.id FROM person WHERE equals(person.team_id, {self.team.pk}) GROUP BY person.id HAVING equals(argMax(person.is_deleted, person.version), 0)) AS persons SAMPLE 0.1 ON equals(persons.id, events.person_id) WHERE equals(events.team_id, {self.team.pk}) LIMIT 65535", ) def test_count_distinct(self): diff --git a/posthog/hogql/test/test_query.py b/posthog/hogql/test/test_query.py index 4bf8cd48a1e43..bf14287dc3935 100644 --- a/posthog/hogql/test/test_query.py +++ b/posthog/hogql/test/test_query.py @@ -91,7 +91,7 @@ def test_query(self): ) self.assertEqual( response.clickhouse, - f"SELECT DISTINCT replaceRegexpAll(JSONExtractRaw(person.properties, %(hogql_val_0)s), '^\"|\"$', '') FROM person WHERE and(equals(person.team_id, {self.team.id}), equals(replaceRegexpAll(JSONExtractRaw(person.properties, %(hogql_val_1)s), '^\"|\"$', ''), %(hogql_val_2)s)) LIMIT 100 SETTINGS readonly=1, max_execution_time=60", + f"SELECT DISTINCT persons.properties___sneaky_mail FROM (SELECT argMax(replaceRegexpAll(JSONExtractRaw(person.properties, %(hogql_val_0)s), '^\"|\"$', ''), person.version) AS properties___sneaky_mail, argMax(replaceRegexpAll(JSONExtractRaw(person.properties, %(hogql_val_1)s), '^\"|\"$', ''), person.version) AS properties___random_uuid, person.id FROM person WHERE equals(person.team_id, {self.team.id}) GROUP BY person.id HAVING equals(argMax(person.is_deleted, person.version), 0)) AS persons WHERE equals(persons.properties___random_uuid, %(hogql_val_2)s) LIMIT 100 SETTINGS readonly=1, max_execution_time=60", ) self.assertEqual( response.hogql, @@ -105,7 +105,7 @@ def test_query(self): ) self.assertEqual( response.clickhouse, - f"SELECT DISTINCT person_distinct_id2.person_id, person_distinct_id2.distinct_id FROM person_distinct_id2 WHERE equals(person_distinct_id2.team_id, {self.team.id}) LIMIT 100 SETTINGS readonly=1, max_execution_time=60", + f"SELECT DISTINCT person_distinct_ids.person_id, person_distinct_ids.distinct_id FROM (SELECT argMax(person_distinct_id2.person_id, person_distinct_id2.version) AS person_id, person_distinct_id2.distinct_id FROM person_distinct_id2 WHERE equals(person_distinct_id2.team_id, {self.team.id}) GROUP BY person_distinct_id2.distinct_id HAVING equals(argMax(person_distinct_id2.is_deleted, person_distinct_id2.version), 0)) AS person_distinct_ids LIMIT 100 SETTINGS readonly=1, max_execution_time=60", ) self.assertEqual( response.hogql, @@ -150,7 +150,7 @@ def test_query_joins_pdi(self): INNER JOIN ( SELECT distinct_id, argMax(person_id, version) as person_id - FROM person_distinct_ids + FROM raw_person_distinct_ids GROUP BY distinct_id HAVING argMax(is_deleted, version) = 0 ) AS pdi @@ -169,7 +169,7 @@ def test_query_joins_pdi(self): ) self.assertEqual( response.hogql, - "SELECT event, timestamp, pdi.person_id FROM events AS e INNER JOIN (SELECT distinct_id, argMax(person_id, version) AS person_id FROM person_distinct_ids GROUP BY distinct_id HAVING equals(argMax(is_deleted, version), 0)) AS pdi ON equals(e.distinct_id, pdi.distinct_id) LIMIT 100", + "SELECT event, timestamp, pdi.person_id FROM events AS e INNER JOIN (SELECT distinct_id, argMax(person_id, version) AS person_id FROM raw_person_distinct_ids GROUP BY distinct_id HAVING equals(argMax(is_deleted, version), 0)) AS pdi ON equals(e.distinct_id, pdi.distinct_id) LIMIT 100", ) self.assertTrue(len(response.results) > 0) diff --git a/posthog/hogql/test/test_resolver.py b/posthog/hogql/test/test_resolver.py index b1dbe411df9eb..d2e1463915405 100644 --- a/posthog/hogql/test/test_resolver.py +++ b/posthog/hogql/test/test_resolver.py @@ -252,8 +252,8 @@ def test_resolve_lazy_pdi_person_table(self): chain=["person", "id"], ref=ast.FieldRef( name="id", - table=ast.LazyTableRef( - table=pdi_table_ref, field="person", lazy_table=self.database.person_distinct_ids.person + table=ast.LazyJoinRef( + table=pdi_table_ref, field="person", lazy_join=self.database.person_distinct_ids.person ), ), ), @@ -269,9 +269,9 @@ def test_resolve_lazy_pdi_person_table(self): "distinct_id": ast.FieldRef(name="distinct_id", table=pdi_table_ref), "id": ast.FieldRef( name="id", - table=ast.LazyTableRef( + table=ast.LazyJoinRef( table=pdi_table_ref, - lazy_table=self.database.person_distinct_ids.person, + lazy_join=self.database.person_distinct_ids.person, field="person", ), ), @@ -299,9 +299,7 @@ def test_resolve_lazy_events_pdi_table(self): chain=["pdi", "person_id"], ref=ast.FieldRef( name="person_id", - table=ast.LazyTableRef( - table=events_table_ref, field="pdi", lazy_table=self.database.events.pdi - ), + table=ast.LazyJoinRef(table=events_table_ref, field="pdi", lazy_join=self.database.events.pdi), ), ), ], @@ -316,9 +314,9 @@ def test_resolve_lazy_events_pdi_table(self): "event": ast.FieldRef(name="event", table=events_table_ref), "person_id": ast.FieldRef( name="person_id", - table=ast.LazyTableRef( + table=ast.LazyJoinRef( table=events_table_ref, - lazy_table=self.database.events.pdi, + lazy_join=self.database.events.pdi, field="pdi", ), ), @@ -347,8 +345,8 @@ def test_resolve_lazy_events_pdi_table_aliased(self): chain=["e", "pdi", "person_id"], ref=ast.FieldRef( name="person_id", - table=ast.LazyTableRef( - table=events_table_alias_ref, field="pdi", lazy_table=self.database.events.pdi + table=ast.LazyJoinRef( + table=events_table_alias_ref, field="pdi", lazy_join=self.database.events.pdi ), ), ), @@ -365,9 +363,9 @@ def test_resolve_lazy_events_pdi_table_aliased(self): "event": ast.FieldRef(name="event", table=events_table_alias_ref), "person_id": ast.FieldRef( name="person_id", - table=ast.LazyTableRef( + table=ast.LazyJoinRef( table=events_table_alias_ref, - lazy_table=self.database.events.pdi, + lazy_join=self.database.events.pdi, field="pdi", ), ), @@ -395,12 +393,12 @@ def test_resolve_lazy_events_pdi_person_table(self): chain=["pdi", "person", "id"], ref=ast.FieldRef( name="id", - table=ast.LazyTableRef( - table=ast.LazyTableRef( - table=events_table_ref, field="pdi", lazy_table=self.database.events.pdi + table=ast.LazyJoinRef( + table=ast.LazyJoinRef( + table=events_table_ref, field="pdi", lazy_join=self.database.events.pdi ), field="person", - lazy_table=self.database.events.pdi.table.person, + lazy_join=self.database.events.pdi.join_table.person, ), ), ), @@ -416,12 +414,12 @@ def test_resolve_lazy_events_pdi_person_table(self): "event": ast.FieldRef(name="event", table=events_table_ref), "id": ast.FieldRef( name="id", - table=ast.LazyTableRef( - table=ast.LazyTableRef( - table=events_table_ref, field="pdi", lazy_table=self.database.events.pdi + table=ast.LazyJoinRef( + table=ast.LazyJoinRef( + table=events_table_ref, field="pdi", lazy_join=self.database.events.pdi ), field="person", - lazy_table=self.database.events.pdi.table.person, + lazy_join=self.database.events.pdi.join_table.person, ), ), }, @@ -449,12 +447,12 @@ def test_resolve_lazy_events_pdi_person_table_aliased(self): chain=["e", "pdi", "person", "id"], ref=ast.FieldRef( name="id", - table=ast.LazyTableRef( - table=ast.LazyTableRef( - table=events_table_alias_ref, field="pdi", lazy_table=self.database.events.pdi + table=ast.LazyJoinRef( + table=ast.LazyJoinRef( + table=events_table_alias_ref, field="pdi", lazy_join=self.database.events.pdi ), field="person", - lazy_table=self.database.events.pdi.table.person, + lazy_join=self.database.events.pdi.join_table.person, ), ), ), @@ -471,12 +469,12 @@ def test_resolve_lazy_events_pdi_person_table_aliased(self): "event": ast.FieldRef(name="event", table=events_table_alias_ref), "id": ast.FieldRef( name="id", - table=ast.LazyTableRef( - table=ast.LazyTableRef( - table=events_table_alias_ref, field="pdi", lazy_table=self.database.events.pdi + table=ast.LazyJoinRef( + table=ast.LazyJoinRef( + table=events_table_alias_ref, field="pdi", lazy_join=self.database.events.pdi ), field="person", - lazy_table=self.database.events.pdi.table.person, + lazy_join=self.database.events.pdi.join_table.person, ), ), }, diff --git a/posthog/hogql/transforms/lazy_tables.py b/posthog/hogql/transforms/lazy_tables.py index 4fbdb883791c9..7855c5c5c2668 100644 --- a/posthog/hogql/transforms/lazy_tables.py +++ b/posthog/hogql/transforms/lazy_tables.py @@ -1,10 +1,9 @@ import dataclasses -from typing import Dict, List, Optional, cast +from typing import Dict, List, Optional from posthog.hogql import ast -from posthog.hogql.ast import LazyTableRef from posthog.hogql.context import HogQLContext -from posthog.hogql.database import LazyTable +from posthog.hogql.database import LazyJoin, LazyTable from posthog.hogql.resolver import resolve_refs from posthog.hogql.visitor import TraversingVisitor @@ -19,11 +18,17 @@ def resolve_lazy_tables(node: ast.Expr, stack: Optional[List[ast.SelectQuery]] = @dataclasses.dataclass class JoinToAdd: fields_accessed: Dict[str, ast.Expr] - lazy_table: LazyTable + lazy_join: LazyJoin from_table: str to_table: str +@dataclasses.dataclass +class TableToAdd: + fields_accessed: Dict[str, ast.Expr] + lazy_table: LazyTable + + class LazyTableResolver(TraversingVisitor): def __init__(self, stack: Optional[List[ast.SelectQuery]] = None, context: HogQLContext = None): super().__init__() @@ -33,11 +38,13 @@ def __init__(self, stack: Optional[List[ast.SelectQuery]] = None, context: HogQL def _get_long_table_name(self, select: ast.SelectQueryRef, ref: ast.BaseTableRef) -> str: if isinstance(ref, ast.TableRef): return select.get_alias_for_table_ref(ref) + elif isinstance(ref, ast.LazyTableRef): + return ref.table.hogql_table() elif isinstance(ref, ast.TableAliasRef): return ref.name elif isinstance(ref, ast.SelectQueryAliasRef): return ref.name - elif isinstance(ref, ast.LazyTableRef): + elif isinstance(ref, ast.LazyJoinRef): return f"{self._get_long_table_name(select, ref.table)}__{ref.field}" elif isinstance(ref, ast.VirtualTableRef): return f"{self._get_long_table_name(select, ref.table)}__{ref.field}" @@ -48,7 +55,7 @@ def visit_property_ref(self, node: ast.PropertyRef): if node.joined_subquery is not None: # we have already visited this property return - if isinstance(node.parent.table, ast.LazyTableRef): + if isinstance(node.parent.table, ast.LazyJoinRef) or isinstance(node.parent.table, ast.LazyTableRef): if self.context and self.context.within_non_hogql_query: # If we're in a non-HogQL query, traverse deeper, just like we normally would have. self.visit(node.parent) @@ -59,7 +66,7 @@ def visit_property_ref(self, node: ast.PropertyRef): self.stack_of_fields[-1].append(node) def visit_field_ref(self, node: ast.FieldRef): - if isinstance(node.table, ast.LazyTableRef): + if isinstance(node.table, ast.LazyJoinRef) or isinstance(node.table, ast.LazyTableRef): # Each time we find a field, we place it in a list for processing in "visit_select_query" if len(self.stack_of_fields) == 0: raise ValueError("Can't access a lazy field when not in a SelectQuery context") @@ -70,7 +77,7 @@ def visit_select_query(self, node: ast.SelectQuery): if not select_ref: raise ValueError("Select query must have a ref") - # Collect each `ast.Field` with `ast.LazyTableRef` + # Collect each `ast.Field` with `ast.LazyJoinRef` field_collector: List[ast.FieldRef] = [] self.stack_of_fields.append(field_collector) @@ -79,7 +86,18 @@ def visit_select_query(self, node: ast.SelectQuery): # Collect all the joins we need to add to the select query joins_to_add: Dict[str, JoinToAdd] = {} - for field_or_property in field_collector: + tables_to_add: Dict[str, TableToAdd] = {} + + # First properties, then fields. This way we always get the smallest units to query first. + matched_properties: List[ast.PropertyRef | ast.FieldRef] = [ + property for property in field_collector if isinstance(property, ast.PropertyRef) + ] + matched_fields: List[ast.PropertyRef | ast.FieldRef] = [ + field for field in field_collector if isinstance(field, ast.FieldRef) + ] + sorted_properties: List[ast.PropertyRef | ast.FieldRef] = matched_properties + matched_fields + + for field_or_property in sorted_properties: if isinstance(field_or_property, ast.FieldRef): property = None field = field_or_property @@ -92,70 +110,121 @@ def visit_select_query(self, node: ast.SelectQuery): # Traverse the lazy tables until we reach a real table, collecting them in a list. # Usually there's just one or two. - table_refs: List[LazyTableRef] = [] - while isinstance(table_ref, ast.LazyTableRef): + table_refs: List[ast.LazyJoinRef | ast.LazyTableRef] = [] + while isinstance(table_ref, ast.LazyJoinRef) or isinstance(table_ref, ast.LazyTableRef): table_refs.append(table_ref) table_ref = table_ref.table # Loop over the collected lazy tables in reverse order to create the joins for table_ref in reversed(table_refs): - from_table = self._get_long_table_name(select_ref, table_ref.table) - to_table = self._get_long_table_name(select_ref, table_ref) - if to_table not in joins_to_add: - joins_to_add[to_table] = JoinToAdd( - fields_accessed={}, # collect here all fields accessed on this table - lazy_table=table_ref.lazy_table, - from_table=from_table, - to_table=to_table, - ) - new_join = joins_to_add[to_table] - if table_ref == field.table: - chain = [] - if isinstance(table_ref, ast.LazyTableRef): - chain.append(table_ref.resolve_database_table().hogql_table()) - chain.append(field.name) - if property is not None: - chain.extend(property.chain) - property.joined_subquery_field_name = f"{field.name}___{'___'.join(property.chain)}" - new_join.fields_accessed[property.joined_subquery_field_name] = ast.Field(chain=chain) - else: - new_join.fields_accessed[field.name] = ast.Field(chain=chain) + if isinstance(table_ref, ast.LazyJoinRef): + from_table = self._get_long_table_name(select_ref, table_ref.table) + to_table = self._get_long_table_name(select_ref, table_ref) + if to_table not in joins_to_add: + joins_to_add[to_table] = JoinToAdd( + fields_accessed={}, # collect here all fields accessed on this table + lazy_join=table_ref.lazy_join, + from_table=from_table, + to_table=to_table, + ) + new_join = joins_to_add[to_table] + if table_ref == field.table: + chain = [] + chain.append(field.name) + if property is not None: + chain.extend(property.chain) + property.joined_subquery_field_name = f"{field.name}___{'___'.join(property.chain)}" + new_join.fields_accessed[property.joined_subquery_field_name] = ast.Field(chain=chain) + else: + new_join.fields_accessed[field.name] = ast.Field(chain=chain) + elif isinstance(table_ref, ast.LazyTableRef): + table_name = self._get_long_table_name(select_ref, table_ref) + if table_name not in tables_to_add: + tables_to_add[table_name] = TableToAdd( + fields_accessed={}, # collect here all fields accessed on this table + lazy_table=table_ref.table, + ) + new_table = tables_to_add[table_name] + if table_ref == field.table: + chain = [] + chain.append(field.name) + if property is not None: + chain.extend(property.chain) + property.joined_subquery_field_name = f"{field.name}___{'___'.join(property.chain)}" + new_table.fields_accessed[property.joined_subquery_field_name] = ast.Field(chain=chain) + else: + new_table.fields_accessed[field.name] = ast.Field(chain=chain) # Make sure we also add fields we will use for the join's "ON" condition into the list of fields accessed. # Without this "pdi.person.id" won't work if you did not ALSO select "pdi.person_id" explicitly for the join. for new_join in joins_to_add.values(): if new_join.from_table in joins_to_add: - joins_to_add[new_join.from_table].fields_accessed[new_join.lazy_table.from_field] = ast.Field( - chain=[new_join.lazy_table.from_field] + joins_to_add[new_join.from_table].fields_accessed[new_join.lazy_join.from_field] = ast.Field( + chain=[new_join.lazy_join.from_field] ) - # Move the "last_join" pointer to the last join in the SELECT query - last_join = node.select_from - while last_join and last_join.next_join is not None: - last_join = last_join.next_join + # For all the collected tables, create the subqueries, and add them to the table. + for table_name, table_to_add in tables_to_add.items(): + subquery = table_to_add.lazy_table.lazy_select(table_to_add.fields_accessed) + resolve_refs(subquery, self.context.database, select_ref) + old_table_ref = select_ref.tables[table_name] + select_ref.tables[table_name] = ast.SelectQueryAliasRef(name=table_name, ref=subquery.ref) + + join_ptr = node.select_from + while join_ptr: + if join_ptr.table.ref == old_table_ref: + join_ptr.table = subquery + join_ptr.ref = select_ref.tables[table_name] + join_ptr.alias = table_name + break + join_ptr = join_ptr.next_join # For all the collected joins, create the join subqueries, and add them to the table. - for to_table, scope in joins_to_add.items(): - next_join = scope.lazy_table.join_function(scope.from_table, scope.to_table, scope.fields_accessed) - resolve_refs(next_join, self.context.database, select_ref) - select_ref.tables[to_table] = next_join.ref - - # Link up the joins properly - if last_join is None: - node.select_from = next_join - last_join = next_join - else: - last_join.next_join = next_join - while last_join.next_join is not None: - last_join = last_join.next_join + for to_table, join_scope in joins_to_add.items(): + join_to_add: ast.JoinExpr = join_scope.lazy_join.join_function( + join_scope.from_table, join_scope.to_table, join_scope.fields_accessed + ) + resolve_refs(join_to_add, self.context.database, select_ref) + select_ref.tables[to_table] = join_to_add.ref + + join_ptr = node.select_from + added = False + while join_ptr: + if join_scope.from_table == join_ptr.alias or ( + isinstance(join_ptr.table, ast.Field) and join_scope.from_table == join_ptr.table.chain[0] + ): + join_to_add.next_join = join_ptr.next_join + join_ptr.next_join = join_to_add + added = True + break + if join_ptr.next_join: + join_ptr = join_ptr.next_join + else: + break + if not added: + if join_ptr: + join_ptr.next_join = join_to_add + elif node.select_from: + node.select_from.next_join = join_to_add + else: + node.select_from = join_to_add # Assign all refs on the fields we collected earlier for field_or_property in field_collector: if isinstance(field_or_property, ast.FieldRef): - to_table = self._get_long_table_name(select_ref, field_or_property.table) - field_or_property.table = select_ref.tables[to_table] + table_ref = field_or_property.table + elif isinstance(field_or_property, ast.PropertyRef): + table_ref = field_or_property.parent.table + else: + raise Exception("Should not be reachable") + + table_name = self._get_long_table_name(select_ref, table_ref) + table_ref = select_ref.tables[table_name] + + if isinstance(field_or_property, ast.FieldRef): + field_or_property.table = table_ref elif isinstance(field_or_property, ast.PropertyRef): - to_table = self._get_long_table_name(select_ref, cast(ast.PropertyRef, field_or_property).parent.table) - field_or_property.joined_subquery = select_ref.tables[to_table] + field_or_property.parent.table = table_ref + field_or_property.joined_subquery = table_ref self.stack_of_fields.pop() diff --git a/posthog/hogql/transforms/property_types.py b/posthog/hogql/transforms/property_types.py index 43dd7ab65ec4e..da3a076747090 100644 --- a/posthog/hogql/transforms/property_types.py +++ b/posthog/hogql/transforms/property_types.py @@ -59,7 +59,7 @@ def visit_property_ref(self, node: ast.PropertyRef): if node.parent.name == "properties" and len(node.chain) == 1: if isinstance(node.parent.table, ast.BaseTableRef): table = node.parent.table.resolve_database_table().hogql_table() - if table == "persons": + if table == "persons" or table == "raw_persons": self.person_properties.add(node.chain[0]) if table == "events": self.event_properties.add(node.chain[0]) @@ -86,7 +86,7 @@ def visit_field(self, node: ast.Field): if isinstance(ref, ast.PropertyRef) and ref.parent.name == "properties" and len(ref.chain) == 1: if isinstance(ref.parent.table, ast.BaseTableRef): table = ref.parent.table.resolve_database_table().hogql_table() - if table == "persons": + if table == "persons" or table == "raw_persons": if ref.chain[0] in self.person_properties: return self._add_type_to_string_field(node, self.person_properties[ref.chain[0]]) if table == "events": diff --git a/posthog/hogql/transforms/test/test_lazy_tables.py b/posthog/hogql/transforms/test/test_lazy_tables.py index 9421858006b8d..9ad5282a666d7 100644 --- a/posthog/hogql/transforms/test/test_lazy_tables.py +++ b/posthog/hogql/transforms/test/test_lazy_tables.py @@ -6,7 +6,7 @@ from posthog.test.base import BaseTest -class TestLazyTables(BaseTest): +class TestLazyJoins(BaseTest): def test_resolve_lazy_tables(self): printed = self._print_select("select event, pdi.person_id from events") expected = ( @@ -67,15 +67,14 @@ def test_resolve_lazy_tables_two_levels_traversed(self): def test_resolve_lazy_tables_one_level_properties(self): printed = self._print_select("select person.properties.$browser from person_distinct_ids") expected = ( - "SELECT person_distinct_ids__person.`properties___$browser` " - "FROM person_distinct_id2 INNER JOIN " - "(SELECT argMax(replaceRegexpAll(JSONExtractRaw(person.properties, %(hogql_val_0)s), '^\"|\"$', ''), person.version) " - "AS `properties___$browser`, person.id FROM person " - f"WHERE equals(person.team_id, {self.team.pk}) GROUP BY person.id " - "HAVING equals(argMax(person.is_deleted, person.version), 0)" - ") AS person_distinct_ids__person ON equals(person_distinct_id2.person_id, person_distinct_ids__person.id) " - f"WHERE equals(person_distinct_id2.team_id, {self.team.pk}) " - "LIMIT 65535" + f"SELECT person_distinct_ids__person.`properties___$browser` FROM " + f"(SELECT argMax(person_distinct_id2.person_id, person_distinct_id2.version) AS person_id, person_distinct_id2.distinct_id " + f"FROM person_distinct_id2 WHERE equals(person_distinct_id2.team_id, {self.team.pk}) GROUP BY person_distinct_id2.distinct_id " + f"HAVING equals(argMax(person_distinct_id2.is_deleted, person_distinct_id2.version), 0)) AS person_distinct_ids " + f"INNER JOIN (SELECT argMax(replaceRegexpAll(JSONExtractRaw(person.properties, %(hogql_val_0)s), '^\"|\"$', ''), person.version) " + f"AS `properties___$browser`, person.id FROM person WHERE equals(person.team_id, {self.team.pk}) GROUP BY person.id " + f"HAVING equals(argMax(person.is_deleted, person.version), 0)) AS person_distinct_ids__person " + f"ON equals(person_distinct_ids.person_id, person_distinct_ids__person.id) LIMIT 65535" ) self.assertEqual(printed, expected) @@ -83,15 +82,14 @@ def test_resolve_lazy_tables_one_level_properties(self): def test_resolve_lazy_tables_one_level_properties_deep(self): printed = self._print_select("select person.properties.$browser.in.json from person_distinct_ids") expected = ( - "SELECT person_distinct_ids__person.`properties___$browser___in___json` " - "FROM person_distinct_id2 INNER JOIN " - "(SELECT argMax(replaceRegexpAll(JSONExtractRaw(person.properties, %(hogql_val_0)s, %(hogql_val_1)s, %(hogql_val_2)s), '^\"|\"$', ''), person.version) " - "AS `properties___$browser___in___json`, person.id FROM person " - f"WHERE equals(person.team_id, {self.team.pk}) GROUP BY person.id " - "HAVING equals(argMax(person.is_deleted, person.version), 0)" - ") AS person_distinct_ids__person ON equals(person_distinct_id2.person_id, person_distinct_ids__person.id) " - f"WHERE equals(person_distinct_id2.team_id, {self.team.pk}) " - "LIMIT 65535" + f"SELECT person_distinct_ids__person.`properties___$browser___in___json` FROM " + f"(SELECT argMax(person_distinct_id2.person_id, person_distinct_id2.version) AS person_id, person_distinct_id2.distinct_id " + f"FROM person_distinct_id2 WHERE equals(person_distinct_id2.team_id, {self.team.pk}) GROUP BY person_distinct_id2.distinct_id " + f"HAVING equals(argMax(person_distinct_id2.is_deleted, person_distinct_id2.version), 0)) AS person_distinct_ids " + f"INNER JOIN (SELECT argMax(replaceRegexpAll(JSONExtractRaw(person.properties, %(hogql_val_0)s, %(hogql_val_1)s, %(hogql_val_2)s), '^\"|\"$', ''), person.version) " + f"AS `properties___$browser___in___json`, person.id FROM person WHERE equals(person.team_id, {self.team.pk}) GROUP BY person.id " + f"HAVING equals(argMax(person.is_deleted, person.version), 0)) AS person_distinct_ids__person " + f"ON equals(person_distinct_ids.person_id, person_distinct_ids__person.id) LIMIT 65535" ) self.assertEqual(printed, expected) @@ -119,11 +117,40 @@ def test_resolve_lazy_tables_two_levels_properties_duplicate(self): f"person_distinct_id2.distinct_id FROM person_distinct_id2 WHERE equals(person_distinct_id2.team_id, {self.team.pk}) " f"GROUP BY person_distinct_id2.distinct_id HAVING equals(argMax(person_distinct_id2.is_deleted, " f"person_distinct_id2.version), 0)) AS events__pdi ON equals(events.distinct_id, events__pdi.distinct_id) " - f"INNER JOIN (SELECT argMax(person.properties, person.version) AS properties, " - f"argMax(replaceRegexpAll(JSONExtractRaw(person.properties, %(hogql_val_0)s), '^\"|\"$', ''), person.version) " - f"AS properties___name, person.id FROM person WHERE equals(person.team_id, {self.team.pk}) GROUP BY person.id " - f"HAVING equals(argMax(person.is_deleted, person.version), 0)) AS events__pdi__person ON " - f"equals(events__pdi.person_id, events__pdi__person.id) WHERE equals(events.team_id, {self.team.pk}) LIMIT 65535" + f"INNER JOIN (SELECT argMax(replaceRegexpAll(JSONExtractRaw(person.properties, %(hogql_val_0)s), '^\"|\"$', ''), person.version) " + f"AS properties___name, argMax(person.properties, person.version) AS properties, person.id FROM person " + f"WHERE equals(person.team_id, {self.team.pk}) GROUP BY person.id HAVING equals(argMax(person.is_deleted, person.version), 0)) " + f"AS events__pdi__person ON equals(events__pdi.person_id, events__pdi__person.id) WHERE equals(events.team_id, {self.team.pk}) LIMIT 65535" + ) + self.assertEqual(printed, expected) + + @override_settings(PERSON_ON_EVENTS_OVERRIDE=False) + def test_resolve_lazy_table_as_select_table(self): + printed = self._print_select("select id, properties.email, properties.$browser from persons") + expected = ( + f"SELECT persons.id, persons.properties___email, persons.`properties___$browser` FROM " + f"(SELECT argMax(replaceRegexpAll(JSONExtractRaw(person.properties, %(hogql_val_0)s), '^\"|\"$', ''), person.version) AS " + f"properties___email, argMax(replaceRegexpAll(JSONExtractRaw(person.properties, %(hogql_val_1)s), '^\"|\"$', ''), person.version) " + f"AS `properties___$browser`, person.id FROM person WHERE equals(person.team_id, {self.team.pk}) GROUP BY person.id " + f"HAVING equals(argMax(person.is_deleted, person.version), 0)) AS persons LIMIT 65535" + ) + self.assertEqual(printed, expected) + + @override_settings(PERSON_ON_EVENTS_OVERRIDE=False) + def test_resolve_lazy_table_as_table_in_join(self): + printed = self._print_select( + "select event, distinct_id, events.person_id, persons.properties.email from events left join persons on persons.id = events.person_id limit 10" + ) + expected = ( + f"SELECT events.event, events.distinct_id, events__pdi.person_id, persons.properties___email FROM events " + f"INNER JOIN (SELECT argMax(person_distinct_id2.person_id, person_distinct_id2.version) AS person_id, " + f"person_distinct_id2.distinct_id FROM person_distinct_id2 WHERE equals(person_distinct_id2.team_id, {self.team.pk}) " + f"GROUP BY person_distinct_id2.distinct_id HAVING equals(argMax(person_distinct_id2.is_deleted, person_distinct_id2.version), 0)) " + f"AS events__pdi ON equals(events.distinct_id, events__pdi.distinct_id) LEFT JOIN (SELECT " + f"argMax(replaceRegexpAll(JSONExtractRaw(person.properties, %(hogql_val_0)s), '^\"|\"$', ''), person.version) AS properties___email, " + f"person.id FROM person WHERE equals(person.team_id, {self.team.pk}) GROUP BY person.id " + f"HAVING equals(argMax(person.is_deleted, person.version), 0)) AS persons ON equals(persons.id, events__pdi.person_id) " + f"WHERE equals(events.team_id, {self.team.pk}) LIMIT 10" ) self.assertEqual(printed, expected) diff --git a/posthog/hogql/transforms/test/test_property_types.py b/posthog/hogql/transforms/test/test_property_types.py index 2b128c9e886e7..eca0a3020e2a4 100644 --- a/posthog/hogql/transforms/test/test_property_types.py +++ b/posthog/hogql/transforms/test/test_property_types.py @@ -54,9 +54,21 @@ def test_resolve_property_types_event(self): ) self.assertEqual(printed, expected) + def test_resolve_property_types_person_raw(self): + printed = self._print_select( + "select properties.tickets, properties.provided_timestamp, properties.$initial_browser from raw_persons" + ) + expected = ( + "SELECT toFloat64OrNull(replaceRegexpAll(JSONExtractRaw(person.properties, %(hogql_val_0)s), '^\"|\"$', '')), " + "toDateTimeOrNull(replaceRegexpAll(JSONExtractRaw(person.properties, %(hogql_val_1)s), '^\"|\"$', ''), %(hogql_val_2)s), " + "replaceRegexpAll(JSONExtractRaw(person.properties, %(hogql_val_3)s), '^\"|\"$', '') " + f"FROM person WHERE equals(person.team_id, {self.team.pk}) LIMIT 65535" + ) + self.assertEqual(printed, expected) + def test_resolve_property_types_person(self): printed = self._print_select( - "select properties.tickets, properties.provided_timestamp, properties.$initial_browser from persons" + "select properties.tickets, properties.provided_timestamp, properties.$initial_browser from raw_persons" ) expected = ( "SELECT toFloat64OrNull(replaceRegexpAll(JSONExtractRaw(person.properties, %(hogql_val_0)s), '^\"|\"$', '')), " diff --git a/posthog/hogql/visitor.py b/posthog/hogql/visitor.py index 382f38b4e9482..e5f52d404560f 100644 --- a/posthog/hogql/visitor.py +++ b/posthog/hogql/visitor.py @@ -118,10 +118,13 @@ def visit_select_union_query_ref(self, node: ast.SelectUnionQueryRef): def visit_table_ref(self, node: ast.TableRef): pass - def visit_field_traverser_ref(self, node: ast.LazyTableRef): + def visit_lazy_table_ref(self, node: ast.TableRef): + pass + + def visit_field_traverser_ref(self, node: ast.LazyJoinRef): self.visit(node.table) - def visit_lazy_table_ref(self, node: ast.LazyTableRef): + def visit_lazy_join_ref(self, node: ast.LazyJoinRef): self.visit(node.table) def visit_virtual_table_ref(self, node: ast.VirtualTableRef): diff --git a/posthog/queries/funnels/test/test_funnel.py b/posthog/queries/funnels/test/test_funnel.py index 3eda9bc2cd871..94a9eef2742d1 100644 --- a/posthog/queries/funnels/test/test_funnel.py +++ b/posthog/queries/funnels/test/test_funnel.py @@ -176,7 +176,6 @@ def test_funnel_events(self): @override_settings(PERSON_ON_EVENTS_V2_OVERRIDE=True) @snapshot_clickhouse_queries def test_funnel_events_with_person_on_events_v2(self): - # KLUDGE: We need to do this to ensure create_person_id_override_by_distinct_id # works correctly. Worth considering other approaches as we generally like to # avoid truncating tables in tests for speed.