diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 01d495ee7f6b..515df2a970a4 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -131,6 +131,78 @@ macro_rules! array { }}; } +/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array. +/// +/// # Arguments +/// +/// * `list_array_row` - A reference to a trait object implementing the Arrow `Array` trait. It represents the list array for which the equality or inequality will be compared. +/// +/// * `element_array` - A reference to a trait object implementing the Arrow `Array` trait. It represents the array with which each element in the `list_array_row` will be compared. +/// +/// * `row_index` - The index of the row in the `element_array` and `list_array` to use for the comparison. +/// +/// * `eq` - A boolean flag. If `true`, the function computes equality; if `false`, it computes inequality. +/// +/// # Returns +/// +/// Returns a `Result` representing the comparison results. The result may contain an error if there are issues with the computation. +/// +/// # Example +/// +/// ```text +/// compare_element_to_list( +/// [1, 2, 3], [1, 2, 3], 0, true => [true, false, false] +/// [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, true, false] +/// +/// [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 0, true => [true, false, false] +/// [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 1, false => [true, false, false] +/// ) +/// ``` +fn compare_element_to_list( + list_array_row: &dyn Array, + element_array: &dyn Array, + row_index: usize, + eq: bool, +) -> Result { + let indices = UInt32Array::from(vec![row_index as u32]); + let element_array_row = arrow::compute::take(element_array, &indices, None)?; + // Compute all positions in list_row_array (that is itself an + // array) that are equal to `from_array_row` + let res = match element_array_row.data_type() { + // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop + DataType::List(_) => { + // compare each element of the from array + let element_array_row_inner = as_list_array(&element_array_row)?.value(0); + let list_array_row_inner = as_list_array(list_array_row)?; + + list_array_row_inner + .iter() + // compare element by element the current row of list_array + .map(|row| { + row.map(|row| { + if eq { + row.eq(&element_array_row_inner) + } else { + row.ne(&element_array_row_inner) + } + }) + }) + .collect::() + } + _ => { + let element_arr = Scalar::new(element_array_row); + // use not_distinct so we can compare NULL + if eq { + arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)? + } else { + arrow_ord::cmp::distinct(&list_array_row, &element_arr)? + } + } + }; + + Ok(res) +} + /// Returns the length of a concrete array dimension fn compute_array_length( arr: Option, @@ -1005,114 +1077,68 @@ fn general_list_repeat( )?)) } -macro_rules! position { - ($ARRAY:expr, $ELEMENT:expr, $INDEX:expr, $ARRAY_TYPE:ident) => {{ - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - $ARRAY - .iter() - .zip(element.iter()) - .zip($INDEX.iter()) - .map(|((arr, el), i)| { - let index = match i { - Some(i) => { - if i <= 0 { - 0 - } else { - i - 1 - } - } - None => return exec_err!("initial position must not be null"), - }; - - match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); - - match child_array - .iter() - .skip(index as usize) - .position(|x| x == el) - { - Some(value) => Ok(Some(value as u64 + index as u64 + 1u64)), - None => Ok(None), - } - } - None => Ok(None), - } - }) - .collect::>()? - }}; -} - /// Array_position SQL function pub fn array_position(args: &[ArrayRef]) -> Result { - let arr = as_list_array(&args[0])?; - let element = &args[1]; + let list_array = as_list_array(&args[0])?; + let element_array = &args[1]; - let index = if args.len() == 3 { - as_int64_array(&args[2])?.clone() + check_datatypes("array_position", &[list_array.values(), element_array])?; + + let arr_from = if args.len() == 3 { + as_int64_array(&args[2])? + .values() + .to_vec() + .iter() + .map(|&x| x - 1) + .collect::>() } else { - Int64Array::from_value(0, arr.len()) + vec![0; list_array.len()] }; - check_datatypes("array_position", &[arr.values(), element])?; - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - position!(arr, element, index, $ARRAY_TYPE) - }; + // if `start_from` index is out of bounds, return error + for (arr, &from) in list_array.iter().zip(arr_from.iter()) { + if let Some(arr) = arr { + if from < 0 || from as usize >= arr.len() { + return internal_err!("start_from index out of bounds"); + } + } else { + // We will get null if we got null in the array, so we don't need to check + } } - let res = call_array_function!(arr.value_type(), true); - Ok(Arc::new(res)) + general_position::(list_array, element_array, arr_from) } -macro_rules! positions { - ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array(&DataType::UInt64), UInt64Array).clone(); - for comp in $ARRAY - .iter() - .zip(element.iter()) - .map(|(arr, el)| match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); - let res = child_array - .iter() - .enumerate() - .filter(|(_, x)| *x == el) - .flat_map(|(i, _)| Some((i + 1) as u64)) - .collect::(); +fn general_position( + list_array: &GenericListArray, + element_array: &ArrayRef, + arr_from: Vec, // 0-indexed +) -> Result { + let mut data = Vec::with_capacity(list_array.len()); - Ok(res) - } - None => Ok(downcast_arg!( - new_empty_array(&DataType::UInt64), - UInt64Array - ) - .clone()), - }) - .collect::>>()? - { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty",)) - })?; - values = - downcast_arg!(compute::concat(&[&values, &comp,])?.clone(), UInt64Array) - .clone(); - offsets.push(last_offset + comp.len() as i32); - } + for (row_index, (list_array_row, &from)) in + list_array.iter().zip(arr_from.iter()).enumerate() + { + let from = from as usize; - let field = Arc::new(Field::new("item", DataType::UInt64, true)); + if let Some(list_array_row) = list_array_row { + let eq_array = + compare_element_to_list(&list_array_row, element_array, row_index, true)?; - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; + // Collect `true`s in 1-indexed positions + let index = eq_array + .iter() + .skip(from) + .position(|e| e == Some(true)) + .map(|index| (from + index + 1) as u64); + + data.push(index); + } else { + data.push(None); + } + } + + Ok(Arc::new(UInt64Array::from(data))) } /// Array_positions SQL function @@ -1121,14 +1147,37 @@ pub fn array_positions(args: &[ArrayRef]) -> Result { let element = &args[1]; check_datatypes("array_positions", &[arr.values(), element])?; - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - positions!(arr, element, $ARRAY_TYPE) - }; + + general_positions::(arr, element) +} + +fn general_positions( + list_array: &GenericListArray, + element_array: &ArrayRef, +) -> Result { + let mut data = Vec::with_capacity(list_array.len()); + + for (row_index, list_array_row) in list_array.iter().enumerate() { + if let Some(list_array_row) = list_array_row { + let eq_array = + compare_element_to_list(&list_array_row, element_array, row_index, true)?; + + // Collect `true`s in 1-indexed positions + let indexes = eq_array + .iter() + .positions(|e| e == Some(true)) + .map(|index| Some(index as u64 + 1)) + .collect::>(); + + data.push(Some(indexes)); + } else { + data.push(None); + } } - let res = call_array_function!(arr.value_type(), true); - Ok(res) + Ok(Arc::new( + ListArray::from_iter_primitive::(data), + )) } /// For each element of `list_array[i]`, removed up to `arr_n[i]` occurences @@ -1165,30 +1214,12 @@ fn general_remove( { match list_array_row { Some(list_array_row) => { - let indices = UInt32Array::from(vec![row_index as u32]); - let element_array_row = - arrow::compute::take(element_array, &indices, None)?; - - let eq_array = match element_array_row.data_type() { - // arrow_ord::cmp::distinct does not support ListArray, so we need to compare it by loop - DataType::List(_) => { - // compare each element of the from array - let element_array_row_inner = - as_list_array(&element_array_row)?.value(0); - let list_array_row_inner = as_list_array(&list_array_row)?; - - list_array_row_inner - .iter() - // compare element by element the current row of list_array - .map(|row| row.map(|row| row.ne(&element_array_row_inner))) - .collect::() - } - _ => { - let from_arr = Scalar::new(element_array_row); - // use distinct so Null = Null is false - arrow_ord::cmp::distinct(&list_array_row, &from_arr)? - } - }; + let eq_array = compare_element_to_list( + &list_array_row, + element_array, + row_index, + false, + )?; // We need to keep at most first n elements as `false`, which represent the elements to remove. let eq_array = if eq_array.false_count() < *n as usize { @@ -1313,30 +1344,14 @@ fn general_replace( match list_array_row { Some(list_array_row) => { - let indices = UInt32Array::from(vec![row_index as u32]); - let from_array_row = arrow::compute::take(from_array, &indices, None)?; // Compute all positions in list_row_array (that is itself an // array) that are equal to `from_array_row` - let eq_array = match from_array_row.data_type() { - // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop - DataType::List(_) => { - // compare each element of the from array - let from_array_row_inner = - as_list_array(&from_array_row)?.value(0); - let list_array_row_inner = as_list_array(&list_array_row)?; - - list_array_row_inner - .iter() - // compare element by element the current row of list_array - .map(|row| row.map(|row| row.eq(&from_array_row_inner))) - .collect::() - } - _ => { - let from_arr = Scalar::new(from_array_row); - // use not_distinct so NULL = NULL - arrow_ord::cmp::not_distinct(&list_array_row, &from_arr)? - } - }; + let eq_array = compare_element_to_list( + &list_array_row, + &from_array, + row_index, + true, + )?; // Use MutableArrayData to build the replaced array let original_data = list_array_row.to_data(); diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 92013f37d36c..67cabb0988fd 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -702,7 +702,7 @@ select array_element(make_array(1, 2, 3, 4, 5), 0), array_element(make_array('h' NULL NULL # array_element scalar function #4 (with NULL) -query error +query error select array_element(make_array(1, 2, 3, 4, 5), NULL), array_element(make_array('h', 'e', 'l', 'l', 'o'), NULL); # array_element scalar function #5 (with negative index) @@ -871,11 +871,11 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, 4), array_slice(make_array('h', [1, 2, 3, 4] [h, e, l] # array_slice scalar function #8 (with NULL and positive number) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), NULL, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, 3); # array_slice scalar function #9 (with positive number and NULL) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); # array_slice scalar function #10 (with zero-zero) @@ -885,7 +885,7 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, 0), array_slice(make_array('h', [] [] # array_slice scalar function #11 (with NULL-NULL) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL); # array_slice scalar function #12 (with zero and negative number) @@ -895,11 +895,11 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, -4), array_slice(make_array('h' [1] [h, e] # array_slice scalar function #13 (with negative number and NULL) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); # array_slice scalar function #14 (with NULL and negative number) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), NULL, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, -3); # array_slice scalar function #15 (with negative indexes)