Skip to content

Commit

Permalink
optimize nth_value
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiayu Liu committed May 31, 2021
1 parent c8ab5a4 commit 18a6ffa
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 111 deletions.
2 changes: 1 addition & 1 deletion datafusion/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pub use literal::{lit, Literal};
pub use min_max::{Max, Min};
pub use negative::{negative, NegativeExpr};
pub use not::{not, NotExpr};
pub use nth_value::{FirstValue, LastValue, NthValue};
pub use nth_value::NthValue;
pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES};
pub use row_number::RowNumber;
pub use sum::{sum_return_type, Sum};
Expand Down
229 changes: 123 additions & 106 deletions datafusion/src/physical_plan/expressions/nth_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,128 +27,69 @@ use std::any::Any;
use std::convert::TryFrom;
use std::sync::Arc;

/// first_value expression
/// nth_value kind
#[derive(Debug, Copy, Clone)]
enum NthValueKind {
First,
Last,
Nth(u32),
}

/// nth_value expression
#[derive(Debug)]
pub struct FirstValue {
pub struct NthValue {
name: String,
data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
kind: NthValueKind,
}

impl FirstValue {
impl NthValue {
/// Create a new FIRST_VALUE window aggregate function
pub fn new(expr: Arc<dyn PhysicalExpr>, name: String, data_type: DataType) -> Self {
pub fn first_value(
name: String,
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
) -> Self {
Self {
name,
data_type,
expr,
data_type,
kind: NthValueKind::First,
}
}
}

impl BuiltInWindowFunctionExpr for FirstValue {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
}

fn field(&self) -> Result<Field> {
let nullable = true;
Ok(Field::new(&self.name, self.data_type.clone(), nullable))
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.expr.clone()]
}

fn name(&self) -> &str {
&self.name
}

fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
Ok(Box::new(NthValueAccumulator::try_new(
1,
self.data_type.clone(),
)?))
}
}

// sql values start with 1, so we can use 0 to indicate the special last value behavior
const SPECIAL_SIZE_VALUE_FOR_LAST: u32 = 0;

/// last_value expression
#[derive(Debug)]
pub struct LastValue {
name: String,
data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
}

impl LastValue {
/// Create a new FIRST_VALUE window aggregate function
pub fn new(expr: Arc<dyn PhysicalExpr>, name: String, data_type: DataType) -> Self {
/// Create a new LAST_VALUE window aggregate function
pub fn last_value(
name: String,
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
) -> Self {
Self {
name,
data_type,
expr,
data_type,
kind: NthValueKind::Last,
}
}
}

impl BuiltInWindowFunctionExpr for LastValue {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
}

fn field(&self) -> Result<Field> {
let nullable = true;
Ok(Field::new(&self.name, self.data_type.clone(), nullable))
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.expr.clone()]
}

fn name(&self) -> &str {
&self.name
}

fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
Ok(Box::new(NthValueAccumulator::try_new(
SPECIAL_SIZE_VALUE_FOR_LAST,
self.data_type.clone(),
)?))
}
}

/// nth_value expression
#[derive(Debug)]
pub struct NthValue {
name: String,
n: u32,
data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
}

impl NthValue {
/// Create a new NTH_VALUE window aggregate function
pub fn try_new(
expr: Arc<dyn PhysicalExpr>,
pub fn nth_value(
name: String,
n: u32,
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
n: u32,
) -> Result<Self> {
if n == SPECIAL_SIZE_VALUE_FOR_LAST {
if n == 0 {
Err(DataFusionError::Execution(
"nth_value expect n to be > 0".to_owned(),
))
} else {
Ok(Self {
name,
n,
data_type,
expr,
data_type,
kind: NthValueKind::Nth(n),
})
}
}
Expand All @@ -175,27 +116,24 @@ impl BuiltInWindowFunctionExpr for NthValue {

fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
Ok(Box::new(NthValueAccumulator::try_new(
self.n,
self.kind,
self.data_type.clone(),
)?))
}
}

#[derive(Debug)]
struct NthValueAccumulator {
// n the target nth_value, however we'll reuse it for last_value acc, so when n == 0 it specifically
// means last; also note that it is totally valid for n to be larger than the number of rows input
// in which case all the values shall be null
n: u32,
kind: NthValueKind,
offset: u32,
value: ScalarValue,
}

impl NthValueAccumulator {
/// new count accumulator
pub fn try_new(n: u32, data_type: DataType) -> Result<Self> {
pub fn try_new(kind: NthValueKind, data_type: DataType) -> Result<Self> {
Ok(Self {
n,
kind,
offset: 0,
// null value of that data_type by default
value: ScalarValue::try_from(&data_type)?,
Expand All @@ -205,19 +143,98 @@ impl NthValueAccumulator {

impl WindowAccumulator for NthValueAccumulator {
fn scan(&mut self, values: &[ScalarValue]) -> Result<Option<ScalarValue>> {
if self.n == SPECIAL_SIZE_VALUE_FOR_LAST {
// for last_value function
self.value = values[0].clone();
} else if self.offset < self.n {
self.offset += 1;
if self.offset == self.n {
self.offset += 1;
match self.kind {
NthValueKind::Last => {
self.value = values[0].clone();
}
NthValueKind::First if self.offset == 1 => {
self.value = values[0].clone();
}
NthValueKind::Nth(n) if self.offset == n => {
self.value = values[0].clone();
}
_ => {}
}

Ok(None)
}

fn evaluate(&self) -> Result<Option<ScalarValue>> {
Ok(Some(self.value.clone()))
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::error::Result;
use crate::physical_plan::expressions::col;
use arrow::record_batch::RecordBatch;
use arrow::{array::*, datatypes::*};

fn test_i32_result(expr: Arc<NthValue>, expected: i32) -> Result<()> {
let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?;

let mut acc = expr.create_accumulator()?;
let expr = expr.expressions();
let values = expr
.iter()
.map(|e| e.evaluate(&batch))
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
.collect::<Result<Vec<_>>>()?;
let result = acc.scan_batch(batch.num_rows(), &values)?;
assert_eq!(false, result.is_some());
let result = acc.evaluate()?;
assert_eq!(Some(ScalarValue::Int32(Some(expected))), result);
Ok(())
}

#[test]
fn first_value() -> Result<()> {
let first_value = Arc::new(NthValue::first_value(
"first_value".to_owned(),
col("arr"),
DataType::Int32,
));
test_i32_result(first_value, 1)?;
Ok(())
}

#[test]
fn last_value() -> Result<()> {
let last_value = Arc::new(NthValue::last_value(
"last_value".to_owned(),
col("arr"),
DataType::Int32,
));
test_i32_result(last_value, 8)?;
Ok(())
}

#[test]
fn nth_value_1() -> Result<()> {
let nth_value = Arc::new(NthValue::nth_value(
"nth_value".to_owned(),
col("arr"),
DataType::Int32,
1,
)?);
test_i32_result(nth_value, 1)?;
Ok(())
}

#[test]
fn nth_value_2() -> Result<()> {
let nth_value = Arc::new(NthValue::nth_value(
"nth_value".to_owned(),
col("arr"),
DataType::Int32,
1,
)?);
test_i32_result(nth_value, -2)?;
Ok(())
}
}
8 changes: 4 additions & 4 deletions datafusion/src/physical_plan/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
use crate::error::{DataFusionError, Result};
use crate::physical_plan::{
aggregates,
expressions::{FirstValue, LastValue, Literal, NthValue, RowNumber},
expressions::{Literal, NthValue, RowNumber},
type_coercion::coerce,
window_functions::signature_for_built_in,
window_functions::BuiltInWindowFunctionExpr,
Expand Down Expand Up @@ -105,19 +105,19 @@ fn create_built_in_window_expr(
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;
let n: u32 = n as u32;
let data_type = args[0].data_type(input_schema)?;
Ok(Arc::new(NthValue::try_new(arg, name, n, data_type)?))
Ok(Arc::new(NthValue::nth_value(name, arg, data_type, n)?))
}
BuiltInWindowFunction::FirstValue => {
let arg =
coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone();
let data_type = args[0].data_type(input_schema)?;
Ok(Arc::new(FirstValue::new(arg, name, data_type)))
Ok(Arc::new(NthValue::first_value(name, arg, data_type)))
}
BuiltInWindowFunction::LastValue => {
let arg =
coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone();
let data_type = args[0].data_type(input_schema)?;
Ok(Arc::new(LastValue::new(arg, name, data_type)))
Ok(Arc::new(NthValue::last_value(name, arg, data_type)))
}
_ => Err(DataFusionError::NotImplemented(format!(
"Window function with {:?} not yet implemented",
Expand Down

0 comments on commit 18a6ffa

Please sign in to comment.