Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move concat, concat_ws, ends_with, initcap to datafusion-functions #10089

Merged
merged 1 commit into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ use crate::{
physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner},
variable::{VarProvider, VarType},
};
use crate::{functions, functions_aggregate, functions_array};

#[cfg(feature = "array_expressions")]
use crate::functions_array;
use crate::{functions, functions_aggregate};

use arrow::datatypes::{DataType, SchemaRef};
use arrow::record_batch::RecordBatch;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use datafusion::assert_batches_eq;
use datafusion_common::DFSchema;
use datafusion_expr::expr::Alias;
use datafusion_expr::{approx_median, cast, ExprSchemable};
use datafusion_functions::unicode::expr_fn::character_length;
use datafusion_functions_array::expr_fn::array_to_string;

fn test_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Expand Down
28 changes: 27 additions & 1 deletion datafusion/core/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,30 @@ fn timestamp_nano_ts_utc_predicates() {
assert_eq!(expected, format!("{plan:?}"));
}

#[test]
fn concat_literals() -> Result<()> {
let sql = "SELECT concat(true, col_int32, false, null, 'hello', col_utf8, 12, 3.4) \
AS col
FROM test";
let expected =
"Projection: concat(Utf8(\"true\"), CAST(test.col_int32 AS Utf8), Utf8(\"falsehello\"), test.col_utf8, Utf8(\"123.4\")) AS col\
\n TableScan: test projection=[col_int32, col_utf8]";
quick_test(sql, expected);
Ok(())
}

#[test]
fn concat_ws_literals() -> Result<()> {
let sql = "SELECT concat_ws('-', true, col_int32, false, null, 'hello', col_utf8, 12, '', 3.4) \
AS col
FROM test";
let expected =
"Projection: concat_ws(Utf8(\"-\"), Utf8(\"true\"), CAST(test.col_int32 AS Utf8), Utf8(\"false-hello\"), test.col_utf8, Utf8(\"12--3.4\")) AS col\
\n TableScan: test projection=[col_int32, col_utf8]";
quick_test(sql, expected);
Ok(())
}

fn quick_test(sql: &str, expected_plan: &str) {
let plan = test_sql(sql).unwrap();
assert_eq!(expected_plan, format!("{:?}", plan));
Expand All @@ -97,7 +121,9 @@ fn test_sql(sql: &str) -> Result<LogicalPlan> {
// create a logical query plan
let context_provider = MyContextProvider::default()
.with_udf(datetime::now())
.with_udf(datafusion_functions::core::arrow_cast());
.with_udf(datafusion_functions::core::arrow_cast())
.with_udf(datafusion_functions::string::concat())
.with_udf(datafusion_functions::string::concat_ws());
let sql_to_rel = SqlToRel::new(&context_provider);
let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap();

Expand Down
99 changes: 94 additions & 5 deletions datafusion/core/tests/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use datafusion_expr::{
expr, table_scan, Cast, ColumnarValue, Expr, ExprSchemable, LogicalPlan,
LogicalPlanBuilder, ScalarUDF, Volatility,
};
use datafusion_functions::math;
use datafusion_functions::{math, string};
use datafusion_optimizer::optimizer::Optimizer;
use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions};
use datafusion_optimizer::{OptimizerContext, OptimizerRule};
Expand Down Expand Up @@ -217,7 +217,7 @@ fn fold_and_simplify() {
let info: MyInfo = schema().into();

// What will it do with the expression `concat('foo', 'bar') == 'foobar')`?
let expr = concat(&[lit("foo"), lit("bar")]).eq(lit("foobar"));
let expr = concat(vec![lit("foo"), lit("bar")]).eq(lit("foobar"));

// Since datafusion applies both simplification *and* rewriting
// some expressions can be entirely simplified
Expand Down Expand Up @@ -364,13 +364,13 @@ fn test_const_evaluator() {
#[test]
fn test_const_evaluator_scalar_functions() {
// concat("foo", "bar") --> "foobar"
let expr = call_fn("concat", vec![lit("foo"), lit("bar")]).unwrap();
let expr = string::expr_fn::concat(vec![lit("foo"), lit("bar")]);
test_evaluate(expr, lit("foobar"));

// ensure arguments are also constant folded
// concat("foo", concat("bar", "baz")) --> "foobarbaz"
let concat1 = call_fn("concat", vec![lit("bar"), lit("baz")]).unwrap();
let expr = call_fn("concat", vec![lit("foo"), concat1]).unwrap();
let concat1 = string::expr_fn::concat(vec![lit("bar"), lit("baz")]);
let expr = string::expr_fn::concat(vec![lit("foo"), concat1]);
test_evaluate(expr, lit("foobarbaz"));

// Check non string arguments
Expand Down Expand Up @@ -569,3 +569,92 @@ fn test_simplify_power() {
test_simplify(expr, expected)
}
}

#[test]
fn test_simplify_concat_ws() {
let null = lit(ScalarValue::Utf8(None));
// the delimiter is not a literal
{
let expr = concat_ws(col("c"), vec![lit("a"), null.clone(), lit("b")]);
let expected = concat_ws(col("c"), vec![lit("a"), lit("b")]);
test_simplify(expr, expected);
}

// the delimiter is an empty string
{
let expr = concat_ws(lit(""), vec![col("a"), lit("c"), lit("b")]);
let expected = concat(vec![col("a"), lit("cb")]);
test_simplify(expr, expected);
}

// the delimiter is a not-empty string
{
let expr = concat_ws(
lit("-"),
vec![
null.clone(),
col("c0"),
lit("hello"),
null.clone(),
lit("rust"),
col("c1"),
lit(""),
lit(""),
null,
],
);
let expected = concat_ws(
lit("-"),
vec![col("c0"), lit("hello-rust"), col("c1"), lit("-")],
);
test_simplify(expr, expected)
}
}

#[test]
fn test_simplify_concat_ws_with_null() {
let null = lit(ScalarValue::Utf8(None));
// null delimiter -> null
{
let expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]);
test_simplify(expr, null.clone());
}

// filter out null args
{
let expr = concat_ws(lit("|"), vec![col("c1"), null.clone(), col("c2")]);
let expected = concat_ws(lit("|"), vec![col("c1"), col("c2")]);
test_simplify(expr, expected);
}

// nested test
{
let sub_expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]);
let expr = concat_ws(lit("|"), vec![sub_expr, col("c3")]);
test_simplify(expr, concat_ws(lit("|"), vec![col("c3")]));
}

// null delimiter (nested)
{
let sub_expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]);
let expr = concat_ws(sub_expr, vec![col("c3"), col("c4")]);
test_simplify(expr, null);
}
}

#[test]
fn test_simplify_concat() {
let null = lit(ScalarValue::Utf8(None));
let expr = concat(vec![
null.clone(),
col("c0"),
lit("hello "),
null.clone(),
lit("rust"),
col("c1"),
lit(""),
null,
]);
let expected = concat(vec![col("c0"), lit("hello rust"), col("c1")]);
test_simplify(expr, expected)
}
90 changes: 1 addition & 89 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use std::str::FromStr;
use std::sync::OnceLock;

use crate::type_coercion::functions::data_types;
use crate::{FuncMonotonicity, Signature, TypeSignature, Volatility};
use crate::{FuncMonotonicity, Signature, Volatility};

use arrow::datatypes::DataType;
use datafusion_common::{plan_err, DataFusionError, Result};
Expand All @@ -39,15 +39,6 @@ pub enum BuiltinScalarFunction {
// math functions
/// coalesce
Coalesce,
// string functions
/// concat
Concat,
/// concat_ws
ConcatWithSeparator,
/// ends_with
EndsWith,
/// initcap
InitCap,
}

/// Maps the sql function name to `BuiltinScalarFunction`
Expand Down Expand Up @@ -101,10 +92,6 @@ impl BuiltinScalarFunction {
match self {
// Immutable scalar builtins
BuiltinScalarFunction::Coalesce => Volatility::Immutable,
BuiltinScalarFunction::Concat => Volatility::Immutable,
BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable,
BuiltinScalarFunction::EndsWith => Volatility::Immutable,
BuiltinScalarFunction::InitCap => Volatility::Immutable,
}
}

Expand All @@ -117,8 +104,6 @@ impl BuiltinScalarFunction {
/// 1. Perform additional checks on `input_expr_types` that are beyond the scope of `TypeSignature` validation.
/// 2. Deduce the output `DataType` based on the provided `input_expr_types`.
pub fn return_type(self, input_expr_types: &[DataType]) -> Result<DataType> {
use DataType::*;

// Note that this function *must* return the same type that the respective physical expression returns
// or the execution panics.

Expand All @@ -130,43 +115,18 @@ impl BuiltinScalarFunction {
let coerced_types = data_types(input_expr_types, &self.signature());
coerced_types.map(|types| types[0].clone())
}
BuiltinScalarFunction::Concat => Ok(Utf8),
BuiltinScalarFunction::ConcatWithSeparator => Ok(Utf8),
BuiltinScalarFunction::InitCap => {
utf8_to_str_type(&input_expr_types[0], "initcap")
}
BuiltinScalarFunction::EndsWith => Ok(Boolean),
}
}

/// Return the argument [`Signature`] supported by this function
pub fn signature(&self) -> Signature {
use DataType::*;
use TypeSignature::*;
// note: the physical expression must accept the type returned by this function or the execution panics.

// for now, the list is small, as we do not have many built-in functions.
match self {
BuiltinScalarFunction::Concat
| BuiltinScalarFunction::ConcatWithSeparator => {
Signature::variadic(vec![Utf8], self.volatility())
}
BuiltinScalarFunction::Coalesce => {
Signature::variadic_equal(self.volatility())
}
BuiltinScalarFunction::InitCap => {
Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility())
}

BuiltinScalarFunction::EndsWith => Signature::one_of(
vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![Utf8, LargeUtf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![LargeUtf8, LargeUtf8]),
],
self.volatility(),
),
}
}

Expand All @@ -182,11 +142,6 @@ impl BuiltinScalarFunction {
match self {
// conditional functions
BuiltinScalarFunction::Coalesce => &["coalesce"],

BuiltinScalarFunction::Concat => &["concat"],
BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"],
BuiltinScalarFunction::EndsWith => &["ends_with"],
BuiltinScalarFunction::InitCap => &["initcap"],
}
}
}
Expand All @@ -208,49 +163,6 @@ impl FromStr for BuiltinScalarFunction {
}
}

/// Creates a function to identify the optimal return type of a string function given
/// the type of its first argument.
///
/// If the input type is `LargeUtf8` or `LargeBinary` the return type is
/// `$largeUtf8Type`,
///
/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`,
macro_rules! get_optimal_return_type {
($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => {
fn $FUNC(arg_type: &DataType, name: &str) -> Result<DataType> {
Ok(match arg_type {
// LargeBinary inputs are automatically coerced to Utf8
DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
// Binary inputs are automatically coerced to Utf8
DataType::Utf8 | DataType::Binary => $utf8Type,
DataType::Null => DataType::Null,
DataType::Dictionary(_, value_type) => match **value_type {
DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
DataType::Utf8 | DataType::Binary => $utf8Type,
DataType::Null => DataType::Null,
_ => {
return plan_err!(
"The {} function can only accept strings, but got {:?}.",
name.to_uppercase(),
**value_type
);
}
},
data_type => {
return plan_err!(
"The {} function can only accept strings, but got {:?}.",
name.to_uppercase(),
data_type
);
}
})
}
};
}

// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size.
get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8);

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading