Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite array @> array and array <@ array in sql_expr_to_logical_expr #11155

Merged
merged 4 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion/expr/src/expr_rewriter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub use order_by::rewrite_sort_cols_by_aggs;
/// For example, concatenating arrays `a || b` is represented as
/// `Operator::ArrowAt`, but can be implemented by calling a function
/// `array_concat` from the `functions-array` crate.
// This is not used in datafusion internally, but it is still helpful for downstream project so don't remove it.
pub trait FunctionRewrite {
/// Return a human readable name for this rewrite
fn name(&self) -> &str;
Expand Down
2 changes: 0 additions & 2 deletions datafusion/functions-array/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ pub mod repeat;
pub mod replace;
pub mod resize;
pub mod reverse;
pub mod rewrite;
pub mod set_ops;
pub mod sort;
pub mod string;
Expand Down Expand Up @@ -152,7 +151,6 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
}
Ok(()) as Result<()>
})?;
registry.register_function_rewrite(Arc::new(rewrite::ArrayFunctionRewriter {}))?;

Ok(())
}
Expand Down
76 changes: 0 additions & 76 deletions datafusion/functions-array/src/rewrite.rs

This file was deleted.

33 changes: 0 additions & 33 deletions datafusion/physical-expr-common/src/datum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,36 +145,3 @@ fn compare_op_for_nested(
Ok(BooleanArray::new(values, nulls))
}
}

#[cfg(test)]
mod tests {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot to remove in previous PR

use arrow::{
array::{make_comparator, Array, BooleanArray, ListArray},
buffer::NullBuffer,
compute::SortOptions,
datatypes::Int32Type,
};

#[test]
fn test123() {
let data = vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), None, Some(5)]),
Some(vec![Some(6), Some(7)]),
];
let a = ListArray::from_iter_primitive::<Int32Type, _, _>(data);
let data = vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), None, Some(5)]),
Some(vec![Some(6), Some(7)]),
];
let b = ListArray::from_iter_primitive::<Int32Type, _, _>(data);
let cmp = make_comparator(&a, &b, SortOptions::default()).unwrap();
let len = a.len().min(b.len());
let values = (0..len).map(|i| cmp(i, i).is_eq()).collect();
let nulls = NullBuffer::union(a.nulls(), b.nulls());
println!("res: {:?}", BooleanArray::new(values, nulls));
}
}
30 changes: 29 additions & 1 deletion datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,38 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
vec![left, right],
)));
} else {
return internal_err!("array_append not found");
return internal_err!("array_prepend not found");
}
}
} else if matches!(op, Operator::AtArrow | Operator::ArrowAt) {
let left_type = left.get_type(schema)?;
let right_type = right.get_type(schema)?;
let left_list_ndims = list_ndims(&left_type);
let right_list_ndims = list_ndims(&right_type);
// if both are list
if left_list_ndims > 0 && right_list_ndims > 0 {
if let Some(udf) =
self.context_provider.get_function_meta("array_has_all")
{
// array1 @> array2 -> array_has_all(array1, array2)
if op == Operator::AtArrow {
return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![left, right],
)));
// array1 <@ array2 -> array_has_all(array2, array1)
} else {
return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![right, left],
)));
}
} else {
return internal_err!("array_has_all not found");
}
}
}

Ok(Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
op,
Expand Down
16 changes: 0 additions & 16 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1225,22 +1225,6 @@ fn select_binary_expr_nested() {
quick_test(sql, expected);
}

#[test]
fn select_at_arrow_operator() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add an explain test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is needed since we don't really care about how it looks in plan.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can make sure array_has_all is applied 🤔

let sql = "SELECT left @> right from array";
let expected = "Projection: array.left @> array.right\
\n TableScan: array";
quick_test(sql, expected);
}

#[test]
fn select_arrow_at_operator() {
let sql = "SELECT left <@ right from array";
let expected = "Projection: array.left <@ array.right\
\n TableScan: array";
quick_test(sql, expected);
}

#[test]
fn select_wildcard_with_groupby() {
quick_test(
Expand Down
22 changes: 22 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -6076,6 +6076,17 @@ select make_array(1,2,3) @> make_array(1,3),
----
true false true false false false true

# Make sure it is rewritten to function array_has_all()
query TT
explain select [1,2,3] @> [1,3];
----
logical_plan
01)Projection: Boolean(true) AS array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))
02)--EmptyRelation
physical_plan
01)ProjectionExec: expr=[true as array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))]
02)--PlaceholderRowExec

# array containment operator with scalars #2 (arrow at)
query BBBBBBB
select make_array(1,3) <@ make_array(1,2,3),
Expand All @@ -6088,6 +6099,17 @@ select make_array(1,3) <@ make_array(1,2,3),
----
true false true false false false true

# Make sure it is rewritten to function array_has_all()
query TT
explain select [1,3] <@ [1,2,3];
----
logical_plan
01)Projection: Boolean(true) AS array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))
02)--EmptyRelation
physical_plan
01)ProjectionExec: expr=[true as array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))]
02)--PlaceholderRowExec

### Array casting tests


Expand Down