Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace macro with function for array_position and array_positions #8170

Merged
merged 7 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
309 changes: 162 additions & 147 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,78 @@ macro_rules! array {
}};
}

/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

///
/// # Arguments
///
/// * `list_array_row` - A reference to a trait object implementing the Arrow `Array` trait. It represents the list array for which the equality or inequality will be compared.
///
/// * `element_array` - A reference to a trait object implementing the Arrow `Array` trait. It represents the array with which each element in the `list_array_row` will be compared.
///
/// * `row_index` - The index of the row in the `element_array` and `list_array` to use for the comparison.
///
/// * `eq` - A boolean flag. If `true`, the function computes equality; if `false`, it computes inequality.
///
/// # Returns
///
/// Returns a `Result<BooleanArray>` representing the comparison results. The result may contain an error if there are issues with the computation.
///
/// # Example
///
/// ```text
/// compare_element_to_list(
/// [1, 2, 3], [1, 2, 3], 0, true => [true, false, false]
/// [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, true, false]
///
/// [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 0, true => [true, false, false]
/// [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 1, false => [true, false, false]
/// )
/// ```
fn compare_element_to_list(
list_array_row: &dyn Array,
element_array: &dyn Array,
row_index: usize,
eq: bool,
) -> Result<BooleanArray> {
let indices = UInt32Array::from(vec![row_index as u32]);
let element_array_row = arrow::compute::take(element_array, &indices, None)?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will always be a single row Array as indices have a single value 🤔 I wonder if you could call

I think you could instead call list_array_row.values().slice() and find the relevant row to compare against those values 🤔

// Compute all positions in list_row_array (that is itself an
// array) that are equal to `from_array_row`
let res = match element_array_row.data_type() {
// arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop
DataType::List(_) => {
// compare each element of the from array
let element_array_row_inner = as_list_array(&element_array_row)?.value(0);
let list_array_row_inner = as_list_array(list_array_row)?;

list_array_row_inner
.iter()
// compare element by element the current row of list_array
.map(|row| {
row.map(|row| {
if eq {
row.eq(&element_array_row_inner)
} else {
row.ne(&element_array_row_inner)
}
})
})
.collect::<BooleanArray>()
}
_ => {
let element_arr = Scalar::new(element_array_row);
// use not_distinct so we can compare NULL
if eq {
arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)?
} else {
arrow_ord::cmp::distinct(&list_array_row, &element_arr)?
}
}
};

Ok(res)
}

/// Returns the length of a concrete array dimension
fn compute_array_length(
arr: Option<ArrayRef>,
Expand Down Expand Up @@ -953,114 +1025,68 @@ fn general_list_repeat(
)?))
}

macro_rules! position {
($ARRAY:expr, $ELEMENT:expr, $INDEX:expr, $ARRAY_TYPE:ident) => {{
let element = downcast_arg!($ELEMENT, $ARRAY_TYPE);
$ARRAY
.iter()
.zip(element.iter())
.zip($INDEX.iter())
.map(|((arr, el), i)| {
let index = match i {
Some(i) => {
if i <= 0 {
0
} else {
i - 1
}
}
None => return exec_err!("initial position must not be null"),
};

match arr {
Some(arr) => {
let child_array = downcast_arg!(arr, $ARRAY_TYPE);

match child_array
.iter()
.skip(index as usize)
.position(|x| x == el)
{
Some(value) => Ok(Some(value as u64 + index as u64 + 1u64)),
None => Ok(None),
}
}
None => Ok(None),
}
})
.collect::<Result<UInt64Array>>()?
}};
}

/// Array_position SQL function
pub fn array_position(args: &[ArrayRef]) -> Result<ArrayRef> {
let arr = as_list_array(&args[0])?;
let element = &args[1];
let list_array = as_list_array(&args[0])?;
let element_array = &args[1];

check_datatypes("array_position", &[list_array.values(), element_array])?;

let index = if args.len() == 3 {
as_int64_array(&args[2])?.clone()
let arr_from = if args.len() == 3 {
as_int64_array(&args[2])?
.values()
.to_vec()
.iter()
.map(|&x| x - 1)
.collect::<Vec<_>>()
} else {
Int64Array::from_value(0, arr.len())
vec![0; list_array.len()]
};

check_datatypes("array_position", &[arr.values(), element])?;
macro_rules! array_function {
($ARRAY_TYPE:ident) => {
position!(arr, element, index, $ARRAY_TYPE)
};
// if `start_from` index is out of bounds, return error
for (arr, &from) in list_array.iter().zip(arr_from.iter()) {
if let Some(arr) = arr {
if from < 0 || from as usize >= arr.len() {
return internal_err!("start_from index out of bounds");
}
} else {
// We will get null if we got null in the array, so we don't need to check
}
}
let res = call_array_function!(arr.value_type(), true);

Ok(Arc::new(res))
general_position::<i32>(list_array, element_array, arr_from)
}

macro_rules! positions {
($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{
let element = downcast_arg!($ELEMENT, $ARRAY_TYPE);
let mut offsets: Vec<i32> = vec![0];
let mut values =
downcast_arg!(new_empty_array(&DataType::UInt64), UInt64Array).clone();
for comp in $ARRAY
.iter()
.zip(element.iter())
.map(|(arr, el)| match arr {
Some(arr) => {
let child_array = downcast_arg!(arr, $ARRAY_TYPE);
let res = child_array
.iter()
.enumerate()
.filter(|(_, x)| *x == el)
.flat_map(|(i, _)| Some((i + 1) as u64))
.collect::<UInt64Array>();
fn general_position<OffsetSize: OffsetSizeTrait>(
Copy link
Contributor

@Veeupup Veeupup Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if we only use i32 as OffsetSize, do we really need the Generics?

Copy link
Contributor Author

@jayzhan211 jayzhan211 Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, but we will need it if we want to extend it for large list, I'm just lazy to remove it for now :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice extension 👍

list_array: &GenericListArray<OffsetSize>,
element_array: &ArrayRef,
arr_from: Vec<i64>, // 0-indexed
) -> Result<ArrayRef> {
let mut data = Vec::with_capacity(list_array.len());

Ok(res)
}
None => Ok(downcast_arg!(
new_empty_array(&DataType::UInt64),
UInt64Array
)
.clone()),
})
.collect::<Result<Vec<UInt64Array>>>()?
{
let last_offset: i32 = offsets.last().copied().ok_or_else(|| {
DataFusionError::Internal(format!("offsets should not be empty",))
})?;
values =
downcast_arg!(compute::concat(&[&values, &comp,])?.clone(), UInt64Array)
.clone();
offsets.push(last_offset + comp.len() as i32);
}
for (row_index, (list_array_row, &from)) in
list_array.iter().zip(arr_from.iter()).enumerate()
{
let from = from as usize;

let field = Arc::new(Field::new("item", DataType::UInt64, true));
if let Some(list_array_row) = list_array_row {
let eq_array =
compare_element_to_list(&list_array_row, element_array, row_index, true)?;

Arc::new(ListArray::try_new(
field,
OffsetBuffer::new(offsets.into()),
Arc::new(values),
None,
)?)
}};
// Collect `true`s in 1-indexed positions
let index = eq_array
.iter()
.skip(from)
.position(|e| e == Some(true))
.map(|index| (from + index + 1) as u64);

data.push(index);
} else {
data.push(None);
}
}

Ok(Arc::new(UInt64Array::from(data)))
}

/// Array_positions SQL function
Expand All @@ -1069,14 +1095,37 @@ pub fn array_positions(args: &[ArrayRef]) -> Result<ArrayRef> {
let element = &args[1];

check_datatypes("array_positions", &[arr.values(), element])?;
macro_rules! array_function {
($ARRAY_TYPE:ident) => {
positions!(arr, element, $ARRAY_TYPE)
};

general_positions::<i32>(arr, element)
}

fn general_positions<OffsetSize: OffsetSizeTrait>(
list_array: &GenericListArray<OffsetSize>,
element_array: &ArrayRef,
) -> Result<ArrayRef> {
let mut data = Vec::with_capacity(list_array.len());

for (row_index, list_array_row) in list_array.iter().enumerate() {
if let Some(list_array_row) = list_array_row {
let eq_array =
compare_element_to_list(&list_array_row, element_array, row_index, true)?;

// Collect `true`s in 1-indexed positions
let indexes = eq_array
.iter()
.positions(|e| e == Some(true))
.map(|index| Some(index as u64 + 1))
.collect::<Vec<_>>();

data.push(Some(indexes));
} else {
data.push(None);
}
}
let res = call_array_function!(arr.value_type(), true);

Ok(res)
Ok(Arc::new(
ListArray::from_iter_primitive::<UInt64Type, _, _>(data),
))
}

/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurences
Expand Down Expand Up @@ -1113,30 +1162,12 @@ fn general_remove<OffsetSize: OffsetSizeTrait>(
{
match list_array_row {
Some(list_array_row) => {
let indices = UInt32Array::from(vec![row_index as u32]);
let element_array_row =
arrow::compute::take(element_array, &indices, None)?;

let eq_array = match element_array_row.data_type() {
// arrow_ord::cmp::distinct does not support ListArray, so we need to compare it by loop
DataType::List(_) => {
// compare each element of the from array
let element_array_row_inner =
as_list_array(&element_array_row)?.value(0);
let list_array_row_inner = as_list_array(&list_array_row)?;

list_array_row_inner
.iter()
// compare element by element the current row of list_array
.map(|row| row.map(|row| row.ne(&element_array_row_inner)))
.collect::<BooleanArray>()
}
_ => {
let from_arr = Scalar::new(element_array_row);
// use distinct so Null = Null is false
arrow_ord::cmp::distinct(&list_array_row, &from_arr)?
}
};
let eq_array = compare_element_to_list(
&list_array_row,
element_array,
row_index,
false,
)?;

// We need to keep at most first n elements as `false`, which represent the elements to remove.
let eq_array = if eq_array.false_count() < *n as usize {
Expand Down Expand Up @@ -1261,30 +1292,14 @@ fn general_replace(

match list_array_row {
Some(list_array_row) => {
let indices = UInt32Array::from(vec![row_index as u32]);
let from_array_row = arrow::compute::take(from_array, &indices, None)?;
// Compute all positions in list_row_array (that is itself an
// array) that are equal to `from_array_row`
let eq_array = match from_array_row.data_type() {
// arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop
DataType::List(_) => {
// compare each element of the from array
let from_array_row_inner =
as_list_array(&from_array_row)?.value(0);
let list_array_row_inner = as_list_array(&list_array_row)?;

list_array_row_inner
.iter()
// compare element by element the current row of list_array
.map(|row| row.map(|row| row.eq(&from_array_row_inner)))
.collect::<BooleanArray>()
}
_ => {
let from_arr = Scalar::new(from_array_row);
// use not_distinct so NULL = NULL
arrow_ord::cmp::not_distinct(&list_array_row, &from_arr)?
}
};
let eq_array = compare_element_to_list(
&list_array_row,
&from_array,
row_index,
true,
)?;

// Use MutableArrayData to build the replaced array
let original_data = list_array_row.to_data();
Expand Down
Loading