From 0e39d9bbac930e8588e6e695cf1f648d5cbe2035 Mon Sep 17 00:00:00 2001 From: joocer Date: Sat, 28 Sep 2024 21:32:33 +0100 Subject: [PATCH 1/2] #2029 --- opteryx/managers/expression/formatter.py | 6 +-- opteryx/operators/heap_sort_node.py | 49 ++++++++++++++++++- opteryx/planner/binder/binder_visitor.py | 45 ++++++++++++----- .../strategies/projection_pushdown.py | 18 +++++-- .../test_shapes_and_errors_battery.py | 2 + 5 files changed, 101 insertions(+), 19 deletions(-) diff --git a/opteryx/managers/expression/formatter.py b/opteryx/managers/expression/formatter.py index 69838f2b6..83477d0da 100644 --- a/opteryx/managers/expression/formatter.py +++ b/opteryx/managers/expression/formatter.py @@ -99,6 +99,9 @@ def format_expression(root, qualify: bool = False): "BitwiseXor": "^", "ShiftLeft": "<<", "ShiftRight": ">>", + "Arrow": "->", + "LongArrow": "->>", + "AtQuestion": "@?", } return f"{format_expression(root.left, qualify)} {_map.get(root.value, root.value).upper()} {format_expression(root.right, qualify)}" if node_type == NodeType.EXPRESSION_LIST: @@ -112,9 +115,6 @@ def format_expression(root, qualify: bool = False): "BitwiseOr": "|", "LtEq": "<=", "GtEq": ">=", - "Arrow": "->", - "LongArrow": "->>", - "AtQuestion": "@?", } return f"{format_expression(root.left, qualify)} {_map.get(root.value, root.value).upper()} {format_expression(root.right, qualify)}" if node_type == NodeType.UNARY_OPERATOR: diff --git a/opteryx/operators/heap_sort_node.py b/opteryx/operators/heap_sort_node.py index 5f1b8fe2a..b089038fa 100644 --- a/opteryx/operators/heap_sort_node.py +++ b/opteryx/operators/heap_sort_node.py @@ -28,7 +28,9 @@ from dataclasses import dataclass from typing import Generator +import numpy import pyarrow +import pyarrow.compute from pyarrow import concat_tables from opteryx.exceptions import ColumnNotFoundError @@ -96,8 +98,51 @@ def execute(self) -> Generator[pyarrow.Table, None, None]: # pragma: no cover else: table = morsel - # Sort and slice the concatenated table to maintain the limit - table = table.sort_by(mapped_order).slice(offset=0, length=self.limit) + # Determine if any columns are string-based + use_pyarrow_sort = any( + pyarrow.types.is_string(table.column(column_name).type) + or pyarrow.types.is_binary(table.column(column_name).type) + for column_name, _ in mapped_order + ) + + # strings are sorted faster user pyarrow, single columns faster using compute + if len(mapped_order) == 1 and use_pyarrow_sort: + column_name, sort_direction = mapped_order[0] + column = table.column(column_name) + if sort_direction == "ascending": + sort_indices = pyarrow.compute.sort_indices(column) + else: + sort_indices = pyarrow.compute.sort_indices(column)[::-1] + table = table.take(sort_indices[: self.limit]) + # strings are sorted faster using pyarrow + elif use_pyarrow_sort: + table = table.sort_by(mapped_order).slice(offset=0, length=self.limit) + # single column sort using numpy + elif len(mapped_order) == 1: + # Single-column sort using mergesort to take advantage of partially sorted data + column_name, sort_direction = mapped_order[0] + column = table.column(column_name).to_numpy() + if sort_direction == "ascending": + sort_indices = numpy.argsort(column) + else: + sort_indices = numpy.argsort(column)[::-1] # Reverse for descending + # Slice the sorted table + table = table.take(sort_indices[: self.limit]) + # multi column sort using numpy + else: + # Multi-column sort using lexsort + columns_for_sorting = [] + directions = [] + for column_name, sort_direction in mapped_order: + column = table.column(column_name).to_numpy() + columns_for_sorting.append(column) + directions.append(1 if sort_direction == "ascending" else -1) + + sort_indices = numpy.lexsort( + [col[::direction] for col, direction in zip(columns_for_sorting, directions)] + ) + # Slice the sorted table + table = table.take(sort_indices[: self.limit]) self.statistics.time_heap_sorting += time.time_ns() - start_time diff --git a/opteryx/planner/binder/binder_visitor.py b/opteryx/planner/binder/binder_visitor.py index a903abec7..18d696e32 100644 --- a/opteryx/planner/binder/binder_visitor.py +++ b/opteryx/planner/binder/binder_visitor.py @@ -789,36 +789,58 @@ def visit_project(self, node: Node, context: BindingContext) -> Tuple[Node, Bind context.schemas = merge_schemas(*[ctx.schemas for ctx in group_contexts]) # Check for duplicates - all_identities = [c.schema_column.identity for c in node.columns] - if len(set(all_identities)) != len(all_identities): + all_top_level_identities = [c.schema_column.identity for c in node.columns] + if len(set(all_top_level_identities)) != len(all_top_level_identities): from collections import Counter from opteryx.exceptions import AmbiguousIdentifierError - duplicates = [column for column, count in Counter(all_identities).items() if count > 1] + duplicates = [ + column for column, count in Counter(all_top_level_identities).items() if count > 1 + ] matches = {c.value for c in node.columns if c.schema_column.identity in duplicates} raise AmbiguousIdentifierError( message=f"Query result contains multiple instances of the same column(s) - `{'`, `'.join(matches)}`" ) + # get any column or field from a realtion referenced + # 1984 + all_identities = set( + [ + item.schema_column.identity + for sublist in [ + get_all_nodes_of_type(c, (NodeType.IDENTIFIER,)) for c in node.columns + ] + for item in sublist + ] + + all_top_level_identities + ) + # Remove columns not being projected from the schemas, and remove empty schemas columns = [] for relation, schema in list(context.schemas.items()): schema_columns = [ - column for column in schema.columns if column.identity in all_identities + column for column in schema.columns if column.identity in all_top_level_identities ] if len(schema_columns) == 0: context.schemas.pop(relation) else: for column in schema_columns: - node_column = [ - n for n in node.columns if n.schema_column.identity == column.identity - ][0] - if node_column.alias: - node_column.schema_column.aliases.append(node_column.alias) - column.aliases.append(node_column.alias) + # for each column in the schema, try to find the node's columns + node_column = next( + (n for n in node.columns if n.schema_column.identity == column.identity), + None, + ) + if node_column: + # update the column reference with any AS aliases + if node_column.alias: + node_column.schema_column.aliases.append(node_column.alias) + column.aliases.append(node_column.alias) + # update the schema with columns we have references to, removing redundant columns schema.columns = schema_columns for column in node.columns: + # indirect references are when we're keeping a column for a function or sort + # 1984 column.direct_reference = column.identity in all_top_level_identities if column.schema_column.identity in [i.identity for i in schema_columns]: columns.append(column) @@ -964,7 +986,8 @@ def visit_subquery(self, node: Node, context: BindingContext) -> Tuple[Node, Bin if not schema_column.origin: schema_column.origin = [] source_relations.extend(schema_column.origin or []) - projection_column.source = node.alias + if projection_column: + projection_column.source = node.alias schema_column.origin += [node.alias] schema_column.name = ( diff --git a/opteryx/planner/cost_based_optimizer/strategies/projection_pushdown.py b/opteryx/planner/cost_based_optimizer/strategies/projection_pushdown.py index 1ff6ba9ba..39a70585b 100644 --- a/opteryx/planner/cost_based_optimizer/strategies/projection_pushdown.py +++ b/opteryx/planner/cost_based_optimizer/strategies/projection_pushdown.py @@ -36,9 +36,21 @@ def visit(self, node: LogicalPlanNode, context: OptimizerContext) -> OptimizerCo A tuple containing the potentially modified node and the updated context. """ node.pre_update_columns = set(context.collected_identities) - if node.columns: # Assumes node.columns is an iterable or None - collected_columns = self.collect_columns(node) - context.collected_identities.update(collected_columns) + + # If we're at a project, we only keep the columns that are referenced + # this is mainly when we have columns in a subquery which aren't used + # in the outer query + # 1984 + # if node.node_type == LogicalPlanStepType.Project: + # node.columns = [ + # n for n in node.columns if n.schema_column.identity in context.collected_identities + # ] + + # Subqueries act like all columns are referenced + if node.node_type != LogicalPlanStepType.Subquery: + if node.columns: # Assumes node.columns is an iterable or None + collected_columns = self.collect_columns(node) + context.collected_identities.update(collected_columns) if ( node.node_type diff --git a/tests/sql_battery/test_shapes_and_errors_battery.py b/tests/sql_battery/test_shapes_and_errors_battery.py index 91300cfe2..b7aa46fa5 100644 --- a/tests/sql_battery/test_shapes_and_errors_battery.py +++ b/tests/sql_battery/test_shapes_and_errors_battery.py @@ -1718,6 +1718,8 @@ ("SHOW CREATE VIEW mission.reports", 1, 1, DatasetNotFoundError), ("SHOW CREATE TABLE mission_reports", 1, 1, UnsupportedSyntaxError), + ("SELECT name FROM (SELECT MD5(name) AS hash, name FROM $planets) AS S", 9, 1, None), + # **************************************************************************************** # These are queries which have been found to return the wrong result or not run correctly From 916ab036ccfac6552ad7b2b13a3fc1f6d481a3c4 Mon Sep 17 00:00:00 2001 From: XB500 Date: Sat, 28 Sep 2024 20:32:58 +0000 Subject: [PATCH 2/2] Opteryx Version 0.17.2-alpha.809 --- opteryx/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opteryx/__version__.py b/opteryx/__version__.py index e6ef632a3..9a832d156 100644 --- a/opteryx/__version__.py +++ b/opteryx/__version__.py @@ -1,4 +1,4 @@ -__build__ = 808 +__build__ = 809 # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.