Skip to content

Commit

Permalink
Minor: Support nulls in array_replace, avoid a copy (#8054)
Browse files Browse the repository at this point in the history
* Minor: clean up array_replace

* null test

* remove println

* Fix doc test

* port test to sqllogictest

* Use not_distinct

* Apply suggestions from code review

Co-authored-by: jakevin <jakevingoo@gmail.com>

---------

Co-authored-by: jakevin <jakevingoo@gmail.com>
  • Loading branch information
alamb and jackwener authored Nov 9, 2023
1 parent 91a44c1 commit 93af440
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 63 deletions.
151 changes: 88 additions & 63 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1211,119 +1211,144 @@ array_removement_function!(
"Array_remove_all SQL function"
);

fn general_replace(args: &[ArrayRef], arr_n: Vec<i64>) -> Result<ArrayRef> {
let list_array = as_list_array(&args[0])?;
let from_array = &args[1];
let to_array = &args[2];

/// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurences
/// of `from_array[i]`, `to_array[i]`.
///
/// The type of each **element** in `list_array` must be the same as the type of
/// `from_array` and `to_array`. This function also handles nested arrays
/// ([`ListArray`] of [`ListArray`]s)
///
/// For example, when called to replace a list array (where each element is a
/// list of int32s, the second and third argument are int32 arrays, and the
/// fourth argument is the number of occurrences to replace
///
/// ```text
/// general_replace(
/// [1, 2, 3, 2], 2, 10, 1 ==> [1, 10, 3, 2] (only the first 2 is replaced)
/// [4, 5, 6, 5], 5, 20, 2 ==> [4, 20, 6, 20] (both 5s are replaced)
/// )
/// ```
fn general_replace(
list_array: &ListArray,
from_array: &ArrayRef,
to_array: &ArrayRef,
arr_n: Vec<i64>,
) -> Result<ArrayRef> {
// Build up the offsets for the final output array
let mut offsets: Vec<i32> = vec![0];
let data_type = list_array.value_type();
let mut values = new_empty_array(&data_type);
let mut new_values = vec![];

for (row_index, (arr, n)) in list_array.iter().zip(arr_n.iter()).enumerate() {
// n is the number of elements to replace in this row
for (row_index, (list_array_row, n)) in
list_array.iter().zip(arr_n.iter()).enumerate()
{
let last_offset: i32 = offsets
.last()
.copied()
.ok_or_else(|| internal_datafusion_err!("offsets should not be empty"))?;
match arr {
Some(arr) => {
let indices = UInt32Array::from(vec![row_index as u32]);
let from_arr = arrow::compute::take(from_array, &indices, None)?;

let eq_array = match from_arr.data_type() {
// arrow_ord::cmp_eq does not support ListArray, so we need to compare it by loop
match list_array_row {
Some(list_array_row) => {
let indices = UInt32Array::from(vec![row_index as u32]);
let from_array_row = arrow::compute::take(from_array, &indices, None)?;
// Compute all positions in list_row_array (that is itself an
// array) that are equal to `from_array_row`
let eq_array = match from_array_row.data_type() {
// arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop
DataType::List(_) => {
let from_a = as_list_array(&from_arr)?.value(0);
let list_arr = as_list_array(&arr)?;
// compare each element of the from array
let from_array_row_inner =
as_list_array(&from_array_row)?.value(0);
let list_array_row_inner = as_list_array(&list_array_row)?;

let mut bool_values = vec![];
for arr in list_arr.iter() {
if let Some(a) = arr {
bool_values.push(Some(a.eq(&from_a)));
} else {
return internal_err!(
"Null value is not supported in array_replace"
);
}
}
BooleanArray::from(bool_values)
list_array_row_inner
.iter()
// compare element by element the current row of list_array
.map(|row| row.map(|row| row.eq(&from_array_row_inner)))
.collect::<BooleanArray>()
}
_ => {
let from_arr = Scalar::new(from_arr);
arrow_ord::cmp::eq(&arr, &from_arr)?
let from_arr = Scalar::new(from_array_row);
// use not_distinct so NULL = NULL
arrow_ord::cmp::not_distinct(&list_array_row, &from_arr)?
}
};

// Use MutableArrayData to build the replaced array
let original_data = list_array_row.to_data();
let to_data = to_array.to_data();
let capacity = Capacities::Array(original_data.len() + to_data.len());

// First array is the original array, second array is the element to replace with.
let arrays = vec![arr, to_array.clone()];
let arrays_data = arrays
.iter()
.map(|a| a.to_data())
.collect::<Vec<ArrayData>>();
let arrays_data = arrays_data.iter().collect::<Vec<&ArrayData>>();

let arrays = arrays
.iter()
.map(|arr| arr.as_ref())
.collect::<Vec<&dyn Array>>();
let capacity = Capacities::Array(arrays.iter().map(|a| a.len()).sum());

let mut mutable =
MutableArrayData::with_capacities(arrays_data, false, capacity);
let mut mutable = MutableArrayData::with_capacities(
vec![&original_data, &to_data],
false,
capacity,
);
let original_idx = 0;
let replace_idx = 1;

let mut counter = 0;
for (i, to_replace) in eq_array.iter().enumerate() {
if let Some(to_replace) = to_replace {
if to_replace {
mutable.extend(1, row_index, row_index + 1);
counter += 1;
if counter == *n {
// extend the rest of the array
mutable.extend(0, i + 1, eq_array.len());
break;
}
} else {
mutable.extend(0, i, i + 1);
if let Some(true) = to_replace {
mutable.extend(replace_idx, row_index, row_index + 1);
counter += 1;
if counter == *n {
// copy original data for any matches past n
mutable.extend(original_idx, i + 1, eq_array.len());
break;
}
} else {
return internal_err!("eq_array should not contain None");
// copy original data for false / null matches
mutable.extend(original_idx, i, i + 1);
}
}

let data = mutable.freeze();
let replaced_array = arrow_array::make_array(data);

let v = arrow::compute::concat(&[&values, &replaced_array])?;
values = v;
offsets.push(last_offset + replaced_array.len() as i32);
new_values.push(replaced_array);
}
None => {
// Null element results in a null row (no new offsets)
offsets.push(last_offset);
}
}
}

let values = if new_values.is_empty() {
new_empty_array(&data_type)
} else {
let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
arrow::compute::concat(&new_values)?
};

Ok(Arc::new(ListArray::try_new(
Arc::new(Field::new("item", data_type, true)),
OffsetBuffer::new(offsets.into()),
values,
None,
list_array.nulls().cloned(),
)?))
}

pub fn array_replace(args: &[ArrayRef]) -> Result<ArrayRef> {
general_replace(args, vec![1; args[0].len()])
// replace at most one occurence for each element
let arr_n = vec![1; args[0].len()];
general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n)
}

pub fn array_replace_n(args: &[ArrayRef]) -> Result<ArrayRef> {
let arr = as_int64_array(&args[3])?;
let arr_n = arr.values().to_vec();
general_replace(args, arr_n)
// replace the specified number of occurences
let arr_n = as_int64_array(&args[3])?.values().to_vec();
general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n)
}

pub fn array_replace_all(args: &[ArrayRef]) -> Result<ArrayRef> {
general_replace(args, vec![i64::MAX; args[0].len()])
// replace all occurences (up to "i64::MAX")
let arr_n = vec![i64::MAX; args[0].len()];
general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n)
}

macro_rules! to_string {
Expand Down
31 changes: 31 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1720,6 +1720,37 @@ select array_replace_all(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12
[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [11, 12, 13], [11, 12, 13], [22, 23, 24], [11, 12, 13], [25, 26, 27], [11, 12, 13], [22, 23, 24], [11, 12, 13], [11, 12, 13]]
[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13]]

# array_replace with null handling

statement ok
create table t as values
(make_array(3, 1, NULL, 3), 3, 4, 2),
(make_array(3, 1, NULL, 3), NULL, 5, 2),
(NULL, 3, 2, 1),
(make_array(3, 1, 3), 3, NULL, 1)
;


# ([3, 1, NULL, 3], 3, 4, 2) => [4, 1, NULL, 4] NULL not matched
# ([3, 1, NULL, 3], NULL, 5, 2) => [3, 1, NULL, 3] NULL is replaced with 5
# ([NULL], 3, 2, 1) => NULL
# ([3, 1, 3], 3, NULL, 1) => [NULL, 1 3]

query ?III?
select column1, column2, column3, column4, array_replace_n(column1, column2, column3, column4) from t;
----
[3, 1, , 3] 3 4 2 [4, 1, , 4]
[3, 1, , 3] NULL 5 2 [3, 1, 5, 3]
NULL 3 2 1 NULL
[3, 1, 3] 3 NULL 1 [, 1, 3]



statement ok
drop table t;



## array_to_string (aliases: `list_to_string`, `array_join`, `list_join`)

# array_to_string scalar function #1
Expand Down

0 comments on commit 93af440

Please sign in to comment.