diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 89c5cfcce647..f6789ea5966d 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -29,7 +29,7 @@ use crate::type_coercion::functions::data_types; use crate::{FuncMonotonicity, Signature, TypeSignature, Volatility}; use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; -use datafusion_common::{exec_err, plan_err, DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use strum::IntoEnumIterator; use strum_macros::EnumIter; @@ -160,8 +160,6 @@ pub enum BuiltinScalarFunction { ArrayResize, /// construct an array from columns MakeArray, - /// Flatten - Flatten, // struct functions /// struct @@ -372,7 +370,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReplaceN => Volatility::Immutable, BuiltinScalarFunction::ArrayReplaceAll => Volatility::Immutable, BuiltinScalarFunction::ArrayReverse => Volatility::Immutable, - BuiltinScalarFunction::Flatten => Volatility::Immutable, BuiltinScalarFunction::ArraySlice => Volatility::Immutable, BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable, BuiltinScalarFunction::ArrayUnion => Volatility::Immutable, @@ -475,20 +472,6 @@ impl BuiltinScalarFunction { // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match self { - BuiltinScalarFunction::Flatten => { - fn get_base_type(data_type: &DataType) -> Result { - match data_type { - DataType::List(field) | DataType::FixedSizeList(field, _) if matches!(field.data_type(), DataType::List(_)|DataType::FixedSizeList(_,_ )) => get_base_type(field.data_type()), - DataType::LargeList(field) if matches!(field.data_type(), DataType::LargeList(_)) => get_base_type(field.data_type()), - DataType::Null | DataType::List(_) | DataType::LargeList(_) => Ok(data_type.to_owned()), - DataType::FixedSizeList(field,_ ) => Ok(DataType::List(field.clone())), - _ => exec_err!("Not reachable, data_type should be List, LargeList or FixedSizeList"), - } - } - - let data_type = get_base_type(&input_expr_types[0])?; - Ok(data_type) - } BuiltinScalarFunction::ArrayAppend => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArraySort => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayConcat => { @@ -827,7 +810,6 @@ impl BuiltinScalarFunction { Signature::array_and_index(self.volatility()) } BuiltinScalarFunction::ArrayExcept => Signature::any(2, self.volatility()), - BuiltinScalarFunction::Flatten => Signature::array(self.volatility()), BuiltinScalarFunction::ArrayDistinct => Signature::array(self.volatility()), BuiltinScalarFunction::ArrayPosition => { Signature::array_and_element_and_optional_index(self.volatility()) @@ -1391,7 +1373,6 @@ impl BuiltinScalarFunction { "list_extract", ], BuiltinScalarFunction::ArrayExcept => &["array_except", "list_except"], - BuiltinScalarFunction::Flatten => &["flatten"], BuiltinScalarFunction::ArrayPopFront => { &["array_pop_front", "list_pop_front"] } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index ef1b1c45042c..4090f9da6e0b 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -611,12 +611,6 @@ scalar_expr!( ); nary_scalar_expr!(ArrayConcat, array_concat, "concatenates arrays."); -scalar_expr!( - Flatten, - flatten, - array, - "flattens an array of arrays into a single array." -); scalar_expr!( ArrayElement, array_element, diff --git a/datafusion/functions-array/src/kernels.rs b/datafusion/functions-array/src/kernels.rs index 3138843feb58..ad96d232aa4a 100644 --- a/datafusion/functions-array/src/kernels.rs +++ b/datafusion/functions-array/src/kernels.rs @@ -616,3 +616,72 @@ pub fn array_length(args: &[ArrayRef]) -> Result { array_type => exec_err!("array_length does not support type '{array_type:?}'"), } } + +// Create new offsets that are euqiavlent to `flatten` the array. +fn get_offsets_for_flatten( + offsets: OffsetBuffer, + indexes: OffsetBuffer, +) -> OffsetBuffer { + let buffer = offsets.into_inner(); + let offsets: Vec = indexes + .iter() + .map(|i| buffer[i.to_usize().unwrap()]) + .collect(); + OffsetBuffer::new(offsets.into()) +} + +fn flatten_internal( + list_arr: GenericListArray, + indexes: Option>, +) -> Result> { + let (field, offsets, values, _) = list_arr.clone().into_parts(); + let data_type = field.data_type(); + + match data_type { + // Recursively get the base offsets for flattened array + DataType::List(_) | DataType::LargeList(_) => { + let sub_list = as_generic_list_array::(&values)?; + if let Some(indexes) = indexes { + let offsets = get_offsets_for_flatten(offsets, indexes); + flatten_internal::(sub_list.clone(), Some(offsets)) + } else { + flatten_internal::(sub_list.clone(), Some(offsets)) + } + } + // Reach the base level, create a new list array + _ => { + if let Some(indexes) = indexes { + let offsets = get_offsets_for_flatten(offsets, indexes); + let list_arr = GenericListArray::::new(field, offsets, values, None); + Ok(list_arr) + } else { + Ok(list_arr.clone()) + } + } + } +} + +/// Flatten SQL function +pub fn flatten(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("flatten expects one argument"); + } + + let array_type = args[0].data_type(); + match array_type { + DataType::List(_) => { + let list_arr = as_list_array(&args[0])?; + let flattened_array = flatten_internal::(list_arr.clone(), None)?; + Ok(Arc::new(flattened_array) as ArrayRef) + } + DataType::LargeList(_) => { + let list_arr = as_large_list_array(&args[0])?; + let flattened_array = flatten_internal::(list_arr.clone(), None)?; + Ok(Arc::new(flattened_array) as ArrayRef) + } + DataType::Null => Ok(args[0].clone()), + _ => { + exec_err!("flatten does not support type '{array_type:?}'") + } + } +} diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index 6f0f2beca75c..73055966ee46 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -50,6 +50,7 @@ pub mod expr_fn { pub use super::udf::array_ndims; pub use super::udf::array_to_string; pub use super::udf::cardinality; + pub use super::udf::flatten; pub use super::udf::gen_series; pub use super::udf::range; } @@ -68,6 +69,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { array_has::array_has_any_udf(), udf::array_empty_udf(), udf::array_length_udf(), + udf::flatten_udf(), ]; functions.into_iter().try_for_each(|udf| { let existing_udf = registry.register_udf(udf)?; diff --git a/datafusion/functions-array/src/udf.rs b/datafusion/functions-array/src/udf.rs index 8dc4f722c08a..b2c310e1701d 100644 --- a/datafusion/functions-array/src/udf.rs +++ b/datafusion/functions-array/src/udf.rs @@ -501,3 +501,71 @@ impl ScalarUDFImpl for ArrayLength { &self.aliases } } + +make_udf_function!( + Flatten, + flatten, + array, + "flattens an array of arrays into a single array.", + flatten_udf +); + +#[derive(Debug)] +pub(super) struct Flatten { + signature: Signature, + aliases: Vec, +} +impl Flatten { + pub fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + aliases: vec![String::from("flatten")], + } + } +} + +impl ScalarUDFImpl for Flatten { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "flatten" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use DataType::*; + fn get_base_type(data_type: &DataType) -> Result { + match data_type { + List(field) | FixedSizeList(field, _) + if matches!(field.data_type(), List(_) | FixedSizeList(_, _)) => + { + get_base_type(field.data_type()) + } + LargeList(field) if matches!(field.data_type(), LargeList(_)) => { + get_base_type(field.data_type()) + } + Null | List(_) | LargeList(_) => Ok(data_type.to_owned()), + FixedSizeList(field, _) => Ok(List(field.clone())), + _ => exec_err!( + "Not reachable, data_type should be List, LargeList or FixedSizeList" + ), + } + } + + let data_type = get_base_type(&arg_types[0])?; + Ok(data_type) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + crate::kernels::flatten(&args).map(ColumnarValue::Array) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 495de01c7615..5be72b0559d3 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1836,77 +1836,6 @@ pub fn array_intersect(args: &[ArrayRef]) -> Result { general_set_op(array1, array2, SetOp::Intersect) } -// Create new offsets that are euqiavlent to `flatten` the array. -fn get_offsets_for_flatten( - offsets: OffsetBuffer, - indexes: OffsetBuffer, -) -> OffsetBuffer { - let buffer = offsets.into_inner(); - let offsets: Vec = indexes - .iter() - .map(|i| buffer[i.to_usize().unwrap()]) - .collect(); - OffsetBuffer::new(offsets.into()) -} - -fn flatten_internal( - list_arr: GenericListArray, - indexes: Option>, -) -> Result> { - let (field, offsets, values, _) = list_arr.clone().into_parts(); - let data_type = field.data_type(); - - match data_type { - // Recursively get the base offsets for flattened array - DataType::List(_) | DataType::LargeList(_) => { - let sub_list = as_generic_list_array::(&values)?; - if let Some(indexes) = indexes { - let offsets = get_offsets_for_flatten(offsets, indexes); - flatten_internal::(sub_list.clone(), Some(offsets)) - } else { - flatten_internal::(sub_list.clone(), Some(offsets)) - } - } - // Reach the base level, create a new list array - _ => { - if let Some(indexes) = indexes { - let offsets = get_offsets_for_flatten(offsets, indexes); - let list_arr = GenericListArray::::new(field, offsets, values, None); - Ok(list_arr) - } else { - Ok(list_arr.clone()) - } - } - } -} - -/// Flatten SQL function -pub fn flatten(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("flatten expects one argument"); - } - - let array_type = args[0].data_type(); - match array_type { - DataType::List(_) => { - let list_arr = as_list_array(&args[0])?; - let flattened_array = flatten_internal::(list_arr.clone(), None)?; - Ok(Arc::new(flattened_array) as ArrayRef) - } - DataType::LargeList(_) => { - let list_arr = as_large_list_array(&args[0])?; - let flattened_array = flatten_internal::(list_arr.clone(), None)?; - Ok(Arc::new(flattened_array) as ArrayRef) - } - DataType::Null => Ok(args[0].clone()), - _ => { - exec_err!("flatten does not support type '{array_type:?}'") - } - } - - // Ok(Arc::new(flattened_array) as ArrayRef) -} - /// Splits string at occurrences of delimiter and returns an array of parts /// string_to_array('abc~@~def~@~ghi', '~@~') = '["abc", "def", "ghi"]' pub fn string_to_array(args: &[ArrayRef]) -> Result { diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 776f6315a405..12727920dd33 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -320,9 +320,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayExcept => Arc::new(|args| { make_scalar_function_inner(array_expressions::array_except)(args) }), - BuiltinScalarFunction::Flatten => { - Arc::new(|args| make_scalar_function_inner(array_expressions::flatten)(args)) - } BuiltinScalarFunction::ArrayPopFront => Arc::new(|args| { make_scalar_function_inner(array_expressions::array_pop_front)(args) }), diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d1fef7c1ceae..206c43822bbb 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -659,7 +659,7 @@ enum ScalarFunction { ArrayRemoveAll = 109; ArrayReplaceAll = 110; Nanvl = 111; - Flatten = 112; + // 112 was Flatten // 113 was IsNan Iszero = 114; // 115 was ArrayEmpty diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index e4da28ed44ec..022d72ce212c 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22414,7 +22414,6 @@ impl serde::Serialize for ScalarFunction { Self::ArrayRemoveAll => "ArrayRemoveAll", Self::ArrayReplaceAll => "ArrayReplaceAll", Self::Nanvl => "Nanvl", - Self::Flatten => "Flatten", Self::Iszero => "Iszero", Self::ArrayPopBack => "ArrayPopBack", Self::StringToArray => "StringToArray", @@ -22537,7 +22536,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayRemoveAll", "ArrayReplaceAll", "Nanvl", - "Flatten", "Iszero", "ArrayPopBack", "StringToArray", @@ -22689,7 +22687,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayRemoveAll" => Ok(ScalarFunction::ArrayRemoveAll), "ArrayReplaceAll" => Ok(ScalarFunction::ArrayReplaceAll), "Nanvl" => Ok(ScalarFunction::Nanvl), - "Flatten" => Ok(ScalarFunction::Flatten), "Iszero" => Ok(ScalarFunction::Iszero), "ArrayPopBack" => Ok(ScalarFunction::ArrayPopBack), "StringToArray" => Ok(ScalarFunction::StringToArray), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 30b76c16bc91..bccc77875f7c 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2747,7 +2747,7 @@ pub enum ScalarFunction { ArrayRemoveAll = 109, ArrayReplaceAll = 110, Nanvl = 111, - Flatten = 112, + /// 112 was Flatten /// 113 was IsNan Iszero = 114, /// 115 was ArrayEmpty @@ -2876,7 +2876,6 @@ impl ScalarFunction { ScalarFunction::ArrayRemoveAll => "ArrayRemoveAll", ScalarFunction::ArrayReplaceAll => "ArrayReplaceAll", ScalarFunction::Nanvl => "Nanvl", - ScalarFunction::Flatten => "Flatten", ScalarFunction::Iszero => "Iszero", ScalarFunction::ArrayPopBack => "ArrayPopBack", ScalarFunction::StringToArray => "StringToArray", @@ -2993,7 +2992,6 @@ impl ScalarFunction { "ArrayRemoveAll" => Some(Self::ArrayRemoveAll), "ArrayReplaceAll" => Some(Self::ArrayReplaceAll), "Nanvl" => Some(Self::Nanvl), - "Flatten" => Some(Self::Flatten), "Iszero" => Some(Self::Iszero), "ArrayPopBack" => Some(Self::ArrayPopBack), "StringToArray" => Some(Self::StringToArray), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index ece3caa09475..8222be1fe9bc 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -56,8 +56,8 @@ use datafusion_expr::{ concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, date_part, date_trunc, degrees, digest, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, find_in_set, flatten, floor, from_unixtime, gcd, initcap, iszero, lcm, - left, levenshtein, ln, log, log10, log2, + factorial, find_in_set, floor, from_unixtime, gcd, initcap, iszero, lcm, left, + levenshtein, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, lower, lpad, ltrim, md5, nanvl, now, octet_length, overlay, pi, power, radians, random, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256, sha384, @@ -483,7 +483,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrayExcept => Self::ArrayExcept, ScalarFunction::ArrayDistinct => Self::ArrayDistinct, ScalarFunction::ArrayElement => Self::ArrayElement, - ScalarFunction::Flatten => Self::Flatten, ScalarFunction::ArrayPopFront => Self::ArrayPopFront, ScalarFunction::ArrayPopBack => Self::ArrayPopBack, ScalarFunction::ArrayPosition => Self::ArrayPosition, @@ -1784,9 +1783,6 @@ pub fn parse_expr( ScalarFunction::ArrowTypeof => { Ok(arrow_typeof(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Flatten => { - Ok(flatten(parse_expr(&args[0], registry, codec)?)) - } ScalarFunction::StringToArray => Ok(string_to_array( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 43c8d7e4b299..f28fb5466734 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1462,7 +1462,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArrayExcept => Self::ArrayExcept, BuiltinScalarFunction::ArrayDistinct => Self::ArrayDistinct, BuiltinScalarFunction::ArrayElement => Self::ArrayElement, - BuiltinScalarFunction::Flatten => Self::Flatten, BuiltinScalarFunction::ArrayPopFront => Self::ArrayPopFront, BuiltinScalarFunction::ArrayPopBack => Self::ArrayPopBack, BuiltinScalarFunction::ArrayPosition => Self::ArrayPosition, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index ef9f2b27aa15..fb9f2967553f 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -599,6 +599,7 @@ async fn roundtrip_expr_api() -> Result<()> { ), array_empty(array(vec![lit(1), lit(2), lit(3)])), array_length(array(vec![lit(1), lit(2), lit(3)])), + flatten(array(vec![lit(1), lit(2), lit(3)])), ]; // ensure expressions created with the expr api can be round tripped