From 6861b94f5ef3ff7e0d749480f64daae3fcfe7a9e Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 13 Nov 2023 19:52:46 +0800 Subject: [PATCH 1/7] basic one Signed-off-by: jayzhan211 --- .../physical-expr/src/array_expressions.rs | 54 ++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 9b074ff0ee0d..2c0dc3f675d6 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -25,6 +25,7 @@ use arrow::buffer::OffsetBuffer; use arrow::compute; use arrow::datatypes::{DataType, Field, UInt64Type}; use arrow::row::{RowConverter, SortField}; +use arrow_array::types::Int64Type; use arrow_buffer::NullBuffer; use arrow_schema::FieldRef; @@ -992,6 +993,56 @@ macro_rules! position { }}; } +fn general_position( + list_array: &GenericListArray, + element_array: &ArrayRef, + from_array: &Int64Array, + arr_n: Vec, +) -> Result { + let mut data = vec![]; + + for (row_index, (list_array_row, n)) in + list_array.iter().zip(arr_n.iter()).enumerate() + { + if let Some(list_array_row) = list_array_row { + let indices = UInt32Array::from(vec![row_index as u32]); + let element_array_row = arrow::compute::take(element_array, &indices, None)?; + // Compute all positions in list_row_array (that is itself an + // array) that are equal to `from_array_row` + let eq_array = match element_array_row.data_type() { + // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop + DataType::List(_) => { + // compare each element of the from array + let element_array_row_inner = + as_list_array(&element_array_row)?.value(0); + let list_array_row_inner = as_list_array(&list_array_row)?; + + list_array_row_inner + .iter() + // compare element by element the current row of list_array + .map(|row| row.map(|row| row.eq(&element_array_row_inner))) + .collect::() + } + _ => { + let element_arr = Scalar::new(element_array_row); + // use not_distinct so NULL = NULL + arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)? + } + }; + + // Collect `true`s in 1-indexed positions + let indexes = eq_array.iter().positions(|e| e == Some(true)).map(|index| Some(index as u64 + 1)).collect::>(); + data.push(Some(indexes)); + } else { + data.push(None); + } + } + + Ok(Arc::new(ListArray::from_iter_primitive::(data))) + + // Ok(Arc::new(Int64Array::from(vec![Some(1)]))) +} + /// Array_position SQL function pub fn array_position(args: &[ArrayRef]) -> Result { let arr = as_list_array(&args[0])?; @@ -1000,7 +1051,8 @@ pub fn array_position(args: &[ArrayRef]) -> Result { let index = if args.len() == 3 { as_int64_array(&args[2])?.clone() } else { - Int64Array::from_value(0, arr.len()) + return general_position::(arr, element, &Int64Array::from_value(1, arr.len()), vec![1; arr.len()]) + // Int64Array::from_value(0, arr.len()) }; check_datatypes("array_position", &[arr.values(), element])?; From ef38e320b4e731f8eba5c17f74d6ff0aaf445e4e Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 13 Nov 2023 20:24:02 +0800 Subject: [PATCH 2/7] complete n Signed-off-by: jayzhan211 --- datafusion/expr/src/built_in_function.rs | 4 +- .../physical-expr/src/array_expressions.rs | 53 ++++++++++++------- 2 files changed, 36 insertions(+), 21 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 77c64128e156..1771adbe873c 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -568,8 +568,8 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayLength => Ok(UInt64), BuiltinScalarFunction::ArrayNdims => Ok(UInt64), BuiltinScalarFunction::ArrayPopBack => Ok(input_expr_types[0].clone()), - BuiltinScalarFunction::ArrayPosition => Ok(UInt64), - BuiltinScalarFunction::ArrayPositions => { + BuiltinScalarFunction::ArrayPosition + | BuiltinScalarFunction::ArrayPositions => { Ok(List(Arc::new(Field::new("item", UInt64, true)))) } BuiltinScalarFunction::ArrayPrepend => Ok(input_expr_types[1].clone()), diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 2c0dc3f675d6..b91c37184870 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -996,13 +996,16 @@ macro_rules! position { fn general_position( list_array: &GenericListArray, element_array: &ArrayRef, - from_array: &Int64Array, - arr_n: Vec, + from_array: Vec, // 0-indexed + n_array: Vec, ) -> Result { - let mut data = vec![]; + let mut data = Vec::with_capacity(n_array.len()); - for (row_index, (list_array_row, n)) in - list_array.iter().zip(arr_n.iter()).enumerate() + for (row_index, ((list_array_row, &from), &n)) in list_array + .iter() + .zip(from_array.iter()) + .zip(n_array.iter()) + .enumerate() { if let Some(list_array_row) = list_array_row { let indices = UInt32Array::from(vec![row_index as u32]); @@ -1029,18 +1032,25 @@ fn general_position( arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)? } }; - + // Collect `true`s in 1-indexed positions - let indexes = eq_array.iter().positions(|e| e == Some(true)).map(|index| Some(index as u64 + 1)).collect::>(); + let indexes = eq_array + .iter() + .skip(from) + .positions(|e| e == Some(true)) + .map(|index| Some(index as u64 + 1)) + .take(n) + .collect::>(); + data.push(Some(indexes)); } else { data.push(None); } } - Ok(Arc::new(ListArray::from_iter_primitive::(data))) - - // Ok(Arc::new(Int64Array::from(vec![Some(1)]))) + Ok(Arc::new( + ListArray::from_iter_primitive::(data), + )) } /// Array_position SQL function @@ -1048,11 +1058,17 @@ pub fn array_position(args: &[ArrayRef]) -> Result { let arr = as_list_array(&args[0])?; let element = &args[1]; + // handle incorrect from_array only 1 to n is accept. + let index = if args.len() == 3 { as_int64_array(&args[2])?.clone() } else { - return general_position::(arr, element, &Int64Array::from_value(1, arr.len()), vec![1; arr.len()]) - // Int64Array::from_value(0, arr.len()) + return general_position::( + arr, + element, + vec![0; arr.len()], + vec![1; arr.len()], + ); }; check_datatypes("array_position", &[arr.values(), element])?; @@ -1121,14 +1137,13 @@ pub fn array_positions(args: &[ArrayRef]) -> Result { let element = &args[1]; check_datatypes("array_positions", &[arr.values(), element])?; - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - positions!(arr, element, $ARRAY_TYPE) - }; - } - let res = call_array_function!(arr.value_type(), true); - Ok(res) + general_position::( + arr, + element, + vec![0; arr.len()], + vec![usize::MAX; arr.len()], + ) } /// For each element of `list_array[i]`, removed up to `arr_n[i]` occurences From f8e40db37c834c79db11f2378bd3c4de9c670316 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 13 Nov 2023 20:32:24 +0800 Subject: [PATCH 3/7] positions done Signed-off-by: jayzhan211 --- .../physical-expr/src/array_expressions.rs | 89 ++++--------------- 1 file changed, 16 insertions(+), 73 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index b91c37184870..f44886655eae 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -996,15 +996,15 @@ macro_rules! position { fn general_position( list_array: &GenericListArray, element_array: &ArrayRef, - from_array: Vec, // 0-indexed - n_array: Vec, + arr_from: Vec, // 0-indexed + arr_n: Vec, ) -> Result { - let mut data = Vec::with_capacity(n_array.len()); + let mut data = Vec::with_capacity(arr_n.len()); for (row_index, ((list_array_row, &from), &n)) in list_array .iter() - .zip(from_array.iter()) - .zip(n_array.iter()) + .zip(arr_from.iter()) + .zip(arr_n.iter()) .enumerate() { if let Some(list_array_row) = list_array_row { @@ -1039,7 +1039,7 @@ fn general_position( .skip(from) .positions(|e| e == Some(true)) .map(|index| Some(index as u64 + 1)) - .take(n) + .take(n as usize) .collect::>(); data.push(Some(indexes)); @@ -1060,75 +1060,18 @@ pub fn array_position(args: &[ArrayRef]) -> Result { // handle incorrect from_array only 1 to n is accept. - let index = if args.len() == 3 { - as_int64_array(&args[2])?.clone() + let arr_n = if args.len() == 3 { + as_int64_array(&args[2])?.values().to_vec() } else { - return general_position::( - arr, - element, - vec![0; arr.len()], - vec![1; arr.len()], - ); + vec![1; arr.len()] }; - check_datatypes("array_position", &[arr.values(), element])?; - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - position!(arr, element, index, $ARRAY_TYPE) - }; - } - let res = call_array_function!(arr.value_type(), true); - - Ok(Arc::new(res)) -} - -macro_rules! positions { - ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array(&DataType::UInt64), UInt64Array).clone(); - for comp in $ARRAY - .iter() - .zip(element.iter()) - .map(|(arr, el)| match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); - let res = child_array - .iter() - .enumerate() - .filter(|(_, x)| *x == el) - .flat_map(|(i, _)| Some((i + 1) as u64)) - .collect::(); - - Ok(res) - } - None => Ok(downcast_arg!( - new_empty_array(&DataType::UInt64), - UInt64Array - ) - .clone()), - }) - .collect::>>()? - { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty",)) - })?; - values = - downcast_arg!(compute::concat(&[&values, &comp,])?.clone(), UInt64Array) - .clone(); - offsets.push(last_offset + comp.len() as i32); - } - - let field = Arc::new(Field::new("item", DataType::UInt64, true)); - - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; + general_position::( + arr, + element, + vec![0; arr.len()], + arr_n, + ) } /// Array_positions SQL function @@ -1142,7 +1085,7 @@ pub fn array_positions(args: &[ArrayRef]) -> Result { arr, element, vec![0; arr.len()], - vec![usize::MAX; arr.len()], + vec![i64::MAX; arr.len()], ) } From edefb73a7df153c7ee6a1d7cda41589cf8e98aa7 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 14 Nov 2023 21:43:51 +0800 Subject: [PATCH 4/7] compare_element_to_list Signed-off-by: jayzhan211 --- datafusion/expr/src/built_in_function.rs | 4 +- .../physical-expr/src/array_expressions.rs | 234 +++++++++++++----- datafusion/sqllogictest/test_files/array.slt | 12 +- 3 files changed, 175 insertions(+), 75 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 1771adbe873c..77c64128e156 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -568,8 +568,8 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayLength => Ok(UInt64), BuiltinScalarFunction::ArrayNdims => Ok(UInt64), BuiltinScalarFunction::ArrayPopBack => Ok(input_expr_types[0].clone()), - BuiltinScalarFunction::ArrayPosition - | BuiltinScalarFunction::ArrayPositions => { + BuiltinScalarFunction::ArrayPosition => Ok(UInt64), + BuiltinScalarFunction::ArrayPositions => { Ok(List(Arc::new(Field::new("item", UInt64, true)))) } BuiltinScalarFunction::ArrayPrepend => Ok(input_expr_types[1].clone()), diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index f44886655eae..a4d5dd306c8b 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -132,6 +132,78 @@ macro_rules! array { }}; } +/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array. +/// +/// # Arguments +/// +/// * `list_array_row` - A reference to a trait object implementing the Arrow `Array` trait. It represents the list array for which the equality or inequality will be compared. +/// +/// * `element_array` - A reference to a trait object implementing the Arrow `Array` trait. It represents the array with which each element in the `list_array_row` will be compared. +/// +/// * `row_index` - The index of the row in the `element_array` and `list_array` to use for the comparison. +/// +/// * `eq` - A boolean flag. If `true`, the function computes equality; if `false`, it computes inequality. +/// +/// # Returns +/// +/// Returns a `Result` representing the comparison results. The result may contain an error if there are issues with the computation. +/// +/// # Example +/// +/// ```text +/// compare_element_to_list( +/// [1, 2, 3], [1, 2, 3], 0, true => [true, false, false] +/// [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, true, false] +/// +/// [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 0, true => [true, false, false] +/// [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 1, false => [true, false, false] +/// ) +/// ``` +fn compare_element_to_list( + list_array_row: &dyn Array, + element_array: &dyn Array, + row_index: usize, + eq: bool, +) -> Result { + let indices = UInt32Array::from(vec![row_index as u32]); + let element_array_row = arrow::compute::take(element_array, &indices, None)?; + // Compute all positions in list_row_array (that is itself an + // array) that are equal to `from_array_row` + let res = match element_array_row.data_type() { + // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop + DataType::List(_) => { + // compare each element of the from array + let element_array_row_inner = as_list_array(&element_array_row)?.value(0); + let list_array_row_inner = as_list_array(list_array_row)?; + + list_array_row_inner + .iter() + // compare element by element the current row of list_array + .map(|row| { + row.map(|row| { + if eq { + row.eq(&element_array_row_inner) + } else { + row.ne(&element_array_row_inner) + } + }) + }) + .collect::() + } + _ => { + let element_arr = Scalar::new(element_array_row); + // use not_distinct so we can compare NULL + if eq { + arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)? + } else { + arrow_ord::cmp::distinct(&list_array_row, &element_arr)? + } + } + }; + + Ok(res) +} + /// Returns the length of a concrete array dimension fn compute_array_length( arr: Option, @@ -954,92 +1026,120 @@ fn general_list_repeat( )?)) } -macro_rules! position { - ($ARRAY:expr, $ELEMENT:expr, $INDEX:expr, $ARRAY_TYPE:ident) => {{ - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - $ARRAY +/// Array_position SQL function +pub fn array_position(args: &[ArrayRef]) -> Result { + let list_array = as_list_array(&args[0])?; + let element_array = &args[1]; + + check_datatypes("array_position", &[list_array.values(), element_array])?; + + let arr_from = if args.len() == 3 { + as_int64_array(&args[2])? + .values() + .to_vec() .iter() - .zip(element.iter()) - .zip($INDEX.iter()) - .map(|((arr, el), i)| { - let index = match i { - Some(i) => { - if i <= 0 { - 0 - } else { - i - 1 - } - } - None => return exec_err!("initial position must not be null"), - }; + .map(|&x| x - 1) + .collect::>() + } else { + vec![0; list_array.len()] + }; - match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); + // if `start_from` index is out of bounds, return error + for (arr, &from) in list_array.iter().zip(arr_from.iter()) { + if let Some(arr) = arr { + if from < 0 || from as usize >= arr.len() { + return internal_err!("start_from index out of bounds"); + } + } else { + // We will get null if we got null in the array, so we don't need to check + } + } - match child_array - .iter() - .skip(index as usize) - .position(|x| x == el) - { - Some(value) => Ok(Some(value as u64 + index as u64 + 1u64)), - None => Ok(None), - } - } - None => Ok(None), - } - }) - .collect::>()? - }}; + general_position::(list_array, element_array, arr_from) } fn general_position( list_array: &GenericListArray, element_array: &ArrayRef, - arr_from: Vec, // 0-indexed - arr_n: Vec, + arr_from: Vec, // 0-indexed ) -> Result { - let mut data = Vec::with_capacity(arr_n.len()); + let mut data = Vec::with_capacity(list_array.len()); - for (row_index, ((list_array_row, &from), &n)) in list_array - .iter() - .zip(arr_from.iter()) - .zip(arr_n.iter()) - .enumerate() + for (row_index, (list_array_row, &from)) in + list_array.iter().zip(arr_from.iter()).enumerate() { + let from = from as usize; + if let Some(list_array_row) = list_array_row { - let indices = UInt32Array::from(vec![row_index as u32]); - let element_array_row = arrow::compute::take(element_array, &indices, None)?; - // Compute all positions in list_row_array (that is itself an - // array) that are equal to `from_array_row` - let eq_array = match element_array_row.data_type() { - // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop - DataType::List(_) => { - // compare each element of the from array - let element_array_row_inner = - as_list_array(&element_array_row)?.value(0); - let list_array_row_inner = as_list_array(&list_array_row)?; + let eq_array = + compare_element_to_list(&list_array_row, element_array, row_index, true)?; - list_array_row_inner - .iter() - // compare element by element the current row of list_array - .map(|row| row.map(|row| row.eq(&element_array_row_inner))) - .collect::() - } - _ => { - let element_arr = Scalar::new(element_array_row); - // use not_distinct so NULL = NULL - arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)? - } - }; + // Collect `true`s in 1-indexed positions + let index = eq_array + .iter() + .skip(from) + .position(|e| e == Some(true)) + .map(|index| (from + index + 1) as u64); + + data.push(index); + } else { + data.push(None); + } + } + + Ok(Arc::new(UInt64Array::from(data))) +} + +/// Array_positions SQL function +pub fn array_positions(args: &[ArrayRef]) -> Result { + let arr = as_list_array(&args[0])?; + let element = &args[1]; + + check_datatypes("array_positions", &[arr.values(), element])?; + + general_positions::(arr, element) +} + +fn general_positions( + list_array: &GenericListArray, + element_array: &ArrayRef, +) -> Result { + let mut data = Vec::with_capacity(list_array.len()); + + for (row_index, list_array_row) in list_array.iter().enumerate() { + if let Some(list_array_row) = list_array_row { + let eq_array = + compare_element_to_list(&list_array_row, element_array, row_index, true)?; + // let indices = UInt32Array::from(vec![row_index as u32]); + // let element_array_row = arrow::compute::take(element_array, &indices, None)?; + // // Compute all positions in list_row_array (that is itself an + // // array) that are equal to `from_array_row` + // let eq_array = match element_array_row.data_type() { + // // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop + // DataType::List(_) => { + // // compare each element of the from array + // let element_array_row_inner = + // as_list_array(&element_array_row)?.value(0); + // let list_array_row_inner = as_list_array(&list_array_row)?; + + // list_array_row_inner + // .iter() + // // compare element by element the current row of list_array + // .map(|row| row.map(|row| row.eq(&element_array_row_inner))) + // .collect::() + // } + // _ => { + // let element_arr = Scalar::new(element_array_row); + // // use not_distinct so NULL = NULL + // arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)? + // } + // }; // Collect `true`s in 1-indexed positions let indexes = eq_array .iter() - .skip(from) .positions(|e| e == Some(true)) .map(|index| Some(index as u64 + 1)) - .take(n as usize) .collect::>(); data.push(Some(indexes)); diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 54741afdf83a..af16f6322035 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -695,7 +695,7 @@ select array_element(make_array(1, 2, 3, 4, 5), 0), array_element(make_array('h' NULL NULL # array_element scalar function #4 (with NULL) -query error +query error select array_element(make_array(1, 2, 3, 4, 5), NULL), array_element(make_array('h', 'e', 'l', 'l', 'o'), NULL); # array_element scalar function #5 (with negative index) @@ -864,11 +864,11 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, 4), array_slice(make_array('h', [1, 2, 3, 4] [h, e, l] # array_slice scalar function #8 (with NULL and positive number) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), NULL, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, 3); # array_slice scalar function #9 (with positive number and NULL) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); # array_slice scalar function #10 (with zero-zero) @@ -878,7 +878,7 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, 0), array_slice(make_array('h', [] [] # array_slice scalar function #11 (with NULL-NULL) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL); # array_slice scalar function #12 (with zero and negative number) @@ -888,11 +888,11 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, -4), array_slice(make_array('h' [1] [h, e] # array_slice scalar function #13 (with negative number and NULL) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); # array_slice scalar function #14 (with NULL and negative number) -query error +query error select array_slice(make_array(1, 2, 3, 4, 5), NULL, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, -3); # array_slice scalar function #15 (with negative indexes) From 8274afe407fd5eb843a2162727a4a11db0b2d35d Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 14 Nov 2023 21:44:12 +0800 Subject: [PATCH 5/7] fmt Signed-off-by: jayzhan211 --- .../physical-expr/src/array_expressions.rs | 55 +++---------------- 1 file changed, 7 insertions(+), 48 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index a4d5dd306c8b..6ac16b29437c 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -25,7 +25,6 @@ use arrow::buffer::OffsetBuffer; use arrow::compute; use arrow::datatypes::{DataType, Field, UInt64Type}; use arrow::row::{RowConverter, SortField}; -use arrow_array::types::Int64Type; use arrow_buffer::NullBuffer; use arrow_schema::FieldRef; @@ -34,7 +33,7 @@ use datafusion_common::cast::{ }; use datafusion_common::utils::array_into_list_array; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, + internal_datafusion_err, internal_err, not_impl_err, plan_err, DataFusionError, Result, }; @@ -1110,30 +1109,6 @@ fn general_positions( if let Some(list_array_row) = list_array_row { let eq_array = compare_element_to_list(&list_array_row, element_array, row_index, true)?; - // let indices = UInt32Array::from(vec![row_index as u32]); - // let element_array_row = arrow::compute::take(element_array, &indices, None)?; - // // Compute all positions in list_row_array (that is itself an - // // array) that are equal to `from_array_row` - // let eq_array = match element_array_row.data_type() { - // // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop - // DataType::List(_) => { - // // compare each element of the from array - // let element_array_row_inner = - // as_list_array(&element_array_row)?.value(0); - // let list_array_row_inner = as_list_array(&list_array_row)?; - - // list_array_row_inner - // .iter() - // // compare element by element the current row of list_array - // .map(|row| row.map(|row| row.eq(&element_array_row_inner))) - // .collect::() - // } - // _ => { - // let element_arr = Scalar::new(element_array_row); - // // use not_distinct so NULL = NULL - // arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)? - // } - // }; // Collect `true`s in 1-indexed positions let indexes = eq_array @@ -1371,30 +1346,14 @@ fn general_replace( match list_array_row { Some(list_array_row) => { - let indices = UInt32Array::from(vec![row_index as u32]); - let from_array_row = arrow::compute::take(from_array, &indices, None)?; // Compute all positions in list_row_array (that is itself an // array) that are equal to `from_array_row` - let eq_array = match from_array_row.data_type() { - // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop - DataType::List(_) => { - // compare each element of the from array - let from_array_row_inner = - as_list_array(&from_array_row)?.value(0); - let list_array_row_inner = as_list_array(&list_array_row)?; - - list_array_row_inner - .iter() - // compare element by element the current row of list_array - .map(|row| row.map(|row| row.eq(&from_array_row_inner))) - .collect::() - } - _ => { - let from_arr = Scalar::new(from_array_row); - // use not_distinct so NULL = NULL - arrow_ord::cmp::not_distinct(&list_array_row, &from_arr)? - } - }; + let eq_array = compare_element_to_list( + &list_array_row, + &from_array, + row_index, + true, + )?; // Use MutableArrayData to build the replaced array let original_data = list_array_row.to_data(); From e46a62778500a53cb4b9e8b468e16a174025b05d Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 14 Nov 2023 21:51:12 +0800 Subject: [PATCH 6/7] resolve rebase Signed-off-by: jayzhan211 --- .../physical-expr/src/array_expressions.rs | 70 +++---------------- 1 file changed, 8 insertions(+), 62 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 6ac16b29437c..d42f241ff8dc 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -33,8 +33,8 @@ use datafusion_common::cast::{ }; use datafusion_common::utils::array_into_list_array; use datafusion_common::{ - internal_datafusion_err, internal_err, not_impl_err, plan_err, - DataFusionError, Result, + exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DataFusionError, + Result, }; use itertools::Itertools; @@ -1128,42 +1128,6 @@ fn general_positions( )) } -/// Array_position SQL function -pub fn array_position(args: &[ArrayRef]) -> Result { - let arr = as_list_array(&args[0])?; - let element = &args[1]; - - // handle incorrect from_array only 1 to n is accept. - - let arr_n = if args.len() == 3 { - as_int64_array(&args[2])?.values().to_vec() - } else { - vec![1; arr.len()] - }; - - general_position::( - arr, - element, - vec![0; arr.len()], - arr_n, - ) -} - -/// Array_positions SQL function -pub fn array_positions(args: &[ArrayRef]) -> Result { - let arr = as_list_array(&args[0])?; - let element = &args[1]; - - check_datatypes("array_positions", &[arr.values(), element])?; - - general_position::( - arr, - element, - vec![0; arr.len()], - vec![i64::MAX; arr.len()], - ) -} - /// For each element of `list_array[i]`, removed up to `arr_n[i]` occurences /// of `element_array[i]`. /// @@ -1198,30 +1162,12 @@ fn general_remove( { match list_array_row { Some(list_array_row) => { - let indices = UInt32Array::from(vec![row_index as u32]); - let element_array_row = - arrow::compute::take(element_array, &indices, None)?; - - let eq_array = match element_array_row.data_type() { - // arrow_ord::cmp::distinct does not support ListArray, so we need to compare it by loop - DataType::List(_) => { - // compare each element of the from array - let element_array_row_inner = - as_list_array(&element_array_row)?.value(0); - let list_array_row_inner = as_list_array(&list_array_row)?; - - list_array_row_inner - .iter() - // compare element by element the current row of list_array - .map(|row| row.map(|row| row.ne(&element_array_row_inner))) - .collect::() - } - _ => { - let from_arr = Scalar::new(element_array_row); - // use distinct so Null = Null is false - arrow_ord::cmp::distinct(&list_array_row, &from_arr)? - } - }; + let eq_array = compare_element_to_list( + &list_array_row, + element_array, + row_index, + false, + )?; // We need to keep at most first n elements as `false`, which represent the elements to remove. let eq_array = if eq_array.false_count() < *n as usize { From 340613cc848fbe5f4648ba56a81db801b06560cc Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 14 Nov 2023 21:54:57 +0800 Subject: [PATCH 7/7] fmt Signed-off-by: jayzhan211 --- datafusion/physical-expr/src/array_expressions.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index d42f241ff8dc..0f50831d5678 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -33,8 +33,8 @@ use datafusion_common::cast::{ }; use datafusion_common::utils::array_into_list_array; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DataFusionError, - Result, + exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, + DataFusionError, Result, }; use itertools::Itertools;