Skip to content

Commit

Permalink
Add simplification rules for the CONCAT function (#3684)
Browse files Browse the repository at this point in the history
* simpl concat

Signed-off-by: remzi <13716567376yh@gmail.com>

* update after type coercion

Signed-off-by: remzi <13716567376yh@gmail.com>

Signed-off-by: remzi <13716567376yh@gmail.com>
  • Loading branch information
HaoYang670 authored Oct 11, 2022
1 parent 0cf5630 commit ac1631a
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
68 changes: 68 additions & 0 deletions datafusion/optimizer/src/simplify_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -878,12 +878,56 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> {
out_expr.rewrite(self)?
}

// concat
ScalarFunction {
fun: BuiltinScalarFunction::Concat,
args,
} => {
let mut new_args = Vec::with_capacity(args.len());
let mut contiguous_scalar = "".to_string();
for e in args {
match e {
// All literals have been converted to Utf8 or LargeUtf8 in type_coercion.
// Concatenate it with `contiguous_scalar`.
Expr::Literal(
ScalarValue::Utf8(x) | ScalarValue::LargeUtf8(x),
) => {
if let Some(s) = x {
contiguous_scalar += &s;
}
}
// If the arg is not a literal, we should first push the current `contiguous_scalar`
// to the `new_args` (if it is not empty) and reset it to empty string.
// Then pushing this arg to the `new_args`.
e => {
if !contiguous_scalar.is_empty() {
new_args.push(Expr::Literal(ScalarValue::Utf8(Some(
contiguous_scalar.clone(),
))));
contiguous_scalar = "".to_string();
}
new_args.push(e);
}
}
}
if !contiguous_scalar.is_empty() {
new_args
.push(Expr::Literal(ScalarValue::Utf8(Some(contiguous_scalar))));
}

ScalarFunction {
fun: BuiltinScalarFunction::Concat,
args: new_args,
}
}

// concat_ws
ScalarFunction {
fun: BuiltinScalarFunction::ConcatWithSeparator,
args,
} => {
match &args[..] {
// concat_ws(null, ..) --> null
[Expr::Literal(sp), ..] if sp.is_null() => {
Expr::Literal(ScalarValue::Utf8(None))
}
Expand Down Expand Up @@ -1352,6 +1396,30 @@ mod tests {
}
}

#[test]
fn test_simplify_concat() {
fn build_concat_expr(args: &[Expr]) -> Expr {
Expr::ScalarFunction {
fun: BuiltinScalarFunction::Concat,
args: args.to_vec(),
}
}

let null = Expr::Literal(ScalarValue::Utf8(None));
let expr = build_concat_expr(&[
null.clone(),
col("c0"),
lit("hello "),
null.clone(),
lit("rust"),
col("c1"),
lit(""),
null,
]);
let expected = build_concat_expr(&[col("c0"), lit("hello rust"), col("c1")]);
assert_eq!(simplify(expr), expected)
}

// ------------------------------
// --- ConstEvaluator tests -----
// ------------------------------
Expand Down
13 changes: 13 additions & 0 deletions datafusion/optimizer/tests/integration-test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,19 @@ fn between_date64_plus_interval() -> Result<()> {
Ok(())
}

#[test]
fn concat_literals() -> Result<()> {
let sql = "SELECT concat(true, col_int32, false, null, 'hello', col_utf8, 12, 3.4) \
AS col
FROM test";
let plan = test_sql(sql)?;
let expected =
"Projection: concat(Utf8(\"1\"), CAST(test.col_int32 AS Utf8), Utf8(\"0hello\"), test.col_utf8, Utf8(\"123.4\")) AS col\
\n TableScan: test projection=[col_int32, col_utf8]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())
}

fn test_sql(sql: &str) -> Result<LogicalPlan> {
// parse the SQL
let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...
Expand Down

0 comments on commit ac1631a

Please sign in to comment.