diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ddd09e7..68e93323 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Added +- Implemented `Coalesce` [MR178](https://github.com/Qrlew/qrlew/pull/178) ## [0.4.10] - 2023-11-09 ### Fixed diff --git a/src/data_type/function.rs b/src/data_type/function.rs index 8de1c4f2..4023ba23 100644 --- a/src/data_type/function.rs +++ b/src/data_type/function.rs @@ -1047,6 +1047,57 @@ impl Function for InList { } } +#[derive(Clone, Debug)] +pub struct Coalesce; + +impl fmt::Display for Coalesce { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "coalesce") + } +} + +impl Function for Coalesce { + fn domain(&self) -> DataType { + DataType::from(data_type::Struct::from_data_types(&[ + DataType::Any, + DataType::Any, + ])) + } + + fn super_image(&self, set: &DataType) -> Result { + if !set.is_subset_of(&self.domain()) { + Err(Error::set_out_of_range(set, self.domain())) + } else { + if let DataType::Struct(struct_data_type) = set { + let data_type_1 = struct_data_type.field_from_index(0).1.as_ref().clone(); + let data_type_2 = struct_data_type.field_from_index(1).1.as_ref().clone(); + + Ok( + if let DataType::Optional(o) = data_type_1 { + o.data_type().super_union(&data_type_2)? + } else { + data_type_1 + } + ) + } else { + Err(Error::set_out_of_range(set, self.domain())) + } + } + } + + fn value(&self, arg: &Value) -> Result { + if let Value::Struct(struct_values) = arg { + if struct_values.field_from_index(0).1 == Arc::new(Value::none()) { + Ok(struct_values.field_from_index(1).1.as_ref().clone()) + } else { + Ok(struct_values.field_from_index(0).1.as_ref().clone()) + } + } else { + Err(Error::argument_out_of_range(arg, self.domain())) + } + } +} + /* We list here all the functions to expose */ @@ -1590,6 +1641,11 @@ pub fn in_list() -> impl Function { )) } +// Coalesce function +pub fn coalesce() -> impl Function { + Coalesce +} + /* Aggregation functions */ @@ -2943,4 +2999,46 @@ mod tests { .unwrap() ); } + + + #[test] + fn test_coalesce() { + println!("Test coalesce"); + let fun = coalesce(); + println!("type = {}", fun); + println!("domain = {}", fun.domain()); + println!("co_domain = {}", fun.co_domain()); + + let set = DataType::from(Struct::from_data_types(&[ + DataType::integer(), + DataType::text() + ])); + let im = fun.super_image(&set).unwrap(); + println!("im({}) = {}", set, im); + assert!(im == DataType::integer()); + + let set = DataType::from(Struct::from_data_types(&[ + DataType::optional(DataType::integer()), + DataType::text() + ])); + let im = fun.super_image(&set).unwrap(); + println!("im({}) = {}", set, im); + assert!(im == DataType::text()); + + let set = DataType::from(Struct::from_data_types(&[ + DataType::optional(DataType::integer_interval(1, 5)), + DataType::integer_value(20) + ])); + let im = fun.super_image(&set).unwrap(); + println!("im({}) = {}", set, im); + assert_eq!(im, DataType::integer_interval(1, 5).super_union(&DataType::integer_value(20)).unwrap()); + + let set = DataType::from(Struct::from_data_types(&[ + DataType::optional(DataType::integer()), + DataType::optional(DataType::text()) + ])); + let im = fun.super_image(&set).unwrap(); + println!("im({}) = {}", set, im); + assert!(im == DataType::optional(DataType::text())); + } } diff --git a/src/expr/function.rs b/src/expr/function.rs index 7144c842..d7a791bc 100644 --- a/src/expr/function.rs +++ b/src/expr/function.rs @@ -32,6 +32,7 @@ pub enum Function { BitwiseAnd, BitwiseXor, InList, + Coalesce, // Functions Exp, Ln, @@ -120,6 +121,7 @@ impl Function { | Function::Position | Function::Least | Function::Greatest + | Function::Coalesce // Ternary Function | Function::Case // Nary Function @@ -171,7 +173,7 @@ impl Function { | Function::CastAsInteger | Function::CastAsDateTime => Arity::Unary, // Binary Function - Function::Pow | Function::Position | Function::Least | Function::Greatest => { + Function::Pow | Function::Position | Function::Least | Function::Greatest | Function::Coalesce => { Arity::Nary(2) } // Ternary Function @@ -251,6 +253,7 @@ impl fmt::Display for Function { Function::Position => "position", Function::Least => "least", Function::Greatest => "greatest", + Function::Coalesce => "coalesce", // Ternary Functions Function::Case => "case", // Nary Functions diff --git a/src/expr/implementation.rs b/src/expr/implementation.rs index 2350a4ad..747552ae 100644 --- a/src/expr/implementation.rs +++ b/src/expr/implementation.rs @@ -83,6 +83,7 @@ function_implementations!( Function::CastAsDateTime => Arc::new(function::cast(DataType::date_time())), Function::Concat(n) => Arc::new(function::concat(n)), Function::Random(n) => Arc::new(function::random(Mutex::new(OsRng))), //TODO change this initialization + Function::Coalesce => Arc::new(function::coalesce()), _ => unreachable!(), } } diff --git a/src/expr/mod.rs b/src/expr/mod.rs index f65812af..4d797aa7 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -332,7 +332,8 @@ impl_binary_function_constructors!( Position, InList, Least, - Greatest + Greatest, + Coalesce ); /// Implement ternary function constructors @@ -2745,4 +2746,41 @@ mod tests { let v = Value::try_from(x).unwrap(); assert_eq!(v, Value::from(-5.)); } + + #[test] + fn test_coalesce() { + let expression = Expr::coalesce( + Expr::col("x".to_string()), + Expr::col("y".to_string()), + ); + println!("\nexpression = {}", expression); + println!("expression data type = {}", expression.data_type()); + let set = DataType::structured([ + ("x", DataType::float_interval(1., 10.)), + ("y", DataType::float_values([-2., 0.5])), + ]); + println!( + "expression super image = {}", + expression.super_image(&set).unwrap() + ); + assert_eq!(expression.super_image(&set).unwrap(), DataType::float_interval(1., 10.)); + + let expression = Expr::coalesce( + Expr::col("column_a".to_string()), + Expr::val(20.) + ); + println!("\nexpression = {}", expression); + println!("expression data type = {}", expression.data_type()); + let set = DataType::structured([ + ("column_a", DataType::optional(DataType::float_interval(0., 5.))), + ]); + println!( + "expression super image = {}", + expression.super_image(&set).unwrap() + ); + assert_eq!( + expression.super_image(&set).unwrap(), + DataType::float_interval(0., 5.).super_union(&DataType::float_value(20.)).unwrap() + ); + } } diff --git a/src/expr/rewriting.rs b/src/expr/rewriting.rs index 04e74493..d6f9509f 100644 --- a/src/expr/rewriting.rs +++ b/src/expr/rewriting.rs @@ -21,7 +21,7 @@ impl Expr { /// Gaussian noise based on [Box Muller transform](https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform) pub fn add_gaussian_noise(self, sigma: f64) -> Self { Expr::plus( - self, + Expr::coalesce(self, Expr::val(0.)), Expr::multiply(Expr::val(sigma), Expr::gaussian_noise()), ) } diff --git a/src/expr/sql.rs b/src/expr/sql.rs index 40fa4246..3ab37da7 100644 --- a/src/expr/sql.rs +++ b/src/expr/sql.rs @@ -179,7 +179,8 @@ impl<'a> expr::Visitor<'a, ast::Expr> for FromExprVisitor { | expr::function::Function::Upper | expr::function::Function::Random(_) | expr::function::Function::Least - | expr::function::Function::Greatest => ast::Expr::Function(ast::Function { + | expr::function::Function::Greatest + | expr::function::Function::Coalesce => ast::Expr::Function(ast::Function { name: ast::ObjectName(vec![ast::Ident::new(function.to_string())]), args: arguments .into_iter() @@ -471,4 +472,13 @@ mod tests { println!("ast::expr = {}", gen_expr.to_string()); assert_eq!(gen_expr.to_string(), "a IN (4, 5)".to_string(),); } + + #[test] + fn test_coalesce() { + let str_expr = "Coalesce(a, 5)"; + let ast_expr: ast::Expr = parse_expr(str_expr).unwrap(); + println!("ast::expr = {ast_expr}"); + println!("ast::expr = {:?}", ast_expr); + assert_eq!(ast_expr.to_string(), str_expr.to_string(),); + } }