Skip to content

Commit

Permalink
Rewrite array @> array and array <@ array in sql_expr_to_logical_…
Browse files Browse the repository at this point in the history
…expr (#11155)

* rewrite at arrow

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* rm useless test

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* add test

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* rm test

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

---------

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 authored Jun 29, 2024
1 parent c80da91 commit 14d3973
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 128 deletions.
1 change: 1 addition & 0 deletions datafusion/expr/src/expr_rewriter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub use order_by::rewrite_sort_cols_by_aggs;
/// For example, concatenating arrays `a || b` is represented as
/// `Operator::ArrowAt`, but can be implemented by calling a function
/// `array_concat` from the `functions-array` crate.
// This is not used in datafusion internally, but it is still helpful for downstream project so don't remove it.
pub trait FunctionRewrite {
/// Return a human readable name for this rewrite
fn name(&self) -> &str;
Expand Down
2 changes: 0 additions & 2 deletions datafusion/functions-array/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ pub mod repeat;
pub mod replace;
pub mod resize;
pub mod reverse;
pub mod rewrite;
pub mod set_ops;
pub mod sort;
pub mod string;
Expand Down Expand Up @@ -152,7 +151,6 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
}
Ok(()) as Result<()>
})?;
registry.register_function_rewrite(Arc::new(rewrite::ArrayFunctionRewriter {}))?;

Ok(())
}
Expand Down
76 changes: 0 additions & 76 deletions datafusion/functions-array/src/rewrite.rs

This file was deleted.

33 changes: 0 additions & 33 deletions datafusion/physical-expr-common/src/datum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,36 +145,3 @@ pub fn compare_op_for_nested(
Ok(BooleanArray::new(values, nulls))
}
}

#[cfg(test)]
mod tests {
use arrow::{
array::{make_comparator, Array, BooleanArray, ListArray},
buffer::NullBuffer,
compute::SortOptions,
datatypes::Int32Type,
};

#[test]
fn test123() {
let data = vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), None, Some(5)]),
Some(vec![Some(6), Some(7)]),
];
let a = ListArray::from_iter_primitive::<Int32Type, _, _>(data);
let data = vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), None, Some(5)]),
Some(vec![Some(6), Some(7)]),
];
let b = ListArray::from_iter_primitive::<Int32Type, _, _>(data);
let cmp = make_comparator(&a, &b, SortOptions::default()).unwrap();
let len = a.len().min(b.len());
let values = (0..len).map(|i| cmp(i, i).is_eq()).collect();
let nulls = NullBuffer::union(a.nulls(), b.nulls());
println!("res: {:?}", BooleanArray::new(values, nulls));
}
}
30 changes: 29 additions & 1 deletion datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,38 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
vec![left, right],
)));
} else {
return internal_err!("array_append not found");
return internal_err!("array_prepend not found");
}
}
} else if matches!(op, Operator::AtArrow | Operator::ArrowAt) {
let left_type = left.get_type(schema)?;
let right_type = right.get_type(schema)?;
let left_list_ndims = list_ndims(&left_type);
let right_list_ndims = list_ndims(&right_type);
// if both are list
if left_list_ndims > 0 && right_list_ndims > 0 {
if let Some(udf) =
self.context_provider.get_function_meta("array_has_all")
{
// array1 @> array2 -> array_has_all(array1, array2)
if op == Operator::AtArrow {
return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![left, right],
)));
// array1 <@ array2 -> array_has_all(array2, array1)
} else {
return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![right, left],
)));
}
} else {
return internal_err!("array_has_all not found");
}
}
}

Ok(Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
op,
Expand Down
16 changes: 0 additions & 16 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1227,22 +1227,6 @@ fn select_binary_expr_nested() {
quick_test(sql, expected);
}

#[test]
fn select_at_arrow_operator() {
let sql = "SELECT left @> right from array";
let expected = "Projection: array.left @> array.right\
\n TableScan: array";
quick_test(sql, expected);
}

#[test]
fn select_arrow_at_operator() {
let sql = "SELECT left <@ right from array";
let expected = "Projection: array.left <@ array.right\
\n TableScan: array";
quick_test(sql, expected);
}

#[test]
fn select_wildcard_with_groupby() {
quick_test(
Expand Down
22 changes: 22 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -6076,6 +6076,17 @@ select make_array(1,2,3) @> make_array(1,3),
----
true false true false false false true

# Make sure it is rewritten to function array_has_all()
query TT
explain select [1,2,3] @> [1,3];
----
logical_plan
01)Projection: Boolean(true) AS array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))
02)--EmptyRelation
physical_plan
01)ProjectionExec: expr=[true as array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))]
02)--PlaceholderRowExec

# array containment operator with scalars #2 (arrow at)
query BBBBBBB
select make_array(1,3) <@ make_array(1,2,3),
Expand All @@ -6088,6 +6099,17 @@ select make_array(1,3) <@ make_array(1,2,3),
----
true false true false false false true

# Make sure it is rewritten to function array_has_all()
query TT
explain select [1,3] <@ [1,2,3];
----
logical_plan
01)Projection: Boolean(true) AS array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))
02)--EmptyRelation
physical_plan
01)ProjectionExec: expr=[true as array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))]
02)--PlaceholderRowExec

### Array casting tests


Expand Down

0 comments on commit 14d3973

Please sign in to comment.