diff --git a/datafusion/src/physical_plan/expressions/lead_lag.rs b/datafusion/src/physical_plan/expressions/lead_lag.rs index 352d97c1e116..d1f6c197a186 100644 --- a/datafusion/src/physical_plan/expressions/lead_lag.rs +++ b/datafusion/src/physical_plan/expressions/lead_lag.rs @@ -21,11 +21,13 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::window_functions::PartitionEvaluator; use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr}; +use crate::scalar::ScalarValue; use arrow::array::ArrayRef; -use arrow::compute::kernels::window::shift; +use arrow::compute::cast; use arrow::datatypes::{DataType, Field}; use arrow::record_batch::RecordBatch; use std::any::Any; +use std::ops::Neg; use std::ops::Range; use std::sync::Arc; @@ -36,6 +38,7 @@ pub struct WindowShift { data_type: DataType, shift_offset: i64, expr: Arc, + default_value: Option, } /// lead() window function @@ -43,12 +46,15 @@ pub fn lead( name: String, data_type: DataType, expr: Arc, + shift_offset: Option, + default_value: Option, ) -> WindowShift { WindowShift { name, data_type, - shift_offset: -1, + shift_offset: shift_offset.map(|v| v.neg()).unwrap_or(-1), expr, + default_value, } } @@ -57,12 +63,15 @@ pub fn lag( name: String, data_type: DataType, expr: Arc, + shift_offset: Option, + default_value: Option, ) -> WindowShift { WindowShift { name, data_type, - shift_offset: 1, + shift_offset: shift_offset.unwrap_or(1), expr, + default_value, } } @@ -98,6 +107,7 @@ impl BuiltInWindowFunctionExpr for WindowShift { Ok(Box::new(WindowShiftEvaluator { shift_offset: self.shift_offset, values, + default_value: self.default_value.clone(), })) } } @@ -105,13 +115,63 @@ impl BuiltInWindowFunctionExpr for WindowShift { pub(crate) struct WindowShiftEvaluator { shift_offset: i64, values: Vec, + default_value: Option, +} + +fn create_empty_array( + value: &Option, + data_type: &DataType, + size: usize, +) -> Result { + use arrow::array::new_null_array; + let array = value + .as_ref() + .map(|scalar| scalar.to_array_of_size(size)) + .unwrap_or_else(|| new_null_array(data_type, size)); + if array.data_type() != data_type { + cast(&array, data_type).map_err(DataFusionError::ArrowError) + } else { + Ok(array) + } +} + +// TODO: change the original arrow::compute::kernels::window::shift impl to support an optional default value +fn shift_with_default_value( + array: &ArrayRef, + offset: i64, + value: &Option, +) -> Result { + use arrow::compute::concat; + + let value_len = array.len() as i64; + if offset == 0 { + Ok(arrow::array::make_array(array.data_ref().clone())) + } else if offset == i64::MIN || offset.abs() >= value_len { + create_empty_array(value, array.data_type(), array.len()) + } else { + let slice_offset = (-offset).clamp(0, value_len) as usize; + let length = array.len() - offset.abs() as usize; + let slice = array.slice(slice_offset, length); + + // Generate array with remaining `null` items + let nulls = offset.abs() as usize; + let default_values = create_empty_array(value, slice.data_type(), nulls)?; + // Concatenate both arrays, add nulls after if shift > 0 else before + if offset > 0 { + concat(&[default_values.as_ref(), slice.as_ref()]) + .map_err(DataFusionError::ArrowError) + } else { + concat(&[slice.as_ref(), default_values.as_ref()]) + .map_err(DataFusionError::ArrowError) + } + } } impl PartitionEvaluator for WindowShiftEvaluator { fn evaluate_partition(&self, partition: Range) -> Result { let value = &self.values[0]; let value = value.slice(partition.start, partition.end - partition.start); - shift(value.as_ref(), self.shift_offset).map_err(DataFusionError::ArrowError) + shift_with_default_value(&value, self.shift_offset, &self.default_value) } } @@ -142,6 +202,8 @@ mod tests { "lead".to_owned(), DataType::Float32, Arc::new(Column::new("c3", 0)), + None, + None, ), vec![ Some(-2), @@ -162,6 +224,8 @@ mod tests { "lead".to_owned(), DataType::Float32, Arc::new(Column::new("c3", 0)), + None, + None, ), vec![ None, @@ -176,6 +240,28 @@ mod tests { .iter() .collect::(), )?; + + test_i32_result( + lag( + "lead".to_owned(), + DataType::Int32, + Arc::new(Column::new("c3", 0)), + None, + Some(ScalarValue::Int32(Some(100))), + ), + vec![ + Some(100), + Some(1), + Some(-2), + Some(3), + Some(-4), + Some(5), + Some(-6), + Some(7), + ] + .iter() + .collect::(), + )?; Ok(()) } } diff --git a/datafusion/src/physical_plan/type_coercion.rs b/datafusion/src/physical_plan/type_coercion.rs index ffd8f20064f7..c8387bbd71e1 100644 --- a/datafusion/src/physical_plan/type_coercion.rs +++ b/datafusion/src/physical_plan/type_coercion.rs @@ -128,13 +128,11 @@ fn get_valid_types( } vec![(0..*number).map(|i| current_types[i].clone()).collect()] } - Signature::OneOf(types) => { - let mut r = vec![]; - for s in types { - r.extend(get_valid_types(s, current_types)?); - } - r - } + Signature::OneOf(types) => types + .iter() + .filter_map(|t| get_valid_types(t, current_types).ok()) + .flatten() + .collect::>(), }; Ok(valid_types) @@ -367,4 +365,27 @@ mod tests { Ok(()) } + + #[test] + fn test_get_valid_types_one_of() -> Result<()> { + let signature = Signature::OneOf(vec![Signature::Any(1), Signature::Any(2)]); + + let invalid_types = get_valid_types( + &signature, + &[DataType::Int32, DataType::Int32, DataType::Int32], + )?; + assert_eq!(invalid_types.len(), 0); + + let args = vec![DataType::Int32, DataType::Int32]; + let valid_types = get_valid_types(&signature, &args)?; + assert_eq!(valid_types.len(), 1); + assert_eq!(valid_types[0], args); + + let args = vec![DataType::Int32]; + let valid_types = get_valid_types(&signature, &args)?; + assert_eq!(valid_types.len(), 1); + assert_eq!(valid_types[0], args); + + Ok(()) + } } diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs index 99805b6d2941..e2b460644479 100644 --- a/datafusion/src/physical_plan/window_functions.rs +++ b/datafusion/src/physical_plan/window_functions.rs @@ -201,10 +201,16 @@ pub(super) fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature { | BuiltInWindowFunction::DenseRank | BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => Signature::Any(0), - BuiltInWindowFunction::Lag - | BuiltInWindowFunction::Lead - | BuiltInWindowFunction::FirstValue - | BuiltInWindowFunction::LastValue => Signature::Any(1), + BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => { + Signature::OneOf(vec![ + Signature::Any(1), + Signature::Any(2), + Signature::Any(3), + ]) + } + BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { + Signature::Any(1) + } BuiltInWindowFunction::Ntile => Signature::Exact(vec![DataType::UInt64]), BuiltInWindowFunction::NthValue => Signature::Any(2), } diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index 1b783782e164..a1f4b7ace530 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -32,6 +32,7 @@ use crate::physical_plan::{ Accumulator, AggregateExpr, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, WindowExpr, }; +use crate::scalar::ScalarValue; use arrow::compute::concat; use arrow::{ array::ArrayRef, @@ -96,6 +97,19 @@ pub fn create_window_expr( }) } +fn get_scalar_value_from_args( + args: &[Arc], + index: usize, +) -> Option { + args.get(index).map(|v| { + v.as_any() + .downcast_ref::() + .unwrap() + .value() + .clone() + }) +} + fn create_built_in_window_expr( fun: &BuiltInWindowFunction, args: &[Arc], @@ -110,13 +124,21 @@ fn create_built_in_window_expr( let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?; let arg = coerced_args[0].clone(); let data_type = args[0].data_type(input_schema)?; - Arc::new(lag(name, data_type, arg)) + let shift_offset = get_scalar_value_from_args(&coerced_args, 1) + .map(|v| v.try_into()) + .and_then(|v| v.ok()); + let default_value = get_scalar_value_from_args(&coerced_args, 2); + Arc::new(lag(name, data_type, arg, shift_offset, default_value)) } BuiltInWindowFunction::Lead => { let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?; let arg = coerced_args[0].clone(); let data_type = args[0].data_type(input_schema)?; - Arc::new(lead(name, data_type, arg)) + let shift_offset = get_scalar_value_from_args(&coerced_args, 1) + .map(|v| v.try_into()) + .and_then(|v| v.ok()); + let default_value = get_scalar_value_from_args(&coerced_args, 2); + Arc::new(lead(name, data_type, arg, shift_offset, default_value)) } BuiltInWindowFunction::NthValue => { let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?; @@ -592,6 +614,47 @@ mod tests { Ok((input, schema)) } + #[test] + fn test_create_window_exp_lead_no_args() -> Result<()> { + let (_, schema) = create_test_schema(1)?; + + let expr = create_window_expr( + &WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead), + "prev".to_owned(), + &[col("c2", &schema)?], + &[], + &[], + Some(WindowFrame::default()), + schema.as_ref(), + )?; + + assert_eq!(expr.name(), "prev"); + + Ok(()) + } + + #[test] + fn test_create_window_exp_lead_with_args() -> Result<()> { + let (_, schema) = create_test_schema(1)?; + + let expr = create_window_expr( + &WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead), + "prev".to_owned(), + &[ + col("c2", &schema)?, + Arc::new(Literal::new(ScalarValue::Int64(Some(1)))), + ], + &[], + &[], + Some(WindowFrame::default()), + schema.as_ref(), + )?; + + assert_eq!(expr.name(), "prev"); + + Ok(()) + } + #[tokio::test] async fn window_function() -> Result<()> { let (input, schema) = create_test_schema(1)?; diff --git a/integration-tests/sqls/simple_window_lead_built_in_functions.sql b/integration-tests/sqls/simple_window_lead_built_in_functions.sql new file mode 100644 index 000000000000..67df05b68c1a --- /dev/null +++ b/integration-tests/sqls/simple_window_lead_built_in_functions.sql @@ -0,0 +1,27 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at + +-- http://www.apache.org/licenses/LICENSE-2.0 + +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +SELECT + c8, + LEAD(c8) OVER () next_c8, + LEAD(c8, 10, 10) OVER() next_10_c8, + LEAD(c8, 100, 10) OVER() next_out_of_bounds_c8, + LAG(c8) OVER() prev_c8, + LAG(c8, -2, 0) OVER() AS prev_2_c8, + LAG(c8, -200, 10) OVER() AS prev_out_of_bounds_c8 + +FROM test +ORDER BY c8; diff --git a/integration-tests/test_psql_parity.py b/integration-tests/test_psql_parity.py index a160d3e320ce..a85a2c2f4b37 100644 --- a/integration-tests/test_psql_parity.py +++ b/integration-tests/test_psql_parity.py @@ -77,7 +77,7 @@ def generate_csv_from_psql(fname: str): class TestPsqlParity: def test_tests_count(self): - assert len(test_files) == 14, "tests are missed" + assert len(test_files) == 15, "tests are missed" @pytest.mark.parametrize("fname", test_files) def test_sql_file(self, fname):