Skip to content

Commit

Permalink
Lead/lag window function with offset and default value arguments (#687)
Browse files Browse the repository at this point in the history
  • Loading branch information
jgoday authored Jul 14, 2021
1 parent fd50dd8 commit 002ca5d
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 18 deletions.
94 changes: 90 additions & 4 deletions datafusion/src/physical_plan/expressions/lead_lag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
use crate::error::{DataFusionError, Result};
use crate::physical_plan::window_functions::PartitionEvaluator;
use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr};
use crate::scalar::ScalarValue;
use arrow::array::ArrayRef;
use arrow::compute::kernels::window::shift;
use arrow::compute::cast;
use arrow::datatypes::{DataType, Field};
use arrow::record_batch::RecordBatch;
use std::any::Any;
use std::ops::Neg;
use std::ops::Range;
use std::sync::Arc;

Expand All @@ -36,19 +38,23 @@ pub struct WindowShift {
data_type: DataType,
shift_offset: i64,
expr: Arc<dyn PhysicalExpr>,
default_value: Option<ScalarValue>,
}

/// lead() window function
pub fn lead(
name: String,
data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
shift_offset: Option<i64>,
default_value: Option<ScalarValue>,
) -> WindowShift {
WindowShift {
name,
data_type,
shift_offset: -1,
shift_offset: shift_offset.map(|v| v.neg()).unwrap_or(-1),
expr,
default_value,
}
}

Expand All @@ -57,12 +63,15 @@ pub fn lag(
name: String,
data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
shift_offset: Option<i64>,
default_value: Option<ScalarValue>,
) -> WindowShift {
WindowShift {
name,
data_type,
shift_offset: 1,
shift_offset: shift_offset.unwrap_or(1),
expr,
default_value,
}
}

Expand Down Expand Up @@ -98,20 +107,71 @@ impl BuiltInWindowFunctionExpr for WindowShift {
Ok(Box::new(WindowShiftEvaluator {
shift_offset: self.shift_offset,
values,
default_value: self.default_value.clone(),
}))
}
}

pub(crate) struct WindowShiftEvaluator {
shift_offset: i64,
values: Vec<ArrayRef>,
default_value: Option<ScalarValue>,
}

fn create_empty_array(
value: &Option<ScalarValue>,
data_type: &DataType,
size: usize,
) -> Result<ArrayRef> {
use arrow::array::new_null_array;
let array = value
.as_ref()
.map(|scalar| scalar.to_array_of_size(size))
.unwrap_or_else(|| new_null_array(data_type, size));
if array.data_type() != data_type {
cast(&array, data_type).map_err(DataFusionError::ArrowError)
} else {
Ok(array)
}
}

// TODO: change the original arrow::compute::kernels::window::shift impl to support an optional default value
fn shift_with_default_value(
array: &ArrayRef,
offset: i64,
value: &Option<ScalarValue>,
) -> Result<ArrayRef> {
use arrow::compute::concat;

let value_len = array.len() as i64;
if offset == 0 {
Ok(arrow::array::make_array(array.data_ref().clone()))
} else if offset == i64::MIN || offset.abs() >= value_len {
create_empty_array(value, array.data_type(), array.len())
} else {
let slice_offset = (-offset).clamp(0, value_len) as usize;
let length = array.len() - offset.abs() as usize;
let slice = array.slice(slice_offset, length);

// Generate array with remaining `null` items
let nulls = offset.abs() as usize;
let default_values = create_empty_array(value, slice.data_type(), nulls)?;
// Concatenate both arrays, add nulls after if shift > 0 else before
if offset > 0 {
concat(&[default_values.as_ref(), slice.as_ref()])
.map_err(DataFusionError::ArrowError)
} else {
concat(&[slice.as_ref(), default_values.as_ref()])
.map_err(DataFusionError::ArrowError)
}
}
}

impl PartitionEvaluator for WindowShiftEvaluator {
fn evaluate_partition(&self, partition: Range<usize>) -> Result<ArrayRef> {
let value = &self.values[0];
let value = value.slice(partition.start, partition.end - partition.start);
shift(value.as_ref(), self.shift_offset).map_err(DataFusionError::ArrowError)
shift_with_default_value(&value, self.shift_offset, &self.default_value)
}
}

Expand Down Expand Up @@ -142,6 +202,8 @@ mod tests {
"lead".to_owned(),
DataType::Float32,
Arc::new(Column::new("c3", 0)),
None,
None,
),
vec![
Some(-2),
Expand All @@ -162,6 +224,8 @@ mod tests {
"lead".to_owned(),
DataType::Float32,
Arc::new(Column::new("c3", 0)),
None,
None,
),
vec![
None,
Expand All @@ -176,6 +240,28 @@ mod tests {
.iter()
.collect::<Int32Array>(),
)?;

test_i32_result(
lag(
"lead".to_owned(),
DataType::Int32,
Arc::new(Column::new("c3", 0)),
None,
Some(ScalarValue::Int32(Some(100))),
),
vec![
Some(100),
Some(1),
Some(-2),
Some(3),
Some(-4),
Some(5),
Some(-6),
Some(7),
]
.iter()
.collect::<Int32Array>(),
)?;
Ok(())
}
}
35 changes: 28 additions & 7 deletions datafusion/src/physical_plan/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,11 @@ fn get_valid_types(
}
vec![(0..*number).map(|i| current_types[i].clone()).collect()]
}
Signature::OneOf(types) => {
let mut r = vec![];
for s in types {
r.extend(get_valid_types(s, current_types)?);
}
r
}
Signature::OneOf(types) => types
.iter()
.filter_map(|t| get_valid_types(t, current_types).ok())
.flatten()
.collect::<Vec<_>>(),
};

Ok(valid_types)
Expand Down Expand Up @@ -367,4 +365,27 @@ mod tests {

Ok(())
}

#[test]
fn test_get_valid_types_one_of() -> Result<()> {
let signature = Signature::OneOf(vec![Signature::Any(1), Signature::Any(2)]);

let invalid_types = get_valid_types(
&signature,
&[DataType::Int32, DataType::Int32, DataType::Int32],
)?;
assert_eq!(invalid_types.len(), 0);

let args = vec![DataType::Int32, DataType::Int32];
let valid_types = get_valid_types(&signature, &args)?;
assert_eq!(valid_types.len(), 1);
assert_eq!(valid_types[0], args);

let args = vec![DataType::Int32];
let valid_types = get_valid_types(&signature, &args)?;
assert_eq!(valid_types.len(), 1);
assert_eq!(valid_types[0], args);

Ok(())
}
}
14 changes: 10 additions & 4 deletions datafusion/src/physical_plan/window_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,16 @@ pub(super) fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature {
| BuiltInWindowFunction::DenseRank
| BuiltInWindowFunction::PercentRank
| BuiltInWindowFunction::CumeDist => Signature::Any(0),
BuiltInWindowFunction::Lag
| BuiltInWindowFunction::Lead
| BuiltInWindowFunction::FirstValue
| BuiltInWindowFunction::LastValue => Signature::Any(1),
BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => {
Signature::OneOf(vec![
Signature::Any(1),
Signature::Any(2),
Signature::Any(3),
])
}
BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => {
Signature::Any(1)
}
BuiltInWindowFunction::Ntile => Signature::Exact(vec![DataType::UInt64]),
BuiltInWindowFunction::NthValue => Signature::Any(2),
}
Expand Down
67 changes: 65 additions & 2 deletions datafusion/src/physical_plan/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use crate::physical_plan::{
Accumulator, AggregateExpr, Distribution, ExecutionPlan, Partitioning, PhysicalExpr,
RecordBatchStream, SendableRecordBatchStream, WindowExpr,
};
use crate::scalar::ScalarValue;
use arrow::compute::concat;
use arrow::{
array::ArrayRef,
Expand Down Expand Up @@ -96,6 +97,19 @@ pub fn create_window_expr(
})
}

fn get_scalar_value_from_args(
args: &[Arc<dyn PhysicalExpr>],
index: usize,
) -> Option<ScalarValue> {
args.get(index).map(|v| {
v.as_any()
.downcast_ref::<Literal>()
.unwrap()
.value()
.clone()
})
}

fn create_built_in_window_expr(
fun: &BuiltInWindowFunction,
args: &[Arc<dyn PhysicalExpr>],
Expand All @@ -110,13 +124,21 @@ fn create_built_in_window_expr(
let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?;
let arg = coerced_args[0].clone();
let data_type = args[0].data_type(input_schema)?;
Arc::new(lag(name, data_type, arg))
let shift_offset = get_scalar_value_from_args(&coerced_args, 1)
.map(|v| v.try_into())
.and_then(|v| v.ok());
let default_value = get_scalar_value_from_args(&coerced_args, 2);
Arc::new(lag(name, data_type, arg, shift_offset, default_value))
}
BuiltInWindowFunction::Lead => {
let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?;
let arg = coerced_args[0].clone();
let data_type = args[0].data_type(input_schema)?;
Arc::new(lead(name, data_type, arg))
let shift_offset = get_scalar_value_from_args(&coerced_args, 1)
.map(|v| v.try_into())
.and_then(|v| v.ok());
let default_value = get_scalar_value_from_args(&coerced_args, 2);
Arc::new(lead(name, data_type, arg, shift_offset, default_value))
}
BuiltInWindowFunction::NthValue => {
let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?;
Expand Down Expand Up @@ -592,6 +614,47 @@ mod tests {
Ok((input, schema))
}

#[test]
fn test_create_window_exp_lead_no_args() -> Result<()> {
let (_, schema) = create_test_schema(1)?;

let expr = create_window_expr(
&WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead),
"prev".to_owned(),
&[col("c2", &schema)?],
&[],
&[],
Some(WindowFrame::default()),
schema.as_ref(),
)?;

assert_eq!(expr.name(), "prev");

Ok(())
}

#[test]
fn test_create_window_exp_lead_with_args() -> Result<()> {
let (_, schema) = create_test_schema(1)?;

let expr = create_window_expr(
&WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead),
"prev".to_owned(),
&[
col("c2", &schema)?,
Arc::new(Literal::new(ScalarValue::Int64(Some(1)))),
],
&[],
&[],
Some(WindowFrame::default()),
schema.as_ref(),
)?;

assert_eq!(expr.name(), "prev");

Ok(())
}

#[tokio::test]
async fn window_function() -> Result<()> {
let (input, schema) = create_test_schema(1)?;
Expand Down
27 changes: 27 additions & 0 deletions integration-tests/sqls/simple_window_lead_built_in_functions.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at

-- http://www.apache.org/licenses/LICENSE-2.0

-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

SELECT
c8,
LEAD(c8) OVER () next_c8,
LEAD(c8, 10, 10) OVER() next_10_c8,
LEAD(c8, 100, 10) OVER() next_out_of_bounds_c8,
LAG(c8) OVER() prev_c8,
LAG(c8, -2, 0) OVER() AS prev_2_c8,
LAG(c8, -200, 10) OVER() AS prev_out_of_bounds_c8

FROM test
ORDER BY c8;
2 changes: 1 addition & 1 deletion integration-tests/test_psql_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def generate_csv_from_psql(fname: str):

class TestPsqlParity:
def test_tests_count(self):
assert len(test_files) == 14, "tests are missed"
assert len(test_files) == 15, "tests are missed"

@pytest.mark.parametrize("fname", test_files)
def test_sql_file(self, fname):
Expand Down

0 comments on commit 002ca5d

Please sign in to comment.