Skip to content

Commit

Permalink
fix(optimizer)!: Add type hints for optimizer rules eliminate & merge…
Browse files Browse the repository at this point in the history
… subqueries (#4267)

* fix(optimizer): Add type annotations for eliminate & merge subqueries

* PR Feedback 1

* PR Feedback 2
  • Loading branch information
VaggelisD authored Oct 21, 2024
1 parent 36f6841 commit 222152e
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 46 deletions.
33 changes: 25 additions & 8 deletions sqlglot/optimizer/eliminate_subqueries.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from __future__ import annotations

import itertools
import typing as t

from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import build_scope
from sqlglot.optimizer.scope import Scope, build_scope

if t.TYPE_CHECKING:
ExistingCTEsMapping = t.Dict[exp.Expression, str]
TakenNameMapping = t.Dict[str, t.Union[Scope, exp.Expression]]


def eliminate_subqueries(expression):
def eliminate_subqueries(expression: exp.Expression) -> exp.Expression:
"""
Rewrite derived tables as CTES, deduplicating if possible.
Expand Down Expand Up @@ -38,7 +45,7 @@ def eliminate_subqueries(expression):
# Map of alias->Scope|Table
# These are all aliases that are already used in the expression.
# We don't want to create new CTEs that conflict with these names.
taken = {}
taken: TakenNameMapping = {}

# All CTE aliases in the root scope are taken
for scope in root.cte_scopes:
Expand All @@ -56,7 +63,7 @@ def eliminate_subqueries(expression):

# Map of Expression->alias
# Existing CTES in the root expression. We'll use this for deduplication.
existing_ctes = {}
existing_ctes: ExistingCTEsMapping = {}

with_ = root.expression.args.get("with")
recursive = False
Expand Down Expand Up @@ -95,15 +102,21 @@ def eliminate_subqueries(expression):
return expression


def _eliminate(scope, existing_ctes, taken):
def _eliminate(
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
) -> t.Optional[exp.Expression]:
if scope.is_derived_table:
return _eliminate_derived_table(scope, existing_ctes, taken)

if scope.is_cte:
return _eliminate_cte(scope, existing_ctes, taken)

return None


def _eliminate_derived_table(scope, existing_ctes, taken):
def _eliminate_derived_table(
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
) -> t.Optional[exp.Expression]:
# This makes sure that we don't:
# - drop the "pivot" arg from a pivoted subquery
# - eliminate a lateral correlated subquery
Expand All @@ -121,7 +134,9 @@ def _eliminate_derived_table(scope, existing_ctes, taken):
return cte


def _eliminate_cte(scope, existing_ctes, taken):
def _eliminate_cte(
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
) -> t.Optional[exp.Expression]:
parent = scope.expression.parent
name, cte = _new_cte(scope, existing_ctes, taken)

Expand All @@ -140,7 +155,9 @@ def _eliminate_cte(scope, existing_ctes, taken):
return cte


def _new_cte(scope, existing_ctes, taken):
def _new_cte(
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
) -> t.Tuple[str, t.Optional[exp.Expression]]:
"""
Returns:
tuple of (name, cte)
Expand Down
68 changes: 30 additions & 38 deletions sqlglot/optimizer/merge_subqueries.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from __future__ import annotations

import typing as t

from collections import defaultdict

from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import Scope, traverse_scope

if t.TYPE_CHECKING:
from sqlglot._typing import E

FromOrJoin = t.Union[exp.From, exp.Join]

def merge_subqueries(expression, leave_tables_isolated=False):

def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E:
"""
Rewrite sqlglot AST to merge derived tables into the outer query.
Expand Down Expand Up @@ -58,7 +67,7 @@ def merge_subqueries(expression, leave_tables_isolated=False):
)


def merge_ctes(expression, leave_tables_isolated=False):
def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E:
scopes = traverse_scope(expression)

# All places where we select from CTEs.
Expand Down Expand Up @@ -92,7 +101,7 @@ def merge_ctes(expression, leave_tables_isolated=False):
return expression


def merge_derived_tables(expression, leave_tables_isolated=False):
def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) -> E:
for outer_scope in traverse_scope(expression):
for subquery in outer_scope.derived_tables:
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
Expand All @@ -111,17 +120,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
return expression


def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
def _mergeable(
outer_scope: Scope, inner_scope: Scope, leave_tables_isolated: bool, from_or_join: FromOrJoin
) -> bool:
"""
Return True if `inner_select` can be merged into outer query.
Args:
outer_scope (Scope)
inner_scope (Scope)
leave_tables_isolated (bool)
from_or_join (exp.From|exp.Join)
Returns:
bool: True if can be merged
"""
inner_select = inner_scope.expression.unnest()

Expand Down Expand Up @@ -195,7 +198,7 @@ def _is_recursive():
and not outer_scope.expression.is_star
and isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from")
and inner_select.args.get("from") is not None
and not outer_scope.pivots
and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions)
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
Expand All @@ -218,14 +221,9 @@ def _is_recursive():
)


def _rename_inner_sources(outer_scope, inner_scope, alias):
def _rename_inner_sources(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
"""
Renames any sources in the inner query that conflict with names in the outer query.
Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
alias (str)
"""
inner_taken = set(inner_scope.selected_sources)
outer_taken = set(outer_scope.selected_sources)
Expand Down Expand Up @@ -253,15 +251,14 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
inner_scope.rename_source(conflict, new_name)


def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
def _merge_from(
outer_scope: Scope,
inner_scope: Scope,
node_to_replace: t.Union[exp.Subquery, exp.Table],
alias: str,
) -> None:
"""
Merge FROM clause of inner query into outer query.
Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
node_to_replace (exp.Subquery|exp.Table)
alias (str)
"""
new_subquery = inner_scope.expression.args["from"].this
new_subquery.set("joins", node_to_replace.args.get("joins"))
Expand All @@ -277,14 +274,9 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
)


def _merge_joins(outer_scope, inner_scope, from_or_join):
def _merge_joins(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
"""
Merge JOIN clauses of inner query into outer query.
Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
from_or_join (exp.From|exp.Join)
"""

new_joins = []
Expand All @@ -307,7 +299,7 @@ def _merge_joins(outer_scope, inner_scope, from_or_join):
outer_scope.expression.set("joins", outer_joins)


def _merge_expressions(outer_scope, inner_scope, alias):
def _merge_expressions(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
"""
Merge projections of inner query into outer query.
Expand Down Expand Up @@ -341,7 +333,7 @@ def _merge_expressions(outer_scope, inner_scope, alias):
column.replace(expression.copy())


def _merge_where(outer_scope, inner_scope, from_or_join):
def _merge_where(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
"""
Merge WHERE clause of inner query into outer query.
Expand All @@ -360,7 +352,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
# Merge predicates from an outer join to the ON clause
# if it only has columns that are already joined
from_ = expression.args.get("from")
sources = {from_.alias_or_name} if from_ else {}
sources = {from_.alias_or_name} if from_ else set()

for join in expression.args["joins"]:
source = join.alias_or_name
Expand All @@ -376,7 +368,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
expression.where(where.this, copy=False)


def _merge_order(outer_scope, inner_scope):
def _merge_order(outer_scope: Scope, inner_scope: Scope) -> None:
"""
Merge ORDER clause of inner query into outer query.
Expand All @@ -396,7 +388,7 @@ def _merge_order(outer_scope, inner_scope):
outer_scope.expression.set("order", inner_scope.expression.args.get("order"))


def _merge_hints(outer_scope, inner_scope):
def _merge_hints(outer_scope: Scope, inner_scope: Scope) -> None:
inner_scope_hint = inner_scope.expression.args.get("hint")
if not inner_scope_hint:
return
Expand All @@ -408,7 +400,7 @@ def _merge_hints(outer_scope, inner_scope):
outer_scope.expression.set("hint", inner_scope_hint)


def _pop_cte(inner_scope):
def _pop_cte(inner_scope: Scope) -> None:
"""
Remove CTE from the AST.
Expand Down

0 comments on commit 222152e

Please sign in to comment.