Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize nth_value, remove first_value, last_value structs and use idiomatic rust style #452

Merged
merged 2 commits into from
Jun 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pub use literal::{lit, Literal};
pub use min_max::{Max, Min};
pub use negative::{negative, NegativeExpr};
pub use not::{not, NotExpr};
pub use nth_value::{FirstValue, LastValue, NthValue};
pub use nth_value::NthValue;
pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES};
pub use row_number::RowNumber;
pub use sum::{sum_return_type, Sum};
Expand Down
238 changes: 127 additions & 111 deletions datafusion/src/physical_plan/expressions/nth_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,129 +27,69 @@ use std::any::Any;
use std::convert::TryFrom;
use std::sync::Arc;

/// first_value expression
/// nth_value kind
#[derive(Debug, Copy, Clone)]
enum NthValueKind {
First,
Last,
Nth(u32),
}

/// nth_value expression
#[derive(Debug)]
pub struct FirstValue {
pub struct NthValue {
name: String,
data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
kind: NthValueKind,
}

impl FirstValue {
impl NthValue {
/// Create a new FIRST_VALUE window aggregate function
pub fn new(expr: Arc<dyn PhysicalExpr>, name: String, data_type: DataType) -> Self {
pub fn first_value(
name: String,
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
) -> Self {
Self {
name,
data_type,
expr,
data_type,
kind: NthValueKind::First,
}
}
}

impl BuiltInWindowFunctionExpr for FirstValue {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
}

fn field(&self) -> Result<Field> {
let nullable = true;
Ok(Field::new(&self.name, self.data_type.clone(), nullable))
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.expr.clone()]
}

fn name(&self) -> &str {
&self.name
}

fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
Ok(Box::new(NthValueAccumulator::try_new(
1,
self.data_type.clone(),
)?))
}
}

// sql values start with 1, so we can use 0 to indicate the special last value behavior
const SPECIAL_SIZE_VALUE_FOR_LAST: u32 = 0;

/// last_value expression
#[derive(Debug)]
pub struct LastValue {
name: String,
data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
}

impl LastValue {
/// Create a new FIRST_VALUE window aggregate function
pub fn new(expr: Arc<dyn PhysicalExpr>, name: String, data_type: DataType) -> Self {
/// Create a new LAST_VALUE window aggregate function
pub fn last_value(
name: String,
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
) -> Self {
Self {
name,
data_type,
expr,
data_type,
kind: NthValueKind::Last,
}
}
}

impl BuiltInWindowFunctionExpr for LastValue {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
}

fn field(&self) -> Result<Field> {
let nullable = true;
Ok(Field::new(&self.name, self.data_type.clone(), nullable))
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.expr.clone()]
}

fn name(&self) -> &str {
&self.name
}

fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
Ok(Box::new(NthValueAccumulator::try_new(
SPECIAL_SIZE_VALUE_FOR_LAST,
self.data_type.clone(),
)?))
}
}

/// nth_value expression
#[derive(Debug)]
pub struct NthValue {
name: String,
n: u32,
data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
}

impl NthValue {
/// Create a new NTH_VALUE window aggregate function
pub fn try_new(
expr: Arc<dyn PhysicalExpr>,
pub fn nth_value(
name: String,
n: u32,
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
n: u32,
) -> Result<Self> {
if n == SPECIAL_SIZE_VALUE_FOR_LAST {
Err(DataFusionError::Execution(
match n {
0 => Err(DataFusionError::Execution(
"nth_value expect n to be > 0".to_owned(),
))
} else {
Ok(Self {
)),
_ => Ok(Self {
name,
n,
data_type,
expr,
})
data_type,
kind: NthValueKind::Nth(n),
}),
}
}
}
Expand All @@ -175,27 +115,24 @@ impl BuiltInWindowFunctionExpr for NthValue {

fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
Ok(Box::new(NthValueAccumulator::try_new(
self.n,
self.kind,
self.data_type.clone(),
)?))
}
}

#[derive(Debug)]
struct NthValueAccumulator {
// n the target nth_value, however we'll reuse it for last_value acc, so when n == 0 it specifically
// means last; also note that it is totally valid for n to be larger than the number of rows input
// in which case all the values shall be null
n: u32,
kind: NthValueKind,
offset: u32,
value: ScalarValue,
}

impl NthValueAccumulator {
/// new count accumulator
pub fn try_new(n: u32, data_type: DataType) -> Result<Self> {
pub fn try_new(kind: NthValueKind, data_type: DataType) -> Result<Self> {
Ok(Self {
n,
kind,
offset: 0,
// null value of that data_type by default
value: ScalarValue::try_from(&data_type)?,
Expand All @@ -205,19 +142,98 @@ impl NthValueAccumulator {

impl WindowAccumulator for NthValueAccumulator {
fn scan(&mut self, values: &[ScalarValue]) -> Result<Option<ScalarValue>> {
if self.n == SPECIAL_SIZE_VALUE_FOR_LAST {
// for last_value function
self.value = values[0].clone();
} else if self.offset < self.n {
self.offset += 1;
if self.offset == self.n {
self.offset += 1;
match self.kind {
NthValueKind::Last => {
self.value = values[0].clone();
}
NthValueKind::First if self.offset == 1 => {
self.value = values[0].clone();
}
NthValueKind::Nth(n) if self.offset == n => {
self.value = values[0].clone();
}
_ => {}
}

Ok(None)
}

fn evaluate(&self) -> Result<Option<ScalarValue>> {
Ok(Some(self.value.clone()))
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::error::Result;
use crate::physical_plan::expressions::col;
use arrow::record_batch::RecordBatch;
use arrow::{array::*, datatypes::*};

fn test_i32_result(expr: Arc<NthValue>, expected: i32) -> Result<()> {
let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW for tests like this you can also use RecordBatch::try_from_iter to avoid having to construct the Schema directly.

This way is great too, I just figured I would point it out for the future


let mut acc = expr.create_accumulator()?;
let expr = expr.expressions();
let values = expr
.iter()
.map(|e| e.evaluate(&batch))
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
.collect::<Result<Vec<_>>>()?;
let result = acc.scan_batch(batch.num_rows(), &values)?;
assert_eq!(false, result.is_some());
let result = acc.evaluate()?;
assert_eq!(Some(ScalarValue::Int32(Some(expected))), result);
Ok(())
}

#[test]
fn first_value() -> Result<()> {
let first_value = Arc::new(NthValue::first_value(
"first_value".to_owned(),
col("arr"),
DataType::Int32,
));
test_i32_result(first_value, 1)?;
Ok(())
}

#[test]
fn last_value() -> Result<()> {
let last_value = Arc::new(NthValue::last_value(
"last_value".to_owned(),
col("arr"),
DataType::Int32,
));
test_i32_result(last_value, 8)?;
Ok(())
}

#[test]
fn nth_value_1() -> Result<()> {
let nth_value = Arc::new(NthValue::nth_value(
"nth_value".to_owned(),
col("arr"),
DataType::Int32,
1,
)?);
test_i32_result(nth_value, 1)?;
Ok(())
}

#[test]
fn nth_value_2() -> Result<()> {
let nth_value = Arc::new(NthValue::nth_value(
"nth_value".to_owned(),
col("arr"),
DataType::Int32,
2,
)?);
test_i32_result(nth_value, -2)?;
Ok(())
}
}
8 changes: 4 additions & 4 deletions datafusion/src/physical_plan/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
use crate::error::{DataFusionError, Result};
use crate::physical_plan::{
aggregates,
expressions::{FirstValue, LastValue, Literal, NthValue, RowNumber},
expressions::{Literal, NthValue, RowNumber},
type_coercion::coerce,
window_functions::signature_for_built_in,
window_functions::BuiltInWindowFunctionExpr,
Expand Down Expand Up @@ -105,19 +105,19 @@ fn create_built_in_window_expr(
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;
let n: u32 = n as u32;
let data_type = args[0].data_type(input_schema)?;
Ok(Arc::new(NthValue::try_new(arg, name, n, data_type)?))
Ok(Arc::new(NthValue::nth_value(name, arg, data_type, n)?))
}
BuiltInWindowFunction::FirstValue => {
let arg =
coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone();
let data_type = args[0].data_type(input_schema)?;
Ok(Arc::new(FirstValue::new(arg, name, data_type)))
Ok(Arc::new(NthValue::first_value(name, arg, data_type)))
}
BuiltInWindowFunction::LastValue => {
let arg =
coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone();
let data_type = args[0].data_type(input_schema)?;
Ok(Arc::new(LastValue::new(arg, name, data_type)))
Ok(Arc::new(NthValue::last_value(name, arg, data_type)))
}
_ => Err(DataFusionError::NotImplemented(format!(
"Window function with {:?} not yet implemented",
Expand Down