Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer committed Sep 17, 2024
1 parent 42ecf80 commit 3be71a3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 29 deletions.
4 changes: 3 additions & 1 deletion opteryx/models/logical_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def copy(self):
source_column=self.source_column,
source=self.source,
alias=self.alias,
schema_column=self.schema_column,
schema_column=None
if self.schema_column is None
else self.schema_column.to_flatcolumn(),
)

def __repr__(self) -> str:
Expand Down
68 changes: 40 additions & 28 deletions opteryx/utils/file_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from typing import BinaryIO
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
Expand Down Expand Up @@ -104,21 +103,35 @@ def do_nothing(buffer: Union[memoryview, bytes], **kwargs): # pragma: no cover
return False


def filter_records(filter: Optional[Union[List, Tuple]], table: pyarrow.Table):
def filter_records(filters: Optional[list], table: pyarrow.Table) -> pyarrow.Table:
"""
When we can't push predicates to the actual read, use this to filter records
just after the read.
Apply filters to a PyArrow table that could not be pushed down during the read operation.
This is a post-read filtering step.
Parameters:
filters: Optional[list]
A list of filter conditions (predicates) to apply to the table.
table: pyarrow.Table
The PyArrow table to be filtered.
Returns:
pyarrow.Table:
A new PyArrow table with rows filtered according to the specified conditions.
Note:
At this point the columns are the raw column names from the file so we need to ensure
the filters reference the raw column names not the engine internal 'identity'=
"""
# notes:
# at this point we've not renamed any columns, this may affect some filters
from opteryx.managers.expression import NodeType
from opteryx.managers.expression import evaluate
from opteryx.models import Node

if isinstance(filter, list) and filter:
filter_copy = list(filter)
if isinstance(filters, list) and filters:
# Create a copy of the filters list to avoid mutating the original.
filter_copy = [f.copy() for f in filters]
root = filter_copy.pop()

# If the left or right side of the root filter node is an identifier, set its identity.
# This step ensures that the filtering logic aligns with the schema before any renaming.
if root.left.node_type == NodeType.IDENTIFIER:
root.left.schema_column.identity = root.left.source_column
if root.right.node_type == NodeType.IDENTIFIER:
Expand All @@ -130,14 +143,15 @@ def filter_records(filter: Optional[Union[List, Tuple]], table: pyarrow.Table):
right.left.schema_column.identity = right.left.source_column
if right.right.node_type == NodeType.IDENTIFIER:
right.right.schema_column.identity = right.right.source_column
# Combine the current root with the next filter using an AND node.
root = Node(
NodeType.AND,
left=root,
right=right,
schema_column=Node("schema_column", identity=random_string()),
)
else:
root = filter
root = filters

mask = evaluate(root, table)
return table.filter(mask)
Expand Down Expand Up @@ -212,11 +226,11 @@ def parquet_decoder(
Returns:
Tuple containing number of rows, number of columns, and the table or schema.
"""
selected_columns = None

# we need to work out if we have a selection which may force us
# fetching columns just for filtering
dnf_filter, selection = PredicatePushable.to_dnf(selection) if selection else (None, None)
dnf_filter, processed_selection = (
PredicatePushable.to_dnf(selection) if selection else (None, None)
)

# try to avoid turning a memoryview buffer into bytes, it's quite slow
stream: BinaryIO = (
Expand All @@ -230,33 +244,31 @@ def parquet_decoder(
if just_schema:
return convert_arrow_schema_to_orso_schema(parquet_file.schema_arrow)

# Projection processing
columns_in_filters = {c.value for c in get_all_nodes_of_type(selection, (NodeType.IDENTIFIER,))}
arrow_schema_columns_set = set(parquet_file.schema_arrow.names)
projection_names = {
name for proj_col in projection for name in proj_col.schema_column.all_names
}.union(columns_in_filters)
selected_columns = list(arrow_schema_columns_set.intersection(projection_names))

# If no columns are selected, set to None
if not selected_columns:
selected_columns = None
# Determine the columns needed for projection and filtering
projection_set = set(p.source_column for p in projection or [])
filter_columns = {
c.value for c in get_all_nodes_of_type(processed_selection, (NodeType.IDENTIFIER,))
}
selected_columns = list(
projection_set.union(filter_columns).intersection(parquet_file.schema_arrow.names)
)

if selected_columns is None and not force_read:
# Read all columns if none are selected, unless force_read is set
if not selected_columns and not force_read:
selected_columns = []

# Read the parquet table with the optimized column list and selection filters
table = parquet.read_table(
stream,
columns=selected_columns,
columns=selected_columns if selected_columns else None,
pre_buffer=False,
filters=dnf_filter,
use_threads=False,
)

# Any filters we couldn't push to PyArrow to read we run here
if selection:
table = filter_records(selection, table)
if processed_selection:
table = filter_records(processed_selection, table)

return (parquet_file.metadata.num_rows, parquet_file.metadata.num_columns, table)

Expand Down

0 comments on commit 3be71a3

Please sign in to comment.