From 9f6a56b3feb1f330cb3cc5a9dac1a77f5a52ba95 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Sun, 13 Jun 2021 13:22:41 +0800 Subject: [PATCH] add window function implementation with order_by clause --- datafusion/src/execution/context.rs | 55 +++- .../physical_plan/expressions/nth_value.rs | 137 ++++------ .../physical_plan/expressions/row_number.rs | 89 +------ .../src/physical_plan/hash_aggregate.rs | 4 +- datafusion/src/physical_plan/mod.rs | 130 +++++----- datafusion/src/physical_plan/planner.rs | 15 +- .../src/physical_plan/window_functions.rs | 14 +- datafusion/src/physical_plan/windows.rs | 244 +++++++++--------- datafusion/src/scalar.rs | 2 +- datafusion/src/sql/planner.rs | 2 +- datafusion/tests/sql.rs | 147 +++++++++-- .../simple_window_ordered_aggregation.sql | 26 ++ integration-tests/test_psql_parity.py | 2 +- 13 files changed, 476 insertions(+), 391 deletions(-) create mode 100644 integration-tests/sqls/simple_window_ordered_aggregation.sql diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index f09d7f4f90c9..183524497940 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -1273,7 +1273,17 @@ mod tests { #[tokio::test] async fn window() -> Result<()> { let results = execute( - "SELECT c1, c2, SUM(c2) OVER (), COUNT(c2) OVER (), MAX(c2) OVER (), MIN(c2) OVER (), AVG(c2) OVER () FROM test ORDER BY c1, c2 LIMIT 5", + "SELECT \ + c1, \ + c2, \ + SUM(c2) OVER (), \ + COUNT(c2) OVER (), \ + MAX(c2) OVER (), \ + MIN(c2) OVER (), \ + AVG(c2) OVER () \ + FROM test \ + ORDER BY c1, c2 \ + LIMIT 5", 4, ) .await?; @@ -1299,6 +1309,49 @@ mod tests { Ok(()) } + #[tokio::test] + async fn window_order_by() -> Result<()> { + let results = execute( + "SELECT \ + c1, \ + c2, \ + ROW_NUMBER() OVER (ORDER BY c1, c2), \ + FIRST_VALUE(c2) OVER (ORDER BY c1, c2), \ + LAST_VALUE(c2) OVER (ORDER BY c1, c2), \ + NTH_VALUE(c2, 2) OVER (ORDER BY c1, c2), \ + SUM(c2) OVER (ORDER BY c1, c2), \ + COUNT(c2) OVER (ORDER BY c1, c2), \ + MAX(c2) OVER (ORDER BY c1, c2), \ + MIN(c2) OVER (ORDER BY c1, c2), \ + AVG(c2) OVER (ORDER BY c1, c2) \ + FROM test \ + ORDER BY c1, c2 \ + LIMIT 5", + 4, + ) + .await?; + // result in one batch, although e.g. having 2 batches do not change + // result semantics, having a len=1 assertion upfront keeps surprises + // at bay + assert_eq!(results.len(), 1); + + let expected = vec![ + "+----+----+--------------+-----------------+----------------+------------------------+---------+-----------+---------+---------+---------+", + "| c1 | c2 | ROW_NUMBER() | FIRST_VALUE(c2) | LAST_VALUE(c2) | NTH_VALUE(c2,Int64(2)) | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |", + "+----+----+--------------+-----------------+----------------+------------------------+---------+-----------+---------+---------+---------+", + "| 0 | 1 | 1 | 1 | 10 | 2 | 1 | 1 | 1 | 1 | 1 |", + "| 0 | 2 | 2 | 1 | 10 | 2 | 3 | 2 | 2 | 1 | 1.5 |", + "| 0 | 3 | 3 | 1 | 10 | 2 | 6 | 3 | 3 | 1 | 2 |", + "| 0 | 4 | 4 | 1 | 10 | 2 | 10 | 4 | 4 | 1 | 2.5 |", + "| 0 | 5 | 5 | 1 | 10 | 2 | 15 | 5 | 5 | 1 | 3 |", + "+----+----+--------------+-----------------+----------------+------------------------+---------+-----------+---------+---------+---------+", + ]; + + // window function shall respect ordering + assert_batches_eq!(expected, &results); + Ok(()) + } + #[tokio::test] async fn aggregate() -> Result<()> { let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?; diff --git a/datafusion/src/physical_plan/expressions/nth_value.rs b/datafusion/src/physical_plan/expressions/nth_value.rs index fb0e79f7ad3c..98083fa26eaa 100644 --- a/datafusion/src/physical_plan/expressions/nth_value.rs +++ b/datafusion/src/physical_plan/expressions/nth_value.rs @@ -18,13 +18,11 @@ //! Defines physical expressions that can evaluated at runtime during query execution use crate::error::{DataFusionError, Result}; -use crate::physical_plan::{ - window_functions::BuiltInWindowFunctionExpr, PhysicalExpr, WindowAccumulator, -}; +use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr}; use crate::scalar::ScalarValue; +use arrow::array::{new_empty_array, ArrayRef}; use arrow::datatypes::{DataType, Field}; use std::any::Any; -use std::convert::TryFrom; use std::sync::Arc; /// nth_value kind @@ -113,54 +111,32 @@ impl BuiltInWindowFunctionExpr for NthValue { &self.name } - fn create_accumulator(&self) -> Result> { - Ok(Box::new(NthValueAccumulator::try_new( - self.kind, - self.data_type.clone(), - )?)) - } -} - -#[derive(Debug)] -struct NthValueAccumulator { - kind: NthValueKind, - offset: u32, - value: ScalarValue, -} - -impl NthValueAccumulator { - /// new count accumulator - pub fn try_new(kind: NthValueKind, data_type: DataType) -> Result { - Ok(Self { - kind, - offset: 0, - // null value of that data_type by default - value: ScalarValue::try_from(&data_type)?, - }) - } -} - -impl WindowAccumulator for NthValueAccumulator { - fn scan(&mut self, values: &[ScalarValue]) -> Result> { - self.offset += 1; - match self.kind { - NthValueKind::Last => { - self.value = values[0].clone(); - } - NthValueKind::First if self.offset == 1 => { - self.value = values[0].clone(); - } - NthValueKind::Nth(n) if self.offset == n => { - self.value = values[0].clone(); - } - _ => {} + fn evaluate(&self, num_rows: usize, values: &[ArrayRef]) -> Result { + if values.is_empty() { + return Err(DataFusionError::Execution(format!( + "No arguments supplied to {}", + self.name() + ))); } - - Ok(None) - } - - fn evaluate(&self) -> Result> { - Ok(Some(self.value.clone())) + let value = &values[0]; + if value.len() != num_rows { + return Err(DataFusionError::Execution(format!( + "Invalid data supplied to {}, expect {} rows, got {} rows", + self.name(), + num_rows, + value.len() + ))); + } + if num_rows == 0 { + return Ok(new_empty_array(value.data_type())); + } + let index: usize = match self.kind { + NthValueKind::First => 0, + NthValueKind::Last => (num_rows as usize) - 1, + NthValueKind::Nth(n) => (n as usize) - 1, + }; + let value = ScalarValue::try_from_array(value, index)?; + Ok(value.to_array_of_size(num_rows)) } } @@ -172,68 +148,47 @@ mod tests { use arrow::record_batch::RecordBatch; use arrow::{array::*, datatypes::*}; - fn test_i32_result(expr: Arc, expected: i32) -> Result<()> { + fn test_i32_result(expr: NthValue, expected: Vec) -> Result<()> { let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8])); + let values = vec![arr]; let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]); - let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?; - - let mut acc = expr.create_accumulator()?; - let expr = expr.expressions(); - let values = expr - .iter() - .map(|e| e.evaluate(&batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) - .collect::>>()?; - let result = acc.scan_batch(batch.num_rows(), &values)?; - assert_eq!(false, result.is_some()); - let result = acc.evaluate()?; - assert_eq!(Some(ScalarValue::Int32(Some(expected))), result); + let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?; + let result = expr.evaluate(batch.num_rows(), &values)?; + let result = result.as_any().downcast_ref::().unwrap(); + let result = result.values(); + assert_eq!(expected, result); Ok(()) } #[test] fn first_value() -> Result<()> { - let first_value = Arc::new(NthValue::first_value( - "first_value".to_owned(), - col("arr"), - DataType::Int32, - )); - test_i32_result(first_value, 1)?; + let first_value = + NthValue::first_value("first_value".to_owned(), col("arr"), DataType::Int32); + test_i32_result(first_value, vec![1; 8])?; Ok(()) } #[test] fn last_value() -> Result<()> { - let last_value = Arc::new(NthValue::last_value( - "last_value".to_owned(), - col("arr"), - DataType::Int32, - )); - test_i32_result(last_value, 8)?; + let last_value = + NthValue::last_value("last_value".to_owned(), col("arr"), DataType::Int32); + test_i32_result(last_value, vec![8; 8])?; Ok(()) } #[test] fn nth_value_1() -> Result<()> { - let nth_value = Arc::new(NthValue::nth_value( - "nth_value".to_owned(), - col("arr"), - DataType::Int32, - 1, - )?); - test_i32_result(nth_value, 1)?; + let nth_value = + NthValue::nth_value("nth_value".to_owned(), col("arr"), DataType::Int32, 1)?; + test_i32_result(nth_value, vec![1; 8])?; Ok(()) } #[test] fn nth_value_2() -> Result<()> { - let nth_value = Arc::new(NthValue::nth_value( - "nth_value".to_owned(), - col("arr"), - DataType::Int32, - 2, - )?); - test_i32_result(nth_value, -2)?; + let nth_value = + NthValue::nth_value("nth_value".to_owned(), col("arr"), DataType::Int32, 2)?; + test_i32_result(nth_value, vec![-2; 8])?; Ok(()) } } diff --git a/datafusion/src/physical_plan/expressions/row_number.rs b/datafusion/src/physical_plan/expressions/row_number.rs index eaf9b21cbc64..0444ee971f40 100644 --- a/datafusion/src/physical_plan/expressions/row_number.rs +++ b/datafusion/src/physical_plan/expressions/row_number.rs @@ -18,10 +18,7 @@ //! Defines physical expression for `row_number` that can evaluated at runtime during query execution use crate::error::Result; -use crate::physical_plan::{ - window_functions::BuiltInWindowFunctionExpr, PhysicalExpr, WindowAccumulator, -}; -use crate::scalar::ScalarValue; +use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr}; use arrow::array::{ArrayRef, UInt64Array}; use arrow::datatypes::{DataType, Field}; use std::any::Any; @@ -60,46 +57,10 @@ impl BuiltInWindowFunctionExpr for RowNumber { self.name.as_str() } - fn create_accumulator(&self) -> Result> { - Ok(Box::new(RowNumberAccumulator::new())) - } -} - -#[derive(Debug)] -struct RowNumberAccumulator { - row_number: u64, -} - -impl RowNumberAccumulator { - /// new row_number accumulator - pub fn new() -> Self { - // row number is 1 based - Self { row_number: 1 } - } -} - -impl WindowAccumulator for RowNumberAccumulator { - fn scan(&mut self, _values: &[ScalarValue]) -> Result> { - let result = Some(ScalarValue::UInt64(Some(self.row_number))); - self.row_number += 1; - Ok(result) - } - - fn scan_batch( - &mut self, - num_rows: usize, - _values: &[ArrayRef], - ) -> Result> { - let new_row_number = self.row_number + (num_rows as u64); - // TODO: probably would be nice to have a (optimized) kernel for this at some point to - // generate an array like this. - let result = UInt64Array::from_iter_values(self.row_number..new_row_number); - self.row_number = new_row_number; - Ok(Some(Arc::new(result))) - } - - fn evaluate(&self) -> Result> { - Ok(None) + fn evaluate(&self, num_rows: usize, _values: &[ArrayRef]) -> Result { + Ok(Arc::new(UInt64Array::from_iter_values( + (1..num_rows + 1).map(|i| i as u64), + ))) } } @@ -117,27 +78,11 @@ mod tests { ])); let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, false)]); let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?; - - let row_number = Arc::new(RowNumber::new("row_number".to_owned())); - - let mut acc = row_number.create_accumulator()?; - let expr = row_number.expressions(); - let values = expr - .iter() - .map(|e| e.evaluate(&batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) - .collect::>>()?; - - let result = acc.scan_batch(batch.num_rows(), &values)?; - assert_eq!(true, result.is_some()); - - let result = result.unwrap(); + let row_number = RowNumber::new("row_number".to_owned()); + let result = row_number.evaluate(batch.num_rows(), &[])?; let result = result.as_any().downcast_ref::().unwrap(); let result = result.values(); assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result); - - let result = acc.evaluate()?; - assert_eq!(false, result.is_some()); Ok(()) } @@ -148,27 +93,11 @@ mod tests { ])); let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, false)]); let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?; - - let row_number = Arc::new(RowNumber::new("row_number".to_owned())); - - let mut acc = row_number.create_accumulator()?; - let expr = row_number.expressions(); - let values = expr - .iter() - .map(|e| e.evaluate(&batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) - .collect::>>()?; - - let result = acc.scan_batch(batch.num_rows(), &values)?; - assert_eq!(true, result.is_some()); - - let result = result.unwrap(); + let row_number = RowNumber::new("row_number".to_owned()); + let result = row_number.evaluate(batch.num_rows(), &[])?; let result = result.as_any().downcast_ref::().unwrap(); let result = result.values(); assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result); - - let result = acc.evaluate()?; - assert_eq!(false, result.is_some()); Ok(()) } } diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index 453d500e98bd..f1611ebd7a77 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -500,7 +500,7 @@ fn dictionary_create_key_for_col( let dict_col = col.as_any().downcast_ref::>().unwrap(); // look up the index in the values dictionary - let keys_col = dict_col.keys_array(); + let keys_col = dict_col.keys(); let values_index = keys_col.value(row).to_usize().ok_or_else(|| { DataFusionError::Internal(format!( "Can not convert index to usize in dictionary of type creating group by value {:?}", @@ -1083,7 +1083,7 @@ fn dictionary_create_group_by_value( let dict_col = col.as_any().downcast_ref::>().unwrap(); // look up the index in the values dictionary - let keys_col = dict_col.keys_array(); + let keys_col = dict_col.keys(); let values_index = keys_col.value(row).to_usize().ok_or_else(|| { DataFusionError::Internal(format!( "Can not convert index to usize in dictionary of type creating group by value {:?}", diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 2dcba802560a..713956f00a9e 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -17,17 +17,16 @@ //! Traits for physical query plan, supporting parallel execution for partitioned relations. -use std::fmt; -use std::fmt::{Debug, Display}; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; - +use self::{display::DisplayableExecutionPlan, merge::MergeExec}; use crate::execution::context::ExecutionContextState; use crate::logical_plan::LogicalPlan; +use crate::physical_plan::expressions::PhysicalSortExpr; use crate::{ error::{DataFusionError, Result}, scalar::ScalarValue, }; +use arrow::compute::kernels::partition::lexicographical_partition_ranges; +use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::datatypes::{DataType, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; @@ -35,10 +34,13 @@ use arrow::{array::ArrayRef, datatypes::Field}; use async_trait::async_trait; pub use display::DisplayFormatType; use futures::stream::Stream; -use std::{any::Any, pin::Pin}; - -use self::{display::DisplayableExecutionPlan, merge::MergeExec}; use hashbrown::HashMap; +use std::fmt; +use std::fmt::{Debug, Display}; +use std::ops::Range; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::{any::Any, pin::Pin}; /// Trait for types that stream [arrow::record_batch::RecordBatch] pub trait RecordBatchStream: Stream> { @@ -465,15 +467,65 @@ pub trait WindowExpr: Send + Sync + Debug { "WindowExpr: default name" } - /// the accumulator used to accumulate values from the expressions. - /// the accumulator expects the same number of arguments as `expressions` and must - /// return states with the same description as `state_fields` - fn create_accumulator(&self) -> Result>; - /// expressions that are passed to the WindowAccumulator. /// Functions which take a single input argument, such as `sum`, return a single [`Expr`], /// others (e.g. `cov`) return many. fn expressions(&self) -> Vec>; + + /// evaluate the window function arguments against the batch and return + /// array ref, normally the resulting vec is a single element one. + fn evaluate_args(&self, batch: &RecordBatch) -> Result> { + self.expressions() + .iter() + .map(|e| e.evaluate(batch)) + .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .collect() + } + + /// evaluate the window function values against the batch + fn evaluate(&self, batch: &RecordBatch) -> Result; + + /// evaluate the sort partition points + fn evaluate_sort_partition_points( + &self, + batch: &RecordBatch, + ) -> Result>> { + let sort_columns = self.sort_columns(batch)?; + if sort_columns.is_empty() { + Ok(vec![Range { + start: 0, + end: batch.num_rows(), + }]) + } else { + lexicographical_partition_ranges(&sort_columns) + .map_err(DataFusionError::ArrowError) + } + } + + /// expressions that's from the window function's partition by clause, empty if absent + fn partition_by(&self) -> &[Arc]; + + /// expressions that's from the window function's order by clause, empty if absent + fn order_by(&self) -> &[PhysicalSortExpr]; + + /// get sort columns that can be used for partitioning, empty if absent + fn sort_columns(&self, batch: &RecordBatch) -> Result> { + self.partition_by() + .iter() + .map(|expr| { + PhysicalSortExpr { + expr: expr.clone(), + options: SortOptions::default(), + } + .evaluate_to_sort_column(batch) + }) + .chain( + self.order_by() + .iter() + .map(|e| e.evaluate_to_sort_column(batch)), + ) + .collect() + } } /// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and @@ -528,58 +580,6 @@ pub trait Accumulator: Send + Sync + Debug { fn evaluate(&self) -> Result; } -/// A window accumulator represents a stateful object that lives throughout the evaluation of multiple -/// rows and generically accumulates values. -/// -/// An accumulator knows how to: -/// * update its state from inputs via `update` -/// * convert its internal state to a vector of scalar values -/// * update its state from multiple accumulators' states via `merge` -/// * compute the final value from its internal state via `evaluate` -pub trait WindowAccumulator: Send + Sync + Debug { - /// scans the accumulator's state from a vector of scalars, similar to Accumulator it also - /// optionally generates values. - fn scan(&mut self, values: &[ScalarValue]) -> Result>; - - /// scans the accumulator's state from a vector of arrays. - fn scan_batch( - &mut self, - num_rows: usize, - values: &[ArrayRef], - ) -> Result> { - if values.is_empty() { - return Ok(None); - }; - // transpose columnar to row based so that we can apply window - let result = (0..num_rows) - .map(|index| { - let v = values - .iter() - .map(|array| ScalarValue::try_from_array(array, index)) - .collect::>>()?; - self.scan(&v) - }) - .collect::>>>()? - .into_iter() - .collect::>>(); - - Ok(match result { - Some(arr) if num_rows == arr.len() => Some(ScalarValue::iter_to_array(arr)?), - None => None, - Some(arr) => { - return Err(DataFusionError::Internal(format!( - "expect scan batch to return {:?} rows, but got {:?}", - num_rows, - arr.len() - ))) - } - }) - } - - /// returns its value based on its current state. - fn evaluate(&self) -> Result>; -} - pub mod aggregates; pub mod array_expressions; pub mod coalesce_batches; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 31b3749dd354..1121c28184bd 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -143,7 +143,12 @@ impl DefaultPhysicalPlanner { LogicalPlan::Window { input, window_expr, .. } => { - // Initially need to perform the aggregate and then merge the partitions + if window_expr.is_empty() { + return Err(DataFusionError::Internal( + "Impossibly got empty window expression".to_owned(), + )); + } + let input_exec = self.create_initial_plan(input, ctx_state)?; let input_schema = input_exec.schema(); @@ -364,7 +369,7 @@ impl DefaultPhysicalPlanner { let left_expr = keys.iter().map(|x| col(&x.0)).collect(); let right_expr = keys.iter().map(|x| col(&x.1)).collect(); - // Use hash partition by defualt to parallelize hash joins + // Use hash partition by default to parallelize hash joins Ok(Arc::new(HashJoinExec::try_new( Arc::new(RepartitionExec::try_new( left, @@ -776,12 +781,6 @@ impl DefaultPhysicalPlanner { .to_owned(), )); } - if !order_by.is_empty() { - return Err(DataFusionError::NotImplemented( - "window expression with non-empty order by clause is not yet supported" - .to_owned(), - )); - } if window_frame.is_some() { return Err(DataFusionError::NotImplemented( "window expression with window frame definition is not yet supported" diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs index e6afcaad8ad6..4f56aa7d3826 100644 --- a/datafusion/src/physical_plan/window_functions.rs +++ b/datafusion/src/physical_plan/window_functions.rs @@ -20,11 +20,12 @@ //! //! see also https://www.postgresql.org/docs/current/functions-window.html +use crate::arrow::array::ArrayRef; use crate::arrow::datatypes::Field; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ aggregates, aggregates::AggregateFunction, functions::Signature, - type_coercion::data_types, PhysicalExpr, WindowAccumulator, + type_coercion::data_types, PhysicalExpr, }; use arrow::datatypes::DataType; use std::any::Any; @@ -207,7 +208,10 @@ pub(super) fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature { } } -/// A window expression that is a built-in window function +/// A window expression that is a built-in window function. +/// +/// Note that unlike aggregation based window functions, built-in window functions normally ignore +/// window frame spec, with the exception of first_value, last_value, and nth_value. pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { /// Returns the aggregate expression as [`Any`](std::any::Any) so that it can be /// downcast to a specific implementation. @@ -226,10 +230,8 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { "BuiltInWindowFunctionExpr: default name" } - /// the accumulator used to accumulate values from the expressions. - /// the accumulator expects the same number of arguments as `expressions` and must - /// return states with the same description as `state_fields` - fn create_accumulator(&self) -> Result>; + /// Evaluate the built-in window function against the number of rows and the arguments + fn evaluate(&self, num_rows: usize, values: &[ArrayRef]) -> Result; } #[cfg(test)] diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index f95dd446844d..e5570971cf16 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -18,8 +18,7 @@ //! Execution plan for window functions use crate::error::{DataFusionError, Result}; - -use crate::logical_plan::window_frames::WindowFrame; +use crate::logical_plan::window_frames::{WindowFrame, WindowFrameUnits}; use crate::physical_plan::{ aggregates, common, expressions::{Literal, NthValue, PhysicalSortExpr, RowNumber}, @@ -28,9 +27,9 @@ use crate::physical_plan::{ window_functions::BuiltInWindowFunctionExpr, window_functions::{BuiltInWindowFunction, WindowFunction}, Accumulator, AggregateExpr, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, - RecordBatchStream, SendableRecordBatchStream, WindowAccumulator, WindowExpr, + RecordBatchStream, SendableRecordBatchStream, WindowExpr, }; -use crate::scalar::ScalarValue; +use arrow::compute::concat; use arrow::{ array::ArrayRef, datatypes::{Field, Schema, SchemaRef}, @@ -43,6 +42,7 @@ use futures::Future; use pin_project_lite::pin_project; use std::any::Any; use std::convert::TryInto; +use std::ops::Range; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -65,12 +65,9 @@ pub fn create_window_expr( fun: &WindowFunction, name: String, args: &[Arc], - // https://github.com/apache/arrow-datafusion/issues/299 - _partition_by: &[Arc], - // https://github.com/apache/arrow-datafusion/issues/360 - _order_by: &[PhysicalSortExpr], - // https://github.com/apache/arrow-datafusion/issues/361 - _window_frame: Option, + partition_by: &[Arc], + order_by: &[PhysicalSortExpr], + window_frame: Option, input_schema: &Schema, ) -> Result> { Ok(match fun { @@ -82,9 +79,15 @@ pub fn create_window_expr( input_schema, name, )?, + partition_by: partition_by.to_vec(), + order_by: order_by.to_vec(), + window_frame, }), WindowFunction::BuiltInWindowFunction(fun) => Arc::new(BuiltInWindowExpr { window: create_built_in_window_expr(fun, args, input_schema, name)?, + partition_by: partition_by.to_vec(), + order_by: order_by.to_vec(), + window_frame, }), }) } @@ -136,6 +139,9 @@ fn create_built_in_window_expr( #[derive(Debug)] pub struct BuiltInWindowExpr { window: Arc, + partition_by: Vec>, + order_by: Vec, + window_frame: Option, } impl WindowExpr for BuiltInWindowExpr { @@ -156,8 +162,20 @@ impl WindowExpr for BuiltInWindowExpr { self.window.expressions() } - fn create_accumulator(&self) -> Result> { - self.window.create_accumulator() + fn partition_by(&self) -> &[Arc] { + &self.partition_by + } + + fn order_by(&self) -> &[PhysicalSortExpr] { + &self.order_by + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + // FIXME, for now we assume all the rows belong to the same partition, which will not be the + // case when partition_by is supported, in which case we'll parallelize the calls. + // See https://github.com/apache/arrow-datafusion/issues/299 + let values = self.evaluate_args(batch)?; + self.window.evaluate(batch.num_rows(), &values) } } @@ -165,22 +183,51 @@ impl WindowExpr for BuiltInWindowExpr { #[derive(Debug)] pub struct AggregateWindowExpr { aggregate: Arc, + partition_by: Vec>, + order_by: Vec, + window_frame: Option, } -#[derive(Debug)] -struct AggregateWindowAccumulator { - accumulator: Box, -} +impl AggregateWindowExpr { + /// the aggregate window function operates based on window frame, and by default the mode is + /// "range". + fn evaluation_mode(&self) -> WindowFrameUnits { + self.window_frame.unwrap_or_default().units + } -impl WindowAccumulator for AggregateWindowAccumulator { - fn scan(&mut self, values: &[ScalarValue]) -> Result> { - self.accumulator.update(values)?; - Ok(None) + /// create a new accumulator based on the underlying aggregation function + fn create_accumulator(&self) -> Result { + let accumulator = self.aggregate.create_accumulator()?; + Ok(AggregateWindowAccumulator { accumulator }) } - /// returns its value based on its current state. - fn evaluate(&self) -> Result> { - Ok(Some(self.accumulator.evaluate()?)) + /// peer based evaluation based on the fact that batch is pre-sorted given the sort columns + /// and then per partition point we'll evaluate the peer group (e.g. SUM or MAX gives the same + /// results for peers) and concatenate the results. + fn peer_based_evaluate(&self, batch: &RecordBatch) -> Result { + let sort_partition_points = self.evaluate_sort_partition_points(batch)?; + let mut window_accumulators = self.create_accumulator()?; + let values = self.evaluate_args(batch)?; + let results = sort_partition_points + .iter() + .map(|peer_range| window_accumulators.scan_peers(&values, peer_range)) + .collect::>>()?; + let results = results.iter().map(|i| i.as_ref()).collect::>(); + concat(&results).map_err(DataFusionError::ArrowError) + } + + fn group_based_evaluate(&self, _batch: &RecordBatch) -> Result { + Err(DataFusionError::NotImplemented(format!( + "Group based evaluation for {} is not yet implemented", + self.name() + ))) + } + + fn row_based_evaluate(&self, _batch: &RecordBatch) -> Result { + Err(DataFusionError::NotImplemented(format!( + "Row based evaluation for {} is not yet implemented", + self.name() + ))) } } @@ -202,9 +249,55 @@ impl WindowExpr for AggregateWindowExpr { self.aggregate.expressions() } - fn create_accumulator(&self) -> Result> { - let accumulator = self.aggregate.create_accumulator()?; - Ok(Box::new(AggregateWindowAccumulator { accumulator })) + fn partition_by(&self) -> &[Arc] { + &self.partition_by + } + + fn order_by(&self) -> &[PhysicalSortExpr] { + &self.order_by + } + + /// evaluate the window function values against the batch + fn evaluate(&self, batch: &RecordBatch) -> Result { + // FIXME, for now we assume all the rows belong to the same partition, which will not be the + // case when partition_by is supported, in which case we'll parallelize the calls. + // See https://github.com/apache/arrow-datafusion/issues/299 + match self.evaluation_mode() { + WindowFrameUnits::Range => self.peer_based_evaluate(batch), + WindowFrameUnits::Rows => self.row_based_evaluate(batch), + WindowFrameUnits::Groups => self.group_based_evaluate(batch), + } + } +} + +/// Aggregate window accumulator utilizes the accumulator from aggregation and do a accumulative sum +/// across evaluation arguments based on peer equivalences. +#[derive(Debug)] +struct AggregateWindowAccumulator { + accumulator: Box, +} + +impl AggregateWindowAccumulator { + /// scan one peer group of values (as arguments to window function) given by the value_range + /// and return evaluation result that are of the same number of rows. + fn scan_peers( + &mut self, + values: &[ArrayRef], + value_range: &Range, + ) -> Result { + if value_range.is_empty() { + return Err(DataFusionError::Internal( + "Value range cannot be empty".to_owned(), + )); + } + let len = value_range.end - value_range.start; + let values = values + .iter() + .map(|v| v.slice(value_range.start, len)) + .collect::>(); + self.accumulator.update_batch(&values)?; + let value = self.accumulator.evaluate()?; + Ok(value.to_array_of_size(len)) } } @@ -329,106 +422,17 @@ pin_project! { } } -type WindowAccumulatorItem = Box; - -fn window_expressions( - window_expr: &[Arc], -) -> Result>>> { - Ok(window_expr - .iter() - .map(|expr| expr.expressions()) - .collect::>()) -} - -fn window_aggregate_batch( - batch: &RecordBatch, - window_accumulators: &mut [WindowAccumulatorItem], - expressions: &[Vec>], -) -> Result>> { - window_accumulators - .iter_mut() - .zip(expressions) - .map(|(window_acc, expr)| { - let values = &expr - .iter() - .map(|e| e.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) - .collect::>>()?; - window_acc.scan_batch(batch.num_rows(), values) - }) - .collect::>>() -} - -/// returns a vector of ArrayRefs, where each entry corresponds to one window expr -fn finalize_window_aggregation( - window_accumulators: &[WindowAccumulatorItem], -) -> Result>> { - window_accumulators - .iter() - .map(|window_accumulator| window_accumulator.evaluate()) - .collect::>>() -} - -fn create_window_accumulators( - window_expr: &[Arc], -) -> Result> { - window_expr - .iter() - .map(|expr| expr.create_accumulator()) - .collect::>>() -} - /// Compute the window aggregate columns -/// -/// 1. get a list of window accumulators -/// 2. evaluate the args -/// 3. scan args with window functions -/// 4. concat with final aggregations -/// -/// FIXME so far this fn does not support: -/// 1. partition by -/// 2. order by -/// 3. window frame -/// -/// which will require further work: -/// 1. inter-partition order by using vec partition-point (https://github.com/apache/arrow-datafusion/issues/360) -/// 2. inter-partition parallelism using one-shot channel (https://github.com/apache/arrow-datafusion/issues/299) -/// 3. convert aggregation based window functions to be self-contain so that: (https://github.com/apache/arrow-datafusion/issues/361) -/// a. some can be grow-only window-accumulating -/// b. some can be grow-and-shrink window-accumulating -/// c. some can be based on segment tree fn compute_window_aggregates( window_expr: Vec>, batch: &RecordBatch, ) -> Result> { - let mut window_accumulators = create_window_accumulators(&window_expr)?; - let expressions = Arc::new(window_expressions(&window_expr)?); - let num_rows = batch.num_rows(); - let window_aggregates = - window_aggregate_batch(batch, &mut window_accumulators, &expressions)?; - let final_aggregates = finalize_window_aggregation(&window_accumulators)?; - - // both must equal to window_expr.len() - if window_aggregates.len() != final_aggregates.len() { - return Err(DataFusionError::Internal( - "Impossibly got len mismatch".to_owned(), - )); - } - - window_aggregates + // FIXME, for now we assume all the rows belong to the same partition, which will not be the + // case when partition_by is supported, in which case we'll parallelize the calls. + // See https://github.com/apache/arrow-datafusion/issues/299 + window_expr .iter() - .zip(final_aggregates) - .map(|(wa, fa)| { - Ok(match (wa, fa) { - (None, Some(fa)) => fa.to_array_of_size(num_rows), - (Some(wa), None) if wa.len() == num_rows => wa.clone(), - _ => { - return Err(DataFusionError::Execution( - "Invalid window function behavior".to_owned(), - )) - } - }) - }) + .map(|window_expr| window_expr.evaluate(batch)) .collect() } diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index ac7deeed22c7..933bb8cebcb1 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -771,7 +771,7 @@ impl ScalarValue { let dict_array = array.as_any().downcast_ref::>().unwrap(); // look up the index in the values dictionary - let keys_col = dict_array.keys_array(); + let keys_col = dict_array.keys(); let values_index = keys_col.value(index).to_usize().ok_or_else(|| { DataFusionError::Internal(format!( "Can not convert index to usize in dictionary of type creating group by value {:?}", diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index e860bd74641d..4c1d8610dfdd 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -703,7 +703,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut plan = input; let mut groups = group_window_expr_by_sort_keys(&window_exprs)?; // sort by sort_key len descending, so that more deeply sorted plans gets nested further - // down as children; to further minic the behavior of PostgreSQL, we want stable sort + // down as children; to further mimic the behavior of PostgreSQL, we want stable sort // and a reverse so that tieing sort keys are reversed in order; note that by this rule // if there's an empty over, it'll be at the top level groups.sort_by(|(key_a, _), (key_b, _)| key_a.len().cmp(&key_b.len())); diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index d9d77648c742..21da793b5538 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -802,25 +802,142 @@ async fn csv_query_window_with_empty_over() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx)?; let sql = "select \ - c2, \ - sum(c3) over (), \ - avg(c3) over (), \ - count(c3) over (), \ - max(c3) over (), \ - min(c3) over (), \ - first_value(c3) over (), \ - last_value(c3) over (), \ - nth_value(c3, 2) over () + c9, \ + count(c5) over (), \ + max(c5) over (), \ + min(c5) over (), \ + first_value(c5) over (), \ + last_value(c5) over (), \ + nth_value(c5, 2) over () \ from aggregate_test_100 \ - order by c2 + order by c9 \ limit 5"; let actual = execute(&mut ctx, sql).await; let expected = vec![ - vec!["1", "781", "7.81", "100", "125", "-117", "1", "30", "-40"], - vec!["1", "781", "7.81", "100", "125", "-117", "1", "30", "-40"], - vec!["1", "781", "7.81", "100", "125", "-117", "1", "30", "-40"], - vec!["1", "781", "7.81", "100", "125", "-117", "1", "30", "-40"], - vec!["1", "781", "7.81", "100", "125", "-117", "1", "30", "-40"], + vec![ + "28774375", + "100", + "2143473091", + "-2141999138", + "2033001162", + "61035129", + "706441268", + ], + vec![ + "63044568", + "100", + "2143473091", + "-2141999138", + "2033001162", + "61035129", + "706441268", + ], + vec![ + "141047417", + "100", + "2143473091", + "-2141999138", + "2033001162", + "61035129", + "706441268", + ], + vec![ + "141680161", + "100", + "2143473091", + "-2141999138", + "2033001162", + "61035129", + "706441268", + ], + vec![ + "145294611", + "100", + "2143473091", + "-2141999138", + "2033001162", + "61035129", + "706441268", + ], + ]; + assert_eq!(expected, actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_window_with_order_by() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx)?; + let sql = "select \ + c9, \ + sum(c5) over (order by c9), \ + avg(c5) over (order by c9), \ + count(c5) over (order by c9), \ + max(c5) over (order by c9), \ + min(c5) over (order by c9), \ + first_value(c5) over (order by c9), \ + last_value(c5) over (order by c9), \ + nth_value(c5, 2) over (order by c9) \ + from aggregate_test_100 \ + order by c9 \ + limit 5"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![ + vec![ + "28774375", + "61035129", + "61035129", + "1", + "61035129", + "61035129", + "61035129", + "2025611582", + "-108973366", + ], + vec![ + "63044568", + "-47938237", + "-23969118.5", + "2", + "61035129", + "-108973366", + "61035129", + "2025611582", + "-108973366", + ], + vec![ + "141047417", + "575165281", + "191721760.33333334", + "3", + "623103518", + "-108973366", + "61035129", + "2025611582", + "-108973366", + ], + vec![ + "141680161", + "-1352462829", + "-338115707.25", + "4", + "623103518", + "-1927628110", + "61035129", + "2025611582", + "-108973366", + ], + vec![ + "145294611", + "-3251637940", + "-650327588", + "5", + "623103518", + "-1927628110", + "61035129", + "2025611582", + "-108973366", + ], ]; assert_eq!(expected, actual); Ok(()) diff --git a/integration-tests/sqls/simple_window_ordered_aggregation.sql b/integration-tests/sqls/simple_window_ordered_aggregation.sql new file mode 100644 index 000000000000..d9f467b0cb09 --- /dev/null +++ b/integration-tests/sqls/simple_window_ordered_aggregation.sql @@ -0,0 +1,26 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at + +-- http://www.apache.org/licenses/LICENSE-2.0 + +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language gOVERning permissions and +-- limitations under the License. + +SELECT + c9, + row_number() OVER (ORDER BY c2, c9) AS row_number, + count(c3) OVER (ORDER BY c9) AS count_c3, + avg(c3) OVER (ORDER BY c2) AS avg_c3_by_c2, + sum(c3) OVER (ORDER BY c2) AS sum_c3_by_c2, + max(c3) OVER (ORDER BY c2) AS max_c3_by_c2, + min(c3) OVER (ORDER BY c2) AS min_c3_by_c2 +FROM test +ORDER BY row_number; diff --git a/integration-tests/test_psql_parity.py b/integration-tests/test_psql_parity.py index 51861c583f8a..4e0878c24b81 100644 --- a/integration-tests/test_psql_parity.py +++ b/integration-tests/test_psql_parity.py @@ -74,7 +74,7 @@ class PsqlParityTest(unittest.TestCase): def test_parity(self): root = Path(os.path.dirname(__file__)) / "sqls" files = set(root.glob("*.sql")) - self.assertEqual(len(files), 6, msg="tests are missed") + self.assertEqual(len(files), 7, msg="tests are missed") for fname in files: with self.subTest(fname=fname): datafusion_output = pd.read_csv(