From d567705a9bb04a08bd82a85343859342a06ccf43 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Fri, 21 May 2021 19:44:33 +0800 Subject: [PATCH] Squashed commit of the following: commit 7fb3640e733bfbbdbf18d58000896f378ba9644c Author: Jiayu Liu Date: Fri May 21 16:38:25 2021 +0800 row number done commit 17239267cd2fbcbb676d5731beeffd0321bbd3ba Author: Jiayu Liu Date: Fri May 21 16:05:50 2021 +0800 add row number commit bf5b8a56f6f33d8eedf3e3009e7fcdb3c388ea5b Author: Jiayu Liu Date: Fri May 21 15:04:49 2021 +0800 save commit d2ce852ead5d8ae3d15962b4dd3062e24bce51de Author: Jiayu Liu Date: Fri May 21 14:53:05 2021 +0800 add streams commit 0a861a76bde0bb43e5561f1cf1ef14fd64e0c08b Author: Jiayu Liu Date: Thu May 20 22:28:34 2021 +0800 save stream commit a9121af7e2e9104d0e4b6ca3ef4f484aaf8baf42 Author: Jiayu Liu Date: Thu May 20 22:01:51 2021 +0800 update unit test commit 2af2a270262ff1bc759af39153d7cd681c32dc0a Author: Jiayu Liu Date: Fri May 21 14:25:12 2021 +0800 fix unit test commit bb57c762b0a1fabc35e207e681bca2bfff7fcf01 Author: Jiayu Liu Date: Fri May 21 14:23:34 2021 +0800 use upper case commit 5d96e525f587fbfaf3e5e9762c9bb10315fcbc3a Author: Jiayu Liu Date: Fri May 21 14:16:16 2021 +0800 fix unit test commit 1ecae8f6cbc6c1898ccf0b38b1e596b6c2e9bb46 Author: Jiayu Liu Date: Fri May 21 12:27:26 2021 +0800 fix unit test commit bc2271d58fd4a9a9cc96126f8abcd6e8f10272ca Author: Jiayu Liu Date: Fri May 21 10:04:29 2021 +0800 fix error commit 880b94f6e27df61b4d3877366f71a51b9b2f5d5d Author: Jiayu Liu Date: Fri May 21 08:24:00 2021 +0800 fix unit test commit 4e792e123a33fd0dcb5f701c679566b55589b0c0 Author: Jiayu Liu Date: Fri May 21 08:05:17 2021 +0800 fix test commit c36c04abf06c74d016597983bf3d3a2a5b5cbdd5 Author: Jiayu Liu Date: Fri May 21 00:07:54 2021 +0800 add more tests commit f5e64de7192a1916df78a4c2fbab7d471c906720 Author: Jiayu Liu Date: Thu May 20 23:41:36 2021 +0800 update commit a1eae864926a6acfeeebe995a12de4ad725ea869 Author: Jiayu Liu Date: Thu May 20 23:36:15 2021 +0800 enrich unit test commit 0d2a214131fe69e19e22144c68fbb992228db6b3 Author: Jiayu Liu Date: Thu May 20 23:25:43 2021 +0800 adding filter by todo commit 8b486d53b09ff1c7a6b9cf4687796ba1c13d6160 Author: Jiayu Liu Date: Thu May 20 23:17:22 2021 +0800 adding more built-in functions commit abf08cd137a80c1381af7de9ae2b3dab05cb4512 Author: Jiayu Liu Date: Thu May 20 22:36:27 2021 +0800 Update datafusion/src/physical_plan/window_functions.rs Co-authored-by: Andrew Lamb commit 0cbca53dac642233520f7d32289b1dfad77b882e Author: Jiayu Liu Date: Thu May 20 22:34:57 2021 +0800 Update datafusion/src/physical_plan/window_functions.rs Co-authored-by: Andrew Lamb commit 831c069f02236a953653b8f1ca25124e393ce20b Author: Jiayu Liu Date: Thu May 20 22:34:04 2021 +0800 Update datafusion/src/logical_plan/builder.rs Co-authored-by: Andrew Lamb commit f70c739fd40e30c4b476253e58b24b9297b42859 Author: Jiayu Liu Date: Thu May 20 22:33:04 2021 +0800 Update datafusion/src/logical_plan/builder.rs Co-authored-by: Andrew Lamb commit 3ee87aa3477c160f17a86628d71a353e03d736b3 Author: Jiayu Liu Date: Wed May 19 22:55:08 2021 +0800 fix unit test commit 5c4d92dc9f570ba6919d84cb8ac70a736d73f40f Author: Jiayu Liu Date: Wed May 19 22:48:26 2021 +0800 fix clippy commit a0b7526c413abbdd4aadab4af8ca9ad8f323f03b Author: Jiayu Liu Date: Wed May 19 22:46:38 2021 +0800 fix unused imports commit 1d3b076acc1c0f248a19c6149c0634e63a5b836e Author: Jiayu Liu Date: Thu May 13 18:51:14 2021 +0800 add window expr --- datafusion/src/execution/context.rs | 7 + .../src/physical_plan/expressions/mod.rs | 4 + .../physical_plan/expressions/nth_value.rs | 221 ++++++++++ .../physical_plan/expressions/row_number.rs | 102 +++++ .../src/physical_plan/hash_aggregate.rs | 7 +- datafusion/src/physical_plan/mod.rs | 80 +++- datafusion/src/physical_plan/planner.rs | 4 +- datafusion/src/physical_plan/sort.rs | 1 + .../src/physical_plan/window_functions.rs | 77 ++-- datafusion/src/physical_plan/windows.rs | 387 +++++++++++++++++- datafusion/tests/sql.rs | 51 ++- 11 files changed, 876 insertions(+), 65 deletions(-) create mode 100644 datafusion/src/physical_plan/expressions/nth_value.rs create mode 100644 datafusion/src/physical_plan/expressions/row_number.rs diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 272e75acba6f..554c0d99e915 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -1268,6 +1268,13 @@ mod tests { Ok(()) } + #[tokio::test] + async fn window() -> Result<()> { + let results = execute("SELECT c1, MAX(c2) OVER () FROM test", 4).await?; + assert_eq!(results.len(), 1); + Ok(()) + } + #[tokio::test] async fn aggregate() -> Result<()> { let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?; diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 4d57c39bb31c..77da95c3a04a 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -40,7 +40,9 @@ mod literal; mod min_max; mod negative; mod not; +mod nth_value; mod nullif; +mod row_number; mod sum; mod try_cast; @@ -57,7 +59,9 @@ pub use literal::{lit, Literal}; pub use min_max::{Max, Min}; pub use negative::{negative, NegativeExpr}; pub use not::{not, NotExpr}; +pub use nth_value::{FirstValue, LastValue, NthValue}; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; +pub use row_number::RowNumber; pub use sum::{sum_return_type, Sum}; pub use try_cast::{try_cast, TryCastExpr}; /// returns the name of the state diff --git a/datafusion/src/physical_plan/expressions/nth_value.rs b/datafusion/src/physical_plan/expressions/nth_value.rs new file mode 100644 index 000000000000..2fe8be2a5c94 --- /dev/null +++ b/datafusion/src/physical_plan/expressions/nth_value.rs @@ -0,0 +1,221 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{BuiltInWindowFunctionExpr, PhysicalExpr, WindowAccumulator}; +use crate::scalar::ScalarValue; +use arrow::datatypes::{DataType, Field}; +use std::any::Any; +use std::convert::TryFrom; +use std::sync::Arc; + +/// first_value expression +#[derive(Debug)] +pub struct FirstValue { + name: String, + data_type: DataType, + expr: Arc, +} + +impl FirstValue { + /// Create a new FIRST_VALUE window aggregate function + pub fn new(expr: Arc, name: String, data_type: DataType) -> Self { + Self { + name, + data_type, + expr, + } + } +} + +impl BuiltInWindowFunctionExpr for FirstValue { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + let nullable = true; + Ok(Field::new(&self.name, self.data_type.clone(), nullable)) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(NthValueAccumulator::try_new( + 1, + self.data_type.clone(), + )?)) + } +} + +// sql values start with 1, so we can use 0 to indicate the special last value behavior +const SPECIAL_SIZE_VALUE_FOR_LAST: u32 = 0; + +/// last_value expression +#[derive(Debug)] +pub struct LastValue { + name: String, + data_type: DataType, + expr: Arc, +} + +impl LastValue { + /// Create a new FIRST_VALUE window aggregate function + pub fn new(expr: Arc, name: String, data_type: DataType) -> Self { + Self { + name, + data_type, + expr, + } + } +} + +impl BuiltInWindowFunctionExpr for LastValue { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + let nullable = true; + Ok(Field::new(&self.name, self.data_type.clone(), nullable)) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(NthValueAccumulator::try_new( + SPECIAL_SIZE_VALUE_FOR_LAST, + self.data_type.clone(), + )?)) + } +} + +/// nth_value expression +#[derive(Debug)] +pub struct NthValue { + name: String, + n: u32, + data_type: DataType, + expr: Arc, +} + +impl NthValue { + /// Create a new NTH_VALUE window aggregate function + pub fn try_new( + expr: Arc, + name: String, + n: u32, + data_type: DataType, + ) -> Result { + if n == SPECIAL_SIZE_VALUE_FOR_LAST { + Err(DataFusionError::Execution( + "nth_value expect n to be > 0".to_owned(), + )) + } else { + Ok(Self { + name, + n, + data_type, + expr, + }) + } + } +} + +impl BuiltInWindowFunctionExpr for NthValue { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + let nullable = true; + Ok(Field::new(&self.name, self.data_type.clone(), nullable)) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(NthValueAccumulator::try_new( + self.n, + self.data_type.clone(), + )?)) + } +} + +#[derive(Debug)] +struct NthValueAccumulator { + // n the target nth_value, however we'll reuse it for last_value acc, so when n == 0 it specifically + // means last; also note that it is totally valid for n to be larger than the number of rows input + // in which case all the values shall be null + n: u32, + offset: u32, + value: ScalarValue, +} + +impl NthValueAccumulator { + /// new count accumulator + pub fn try_new(n: u32, data_type: DataType) -> Result { + Ok(Self { + n, + offset: 0, + // null value of that data_type by default + value: ScalarValue::try_from(&data_type)?, + }) + } +} + +impl WindowAccumulator for NthValueAccumulator { + fn scan(&mut self, values: &[ScalarValue]) -> Result> { + if self.n == SPECIAL_SIZE_VALUE_FOR_LAST { + // for last_value function + self.value = values[0].clone(); + } else if self.offset < self.n { + self.offset += 1; + if self.offset == self.n { + self.value = values[0].clone(); + } + } + Ok(None) + } + + fn evaluate(&self) -> Result> { + Ok(Some(self.value.clone())) + } +} diff --git a/datafusion/src/physical_plan/expressions/row_number.rs b/datafusion/src/physical_plan/expressions/row_number.rs new file mode 100644 index 000000000000..2a8a2278ff9c --- /dev/null +++ b/datafusion/src/physical_plan/expressions/row_number.rs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use crate::error::Result; +use crate::physical_plan::{BuiltInWindowFunctionExpr, PhysicalExpr, WindowAccumulator}; +use crate::scalar::ScalarValue; +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field}; +use std::any::Any; +use std::sync::Arc; + +/// row_number expression +#[derive(Debug)] +pub struct RowNumber { + name: String, +} + +impl RowNumber { + /// Create a new MAX aggregate function + pub fn new(name: String) -> Self { + Self { name } + } +} + +impl BuiltInWindowFunctionExpr for RowNumber { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + let nullable = false; + let data_type = DataType::UInt64; + Ok(Field::new(&self.name, data_type, nullable)) + } + + fn expressions(&self) -> Vec> { + vec![] + } + + fn name(&self) -> &str { + &self.name + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(RowNumberAccumulator::new())) + } +} + +#[derive(Debug)] +struct RowNumberAccumulator { + row_number: u64, +} + +impl RowNumberAccumulator { + /// new count accumulator + pub fn new() -> Self { + // row number is 1 based + Self { row_number: 1 } + } +} + +impl WindowAccumulator for RowNumberAccumulator { + fn scan(&mut self, _values: &[ScalarValue]) -> Result> { + let result = Some(ScalarValue::UInt64(Some(self.row_number))); + self.row_number += 1; + Ok(result) + } + + fn scan_batch( + &mut self, + num_rows: usize, + _values: &[ArrayRef], + ) -> Result>> { + let new_row_number = self.row_number + (num_rows as u64); + let result = (self.row_number..new_row_number) + .map(|i| ScalarValue::UInt64(Some(i))) + .collect(); + self.row_number = new_row_number; + Ok(Some(result)) + } + + fn evaluate(&self) -> Result> { + Ok(None) + } +} diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index c9d268619cad..5008f49250b0 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -712,7 +712,7 @@ impl GroupedHashAggregateStream { tx.send(result) }); - GroupedHashAggregateStream { + Self { schema, output: rx, finished: false, @@ -825,7 +825,8 @@ fn aggregate_expressions( } pin_project! { - struct HashAggregateStream { + /// stream struct for hash aggregation + pub struct HashAggregateStream { schema: SchemaRef, #[pin] output: futures::channel::oneshot::Receiver>, @@ -878,7 +879,7 @@ impl HashAggregateStream { tx.send(result) }); - HashAggregateStream { + Self { schema, output: rx, finished: false, diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index c053229bc000..a2a3b4a8ec12 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -457,10 +457,46 @@ pub trait WindowExpr: Send + Sync + Debug { fn name(&self) -> &str { "WindowExpr: default name" } + + /// the accumulator used to accumulate values from the expressions. + /// the accumulator expects the same number of arguments as `expressions` and must + /// return states with the same description as `state_fields` + fn create_accumulator(&self) -> Result>; + + /// expressions that are passed to the WindowAccumulator. + /// Single-column aggregations such as `sum` return a single value, others (e.g. `cov`) return many. + fn expressions(&self) -> Vec>; +} + +/// A window expression that is a built-in window function +pub trait BuiltInWindowFunctionExpr: Send + Sync + Debug { + /// Returns the aggregate expression as [`Any`](std::any::Any) so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; + + /// the field of the final result of this aggregation. + fn field(&self) -> Result; + + /// expressions that are passed to the Accumulator. + /// Single-column aggregations such as `sum` return a single value, others (e.g. `cov`) return many. + fn expressions(&self) -> Vec>; + + /// Human readable name such as `"MIN(c2)"` or `"RANK()"`. The default + /// implementation returns placeholder text. + fn name(&self) -> &str { + "BuiltInWindowFunctionExpr: default name" + } + + /// the accumulator used to accumulate values from the expressions. + /// the accumulator expects the same number of arguments as `expressions` and must + /// return states with the same description as `state_fields` + fn create_accumulator(&self) -> Result>; } /// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and -/// generically accumulates values. An accumulator knows how to: +/// generically accumulates values. +/// +/// An accumulator knows how to: /// * update its state from inputs via `update` /// * convert its internal state to a vector of scalar values /// * update its state from multiple accumulators' states via `merge` @@ -509,6 +545,48 @@ pub trait Accumulator: Send + Sync + Debug { fn evaluate(&self) -> Result; } +/// A window accumulator represents a stateful object that lives throughout the evaluation of multiple +/// rows and generically accumulates values. +/// +/// An accumulator knows how to: +/// * update its state from inputs via `update` +/// * convert its internal state to a vector of scalar values +/// * update its state from multiple accumulators' states via `merge` +/// * compute the final value from its internal state via `evaluate` +pub trait WindowAccumulator: Send + Sync + Debug { + /// scans the accumulator's state from a vector of scalars, similar to Accumulator it also + /// optionally generates values. + fn scan(&mut self, values: &[ScalarValue]) -> Result>; + + /// scans the accumulator's state from a vector of arrays. + fn scan_batch( + &mut self, + num_rows: usize, + values: &[ArrayRef], + ) -> Result>> { + // note that for row_number and rank this might be different + if values.is_empty() { + return Ok(None); + }; + // transpose columnar to row based so that we can apply window + let result: Vec> = (0..num_rows) + .map(|index| { + let v = values + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::>>()?; + self.scan(&v) + }) + .into_iter() + .collect::>>>()?; + let result: Option> = result.into_iter().collect(); + Ok(result) + } + + /// returns its value based on its current state. + fn evaluate(&self) -> Result>; +} + pub mod aggregates; pub mod array_expressions; pub mod coalesce_batches; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 018925d0e535..7ddfaf8f6897 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -147,8 +147,10 @@ impl DefaultPhysicalPlanner { // Initially need to perform the aggregate and then merge the partitions let input_exec = self.create_initial_plan(input, ctx_state)?; let input_schema = input_exec.schema(); - let physical_input_schema = input_exec.as_ref().schema(); + let logical_input_schema = input.as_ref().schema(); + let physical_input_schema = input_exec.as_ref().schema(); + let window_expr = window_expr .iter() .map(|e| { diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sort.rs index caa32cfa264e..175bf9ce5431 100644 --- a/datafusion/src/physical_plan/sort.rs +++ b/datafusion/src/physical_plan/sort.rs @@ -227,6 +227,7 @@ fn sort_batches( } pin_project! { + /// stream for sort plan struct SortStream { #[pin] output: futures::channel::oneshot::Receiver>>, diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs index 65d5373d54f4..e3693f310fa7 100644 --- a/datafusion/src/physical_plan/window_functions.rs +++ b/datafusion/src/physical_plan/window_functions.rs @@ -143,49 +143,64 @@ impl FromStr for BuiltInWindowFunction { /// Returns the datatype of the window function pub fn return_type(fun: &WindowFunction, arg_types: &[DataType]) -> Result { + match fun { + WindowFunction::AggregateFunction(fun) => aggregates::return_type(fun, arg_types), + WindowFunction::BuiltInWindowFunction(fun) => { + return_type_for_built_in(fun, arg_types) + } + } +} + +/// Returns the datatype of the built-in window function +pub(super) fn return_type_for_built_in( + fun: &BuiltInWindowFunction, + arg_types: &[DataType], +) -> Result { // Note that this function *must* return the same type that the respective physical expression returns // or the execution panics. // verify that this is a valid set of data types for this function - data_types(arg_types, &signature(fun))?; + data_types(arg_types, &signature_for_built_in(fun))?; match fun { - WindowFunction::AggregateFunction(fun) => aggregates::return_type(fun, arg_types), - WindowFunction::BuiltInWindowFunction(fun) => match fun { - BuiltInWindowFunction::RowNumber - | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64), - BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { - Ok(DataType::Float64) - } - BuiltInWindowFunction::Ntile => Ok(DataType::UInt32), - BuiltInWindowFunction::Lag - | BuiltInWindowFunction::Lead - | BuiltInWindowFunction::FirstValue - | BuiltInWindowFunction::LastValue - | BuiltInWindowFunction::NthValue => Ok(arg_types[0].clone()), - }, + BuiltInWindowFunction::RowNumber + | BuiltInWindowFunction::Rank + | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64), + BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { + Ok(DataType::Float64) + } + BuiltInWindowFunction::Ntile => Ok(DataType::UInt32), + BuiltInWindowFunction::Lag + | BuiltInWindowFunction::Lead + | BuiltInWindowFunction::FirstValue + | BuiltInWindowFunction::LastValue + | BuiltInWindowFunction::NthValue => Ok(arg_types[0].clone()), } } /// the signatures supported by the function `fun`. -fn signature(fun: &WindowFunction) -> Signature { - // note: the physical expression must accept the type returned by this function or the execution panics. +pub fn signature(fun: &WindowFunction) -> Signature { match fun { WindowFunction::AggregateFunction(fun) => aggregates::signature(fun), - WindowFunction::BuiltInWindowFunction(fun) => match fun { - BuiltInWindowFunction::RowNumber - | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank - | BuiltInWindowFunction::PercentRank - | BuiltInWindowFunction::CumeDist => Signature::Any(0), - BuiltInWindowFunction::Lag - | BuiltInWindowFunction::Lead - | BuiltInWindowFunction::FirstValue - | BuiltInWindowFunction::LastValue => Signature::Any(1), - BuiltInWindowFunction::Ntile => Signature::Exact(vec![DataType::UInt64]), - BuiltInWindowFunction::NthValue => Signature::Any(2), - }, + WindowFunction::BuiltInWindowFunction(fun) => signature_for_built_in(fun), + } +} + +/// the signatures supported by the built-in window function `fun`. +pub(super) fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature { + // note: the physical expression must accept the type returned by this function or the execution panics. + match fun { + BuiltInWindowFunction::RowNumber + | BuiltInWindowFunction::Rank + | BuiltInWindowFunction::DenseRank + | BuiltInWindowFunction::PercentRank + | BuiltInWindowFunction::CumeDist => Signature::Any(0), + BuiltInWindowFunction::Lag + | BuiltInWindowFunction::Lead + | BuiltInWindowFunction::FirstValue + | BuiltInWindowFunction::LastValue => Signature::Any(1), + BuiltInWindowFunction::Ntile => Signature::Exact(vec![DataType::UInt64]), + BuiltInWindowFunction::NthValue => Signature::Any(2), } } diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index bdd25d69fd55..0e59d23d01c0 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -19,13 +19,32 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ - aggregates, window_functions::WindowFunction, AggregateExpr, Distribution, - ExecutionPlan, Partitioning, PhysicalExpr, SendableRecordBatchStream, WindowExpr, + aggregates, + expressions::{FirstValue, LastValue, Literal, NthValue, RowNumber}, + type_coercion::coerce, + window_functions::{signature_for_built_in, BuiltInWindowFunction, WindowFunction}, + Accumulator, AggregateExpr, BuiltInWindowFunctionExpr, Distribution, ExecutionPlan, + Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, + WindowAccumulator, WindowExpr, +}; +use crate::scalar::ScalarValue; +use arrow::compute::concat; +use arrow::{ + array::ArrayRef, + datatypes::{Field, Schema, SchemaRef}, + error::{ArrowError, Result as ArrowResult}, + record_batch::RecordBatch, }; -use arrow::datatypes::{Field, Schema, SchemaRef}; use async_trait::async_trait; +use futures::stream::{Stream, StreamExt}; +use futures::Future; +use pin_project_lite::pin_project; use std::any::Any; +use std::convert::TryInto; +use std::iter; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; /// Window execution plan #[derive(Debug)] @@ -57,18 +76,83 @@ pub fn create_window_expr( name, )?, })), - WindowFunction::BuiltInWindowFunction(fun) => { - Err(DataFusionError::NotImplemented(format!( - "window function with {:?} not implemented", - fun - ))) + WindowFunction::BuiltInWindowFunction(fun) => Ok(Arc::new(BuiltInWindowExpr { + window: create_built_in_window_expr(fun, args, input_schema, name)?, + })), + } +} + +fn create_built_in_window_expr( + fun: &BuiltInWindowFunction, + args: &[Arc], + input_schema: &Schema, + name: String, +) -> Result> { + match fun { + BuiltInWindowFunction::RowNumber => Ok(Arc::new(RowNumber::new(name))), + BuiltInWindowFunction::NthValue => { + let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?; + let arg = coerced_args[0].clone(); + let n = coerced_args[1] + .as_any() + .downcast_ref::() + .unwrap() + .value(); + let n: i64 = n + .clone() + .try_into() + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; + let n: u32 = n as u32; + let data_type = args[0].data_type(input_schema)?; + Ok(Arc::new(NthValue::try_new(arg, name, n, data_type)?)) + } + BuiltInWindowFunction::FirstValue => { + let arg = + coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone(); + let data_type = args[0].data_type(input_schema)?; + Ok(Arc::new(FirstValue::new(arg, name, data_type))) } + BuiltInWindowFunction::LastValue => { + let arg = + coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone(); + let data_type = args[0].data_type(input_schema)?; + Ok(Arc::new(LastValue::new(arg, name, data_type))) + } + _ => Err(DataFusionError::NotImplemented(format!( + "Window function with {:?} not yet implemented", + fun + ))), } } /// A window expr that takes the form of a built in window function #[derive(Debug)] -pub struct BuiltInWindowExpr {} +pub struct BuiltInWindowExpr { + window: Arc, +} + +impl WindowExpr for BuiltInWindowExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.window.name() + } + + fn field(&self) -> Result { + self.window.field() + } + + fn expressions(&self) -> Vec> { + self.window.expressions() + } + + fn create_accumulator(&self) -> Result> { + self.window.create_accumulator() + } +} /// A window expr that takes the form of an aggregate function #[derive(Debug)] @@ -76,6 +160,23 @@ pub struct AggregateWindowExpr { aggregate: Arc, } +#[derive(Debug)] +struct AggregateWindowAccumulator { + accumulator: Box, +} + +impl WindowAccumulator for AggregateWindowAccumulator { + fn scan(&mut self, values: &[ScalarValue]) -> Result> { + self.accumulator.update(values)?; + Ok(None) + } + + /// returns its value based on its current state. + fn evaluate(&self) -> Result> { + Ok(Some(self.accumulator.evaluate()?)) + } +} + impl WindowExpr for AggregateWindowExpr { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -89,6 +190,15 @@ impl WindowExpr for AggregateWindowExpr { fn field(&self) -> Result { self.aggregate.field() } + + fn expressions(&self) -> Vec> { + self.aggregate.expressions() + } + + fn create_accumulator(&self) -> Result> { + let accumulator = self.aggregate.create_accumulator()?; + Ok(Box::new(AggregateWindowAccumulator { accumulator })) + } } fn create_schema( @@ -120,12 +230,17 @@ impl WindowAggExec { }) } + /// Window expressions + pub fn window_expr(&self) -> &[Arc] { + &self.window_expr + } + /// Input plan pub fn input(&self) -> &Arc { &self.input } - /// Get the input schema before any aggregates are applied + /// Get the input schema before any window functions are applied pub fn input_schema(&self) -> SchemaRef { self.input_schema.clone() } @@ -163,7 +278,7 @@ impl ExecutionPlan for WindowAggExec { 1 => Ok(Arc::new(WindowAggExec::try_new( self.window_expr.clone(), children[0].clone(), - children[0].schema(), + self.input_schema.clone(), )?)), _ => Err(DataFusionError::Internal( "WindowAggExec wrong number of children".to_owned(), @@ -186,10 +301,252 @@ impl ExecutionPlan for WindowAggExec { )); } - // let input = self.input.execute(0).await?; + let input = self.input.execute(partition).await?; + + let stream = Box::pin(WindowAggStream::new( + self.schema.clone(), + self.window_expr.clone(), + input, + )); + Ok(stream) + } +} + +pin_project! { + /// stream for window aggregation plan + pub struct WindowAggStream { + schema: SchemaRef, + #[pin] + output: futures::channel::oneshot::Receiver>, + finished: bool, + } +} + +type WindowAccumulatorItem = Box; + +fn window_expressions( + window_expr: &[Arc], +) -> Result>>> { + Ok(window_expr + .iter() + .map(|expr| expr.expressions()) + .collect::>()) +} + +fn window_aggregate_batch( + batch: &RecordBatch, + window_accumulators: &mut [WindowAccumulatorItem], + expressions: &[Vec>], +) -> Result>>> { + // 1.1 iterate accumulators and respective expressions together + // 1.2 evaluate expressions + // 1.3 update / merge window accumulators with the expressions' values + + // 1.1 + window_accumulators + .iter_mut() + .zip(expressions) + .map(|(window_acc, expr)| { + // 1.2 + let values = &expr + .iter() + .map(|e| e.evaluate(batch)) + .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .collect::>>()?; + + window_acc.scan_batch(batch.num_rows(), values) + }) + .into_iter() + .collect::>>() +} + +/// returns a vector of ArrayRefs, where each entry corresponds to either the +/// final value (mode = Final) or states (mode = Partial) +fn finalize_window_aggregation( + window_accumulators: &[WindowAccumulatorItem], +) -> Result>> { + window_accumulators + .iter() + .map(|window_accumulator| window_accumulator.evaluate()) + .collect::>>() +} + +fn create_window_accumulators( + window_expr: &[Arc], +) -> Result> { + window_expr + .iter() + .map(|expr| expr.create_accumulator()) + .collect::>>() +} + +async fn compute_window_aggregate( + schema: SchemaRef, + window_expr: Vec>, + mut input: SendableRecordBatchStream, +) -> ArrowResult { + let mut window_accumulators = create_window_accumulators(&window_expr) + .map_err(DataFusionError::into_arrow_external_error)?; + + let expressions = window_expressions(&window_expr) + .map_err(DataFusionError::into_arrow_external_error)?; + + let expressions = Arc::new(expressions); + + // TODO each element shall have some size hint + let mut accumulator: Vec> = + iter::repeat(vec![]).take(window_expr.len()).collect(); + + let mut original_batches: Vec = vec![]; + + let mut total_num_rows = 0; + + while let Some(batch) = input.next().await { + let batch = batch?; + total_num_rows += batch.num_rows(); + original_batches.push(batch.clone()); + + let batch_aggregated = + window_aggregate_batch(&batch, &mut window_accumulators, &expressions) + .map_err(DataFusionError::into_arrow_external_error)?; + accumulator.iter_mut().zip(batch_aggregated).for_each( + |(acc_for_window, window_batch)| { + if let Some(data) = window_batch { + acc_for_window.extend(data); + } + }, + ); + } + + let aggregated_mapped = finalize_window_aggregation(&window_accumulators) + .map_err(DataFusionError::into_arrow_external_error)?; + + let mut columns: Vec = accumulator + .iter() + .zip(aggregated_mapped) + .map(|(acc, agg)| match (acc, agg) { + // either accumulator values or the aggregated values are non-empty, but not both + (acc, Some(scalar_value)) if acc.is_empty() => { + Ok(scalar_value.to_array_of_size(total_num_rows)) + } + (acc, None) if !acc.is_empty() => ScalarValue::iter_to_array(acc), + _ => Err(DataFusionError::Execution( + "Invalid window function behavior".to_owned(), + )), + }) + .collect::>>() + .map_err(DataFusionError::into_arrow_external_error)?; + + for i in 0..(schema.fields().len() - window_expr.len()) { + let col = concat( + &original_batches + .iter() + .map(|batch| batch.column(i).as_ref()) + .collect::>(), + )?; + columns.push(col); + } + + RecordBatch::try_new(schema.clone(), columns) +} + +impl WindowAggStream { + /// Create a new WindowAggStream + pub fn new( + schema: SchemaRef, + window_expr: Vec>, + input: SendableRecordBatchStream, + ) -> Self { + let (tx, rx) = futures::channel::oneshot::channel(); + let schema_clone = schema.clone(); + tokio::spawn(async move { + let result = compute_window_aggregate(schema_clone, window_expr, input).await; + tx.send(result) + }); + + Self { + output: rx, + finished: false, + schema, + } + } +} + +impl Stream for WindowAggStream { + type Item = ArrowResult; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.finished { + return Poll::Ready(None); + } + + // is the output ready? + let this = self.project(); + let output_poll = this.output.poll(cx); - Err(DataFusionError::NotImplemented( - "WindowAggExec::execute".to_owned(), - )) + match output_poll { + Poll::Ready(result) => { + *this.finished = true; + // check for error in receiving channel and unwrap actual result + let result = match result { + Err(e) => Some(Err(ArrowError::ExternalError(Box::new(e)))), // error receiving + Ok(result) => Some(result), + }; + Poll::Ready(result) + } + Poll::Pending => Poll::Pending, + } } } + +impl RecordBatchStream for WindowAggStream { + /// Get the schema + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +#[cfg(test)] +mod tests { + // use super::*; + + // /// some mock data to test windows + // fn some_data() -> (Arc, Vec) { + // // define a schema. + // let schema = Arc::new(Schema::new(vec![ + // Field::new("a", DataType::UInt32, false), + // Field::new("b", DataType::Float64, false), + // ])); + + // // define data. + // ( + // schema.clone(), + // vec![ + // RecordBatch::try_new( + // schema.clone(), + // vec![ + // Arc::new(UInt32Array::from(vec![2, 3, 4, 4])), + // Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), + // ], + // ) + // .unwrap(), + // RecordBatch::try_new( + // schema, + // vec![ + // Arc::new(UInt32Array::from(vec![2, 3, 3, 4])), + // Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), + // ], + // ) + // .unwrap(), + // ], + // ) + // } + + // #[tokio::test] + // async fn window_function() -> Result<()> { + // let input: Arc = unimplemented!(); + // let input_schema = input.schema(); + // let window_expr = vec![]; + // WindowAggExec::try_new(window_expr, input, input_schema); + // } +} diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index e68c53b251e6..967f3866a515 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -797,20 +797,43 @@ async fn csv_query_count() -> Result<()> { Ok(()) } -// FIXME uncomment this when exec is done -// #[tokio::test] -// async fn csv_query_window_with_empty_over() -> Result<()> { -// let mut ctx = ExecutionContext::new(); -// register_aggregate_csv(&mut ctx)?; -// let sql = "SELECT count(c12) over () FROM aggregate_test_100"; -// // FIXME: so far the WindowAggExec is not implemented -// // and the current behavior is to throw not implemented exception - -// let result = execute(&mut ctx, sql).await; -// let expected: Vec> = vec![]; -// assert_eq!(result, expected); -// Ok(()) -// } +#[tokio::test] +async fn csv_query_window_with_empty_over() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx)?; + let sql = "select \ + c2, \ + row_number() over (), + sum(c3) over (), \ + avg(c3) over (), \ + count(c3) over (), \ + max(c3) over (), \ + min(c3) over (), \ + first_value(c3) over (), \ + last_value(c3) over (), \ + nth_value(c3, 2) over () + from aggregate_test_100 limit 5"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![ + vec![ + "2", "1", "781", "7.81", "100", "125", "-117", "1", "30", "-40", + ], + vec![ + "5", "2", "781", "7.81", "100", "125", "-117", "1", "30", "-40", + ], + vec![ + "1", "3", "781", "7.81", "100", "125", "-117", "1", "30", "-40", + ], + vec![ + "1", "4", "781", "7.81", "100", "125", "-117", "1", "30", "-40", + ], + vec![ + "5", "5", "781", "7.81", "100", "125", "-117", "1", "30", "-40", + ], + ]; + assert_eq!(expected, actual); + Ok(()) +} #[tokio::test] async fn csv_query_group_by_int_count() -> Result<()> {