-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace macro with function for array_position
and array_positions
#8170
Changes from all commits
6861b94
ef38e32
f8e40db
edefb73
8274afe
e46a627
340613c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -131,6 +131,78 @@ macro_rules! array { | |
}}; | ||
} | ||
|
||
/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `list_array_row` - A reference to a trait object implementing the Arrow `Array` trait. It represents the list array for which the equality or inequality will be compared. | ||
/// | ||
/// * `element_array` - A reference to a trait object implementing the Arrow `Array` trait. It represents the array with which each element in the `list_array_row` will be compared. | ||
/// | ||
/// * `row_index` - The index of the row in the `element_array` and `list_array` to use for the comparison. | ||
/// | ||
/// * `eq` - A boolean flag. If `true`, the function computes equality; if `false`, it computes inequality. | ||
/// | ||
/// # Returns | ||
/// | ||
/// Returns a `Result<BooleanArray>` representing the comparison results. The result may contain an error if there are issues with the computation. | ||
/// | ||
/// # Example | ||
/// | ||
/// ```text | ||
/// compare_element_to_list( | ||
/// [1, 2, 3], [1, 2, 3], 0, true => [true, false, false] | ||
/// [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, true, false] | ||
/// | ||
/// [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 0, true => [true, false, false] | ||
/// [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 1, false => [true, false, false] | ||
/// ) | ||
/// ``` | ||
fn compare_element_to_list( | ||
list_array_row: &dyn Array, | ||
element_array: &dyn Array, | ||
row_index: usize, | ||
eq: bool, | ||
) -> Result<BooleanArray> { | ||
let indices = UInt32Array::from(vec![row_index as u32]); | ||
let element_array_row = arrow::compute::take(element_array, &indices, None)?; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will always be a single row Array as indices have a single value 🤔 I wonder if you could call I think you could instead call |
||
// Compute all positions in list_row_array (that is itself an | ||
// array) that are equal to `from_array_row` | ||
let res = match element_array_row.data_type() { | ||
// arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop | ||
DataType::List(_) => { | ||
// compare each element of the from array | ||
let element_array_row_inner = as_list_array(&element_array_row)?.value(0); | ||
let list_array_row_inner = as_list_array(list_array_row)?; | ||
|
||
list_array_row_inner | ||
.iter() | ||
// compare element by element the current row of list_array | ||
.map(|row| { | ||
row.map(|row| { | ||
if eq { | ||
row.eq(&element_array_row_inner) | ||
} else { | ||
row.ne(&element_array_row_inner) | ||
} | ||
}) | ||
}) | ||
.collect::<BooleanArray>() | ||
} | ||
_ => { | ||
let element_arr = Scalar::new(element_array_row); | ||
// use not_distinct so we can compare NULL | ||
if eq { | ||
arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)? | ||
} else { | ||
arrow_ord::cmp::distinct(&list_array_row, &element_arr)? | ||
} | ||
} | ||
}; | ||
|
||
Ok(res) | ||
} | ||
|
||
/// Returns the length of a concrete array dimension | ||
fn compute_array_length( | ||
arr: Option<ArrayRef>, | ||
|
@@ -953,114 +1025,68 @@ fn general_list_repeat( | |
)?)) | ||
} | ||
|
||
macro_rules! position { | ||
($ARRAY:expr, $ELEMENT:expr, $INDEX:expr, $ARRAY_TYPE:ident) => {{ | ||
let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); | ||
$ARRAY | ||
.iter() | ||
.zip(element.iter()) | ||
.zip($INDEX.iter()) | ||
.map(|((arr, el), i)| { | ||
let index = match i { | ||
Some(i) => { | ||
if i <= 0 { | ||
0 | ||
} else { | ||
i - 1 | ||
} | ||
} | ||
None => return exec_err!("initial position must not be null"), | ||
}; | ||
|
||
match arr { | ||
Some(arr) => { | ||
let child_array = downcast_arg!(arr, $ARRAY_TYPE); | ||
|
||
match child_array | ||
.iter() | ||
.skip(index as usize) | ||
.position(|x| x == el) | ||
{ | ||
Some(value) => Ok(Some(value as u64 + index as u64 + 1u64)), | ||
None => Ok(None), | ||
} | ||
} | ||
None => Ok(None), | ||
} | ||
}) | ||
.collect::<Result<UInt64Array>>()? | ||
}}; | ||
} | ||
|
||
/// Array_position SQL function | ||
pub fn array_position(args: &[ArrayRef]) -> Result<ArrayRef> { | ||
let arr = as_list_array(&args[0])?; | ||
let element = &args[1]; | ||
let list_array = as_list_array(&args[0])?; | ||
let element_array = &args[1]; | ||
|
||
check_datatypes("array_position", &[list_array.values(), element_array])?; | ||
|
||
let index = if args.len() == 3 { | ||
as_int64_array(&args[2])?.clone() | ||
let arr_from = if args.len() == 3 { | ||
as_int64_array(&args[2])? | ||
.values() | ||
.to_vec() | ||
.iter() | ||
.map(|&x| x - 1) | ||
.collect::<Vec<_>>() | ||
} else { | ||
Int64Array::from_value(0, arr.len()) | ||
vec![0; list_array.len()] | ||
}; | ||
|
||
check_datatypes("array_position", &[arr.values(), element])?; | ||
macro_rules! array_function { | ||
($ARRAY_TYPE:ident) => { | ||
position!(arr, element, index, $ARRAY_TYPE) | ||
}; | ||
// if `start_from` index is out of bounds, return error | ||
for (arr, &from) in list_array.iter().zip(arr_from.iter()) { | ||
if let Some(arr) = arr { | ||
if from < 0 || from as usize >= arr.len() { | ||
return internal_err!("start_from index out of bounds"); | ||
} | ||
} else { | ||
// We will get null if we got null in the array, so we don't need to check | ||
} | ||
} | ||
let res = call_array_function!(arr.value_type(), true); | ||
|
||
Ok(Arc::new(res)) | ||
general_position::<i32>(list_array, element_array, arr_from) | ||
} | ||
|
||
macro_rules! positions { | ||
($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{ | ||
let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); | ||
let mut offsets: Vec<i32> = vec![0]; | ||
let mut values = | ||
downcast_arg!(new_empty_array(&DataType::UInt64), UInt64Array).clone(); | ||
for comp in $ARRAY | ||
.iter() | ||
.zip(element.iter()) | ||
.map(|(arr, el)| match arr { | ||
Some(arr) => { | ||
let child_array = downcast_arg!(arr, $ARRAY_TYPE); | ||
let res = child_array | ||
.iter() | ||
.enumerate() | ||
.filter(|(_, x)| *x == el) | ||
.flat_map(|(i, _)| Some((i + 1) as u64)) | ||
.collect::<UInt64Array>(); | ||
fn general_position<OffsetSize: OffsetSizeTrait>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: if we only use i32 as OffsetSize, do we really need the Generics? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no, but we will need it if we want to extend it for large list, I'm just lazy to remove it for now :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice extension 👍 |
||
list_array: &GenericListArray<OffsetSize>, | ||
element_array: &ArrayRef, | ||
arr_from: Vec<i64>, // 0-indexed | ||
) -> Result<ArrayRef> { | ||
let mut data = Vec::with_capacity(list_array.len()); | ||
|
||
Ok(res) | ||
} | ||
None => Ok(downcast_arg!( | ||
new_empty_array(&DataType::UInt64), | ||
UInt64Array | ||
) | ||
.clone()), | ||
}) | ||
.collect::<Result<Vec<UInt64Array>>>()? | ||
{ | ||
let last_offset: i32 = offsets.last().copied().ok_or_else(|| { | ||
DataFusionError::Internal(format!("offsets should not be empty",)) | ||
})?; | ||
values = | ||
downcast_arg!(compute::concat(&[&values, &comp,])?.clone(), UInt64Array) | ||
.clone(); | ||
offsets.push(last_offset + comp.len() as i32); | ||
} | ||
for (row_index, (list_array_row, &from)) in | ||
list_array.iter().zip(arr_from.iter()).enumerate() | ||
{ | ||
let from = from as usize; | ||
|
||
let field = Arc::new(Field::new("item", DataType::UInt64, true)); | ||
if let Some(list_array_row) = list_array_row { | ||
let eq_array = | ||
compare_element_to_list(&list_array_row, element_array, row_index, true)?; | ||
|
||
Arc::new(ListArray::try_new( | ||
field, | ||
OffsetBuffer::new(offsets.into()), | ||
Arc::new(values), | ||
None, | ||
)?) | ||
}}; | ||
// Collect `true`s in 1-indexed positions | ||
let index = eq_array | ||
.iter() | ||
.skip(from) | ||
.position(|e| e == Some(true)) | ||
.map(|index| (from + index + 1) as u64); | ||
|
||
data.push(index); | ||
} else { | ||
data.push(None); | ||
} | ||
} | ||
|
||
Ok(Arc::new(UInt64Array::from(data))) | ||
} | ||
|
||
/// Array_positions SQL function | ||
|
@@ -1069,14 +1095,37 @@ pub fn array_positions(args: &[ArrayRef]) -> Result<ArrayRef> { | |
let element = &args[1]; | ||
|
||
check_datatypes("array_positions", &[arr.values(), element])?; | ||
macro_rules! array_function { | ||
($ARRAY_TYPE:ident) => { | ||
positions!(arr, element, $ARRAY_TYPE) | ||
}; | ||
|
||
general_positions::<i32>(arr, element) | ||
} | ||
|
||
fn general_positions<OffsetSize: OffsetSizeTrait>( | ||
list_array: &GenericListArray<OffsetSize>, | ||
element_array: &ArrayRef, | ||
) -> Result<ArrayRef> { | ||
let mut data = Vec::with_capacity(list_array.len()); | ||
|
||
for (row_index, list_array_row) in list_array.iter().enumerate() { | ||
if let Some(list_array_row) = list_array_row { | ||
let eq_array = | ||
compare_element_to_list(&list_array_row, element_array, row_index, true)?; | ||
|
||
// Collect `true`s in 1-indexed positions | ||
let indexes = eq_array | ||
.iter() | ||
.positions(|e| e == Some(true)) | ||
.map(|index| Some(index as u64 + 1)) | ||
.collect::<Vec<_>>(); | ||
|
||
data.push(Some(indexes)); | ||
} else { | ||
data.push(None); | ||
} | ||
} | ||
let res = call_array_function!(arr.value_type(), true); | ||
|
||
Ok(res) | ||
Ok(Arc::new( | ||
ListArray::from_iter_primitive::<UInt64Type, _, _>(data), | ||
)) | ||
} | ||
|
||
/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurences | ||
|
@@ -1113,30 +1162,12 @@ fn general_remove<OffsetSize: OffsetSizeTrait>( | |
{ | ||
match list_array_row { | ||
Some(list_array_row) => { | ||
let indices = UInt32Array::from(vec![row_index as u32]); | ||
let element_array_row = | ||
arrow::compute::take(element_array, &indices, None)?; | ||
|
||
let eq_array = match element_array_row.data_type() { | ||
// arrow_ord::cmp::distinct does not support ListArray, so we need to compare it by loop | ||
DataType::List(_) => { | ||
// compare each element of the from array | ||
let element_array_row_inner = | ||
as_list_array(&element_array_row)?.value(0); | ||
let list_array_row_inner = as_list_array(&list_array_row)?; | ||
|
||
list_array_row_inner | ||
.iter() | ||
// compare element by element the current row of list_array | ||
.map(|row| row.map(|row| row.ne(&element_array_row_inner))) | ||
.collect::<BooleanArray>() | ||
} | ||
_ => { | ||
let from_arr = Scalar::new(element_array_row); | ||
// use distinct so Null = Null is false | ||
arrow_ord::cmp::distinct(&list_array_row, &from_arr)? | ||
} | ||
}; | ||
let eq_array = compare_element_to_list( | ||
&list_array_row, | ||
element_array, | ||
row_index, | ||
false, | ||
)?; | ||
|
||
// We need to keep at most first n elements as `false`, which represent the elements to remove. | ||
let eq_array = if eq_array.false_count() < *n as usize { | ||
|
@@ -1261,30 +1292,14 @@ fn general_replace( | |
|
||
match list_array_row { | ||
Some(list_array_row) => { | ||
let indices = UInt32Array::from(vec![row_index as u32]); | ||
let from_array_row = arrow::compute::take(from_array, &indices, None)?; | ||
// Compute all positions in list_row_array (that is itself an | ||
// array) that are equal to `from_array_row` | ||
let eq_array = match from_array_row.data_type() { | ||
// arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop | ||
DataType::List(_) => { | ||
// compare each element of the from array | ||
let from_array_row_inner = | ||
as_list_array(&from_array_row)?.value(0); | ||
let list_array_row_inner = as_list_array(&list_array_row)?; | ||
|
||
list_array_row_inner | ||
.iter() | ||
// compare element by element the current row of list_array | ||
.map(|row| row.map(|row| row.eq(&from_array_row_inner))) | ||
.collect::<BooleanArray>() | ||
} | ||
_ => { | ||
let from_arr = Scalar::new(from_array_row); | ||
// use not_distinct so NULL = NULL | ||
arrow_ord::cmp::not_distinct(&list_array_row, &from_arr)? | ||
} | ||
}; | ||
let eq_array = compare_element_to_list( | ||
&list_array_row, | ||
&from_array, | ||
row_index, | ||
true, | ||
)?; | ||
|
||
// Use MutableArrayData to build the replaced array | ||
let original_data = list_array_row.to_data(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️