Skip to content

Commit

Permalink
feat: aggregates as windows (#871)
Browse files Browse the repository at this point in the history
* Add  to turn any aggregate function into a window function

* Rename Window to WindowExpr so we can define Window to mean a window definition to be reused

* Add unit test to cover default frames

* Improve error report
  • Loading branch information
timsaucer authored Sep 18, 2024
1 parent 6c8bf5f commit a00cfbf
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 48 deletions.
57 changes: 56 additions & 1 deletion python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
Union = expr_internal.Union
Unnest = expr_internal.Unnest
UnnestExpr = expr_internal.UnnestExpr
Window = expr_internal.Window
WindowExpr = expr_internal.WindowExpr

__all__ = [
"Expr",
Expand Down Expand Up @@ -154,6 +154,7 @@
"Partitioning",
"Repartition",
"Window",
"WindowExpr",
"WindowFrame",
"WindowFrameBound",
]
Expand Down Expand Up @@ -542,6 +543,36 @@ def window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder:
"""
return ExprFuncBuilder(self.expr.window_frame(window_frame.window_frame))

def over(self, window: Window) -> Expr:
"""Turn an aggregate function into a window function.
This function turns any aggregate function into a window function. With the
exception of ``partition_by``, how each of the parameters is used is determined
by the underlying aggregate function.
Args:
window: Window definition
"""
partition_by_raw = expr_list_to_raw_expr_list(window._partition_by)
order_by_raw = sort_list_to_raw_sort_list(window._order_by)
window_frame_raw = (
window._window_frame.window_frame
if window._window_frame is not None
else None
)
null_treatment_raw = (
window._null_treatment.value if window._null_treatment is not None else None
)

return Expr(
self.expr.over(
partition_by=partition_by_raw,
order_by=order_by_raw,
window_frame=window_frame_raw,
null_treatment=null_treatment_raw,
)
)


class ExprFuncBuilder:
def __init__(self, builder: expr_internal.ExprFuncBuilder):
Expand Down Expand Up @@ -584,6 +615,30 @@ def build(self) -> Expr:
return Expr(self.builder.build())


class Window:
"""Define reusable window parameters."""

def __init__(
self,
partition_by: Optional[list[Expr]] = None,
window_frame: Optional[WindowFrame] = None,
order_by: Optional[list[SortExpr | Expr]] = None,
null_treatment: Optional[NullTreatment] = None,
) -> None:
"""Construct a window definition.
Args:
partition_by: Partitions for window operation
window_frame: Define the start and end bounds of the window frame
order_by: Set ordering
null_treatment: Indicate how nulls are to be treated
"""
self._partition_by = partition_by
self._window_frame = window_frame
self._order_by = order_by
self._null_treatment = null_treatment


class WindowFrame:
"""Defines a window frame for performing window operations."""

Expand Down
75 changes: 54 additions & 21 deletions python/datafusion/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
literal,
udf,
)
from datafusion.expr import Window


@pytest.fixture
Expand Down Expand Up @@ -386,38 +387,32 @@ def test_distinct():
),
[-1, -1, None, 7, -1, -1, None],
),
# TODO update all aggregate functions as windows once upstream merges https://github.com/apache/datafusion-python/issues/833
pytest.param(
(
"first_value",
f.window(
"first_value",
[column("a")],
order_by=[f.order_by(column("b"))],
partition_by=[column("c")],
f.first_value(column("a")).over(
Window(partition_by=[column("c")], order_by=[column("b")])
),
[1, 1, 1, 1, 5, 5, 5],
),
pytest.param(
(
"last_value",
f.window("last_value", [column("a")])
.window_frame(WindowFrame("rows", 0, None))
.order_by(column("b"))
.partition_by(column("c"))
.build(),
f.last_value(column("a")).over(
Window(
partition_by=[column("c")],
order_by=[column("b")],
window_frame=WindowFrame("rows", None, None),
)
),
[3, 3, 3, 3, 6, 6, 6],
),
pytest.param(
(
"3rd_value",
f.window(
"nth_value",
[column("b"), literal(3)],
order_by=[f.order_by(column("a"))],
),
f.nth_value(column("b"), 3).over(Window(order_by=[column("a")])),
[None, None, 7, 7, 7, 7, 7],
),
pytest.param(
(
"avg",
f.round(f.window("avg", [column("b")], order_by=[column("a")]), literal(3)),
f.round(f.avg(column("b")).over(Window(order_by=[column("a")])), literal(3)),
[7.0, 7.0, 7.0, 7.333, 7.75, 7.75, 8.0],
),
]
Expand Down Expand Up @@ -473,6 +468,44 @@ def test_invalid_window_frame(units, start_bound, end_bound):
WindowFrame(units, start_bound, end_bound)


def test_window_frame_defaults_match_postgres(partitioned_df):
# ref: https://github.com/apache/datafusion-python/issues/688

window_frame = WindowFrame("rows", None, None)

col_a = column("a")

# Using `f.window` with or without an unbounded window_frame produces the same
# results. These tests are included as a regression check but can be removed when
# f.window() is deprecated in favor of using the .over() approach.
no_frame = f.window("avg", [col_a]).alias("no_frame")
with_frame = f.window("avg", [col_a], window_frame=window_frame).alias("with_frame")
df_1 = partitioned_df.select(col_a, no_frame, with_frame)

expected = {
"a": [0, 1, 2, 3, 4, 5, 6],
"no_frame": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
"with_frame": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
}

assert df_1.sort(col_a).to_pydict() == expected

# When order is not set, the default frame should be unounded preceeding to
# unbounded following. When order is set, the default frame is unbounded preceeding
# to current row.
no_order = f.avg(col_a).over(Window()).alias("over_no_order")
with_order = f.avg(col_a).over(Window(order_by=[col_a])).alias("over_with_order")
df_2 = partitioned_df.select(col_a, no_order, with_order)

expected = {
"a": [0, 1, 2, 3, 4, 5, 6],
"over_no_order": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
"over_with_order": [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0],
}

assert df_2.sort(col_a).to_pydict() == expected


def test_get_dataframe(tmp_path):
ctx = SessionContext()

Expand Down
46 changes: 44 additions & 2 deletions src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
// under the License.

use datafusion::logical_expr::utils::exprlist_to_fields;
use datafusion::logical_expr::{ExprFuncBuilder, ExprFunctionExt, LogicalPlan};
use datafusion::logical_expr::{
ExprFuncBuilder, ExprFunctionExt, LogicalPlan, WindowFunctionDefinition,
};
use pyo3::{basic::CompareOp, prelude::*};
use std::convert::{From, Into};
use std::sync::Arc;
Expand All @@ -39,6 +41,7 @@ use crate::expr::aggregate_expr::PyAggregateFunction;
use crate::expr::binary_expr::PyBinaryExpr;
use crate::expr::column::PyColumn;
use crate::expr::literal::PyLiteral;
use crate::functions::add_builder_fns_to_window;
use crate::sql::logical::PyLogicalPlan;

use self::alias::PyAlias;
Expand Down Expand Up @@ -558,6 +561,45 @@ impl PyExpr {
pub fn window_frame(&self, window_frame: PyWindowFrame) -> PyExprFuncBuilder {
self.expr.clone().window_frame(window_frame.into()).into()
}

#[pyo3(signature = (partition_by=None, window_frame=None, order_by=None, null_treatment=None))]
pub fn over(
&self,
partition_by: Option<Vec<PyExpr>>,
window_frame: Option<PyWindowFrame>,
order_by: Option<Vec<PySortExpr>>,
null_treatment: Option<NullTreatment>,
) -> PyResult<PyExpr> {
match &self.expr {
Expr::AggregateFunction(agg_fn) => {
let window_fn = Expr::WindowFunction(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(agg_fn.func.clone()),
agg_fn.args.clone(),
));

add_builder_fns_to_window(
window_fn,
partition_by,
window_frame,
order_by,
null_treatment,
)
}
Expr::WindowFunction(_) => add_builder_fns_to_window(
self.expr.clone(),
partition_by,
window_frame,
order_by,
null_treatment,
),
_ => Err(
DataFusionError::ExecutionError(datafusion::error::DataFusionError::Plan(
format!("Using {} with `over` is not allowed. Must use an aggregate or window function.", self.expr.variant_name()),
))
.into(),
),
}
}
}

#[pyclass(name = "ExprFuncBuilder", module = "datafusion.expr", subclass)]
Expand Down Expand Up @@ -749,7 +791,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<drop_table::PyDropTable>()?;
m.add_class::<repartition::PyPartitioning>()?;
m.add_class::<repartition::PyRepartition>()?;
m.add_class::<window::PyWindow>()?;
m.add_class::<window::PyWindowExpr>()?;
m.add_class::<window::PyWindowFrame>()?;
m.add_class::<window::PyWindowFrameBound>()?;
Ok(())
Expand Down
20 changes: 10 additions & 10 deletions src/expr/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ use super::py_expr_list;

use crate::errors::py_datafusion_err;

#[pyclass(name = "Window", module = "datafusion.expr", subclass)]
#[pyclass(name = "WindowExpr", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyWindow {
pub struct PyWindowExpr {
window: Window,
}

Expand Down Expand Up @@ -62,15 +62,15 @@ pub struct PyWindowFrameBound {
frame_bound: WindowFrameBound,
}

impl From<PyWindow> for Window {
fn from(window: PyWindow) -> Window {
impl From<PyWindowExpr> for Window {
fn from(window: PyWindowExpr) -> Window {
window.window
}
}

impl From<Window> for PyWindow {
fn from(window: Window) -> PyWindow {
PyWindow { window }
impl From<Window> for PyWindowExpr {
fn from(window: Window) -> PyWindowExpr {
PyWindowExpr { window }
}
}

Expand All @@ -80,7 +80,7 @@ impl From<WindowFrameBound> for PyWindowFrameBound {
}
}

impl Display for PyWindow {
impl Display for PyWindowExpr {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
Expand All @@ -103,7 +103,7 @@ impl Display for PyWindowFrame {
}

#[pymethods]
impl PyWindow {
impl PyWindowExpr {
/// Returns the schema of the Window
pub fn schema(&self) -> PyResult<PyDFSchema> {
Ok(self.window.schema.as_ref().clone().into())
Expand Down Expand Up @@ -283,7 +283,7 @@ impl PyWindowFrameBound {
}
}

impl LogicalNode for PyWindow {
impl LogicalNode for PyWindowExpr {
fn inputs(&self) -> Vec<PyLogicalPlan> {
vec![self.window.input.as_ref().clone().into()]
}
Expand Down
Loading

0 comments on commit a00cfbf

Please sign in to comment.