Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(optimizer)!: Add type hints for optimizer rules eliminate & merge subqueries #4267

Merged
merged 3 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions sqlglot/optimizer/eliminate_subqueries.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from __future__ import annotations

import itertools
import typing as t

from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import build_scope
from sqlglot.optimizer.scope import Scope, build_scope

if t.TYPE_CHECKING:
ExistingCTEsMapping = t.Dict[exp.Expression, str]
TakenNameMapping = t.Dict[str, t.Union[Scope, exp.Expression]]


def eliminate_subqueries(expression):
def eliminate_subqueries(expression: exp.Expression) -> exp.Expression:
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
"""
Rewrite derived tables as CTES, deduplicating if possible.

Expand Down Expand Up @@ -38,7 +45,7 @@ def eliminate_subqueries(expression):
# Map of alias->Scope|Table
# These are all aliases that are already used in the expression.
# We don't want to create new CTEs that conflict with these names.
taken = {}
taken: TakenNameMapping = {}

# All CTE aliases in the root scope are taken
for scope in root.cte_scopes:
Expand All @@ -56,7 +63,7 @@ def eliminate_subqueries(expression):

# Map of Expression->alias
# Existing CTES in the root expression. We'll use this for deduplication.
existing_ctes = {}
existing_ctes: ExistingCTEsMapping = {}

with_ = root.expression.args.get("with")
recursive = False
Expand Down Expand Up @@ -95,15 +102,21 @@ def eliminate_subqueries(expression):
return expression


def _eliminate(scope, existing_ctes, taken):
def _eliminate(
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
) -> t.Optional[exp.Expression]:
if scope.is_derived_table:
return _eliminate_derived_table(scope, existing_ctes, taken)

if scope.is_cte:
return _eliminate_cte(scope, existing_ctes, taken)

return None


def _eliminate_derived_table(scope, existing_ctes, taken):
def _eliminate_derived_table(
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
) -> t.Optional[exp.Expression]:
# This makes sure that we don't:
# - drop the "pivot" arg from a pivoted subquery
# - eliminate a lateral correlated subquery
Expand All @@ -121,7 +134,9 @@ def _eliminate_derived_table(scope, existing_ctes, taken):
return cte


def _eliminate_cte(scope, existing_ctes, taken):
def _eliminate_cte(
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
) -> t.Optional[exp.Expression]:
parent = scope.expression.parent
name, cte = _new_cte(scope, existing_ctes, taken)

Expand All @@ -140,7 +155,9 @@ def _eliminate_cte(scope, existing_ctes, taken):
return cte


def _new_cte(scope, existing_ctes, taken):
def _new_cte(
scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping
) -> t.Tuple[str, t.Optional[exp.Expression]]:
"""
Returns:
tuple of (name, cte)
Expand Down
68 changes: 30 additions & 38 deletions sqlglot/optimizer/merge_subqueries.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from __future__ import annotations

import typing as t

from collections import defaultdict

from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import Scope, traverse_scope

if t.TYPE_CHECKING:
from sqlglot._typing import E

FromOrJoin = t.Union[exp.From, exp.Join]

def merge_subqueries(expression, leave_tables_isolated=False):

def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E:
"""
Rewrite sqlglot AST to merge derived tables into the outer query.

Expand Down Expand Up @@ -58,7 +67,7 @@ def merge_subqueries(expression, leave_tables_isolated=False):
)


def merge_ctes(expression, leave_tables_isolated=False):
def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E:
scopes = traverse_scope(expression)

# All places where we select from CTEs.
Expand Down Expand Up @@ -92,7 +101,7 @@ def merge_ctes(expression, leave_tables_isolated=False):
return expression


def merge_derived_tables(expression, leave_tables_isolated=False):
def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) -> E:
for outer_scope in traverse_scope(expression):
for subquery in outer_scope.derived_tables:
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
Expand All @@ -111,17 +120,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
return expression


def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
def _mergeable(
outer_scope: Scope, inner_scope: Scope, leave_tables_isolated: bool, from_or_join: FromOrJoin
) -> bool:
"""
Return True if `inner_select` can be merged into outer query.

Args:
outer_scope (Scope)
inner_scope (Scope)
leave_tables_isolated (bool)
from_or_join (exp.From|exp.Join)
Returns:
bool: True if can be merged
"""
inner_select = inner_scope.expression.unnest()

Expand Down Expand Up @@ -195,7 +198,7 @@ def _is_recursive():
and not outer_scope.expression.is_star
and isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from")
and inner_select.args.get("from") is not None
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
and not outer_scope.pivots
and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions)
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
Expand All @@ -218,14 +221,9 @@ def _is_recursive():
)


def _rename_inner_sources(outer_scope, inner_scope, alias):
def _rename_inner_sources(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
"""
Renames any sources in the inner query that conflict with names in the outer query.

Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
alias (str)
"""
inner_taken = set(inner_scope.selected_sources)
outer_taken = set(outer_scope.selected_sources)
Expand Down Expand Up @@ -253,15 +251,14 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
inner_scope.rename_source(conflict, new_name)


def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
def _merge_from(
outer_scope: Scope,
inner_scope: Scope,
node_to_replace: t.Union[exp.Subquery, exp.Table],
alias: str,
) -> None:
"""
Merge FROM clause of inner query into outer query.

Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
node_to_replace (exp.Subquery|exp.Table)
alias (str)
"""
new_subquery = inner_scope.expression.args["from"].this
new_subquery.set("joins", node_to_replace.args.get("joins"))
Expand All @@ -277,14 +274,9 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
)


def _merge_joins(outer_scope, inner_scope, from_or_join):
def _merge_joins(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
"""
Merge JOIN clauses of inner query into outer query.

Args:
outer_scope (sqlglot.optimizer.scope.Scope)
inner_scope (sqlglot.optimizer.scope.Scope)
from_or_join (exp.From|exp.Join)
"""

new_joins = []
Expand All @@ -307,7 +299,7 @@ def _merge_joins(outer_scope, inner_scope, from_or_join):
outer_scope.expression.set("joins", outer_joins)


def _merge_expressions(outer_scope, inner_scope, alias):
def _merge_expressions(outer_scope: Scope, inner_scope: Scope, alias: str) -> None:
"""
Merge projections of inner query into outer query.

Expand Down Expand Up @@ -341,7 +333,7 @@ def _merge_expressions(outer_scope, inner_scope, alias):
column.replace(expression.copy())


def _merge_where(outer_scope, inner_scope, from_or_join):
def _merge_where(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None:
"""
Merge WHERE clause of inner query into outer query.

Expand All @@ -360,7 +352,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
# Merge predicates from an outer join to the ON clause
# if it only has columns that are already joined
from_ = expression.args.get("from")
sources = {from_.alias_or_name} if from_ else {}
sources = {from_.alias_or_name} if from_ else set()

for join in expression.args["joins"]:
source = join.alias_or_name
Expand All @@ -376,7 +368,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
expression.where(where.this, copy=False)


def _merge_order(outer_scope, inner_scope):
def _merge_order(outer_scope: Scope, inner_scope: Scope) -> None:
"""
Merge ORDER clause of inner query into outer query.

Expand All @@ -396,7 +388,7 @@ def _merge_order(outer_scope, inner_scope):
outer_scope.expression.set("order", inner_scope.expression.args.get("order"))


def _merge_hints(outer_scope, inner_scope):
def _merge_hints(outer_scope: Scope, inner_scope: Scope) -> None:
inner_scope_hint = inner_scope.expression.args.get("hint")
if not inner_scope_hint:
return
Expand All @@ -408,7 +400,7 @@ def _merge_hints(outer_scope, inner_scope):
outer_scope.expression.set("hint", inner_scope_hint)


def _pop_cte(inner_scope):
def _pop_cte(inner_scope: Scope) -> None:
"""
Remove CTE from the AST.

Expand Down
Loading