Skip to content

Commit

Permalink
Merge pull request #1710 from mabel-dev/#1709
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer authored May 30, 2024
2 parents 7917c65 + 292f125 commit fb2c8ca
Show file tree
Hide file tree
Showing 12 changed files with 38 additions and 34 deletions.
2 changes: 1 addition & 1 deletion opteryx/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__build__ = 537
__build__ = 538

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion opteryx/connectors/base/base_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def chunk_dictset(
_id = record.pop("_id", None)
# column selection
if columns:
record = {k.name: record.get(k.name) for k in columns}
record = {k.source_column: record.get(k.source_column) for k in columns}
record["id"] = None if _id is None else str(_id)

chunk.append(record)
Expand Down
2 changes: 1 addition & 1 deletion opteryx/connectors/cql_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def read_dataset( # type:ignore

# Update the SQL and the target morsel schema if we've pushed a projection
if columns:
column_names = [f'"{col.name}"' for col in columns]
column_names = [f'"{col.source_column}"' for col in columns]
query_builder.add("SELECT", *column_names)
result_schema.columns = [ # type:ignore
col for col in self.schema.columns if f'"{col.name}"' in column_names # type:ignore
Expand Down
3 changes: 3 additions & 0 deletions opteryx/models/logical_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@ def __init__(
self,
node_type,
source_column: str,
source_connector: Optional[str] = None,
source: Optional[str] = None,
alias: Optional[str] = None,
schema_column=None,
):
self.node_type = node_type
self.source_column = source_column
self.source_connector = source_connector
self.source = source
self.alias = alias
self.schema_column = schema_column
Expand Down Expand Up @@ -89,6 +91,7 @@ def to_dict(self) -> dict:
"class": "LogicalColumn",
"node_type": self.node_type.name,
"source_column": self.source_column,
"source_connector": self.source_connector,
"source": self.source,
"alias": self.alias,
"schema_column": dataclass_to_dict(self.schema_column),
Expand Down
3 changes: 3 additions & 0 deletions opteryx/planner/binder/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def create_variable_node(node: Node, context: BindingContext) -> Node:

# Update node.schema_column with the found column
node.schema_column = column
node.source_connector = {context.relations.get(a) for a in found_source_relation.aliases} - {
None
}
# if may need to map source aliases to the columns if they weren't able to be
# mapped before now
if column.origin and len(column.origin) == 1:
Expand Down
19 changes: 11 additions & 8 deletions opteryx/planner/binder/binder_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ def visit_scan(self, node: Node, context: BindingContext) -> Tuple[Node, Binding
context.schemas[node.alias] = node.schema
for column in node.schema.columns:
column.origin = [node.alias]
context.relations.add(node.alias)
context.relations[node.alias] = node.connector.__mode__

return node, context

Expand Down Expand Up @@ -864,21 +864,21 @@ def visit_subquery(self, node: Node, context: BindingContext) -> Tuple[Node, Bin
schema_column.aliases = []
columns.append(schema_column)
if name[0] != "$" and name in context.relations:
context.relations.remove(name)
context.relations.add(node.alias)
context.relations.pop(name)
context.relations[node.alias] = "subquery"

schema = RelationSchema(name=node.alias, columns=columns)

context.schemas = {"$derived": derived.schema(), node.alias: schema}
context.relations.add(node.alias)
context.relations[node.alias] = "subquery"
node.schema = schema
node.source_relations = set(source_relations)
return node, context

def visit_union(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]:
for relation in node.right_relation_names:
context.schemas.pop(relation, None)
context.relations = set(node.left_relation_names)
context.relations = {n: "union" for n in node.left_relation_names}

if len(node.columns) == 1 and node.columns[0].node_type == NodeType.WILDCARD:
columns = []
Expand Down Expand Up @@ -950,9 +950,12 @@ def traverse(
exit_context.schemas = merge_schemas(child_context.schemas, exit_context.schemas)

# Update relations if necessary
context.relations = context.relations.union(exit_context.relations).union(
child_context.relations
)
merged_relations = {
**context.relations,
**exit_context.relations,
**child_context.relations,
}
context.relations = merged_relations

context.schemas = merge_schemas(context.schemas, exit_context.schemas)

Expand Down
7 changes: 3 additions & 4 deletions opteryx/planner/binder/binding_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from dataclasses import dataclass
from typing import Any
from typing import Dict
from typing import Set

from opteryx.models import ConnectionContext
from opteryx.models import QueryStatistics
Expand All @@ -40,7 +39,7 @@ class BindingContext:
schemas: Dict[str, Any]
qid: str
connection: ConnectionContext
relations: Set
relations: Dict[str, str]
statistics: QueryStatistics

@classmethod
Expand All @@ -61,7 +60,7 @@ def initialize(cls, qid: str, connection=None) -> "BindingContext":
schemas={"$derived": derived.schema()}, # Replace with the actual schema
qid=qid,
connection=connection,
relations=set(),
relations={},
statistics=QueryStatistics(qid),
)

Expand All @@ -76,6 +75,6 @@ def copy(self) -> "BindingContext":
schemas=deepcopy(self.schemas),
qid=self.qid,
connection=self.connection,
relations=set(self.relations),
relations={k: v for k, v in self.relations.items()},
statistics=self.statistics,
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from opteryx.connectors.capabilities import PredicatePushable
from opteryx.exceptions import UnsupportedSyntaxError
from opteryx.functions import FUNCTIONS
from opteryx.managers.expression import NodeType
from opteryx.managers.expression import get_all_nodes_of_type
from opteryx.models import Node
Expand Down Expand Up @@ -68,7 +67,11 @@ def _rewrite_predicate(predicate):
):
predicate.value = LIKE_REWRITES[predicate.value]
return predicate
if predicate.value in {"Like", "ILike"} and predicate.right.value:
if (
predicate.value in {"Like", "ILike"}
and predicate.right.value
and predicate.left.source_connector.isdisjoint({"Sql", "Cql"})
):
ignore_case = predicate.value == "ILike"
# Rewrite LIKEs as STARTS_WITH
if (
Expand Down
2 changes: 1 addition & 1 deletion tests/misc/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def test_functions():

rounded = evaluate(_round, planets)
assert len(rounded) == 9
assert set(r for r in rounded) == {4, 23, 9, 1, 11, 10}
assert set(r.as_py() for r in rounded) == {4, 23, 9, 1, 11, 10}, list(rounded)


if __name__ == "__main__": # pragma: no cover
Expand Down
15 changes: 4 additions & 11 deletions tests/plan_optimization/test_predicate_pushdown_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,14 @@
connection="sqlite:///testdata/sqlite/database.db",
)

# fmt: off
test_cases = [
("SELECT * FROM sqlite.planets WHERE name = 'Mercury';", 1, 1),
("SELECT * FROM sqlite.planets WHERE name = 'Mercury' AND gravity = 3.7;", 1, 1),
(
"SELECT * FROM sqlite.planets WHERE name = 'Mercury' AND gravity = 3.7 AND escapeVelocity = 5.0;",
0,
0,
),
("SELECT * FROM sqlite.planets WHERE name = 'Mercury' AND gravity = 3.7 AND escapeVelocity = 5.0;", 0, 0),
("SELECT * FROM sqlite.planets WHERE gravity = 3.7 AND name IN ('Mercury', 'Venus');", 1, 2),
("SELECT * FROM sqlite.planets WHERE surfacePressure IS NULL;", 4, 4),
(
"SELECT * FROM sqlite.planets WHERE orbitalInclination IS FALSE AND name IN ('Earth', 'Mars');",
1,
1,
),
("SELECT * FROM sqlite.planets WHERE orbitalInclination IS FALSE AND name IN ('Earth', 'Mars');", 1, 1),
("SELECT * FROM (SELECT name FROM sqlite.planets) AS $temp WHERE name = 'Earth';", 1, 1),
("SELECT * FROM sqlite.planets WHERE gravity <= 3.7", 3, 3),
("SELECT * FROM sqlite.planets WHERE name != 'Earth'", 8, 8),
Expand All @@ -45,7 +38,7 @@
("SELECT * FROM sqlite.planets WHERE name LIKE '%a%'", 4, 4),
("SELECT * FROM sqlite.planets WHERE id > gravity", 2, 9),
]

# fmt:on

import pytest

Expand Down
8 changes: 4 additions & 4 deletions tests/query_planner/test_binder_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,23 @@ def visit_scan(self, node, context):

def visit_filter(self, node, context):
# the filter has the left scan before it
node.sources = set(context.schemas.keys())
node.sources = {a: "test" for a in context.schemas.keys()}
node.columns = []
return node, context

def visit_union(self, node, context):
node.sources = set(context.schemas.keys())
node.sources = {a: "test" for a in context.schemas.keys()}
node.columns = []
return node, context

def visit_project(self, node, context):
# the project has the left and right scans before it
node.sources = set(context.schemas.keys())
node.sources = {a: "test" for a in context.schemas.keys()}
node.columns = []
return node, context

context = BindingContext(
schemas={}, qid="12345", connection=None, relations=set(), statistics=None
schemas={}, qid="12345", connection=None, relations={}, statistics=None
)

visitor = TestBinderVisitor()
Expand Down
2 changes: 1 addition & 1 deletion tests/storage/test_collection_mongodb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Test we can read from MinIO
Test we can read from Mongo
"""

import os
Expand Down

0 comments on commit fb2c8ca

Please sign in to comment.