Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer committed Sep 5, 2024
1 parent d349bbb commit 9918636
Show file tree
Hide file tree
Showing 13 changed files with 228 additions and 104 deletions.
4 changes: 4 additions & 0 deletions opteryx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import datetime
import os
import warnings

from pathlib import Path
from decimal import getcontext
Expand Down Expand Up @@ -239,3 +240,6 @@ def opteryx(self, line, cell):
ip.register_magics(OpteryxMagics)
except Exception as err: # no sec
pass

# Enable all warnings, including DeprecationWarning
warnings.simplefilter("once", DeprecationWarning)
19 changes: 9 additions & 10 deletions opteryx/connectors/capabilities/predicate_pushable.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
accept filters and others don't so we 'fake' the read-time filtering.
"""

import datetime
from typing import Dict

from orso.types import OrsoTypes
Expand Down Expand Up @@ -67,7 +68,9 @@ def __init__(self, **kwargs):
@staticmethod
def to_dnf(root):
"""
Convert a filter to DNF form, this is the form used by pyarrow
Convert a filter to DNF form, this is the form used by PyArrow.
This is specifically opinionated for the Parquet reader for PyArrow.
"""

def _predicate_to_dnf(root):
Expand All @@ -85,19 +88,15 @@ def _predicate_to_dnf(root):
raise NotSupportedError()
if root.left.node_type != NodeType.IDENTIFIER:
root.left, root.right = root.right, root.left
if root.right.type in (OrsoTypes.DATE, OrsoTypes.TIMESTAMP):
raise NotSupportedError()
if root.right.type in (OrsoTypes.DATE):
date_val = root.right.value
if hasattr(date_val, "item"):
date_val = date_val.item()
root.right.value = datetime.datetime.combine(date_val, datetime.time.min)
if root.left.node_type != NodeType.IDENTIFIER:
raise NotSupportedError()
if root.right.node_type != NodeType.LITERAL:
raise NotSupportedError()
if root.left.type in (
OrsoTypes.DOUBLE,
OrsoTypes.INTEGER,
OrsoTypes.VARCHAR,
):
# not all operands are universally supported
raise NotSupportedError()
return (
root.left.value,
PredicatePushable.OPS_XLAT[root.value],
Expand Down
17 changes: 14 additions & 3 deletions opteryx/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,17 @@ def select_values(boolean_arrays, value_arrays):
return result


DEPRECATED_FUNCTIONS = {
"LIST": "ARRAY_AGG", # deprecated, remove 0.19.0
"MAXIMUM": "MAX", # deprecated, remove 0.19.0
"MINIMUM": "MIN", # deprecated, remove 0.19.0
"AVERAGE": "AVG", # deprecated, remove 0.19.0
"NUMERIC": "DOUBLE", # deprecated, remove 0.19.0
"CEILING": "CEIL", # deprecated, remove 0.19.0
"ABSOLUTE": "ABS", # deprecated, remove 0.19.0
"TRUNCATE": "TRUNC", # deprecated, remove 0.19.0
}

# fmt:off
# Function definitions optionally include the type and the function.
# The type is needed particularly when returning Python objects that
Expand Down Expand Up @@ -429,14 +440,14 @@ def select_values(boolean_arrays, value_arrays):
"ROUND": number_functions.round,
"FLOOR": number_functions.floor,
"CEIL": number_functions.ceiling,
"CEILING": number_functions.ceiling,
"CEILING": number_functions.ceiling, # deprecated, remove 0.19.0
"ABS": compute.abs,
"ABSOLUTE": compute.abs,
"ABSOLUTE": compute.abs, # deprecated, remove 0.19.0
"SIGN": compute.sign,
"SIGNUM": compute.sign,
"SQRT": compute.sqrt,
"TRUNC": compute.trunc,
"TRUNCATE": compute.trunc,
"TRUNCATE": compute.trunc, # deprecated, remove 0.19.0
"PI": lambda x: None, # *
"PHI": lambda x: None, # *
"E": lambda x: None, # *
Expand Down
1 change: 1 addition & 0 deletions opteryx/operators/async_read_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def execute(self) -> Generator:

if len(blob_names) == 0:
# if we don't have any matching blobs, create an empty dataset
# TODO: rewrite
from orso import DataFrame

as_arrow = DataFrame(rows=[], schema=orso_schema).arrow()
Expand Down
14 changes: 9 additions & 5 deletions opteryx/planner/binder/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,12 @@
from opteryx.exceptions import ColumnNotFoundError
from opteryx.exceptions import InvalidInternalStateError
from opteryx.exceptions import UnexpectedDatasetReferenceError
from opteryx.functions import FUNCTIONS
from opteryx.functions import DEPRECATED_FUNCTIONS
from opteryx.functions import fixed_value_function
from opteryx.managers.expression import NodeType
from opteryx.models import Node
from opteryx.operators.aggregate_node import AGGREGATORS
from opteryx.planner.binder.operator_map import determine_type

COMBINED_FUNCTIONS = {**FUNCTIONS, **AGGREGATORS}


def merge_schemas(*schemas: Dict[str, RelationSchema]) -> Dict[str, RelationSchema]:
"""
Expand Down Expand Up @@ -211,7 +208,7 @@ def traversive_recursive_bind(node: Node, context: Any) -> Tuple[Node, Any]:
return node, context


def inner_binder(node: Node, context: Any) -> Tuple[Node, Any]:
def inner_binder(node: Node, context: any) -> Tuple[Node, Any]:
"""
Note, this is a tree within a tree. This function represents a single step in the execution
plan (associated with the relational algebra) which may itself be an evaluation plan
Expand Down Expand Up @@ -282,6 +279,13 @@ def inner_binder(node: Node, context: Any) -> Tuple[Node, Any]:

elif node_type != NodeType.SUBQUERY and not node.do_not_create_column:
if node_type in (NodeType.FUNCTION, NodeType.AGGREGATOR):
if node.value in DEPRECATED_FUNCTIONS:
import warnings

message = f"Function '{node.value}' is deprecated and will be removed in a future version. Use '{DEPRECATED_FUNCTIONS[node.value]}' instead."
context.statistics.add_message(message)
warnings.warn(message, category=DeprecationWarning, stacklevel=2)

# we need to add this new column to the schema
aliases = [node.alias] if node.alias else []
result_type, fixed_function_result = fixed_value_function(node.value, context)
Expand Down
21 changes: 2 additions & 19 deletions opteryx/utils/file_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def parquet_decoder(
# Open the parquet file only once
parquet_file = parquet.ParquetFile(stream)

if projection or just_schema or selection:
if projection or just_schema: # or selection:
# Return just the schema if that's all that's needed
if just_schema:
return convert_arrow_schema_to_orso_schema(parquet_file.schema_arrow)
Expand All @@ -245,24 +245,6 @@ def parquet_decoder(
if not selected_columns:
selected_columns = None

if not columns_in_filters.issubset(arrow_schema_columns_set):
if selected_columns is None:
selected_columns = list(arrow_schema_columns_set)
fields = [pyarrow.field(name, pyarrow.string()) for name in selected_columns]
schema = pyarrow.schema(fields)

# Create an empty table with the schema
empty_table = pyarrow.Table.from_arrays(
[pyarrow.array([], type=schema.field(i).type) for i in range(len(fields))],
schema=schema,
)

return (
parquet_file.metadata.num_rows,
parquet_file.metadata.num_columns,
empty_table,
)

if selected_columns is None and not force_read:
selected_columns = []

Expand All @@ -274,6 +256,7 @@ def parquet_decoder(
filters=dnf_filter,
use_threads=False,
)

# Any filters we couldn't push to PyArrow to read we run here
if selection:
table = filter_records(selection, table)
Expand Down
4 changes: 2 additions & 2 deletions opteryx/utils/lru_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ def evict(self, details=False):
value = self.slots.pop(oldest_key)
self.access_history.pop(oldest_key)
self.evictions += 1
if details:
if details: # pragma: no cover
return oldest_key, value
return oldest_key

if details:
if details: # pragma: no cover
return None, None # No item was evicted
return None

Expand Down
2 changes: 1 addition & 1 deletion opteryx/utils/memory_view_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
def __iter__(self) -> Iterator:
return iter(self.mv)

def __next__(self) -> bytes:
def __next__(self) -> bytes: # pragma: no cover
if self.offset >= len(self.mv):
raise StopIteration()
self.offset += 1
Expand Down
144 changes: 144 additions & 0 deletions tests/fuzzing/test_sql_fuzzer_join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""
Generate random SQL statements
These are pretty basic statements but this has still found bugs.
"""

import os
import sys

sys.path.insert(1, os.path.join(sys.path[0], "../.."))

import datetime
import random
import time

import pytest
from orso.tools import random_int, random_string
from orso.types import OrsoTypes

import opteryx
from opteryx.utils.formatter import format_sql


def random_value(t):
if t == OrsoTypes.VARCHAR:
return f"'{random_string(4)}'"
if t in (OrsoTypes.DATE, OrsoTypes.TIMESTAMP):
return f"'{datetime.datetime.now() + datetime.timedelta(seconds=random_int())}'"
if random.random() < 0.5:
return random_int()
return random_int() / 1000


def generate_condition(columns):
where_column = columns[random.choice(range(len(columns)))]
while where_column.type in (OrsoTypes.ARRAY, OrsoTypes.STRUCT):
where_column = columns[random.choice(range(len(columns)))]
if random.random() < 0.1:
where_operator = random.choice(["IS", "IS NOT"])
where_value = random.choice(["TRUE", "FALSE", "NULL"])
elif where_column.type == OrsoTypes.VARCHAR and random.random() < 0.5:
where_operator = random.choice(
["LIKE", "ILIKE", "NOT LIKE", "NOT ILIKE", "RLIKE", "NOT RLIKE"]
)
where_value = (
"'" + random_string(8).replace("1", "%").replace("A", "%").replace("6", "_") + "'"
)
else:
where_operator = random.choice(["=", "!=", "<", "<=", ">", ">="])
where_value = f"{str(random_value(where_column.type))}"
return f"{where_column.name} {where_operator} {where_value}"

def generate_random_sql_join(columns1, table1, columns2, table2) -> str:
join_type = random.choice(["INNER JOIN", "LEFT JOIN", "RIGHT JOIN", "FULL OUTER JOIN"])

left_column = columns1[random.choice(range(len(columns1)))]
right_column = columns2[random.choice(range(len(columns2)))]
while left_column.type != right_column.type:
left_column = columns1[random.choice(range(len(columns1)))]
right_column = columns2[random.choice(range(len(columns2)))]

join_condition = f"{table1}.{left_column.name} = {table2}.{right_column.name}"
selected_columns = [f"{table1}.{col.name}" for col in columns1 if random.random() < 0.2] + [f"{table2}.{col.name}" for col in columns2 if random.random() < 0.2]
if len(selected_columns) == 0:
selected_columns = ["*"]
select_clause = "SELECT " + ", ".join(selected_columns)

query = f"{select_clause} FROM {table1} {join_type} {table2} ON {join_condition}"

return query

from opteryx import virtual_datasets

TABLES = [
{
"name": virtual_datasets.planets.schema().name,
"fields": virtual_datasets.planets.schema().columns,
},
{
"name": virtual_datasets.satellites.schema().name,
"fields": virtual_datasets.satellites.schema().columns,
},
{
"name": virtual_datasets.astronauts.schema().name,
"fields": virtual_datasets.astronauts.schema().columns,
},
{
"name": virtual_datasets.missions.schema().name,
"fields": virtual_datasets.missions.schema().columns,
},
{
"name": "testdata.planets",
"fields": virtual_datasets.planets.schema().columns,
},
{
"name": "testdata.satellites",
"fields": virtual_datasets.satellites.schema().columns,
},
{
"name": "testdata.missions",
"fields": virtual_datasets.missions.schema().columns,
},
]

TEST_CYCLES: int = 250


@pytest.mark.parametrize("i", range(TEST_CYCLES))
def test_sql_fuzzing_join(i):
seed = random_int()
random.seed(seed)
print(f"Seed: {seed}")

table1 = TABLES[random.choice(range(len(TABLES)))]
table2 = TABLES[random.choice(range(len(TABLES)))]
while table1 == table2:
table2 = TABLES[random.choice(range(len(TABLES)))]
statement = generate_random_sql_join(table1["fields"], table1["name"], table2["fields"], table2["name"])
formatted_statement = format_sql(statement)

print(formatted_statement)

start_time = time.time() # Start timing the query execution
try:
res = opteryx.query(statement)
execution_time = time.time() - start_time # Measure execution time
print(f"Shape: {res.shape}, Execution Time: {execution_time:.2f} seconds")
# Additional success criteria checks can be added here
except Exception as e:
import traceback

print(f"\033[0;31mError in Test Cycle {i+1}\033[0m: {e}")
print(traceback.print_exc())
# Log failing statement and error for analysis
raise e
print()


if __name__ == "__main__": # pragma: no cover

for i in range(TEST_CYCLES):
test_sql_fuzzing_join(i)

print("✅ okay")
Loading

0 comments on commit 9918636

Please sign in to comment.