From 18a6ffa4c59f46c3f70c37c07a853288e3ffd2bc Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Mon, 31 May 2021 09:51:15 +0800 Subject: [PATCH] optimize nth_value --- .../src/physical_plan/expressions/mod.rs | 2 +- .../physical_plan/expressions/nth_value.rs | 229 ++++++++++-------- datafusion/src/physical_plan/windows.rs | 8 +- 3 files changed, 128 insertions(+), 111 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 77da95c3a04a..d18365c47ed5 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -59,7 +59,7 @@ pub use literal::{lit, Literal}; pub use min_max::{Max, Min}; pub use negative::{negative, NegativeExpr}; pub use not::{not, NotExpr}; -pub use nth_value::{FirstValue, LastValue, NthValue}; +pub use nth_value::NthValue; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use row_number::RowNumber; pub use sum::{sum_return_type, Sum}; diff --git a/datafusion/src/physical_plan/expressions/nth_value.rs b/datafusion/src/physical_plan/expressions/nth_value.rs index e90ad322aae9..ecf1a4c000b9 100644 --- a/datafusion/src/physical_plan/expressions/nth_value.rs +++ b/datafusion/src/physical_plan/expressions/nth_value.rs @@ -27,128 +27,69 @@ use std::any::Any; use std::convert::TryFrom; use std::sync::Arc; -/// first_value expression +/// nth_value kind +#[derive(Debug, Copy, Clone)] +enum NthValueKind { + First, + Last, + Nth(u32), +} + +/// nth_value expression #[derive(Debug)] -pub struct FirstValue { +pub struct NthValue { name: String, - data_type: DataType, expr: Arc, + data_type: DataType, + kind: NthValueKind, } -impl FirstValue { +impl NthValue { /// Create a new FIRST_VALUE window aggregate function - pub fn new(expr: Arc, name: String, data_type: DataType) -> Self { + pub fn first_value( + name: String, + expr: Arc, + data_type: DataType, + ) -> Self { Self { name, - data_type, expr, + data_type, + kind: NthValueKind::First, } } -} - -impl BuiltInWindowFunctionExpr for FirstValue { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - let nullable = true; - Ok(Field::new(&self.name, self.data_type.clone(), nullable)) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - fn name(&self) -> &str { - &self.name - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(NthValueAccumulator::try_new( - 1, - self.data_type.clone(), - )?)) - } -} - -// sql values start with 1, so we can use 0 to indicate the special last value behavior -const SPECIAL_SIZE_VALUE_FOR_LAST: u32 = 0; - -/// last_value expression -#[derive(Debug)] -pub struct LastValue { - name: String, - data_type: DataType, - expr: Arc, -} - -impl LastValue { - /// Create a new FIRST_VALUE window aggregate function - pub fn new(expr: Arc, name: String, data_type: DataType) -> Self { + /// Create a new LAST_VALUE window aggregate function + pub fn last_value( + name: String, + expr: Arc, + data_type: DataType, + ) -> Self { Self { name, - data_type, expr, + data_type, + kind: NthValueKind::Last, } } -} -impl BuiltInWindowFunctionExpr for LastValue { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - let nullable = true; - Ok(Field::new(&self.name, self.data_type.clone(), nullable)) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(NthValueAccumulator::try_new( - SPECIAL_SIZE_VALUE_FOR_LAST, - self.data_type.clone(), - )?)) - } -} - -/// nth_value expression -#[derive(Debug)] -pub struct NthValue { - name: String, - n: u32, - data_type: DataType, - expr: Arc, -} - -impl NthValue { /// Create a new NTH_VALUE window aggregate function - pub fn try_new( - expr: Arc, + pub fn nth_value( name: String, - n: u32, + expr: Arc, data_type: DataType, + n: u32, ) -> Result { - if n == SPECIAL_SIZE_VALUE_FOR_LAST { + if n == 0 { Err(DataFusionError::Execution( "nth_value expect n to be > 0".to_owned(), )) } else { Ok(Self { name, - n, - data_type, expr, + data_type, + kind: NthValueKind::Nth(n), }) } } @@ -175,7 +116,7 @@ impl BuiltInWindowFunctionExpr for NthValue { fn create_accumulator(&self) -> Result> { Ok(Box::new(NthValueAccumulator::try_new( - self.n, + self.kind, self.data_type.clone(), )?)) } @@ -183,19 +124,16 @@ impl BuiltInWindowFunctionExpr for NthValue { #[derive(Debug)] struct NthValueAccumulator { - // n the target nth_value, however we'll reuse it for last_value acc, so when n == 0 it specifically - // means last; also note that it is totally valid for n to be larger than the number of rows input - // in which case all the values shall be null - n: u32, + kind: NthValueKind, offset: u32, value: ScalarValue, } impl NthValueAccumulator { /// new count accumulator - pub fn try_new(n: u32, data_type: DataType) -> Result { + pub fn try_new(kind: NthValueKind, data_type: DataType) -> Result { Ok(Self { - n, + kind, offset: 0, // null value of that data_type by default value: ScalarValue::try_from(&data_type)?, @@ -205,15 +143,20 @@ impl NthValueAccumulator { impl WindowAccumulator for NthValueAccumulator { fn scan(&mut self, values: &[ScalarValue]) -> Result> { - if self.n == SPECIAL_SIZE_VALUE_FOR_LAST { - // for last_value function - self.value = values[0].clone(); - } else if self.offset < self.n { - self.offset += 1; - if self.offset == self.n { + self.offset += 1; + match self.kind { + NthValueKind::Last => { + self.value = values[0].clone(); + } + NthValueKind::First if self.offset == 1 => { self.value = values[0].clone(); } + NthValueKind::Nth(n) if self.offset == n => { + self.value = values[0].clone(); + } + _ => {} } + Ok(None) } @@ -221,3 +164,77 @@ impl WindowAccumulator for NthValueAccumulator { Ok(Some(self.value.clone())) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + use crate::physical_plan::expressions::col; + use arrow::record_batch::RecordBatch; + use arrow::{array::*, datatypes::*}; + + fn test_i32_result(expr: Arc, expected: i32) -> Result<()> { + let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8])); + let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?; + + let mut acc = expr.create_accumulator()?; + let expr = expr.expressions(); + let values = expr + .iter() + .map(|e| e.evaluate(&batch)) + .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .collect::>>()?; + let result = acc.scan_batch(batch.num_rows(), &values)?; + assert_eq!(false, result.is_some()); + let result = acc.evaluate()?; + assert_eq!(Some(ScalarValue::Int32(Some(expected))), result); + Ok(()) + } + + #[test] + fn first_value() -> Result<()> { + let first_value = Arc::new(NthValue::first_value( + "first_value".to_owned(), + col("arr"), + DataType::Int32, + )); + test_i32_result(first_value, 1)?; + Ok(()) + } + + #[test] + fn last_value() -> Result<()> { + let last_value = Arc::new(NthValue::last_value( + "last_value".to_owned(), + col("arr"), + DataType::Int32, + )); + test_i32_result(last_value, 8)?; + Ok(()) + } + + #[test] + fn nth_value_1() -> Result<()> { + let nth_value = Arc::new(NthValue::nth_value( + "nth_value".to_owned(), + col("arr"), + DataType::Int32, + 1, + )?); + test_i32_result(nth_value, 1)?; + Ok(()) + } + + #[test] + fn nth_value_2() -> Result<()> { + let nth_value = Arc::new(NthValue::nth_value( + "nth_value".to_owned(), + col("arr"), + DataType::Int32, + 1, + )?); + test_i32_result(nth_value, -2)?; + Ok(()) + } +} diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index e790eeaca749..659d2183819d 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -20,7 +20,7 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ aggregates, - expressions::{FirstValue, LastValue, Literal, NthValue, RowNumber}, + expressions::{Literal, NthValue, RowNumber}, type_coercion::coerce, window_functions::signature_for_built_in, window_functions::BuiltInWindowFunctionExpr, @@ -105,19 +105,19 @@ fn create_built_in_window_expr( .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; let n: u32 = n as u32; let data_type = args[0].data_type(input_schema)?; - Ok(Arc::new(NthValue::try_new(arg, name, n, data_type)?)) + Ok(Arc::new(NthValue::nth_value(name, arg, data_type, n)?)) } BuiltInWindowFunction::FirstValue => { let arg = coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone(); let data_type = args[0].data_type(input_schema)?; - Ok(Arc::new(FirstValue::new(arg, name, data_type))) + Ok(Arc::new(NthValue::first_value(name, arg, data_type))) } BuiltInWindowFunction::LastValue => { let arg = coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone(); let data_type = args[0].data_type(input_schema)?; - Ok(Arc::new(LastValue::new(arg, name, data_type))) + Ok(Arc::new(NthValue::last_value(name, arg, data_type))) } _ => Err(DataFusionError::NotImplemented(format!( "Window function with {:?} not yet implemented",