Skip to content

Commit

Permalink
Merge pull request #2020 from mabel-dev/#2019
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer authored Sep 18, 2024
2 parents 1278363 + 2890776 commit d4005b9
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 43 deletions.
2 changes: 1 addition & 1 deletion opteryx/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__build__ = 799
__build__ = 801

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
3 changes: 3 additions & 0 deletions opteryx/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,9 @@ class MissingSqlStatement(ProgrammingError):
class InconsistentSchemaError(DataError):
"""Raised when, despite efforts, we can't get a consistent schema."""

def __init__(*args, **kwargs):
pass


class DatasetReadError(DataError):
"""Raised when we can't read the data we're pretty sure is there"""
Expand Down
4 changes: 3 additions & 1 deletion opteryx/models/logical_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def copy(self):
source_column=self.source_column,
source=self.source,
alias=self.alias,
schema_column=self.schema_column,
schema_column=None
if self.schema_column is None
else self.schema_column.to_flatcolumn(),
)

def __repr__(self) -> str:
Expand Down
8 changes: 4 additions & 4 deletions opteryx/operators/exit_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ def execute(self) -> Generator:
if len(set(final_names)) != len(final_names): # we have duplicate names
final_names = []
for column in self.columns:
if column.schema_column.origin:
final_names.append(f"{column.schema_column.origin[0]}.{column.current_name}")
else:
final_names.append(column.qualified_name)
# if column.schema_column.origin:
# final_names.append(f"{column.schema_column.origin[0]}.{column.current_name}")
# else:
final_names.append(column.qualified_name)

self.statistics.time_exiting += time.monotonic_ns() - start
for morsel in morsels.execute():
Expand Down
4 changes: 4 additions & 0 deletions opteryx/operators/outer_join_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def left_join(left_relation, right_relation, left_columns: List[str], right_colu
right_indexes = deque()

right_relation = pyarrow.concat_tables(right_relation.execute(), promote_options="none")

if len(set(left_columns) & set(right_relation.column_names)) > 0:
left_columns, right_columns = right_columns, left_columns

right_hash = hash_join_map(right_relation, right_columns)

for left_batch in left_relation.execute():
Expand Down
16 changes: 16 additions & 0 deletions opteryx/planner/binder/binder_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,8 @@ def name_column(qualifier, column):
if projection_column.current_name:
return projection_column.current_name

if needs_qualifier:
return f"{qualifier}.{column.name}"
return column.name

def keep_column(column, identities):
Expand Down Expand Up @@ -931,6 +933,20 @@ def visit_show_columns(
def visit_subquery(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]:
node, context = self.visit_exit(node, context)

# Extract the column names to check for duplicates
column_names = [n.schema_column.name for n in node.columns]
seen = set()
duplicates = [name for name in column_names if name in seen or seen.add(name)]

# Now you can check if there are any duplicates and take action accordingly
if duplicates:
from opteryx.exceptions import AmbiguousIdentifierError

raise AmbiguousIdentifierError(
identifier=duplicates,
message=f"Column name collision in subquery '{node.alias}'; Column(s) {', '.join(duplicates)} is ambiguous in the outer query, use AS to provide unique names for these columns.",
)

# we sack all the tables we previously knew and create a new set of schemas here
columns: list = []
source_relations: list = []
Expand Down
68 changes: 40 additions & 28 deletions opteryx/utils/file_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from typing import BinaryIO
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
Expand Down Expand Up @@ -104,21 +103,35 @@ def do_nothing(buffer: Union[memoryview, bytes], **kwargs): # pragma: no cover
return False


def filter_records(filter: Optional[Union[List, Tuple]], table: pyarrow.Table):
def filter_records(filters: Optional[list], table: pyarrow.Table) -> pyarrow.Table:
"""
When we can't push predicates to the actual read, use this to filter records
just after the read.
Apply filters to a PyArrow table that could not be pushed down during the read operation.
This is a post-read filtering step.
Parameters:
filters: Optional[list]
A list of filter conditions (predicates) to apply to the table.
table: pyarrow.Table
The PyArrow table to be filtered.
Returns:
pyarrow.Table:
A new PyArrow table with rows filtered according to the specified conditions.
Note:
At this point the columns are the raw column names from the file so we need to ensure
the filters reference the raw column names not the engine internal 'identity'=
"""
# notes:
# at this point we've not renamed any columns, this may affect some filters
from opteryx.managers.expression import NodeType
from opteryx.managers.expression import evaluate
from opteryx.models import Node

if isinstance(filter, list) and filter:
filter_copy = list(filter)
if isinstance(filters, list) and filters:
# Create a copy of the filters list to avoid mutating the original.
filter_copy = [f.copy() for f in filters]
root = filter_copy.pop()

# If the left or right side of the root filter node is an identifier, set its identity.
# This step ensures that the filtering logic aligns with the schema before any renaming.
if root.left.node_type == NodeType.IDENTIFIER:
root.left.schema_column.identity = root.left.source_column
if root.right.node_type == NodeType.IDENTIFIER:
Expand All @@ -130,14 +143,15 @@ def filter_records(filter: Optional[Union[List, Tuple]], table: pyarrow.Table):
right.left.schema_column.identity = right.left.source_column
if right.right.node_type == NodeType.IDENTIFIER:
right.right.schema_column.identity = right.right.source_column
# Combine the current root with the next filter using an AND node.
root = Node(
NodeType.AND,
left=root,
right=right,
schema_column=Node("schema_column", identity=random_string()),
)
else:
root = filter
root = filters

mask = evaluate(root, table)
return table.filter(mask)
Expand Down Expand Up @@ -212,11 +226,11 @@ def parquet_decoder(
Returns:
Tuple containing number of rows, number of columns, and the table or schema.
"""
selected_columns = None

# we need to work out if we have a selection which may force us
# fetching columns just for filtering
dnf_filter, selection = PredicatePushable.to_dnf(selection) if selection else (None, None)
dnf_filter, processed_selection = (
PredicatePushable.to_dnf(selection) if selection else (None, None)
)

# try to avoid turning a memoryview buffer into bytes, it's quite slow
stream: BinaryIO = (
Expand All @@ -230,33 +244,31 @@ def parquet_decoder(
if just_schema:
return convert_arrow_schema_to_orso_schema(parquet_file.schema_arrow)

# Projection processing
columns_in_filters = {c.value for c in get_all_nodes_of_type(selection, (NodeType.IDENTIFIER,))}
arrow_schema_columns_set = set(parquet_file.schema_arrow.names)
projection_names = {
name for proj_col in projection for name in proj_col.schema_column.all_names
}.union(columns_in_filters)
selected_columns = list(arrow_schema_columns_set.intersection(projection_names))

# If no columns are selected, set to None
if not selected_columns:
selected_columns = None
# Determine the columns needed for projection and filtering
projection_set = set(p.source_column for p in projection or [])
filter_columns = {
c.value for c in get_all_nodes_of_type(processed_selection, (NodeType.IDENTIFIER,))
}
selected_columns = list(
projection_set.union(filter_columns).intersection(parquet_file.schema_arrow.names)
)

if selected_columns is None and not force_read:
# Read all columns if none are selected, unless force_read is set
if not selected_columns and not force_read:
selected_columns = []

# Read the parquet table with the optimized column list and selection filters
table = parquet.read_table(
stream,
columns=selected_columns,
columns=selected_columns if selected_columns else None,
pre_buffer=False,
filters=dnf_filter,
use_threads=False,
)

# Any filters we couldn't push to PyArrow to read we run here
if selection:
table = filter_records(selection, table)
if processed_selection:
table = filter_records(processed_selection, table)

return (parquet_file.metadata.num_rows, parquet_file.metadata.num_columns, table)

Expand Down
24 changes: 20 additions & 4 deletions tests/fuzzing/test_sql_fuzzer_single_table_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,15 @@ def generate_random_sql_select(columns, table):
def test_sql_fuzzing_single_table(i):
seed = random_int()
random.seed(seed)
print(f"Seed: {seed}")

table = TABLES[random.choice(range(len(TABLES)))]
statement = generate_random_sql_select(table["fields"], table["name"])
formatted_statement = format_sql(statement)

print(formatted_statement)

print(f"Seed: {seed}, Cycle: {i}, ", end="")

start_time = time.time() # Start timing the query execution
try:
res = opteryx.query(statement)
Expand All @@ -183,9 +184,24 @@ def test_sql_fuzzing_single_table(i):
# Log failing statement and error for analysis
raise e
print()
return execution_time, statement

if __name__ == "__main__": # pragma: no cover
for i in range(TEST_CYCLES):
test_sql_fuzzing_single_table(i)
import heapq

print("✅ okay")
top_n: int = 5
slowest_executions = []

for i in range(TEST_CYCLES):
et, st = test_sql_fuzzing_single_table(i)

# Use a heap to maintain only the top N slowest executions
if len(slowest_executions) < top_n:
# If we have less than `top_n` elements, add the current result
heapq.heappush(slowest_executions, (et, i, st))
else:
# If we already have `top_n` elements, replace the smallest one if the current one is larger
heapq.heappushpop(slowest_executions, (et, i, st))

print("✅ okay\n")
print("\n".join(f"{s[1]:03} {s[0]:.4f} {format_sql(s[2])}" for s in sorted(slowest_executions, key=lambda x: x[0], reverse=True)))
18 changes: 13 additions & 5 deletions tests/sql_battery/test_shapes_and_errors_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1972,13 +1972,13 @@
# ("SELECT name, mission FROM (SELECT name, missions FROM $astronauts) as nauts CROSS JOIN UNNEST (nauts.missions) AS mission WHERE VARCHAR(mission) = 'Apollo 11'", 3, 2, None),
# ("SELECT name, mission FROM $astronauts CROSS JOIN UNNEST (missions) AS mission WHERE LEFT(mission, 2) = 'Apollo 11'", 0, 0, None),
# 1887
("SELECT * FROM (SELECT * FROM $satellites LEFT JOIN (SELECT * FROM $planets) AS p ON $satellites.planetId = p.id) AS mapped WHERE mass > 1", 170, 28, None),
("SELECT * FROM (SELECT * FROM $satellites LEFT JOIN $planets AS p ON $satellites.planetId = p.id) AS mapped WHERE mass > 1", 170, 28, None),
("SELECT * FROM (SELECT * FROM $satellites LEFT JOIN (SELECT id AS pid, mass FROM $planets) AS p ON $satellites.planetId = p.pid) AS mapped WHERE mass > 1", 170, 10, None),
("SELECT * FROM (SELECT planetId, mass FROM $satellites LEFT JOIN $planets AS p ON $satellites.planetId = p.id) AS mapped WHERE mass > 1", 170, 2, None),
("SELECT * FROM $satellites LEFT JOIN $planets AS p ON $satellites.planetId = p.id WHERE mass > 1", 170, 28, None),
# ("SELECT * FROM (SELECT * FROM (SELECT * FROM $satellites) AS s LEFT JOIN $planets AS p ON s.planetId = p.id) AS mapped WHERE mass > 1", 170, 28, None),
("SELECT * FROM (SELECT * FROM $satellites) AS s LEFT JOIN (SELECT * FROM $planets) AS p ON s.planetId = p.id WHERE mass > 1", 170, 28, None),
("SELECT * FROM (SELECT p.id, mass FROM (SELECT * FROM $satellites) AS s LEFT JOIN $planets AS p ON s.planetId = p.id) AS mapped WHERE mass > 1", 171, 2, None),
("SELECT * FROM (SELECT * FROM $satellites) AS s LEFT JOIN (SELECT id as pid, mass FROM $planets) AS p ON s.planetId = p.pid WHERE mass > 1", 170, 10, None),
("SELECT * FROM $satellites LEFT JOIN (SELECT * FROM (SELECT * FROM $planets) AS p) AS planets ON $satellites.planetId = planets.id WHERE mass > 1", 170, 28, None),
("SELECT * FROM (SELECT * FROM (SELECT * FROM $satellites LEFT JOIN $planets AS p ON $satellites.planetId = p.id) AS joined) AS mapped WHERE mass > 1", 170, 28, None),
("SELECT * FROM (SELECT * FROM (SELECT p.id, mass FROM $satellites LEFT JOIN $planets AS p ON $satellites.planetId = p.id) AS joined) AS mapped WHERE mass > 1", 170, 2, None),
# 1977
("SELECT s, e FROM GENERATE_SERIES('2024-01-01', '2025-01-01', '1mth') AS s CROSS JOIN GENERATE_SERIES('2024-01-01', '2025-01-01', '1mth') AS e WHERE s = e + INTERVAL '1' MONTH", 12, 2, None),
("SELECT s, e FROM GENERATE_SERIES('2024-01-01', '2025-01-01', '1mth') AS s CROSS JOIN GENERATE_SERIES('2024-01-01', '2025-01-01', '1mth') AS e WHERE s + INTERVAL '1' MONTH = e", 12, 2, None),
Expand All @@ -1992,6 +1992,14 @@
("SELECT * FROM testdata.flat.hosts WHERE address | '20.1.0.0/9'", 0, 2, None),
("SELECT * FROM testdata.flat.hosts WHERE address | '20.112.0.0/16'", 26, 2, None),
("SELECT * FROM testdata.flat.hosts WHERE address | '127.0.0.0/24'", 1, 2, None),
# 2019
("SELECT name, mass, density, rotationPeriod, lengthOfDay, perihelion, aphelion, orbitalVelocity, orbitalEccentricity, obliquityToOrbit, surfacePressure, numberOfMoons FROM testdata.planets WHERE orbitalVelocity <> 2787170570 AND NOT orbitalVelocity BETWEEN 2191745.934 AND 402288.158", 9, 12, None),
("SELECT DISTINCT id, gm, density, magnitude FROM testdata.satellites WHERE radius < 1286258.869 AND NOT id > 2730526.873 AND id IS NULL ORDER BY radius DESC", 0, 4, None),
("SELECT Company, Price, Mission FROM testdata.missions WHERE Price <= 4279346967 AND NOT Price BETWEEN 137294968 AND 2336093823 ORDER BY Company DESC LIMIT 9 ", 9, 3, None),

("SELECT * FROM (SELECT * FROM (SELECT * FROM $satellites LEFT JOIN $planets AS p ON $satellites.planetId = p.id) AS joined) AS mapped WHERE mass > 1", 170, 28, AmbiguousIdentifierError),
("SELECT * FROM (SELECT * FROM $satellites LEFT JOIN (SELECT * FROM $planets) AS p ON $satellites.planetId = p.id) AS mapped WHERE mass > 1", 170, 28, AmbiguousIdentifierError),
("SELECT * FROM (SELECT * FROM $satellites LEFT JOIN $planets AS p ON $satellites.planetId = p.id) AS mapped WHERE mass > 1", 170, 28, AmbiguousIdentifierError),
]
# fmt:on

Expand Down

0 comments on commit d4005b9

Please sign in to comment.