Skip to content

Commit

Permalink
Replace macro with function for array_position and array_positions (
Browse files Browse the repository at this point in the history
#8170)

* basic one

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* complete n

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* positions done

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* compare_element_to_list

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fmt

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* resolve rebase

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fmt

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

---------

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 authored Nov 15, 2023
1 parent e1c2f95 commit 7c2c2f0
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 153 deletions.
309 changes: 162 additions & 147 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,78 @@ macro_rules! array {
}};
}

/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array.
///
/// # Arguments
///
/// * `list_array_row` - A reference to a trait object implementing the Arrow `Array` trait. It represents the list array for which the equality or inequality will be compared.
///
/// * `element_array` - A reference to a trait object implementing the Arrow `Array` trait. It represents the array with which each element in the `list_array_row` will be compared.
///
/// * `row_index` - The index of the row in the `element_array` and `list_array` to use for the comparison.
///
/// * `eq` - A boolean flag. If `true`, the function computes equality; if `false`, it computes inequality.
///
/// # Returns
///
/// Returns a `Result<BooleanArray>` representing the comparison results. The result may contain an error if there are issues with the computation.
///
/// # Example
///
/// ```text
/// compare_element_to_list(
/// [1, 2, 3], [1, 2, 3], 0, true => [true, false, false]
/// [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, true, false]
///
/// [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 0, true => [true, false, false]
/// [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 1, false => [true, false, false]
/// )
/// ```
fn compare_element_to_list(
list_array_row: &dyn Array,
element_array: &dyn Array,
row_index: usize,
eq: bool,
) -> Result<BooleanArray> {
let indices = UInt32Array::from(vec![row_index as u32]);
let element_array_row = arrow::compute::take(element_array, &indices, None)?;
// Compute all positions in list_row_array (that is itself an
// array) that are equal to `from_array_row`
let res = match element_array_row.data_type() {
// arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop
DataType::List(_) => {
// compare each element of the from array
let element_array_row_inner = as_list_array(&element_array_row)?.value(0);
let list_array_row_inner = as_list_array(list_array_row)?;

list_array_row_inner
.iter()
// compare element by element the current row of list_array
.map(|row| {
row.map(|row| {
if eq {
row.eq(&element_array_row_inner)
} else {
row.ne(&element_array_row_inner)
}
})
})
.collect::<BooleanArray>()
}
_ => {
let element_arr = Scalar::new(element_array_row);
// use not_distinct so we can compare NULL
if eq {
arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)?
} else {
arrow_ord::cmp::distinct(&list_array_row, &element_arr)?
}
}
};

Ok(res)
}

/// Returns the length of a concrete array dimension
fn compute_array_length(
arr: Option<ArrayRef>,
Expand Down Expand Up @@ -1005,114 +1077,68 @@ fn general_list_repeat(
)?))
}

macro_rules! position {
($ARRAY:expr, $ELEMENT:expr, $INDEX:expr, $ARRAY_TYPE:ident) => {{
let element = downcast_arg!($ELEMENT, $ARRAY_TYPE);
$ARRAY
.iter()
.zip(element.iter())
.zip($INDEX.iter())
.map(|((arr, el), i)| {
let index = match i {
Some(i) => {
if i <= 0 {
0
} else {
i - 1
}
}
None => return exec_err!("initial position must not be null"),
};

match arr {
Some(arr) => {
let child_array = downcast_arg!(arr, $ARRAY_TYPE);

match child_array
.iter()
.skip(index as usize)
.position(|x| x == el)
{
Some(value) => Ok(Some(value as u64 + index as u64 + 1u64)),
None => Ok(None),
}
}
None => Ok(None),
}
})
.collect::<Result<UInt64Array>>()?
}};
}

/// Array_position SQL function
pub fn array_position(args: &[ArrayRef]) -> Result<ArrayRef> {
let arr = as_list_array(&args[0])?;
let element = &args[1];
let list_array = as_list_array(&args[0])?;
let element_array = &args[1];

let index = if args.len() == 3 {
as_int64_array(&args[2])?.clone()
check_datatypes("array_position", &[list_array.values(), element_array])?;

let arr_from = if args.len() == 3 {
as_int64_array(&args[2])?
.values()
.to_vec()
.iter()
.map(|&x| x - 1)
.collect::<Vec<_>>()
} else {
Int64Array::from_value(0, arr.len())
vec![0; list_array.len()]
};

check_datatypes("array_position", &[arr.values(), element])?;
macro_rules! array_function {
($ARRAY_TYPE:ident) => {
position!(arr, element, index, $ARRAY_TYPE)
};
// if `start_from` index is out of bounds, return error
for (arr, &from) in list_array.iter().zip(arr_from.iter()) {
if let Some(arr) = arr {
if from < 0 || from as usize >= arr.len() {
return internal_err!("start_from index out of bounds");
}
} else {
// We will get null if we got null in the array, so we don't need to check
}
}
let res = call_array_function!(arr.value_type(), true);

Ok(Arc::new(res))
general_position::<i32>(list_array, element_array, arr_from)
}

macro_rules! positions {
($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{
let element = downcast_arg!($ELEMENT, $ARRAY_TYPE);
let mut offsets: Vec<i32> = vec![0];
let mut values =
downcast_arg!(new_empty_array(&DataType::UInt64), UInt64Array).clone();
for comp in $ARRAY
.iter()
.zip(element.iter())
.map(|(arr, el)| match arr {
Some(arr) => {
let child_array = downcast_arg!(arr, $ARRAY_TYPE);
let res = child_array
.iter()
.enumerate()
.filter(|(_, x)| *x == el)
.flat_map(|(i, _)| Some((i + 1) as u64))
.collect::<UInt64Array>();
fn general_position<OffsetSize: OffsetSizeTrait>(
list_array: &GenericListArray<OffsetSize>,
element_array: &ArrayRef,
arr_from: Vec<i64>, // 0-indexed
) -> Result<ArrayRef> {
let mut data = Vec::with_capacity(list_array.len());

Ok(res)
}
None => Ok(downcast_arg!(
new_empty_array(&DataType::UInt64),
UInt64Array
)
.clone()),
})
.collect::<Result<Vec<UInt64Array>>>()?
{
let last_offset: i32 = offsets.last().copied().ok_or_else(|| {
DataFusionError::Internal(format!("offsets should not be empty",))
})?;
values =
downcast_arg!(compute::concat(&[&values, &comp,])?.clone(), UInt64Array)
.clone();
offsets.push(last_offset + comp.len() as i32);
}
for (row_index, (list_array_row, &from)) in
list_array.iter().zip(arr_from.iter()).enumerate()
{
let from = from as usize;

let field = Arc::new(Field::new("item", DataType::UInt64, true));
if let Some(list_array_row) = list_array_row {
let eq_array =
compare_element_to_list(&list_array_row, element_array, row_index, true)?;

Arc::new(ListArray::try_new(
field,
OffsetBuffer::new(offsets.into()),
Arc::new(values),
None,
)?)
}};
// Collect `true`s in 1-indexed positions
let index = eq_array
.iter()
.skip(from)
.position(|e| e == Some(true))
.map(|index| (from + index + 1) as u64);

data.push(index);
} else {
data.push(None);
}
}

Ok(Arc::new(UInt64Array::from(data)))
}

/// Array_positions SQL function
Expand All @@ -1121,14 +1147,37 @@ pub fn array_positions(args: &[ArrayRef]) -> Result<ArrayRef> {
let element = &args[1];

check_datatypes("array_positions", &[arr.values(), element])?;
macro_rules! array_function {
($ARRAY_TYPE:ident) => {
positions!(arr, element, $ARRAY_TYPE)
};

general_positions::<i32>(arr, element)
}

fn general_positions<OffsetSize: OffsetSizeTrait>(
list_array: &GenericListArray<OffsetSize>,
element_array: &ArrayRef,
) -> Result<ArrayRef> {
let mut data = Vec::with_capacity(list_array.len());

for (row_index, list_array_row) in list_array.iter().enumerate() {
if let Some(list_array_row) = list_array_row {
let eq_array =
compare_element_to_list(&list_array_row, element_array, row_index, true)?;

// Collect `true`s in 1-indexed positions
let indexes = eq_array
.iter()
.positions(|e| e == Some(true))
.map(|index| Some(index as u64 + 1))
.collect::<Vec<_>>();

data.push(Some(indexes));
} else {
data.push(None);
}
}
let res = call_array_function!(arr.value_type(), true);

Ok(res)
Ok(Arc::new(
ListArray::from_iter_primitive::<UInt64Type, _, _>(data),
))
}

/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurences
Expand Down Expand Up @@ -1165,30 +1214,12 @@ fn general_remove<OffsetSize: OffsetSizeTrait>(
{
match list_array_row {
Some(list_array_row) => {
let indices = UInt32Array::from(vec![row_index as u32]);
let element_array_row =
arrow::compute::take(element_array, &indices, None)?;

let eq_array = match element_array_row.data_type() {
// arrow_ord::cmp::distinct does not support ListArray, so we need to compare it by loop
DataType::List(_) => {
// compare each element of the from array
let element_array_row_inner =
as_list_array(&element_array_row)?.value(0);
let list_array_row_inner = as_list_array(&list_array_row)?;

list_array_row_inner
.iter()
// compare element by element the current row of list_array
.map(|row| row.map(|row| row.ne(&element_array_row_inner)))
.collect::<BooleanArray>()
}
_ => {
let from_arr = Scalar::new(element_array_row);
// use distinct so Null = Null is false
arrow_ord::cmp::distinct(&list_array_row, &from_arr)?
}
};
let eq_array = compare_element_to_list(
&list_array_row,
element_array,
row_index,
false,
)?;

// We need to keep at most first n elements as `false`, which represent the elements to remove.
let eq_array = if eq_array.false_count() < *n as usize {
Expand Down Expand Up @@ -1313,30 +1344,14 @@ fn general_replace(

match list_array_row {
Some(list_array_row) => {
let indices = UInt32Array::from(vec![row_index as u32]);
let from_array_row = arrow::compute::take(from_array, &indices, None)?;
// Compute all positions in list_row_array (that is itself an
// array) that are equal to `from_array_row`
let eq_array = match from_array_row.data_type() {
// arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop
DataType::List(_) => {
// compare each element of the from array
let from_array_row_inner =
as_list_array(&from_array_row)?.value(0);
let list_array_row_inner = as_list_array(&list_array_row)?;

list_array_row_inner
.iter()
// compare element by element the current row of list_array
.map(|row| row.map(|row| row.eq(&from_array_row_inner)))
.collect::<BooleanArray>()
}
_ => {
let from_arr = Scalar::new(from_array_row);
// use not_distinct so NULL = NULL
arrow_ord::cmp::not_distinct(&list_array_row, &from_arr)?
}
};
let eq_array = compare_element_to_list(
&list_array_row,
&from_array,
row_index,
true,
)?;

// Use MutableArrayData to build the replaced array
let original_data = list_array_row.to_data();
Expand Down
Loading

0 comments on commit 7c2c2f0

Please sign in to comment.