diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index ded606c3b705..c5e8b0e75c83 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -67,70 +67,6 @@ macro_rules! downcast_vec { }}; } -macro_rules! new_builder { - (BooleanBuilder, $len:expr) => { - BooleanBuilder::with_capacity($len) - }; - (StringBuilder, $len:expr) => { - StringBuilder::new() - }; - (LargeStringBuilder, $len:expr) => { - LargeStringBuilder::new() - }; - ($el:ident, $len:expr) => {{ - <$el>::with_capacity($len) - }}; -} - -/// Combines multiple arrays into a single ListArray -/// -/// $ARGS: slice of arrays, each with $ARRAY_TYPE -/// $ARRAY_TYPE: the type of the list elements -/// $BUILDER_TYPE: the type of ArrayBuilder for the list elements -/// -/// Returns: a ListArray where the elements each have the same type as -/// $ARRAY_TYPE and each element have a length of $ARGS.len() -macro_rules! array { - ($ARGS:expr, $ARRAY_TYPE:ident, $BUILDER_TYPE:ident) => {{ - let builder = new_builder!($BUILDER_TYPE, $ARGS[0].len()); - let mut builder = - ListBuilder::<$BUILDER_TYPE>::with_capacity(builder, $ARGS.len()); - - let num_rows = $ARGS[0].len(); - assert!( - $ARGS.iter().all(|a| a.len() == num_rows), - "all arguments must have the same number of rows" - ); - - // for each entry in the array - for index in 0..num_rows { - // for each column - for arg in $ARGS { - match arg.as_any().downcast_ref::<$ARRAY_TYPE>() { - // Copy the source array value into the target ListArray - Some(arr) => { - if arr.is_valid(index) { - builder.values().append_value(arr.value(index)); - } else { - builder.values().append_null(); - } - } - None => match arg.as_any().downcast_ref::() { - Some(arr) => { - for _ in 0..arr.len() { - builder.values().append_null(); - } - } - None => return internal_err!("failed to downcast"), - }, - } - } - builder.append(true); - } - Arc::new(builder.finish()) - }}; -} - /// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array. /// /// # Arguments @@ -389,88 +325,46 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { return plan_err!("Array requires at least one argument"); } - let res = match data_type { - DataType::List(..) => { - let row_count = args[0].len(); - let column_count = args.len(); - let mut list_arrays = vec![]; - let mut list_array_lengths = vec![]; - let mut list_valid = BooleanBufferBuilder::new(row_count); - // Construct ListArray per row - for index in 0..row_count { - let mut arrays = vec![]; - let mut array_lengths = vec![]; - let mut valid = BooleanBufferBuilder::new(column_count); - for arg in args { - if arg.as_any().downcast_ref::().is_some() { - array_lengths.push(0); - valid.append(false); - } else { - let list_arr = as_list_array(arg)?; - let arr = list_arr.value(index); - array_lengths.push(arr.len()); - arrays.push(arr); - valid.append(true); - } - } - if arrays.is_empty() { - list_valid.append(false); - list_array_lengths.push(0); - } else { - let buffer = valid.finish(); - // Assume all list arrays have the same data type - let data_type = arrays[0].data_type(); - let field = Arc::new(Field::new("item", data_type.to_owned(), true)); - let elements = arrays.iter().map(|x| x.as_ref()).collect::>(); - let values = compute::concat(elements.as_slice())?; - let list_arr = ListArray::new( - field, - OffsetBuffer::from_lengths(array_lengths), - values, - Some(NullBuffer::new(buffer)), - ); - list_valid.append(true); - list_array_lengths.push(list_arr.len()); - list_arrays.push(list_arr); - } + let mut data = vec![]; + let mut total_len = 0; + for arg in args { + let arg_data = if arg.as_any().is::() { + ArrayData::new_empty(&data_type) + } else { + arg.to_data() + }; + total_len += arg_data.len(); + data.push(arg_data); + } + let mut offsets = Vec::with_capacity(total_len); + offsets.push(0); + + let capacity = Capacities::Array(total_len); + let data_ref = data.iter().collect::>(); + let mut mutable = MutableArrayData::with_capacities(data_ref, true, capacity); + + let num_rows = args[0].len(); + for row_idx in 0..num_rows { + for (arr_idx, arg) in args.iter().enumerate() { + if !arg.as_any().is::() + && !arg.is_null(row_idx) + && arg.is_valid(row_idx) + { + mutable.extend(arr_idx, row_idx, row_idx + 1); + } else { + mutable.extend_nulls(1); } - // Construct ListArray for all rows - let buffer = list_valid.finish(); - // Assume all list arrays have the same data type - let data_type = list_arrays[0].data_type(); - let field = Arc::new(Field::new("item", data_type.to_owned(), true)); - let elements = list_arrays - .iter() - .map(|x| x as &dyn Array) - .collect::>(); - let values = compute::concat(elements.as_slice())?; - let list_arr = ListArray::new( - field, - OffsetBuffer::from_lengths(list_array_lengths), - values, - Some(NullBuffer::new(buffer)), - ); - Arc::new(list_arr) - } - DataType::Utf8 => array!(args, StringArray, StringBuilder), - DataType::LargeUtf8 => array!(args, LargeStringArray, LargeStringBuilder), - DataType::Boolean => array!(args, BooleanArray, BooleanBuilder), - DataType::Float32 => array!(args, Float32Array, Float32Builder), - DataType::Float64 => array!(args, Float64Array, Float64Builder), - DataType::Int8 => array!(args, Int8Array, Int8Builder), - DataType::Int16 => array!(args, Int16Array, Int16Builder), - DataType::Int32 => array!(args, Int32Array, Int32Builder), - DataType::Int64 => array!(args, Int64Array, Int64Builder), - DataType::UInt8 => array!(args, UInt8Array, UInt8Builder), - DataType::UInt16 => array!(args, UInt16Array, UInt16Builder), - DataType::UInt32 => array!(args, UInt32Array, UInt32Builder), - DataType::UInt64 => array!(args, UInt64Array, UInt64Builder), - data_type => { - return not_impl_err!("Array is not implemented for type '{data_type:?}'.") } - }; + offsets.push(mutable.len() as i32); + } - Ok(res) + let data = mutable.freeze(); + Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::new(offsets.into()), + arrow_array::make_array(data), + None, + )?)) } /// `make_array` SQL function