Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1590 #1591

Merged
merged 4 commits into from
Apr 21, 2024
Merged

#1590 #1591

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion opteryx/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__build__ = 428
__build__ = 430

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
8 changes: 7 additions & 1 deletion opteryx/components/binder/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,13 @@ def inner_binder(node: Node, context: Any) -> Tuple[Node, Any]:
node.query_column = node.alias or column_name

identifiers = get_all_nodes_of_type(node, (NodeType.IDENTIFIER,))
node.relations = {col.source for col in identifiers if col.source is not None}
sources = []
for col in identifiers:
if col.source is not None:
sources.append(col.source)
if col.schema_column is not None:
sources.extend(col.schema_column.origin)
node.relations = set(sources)

context.schemas = schemas
return node, context
6 changes: 4 additions & 2 deletions opteryx/components/binder/binder_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,9 +837,11 @@ def visit_subquery(self, node: Node, context: BindingContext) -> Tuple[Node, Bin
),
None,
)
if not schema_column.origin:
schema_column.origin = []
source_relations.extend(schema_column.origin or [])
projection_column.source = node.alias
schema_column.origin = [node.alias]
schema_column.origin += [node.alias]

schema_column.name = (
projection_column.current_name if projection_column else schema_column.name
Expand All @@ -858,7 +860,7 @@ def visit_subquery(self, node: Node, context: BindingContext) -> Tuple[Node, Bin
schema = RelationSchema(name=node.alias, columns=columns)

context.schemas = {"$derived": derived.schema(), node.alias: schema}
context.relations = {node.alias}
context.relations.add(node.alias)
node.schema = schema
node.source_relations = set(source_relations)
return node, context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,7 @@ def visit(self, node: LogicalPlanNode, context: OptimizerContext) -> OptimizerCo
if not context.optimized_plan:
context.optimized_plan = context.pre_optimized_tree.copy() # type: ignore

if node.node_type in (
LogicalPlanStepType.Scan,
LogicalPlanStepType.FunctionDataset,
LogicalPlanStepType.Subquery,
):
if node.node_type in (LogicalPlanStepType.Scan, LogicalPlanStepType.FunctionDataset):
# Handle predicates specific to node types
context = self._handle_predicates(node, context)

Expand Down Expand Up @@ -236,7 +232,7 @@ def _handle_predicates(
) -> OptimizerContext:
remaining_predicates = []
for predicate in context.collected_predicates:
if len(predicate.relations) == 1 and predicate.relations.intersection(
if len(predicate.relations) >= 1 and predicate.relations.intersection(
(node.relation, node.alias)
):
if node.connector:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,14 @@ def visit(self, node: LogicalPlanNode, context: OptimizerContext) -> OptimizerCo
new_node.columns = get_all_nodes_of_type(
predicate, select_nodes=(NodeType.IDENTIFIER,)
)
new_node.relations = {c.source for c in new_node.columns}

sources = []
for col in new_node.columns:
if col.source is not None:
sources.append(col.source)
if col.schema_column is not None:
sources.extend(col.schema_column.origin)
new_node.relations = set(sources)
new_nodes.append(new_node)
else:
new_nodes = [node]
Expand Down
3 changes: 2 additions & 1 deletion opteryx/models/connection_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import datetime
from dataclasses import dataclass
from dataclasses import field
from typing import Iterable
from typing import List
from typing import Tuple

Expand Down Expand Up @@ -51,7 +52,7 @@ class ConnectionContext:
connected_at: datetime.datetime = field(default_factory=datetime.datetime.utcnow, init=False)
user: str = None
schema: str = None
memberships: str = None
memberships: Iterable[str] = None
variables: SystemVariablesContainer = field(init=False)
history: List[HistoryItem] = field(default_factory=list, init=False)

Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cython
numpy
numpy==1.*
orjson
orso>=0.0.147
orso>=0.0.151
pyarrow>=12.0.1
typer==0.11.*

179 changes: 87 additions & 92 deletions tests/plan_optimization/test_predicate_pushdown_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

import os
import sys
import time

sys.path.insert(1, os.path.join(sys.path[0], "../.."))

import opteryx
from opteryx.connectors import SqlConnector
from opteryx.utils.formatter import format_sql

opteryx.register_store(
"sqlite",
Expand All @@ -17,98 +19,91 @@
connection="sqlite:///testdata/sqlite/database.db",
)


def test_predicate_pushdowns_sqlite_eq():
"""
This is the same test as the collection pushdown - but on a different dataset
"""

conn = opteryx.connect()

cur = conn.cursor()
cur.execute("SELECT * FROM sqlite.planets WHERE name = 'Mercury';")
# when pushdown is enabled, we only read the matching rows from the source
assert cur.rowcount == 1, cur.rowcount
assert cur.stats.get("rows_read", 0) == 1, cur.stats

cur = conn.cursor()
cur.execute("SELECT * FROM sqlite.planets WHERE name = 'Mercury' AND gravity = 3.7;")
# test with a two part filter
assert cur.rowcount == 1, cur.rowcount
assert cur.stats.get("rows_read", 0) == 1, cur.stats

cur = conn.cursor()
cur.execute(
"SELECT * FROM sqlite.planets WHERE name = 'Mercury' AND gravity = 3.7 AND escapeVelocity = 5.0;"
)
# test with A three part filter
assert cur.rowcount == 0, cur.rowcount
assert cur.stats.get("rows_read", 0) == 0, cur.stats

cur = conn.cursor()
cur.execute(
"SELECT * FROM sqlite.planets WHERE gravity = 3.7 AND name IN ('Mercury', 'Venus');"
)
# we don't push all predicates down,
assert cur.rowcount == 1, cur.rowcount
assert cur.stats.get("rows_read", 0) == 2, cur.stats

cur = conn.cursor()
cur.execute("SELECT * FROM sqlite.planets WHERE surfacePressure IS NULL;")
# We push unary ops to SQL
assert cur.rowcount == 4, cur.rowcount
assert cur.stats.get("rows_read", 0) == 4, cur.stats

cur = conn.cursor()
cur.execute(
"SELECT * FROM sqlite.planets WHERE orbitalInclination IS FALSE AND name IN ('Earth', 'Mars');"
)
# We push unary ops to SQL
assert cur.rowcount == 1, cur.rowcount
assert cur.stats.get("rows_read", 0) == 1, cur.stats

conn.close()


def test_predicate_pushdown_sqlite_other():
res = opteryx.query("SELECT * FROM sqlite.planets WHERE gravity <= 3.7")
assert res.rowcount == 3, res.rowcount
assert res.stats.get("rows_read", 0) == 3, res.stats

res = opteryx.query("SELECT * FROM sqlite.planets WHERE name != 'Earth'")
assert res.rowcount == 8, res.rowcount
assert res.stats.get("rows_read", 0) == 8, res.stats

res = opteryx.query("SELECT * FROM sqlite.planets WHERE name != 'E\"arth'")
assert res.rowcount == 9, res.rowcount
assert res.stats.get("rows_read", 0) == 9, res.stats

res = opteryx.query("SELECT * FROM sqlite.planets WHERE gravity != 3.7")
assert res.rowcount == 7, res.rowcount
assert res.stats.get("rows_read", 0) == 7, res.stats

res = opteryx.query("SELECT * FROM sqlite.planets WHERE gravity < 3.7")
assert res.rowcount == 1, res.rowcount
assert res.stats.get("rows_read", 0) == 1, res.stats

res = opteryx.query("SELECT * FROM sqlite.planets WHERE gravity > 3.7")
assert res.rowcount == 6, res.rowcount
assert res.stats.get("rows_read", 0) == 6, res.stats

res = opteryx.query("SELECT * FROM sqlite.planets WHERE gravity >= 3.7")
assert res.rowcount == 8, res.rowcount
assert res.stats.get("rows_read", 0) == 8, res.stats

res = opteryx.query("SELECT * FROM sqlite.planets WHERE name LIKE '%a%'")
assert res.rowcount == 4, res.rowcount
assert res.stats.get("rows_read", 0) == 4, res.stats

res = opteryx.query("SELECT * FROM sqlite.planets WHERE id > gravity")
assert res.rowcount == 2, res.rowcount
assert res.stats.get("rows_read", 0) == 9, res.stats
test_cases = [
("SELECT * FROM sqlite.planets WHERE name = 'Mercury';", 1, 1),
("SELECT * FROM sqlite.planets WHERE name = 'Mercury' AND gravity = 3.7;", 1, 1),
(
"SELECT * FROM sqlite.planets WHERE name = 'Mercury' AND gravity = 3.7 AND escapeVelocity = 5.0;",
0,
0,
),
("SELECT * FROM sqlite.planets WHERE gravity = 3.7 AND name IN ('Mercury', 'Venus');", 1, 2),
("SELECT * FROM sqlite.planets WHERE surfacePressure IS NULL;", 4, 4),
(
"SELECT * FROM sqlite.planets WHERE orbitalInclination IS FALSE AND name IN ('Earth', 'Mars');",
1,
1,
),
("SELECT * FROM (SELECT name FROM sqlite.planets) AS $temp WHERE name = 'Earth';", 1, 1),
("SELECT * FROM sqlite.planets WHERE gravity <= 3.7", 3, 3),
("SELECT * FROM sqlite.planets WHERE name != 'Earth'", 8, 8),
("SELECT * FROM sqlite.planets WHERE name != 'E\"arth'", 9, 9),
("SELECT * FROM sqlite.planets WHERE gravity != 3.7", 7, 7),
("SELECT * FROM sqlite.planets WHERE gravity < 3.7", 1, 1),
("SELECT * FROM sqlite.planets WHERE gravity > 3.7", 6, 6),
("SELECT * FROM sqlite.planets WHERE gravity >= 3.7", 8, 8),
("SELECT * FROM sqlite.planets WHERE name LIKE '%a%'", 4, 4),
("SELECT * FROM sqlite.planets WHERE id > gravity", 2, 9),
]


import pytest


@pytest.mark.parametrize("statement,expected_rowcount,expected_rows_read", test_cases)
def test_predicate_pushdown_postgres_parameterized(
statement, expected_rowcount, expected_rows_read
):
res = opteryx.query(statement)
assert res.rowcount == expected_rowcount, f"Expected {expected_rowcount}, got {res.rowcount}"
assert (
res.stats.get("rows_read", 0) == expected_rows_read
), f"Expected {expected_rows_read}, got {res.stats.get('rows_read', 0)}"


if __name__ == "__main__": # pragma: no cover
from tests.tools import run_tests

run_tests()
import shutil

from tests.tools import trunc_printable

start_suite = time.monotonic_ns()
passed = 0
failed = 0

width = shutil.get_terminal_size((80, 20))[0] - 15

print(f"RUNNING BATTERY OF {len(test_cases)} TESTS")
for index, (statement, returned_rows, read_rows) in enumerate(test_cases):
print(
f"\033[38;2;255;184;108m{(index + 1):04}\033[0m"
f" {trunc_printable(format_sql(statement), width - 1)}",
end="",
flush=True,
)
try:
start = time.monotonic_ns()
test_predicate_pushdown_postgres_parameterized(statement, returned_rows, read_rows)
print(
f"\033[38;2;26;185;67m{str(int((time.monotonic_ns() - start)/1e6)).rjust(4)}ms\033[0m ✅",
end="",
)
passed += 1
if failed > 0:
print(" \033[0;31m*\033[0m")
else:
print()
except Exception as err:
print(f"\033[0;31m{str(int((time.monotonic_ns() - start)/1e6)).rjust(4)}ms ❌ *\033[0m")
print(">", err)
failed += 1

print("--- ✅ \033[0;32mdone\033[0m")

if failed > 0:
print("\n\033[38;2;139;233;253m\033[3mFAILURES\033[0m")

print(
f"\n\033[38;2;139;233;253m\033[3mCOMPLETE\033[0m ({((time.monotonic_ns() - start_suite) / 1e9):.2f} seconds)\n"
f" \033[38;2;26;185;67m{passed} passed ({(passed * 100) // (passed + failed)}%)\033[0m\n"
f" \033[38;2;255;121;198m{failed} failed\033[0m"
)
73 changes: 72 additions & 1 deletion tests/sql_battery/test_results_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import orjson

import opteryx
from opteryx.utils.formatter import format_sql

OS_SEP = os.sep

Expand Down Expand Up @@ -57,7 +58,7 @@ def test_results_tests(test):
), f"Outcome:\n{printable_result}\nExpected:\n{printable_expected}"


if __name__ == "__main__": # pragma: no cover
if __name__ == "__dmain__": # pragma: no cover
import shutil
import time

Expand All @@ -77,3 +78,73 @@ def test_results_tests(test):
print(f"\033[0;32m{str(int((time.monotonic_ns() - start)/1000000)).rjust(4)}ms\033[0m ✅")

print("--- ✅ \033[0;32mdone\033[0m")


if __name__ == "__main__": # pragma: no cover
"""
Running in the IDE we do some formatting - it's not functional but helps
when reading the outputs.
"""

import shutil
import time

from tests.tools import trunc_printable

start_suite = time.monotonic_ns()

width = shutil.get_terminal_size((80, 20))[0] - 45

passed = 0
failed = 0

nl = "\n"

failures = []

print(f"RUNNING BATTERY OF {len(RESULTS_TESTS)} RESULTS TESTS")
for index, test in enumerate(RESULTS_TESTS):

printable = test["statement"]
test_id = test["file"].split(OS_SEP)[-1].split(".")[0][0:25].ljust(25)
if hasattr(printable, "decode"):
printable = printable.decode()
print(
f"\033[38;2;255;184;108m{(index + 1):04}\033[0m",
f"\033[0;35m{test_id}\033[0m",
f" {trunc_printable(format_sql(printable), width - 1)}",
end="",
flush=True,
)
try:
start = time.monotonic_ns()
test_results_tests(test)
print(
f"\033[38;2;26;185;67m{str(int((time.monotonic_ns() - start)/1e6)).rjust(4)}ms\033[0m ✅",
end="",
)
passed += 1
if failed > 0:
print(" \033[0;31m*\033[0m")
else:
print()
except Exception as err:
print(f"\033[0;31m{str(int((time.monotonic_ns() - start)/1e6)).rjust(4)}ms ❌ *\033[0m")
print(">", err)
failed += 1
failures.append((test_id, test["statement"], err))

print("--- ✅ \033[0;32mdone\033[0m")

if failed > 0:
print("\n\033[38;2;139;233;253m\033[3mFAILURES\033[0m")
for test, statement, err in failures:
print(
f"\033[38;2;26;185;67m{test}\033[0m\n{format_sql(statement)}\n\033[38;2;255;121;198m{err}\033[0m\n"
)

print(
f"\n\033[38;2;139;233;253m\033[3mCOMPLETE\033[0m ({((time.monotonic_ns() - start_suite) / 1e9):.2f} seconds)\n"
f" \033[38;2;26;185;67m{passed} passed ({(passed * 100) // (passed + failed)}%)\033[0m\n"
f" \033[38;2;255;121;198m{failed} failed\033[0m"
)
Loading