diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 77da95c3a04a3..d18365c47ed5e 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -59,7 +59,7 @@ pub use literal::{lit, Literal}; pub use min_max::{Max, Min}; pub use negative::{negative, NegativeExpr}; pub use not::{not, NotExpr}; -pub use nth_value::{FirstValue, LastValue, NthValue}; +pub use nth_value::NthValue; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use row_number::RowNumber; pub use sum::{sum_return_type, Sum}; diff --git a/datafusion/src/physical_plan/expressions/nth_value.rs b/datafusion/src/physical_plan/expressions/nth_value.rs index e90ad322aae9d..4fd18244ca2a9 100644 --- a/datafusion/src/physical_plan/expressions/nth_value.rs +++ b/datafusion/src/physical_plan/expressions/nth_value.rs @@ -27,128 +27,69 @@ use std::any::Any; use std::convert::TryFrom; use std::sync::Arc; -/// first_value expression +/// nth_value kind +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum NthValueKind { + First, + Last, + Nth(u32), +} + +/// nth_value expression #[derive(Debug)] -pub struct FirstValue { +pub struct NthValue { name: String, - data_type: DataType, expr: Arc, + data_type: DataType, + kind: NthValueKind, } -impl FirstValue { +impl NthValue { /// Create a new FIRST_VALUE window aggregate function - pub fn new(expr: Arc, name: String, data_type: DataType) -> Self { + pub fn first_value( + name: String, + expr: Arc, + data_type: DataType, + ) -> Self { Self { name, - data_type, expr, + data_type, + kind: NthValueKind::First, } } -} -impl BuiltInWindowFunctionExpr for FirstValue { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - let nullable = true; - Ok(Field::new(&self.name, self.data_type.clone(), nullable)) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(NthValueAccumulator::try_new( - 1, - self.data_type.clone(), - )?)) - } -} - -// sql values start with 1, so we can use 0 to indicate the special last value behavior -const SPECIAL_SIZE_VALUE_FOR_LAST: u32 = 0; - -/// last_value expression -#[derive(Debug)] -pub struct LastValue { - name: String, - data_type: DataType, - expr: Arc, -} - -impl LastValue { - /// Create a new FIRST_VALUE window aggregate function - pub fn new(expr: Arc, name: String, data_type: DataType) -> Self { + /// Create a new LAST_VALUE window aggregate function + pub fn last_value( + name: String, + expr: Arc, + data_type: DataType, + ) -> Self { Self { name, - data_type, expr, + data_type, + kind: NthValueKind::Last, } } -} - -impl BuiltInWindowFunctionExpr for LastValue { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - let nullable = true; - Ok(Field::new(&self.name, self.data_type.clone(), nullable)) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(NthValueAccumulator::try_new( - SPECIAL_SIZE_VALUE_FOR_LAST, - self.data_type.clone(), - )?)) - } -} - -/// nth_value expression -#[derive(Debug)] -pub struct NthValue { - name: String, - n: u32, - data_type: DataType, - expr: Arc, -} -impl NthValue { /// Create a new NTH_VALUE window aggregate function - pub fn try_new( - expr: Arc, + pub fn nth_value( name: String, - n: u32, + expr: Arc, data_type: DataType, + n: u32, ) -> Result { - if n == SPECIAL_SIZE_VALUE_FOR_LAST { + if n == 0 { Err(DataFusionError::Execution( "nth_value expect n to be > 0".to_owned(), )) } else { Ok(Self { name, - n, - data_type, expr, + data_type, + kind: NthValueKind::Nth(n), }) } } @@ -175,7 +116,7 @@ impl BuiltInWindowFunctionExpr for NthValue { fn create_accumulator(&self) -> Result> { Ok(Box::new(NthValueAccumulator::try_new( - self.n, + self.kind, self.data_type.clone(), )?)) } @@ -183,19 +124,16 @@ impl BuiltInWindowFunctionExpr for NthValue { #[derive(Debug)] struct NthValueAccumulator { - // n the target nth_value, however we'll reuse it for last_value acc, so when n == 0 it specifically - // means last; also note that it is totally valid for n to be larger than the number of rows input - // in which case all the values shall be null - n: u32, + kind: NthValueKind, offset: u32, value: ScalarValue, } impl NthValueAccumulator { /// new count accumulator - pub fn try_new(n: u32, data_type: DataType) -> Result { + pub fn try_new(kind: NthValueKind, data_type: DataType) -> Result { Ok(Self { - n, + kind, offset: 0, // null value of that data_type by default value: ScalarValue::try_from(&data_type)?, @@ -205,15 +143,20 @@ impl NthValueAccumulator { impl WindowAccumulator for NthValueAccumulator { fn scan(&mut self, values: &[ScalarValue]) -> Result> { - if self.n == SPECIAL_SIZE_VALUE_FOR_LAST { - // for last_value function - self.value = values[0].clone(); - } else if self.offset < self.n { - self.offset += 1; - if self.offset == self.n { + self.offset += 1; + match self.kind { + NthValueKind::Last => { + self.value = values[0].clone(); + } + NthValueKind::First if self.offset == 1 => { self.value = values[0].clone(); } + NthValueKind::Nth(n) if self.offset == n => { + self.value = values[0].clone(); + } + _ => {} } + Ok(None) } diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index e790eeaca749e..659d2183819d3 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -20,7 +20,7 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ aggregates, - expressions::{FirstValue, LastValue, Literal, NthValue, RowNumber}, + expressions::{Literal, NthValue, RowNumber}, type_coercion::coerce, window_functions::signature_for_built_in, window_functions::BuiltInWindowFunctionExpr, @@ -105,19 +105,19 @@ fn create_built_in_window_expr( .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; let n: u32 = n as u32; let data_type = args[0].data_type(input_schema)?; - Ok(Arc::new(NthValue::try_new(arg, name, n, data_type)?)) + Ok(Arc::new(NthValue::nth_value(name, arg, data_type, n)?)) } BuiltInWindowFunction::FirstValue => { let arg = coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone(); let data_type = args[0].data_type(input_schema)?; - Ok(Arc::new(FirstValue::new(arg, name, data_type))) + Ok(Arc::new(NthValue::first_value(name, arg, data_type))) } BuiltInWindowFunction::LastValue => { let arg = coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone(); let data_type = args[0].data_type(input_schema)?; - Ok(Arc::new(LastValue::new(arg, name, data_type))) + Ok(Arc::new(NthValue::last_value(name, arg, data_type))) } _ => Err(DataFusionError::NotImplemented(format!( "Window function with {:?} not yet implemented",