Skip to content

Commit

Permalink
optimize nth_value
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiayu Liu committed May 31, 2021
1 parent c8ab5a4 commit e7a3964
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 111 deletions.
2 changes: 1 addition & 1 deletion datafusion/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pub use literal::{lit, Literal};
pub use min_max::{Max, Min};
pub use negative::{negative, NegativeExpr};
pub use not::{not, NotExpr};
pub use nth_value::{FirstValue, LastValue, NthValue};
pub use nth_value::NthValue;
pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES};
pub use row_number::RowNumber;
pub use sum::{sum_return_type, Sum};
Expand Down
155 changes: 49 additions & 106 deletions datafusion/src/physical_plan/expressions/nth_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,128 +27,69 @@ use std::any::Any;
use std::convert::TryFrom;
use std::sync::Arc;

/// first_value expression
/// nth_value kind
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum NthValueKind {
First,
Last,
Nth(u32),
}

/// nth_value expression
#[derive(Debug)]
pub struct FirstValue {
pub struct NthValue {
name: String,
data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
kind: NthValueKind,
}

impl FirstValue {
impl NthValue {
/// Create a new FIRST_VALUE window aggregate function
pub fn new(expr: Arc<dyn PhysicalExpr>, name: String, data_type: DataType) -> Self {
pub fn first_value(
name: String,
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
) -> Self {
Self {
name,
data_type,
expr,
data_type,
kind: NthValueKind::First,
}
}
}

impl BuiltInWindowFunctionExpr for FirstValue {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
}

fn field(&self) -> Result<Field> {
let nullable = true;
Ok(Field::new(&self.name, self.data_type.clone(), nullable))
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.expr.clone()]
}

fn name(&self) -> &str {
&self.name
}

fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
Ok(Box::new(NthValueAccumulator::try_new(
1,
self.data_type.clone(),
)?))
}
}

// sql values start with 1, so we can use 0 to indicate the special last value behavior
const SPECIAL_SIZE_VALUE_FOR_LAST: u32 = 0;

/// last_value expression
#[derive(Debug)]
pub struct LastValue {
name: String,
data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
}

impl LastValue {
/// Create a new FIRST_VALUE window aggregate function
pub fn new(expr: Arc<dyn PhysicalExpr>, name: String, data_type: DataType) -> Self {
/// Create a new LAST_VALUE window aggregate function
pub fn last_value(
name: String,
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
) -> Self {
Self {
name,
data_type,
expr,
data_type,
kind: NthValueKind::Last,
}
}
}

impl BuiltInWindowFunctionExpr for LastValue {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
}

fn field(&self) -> Result<Field> {
let nullable = true;
Ok(Field::new(&self.name, self.data_type.clone(), nullable))
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.expr.clone()]
}

fn name(&self) -> &str {
&self.name
}

fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
Ok(Box::new(NthValueAccumulator::try_new(
SPECIAL_SIZE_VALUE_FOR_LAST,
self.data_type.clone(),
)?))
}
}

/// nth_value expression
#[derive(Debug)]
pub struct NthValue {
name: String,
n: u32,
data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
}

impl NthValue {
/// Create a new NTH_VALUE window aggregate function
pub fn try_new(
expr: Arc<dyn PhysicalExpr>,
pub fn nth_value(
name: String,
n: u32,
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
n: u32,
) -> Result<Self> {
if n == SPECIAL_SIZE_VALUE_FOR_LAST {
if n == 0 {
Err(DataFusionError::Execution(
"nth_value expect n to be > 0".to_owned(),
))
} else {
Ok(Self {
name,
n,
data_type,
expr,
data_type,
kind: NthValueKind::Nth(n),
})
}
}
Expand All @@ -175,27 +116,24 @@ impl BuiltInWindowFunctionExpr for NthValue {

fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
Ok(Box::new(NthValueAccumulator::try_new(
self.n,
self.kind,
self.data_type.clone(),
)?))
}
}

#[derive(Debug)]
struct NthValueAccumulator {
// n the target nth_value, however we'll reuse it for last_value acc, so when n == 0 it specifically
// means last; also note that it is totally valid for n to be larger than the number of rows input
// in which case all the values shall be null
n: u32,
kind: NthValueKind,
offset: u32,
value: ScalarValue,
}

impl NthValueAccumulator {
/// new count accumulator
pub fn try_new(n: u32, data_type: DataType) -> Result<Self> {
pub fn try_new(kind: NthValueKind, data_type: DataType) -> Result<Self> {
Ok(Self {
n,
kind,
offset: 0,
// null value of that data_type by default
value: ScalarValue::try_from(&data_type)?,
Expand All @@ -205,15 +143,20 @@ impl NthValueAccumulator {

impl WindowAccumulator for NthValueAccumulator {
fn scan(&mut self, values: &[ScalarValue]) -> Result<Option<ScalarValue>> {
if self.n == SPECIAL_SIZE_VALUE_FOR_LAST {
// for last_value function
self.value = values[0].clone();
} else if self.offset < self.n {
self.offset += 1;
if self.offset == self.n {
self.offset += 1;
match self.kind {
NthValueKind::Last => {
self.value = values[0].clone();
}
NthValueKind::First if self.offset == 1 => {
self.value = values[0].clone();
}
NthValueKind::Nth(n) if self.offset == n => {
self.value = values[0].clone();
}
_ => {}
}

Ok(None)
}

Expand Down
8 changes: 4 additions & 4 deletions datafusion/src/physical_plan/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
use crate::error::{DataFusionError, Result};
use crate::physical_plan::{
aggregates,
expressions::{FirstValue, LastValue, Literal, NthValue, RowNumber},
expressions::{Literal, NthValue, RowNumber},
type_coercion::coerce,
window_functions::signature_for_built_in,
window_functions::BuiltInWindowFunctionExpr,
Expand Down Expand Up @@ -105,19 +105,19 @@ fn create_built_in_window_expr(
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;
let n: u32 = n as u32;
let data_type = args[0].data_type(input_schema)?;
Ok(Arc::new(NthValue::try_new(arg, name, n, data_type)?))
Ok(Arc::new(NthValue::nth_value(name, arg, data_type, n)?))
}
BuiltInWindowFunction::FirstValue => {
let arg =
coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone();
let data_type = args[0].data_type(input_schema)?;
Ok(Arc::new(FirstValue::new(arg, name, data_type)))
Ok(Arc::new(NthValue::first_value(name, arg, data_type)))
}
BuiltInWindowFunction::LastValue => {
let arg =
coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone();
let data_type = args[0].data_type(input_schema)?;
Ok(Arc::new(LastValue::new(arg, name, data_type)))
Ok(Arc::new(NthValue::last_value(name, arg, data_type)))
}
_ => Err(DataFusionError::NotImplemented(format!(
"Window function with {:?} not yet implemented",
Expand Down

0 comments on commit e7a3964

Please sign in to comment.