From 58aa8fbc328e7ed9f965b60ecd9c83d494603d30 Mon Sep 17 00:00:00 2001 From: joocer Date: Fri, 19 Apr 2024 17:27:42 +0100 Subject: [PATCH] #1587 --- .../strategies/constant_folding.py | 4 +-- .../strategies/predicate_pushdown.py | 25 +++++++++++++++---- .../logical_planner/logical_planner.py | 2 +- opteryx/managers/expression/__init__.py | 5 ++++ .../test_shapes_and_errors_battery.py | 6 +++-- 5 files changed, 31 insertions(+), 11 deletions(-) diff --git a/opteryx/components/cost_based_optimizer/strategies/constant_folding.py b/opteryx/components/cost_based_optimizer/strategies/constant_folding.py index 7ff81847c..a34d3abfd 100644 --- a/opteryx/components/cost_based_optimizer/strategies/constant_folding.py +++ b/opteryx/components/cost_based_optimizer/strategies/constant_folding.py @@ -73,9 +73,7 @@ class ConstantFoldingStrategy(OptimizationStrategy): def visit(self, node: LogicalPlanNode, context: OptimizerContext) -> OptimizerContext: """ Constant Folding is when we precalculate expressions (or sub expressions) - which contain only constant or literal or literal values. These don't - tend to happen IRL, but it's a simple enough strategy so should be - included. + which contain only constant or literal values. """ if not context.optimized_plan: context.optimized_plan = context.pre_optimized_tree.copy() # type: ignore diff --git a/opteryx/components/cost_based_optimizer/strategies/predicate_pushdown.py b/opteryx/components/cost_based_optimizer/strategies/predicate_pushdown.py index c9f8f1793..80cdfe657 100644 --- a/opteryx/components/cost_based_optimizer/strategies/predicate_pushdown.py +++ b/opteryx/components/cost_based_optimizer/strategies/predicate_pushdown.py @@ -40,9 +40,10 @@ def _add_condition(existing_condition, new_condition): def _rewrite_predicate(predicate): """ - Rewrite individual predicates + Rewrite individual predicates to forms able to push to more places """ if predicate.value in LIKE_REWRITES: + # LIKE conditions with no wildcards => Eq if ( predicate.right.node_type == NodeType.LITERAL and "%" not in predicate.right.value @@ -51,6 +52,7 @@ def _rewrite_predicate(predicate): predicate.value = LIKE_REWRITES[predicate.value] return predicate if predicate.value in IN_REWRITES: + # IN conditions on single values => Eq if predicate.right.node_type == NodeType.LITERAL and len(predicate.right.value) == 1: predicate.value = IN_REWRITES[predicate.value] predicate.right.value = predicate.right.value.pop() @@ -156,11 +158,24 @@ def _inner(node): elif node.type == "cross join" and node.unnest_column: # if it's a CROSS JOIN UNNEST - don't try to push any further # IMPROVE: we should push everything that doesn't reference the unnested column + # don't push filters we can't resolve here though + remaining_predicates = [] for predicate in context.collected_predicates: - context.optimized_plan.insert_node_after( - predicate.nid, predicate, context.node_id - ) - context.collected_predicates = [] + known_columns = set(col.schema_column.identity for col in predicate.columns) + query_columns = { + predicate.condition.left.schema_column.identity, + predicate.condition.right.schema_column.identity, + } + if ( + query_columns == (known_columns) + or node.unnest_target.identity in query_columns + ): + context.optimized_plan.insert_node_after( + predicate.nid, predicate, context.node_id + ) + else: + remaining_predicates.append(predicate) + context.collected_predicates = remaining_predicates elif node.type in ("cross join",): # , "inner"): # IMPROVE: add predicates to INNER JOIN conditions # we may be able to rewrite as an inner join diff --git a/opteryx/components/logical_planner/logical_planner.py b/opteryx/components/logical_planner/logical_planner.py index 7ceb1a441..60a46babf 100644 --- a/opteryx/components/logical_planner/logical_planner.py +++ b/opteryx/components/logical_planner/logical_planner.py @@ -124,7 +124,7 @@ def __str__(self): if self.function == "VALUES": return f"VALUES (({', '.join(self.columns)}) x {len(self.values)} AS {self.alias})" if self.function == "UNNEST": - return f"UNNEST ({', '.join(format_expression(arg) for arg in self.args)}{' AS ' + self.alias if self.alias else ''})" + return f"UNNEST ({', '.join(format_expression(arg) for arg in self.args)}{' AS ' + self.unnest_target if self.unnest_target else ''})" if node_type == LogicalPlanStepType.Filter: return f"FILTER ({format_expression(self.condition)})" if node_type == LogicalPlanStepType.Join: diff --git a/opteryx/managers/expression/__init__.py b/opteryx/managers/expression/__init__.py index 8f2db9c02..84da2c35c 100644 --- a/opteryx/managers/expression/__init__.py +++ b/opteryx/managers/expression/__init__.py @@ -298,6 +298,11 @@ def _inner_evaluate(root: Node, table: Table, context: ExecutionContext): if node_type == NodeType.EXPRESSION_LIST: values = [_inner_evaluate(val, table, context) for val in root.parameters] return values + from opteryx.exceptions import ColumnNotFoundError + + raise ColumnNotFoundError( + message=f"Unable to locate column '{root.source_column}' this is likely due to differences in SELECT and GROUP BY clauses." + ) def evaluate(expression: Node, table: Table, context: Optional[ExecutionContext] = None): diff --git a/tests/sql_battery/test_shapes_and_errors_battery.py b/tests/sql_battery/test_shapes_and_errors_battery.py index cfbbcf061..440e223e1 100644 --- a/tests/sql_battery/test_shapes_and_errors_battery.py +++ b/tests/sql_battery/test_shapes_and_errors_battery.py @@ -486,7 +486,7 @@ ("SELECT LEFT(name, 1), COUNT(*) FROM $satellites GROUP BY LEFT(name, 1) ORDER BY 2 DESC", 21, 2, UnsupportedSyntaxError), ("SELECT LEFT(name, 1), COUNT(*) FROM $satellites GROUP BY name ORDER BY 2 DESC", 177, 2, UnsupportedSyntaxError), ("SELECT LEFT(name, 2) as le, COUNT(*) FROM $satellites GROUP BY LEFT(name, 2) ORDER BY 2 DESC", 87, 2, UnsupportedSyntaxError), - ("SELECT LEFT(name, 2) as le, COUNT(*) FROM $satellites GROUP BY LEFT(name, 1)", 87, 2, TypeError), + ("SELECT LEFT(name, 2) as le, COUNT(*) FROM $satellites GROUP BY LEFT(name, 1)", 87, 2, ColumnNotFoundError), ("SELECT RIGHT(name, 10), COUNT(*) FROM $satellites GROUP BY RIGHT(name, 10) ORDER BY 2 DESC", 177, 2, UnsupportedSyntaxError), ("SELECT RIGHT(name, 2) as le, COUNT(*) FROM $satellites GROUP BY RIGHT(name, 2) ORDER BY 2 DESC", 91, 2, UnsupportedSyntaxError), ("SELECT RIGHT(name, 2) as le, COUNT(*) FROM $satellites GROUP BY le ORDER BY 2 DESC", 91, 2, UnsupportedSyntaxError), @@ -1537,7 +1537,9 @@ ("SELECT * FROM $planets AS p INNER JOIN $planets AS s ON p.id = s.id WHERE p.name = 'Jupiter' AND p.id = 1.0", 0, 40, None), ("SELECT * FROM sqlite.planets AS p INNER JOIN sqlite.planets AS s ON p.id = s.id WHERE p.name RLIKE 'Jupiter' AND s.id = 1.0", 0, 40, None), ("SELECT * FROM sqlite.planets AS p INNER JOIN sqlite.planets AS s ON p.id = s.id WHERE p.name RLIKE 'Jupiter' AND s.name RLIKE 'Jupiter'", 1, 40, None), - + # 1587 + ("SELECT name, Mission_Status, Mission FROM $astronauts CROSS JOIN UNNEST (missions) AS mission_names INNER JOIN $missions ON Mission = mission_names WHERE mission_names = 'Apollo 11'", 3, 3, None), + ("SELECT name, Mission_Status, Mission FROM $astronauts CROSS JOIN UNNEST (missions) AS mission_names INNER JOIN $missions ON Mission = mission_names WHERE Mission = 'Apollo 11'", 3, 3, None), ] # fmt:on