-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Minor: Support nulls
in array_replace
, avoid a copy
#8054
Changes from 3 commits
ef69a13
05bdf1d
812bd0b
0e6c5c8
c6dabcf
acbffa0
cafc404
58a6ab9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1222,119 +1222,144 @@ array_removement_function!( | |
"Array_remove_all SQL function" | ||
); | ||
|
||
fn general_replace(args: &[ArrayRef], arr_n: Vec<i64>) -> Result<ArrayRef> { | ||
let list_array = as_list_array(&args[0])?; | ||
let from_array = &args[1]; | ||
let to_array = &args[2]; | ||
|
||
/// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurences | ||
/// of `from_array[i]`, `to_array[i]`. | ||
/// | ||
/// The type of each **element** in `list_array` must be the same as the type of | ||
/// `from_array` and `to_array`. This function also handles nested arrays | ||
/// ([`ListArray`] of [`ListArray`]s) | ||
/// | ||
/// For example, whn called to replace a list array (where each element is a | ||
alamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/// list of int32s, the second and third argument are int32 arrays, and the | ||
/// fourth argument is the number of occurrences to replace | ||
/// | ||
/// ``` | ||
/// general_replace( | ||
/// [1, 2, 3, 2], 2, 10, 1 ==> [1, 10, 3, 2] (only the first 2 is replaced) | ||
/// [4, 5, 6, 5], 5, 20, 2 ==> [4, 20, 6, 20] (both 5s are replaced) | ||
/// ) | ||
/// ``` | ||
/// | ||
fn general_replace( | ||
list_array: &ListArray, | ||
from_array: &ArrayRef, | ||
to_array: &ArrayRef, | ||
arr_n: Vec<i64>, | ||
) -> Result<ArrayRef> { | ||
// Build up the offsets for the final output array | ||
let mut offsets: Vec<i32> = vec![0]; | ||
let data_type = list_array.value_type(); | ||
let mut values = new_empty_array(&data_type); | ||
let mut new_values = vec![]; | ||
|
||
for (row_index, (arr, n)) in list_array.iter().zip(arr_n.iter()).enumerate() { | ||
// n is the number of elements to replace in this row | ||
for (row_index, (list_array_row, n)) in | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I rename these variables to better explain what they are (rather than |
||
list_array.iter().zip(arr_n.iter()).enumerate() | ||
{ | ||
let last_offset: i32 = offsets | ||
.last() | ||
.copied() | ||
.ok_or_else(|| internal_datafusion_err!("offsets should not be empty"))?; | ||
match arr { | ||
Some(arr) => { | ||
let indices = UInt32Array::from(vec![row_index as u32]); | ||
let from_arr = arrow::compute::take(from_array, &indices, None)?; | ||
|
||
let eq_array = match from_arr.data_type() { | ||
// arrow_ord::cmp_eq does not support ListArray, so we need to compare it by loop | ||
match list_array_row { | ||
Some(list_array_row) => { | ||
let indices = UInt32Array::from(vec![row_index as u32]); | ||
let from_array_row = arrow::compute::take(from_array, &indices, None)?; | ||
// Compute all positions in list_row_array (that is itself an | ||
// array) that are equal to `from_array_row` | ||
let eq_array = match from_array_row.data_type() { | ||
// arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop | ||
DataType::List(_) => { | ||
let from_a = as_list_array(&from_arr)?.value(0); | ||
let list_arr = as_list_array(&arr)?; | ||
// compare each element of the from array | ||
let from_array_row_inner = | ||
as_list_array(&from_array_row)?.value(0); | ||
let list_array_row_inner = as_list_array(&list_array_row)?; | ||
|
||
let mut bool_values = vec![]; | ||
for arr in list_arr.iter() { | ||
if let Some(a) = arr { | ||
bool_values.push(Some(a.eq(&from_a))); | ||
} else { | ||
return internal_err!( | ||
"Null value is not supported in array_replace" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know why nulls aren't supported -- all of the kernels support them well, and the code is more concise without the special case |
||
); | ||
} | ||
} | ||
BooleanArray::from(bool_values) | ||
list_array_row_inner | ||
.iter() | ||
// compare element by element the current row of list_array | ||
alamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
.map(|row| row.map(|row| row.eq(&from_array_row_inner))) | ||
.collect::<BooleanArray>() | ||
} | ||
_ => { | ||
let from_arr = Scalar::new(from_arr); | ||
arrow_ord::cmp::eq(&arr, &from_arr)? | ||
let from_arr = Scalar::new(from_array_row); | ||
arrow_ord::cmp::eq(&list_array_row, &from_arr)? | ||
} | ||
}; | ||
|
||
// Use MutableArrayData to build the replaced array | ||
let original_data = list_array_row.to_data(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than using a new array with two indexes, I found naming them helped a lot |
||
let to_data = to_array.to_data(); | ||
let capacity = Capacities::Array(original_data.len() + to_data.len()); | ||
|
||
// First array is the original array, second array is the element to replace with. | ||
let arrays = vec![arr, to_array.clone()]; | ||
let arrays_data = arrays | ||
.iter() | ||
.map(|a| a.to_data()) | ||
.collect::<Vec<ArrayData>>(); | ||
let arrays_data = arrays_data.iter().collect::<Vec<&ArrayData>>(); | ||
|
||
let arrays = arrays | ||
.iter() | ||
.map(|arr| arr.as_ref()) | ||
.collect::<Vec<&dyn Array>>(); | ||
let capacity = Capacities::Array(arrays.iter().map(|a| a.len()).sum()); | ||
|
||
let mut mutable = | ||
MutableArrayData::with_capacities(arrays_data, false, capacity); | ||
let mut mutable = MutableArrayData::with_capacities( | ||
vec![&original_data, &to_data], | ||
false, | ||
capacity, | ||
); | ||
let original_idx = 0; | ||
let replace_idx = 1; | ||
|
||
let mut counter = 0; | ||
for (i, to_replace) in eq_array.iter().enumerate() { | ||
if let Some(to_replace) = to_replace { | ||
if to_replace { | ||
mutable.extend(1, row_index, row_index + 1); | ||
counter += 1; | ||
if counter == *n { | ||
// extend the rest of the array | ||
mutable.extend(0, i + 1, eq_array.len()); | ||
break; | ||
} | ||
} else { | ||
mutable.extend(0, i, i + 1); | ||
if let Some(true) = to_replace { | ||
mutable.extend(replace_idx, row_index, row_index + 1); | ||
counter += 1; | ||
if counter == *n { | ||
// copy original data for any matches past n | ||
mutable.extend(original_idx, i + 1, eq_array.len()); | ||
break; | ||
} | ||
} else { | ||
return internal_err!("eq_array should not contain None"); | ||
// copy original data for false / null matches | ||
mutable.extend(original_idx, i, i + 1); | ||
} | ||
} | ||
|
||
let data = mutable.freeze(); | ||
let replaced_array = arrow_array::make_array(data); | ||
|
||
let v = arrow::compute::concat(&[&values, &replaced_array])?; | ||
values = v; | ||
offsets.push(last_offset + replaced_array.len() as i32); | ||
new_values.push(replaced_array); | ||
} | ||
None => { | ||
// Null element results in a null row (no new offsets) | ||
offsets.push(last_offset); | ||
} | ||
} | ||
} | ||
|
||
let values = if new_values.is_empty() { | ||
new_empty_array(&data_type) | ||
} else { | ||
let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); | ||
arrow::compute::concat(&new_values)? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The old logic called I think it is probably possible to avoid calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We used to replace MutableArrayData with ListArray, as it is more straightforward. However, perhaps we can also construct ListArray using MutableArrayData and still maintain readability. Particularly in this case, doing so can enhance performance. We should certainly consider building with MutableArrayData! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, what I was getting at is that the contents of |
||
}; | ||
|
||
Ok(Arc::new(ListArray::try_new( | ||
Arc::new(Field::new("item", data_type, true)), | ||
OffsetBuffer::new(offsets.into()), | ||
values, | ||
None, | ||
list_array.nulls().cloned(), | ||
)?)) | ||
} | ||
|
||
pub fn array_replace(args: &[ArrayRef]) -> Result<ArrayRef> { | ||
general_replace(args, vec![1; args[0].len()]) | ||
// replace at most one occurence for each element | ||
let arr_n = vec![1; args[0].len()]; | ||
general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) | ||
} | ||
|
||
pub fn array_replace_n(args: &[ArrayRef]) -> Result<ArrayRef> { | ||
let arr = as_int64_array(&args[3])?; | ||
let arr_n = arr.values().to_vec(); | ||
general_replace(args, arr_n) | ||
// replace the specified number of occurences | ||
let arr_n = as_int64_array(&args[3])?.values().to_vec(); | ||
general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) | ||
} | ||
|
||
pub fn array_replace_all(args: &[ArrayRef]) -> Result<ArrayRef> { | ||
general_replace(args, vec![i64::MAX; args[0].len()]) | ||
// replace all occurences (up to "i64::MAX") | ||
let arr_n = vec![i64::MAX; args[0].len()]; | ||
general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n) | ||
} | ||
|
||
macro_rules! to_string { | ||
|
@@ -1822,6 +1847,8 @@ pub fn string_to_array<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef | |
mod tests { | ||
use super::*; | ||
use arrow::datatypes::Int64Type; | ||
use arrow::util::pretty::pretty_format_columns; | ||
use arrow_array::types::Int32Type; | ||
use datafusion_common::cast::as_uint64_array; | ||
|
||
#[test] | ||
|
@@ -2763,6 +2790,52 @@ mod tests { | |
); | ||
} | ||
|
||
#[test] | ||
fn test_array_replace_with_null() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the null cases were not covered, I added new coverage. Maybe this style of test could help as we clean up the other aray functions There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer having test in slt file if possible, we can cleanup test in the follow on PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. slt is a good idea -- I will move the tests There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in c6dabcf |
||
// ([3, 1, NULL, 3], 3, 4, 2) => [4, 1, NULL, 4] NULL not matched | ||
// ([3, 1, NULL, 3], NULL, 5, 2) => [3, 1, NULL, 3] NULL not replaced (not eq) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why null is not replaced with 5 => [3, 1, 5, 3]? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that arrow_ord::cmp::eq([3,1,null,3], null) return [null,null,null,null] which I expect to see [false,false,true,false] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe it's reasonable for arrow-rs to handle null comparisons There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is because in SQL, If we want the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I would prefer having There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, sounds good. I will update this PR There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in acbffa0 |
||
// ([NULL], 3, 2, 1) => NULL | ||
// ([3, 1, 3], 3, NULL, 1) => [NULL, 1 3] | ||
let list_array = ListArray::from_iter_primitive::<Int32Type, _, _>([ | ||
Some(vec![Some(3), Some(1), None, Some(3)]), | ||
Some(vec![Some(3), Some(1), None, Some(3)]), | ||
None, | ||
Some(vec![Some(3), Some(1), Some(3)]), | ||
]); | ||
|
||
let from = Int32Array::from(vec![Some(3), None, Some(3), Some(3)]); | ||
let to = Int32Array::from(vec![Some(4), Some(5), Some(2), None]); | ||
let n = vec![2, 2, 1, 1]; | ||
|
||
let expected = ListArray::from_iter_primitive::<Int32Type, _, _>([ | ||
Some(vec![Some(4), Some(1), None, Some(4)]), | ||
Some(vec![Some(3), Some(1), None, Some(3)]), | ||
None, | ||
Some(vec![None, Some(1), Some(3)]), | ||
]); | ||
|
||
let list_array = Arc::new(list_array) as ArrayRef; | ||
let from = Arc::new(from) as ArrayRef; | ||
let to = Arc::new(to) as ArrayRef; | ||
let expected = Arc::new(expected) as ArrayRef; | ||
|
||
let replaced = general_replace( | ||
as_list_array(&list_array).unwrap(), | ||
&from as &ArrayRef, | ||
&to, | ||
n, | ||
) | ||
.unwrap(); | ||
assert_eq!( | ||
&replaced, | ||
&expected, | ||
"\n\n{}\n\n{}\n\n{}", | ||
pretty_format_columns("input", &[Arc::new(list_array) as _]).unwrap(), | ||
pretty_format_columns("replaced", &[replaced.clone()]).unwrap(), | ||
pretty_format_columns("expected", &[expected.clone()]).unwrap(), | ||
); | ||
} | ||
|
||
#[test] | ||
fn test_nested_array_replace() { | ||
// array_replace( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, having a clear description of what the function does, especially when it is mind bending like
general_replace
helps a lot. I also find that writing such a description often helps improve the code as well as results in better namingFor example, I struggled to explain how this function worked when it took
&[ArrayRef]
because then the explanations were in terms fofarg[0]
,arg[1]
, andarg[2]
. Giving the parameters names made things much clearer