Skip to content

Commit

Permalink
Support for bounded execution when window frame involves UNBOUNDED PR…
Browse files Browse the repository at this point in the history
…ECEDING (#5003)

* initial support for aggregators

* Move common functionality to super trait for aggregates

* update tests

* bounded first_value, nth_value, last_value support

* nth_value bug fix

* minor changes

* Review and refactor

* Change naming: NonSliding -> Plain

* remove redundant check

* Remove window function state

* minor changes

* Remove unnecessary continue

* Address reviews

* Address reviews

* Address reviews

Co-authored-by: Mehmet Ozan Kabak <ozankabak@gmail.com>
  • Loading branch information
mustafasrepo and ozankabak authored Jan 23, 2023
1 parent 930c8de commit 624f02d
Show file tree
Hide file tree
Showing 13 changed files with 629 additions and 325 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -548,13 +548,14 @@ impl SortedPartitionByBoundedWindowStream {
for window_agg_state in self.window_agg_states.iter_mut() {
window_agg_state.retain(|_, WindowState { state, .. }| !state.is_end);
for (partition_row, WindowState { state: value, .. }) in window_agg_state {
let n_prune =
min(value.window_frame_range.start, value.last_calculated_index);
if let Some(state) = n_prune_each_partition.get_mut(partition_row) {
if value.window_frame_range.start < *state {
*state = value.window_frame_range.start;
if n_prune < *state {
*state = n_prune;
}
} else {
n_prune_each_partition
.insert(partition_row.clone(), value.window_frame_range.start);
n_prune_each_partition.insert(partition_row.clone(), n_prune);
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions datafusion/core/src/physical_plan/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ mod window_agg_exec;

pub use bounded_window_agg_exec::BoundedWindowAggExec;
pub use datafusion_physical_expr::window::{
AggregateWindowExpr, BuiltInWindowExpr, WindowExpr,
BuiltInWindowExpr, PlainAggregateWindowExpr, WindowExpr,
};
pub use window_agg_exec::WindowAggExec;

Expand All @@ -70,7 +70,7 @@ pub fn create_window_expr(
window_frame,
))
} else {
Arc::new(AggregateWindowExpr::new(
Arc::new(PlainAggregateWindowExpr::new(
aggregate,
partition_by,
order_by,
Expand All @@ -84,7 +84,7 @@ pub fn create_window_expr(
order_by,
window_frame,
)),
WindowFunction::AggregateUDF(fun) => Arc::new(AggregateWindowExpr::new(
WindowFunction::AggregateUDF(fun) => Arc::new(PlainAggregateWindowExpr::new(
udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?,
partition_by,
order_by,
Expand Down
139 changes: 127 additions & 12 deletions datafusion/core/tests/sql/window.rs

Large diffs are not rendered by default.

156 changes: 118 additions & 38 deletions datafusion/core/tests/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use std::sync::Arc;

use arrow::array::{ArrayRef, Int32Array};
use arrow::compute::{concat_batches, SortOptions};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use arrow::util::pretty::pretty_format_batches;
use hashbrown::HashMap;
Expand All @@ -38,7 +39,7 @@ use datafusion_expr::{
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_common::ScalarValue;
use datafusion_physical_expr::expressions::{col, lit};
use datafusion_physical_expr::PhysicalSortExpr;
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use test_utils::add_empty_batches;

#[cfg(test)]
Expand All @@ -51,7 +52,7 @@ mod tests {
let distincts = vec![1, 100];
for distinct in distincts {
let mut handles = Vec::new();
for i in 1..n {
for i in 0..n {
let job = tokio::spawn(run_window_test(
make_staggered_batches::<true>(1000, distinct, i),
i,
Expand All @@ -74,7 +75,7 @@ mod tests {
// since we have sorted pairs (a,b) to not violate per partition soring
// partition should be field a, order by should be field b
let mut handles = Vec::new();
for i in 1..n {
for i in 0..n {
let job = tokio::spawn(run_window_test(
make_staggered_batches::<true>(1000, distinct, i),
i,
Expand All @@ -90,17 +91,11 @@ mod tests {
}
}

/// Perform batch and running window same input
/// and verify outputs of `WindowAggExec` and `BoundedWindowAggExec` are equal
async fn run_window_test(
input1: Vec<RecordBatch>,
random_seed: u64,
orderby_columns: Vec<&str>,
partition_by_columns: Vec<&str>,
) {
let mut rng = StdRng::seed_from_u64(random_seed);
let schema = input1[0].schema();
let mut args = vec![col("x", &schema).unwrap()];
fn get_random_function(
schema: &SchemaRef,
rng: &mut StdRng,
) -> (WindowFunction, Vec<Arc<dyn PhysicalExpr>>, String) {
let mut args = vec![col("x", schema).unwrap()];
let mut window_fn_map = HashMap::new();
// HashMap values consists of tuple first element is WindowFunction, second is additional argument
// window function requires if any. For most of the window functions additional argument is empty
Expand Down Expand Up @@ -188,16 +183,44 @@ async fn run_window_test(
),
);

let session_config = SessionConfig::new().with_batch_size(50);
let ctx = SessionContext::with_config(session_config);
let rand_fn_idx = rng.gen_range(0..window_fn_map.len());
let fn_name = window_fn_map.keys().collect::<Vec<_>>()[rand_fn_idx];
let (window_fn, new_args) = window_fn_map.values().collect::<Vec<_>>()[rand_fn_idx];
for new_arg in new_args {
args.push(new_arg.clone());
}
let preceding = rng.gen_range(0..50);
let following = rng.gen_range(0..50);

(window_fn.clone(), args, fn_name.to_string())
}

fn get_random_window_frame(rng: &mut StdRng) -> WindowFrame {
struct Utils {
val: i32,
is_preceding: bool,
}
let first_bound = Utils {
val: rng.gen_range(0..50),
is_preceding: rng.gen_range(0..2) == 0,
};
let second_bound = Utils {
val: rng.gen_range(0..50),
is_preceding: rng.gen_range(0..2) == 0,
};
let (start_bound, end_bound) =
if first_bound.is_preceding == second_bound.is_preceding {
if (first_bound.val > second_bound.val && first_bound.is_preceding)
|| (first_bound.val < second_bound.val && !first_bound.is_preceding)
{
(first_bound, second_bound)
} else {
(second_bound, first_bound)
}
} else if first_bound.is_preceding {
(first_bound, second_bound)
} else {
(second_bound, first_bound)
};
// 0 means Range, 1 means Rows, 2 means GROUPS
let rand_num = rng.gen_range(0..3);
let units = if rand_num < 1 {
WindowFrameUnits::Range
Expand All @@ -208,26 +231,83 @@ async fn run_window_test(
// TODO: once GROUPS handling is available, use WindowFrameUnits::GROUPS in randomized tests also.
WindowFrameUnits::Range
};
let window_frame = match units {
match units {
// In range queries window frame boundaries should match column type
WindowFrameUnits::Range => WindowFrame {
units,
start_bound: WindowFrameBound::Preceding(ScalarValue::Int32(Some(preceding))),
end_bound: WindowFrameBound::Following(ScalarValue::Int32(Some(following))),
},
WindowFrameUnits::Range => {
let start_bound = if start_bound.is_preceding {
WindowFrameBound::Preceding(ScalarValue::Int32(Some(start_bound.val)))
} else {
WindowFrameBound::Following(ScalarValue::Int32(Some(start_bound.val)))
};
let end_bound = if end_bound.is_preceding {
WindowFrameBound::Preceding(ScalarValue::Int32(Some(end_bound.val)))
} else {
WindowFrameBound::Following(ScalarValue::Int32(Some(end_bound.val)))
};
let mut window_frame = WindowFrame {
units,
start_bound,
end_bound,
};
// with 10% use unbounded preceding in tests
if rng.gen_range(0..10) == 0 {
window_frame.start_bound =
WindowFrameBound::Preceding(ScalarValue::Int32(None));
}
window_frame
}
// In window queries, window frame boundary should be Uint64
WindowFrameUnits::Rows => WindowFrame {
units,
start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(
preceding as u64,
))),
end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(
following as u64,
))),
},
WindowFrameUnits::Rows => {
let start_bound = if start_bound.is_preceding {
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(
start_bound.val as u64,
)))
} else {
WindowFrameBound::Following(ScalarValue::UInt64(Some(
start_bound.val as u64,
)))
};
let end_bound = if end_bound.is_preceding {
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(
end_bound.val as u64,
)))
} else {
WindowFrameBound::Following(ScalarValue::UInt64(Some(
end_bound.val as u64,
)))
};
let mut window_frame = WindowFrame {
units,
start_bound,
end_bound,
};
// with 10% use unbounded preceding in tests
if rng.gen_range(0..10) == 0 {
window_frame.start_bound =
WindowFrameBound::Preceding(ScalarValue::UInt64(None));
}
window_frame
}
// Once GROUPS support is added construct window frame for this case also
_ => todo!(),
};
}
}

/// Perform batch and running window same input
/// and verify outputs of `WindowAggExec` and `BoundedWindowAggExec` are equal
async fn run_window_test(
input1: Vec<RecordBatch>,
random_seed: u64,
orderby_columns: Vec<&str>,
partition_by_columns: Vec<&str>,
) {
let mut rng = StdRng::seed_from_u64(random_seed);
let schema = input1[0].schema();
let session_config = SessionConfig::new().with_batch_size(50);
let ctx = SessionContext::with_config(session_config);
let (window_fn, args, fn_name) = get_random_function(&schema, &mut rng);

let window_frame = get_random_window_frame(&mut rng);
let mut orderby_exprs = vec![];
for column in orderby_columns {
orderby_exprs.push(PhysicalSortExpr {
Expand Down Expand Up @@ -257,8 +337,8 @@ async fn run_window_test(
let usual_window_exec = Arc::new(
WindowAggExec::try_new(
vec![create_window_expr(
window_fn,
fn_name.to_string(),
&window_fn,
fn_name.clone(),
&args,
&partitionby_exprs,
&orderby_exprs,
Expand All @@ -278,8 +358,8 @@ async fn run_window_test(
let running_window_exec = Arc::new(
BoundedWindowAggExec::try_new(
vec![create_window_expr(
window_fn,
fn_name.to_string(),
&window_fn,
fn_name,
&args,
&partitionby_exprs,
&orderby_exprs,
Expand Down
10 changes: 9 additions & 1 deletion datafusion/physical-expr/src/aggregate/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use datafusion_expr::Accumulator;
use datafusion_row::accessor::RowAccessor;

/// AVG aggregate expression
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Avg {
name: String,
expr: Arc<dyn PhysicalExpr>,
Expand Down Expand Up @@ -111,6 +111,10 @@ impl AggregateExpr for Avg {
is_row_accumulator_support_dtype(&self.data_type)
}

fn supports_bounded_execution(&self) -> bool {
true
}

fn create_row_accumulator(
&self,
start_index: usize,
Expand All @@ -121,6 +125,10 @@ impl AggregateExpr for Avg {
)))
}

fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
Some(Arc::new(self.clone()))
}

fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(AvgAccumulator::try_new(&self.data_type)?))
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-expr/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ pub use crate::window::cume_dist::cume_dist;
pub use crate::window::cume_dist::CumeDist;
pub use crate::window::lead_lag::WindowShift;
pub use crate::window::lead_lag::{lag, lead};
pub use crate::window::nth_value::{NthValue, NthValueKind};
pub use crate::window::nth_value::NthValue;
pub use crate::window::ntile::Ntile;
pub use crate::window::rank::{dense_rank, percent_rank, rank};
pub use crate::window::rank::{Rank, RankType};
Expand Down
Loading

0 comments on commit 624f02d

Please sign in to comment.