Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer committed Apr 19, 2024
1 parent b8dbbc6 commit 58aa8fb
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ class ConstantFoldingStrategy(OptimizationStrategy):
def visit(self, node: LogicalPlanNode, context: OptimizerContext) -> OptimizerContext:
"""
Constant Folding is when we precalculate expressions (or sub expressions)
which contain only constant or literal or literal values. These don't
tend to happen IRL, but it's a simple enough strategy so should be
included.
which contain only constant or literal values.
"""
if not context.optimized_plan:
context.optimized_plan = context.pre_optimized_tree.copy() # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ def _add_condition(existing_condition, new_condition):

def _rewrite_predicate(predicate):
"""
Rewrite individual predicates
Rewrite individual predicates to forms able to push to more places
"""
if predicate.value in LIKE_REWRITES:
# LIKE conditions with no wildcards => Eq
if (
predicate.right.node_type == NodeType.LITERAL
and "%" not in predicate.right.value
Expand All @@ -51,6 +52,7 @@ def _rewrite_predicate(predicate):
predicate.value = LIKE_REWRITES[predicate.value]
return predicate
if predicate.value in IN_REWRITES:
# IN conditions on single values => Eq
if predicate.right.node_type == NodeType.LITERAL and len(predicate.right.value) == 1:
predicate.value = IN_REWRITES[predicate.value]
predicate.right.value = predicate.right.value.pop()
Expand Down Expand Up @@ -156,11 +158,24 @@ def _inner(node):
elif node.type == "cross join" and node.unnest_column:
# if it's a CROSS JOIN UNNEST - don't try to push any further
# IMPROVE: we should push everything that doesn't reference the unnested column
# don't push filters we can't resolve here though
remaining_predicates = []
for predicate in context.collected_predicates:
context.optimized_plan.insert_node_after(
predicate.nid, predicate, context.node_id
)
context.collected_predicates = []
known_columns = set(col.schema_column.identity for col in predicate.columns)
query_columns = {
predicate.condition.left.schema_column.identity,
predicate.condition.right.schema_column.identity,
}
if (
query_columns == (known_columns)
or node.unnest_target.identity in query_columns
):
context.optimized_plan.insert_node_after(
predicate.nid, predicate, context.node_id
)
else:
remaining_predicates.append(predicate)
context.collected_predicates = remaining_predicates
elif node.type in ("cross join",): # , "inner"):
# IMPROVE: add predicates to INNER JOIN conditions
# we may be able to rewrite as an inner join
Expand Down
2 changes: 1 addition & 1 deletion opteryx/components/logical_planner/logical_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __str__(self):
if self.function == "VALUES":
return f"VALUES (({', '.join(self.columns)}) x {len(self.values)} AS {self.alias})"
if self.function == "UNNEST":
return f"UNNEST ({', '.join(format_expression(arg) for arg in self.args)}{' AS ' + self.alias if self.alias else ''})"
return f"UNNEST ({', '.join(format_expression(arg) for arg in self.args)}{' AS ' + self.unnest_target if self.unnest_target else ''})"
if node_type == LogicalPlanStepType.Filter:
return f"FILTER ({format_expression(self.condition)})"
if node_type == LogicalPlanStepType.Join:
Expand Down
5 changes: 5 additions & 0 deletions opteryx/managers/expression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,11 @@ def _inner_evaluate(root: Node, table: Table, context: ExecutionContext):
if node_type == NodeType.EXPRESSION_LIST:
values = [_inner_evaluate(val, table, context) for val in root.parameters]
return values
from opteryx.exceptions import ColumnNotFoundError

raise ColumnNotFoundError(
message=f"Unable to locate column '{root.source_column}' this is likely due to differences in SELECT and GROUP BY clauses."
)


def evaluate(expression: Node, table: Table, context: Optional[ExecutionContext] = None):
Expand Down
6 changes: 4 additions & 2 deletions tests/sql_battery/test_shapes_and_errors_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@
("SELECT LEFT(name, 1), COUNT(*) FROM $satellites GROUP BY LEFT(name, 1) ORDER BY 2 DESC", 21, 2, UnsupportedSyntaxError),
("SELECT LEFT(name, 1), COUNT(*) FROM $satellites GROUP BY name ORDER BY 2 DESC", 177, 2, UnsupportedSyntaxError),
("SELECT LEFT(name, 2) as le, COUNT(*) FROM $satellites GROUP BY LEFT(name, 2) ORDER BY 2 DESC", 87, 2, UnsupportedSyntaxError),
("SELECT LEFT(name, 2) as le, COUNT(*) FROM $satellites GROUP BY LEFT(name, 1)", 87, 2, TypeError),
("SELECT LEFT(name, 2) as le, COUNT(*) FROM $satellites GROUP BY LEFT(name, 1)", 87, 2, ColumnNotFoundError),
("SELECT RIGHT(name, 10), COUNT(*) FROM $satellites GROUP BY RIGHT(name, 10) ORDER BY 2 DESC", 177, 2, UnsupportedSyntaxError),
("SELECT RIGHT(name, 2) as le, COUNT(*) FROM $satellites GROUP BY RIGHT(name, 2) ORDER BY 2 DESC", 91, 2, UnsupportedSyntaxError),
("SELECT RIGHT(name, 2) as le, COUNT(*) FROM $satellites GROUP BY le ORDER BY 2 DESC", 91, 2, UnsupportedSyntaxError),
Expand Down Expand Up @@ -1537,7 +1537,9 @@
("SELECT * FROM $planets AS p INNER JOIN $planets AS s ON p.id = s.id WHERE p.name = 'Jupiter' AND p.id = 1.0", 0, 40, None),
("SELECT * FROM sqlite.planets AS p INNER JOIN sqlite.planets AS s ON p.id = s.id WHERE p.name RLIKE 'Jupiter' AND s.id = 1.0", 0, 40, None),
("SELECT * FROM sqlite.planets AS p INNER JOIN sqlite.planets AS s ON p.id = s.id WHERE p.name RLIKE 'Jupiter' AND s.name RLIKE 'Jupiter'", 1, 40, None),

# 1587
("SELECT name, Mission_Status, Mission FROM $astronauts CROSS JOIN UNNEST (missions) AS mission_names INNER JOIN $missions ON Mission = mission_names WHERE mission_names = 'Apollo 11'", 3, 3, None),
("SELECT name, Mission_Status, Mission FROM $astronauts CROSS JOIN UNNEST (missions) AS mission_names INNER JOIN $missions ON Mission = mission_names WHERE Mission = 'Apollo 11'", 3, 3, None),
]
# fmt:on

Expand Down

0 comments on commit 58aa8fb

Please sign in to comment.