Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade datafusion #867

Merged
merged 26 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
25e42c8
update dependencies
emgeee Sep 9, 2024
9cea1fb
update get_logical_plan signature
emgeee Sep 9, 2024
6fca28b
remove row_number() function
emgeee Sep 9, 2024
f2b3d3b
remove unneeded dependency
emgeee Sep 9, 2024
4b45a4b
fix pyo3 warnings
emgeee Sep 10, 2024
6353aa9
update object_store dependency
emgeee Sep 10, 2024
815b6d7
change PyExpr -> PySortExpr
emgeee Sep 10, 2024
92806a8
comment out key.extract::<&PyTuple>() condition statement
emgeee Sep 10, 2024
e2fa24e
change more instances of PyExpr > PySortExpr
emgeee Sep 10, 2024
21013a7
update function signatures to use _bound versions
emgeee Sep 10, 2024
142e4ed
remove clone
emgeee Sep 10, 2024
e971add
Working through some of the sort requirement changes
timsaucer Sep 10, 2024
c89357e
remove unused import
emgeee Sep 10, 2024
8255f09
expr.display_name is deprecated, used format!() + schema_name() instead
emgeee Sep 10, 2024
df46054
expr.canonical_name() is deprecated, use format!() expr instead
emgeee Sep 10, 2024
6c27614
remove comment
emgeee Sep 10, 2024
70546e2
fix tuple extraction in dataframe.__getitem__()
emgeee Sep 10, 2024
836061f
remove unneeded import
emgeee Sep 10, 2024
4945661
Add docstring comments to SortExpr python class
emgeee Sep 10, 2024
cd04c44
change extract() to downcast()
emgeee Sep 10, 2024
afcc9f1
deprecate Expr::display_name
Michael-J-Ward Aug 10, 2024
7f6187a
fix lint errors
emgeee Sep 10, 2024
8aebaea
update datafusion commit hash
emgeee Sep 11, 2024
afa303f
fix type in cargo file for arrow features
emgeee Sep 17, 2024
f4574ec
upgrade to datafusion 42
emgeee Sep 17, 2024
88ccbd8
cleanup
emgeee Sep 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
737 changes: 357 additions & 380 deletions Cargo.lock

Large diffs are not rendered by default.

15 changes: 9 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,24 @@ substrait = ["dep:datafusion-substrait"]
[dependencies]
tokio = { version = "1.39", features = ["macros", "rt", "rt-multi-thread", "sync"] }
rand = "0.8"
pyo3 = { version = "0.21", features = ["extension-module", "abi3", "abi3-py38"] }
arrow = { version = "52", feature = ["pyarrow"] }
pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] }
arrow = { version = "53", feature = ["pyarrow"] }
datafusion = { version = "41.0.0", features = ["pyarrow", "avro", "unicode_expressions"] }
datafusion-substrait = { version = "41.0.0", optional = true }
prost = "0.12" # keep in line with `datafusion-substrait`
prost-types = "0.12" # keep in line with `datafusion-substrait`
prost = "0.13" # keep in line with `datafusion-substrait`
prost-types = "0.13" # keep in line with `datafusion-substrait`
uuid = { version = "1.9", features = ["v4"] }
mimalloc = { version = "0.1", optional = true, default-features = false, features = ["local_dynamic_tls"] }
async-trait = "0.1"
futures = "0.3"
object_store = { version = "0.10.1", features = ["aws", "gcp", "azure"] }
object_store = { version = "0.11.0", features = ["aws", "gcp", "azure"] }
parking_lot = "0.12"
regex-syntax = "0.8"
syn = "2.0.68"
url = "2"

[build-dependencies]
pyo3-build-config = "0.21"
pyo3-build-config = "0.22"

[lib]
name = "datafusion_python"
Expand All @@ -63,3 +63,6 @@ crate-type = ["cdylib", "rlib"]
lto = true
codegen-units = 1

[patch.crates-io]
datafusion = { git = "https://github.com/apache/datafusion.git", rev = "c71a9d7508e37e5d082e22d2953a12b61d290df5" }
datafusion-substrait = { git = "https://github.com/apache/datafusion.git", rev = "c71a9d7508e37e5d082e22d2953a12b61d290df5" }
13 changes: 8 additions & 5 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from datafusion._internal import AggregateUDF
from datafusion.catalog import Catalog, Table
from datafusion.dataframe import DataFrame
from datafusion.expr import Expr
from datafusion.expr import Expr, SortExpr, sort_list_to_raw_sort_list
from datafusion.record_batch import RecordBatchStream
from datafusion.udf import ScalarUDF

Expand Down Expand Up @@ -466,7 +466,7 @@ def register_listing_table(
table_partition_cols: list[tuple[str, str]] | None = None,
file_extension: str = ".parquet",
schema: pyarrow.Schema | None = None,
file_sort_order: list[list[Expr]] | None = None,
file_sort_order: list[list[Expr | SortExpr]] | None = None,
) -> None:
"""Register multiple files as a single table.

Expand All @@ -484,15 +484,18 @@ def register_listing_table(
"""
if table_partition_cols is None:
table_partition_cols = []
if file_sort_order is not None:
file_sort_order = [[x.expr for x in xs] for xs in file_sort_order]
file_sort_order_raw = (
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
if file_sort_order is not None
else None
)
self.ctx.register_listing_table(
name,
str(path),
table_partition_cols,
file_extension,
schema,
file_sort_order,
file_sort_order_raw,
)

def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
Expand Down
8 changes: 4 additions & 4 deletions python/datafusion/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from typing import Callable

from datafusion._internal import DataFrame as DataFrameInternal
from datafusion.expr import Expr
from datafusion.expr import Expr, SortExpr, sort_or_default
from datafusion._internal import (
LogicalPlan,
ExecutionPlan,
Expand Down Expand Up @@ -199,7 +199,7 @@ def aggregate(
aggs = [e.expr for e in aggs]
return DataFrame(self.df.aggregate(group_by, aggs))

def sort(self, *exprs: Expr) -> DataFrame:
def sort(self, *exprs: Expr | SortExpr) -> DataFrame:
"""Sort the DataFrame by the specified sorting expressions.

Note that any expression can be turned into a sort expression by
Expand All @@ -211,8 +211,8 @@ def sort(self, *exprs: Expr) -> DataFrame:
Returns:
DataFrame after sorting.
"""
exprs = [expr.expr for expr in exprs]
return DataFrame(self.df.sort(*exprs))
exprs_raw = [sort_or_default(expr) for expr in exprs]
return DataFrame(self.df.sort(*exprs_raw))

def limit(self, count: int, offset: int = 0) -> DataFrame:
"""Return a new :py:class:`DataFrame` with a limited number of rows.
Expand Down
59 changes: 53 additions & 6 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
ScalarVariable = expr_internal.ScalarVariable
SimilarTo = expr_internal.SimilarTo
Sort = expr_internal.Sort
SortExpr = expr_internal.SortExpr
# SortExpr = expr_internal.SortExpr
emgeee marked this conversation as resolved.
Show resolved Hide resolved
Subquery = expr_internal.Subquery
SubqueryAlias = expr_internal.SubqueryAlias
TableScan = expr_internal.TableScan
Expand Down Expand Up @@ -159,6 +159,27 @@
]


def expr_list_to_raw_expr_list(
expr_list: Optional[list[Expr]],
) -> Optional[list[expr_internal.Expr]]:
"""Helper function to convert an optional list to raw expressions."""
return [e.expr for e in expr_list] if expr_list is not None else None


def sort_or_default(e: Expr | SortExpr) -> expr_internal.SortExpr:
"""Helper function to return a default Sort if an Expr is provided."""
if isinstance(e, SortExpr):
return e.raw_sort
return SortExpr(e.expr, True, True).raw_sort


def sort_list_to_raw_sort_list(
sort_list: Optional[list[Expr | SortExpr]],
) -> Optional[list[expr_internal.SortExpr]]:
"""Helper function to return an optional sort list to raw variant."""
return [sort_or_default(e) for e in sort_list] if sort_list is not None else None


class Expr:
"""Expression object.

Expand Down Expand Up @@ -355,14 +376,14 @@ def alias(self, name: str) -> Expr:
"""Assign a name to the expression."""
return Expr(self.expr.alias(name))

def sort(self, ascending: bool = True, nulls_first: bool = True) -> Expr:
def sort(self, ascending: bool = True, nulls_first: bool = True) -> SortExpr:
"""Creates a sort :py:class:`Expr` from an existing :py:class:`Expr`.

Args:
ascending: If true, sort in ascending order.
nulls_first: Return null values first.
"""
return Expr(self.expr.sort(ascending=ascending, nulls_first=nulls_first))
return SortExpr(self.expr, ascending=ascending, nulls_first=nulls_first)

def is_null(self) -> Expr:
"""Returns ``True`` if this expression is null."""
Expand Down Expand Up @@ -439,14 +460,14 @@ def column_name(self, plan: LogicalPlan) -> str:
"""Compute the output column name based on the provided logical plan."""
return self.expr.column_name(plan)

def order_by(self, *exprs: Expr) -> ExprFuncBuilder:
def order_by(self, *exprs: Expr | SortExpr) -> ExprFuncBuilder:
"""Set the ordering for a window or aggregate function.

This function will create an :py:class:`ExprFuncBuilder` that can be used to
set parameters for either window or aggregate functions. If used on any other
type of expression, an error will be generated when ``build()`` is called.
"""
return ExprFuncBuilder(self.expr.order_by(list(e.expr for e in exprs)))
return ExprFuncBuilder(self.expr.order_by([sort_or_default(e) for e in exprs]))

def filter(self, filter: Expr) -> ExprFuncBuilder:
"""Filter an aggregate function.
Expand Down Expand Up @@ -506,7 +527,9 @@ def order_by(self, *exprs: Expr) -> ExprFuncBuilder:
Values given in ``exprs`` must be sort expressions. You can convert any other
expression to a sort expression using `.sort()`.
"""
return ExprFuncBuilder(self.builder.order_by(list(e.expr for e in exprs)))
return ExprFuncBuilder(
self.builder.order_by([sort_or_default(e) for e in exprs])
)

def filter(self, filter: Expr) -> ExprFuncBuilder:
"""Filter values during aggregation."""
Expand Down Expand Up @@ -643,3 +666,27 @@ def end(self) -> Expr:
Any non-matching cases will end in a `null` value.
"""
return Expr(self.case_builder.end())


class SortExpr:
emgeee marked this conversation as resolved.
Show resolved Hide resolved
"""Used to specify sorting on either a DataFrame or function."""

def __init__(self, expr: Expr, ascending: bool, nulls_first: bool) -> None:
"""This constructor should not be called by the end user."""
self.raw_sort = expr_internal.SortExpr(expr, ascending, nulls_first)

def expr(self) -> Expr:
"""Return the raw expr backing teh SortExpr."""
emgeee marked this conversation as resolved.
Show resolved Hide resolved
return Expr(self.raw_sort.expr())

def ascending(self) -> bool:
"""Return ascending property."""
return self.raw_sort.ascending()

def nulls_first(self) -> bool:
"""Return nulls_first property."""
return self.raw_sort.nulls_first()

def __repr__(self) -> str:
"""Generate a string representation of this expression."""
return self.raw_sort.__repr__()
Loading
Loading