Skip to content

Commit

Permalink
Merge pull request #2031 from mabel-dev/#2029-2
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer authored Sep 28, 2024
2 parents 3a1f724 + 916ab03 commit 89ff100
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 20 deletions.
2 changes: 1 addition & 1 deletion opteryx/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__build__ = 808
__build__ = 809

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
6 changes: 3 additions & 3 deletions opteryx/managers/expression/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def format_expression(root, qualify: bool = False):
"BitwiseXor": "^",
"ShiftLeft": "<<",
"ShiftRight": ">>",
"Arrow": "->",
"LongArrow": "->>",
"AtQuestion": "@?",
}
return f"{format_expression(root.left, qualify)} {_map.get(root.value, root.value).upper()} {format_expression(root.right, qualify)}"
if node_type == NodeType.EXPRESSION_LIST:
Expand All @@ -112,9 +115,6 @@ def format_expression(root, qualify: bool = False):
"BitwiseOr": "|",
"LtEq": "<=",
"GtEq": ">=",
"Arrow": "->",
"LongArrow": "->>",
"AtQuestion": "@?",
}
return f"{format_expression(root.left, qualify)} {_map.get(root.value, root.value).upper()} {format_expression(root.right, qualify)}"
if node_type == NodeType.UNARY_OPERATOR:
Expand Down
49 changes: 47 additions & 2 deletions opteryx/operators/heap_sort_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from dataclasses import dataclass
from typing import Generator

import numpy
import pyarrow
import pyarrow.compute
from pyarrow import concat_tables

from opteryx.exceptions import ColumnNotFoundError
Expand Down Expand Up @@ -96,8 +98,51 @@ def execute(self) -> Generator[pyarrow.Table, None, None]: # pragma: no cover
else:
table = morsel

# Sort and slice the concatenated table to maintain the limit
table = table.sort_by(mapped_order).slice(offset=0, length=self.limit)
# Determine if any columns are string-based
use_pyarrow_sort = any(
pyarrow.types.is_string(table.column(column_name).type)
or pyarrow.types.is_binary(table.column(column_name).type)
for column_name, _ in mapped_order
)

# strings are sorted faster user pyarrow, single columns faster using compute
if len(mapped_order) == 1 and use_pyarrow_sort:
column_name, sort_direction = mapped_order[0]
column = table.column(column_name)
if sort_direction == "ascending":
sort_indices = pyarrow.compute.sort_indices(column)
else:
sort_indices = pyarrow.compute.sort_indices(column)[::-1]
table = table.take(sort_indices[: self.limit])
# strings are sorted faster using pyarrow
elif use_pyarrow_sort:
table = table.sort_by(mapped_order).slice(offset=0, length=self.limit)
# single column sort using numpy
elif len(mapped_order) == 1:
# Single-column sort using mergesort to take advantage of partially sorted data
column_name, sort_direction = mapped_order[0]
column = table.column(column_name).to_numpy()
if sort_direction == "ascending":
sort_indices = numpy.argsort(column)
else:
sort_indices = numpy.argsort(column)[::-1] # Reverse for descending
# Slice the sorted table
table = table.take(sort_indices[: self.limit])
# multi column sort using numpy
else:
# Multi-column sort using lexsort
columns_for_sorting = []
directions = []
for column_name, sort_direction in mapped_order:
column = table.column(column_name).to_numpy()
columns_for_sorting.append(column)
directions.append(1 if sort_direction == "ascending" else -1)

sort_indices = numpy.lexsort(
[col[::direction] for col, direction in zip(columns_for_sorting, directions)]
)
# Slice the sorted table
table = table.take(sort_indices[: self.limit])

self.statistics.time_heap_sorting += time.time_ns() - start_time

Expand Down
45 changes: 34 additions & 11 deletions opteryx/planner/binder/binder_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,36 +789,58 @@ def visit_project(self, node: Node, context: BindingContext) -> Tuple[Node, Bind
context.schemas = merge_schemas(*[ctx.schemas for ctx in group_contexts])

# Check for duplicates
all_identities = [c.schema_column.identity for c in node.columns]
if len(set(all_identities)) != len(all_identities):
all_top_level_identities = [c.schema_column.identity for c in node.columns]
if len(set(all_top_level_identities)) != len(all_top_level_identities):
from collections import Counter

from opteryx.exceptions import AmbiguousIdentifierError

duplicates = [column for column, count in Counter(all_identities).items() if count > 1]
duplicates = [
column for column, count in Counter(all_top_level_identities).items() if count > 1
]
matches = {c.value for c in node.columns if c.schema_column.identity in duplicates}
raise AmbiguousIdentifierError(
message=f"Query result contains multiple instances of the same column(s) - `{'`, `'.join(matches)}`"
)

# get any column or field from a realtion referenced
# 1984
all_identities = set(
[
item.schema_column.identity
for sublist in [
get_all_nodes_of_type(c, (NodeType.IDENTIFIER,)) for c in node.columns
]
for item in sublist
]
+ all_top_level_identities
)

# Remove columns not being projected from the schemas, and remove empty schemas
columns = []
for relation, schema in list(context.schemas.items()):
schema_columns = [
column for column in schema.columns if column.identity in all_identities
column for column in schema.columns if column.identity in all_top_level_identities
]
if len(schema_columns) == 0:
context.schemas.pop(relation)
else:
for column in schema_columns:
node_column = [
n for n in node.columns if n.schema_column.identity == column.identity
][0]
if node_column.alias:
node_column.schema_column.aliases.append(node_column.alias)
column.aliases.append(node_column.alias)
# for each column in the schema, try to find the node's columns
node_column = next(
(n for n in node.columns if n.schema_column.identity == column.identity),
None,
)
if node_column:
# update the column reference with any AS aliases
if node_column.alias:
node_column.schema_column.aliases.append(node_column.alias)
column.aliases.append(node_column.alias)
# update the schema with columns we have references to, removing redundant columns
schema.columns = schema_columns
for column in node.columns:
# indirect references are when we're keeping a column for a function or sort
# 1984 column.direct_reference = column.identity in all_top_level_identities
if column.schema_column.identity in [i.identity for i in schema_columns]:
columns.append(column)

Expand Down Expand Up @@ -964,7 +986,8 @@ def visit_subquery(self, node: Node, context: BindingContext) -> Tuple[Node, Bin
if not schema_column.origin:
schema_column.origin = []
source_relations.extend(schema_column.origin or [])
projection_column.source = node.alias
if projection_column:
projection_column.source = node.alias
schema_column.origin += [node.alias]

schema_column.name = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,21 @@ def visit(self, node: LogicalPlanNode, context: OptimizerContext) -> OptimizerCo
A tuple containing the potentially modified node and the updated context.
"""
node.pre_update_columns = set(context.collected_identities)
if node.columns: # Assumes node.columns is an iterable or None
collected_columns = self.collect_columns(node)
context.collected_identities.update(collected_columns)

# If we're at a project, we only keep the columns that are referenced
# this is mainly when we have columns in a subquery which aren't used
# in the outer query
# 1984
# if node.node_type == LogicalPlanStepType.Project:
# node.columns = [
# n for n in node.columns if n.schema_column.identity in context.collected_identities
# ]

# Subqueries act like all columns are referenced
if node.node_type != LogicalPlanStepType.Subquery:
if node.columns: # Assumes node.columns is an iterable or None
collected_columns = self.collect_columns(node)
context.collected_identities.update(collected_columns)

if (
node.node_type
Expand Down
2 changes: 2 additions & 0 deletions tests/sql_battery/test_shapes_and_errors_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1718,6 +1718,8 @@
("SHOW CREATE VIEW mission.reports", 1, 1, DatasetNotFoundError),
("SHOW CREATE TABLE mission_reports", 1, 1, UnsupportedSyntaxError),

("SELECT name FROM (SELECT MD5(name) AS hash, name FROM $planets) AS S", 9, 1, None),

# ****************************************************************************************

# These are queries which have been found to return the wrong result or not run correctly
Expand Down

0 comments on commit 89ff100

Please sign in to comment.