diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 64550aabf424..deb4372baa32 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1211,119 +1211,144 @@ array_removement_function!( "Array_remove_all SQL function" ); -fn general_replace(args: &[ArrayRef], arr_n: Vec) -> Result { - let list_array = as_list_array(&args[0])?; - let from_array = &args[1]; - let to_array = &args[2]; - +/// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurences +/// of `from_array[i]`, `to_array[i]`. +/// +/// The type of each **element** in `list_array` must be the same as the type of +/// `from_array` and `to_array`. This function also handles nested arrays +/// ([`ListArray`] of [`ListArray`]s) +/// +/// For example, when called to replace a list array (where each element is a +/// list of int32s, the second and third argument are int32 arrays, and the +/// fourth argument is the number of occurrences to replace +/// +/// ```text +/// general_replace( +/// [1, 2, 3, 2], 2, 10, 1 ==> [1, 10, 3, 2] (only the first 2 is replaced) +/// [4, 5, 6, 5], 5, 20, 2 ==> [4, 20, 6, 20] (both 5s are replaced) +/// ) +/// ``` +fn general_replace( + list_array: &ListArray, + from_array: &ArrayRef, + to_array: &ArrayRef, + arr_n: Vec, +) -> Result { + // Build up the offsets for the final output array let mut offsets: Vec = vec![0]; let data_type = list_array.value_type(); - let mut values = new_empty_array(&data_type); + let mut new_values = vec![]; - for (row_index, (arr, n)) in list_array.iter().zip(arr_n.iter()).enumerate() { + // n is the number of elements to replace in this row + for (row_index, (list_array_row, n)) in + list_array.iter().zip(arr_n.iter()).enumerate() + { let last_offset: i32 = offsets .last() .copied() .ok_or_else(|| internal_datafusion_err!("offsets should not be empty"))?; - match arr { - Some(arr) => { - let indices = UInt32Array::from(vec![row_index as u32]); - let from_arr = arrow::compute::take(from_array, &indices, None)?; - let eq_array = match from_arr.data_type() { - // arrow_ord::cmp_eq does not support ListArray, so we need to compare it by loop + match list_array_row { + Some(list_array_row) => { + let indices = UInt32Array::from(vec![row_index as u32]); + let from_array_row = arrow::compute::take(from_array, &indices, None)?; + // Compute all positions in list_row_array (that is itself an + // array) that are equal to `from_array_row` + let eq_array = match from_array_row.data_type() { + // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop DataType::List(_) => { - let from_a = as_list_array(&from_arr)?.value(0); - let list_arr = as_list_array(&arr)?; + // compare each element of the from array + let from_array_row_inner = + as_list_array(&from_array_row)?.value(0); + let list_array_row_inner = as_list_array(&list_array_row)?; - let mut bool_values = vec![]; - for arr in list_arr.iter() { - if let Some(a) = arr { - bool_values.push(Some(a.eq(&from_a))); - } else { - return internal_err!( - "Null value is not supported in array_replace" - ); - } - } - BooleanArray::from(bool_values) + list_array_row_inner + .iter() + // compare element by element the current row of list_array + .map(|row| row.map(|row| row.eq(&from_array_row_inner))) + .collect::() } _ => { - let from_arr = Scalar::new(from_arr); - arrow_ord::cmp::eq(&arr, &from_arr)? + let from_arr = Scalar::new(from_array_row); + // use not_distinct so NULL = NULL + arrow_ord::cmp::not_distinct(&list_array_row, &from_arr)? } }; // Use MutableArrayData to build the replaced array + let original_data = list_array_row.to_data(); + let to_data = to_array.to_data(); + let capacity = Capacities::Array(original_data.len() + to_data.len()); + // First array is the original array, second array is the element to replace with. - let arrays = vec![arr, to_array.clone()]; - let arrays_data = arrays - .iter() - .map(|a| a.to_data()) - .collect::>(); - let arrays_data = arrays_data.iter().collect::>(); - - let arrays = arrays - .iter() - .map(|arr| arr.as_ref()) - .collect::>(); - let capacity = Capacities::Array(arrays.iter().map(|a| a.len()).sum()); - - let mut mutable = - MutableArrayData::with_capacities(arrays_data, false, capacity); + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &to_data], + false, + capacity, + ); + let original_idx = 0; + let replace_idx = 1; let mut counter = 0; for (i, to_replace) in eq_array.iter().enumerate() { - if let Some(to_replace) = to_replace { - if to_replace { - mutable.extend(1, row_index, row_index + 1); - counter += 1; - if counter == *n { - // extend the rest of the array - mutable.extend(0, i + 1, eq_array.len()); - break; - } - } else { - mutable.extend(0, i, i + 1); + if let Some(true) = to_replace { + mutable.extend(replace_idx, row_index, row_index + 1); + counter += 1; + if counter == *n { + // copy original data for any matches past n + mutable.extend(original_idx, i + 1, eq_array.len()); + break; } } else { - return internal_err!("eq_array should not contain None"); + // copy original data for false / null matches + mutable.extend(original_idx, i, i + 1); } } let data = mutable.freeze(); let replaced_array = arrow_array::make_array(data); - let v = arrow::compute::concat(&[&values, &replaced_array])?; - values = v; offsets.push(last_offset + replaced_array.len() as i32); + new_values.push(replaced_array); } None => { + // Null element results in a null row (no new offsets) offsets.push(last_offset); } } } + let values = if new_values.is_empty() { + new_empty_array(&data_type) + } else { + let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); + arrow::compute::concat(&new_values)? + }; + Ok(Arc::new(ListArray::try_new( Arc::new(Field::new("item", data_type, true)), OffsetBuffer::new(offsets.into()), values, - None, + list_array.nulls().cloned(), )?)) } pub fn array_replace(args: &[ArrayRef]) -> Result { - general_replace(args, vec![1; args[0].len()]) + // replace at most one occurence for each element + let arr_n = vec![1; args[0].len()]; + general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) } pub fn array_replace_n(args: &[ArrayRef]) -> Result { - let arr = as_int64_array(&args[3])?; - let arr_n = arr.values().to_vec(); - general_replace(args, arr_n) + // replace the specified number of occurences + let arr_n = as_int64_array(&args[3])?.values().to_vec(); + general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) } pub fn array_replace_all(args: &[ArrayRef]) -> Result { - general_replace(args, vec![i64::MAX; args[0].len()]) + // replace all occurences (up to "i64::MAX") + let arr_n = vec![i64::MAX; args[0].len()]; + general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) } macro_rules! to_string { diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 85218efb5e14..c57369c167f4 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1720,6 +1720,37 @@ select array_replace_all(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12 [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [11, 12, 13], [11, 12, 13], [22, 23, 24], [11, 12, 13], [25, 26, 27], [11, 12, 13], [22, 23, 24], [11, 12, 13], [11, 12, 13]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13]] +# array_replace with null handling + +statement ok +create table t as values + (make_array(3, 1, NULL, 3), 3, 4, 2), + (make_array(3, 1, NULL, 3), NULL, 5, 2), + (NULL, 3, 2, 1), + (make_array(3, 1, 3), 3, NULL, 1) +; + + +# ([3, 1, NULL, 3], 3, 4, 2) => [4, 1, NULL, 4] NULL not matched +# ([3, 1, NULL, 3], NULL, 5, 2) => [3, 1, NULL, 3] NULL is replaced with 5 +# ([NULL], 3, 2, 1) => NULL +# ([3, 1, 3], 3, NULL, 1) => [NULL, 1 3] + +query ?III? +select column1, column2, column3, column4, array_replace_n(column1, column2, column3, column4) from t; +---- +[3, 1, , 3] 3 4 2 [4, 1, , 4] +[3, 1, , 3] NULL 5 2 [3, 1, 5, 3] +NULL 3 2 1 NULL +[3, 1, 3] 3 NULL 1 [, 1, 3] + + + +statement ok +drop table t; + + + ## array_to_string (aliases: `list_to_string`, `array_join`, `list_join`) # array_to_string scalar function #1