From 1721dae53c769651a95b01ce40d098de8f0c0688 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Wed, 21 Jul 2021 10:31:52 +0800 Subject: [PATCH] fix 226 - fix concat - fix concat_ws - fix random - add unit tests --- datafusion/src/logical_plan/expr.rs | 126 ++++++++++++++++++++++++---- 1 file changed, 109 insertions(+), 17 deletions(-) diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 2eee140f47fe..c81302cd9c6e 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -1343,8 +1343,8 @@ impl Literal for ScalarValue { } macro_rules! make_literal { - ($TYPE:ty, $SCALAR:ident) => { - #[allow(missing_docs)] + ($TYPE:ty, $SCALAR:ident, $DOC: expr) => { + #[doc = $DOC] impl Literal for $TYPE { fn lit(&self) -> Expr { Expr::Literal(ScalarValue::$SCALAR(Some(self.clone()))) @@ -1353,27 +1353,55 @@ macro_rules! make_literal { }; } -make_literal!(bool, Boolean); -make_literal!(f32, Float32); -make_literal!(f64, Float64); -make_literal!(i8, Int8); -make_literal!(i16, Int16); -make_literal!(i32, Int32); -make_literal!(i64, Int64); -make_literal!(u8, UInt8); -make_literal!(u16, UInt16); -make_literal!(u32, UInt32); -make_literal!(u64, UInt64); +make_literal!(bool, Boolean, "literal expression containing a bool"); +make_literal!(f32, Float32, "literal expression containing an f32"); +make_literal!(f64, Float64, "literal expression containing an f64"); +make_literal!(i8, Int8, "literal expression containing an i8"); +make_literal!(i16, Int16, "literal expression containing an i16"); +make_literal!(i32, Int32, "literal expression containing an i32"); +make_literal!(i64, Int64, "literal expression containing an i64"); +make_literal!(u8, UInt8, "literal expression containing a u8"); +make_literal!(u16, UInt16, "literal expression containing a u16"); +make_literal!(u32, UInt32, "literal expression containing a u32"); +make_literal!(u64, UInt64, "literal expression containing a u64"); /// Create a literal expression pub fn lit(n: T) -> Expr { n.lit() } +/// Concatenates the text representations of all the arguments. NULL arguments are ignored. +pub fn concat(args: &[Expr]) -> Expr { + Expr::ScalarFunction { + fun: functions::BuiltinScalarFunction::Concat, + args: args.to_vec(), + } +} + +/// Concatenates all but the first argument, with separators. +/// The first argument is used as the separator string, and should not be NULL. +/// Other NULL arguments are ignored. +pub fn concat_ws(sep: impl Into, values: &[Expr]) -> Expr { + let mut args = vec![lit(sep.into())]; + args.extend_from_slice(values); + Expr::ScalarFunction { + fun: functions::BuiltinScalarFunction::ConcatWithSeparator, + args, + } +} + +/// Returns a random value in the range 0.0 <= x < 1.0 +pub fn random() -> Expr { + Expr::ScalarFunction { + fun: functions::BuiltinScalarFunction::Random, + args: vec![], + } +} + /// Create an convenience function representing a unary scalar function macro_rules! unary_scalar_expr { ($ENUM:ident, $FUNC:ident) => { - #[allow(missing_docs)] + #[doc = "this scalar function is not documented yet"] pub fn $FUNC(e: Expr) -> Expr { Expr::ScalarFunction { fun: functions::BuiltinScalarFunction::$ENUM, @@ -1397,7 +1425,6 @@ unary_scalar_expr!(Floor, floor); unary_scalar_expr!(Ceil, ceil); unary_scalar_expr!(Now, now); unary_scalar_expr!(Round, round); -unary_scalar_expr!(Random, random); unary_scalar_expr!(Trunc, trunc); unary_scalar_expr!(Abs, abs); unary_scalar_expr!(Signum, signum); @@ -1413,8 +1440,6 @@ unary_scalar_expr!(Btrim, btrim); unary_scalar_expr!(CharacterLength, character_length); unary_scalar_expr!(CharacterLength, length); unary_scalar_expr!(Chr, chr); -unary_scalar_expr!(Concat, concat); -unary_scalar_expr!(ConcatWithSeparator, concat_ws); unary_scalar_expr!(InitCap, initcap); unary_scalar_expr!(Left, left); unary_scalar_expr!(Lower, lower); @@ -1941,4 +1966,71 @@ mod tests { fn make_field(relation: &str, column: &str) -> DFField { DFField::new(Some(relation), column, DataType::Int8, false) } + + macro_rules! test_unary_scalar_expr { + ($ENUM:ident, $FUNC:ident) => {{ + if let Expr::ScalarFunction { fun, args } = $FUNC(col("tableA.a")) { + let name = functions::BuiltinScalarFunction::$ENUM; + assert_eq!(name, fun); + assert_eq!(1, args.len()); + } else { + assert!(false, "unexpected"); + } + }}; + } + + #[test] + fn scalar_function_definitions() { + test_unary_scalar_expr!(Sqrt, sqrt); + test_unary_scalar_expr!(Sin, sin); + test_unary_scalar_expr!(Cos, cos); + test_unary_scalar_expr!(Tan, tan); + test_unary_scalar_expr!(Asin, asin); + test_unary_scalar_expr!(Acos, acos); + test_unary_scalar_expr!(Atan, atan); + test_unary_scalar_expr!(Floor, floor); + test_unary_scalar_expr!(Ceil, ceil); + test_unary_scalar_expr!(Now, now); + test_unary_scalar_expr!(Round, round); + test_unary_scalar_expr!(Trunc, trunc); + test_unary_scalar_expr!(Abs, abs); + test_unary_scalar_expr!(Signum, signum); + test_unary_scalar_expr!(Exp, exp); + test_unary_scalar_expr!(Log2, log2); + test_unary_scalar_expr!(Log10, log10); + test_unary_scalar_expr!(Ln, ln); + test_unary_scalar_expr!(Ascii, ascii); + test_unary_scalar_expr!(BitLength, bit_length); + test_unary_scalar_expr!(Btrim, btrim); + test_unary_scalar_expr!(CharacterLength, character_length); + test_unary_scalar_expr!(CharacterLength, length); + test_unary_scalar_expr!(Chr, chr); + test_unary_scalar_expr!(InitCap, initcap); + test_unary_scalar_expr!(Left, left); + test_unary_scalar_expr!(Lower, lower); + test_unary_scalar_expr!(Lpad, lpad); + test_unary_scalar_expr!(Ltrim, ltrim); + test_unary_scalar_expr!(MD5, md5); + test_unary_scalar_expr!(OctetLength, octet_length); + test_unary_scalar_expr!(RegexpMatch, regexp_match); + test_unary_scalar_expr!(RegexpReplace, regexp_replace); + test_unary_scalar_expr!(Replace, replace); + test_unary_scalar_expr!(Repeat, repeat); + test_unary_scalar_expr!(Reverse, reverse); + test_unary_scalar_expr!(Right, right); + test_unary_scalar_expr!(Rpad, rpad); + test_unary_scalar_expr!(Rtrim, rtrim); + test_unary_scalar_expr!(SHA224, sha224); + test_unary_scalar_expr!(SHA256, sha256); + test_unary_scalar_expr!(SHA384, sha384); + test_unary_scalar_expr!(SHA512, sha512); + test_unary_scalar_expr!(SplitPart, split_part); + test_unary_scalar_expr!(StartsWith, starts_with); + test_unary_scalar_expr!(Strpos, strpos); + test_unary_scalar_expr!(Substr, substr); + test_unary_scalar_expr!(ToHex, to_hex); + test_unary_scalar_expr!(Translate, translate); + test_unary_scalar_expr!(Trim, trim); + test_unary_scalar_expr!(Upper, upper); + } }