diff --git a/datafusion/src/physical_plan/expressions/nth_value.rs b/datafusion/src/physical_plan/expressions/nth_value.rs index 4fd18244ca2a..411739a1c65e 100644 --- a/datafusion/src/physical_plan/expressions/nth_value.rs +++ b/datafusion/src/physical_plan/expressions/nth_value.rs @@ -164,3 +164,77 @@ impl WindowAccumulator for NthValueAccumulator { Ok(Some(self.value.clone())) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + use crate::physical_plan::expressions::col; + use arrow::record_batch::RecordBatch; + use arrow::{array::*, datatypes::*}; + + fn test_i32_result(expr: Arc, expected: i32) -> Result<()> { + let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8])); + let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?; + + let mut acc = expr.create_accumulator()?; + let expr = expr.expressions(); + let values = expr + .iter() + .map(|e| e.evaluate(&batch)) + .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .collect::>>()?; + let result = acc.scan_batch(batch.num_rows(), &values)?; + assert_eq!(false, result.is_some()); + let result = acc.evaluate()?; + assert_eq!(Some(ScalarValue::Int32(Some(expected))), result); + Ok(()) + } + + #[test] + fn first_value() -> Result<()> { + let first_value = Arc::new(NthValue::first_value( + "first_value".to_owned(), + col("arr"), + DataType::Int32, + )); + test_i32_result(first_value, 1)?; + Ok(()) + } + + #[test] + fn last_value() -> Result<()> { + let last_value = Arc::new(NthValue::last_value( + "last_value".to_owned(), + col("arr"), + DataType::Int32, + )); + test_i32_result(last_value, 8)?; + Ok(()) + } + + #[test] + fn nth_value_1() -> Result<()> { + let nth_value = Arc::new(NthValue::nth_value( + "nth_value".to_owned(), + col("arr"), + DataType::Int32, + 1, + )?); + test_i32_result(nth_value, 1)?; + Ok(()) + } + + #[test] + fn nth_value_2() -> Result<()> { + let nth_value = Arc::new(NthValue::nth_value( + "nth_value".to_owned(), + col("arr"), + DataType::Int32, + 1, + )?); + test_i32_result(nth_value, -2)?; + Ok(()) + } +}