From 41c9517b0d215ba85f5d668325810c44ca877c8b Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Tue, 15 Oct 2024 18:47:36 +0300 Subject: [PATCH 1/3] fix(optimizer): Add type annotations for eliminate & merge subqueries --- sqlglot/optimizer/eliminate_subqueries.py | 31 ++++++++--- sqlglot/optimizer/merge_subqueries.py | 64 ++++++++++------------- 2 files changed, 50 insertions(+), 45 deletions(-) diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index f61f899155..fa46a0d97a 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -1,11 +1,20 @@ +from __future__ import annotations + import itertools +import typing as t from sqlglot import expressions as exp from sqlglot.helper import find_new_name -from sqlglot.optimizer.scope import build_scope +from sqlglot.optimizer.scope import Scope, build_scope + +if t.TYPE_CHECKING: + from sqlglot._typing import E + + ExistingCTEsMapping = t.Dict[E, str] + TakenNameMaping = t.Dict[t.Any, t.Union[Scope, E]] -def eliminate_subqueries(expression): +def eliminate_subqueries(expression: exp.Expression) -> exp.Expression: """ Rewrite derived tables as CTES, deduplicating if possible. @@ -38,7 +47,7 @@ def eliminate_subqueries(expression): # Map of alias->Scope|Table # These are all aliases that are already used in the expression. # We don't want to create new CTEs that conflict with these names. - taken = {} + taken: TakenNameMaping = {} # All CTE aliases in the root scope are taken for scope in root.cte_scopes: @@ -56,7 +65,7 @@ def eliminate_subqueries(expression): # Map of Expression->alias # Existing CTES in the root expression. We'll use this for deduplication. - existing_ctes = {} + existing_ctes: ExistingCTEsMapping = {} with_ = root.expression.args.get("with") recursive = False @@ -95,7 +104,7 @@ def eliminate_subqueries(expression): return expression -def _eliminate(scope, existing_ctes, taken): +def _eliminate(scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMaping): if scope.is_derived_table: return _eliminate_derived_table(scope, existing_ctes, taken) @@ -103,7 +112,9 @@ def _eliminate(scope, existing_ctes, taken): return _eliminate_cte(scope, existing_ctes, taken) -def _eliminate_derived_table(scope, existing_ctes, taken): +def _eliminate_derived_table( + scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMaping +) -> t.Optional[exp.Expression]: # This makes sure that we don't: # - drop the "pivot" arg from a pivoted subquery # - eliminate a lateral correlated subquery @@ -121,7 +132,9 @@ def _eliminate_derived_table(scope, existing_ctes, taken): return cte -def _eliminate_cte(scope, existing_ctes, taken): +def _eliminate_cte( + scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMaping +) -> t.Optional[exp.Expression]: parent = scope.expression.parent name, cte = _new_cte(scope, existing_ctes, taken) @@ -140,7 +153,9 @@ def _eliminate_cte(scope, existing_ctes, taken): return cte -def _new_cte(scope, existing_ctes, taken): +def _new_cte( + scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMaping +) -> t.Tuple[str, t.Optional[exp.Expression]]: """ Returns: tuple of (name, cte) diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 866f78c239..412ffa2025 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -1,11 +1,18 @@ +from __future__ import annotations + +import typing as t + from collections import defaultdict from sqlglot import expressions as exp from sqlglot.helper import find_new_name from sqlglot.optimizer.scope import Scope, traverse_scope +if t.TYPE_CHECKING: + FromOrJoin = t.Union[exp.From, exp.Join] + -def merge_subqueries(expression, leave_tables_isolated=False): +def merge_subqueries(expression: exp.Expression, leave_tables_isolated=False): """ Rewrite sqlglot AST to merge derived tables into the outer query. @@ -58,7 +65,7 @@ def merge_subqueries(expression, leave_tables_isolated=False): ) -def merge_ctes(expression, leave_tables_isolated=False): +def merge_ctes(expression: exp.Expression, leave_tables_isolated=False) -> exp.Expression: scopes = traverse_scope(expression) # All places where we select from CTEs. @@ -92,7 +99,7 @@ def merge_ctes(expression, leave_tables_isolated=False): return expression -def merge_derived_tables(expression, leave_tables_isolated=False): +def merge_derived_tables(expression: exp.Expression, leave_tables_isolated=False): for outer_scope in traverse_scope(expression): for subquery in outer_scope.derived_tables: from_or_join = subquery.find_ancestor(exp.From, exp.Join) @@ -111,17 +118,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False): return expression -def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): +def _mergeable( + outer_scope: Scope, inner_scope: Scope, leave_tables_isolated: bool, from_or_join: FromOrJoin +) -> bool: """ Return True if `inner_select` can be merged into outer query. - - Args: - outer_scope (Scope) - inner_scope (Scope) - leave_tables_isolated (bool) - from_or_join (exp.From|exp.Join) - Returns: - bool: True if can be merged """ inner_select = inner_scope.expression.unnest() @@ -195,7 +196,7 @@ def _is_recursive(): and not outer_scope.expression.is_star and isinstance(inner_select, exp.Select) and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) - and inner_select.args.get("from") + and inner_select.args.get("from") is not None and not outer_scope.pivots and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions) and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) @@ -218,14 +219,9 @@ def _is_recursive(): ) -def _rename_inner_sources(outer_scope, inner_scope, alias): +def _rename_inner_sources(outer_scope: Scope, inner_scope: Scope, alias: str) -> None: """ Renames any sources in the inner query that conflict with names in the outer query. - - Args: - outer_scope (sqlglot.optimizer.scope.Scope) - inner_scope (sqlglot.optimizer.scope.Scope) - alias (str) """ inner_taken = set(inner_scope.selected_sources) outer_taken = set(outer_scope.selected_sources) @@ -253,15 +249,14 @@ def _rename_inner_sources(outer_scope, inner_scope, alias): inner_scope.rename_source(conflict, new_name) -def _merge_from(outer_scope, inner_scope, node_to_replace, alias): +def _merge_from( + outer_scope: Scope, + inner_scope: Scope, + node_to_replace: t.Union[exp.Subquery, exp.Table], + alias: str, +) -> None: """ Merge FROM clause of inner query into outer query. - - Args: - outer_scope (sqlglot.optimizer.scope.Scope) - inner_scope (sqlglot.optimizer.scope.Scope) - node_to_replace (exp.Subquery|exp.Table) - alias (str) """ new_subquery = inner_scope.expression.args["from"].this new_subquery.set("joins", node_to_replace.args.get("joins")) @@ -277,14 +272,9 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias): ) -def _merge_joins(outer_scope, inner_scope, from_or_join): +def _merge_joins(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None: """ Merge JOIN clauses of inner query into outer query. - - Args: - outer_scope (sqlglot.optimizer.scope.Scope) - inner_scope (sqlglot.optimizer.scope.Scope) - from_or_join (exp.From|exp.Join) """ new_joins = [] @@ -341,7 +331,7 @@ def _merge_expressions(outer_scope, inner_scope, alias): column.replace(expression.copy()) -def _merge_where(outer_scope, inner_scope, from_or_join): +def _merge_where(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None: """ Merge WHERE clause of inner query into outer query. @@ -360,7 +350,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join): # Merge predicates from an outer join to the ON clause # if it only has columns that are already joined from_ = expression.args.get("from") - sources = {from_.alias_or_name} if from_ else {} + sources = {from_.alias_or_name} if from_ else set() for join in expression.args["joins"]: source = join.alias_or_name @@ -376,7 +366,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join): expression.where(where.this, copy=False) -def _merge_order(outer_scope, inner_scope): +def _merge_order(outer_scope: Scope, inner_scope: Scope) -> None: """ Merge ORDER clause of inner query into outer query. @@ -396,7 +386,7 @@ def _merge_order(outer_scope, inner_scope): outer_scope.expression.set("order", inner_scope.expression.args.get("order")) -def _merge_hints(outer_scope, inner_scope): +def _merge_hints(outer_scope: Scope, inner_scope: Scope) -> None: inner_scope_hint = inner_scope.expression.args.get("hint") if not inner_scope_hint: return @@ -408,7 +398,7 @@ def _merge_hints(outer_scope, inner_scope): outer_scope.expression.set("hint", inner_scope_hint) -def _pop_cte(inner_scope): +def _pop_cte(inner_scope: Scope) -> None: """ Remove CTE from the AST. From d8cfba9809c63fb8a68cbad3f211337402b314a8 Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Fri, 18 Oct 2024 18:14:43 +0300 Subject: [PATCH 2/3] PR Feedback 1 --- sqlglot/optimizer/eliminate_subqueries.py | 16 ++++++++++------ sqlglot/optimizer/merge_subqueries.py | 10 ++++++---- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index fa46a0d97a..c230398513 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -11,7 +11,7 @@ from sqlglot._typing import E ExistingCTEsMapping = t.Dict[E, str] - TakenNameMaping = t.Dict[t.Any, t.Union[Scope, E]] + TakenNameMapping = t.Dict[str, t.Union[Scope, E]] def eliminate_subqueries(expression: exp.Expression) -> exp.Expression: @@ -47,7 +47,7 @@ def eliminate_subqueries(expression: exp.Expression) -> exp.Expression: # Map of alias->Scope|Table # These are all aliases that are already used in the expression. # We don't want to create new CTEs that conflict with these names. - taken: TakenNameMaping = {} + taken: TakenNameMapping = {} # All CTE aliases in the root scope are taken for scope in root.cte_scopes: @@ -104,16 +104,20 @@ def eliminate_subqueries(expression: exp.Expression) -> exp.Expression: return expression -def _eliminate(scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMaping): +def _eliminate( + scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping +) -> t.Optional[exp.Expression]: if scope.is_derived_table: return _eliminate_derived_table(scope, existing_ctes, taken) if scope.is_cte: return _eliminate_cte(scope, existing_ctes, taken) + return None + def _eliminate_derived_table( - scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMaping + scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping ) -> t.Optional[exp.Expression]: # This makes sure that we don't: # - drop the "pivot" arg from a pivoted subquery @@ -133,7 +137,7 @@ def _eliminate_derived_table( def _eliminate_cte( - scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMaping + scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping ) -> t.Optional[exp.Expression]: parent = scope.expression.parent name, cte = _new_cte(scope, existing_ctes, taken) @@ -154,7 +158,7 @@ def _eliminate_cte( def _new_cte( - scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMaping + scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping ) -> t.Tuple[str, t.Optional[exp.Expression]]: """ Returns: diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 412ffa2025..4bd7ee2291 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -9,10 +9,12 @@ from sqlglot.optimizer.scope import Scope, traverse_scope if t.TYPE_CHECKING: + from sqlglot._typing import E + FromOrJoin = t.Union[exp.From, exp.Join] -def merge_subqueries(expression: exp.Expression, leave_tables_isolated=False): +def merge_subqueries(expression: E, leave_tables_isolated=False) -> E: """ Rewrite sqlglot AST to merge derived tables into the outer query. @@ -65,7 +67,7 @@ def merge_subqueries(expression: exp.Expression, leave_tables_isolated=False): ) -def merge_ctes(expression: exp.Expression, leave_tables_isolated=False) -> exp.Expression: +def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E: scopes = traverse_scope(expression) # All places where we select from CTEs. @@ -99,7 +101,7 @@ def merge_ctes(expression: exp.Expression, leave_tables_isolated=False) -> exp.E return expression -def merge_derived_tables(expression: exp.Expression, leave_tables_isolated=False): +def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) -> E: for outer_scope in traverse_scope(expression): for subquery in outer_scope.derived_tables: from_or_join = subquery.find_ancestor(exp.From, exp.Join) @@ -297,7 +299,7 @@ def _merge_joins(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoi outer_scope.expression.set("joins", outer_joins) -def _merge_expressions(outer_scope, inner_scope, alias): +def _merge_expressions(outer_scope: Scope, inner_scope: Scope, alias: str) -> None: """ Merge projections of inner query into outer query. From 404a9cf75fa864103ccfbc47e3d133ec5de047c1 Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Mon, 21 Oct 2024 19:05:04 +0300 Subject: [PATCH 3/3] PR Feedback 2 --- sqlglot/optimizer/eliminate_subqueries.py | 6 ++---- sqlglot/optimizer/merge_subqueries.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index c230398513..b661003690 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -8,10 +8,8 @@ from sqlglot.optimizer.scope import Scope, build_scope if t.TYPE_CHECKING: - from sqlglot._typing import E - - ExistingCTEsMapping = t.Dict[E, str] - TakenNameMapping = t.Dict[str, t.Union[Scope, E]] + ExistingCTEsMapping = t.Dict[exp.Expression, str] + TakenNameMapping = t.Dict[str, t.Union[Scope, exp.Expression]] def eliminate_subqueries(expression: exp.Expression) -> exp.Expression: diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 4bd7ee2291..f02c6b0b6a 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -14,7 +14,7 @@ FromOrJoin = t.Union[exp.From, exp.Join] -def merge_subqueries(expression: E, leave_tables_isolated=False) -> E: +def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E: """ Rewrite sqlglot AST to merge derived tables into the outer query.