Skip to content

Commit

Permalink
Convert Binary Operator StringConcat to Function for array_concat
Browse files Browse the repository at this point in the history
…, `array_append` and `array_prepend` (#8636)

* reuse function for string concat

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* remove casting in string concat

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* add test

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* operator to function rewrite

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fix explain

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* add more test

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* add column cases

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* cleanup

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* presever name

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* Update datafusion/optimizer/src/analyzer/rewrite_expr.rs

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>

* rename

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

---------

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
jayzhan211 and alamb authored Jan 5, 2024
1 parent 821db54 commit 4e4059a
Show file tree
Hide file tree
Showing 7 changed files with 371 additions and 11 deletions.
2 changes: 0 additions & 2 deletions datafusion/expr/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -667,8 +667,6 @@ fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<Da
(LargeUtf8, from_type) | (from_type, LargeUtf8) => {
string_concat_internal_coercion(from_type, &LargeUtf8)
}
// TODO: cast between array elements (#6558)
(List(_), from_type) | (from_type, List(_)) => Some(from_type.to_owned()),
_ => None,
})
}
Expand Down
6 changes: 6 additions & 0 deletions datafusion/optimizer/src/analyzer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

pub mod count_wildcard_rule;
pub mod inline_table_scan;
pub mod rewrite_expr;
pub mod subquery;
pub mod type_coercion;

Expand All @@ -37,6 +38,8 @@ use log::debug;
use std::sync::Arc;
use std::time::Instant;

use self::rewrite_expr::OperatorToFunction;

/// [`AnalyzerRule`]s transform [`LogicalPlan`]s in some way to make
/// the plan valid prior to the rest of the DataFusion optimization process.
///
Expand Down Expand Up @@ -72,6 +75,9 @@ impl Analyzer {
pub fn new() -> Self {
let rules: Vec<Arc<dyn AnalyzerRule + Send + Sync>> = vec![
Arc::new(InlineTableScan::new()),
// OperatorToFunction should be run before TypeCoercion, since it rewrite based on the argument types (List or Scalar),
// and TypeCoercion may cast the argument types from Scalar to List.
Arc::new(OperatorToFunction::new()),
Arc::new(TypeCoercion::new()),
Arc::new(CountWildcardRule::new()),
];
Expand Down
321 changes: 321 additions & 0 deletions datafusion/optimizer/src/analyzer/rewrite_expr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,321 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Analyzer rule for to replace operators with function calls (e.g `||` to array_concat`)

use std::sync::Arc;

use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::TreeNodeRewriter;
use datafusion_common::utils::list_ndims;
use datafusion_common::DFSchema;
use datafusion_common::DFSchemaRef;
use datafusion_common::Result;
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::expr_rewriter::rewrite_preserving_name;
use datafusion_expr::utils::merge_schema;
use datafusion_expr::BuiltinScalarFunction;
use datafusion_expr::Operator;
use datafusion_expr::ScalarFunctionDefinition;
use datafusion_expr::{BinaryExpr, Expr, LogicalPlan};

use super::AnalyzerRule;

#[derive(Default)]
pub struct OperatorToFunction {}

impl OperatorToFunction {
pub fn new() -> Self {
Self {}
}
}

impl AnalyzerRule for OperatorToFunction {
fn name(&self) -> &str {
"operator_to_function"
}

fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
analyze_internal(&plan)
}
}

fn analyze_internal(plan: &LogicalPlan) -> Result<LogicalPlan> {
// optimize child plans first
let new_inputs = plan
.inputs()
.iter()
.map(|p| analyze_internal(p))
.collect::<Result<Vec<_>>>()?;

// get schema representing all available input fields. This is used for data type
// resolution only, so order does not matter here
let mut schema = merge_schema(new_inputs.iter().collect());

if let LogicalPlan::TableScan(ts) = plan {
let source_schema =
DFSchema::try_from_qualified_schema(&ts.table_name, &ts.source.schema())?;
schema.merge(&source_schema);
}

let mut expr_rewrite = OperatorToFunctionRewriter {
schema: Arc::new(schema),
};

let new_expr = plan
.expressions()
.into_iter()
.map(|expr| {
// ensure names don't change:
// https://github.com/apache/arrow-datafusion/issues/3555
rewrite_preserving_name(expr, &mut expr_rewrite)
})
.collect::<Result<Vec<_>>>()?;

plan.with_new_exprs(new_expr, &new_inputs)
}

pub(crate) struct OperatorToFunctionRewriter {
pub(crate) schema: DFSchemaRef,
}

impl TreeNodeRewriter for OperatorToFunctionRewriter {
type N = Expr;

fn mutate(&mut self, expr: Expr) -> Result<Expr> {
match expr {
Expr::BinaryExpr(BinaryExpr {
ref left,
op,
ref right,
}) => {
if let Some(fun) = rewrite_array_concat_operator_to_func_for_column(
left.as_ref(),
op,
right.as_ref(),
self.schema.as_ref(),
)?
.or_else(|| {
rewrite_array_concat_operator_to_func(
left.as_ref(),
op,
right.as_ref(),
)
}) {
// Convert &Box<Expr> -> Expr
let left = (**left).clone();
let right = (**right).clone();
return Ok(Expr::ScalarFunction(ScalarFunction {
func_def: ScalarFunctionDefinition::BuiltIn(fun),
args: vec![left, right],
}));
}

Ok(expr)
}
_ => Ok(expr),
}
}
}

/// Summary of the logic below:
///
/// 1) array || array -> array concat
///
/// 2) array || scalar -> array append
///
/// 3) scalar || array -> array prepend
///
/// 4) (arry concat, array append, array prepend) || array -> array concat
///
/// 5) (arry concat, array append, array prepend) || scalar -> array append
fn rewrite_array_concat_operator_to_func(
left: &Expr,
op: Operator,
right: &Expr,
) -> Option<BuiltinScalarFunction> {
// Convert `Array StringConcat Array` to ScalarFunction::ArrayConcat

if op != Operator::StringConcat {
return None;
}

match (left, right) {
// Chain concat operator (a || b) || array,
// (arry concat, array append, array prepend) || array -> array concat
(
Expr::ScalarFunction(ScalarFunction {
func_def:
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat),
args: _left_args,
}),
Expr::ScalarFunction(ScalarFunction {
func_def:
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
args: _right_args,
}),
)
| (
Expr::ScalarFunction(ScalarFunction {
func_def:
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend),
args: _left_args,
}),
Expr::ScalarFunction(ScalarFunction {
func_def:
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
args: _right_args,
}),
)
| (
Expr::ScalarFunction(ScalarFunction {
func_def:
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend),
args: _left_args,
}),
Expr::ScalarFunction(ScalarFunction {
func_def:
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
args: _right_args,
}),
) => Some(BuiltinScalarFunction::ArrayConcat),
// Chain concat operator (a || b) || scalar,
// (arry concat, array append, array prepend) || scalar -> array append
(
Expr::ScalarFunction(ScalarFunction {
func_def:
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat),
args: _left_args,
}),
_scalar,
)
| (
Expr::ScalarFunction(ScalarFunction {
func_def:
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend),
args: _left_args,
}),
_scalar,
)
| (
Expr::ScalarFunction(ScalarFunction {
func_def:
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend),
args: _left_args,
}),
_scalar,
) => Some(BuiltinScalarFunction::ArrayAppend),
// array || array -> array concat
(
Expr::ScalarFunction(ScalarFunction {
func_def:
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
args: _left_args,
}),
Expr::ScalarFunction(ScalarFunction {
func_def:
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
args: _right_args,
}),
) => Some(BuiltinScalarFunction::ArrayConcat),
// array || scalar -> array append
(
Expr::ScalarFunction(ScalarFunction {
func_def:
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
args: _left_args,
}),
_right_scalar,
) => Some(BuiltinScalarFunction::ArrayAppend),
// scalar || array -> array prepend
(
_left_scalar,
Expr::ScalarFunction(ScalarFunction {
func_def:
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
args: _right_args,
}),
) => Some(BuiltinScalarFunction::ArrayPrepend),

_ => None,
}
}

/// Summary of the logic below:
///
/// 1) (arry concat, array append, array prepend) || column -> (array append, array concat)
///
/// 2) column1 || column2 -> (array prepend, array append, array concat)
fn rewrite_array_concat_operator_to_func_for_column(
left: &Expr,
op: Operator,
right: &Expr,
schema: &DFSchema,
) -> Result<Option<BuiltinScalarFunction>> {
if op != Operator::StringConcat {
return Ok(None);
}

match (left, right) {
// Column cases:
// 1) array_prepend/append/concat || column
(
Expr::ScalarFunction(ScalarFunction {
func_def:
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend),
args: _left_args,
}),
Expr::Column(c),
)
| (
Expr::ScalarFunction(ScalarFunction {
func_def:
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend),
args: _left_args,
}),
Expr::Column(c),
)
| (
Expr::ScalarFunction(ScalarFunction {
func_def:
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat),
args: _left_args,
}),
Expr::Column(c),
) => {
let d = schema.field_from_column(c)?.data_type();
let ndim = list_ndims(d);
match ndim {
0 => Ok(Some(BuiltinScalarFunction::ArrayAppend)),
_ => Ok(Some(BuiltinScalarFunction::ArrayConcat)),
}
}
// 2) select column1 || column2
(Expr::Column(c1), Expr::Column(c2)) => {
let d1 = schema.field_from_column(c1)?.data_type();
let d2 = schema.field_from_column(c2)?.data_type();
let ndim1 = list_ndims(d1);
let ndim2 = list_ndims(d2);
match (ndim1, ndim2) {
(0, _) => Ok(Some(BuiltinScalarFunction::ArrayPrepend)),
(_, 0) => Ok(Some(BuiltinScalarFunction::ArrayAppend)),
_ => Ok(Some(BuiltinScalarFunction::ArrayConcat)),
}
}
_ => Ok(None),
}
}
11 changes: 2 additions & 9 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ mod kernels;
use std::hash::{Hash, Hasher};
use std::{any::Any, sync::Arc};

use crate::array_expressions::{
array_append, array_concat, array_has_all, array_prepend,
};
use crate::array_expressions::array_has_all;
use crate::expressions::datum::{apply, apply_cmp};
use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
use crate::physical_expr::down_cast_any_ref;
Expand Down Expand Up @@ -598,12 +596,7 @@ impl BinaryExpr {
BitwiseXor => bitwise_xor_dyn(left, right),
BitwiseShiftRight => bitwise_shift_right_dyn(left, right),
BitwiseShiftLeft => bitwise_shift_left_dyn(left, right),
StringConcat => match (left_data_type, right_data_type) {
(DataType::List(_), DataType::List(_)) => array_concat(&[left, right]),
(DataType::List(_), _) => array_append(&[left, right]),
(_, DataType::List(_)) => array_prepend(&[left, right]),
_ => binary_string_array_op!(left, right, concat_elements),
},
StringConcat => binary_string_array_op!(left, right, concat_elements),
AtArrow => array_has_all(&[left, right]),
ArrowAt => array_has_all(&[right, left]),
}
Expand Down
2 changes: 2 additions & 0 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
StackEntry::Operator(op) => {
let right = eval_stack.pop().unwrap();
let left = eval_stack.pop().unwrap();

let expr = Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
op,
Box::new(right),
));

eval_stack.push(expr);
}
}
Expand Down
Loading

0 comments on commit 4e4059a

Please sign in to comment.