From 04ca4ff7e1360b67dba2a48ab70d61e7f0cb942e Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sat, 29 Jun 2024 20:14:15 +0800 Subject: [PATCH] Rewrite `array @> array` and `array <@ array` in sql_expr_to_logical_expr (#11155) * rewrite at arrow Signed-off-by: jayzhan211 * rm useless test Signed-off-by: jayzhan211 * add test Signed-off-by: jayzhan211 * rm test Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/expr/src/expr_rewriter/mod.rs | 1 + datafusion/functions-array/src/lib.rs | 2 - datafusion/functions-array/src/rewrite.rs | 76 -------------------- datafusion/physical-expr-common/src/datum.rs | 33 --------- datafusion/sql/src/expr/mod.rs | 30 +++++++- datafusion/sql/tests/sql_integration.rs | 16 ----- datafusion/sqllogictest/test_files/array.slt | 22 ++++++ 7 files changed, 52 insertions(+), 128 deletions(-) delete mode 100644 datafusion/functions-array/src/rewrite.rs diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 1441374bdba3d..024e4a0ceae51 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -43,6 +43,7 @@ pub use order_by::rewrite_sort_cols_by_aggs; /// For example, concatenating arrays `a || b` is represented as /// `Operator::ArrowAt`, but can be implemented by calling a function /// `array_concat` from the `functions-array` crate. +// This is not used in datafusion internally, but it is still helpful for downstream project so don't remove it. pub trait FunctionRewrite { /// Return a human readable name for this rewrite fn name(&self) -> &str; diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index b2fcb5717b3a5..543b7a60277ed 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -46,7 +46,6 @@ pub mod repeat; pub mod replace; pub mod resize; pub mod reverse; -pub mod rewrite; pub mod set_ops; pub mod sort; pub mod string; @@ -152,7 +151,6 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { } Ok(()) as Result<()> })?; - registry.register_function_rewrite(Arc::new(rewrite::ArrayFunctionRewriter {}))?; Ok(()) } diff --git a/datafusion/functions-array/src/rewrite.rs b/datafusion/functions-array/src/rewrite.rs deleted file mode 100644 index 28bc2d5e43730..0000000000000 --- a/datafusion/functions-array/src/rewrite.rs +++ /dev/null @@ -1,76 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Rewrites for using Array Functions - -use crate::array_has::array_has_all; -use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::Transformed; -use datafusion_common::DFSchema; -use datafusion_common::Result; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::expr_rewriter::FunctionRewrite; -use datafusion_expr::{BinaryExpr, Expr, Operator}; - -/// Rewrites expressions into function calls to array functions -pub(crate) struct ArrayFunctionRewriter {} - -impl FunctionRewrite for ArrayFunctionRewriter { - fn name(&self) -> &str { - "ArrayFunctionRewriter" - } - - fn rewrite( - &self, - expr: Expr, - _schema: &DFSchema, - _config: &ConfigOptions, - ) -> Result> { - let transformed = match expr { - // array1 @> array2 -> array_has_all(array1, array2) - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::AtArrow - && is_func(&left, "make_array") - && is_func(&right, "make_array") => - { - Transformed::yes(array_has_all(*left, *right)) - } - - // array1 <@ array2 -> array_has_all(array2, array1) - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::ArrowAt - && is_func(&left, "make_array") - && is_func(&right, "make_array") => - { - Transformed::yes(array_has_all(*right, *left)) - } - - _ => Transformed::no(expr), - }; - Ok(transformed) - } -} - -/// Returns true if expr is a function call to the specified named function. -/// Returns false otherwise. -fn is_func(expr: &Expr, func_name: &str) -> bool { - let Expr::ScalarFunction(ScalarFunction { func, args: _ }) = expr else { - return false; - }; - - func.name() == func_name -} diff --git a/datafusion/physical-expr-common/src/datum.rs b/datafusion/physical-expr-common/src/datum.rs index 96c903180ed98..790e742c42211 100644 --- a/datafusion/physical-expr-common/src/datum.rs +++ b/datafusion/physical-expr-common/src/datum.rs @@ -145,36 +145,3 @@ pub fn compare_op_for_nested( Ok(BooleanArray::new(values, nulls)) } } - -#[cfg(test)] -mod tests { - use arrow::{ - array::{make_comparator, Array, BooleanArray, ListArray}, - buffer::NullBuffer, - compute::SortOptions, - datatypes::Int32Type, - }; - - #[test] - fn test123() { - let data = vec![ - Some(vec![Some(0), Some(1), Some(2)]), - None, - Some(vec![Some(3), None, Some(5)]), - Some(vec![Some(6), Some(7)]), - ]; - let a = ListArray::from_iter_primitive::(data); - let data = vec![ - Some(vec![Some(0), Some(1), Some(2)]), - None, - Some(vec![Some(3), None, Some(5)]), - Some(vec![Some(6), Some(7)]), - ]; - let b = ListArray::from_iter_primitive::(data); - let cmp = make_comparator(&a, &b, SortOptions::default()).unwrap(); - let len = a.len().min(b.len()); - let values = (0..len).map(|i| cmp(i, i).is_eq()).collect(); - let nulls = NullBuffer::union(a.nulls(), b.nulls()); - println!("res: {:?}", BooleanArray::new(values, nulls)); - } -} diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index a8af37ee6a37d..b1182b35ec95a 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -150,10 +150,38 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { vec![left, right], ))); } else { - return internal_err!("array_append not found"); + return internal_err!("array_prepend not found"); + } + } + } else if matches!(op, Operator::AtArrow | Operator::ArrowAt) { + let left_type = left.get_type(schema)?; + let right_type = right.get_type(schema)?; + let left_list_ndims = list_ndims(&left_type); + let right_list_ndims = list_ndims(&right_type); + // if both are list + if left_list_ndims > 0 && right_list_ndims > 0 { + if let Some(udf) = + self.context_provider.get_function_meta("array_has_all") + { + // array1 @> array2 -> array_has_all(array1, array2) + if op == Operator::AtArrow { + return Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![left, right], + ))); + // array1 <@ array2 -> array_has_all(array2, array1) + } else { + return Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![right, left], + ))); + } + } else { + return internal_err!("array_has_all not found"); } } } + Ok(Expr::BinaryExpr(BinaryExpr::new( Box::new(left), op, diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index e72a439b323b0..ec623a956186b 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -1227,22 +1227,6 @@ fn select_binary_expr_nested() { quick_test(sql, expected); } -#[test] -fn select_at_arrow_operator() { - let sql = "SELECT left @> right from array"; - let expected = "Projection: array.left @> array.right\ - \n TableScan: array"; - quick_test(sql, expected); -} - -#[test] -fn select_arrow_at_operator() { - let sql = "SELECT left <@ right from array"; - let expected = "Projection: array.left <@ array.right\ - \n TableScan: array"; - quick_test(sql, expected); -} - #[test] fn select_wildcard_with_groupby() { quick_test( diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 77d1a9da1f552..7917f1d78da8e 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -6076,6 +6076,17 @@ select make_array(1,2,3) @> make_array(1,3), ---- true false true false false false true +# Make sure it is rewritten to function array_has_all() +query TT +explain select [1,2,3] @> [1,3]; +---- +logical_plan +01)Projection: Boolean(true) AS array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3))) +02)--EmptyRelation +physical_plan +01)ProjectionExec: expr=[true as array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))] +02)--PlaceholderRowExec + # array containment operator with scalars #2 (arrow at) query BBBBBBB select make_array(1,3) <@ make_array(1,2,3), @@ -6088,6 +6099,17 @@ select make_array(1,3) <@ make_array(1,2,3), ---- true false true false false false true +# Make sure it is rewritten to function array_has_all() +query TT +explain select [1,3] <@ [1,2,3]; +---- +logical_plan +01)Projection: Boolean(true) AS array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3))) +02)--EmptyRelation +physical_plan +01)ProjectionExec: expr=[true as array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))] +02)--PlaceholderRowExec + ### Array casting tests