Skip to content

Commit

Permalink
port flatten (#9523)
Browse files Browse the repository at this point in the history
  • Loading branch information
Weijun-H authored Mar 10, 2024
1 parent eebdbe8 commit 37b3ff3
Show file tree
Hide file tree
Showing 13 changed files with 145 additions and 114 deletions.
21 changes: 1 addition & 20 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::type_coercion::functions::data_types;
use crate::{FuncMonotonicity, Signature, TypeSignature, Volatility};

use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit};
use datafusion_common::{exec_err, plan_err, DataFusionError, Result};
use datafusion_common::{plan_err, DataFusionError, Result};

use strum::IntoEnumIterator;
use strum_macros::EnumIter;
Expand Down Expand Up @@ -160,8 +160,6 @@ pub enum BuiltinScalarFunction {
ArrayResize,
/// construct an array from columns
MakeArray,
/// Flatten
Flatten,

// struct functions
/// struct
Expand Down Expand Up @@ -372,7 +370,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayReplaceN => Volatility::Immutable,
BuiltinScalarFunction::ArrayReplaceAll => Volatility::Immutable,
BuiltinScalarFunction::ArrayReverse => Volatility::Immutable,
BuiltinScalarFunction::Flatten => Volatility::Immutable,
BuiltinScalarFunction::ArraySlice => Volatility::Immutable,
BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable,
BuiltinScalarFunction::ArrayUnion => Volatility::Immutable,
Expand Down Expand Up @@ -475,20 +472,6 @@ impl BuiltinScalarFunction {
// the return type of the built in function.
// Some built-in functions' return type depends on the incoming type.
match self {
BuiltinScalarFunction::Flatten => {
fn get_base_type(data_type: &DataType) -> Result<DataType> {
match data_type {
DataType::List(field) | DataType::FixedSizeList(field, _) if matches!(field.data_type(), DataType::List(_)|DataType::FixedSizeList(_,_ )) => get_base_type(field.data_type()),
DataType::LargeList(field) if matches!(field.data_type(), DataType::LargeList(_)) => get_base_type(field.data_type()),
DataType::Null | DataType::List(_) | DataType::LargeList(_) => Ok(data_type.to_owned()),
DataType::FixedSizeList(field,_ ) => Ok(DataType::List(field.clone())),
_ => exec_err!("Not reachable, data_type should be List, LargeList or FixedSizeList"),
}
}

let data_type = get_base_type(&input_expr_types[0])?;
Ok(data_type)
}
BuiltinScalarFunction::ArrayAppend => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArraySort => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayConcat => {
Expand Down Expand Up @@ -827,7 +810,6 @@ impl BuiltinScalarFunction {
Signature::array_and_index(self.volatility())
}
BuiltinScalarFunction::ArrayExcept => Signature::any(2, self.volatility()),
BuiltinScalarFunction::Flatten => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayDistinct => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayPosition => {
Signature::array_and_element_and_optional_index(self.volatility())
Expand Down Expand Up @@ -1391,7 +1373,6 @@ impl BuiltinScalarFunction {
"list_extract",
],
BuiltinScalarFunction::ArrayExcept => &["array_except", "list_except"],
BuiltinScalarFunction::Flatten => &["flatten"],
BuiltinScalarFunction::ArrayPopFront => {
&["array_pop_front", "list_pop_front"]
}
Expand Down
6 changes: 0 additions & 6 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -611,12 +611,6 @@ scalar_expr!(
);

nary_scalar_expr!(ArrayConcat, array_concat, "concatenates arrays.");
scalar_expr!(
Flatten,
flatten,
array,
"flattens an array of arrays into a single array."
);
scalar_expr!(
ArrayElement,
array_element,
Expand Down
69 changes: 69 additions & 0 deletions datafusion/functions-array/src/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,3 +616,72 @@ pub fn array_length(args: &[ArrayRef]) -> Result<ArrayRef> {
array_type => exec_err!("array_length does not support type '{array_type:?}'"),
}
}

// Create new offsets that are euqiavlent to `flatten` the array.
fn get_offsets_for_flatten<O: OffsetSizeTrait>(
offsets: OffsetBuffer<O>,
indexes: OffsetBuffer<O>,
) -> OffsetBuffer<O> {
let buffer = offsets.into_inner();
let offsets: Vec<O> = indexes
.iter()
.map(|i| buffer[i.to_usize().unwrap()])
.collect();
OffsetBuffer::new(offsets.into())
}

fn flatten_internal<O: OffsetSizeTrait>(
list_arr: GenericListArray<O>,
indexes: Option<OffsetBuffer<O>>,
) -> Result<GenericListArray<O>> {
let (field, offsets, values, _) = list_arr.clone().into_parts();
let data_type = field.data_type();

match data_type {
// Recursively get the base offsets for flattened array
DataType::List(_) | DataType::LargeList(_) => {
let sub_list = as_generic_list_array::<O>(&values)?;
if let Some(indexes) = indexes {
let offsets = get_offsets_for_flatten(offsets, indexes);
flatten_internal::<O>(sub_list.clone(), Some(offsets))
} else {
flatten_internal::<O>(sub_list.clone(), Some(offsets))
}
}
// Reach the base level, create a new list array
_ => {
if let Some(indexes) = indexes {
let offsets = get_offsets_for_flatten(offsets, indexes);
let list_arr = GenericListArray::<O>::new(field, offsets, values, None);
Ok(list_arr)
} else {
Ok(list_arr.clone())
}
}
}
}

/// Flatten SQL function
pub fn flatten(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("flatten expects one argument");
}

let array_type = args[0].data_type();
match array_type {
DataType::List(_) => {
let list_arr = as_list_array(&args[0])?;
let flattened_array = flatten_internal::<i32>(list_arr.clone(), None)?;
Ok(Arc::new(flattened_array) as ArrayRef)
}
DataType::LargeList(_) => {
let list_arr = as_large_list_array(&args[0])?;
let flattened_array = flatten_internal::<i64>(list_arr.clone(), None)?;
Ok(Arc::new(flattened_array) as ArrayRef)
}
DataType::Null => Ok(args[0].clone()),
_ => {
exec_err!("flatten does not support type '{array_type:?}'")
}
}
}
2 changes: 2 additions & 0 deletions datafusion/functions-array/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ pub mod expr_fn {
pub use super::udf::array_ndims;
pub use super::udf::array_to_string;
pub use super::udf::cardinality;
pub use super::udf::flatten;
pub use super::udf::gen_series;
pub use super::udf::range;
}
Expand All @@ -68,6 +69,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
array_has::array_has_any_udf(),
udf::array_empty_udf(),
udf::array_length_udf(),
udf::flatten_udf(),
];
functions.into_iter().try_for_each(|udf| {
let existing_udf = registry.register_udf(udf)?;
Expand Down
68 changes: 68 additions & 0 deletions datafusion/functions-array/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -501,3 +501,71 @@ impl ScalarUDFImpl for ArrayLength {
&self.aliases
}
}

make_udf_function!(
Flatten,
flatten,
array,
"flattens an array of arrays into a single array.",
flatten_udf
);

#[derive(Debug)]
pub(super) struct Flatten {
signature: Signature,
aliases: Vec<String>,
}
impl Flatten {
pub fn new() -> Self {
Self {
signature: Signature::array(Volatility::Immutable),
aliases: vec![String::from("flatten")],
}
}
}

impl ScalarUDFImpl for Flatten {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"flatten"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
use DataType::*;
fn get_base_type(data_type: &DataType) -> Result<DataType> {
match data_type {
List(field) | FixedSizeList(field, _)
if matches!(field.data_type(), List(_) | FixedSizeList(_, _)) =>
{
get_base_type(field.data_type())
}
LargeList(field) if matches!(field.data_type(), LargeList(_)) => {
get_base_type(field.data_type())
}
Null | List(_) | LargeList(_) => Ok(data_type.to_owned()),
FixedSizeList(field, _) => Ok(List(field.clone())),
_ => exec_err!(
"Not reachable, data_type should be List, LargeList or FixedSizeList"
),
}
}

let data_type = get_base_type(&arg_types[0])?;
Ok(data_type)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;
crate::kernels::flatten(&args).map(ColumnarValue::Array)
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}
71 changes: 0 additions & 71 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1836,77 +1836,6 @@ pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
general_set_op(array1, array2, SetOp::Intersect)
}

// Create new offsets that are euqiavlent to `flatten` the array.
fn get_offsets_for_flatten<O: OffsetSizeTrait>(
offsets: OffsetBuffer<O>,
indexes: OffsetBuffer<O>,
) -> OffsetBuffer<O> {
let buffer = offsets.into_inner();
let offsets: Vec<O> = indexes
.iter()
.map(|i| buffer[i.to_usize().unwrap()])
.collect();
OffsetBuffer::new(offsets.into())
}

fn flatten_internal<O: OffsetSizeTrait>(
list_arr: GenericListArray<O>,
indexes: Option<OffsetBuffer<O>>,
) -> Result<GenericListArray<O>> {
let (field, offsets, values, _) = list_arr.clone().into_parts();
let data_type = field.data_type();

match data_type {
// Recursively get the base offsets for flattened array
DataType::List(_) | DataType::LargeList(_) => {
let sub_list = as_generic_list_array::<O>(&values)?;
if let Some(indexes) = indexes {
let offsets = get_offsets_for_flatten(offsets, indexes);
flatten_internal::<O>(sub_list.clone(), Some(offsets))
} else {
flatten_internal::<O>(sub_list.clone(), Some(offsets))
}
}
// Reach the base level, create a new list array
_ => {
if let Some(indexes) = indexes {
let offsets = get_offsets_for_flatten(offsets, indexes);
let list_arr = GenericListArray::<O>::new(field, offsets, values, None);
Ok(list_arr)
} else {
Ok(list_arr.clone())
}
}
}
}

/// Flatten SQL function
pub fn flatten(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("flatten expects one argument");
}

let array_type = args[0].data_type();
match array_type {
DataType::List(_) => {
let list_arr = as_list_array(&args[0])?;
let flattened_array = flatten_internal::<i32>(list_arr.clone(), None)?;
Ok(Arc::new(flattened_array) as ArrayRef)
}
DataType::LargeList(_) => {
let list_arr = as_large_list_array(&args[0])?;
let flattened_array = flatten_internal::<i64>(list_arr.clone(), None)?;
Ok(Arc::new(flattened_array) as ArrayRef)
}
DataType::Null => Ok(args[0].clone()),
_ => {
exec_err!("flatten does not support type '{array_type:?}'")
}
}

// Ok(Arc::new(flattened_array) as ArrayRef)
}

/// Splits string at occurrences of delimiter and returns an array of parts
/// string_to_array('abc~@~def~@~ghi', '~@~') = '["abc", "def", "ghi"]'
pub fn string_to_array<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
Expand Down
3 changes: 0 additions & 3 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,6 @@ pub fn create_physical_fun(
BuiltinScalarFunction::ArrayExcept => Arc::new(|args| {
make_scalar_function_inner(array_expressions::array_except)(args)
}),
BuiltinScalarFunction::Flatten => {
Arc::new(|args| make_scalar_function_inner(array_expressions::flatten)(args))
}
BuiltinScalarFunction::ArrayPopFront => Arc::new(|args| {
make_scalar_function_inner(array_expressions::array_pop_front)(args)
}),
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ enum ScalarFunction {
ArrayRemoveAll = 109;
ArrayReplaceAll = 110;
Nanvl = 111;
Flatten = 112;
// 112 was Flatten
// 113 was IsNan
Iszero = 114;
// 115 was ArrayEmpty
Expand Down
3 changes: 0 additions & 3 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 37b3ff3

Please sign in to comment.