diff --git a/CHANGELOG.md b/CHANGELOG.md index 2077dd3e..48ad1d90 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ## Added +- `CAST` function [MR188](https://github.com/Qrlew/qrlew/pull/188) ## [0.5.1] - 2023-11-19 ## Added diff --git a/src/data_type/function.rs b/src/data_type/function.rs index d194fc42..b0745e68 100644 --- a/src/data_type/function.rs +++ b/src/data_type/function.rs @@ -1155,10 +1155,75 @@ Conversion function /// Builds the cast operator pub fn cast(into: DataType) -> impl Function { - // TODO Only cast as text is working for now match into { DataType::Text(t) if t == data_type::Text::full() => { - Pointwise::univariate(DataType::Any, DataType::text(), |v| v.to_string().into()) + Pointwise::univariate( + //DataType::Any, + DataType::Any, + DataType::text(), + |v| v.to_string().into()) + } + DataType::Float(f) if f == data_type::Float::full() => { + Pointwise::univariate( + DataType::text(), + DataType::float(), + |v| v.to_string().parse::().unwrap().into() + ) + } + DataType::Integer(i) if i == data_type::Integer::full() => { + Pointwise::univariate( + DataType::text(), + DataType::integer(), + |v| v.to_string().parse::().unwrap().into() + ) + } + DataType::Boolean(b) if b == data_type::Boolean::full() => { + Pointwise::univariate( + DataType::text(), + DataType::boolean(), + |v| { + let true_list = vec![ + "t".to_string(), "tr".to_string(), "tru".to_string(), "true".to_string(), + "y".to_string(), "ye".to_string(), "yes".to_string(), + "on".to_string(), + "1".to_string() + ]; + let false_list = vec![ + "f".to_string(), "fa".to_string(), "fal".to_string(), "fals".to_string(), "false".to_string(), + "n".to_string(), "no".to_string(), + "off".to_string(), + "0".to_string() + ]; + if true_list.contains(&v.to_string().to_lowercase()) { + true.into() + } else if false_list.contains(&v.to_string().to_lowercase()) { + false.into() + } else { + panic!() + } + } + ) + } + DataType::Date(d) if d == data_type::Date::full() => { + Pointwise::univariate( + DataType::text(), + DataType::date(), + |v| todo!() + ) + } + DataType::DateTime(d) if d == data_type::DateTime::full() => { + Pointwise::univariate( + DataType::text(), + DataType::date_time(), + |v| todo!() + ) + } + DataType::Time(t) if t == data_type::Time::full() => { + Pointwise::univariate( + DataType::text(), + DataType::time(), + |v| todo!() + ) } _ => todo!(), } @@ -1978,6 +2043,7 @@ mod tests { super::{value::Value, Struct}, *, }; + use chrono; #[test] fn test_argument_conversion() { @@ -3293,4 +3359,74 @@ mod tests { ]) ); } + + #[test] + fn test_cast_as_text() { + println!("Test cast as text"); + let fun = cast(DataType::text()); + println!("type = {}", fun); + println!("domain = {}", fun.domain()); + println!("co_domain = {}", fun.co_domain()); + println!("data_type = {}", fun.data_type()); + + let set = DataType::integer_values([1, 3, 4]); + let im = fun.super_image(&set).unwrap(); + println!("im({}) = {}", set, im); + assert!(im == DataType::text_values(["1".to_string(), "3".to_string(), "4".to_string()])); + + let set = DataType::integer_values([1, 3, 4]); + let im = fun.super_image(&set).unwrap(); + println!("im({}) = {}", set, im); + assert!(im == DataType::text_values(["1".to_string(), "3".to_string(), "4".to_string()])); + + let set = DataType::date_value(chrono::NaiveDate::from_ymd_opt(2015, 6, 3).unwrap()); + let im = fun.super_image(&set).unwrap(); + println!("im({}) = {}", set, im); + assert!(im == DataType::text_values(["2015-06-03".to_string()])); + } + + #[test] + fn test_cast_as_float() { + println!("Test cast as float"); + let fun = cast(DataType::float()); + println!("type = {}", fun); + println!("domain = {}", fun.domain()); + println!("co_domain = {}", fun.co_domain()); + println!("data_type = {}", fun.data_type()); + + let set = DataType::text_values(["1.5".to_string(), "3".to_string(), "4.555".to_string()]); + let im = fun.super_image(&set).unwrap(); + println!("im({}) = {}", set, im); + assert!(im == DataType::float_values([1.5, 3., 4.555])); + } + + #[test] + fn test_cast_as_integer() { + println!("\nTest cast as integer"); + let fun = cast(DataType::integer()); + println!("type = {}", fun); + println!("domain = {}", fun.domain()); + println!("co_domain = {}", fun.co_domain()); + println!("data_type = {}", fun.data_type()); + + let set = DataType::text_values(["1".to_string(), "3".to_string(), "4".to_string()]); + let im = fun.super_image(&set).unwrap(); + println!("im({}) = {}", set, im); + assert!(im == DataType::integer_values([1, 3, 4])); + } + + #[test] + fn test_cast_to_boolean() { + println!("\nTest cast as boolean"); + let fun = cast(DataType::boolean()); + println!("type = {}", fun); + println!("domain = {}", fun.domain()); + println!("co_domain = {}", fun.co_domain()); + println!("data_type = {}", fun.data_type()); + + let set = DataType::text_values(["1".to_string(), "tru".to_string()]); + let im = fun.super_image(&set).unwrap(); + println!("im({}) = {}", set, im); + assert!(im == DataType::boolean_value(true)); + } } diff --git a/src/expr/function.rs b/src/expr/function.rs index 5cd37204..e5e191a2 100644 --- a/src/expr/function.rs +++ b/src/expr/function.rs @@ -53,13 +53,16 @@ pub enum Function { CastAsText, CastAsFloat, CastAsInteger, + CastAsBoolean, CastAsDateTime, + CastAsDate, + CastAsTime, Least, Greatest, Rtrim, Ltrim, Substr, - SubstrWithSize, + SubstrWithSize } #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] @@ -119,7 +122,10 @@ impl Function { | Function::CastAsText | Function::CastAsFloat | Function::CastAsInteger + | Function::CastAsBoolean | Function::CastAsDateTime + | Function::CastAsDate + | Function::CastAsTime // Binary Functions | Function::Pow | Function::Position @@ -179,7 +185,10 @@ impl Function { | Function::CastAsText | Function::CastAsFloat | Function::CastAsInteger - | Function::CastAsDateTime => Arity::Unary, + | Function::CastAsBoolean + | Function::CastAsDateTime + | Function::CastAsDate + | Function::CastAsTime => Arity::Unary, // Binary Function Function::Pow | Function::Position @@ -260,7 +269,10 @@ impl fmt::Display for Function { Function::CastAsText => "cast_as_text", Function::CastAsInteger => "cast_as_integer", Function::CastAsFloat => "cast_as_float", + Function::CastAsBoolean => "cast_as_boolean", Function::CastAsDateTime => "cast_as_date_time", + Function::CastAsDate => "cast_as_date", + Function::CastAsTime => "cast_as_time", // Binary Functions Function::Pow => "pow", Function::Position => "position", diff --git a/src/expr/implementation.rs b/src/expr/implementation.rs index 15076fc3..9d313549 100644 --- a/src/expr/implementation.rs +++ b/src/expr/implementation.rs @@ -81,9 +81,12 @@ function_implementations!( { match x { Function::CastAsText => Arc::new(function::cast(DataType::text())), - Function::CastAsInteger => Arc::new(function::cast(DataType::integer())), - Function::CastAsFloat => Arc::new(function::cast(DataType::float())), - Function::CastAsDateTime => Arc::new(function::cast(DataType::date_time())), + Function::CastAsInteger => Arc::new(Optional::new(function::cast(DataType::integer()))), + Function::CastAsFloat => Arc::new(Optional::new(function::cast(DataType::float()))), + Function::CastAsBoolean => Arc::new(Optional::new(function::cast(DataType::boolean()))), + Function::CastAsDateTime => Arc::new(Optional::new(function::cast(DataType::date_time()))), + Function::CastAsDate => Arc::new(Optional::new(function::cast(DataType::date()))), + Function::CastAsTime => Arc::new(Optional::new(function::cast(DataType::time()))), Function::Concat(n) => Arc::new(function::concat(n)), Function::Random(n) => Arc::new(function::random(Mutex::new(OsRng))), //TODO change this initialization Function::Coalesce => Arc::new(function::coalesce()), diff --git a/src/expr/mod.rs b/src/expr/mod.rs index 7d717a2c..105a294e 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -281,7 +281,10 @@ impl_unary_function_constructors!( CastAsText, CastAsInteger, CastAsFloat, - CastAsDateTime + CastAsBoolean, + CastAsDateTime, + CastAsDate, + CastAsTime ); // TODO Complete that /// Implement binary function constructors @@ -2862,4 +2865,133 @@ mod tests { DataType::optional(DataType::text()) ); } + + #[test] + fn test_cast_integer_text() { + println!("integer => text"); + let expression = Expr::cast_as_text( + Expr::col("col1".to_string()) + ); + println!("expression = {}", expression); + println!("expression domain = {}", expression.domain()); + println!("expression co domain = {}", expression.co_domain()); + println!("expression data type = {}", expression.data_type()); + let set = DataType::structured([ + ("col1", DataType::integer_values([1, 2, 3])), + ]); + println!( + "expression super image = {}", + expression.super_image(&set).unwrap() + ); + assert_eq!( + expression.super_image(&set).unwrap(), + DataType::text_values(["1".to_string(), "2".to_string(), "3".to_string()]) + ); + + println!("\ntext => integer"); + let expression = Expr::cast_as_integer( + Expr::col("col1".to_string()) + ); + println!("expression = {}", expression); + println!("expression domain = {}", expression.domain()); + println!("expression co domain = {}", expression.co_domain()); + println!("expression data type = {}", expression.data_type()); + let set = DataType::structured([ + ("col1", DataType::text_values(["1".to_string(), "2".to_string(), "3".to_string()])), + ]); + println!( + "expression super image = {}", + expression.super_image(&set).unwrap() + ); + assert_eq!( + expression.super_image(&set).unwrap(), + DataType::integer_values([1, 2, 3]) + ); + } + + #[test] + fn test_cast_float_text() { + println!("float => text"); + let expression = Expr::cast_as_text( + Expr::col("col1".to_string()) + ); + println!("expression = {}", expression); + println!("expression domain = {}", expression.domain()); + println!("expression co domain = {}", expression.co_domain()); + println!("expression data type = {}", expression.data_type()); + let set = DataType::structured([ + ("col1", DataType::float_values([1.1, 2., 3.5])), + ]); + println!( + "expression super image = {}", + expression.super_image(&set).unwrap() + ); + assert_eq!( + expression.super_image(&set).unwrap(), + DataType::text_values(["1.1".to_string(), "2".to_string(), "3.5".to_string()]) + ); + + println!("\ntext => float"); + let expression = Expr::cast_as_float( + Expr::col("col1".to_string()) + ); + println!("expression = {}", expression); + println!("expression domain = {}", expression.domain()); + println!("expression co domain = {}", expression.co_domain()); + println!("expression data type = {}", expression.data_type()); + let set = DataType::structured([ + ("col1", DataType::text_values(["1.1".to_string(), "2".to_string(), "3.5".to_string()])), + ]); + println!( + "expression super image = {}", + expression.super_image(&set).unwrap() + ); + assert_eq!( + expression.super_image(&set).unwrap(), + DataType::float_values([1.1, 2., 3.5]) + ); + } + + #[test] + fn test_cast_boolean_text() { + println!("boolean => text"); + let expression = Expr::cast_as_text( + Expr::col("col1".to_string()) + ); + println!("expression = {}", expression); + println!("expression domain = {}", expression.domain()); + println!("expression co domain = {}", expression.co_domain()); + println!("expression data type = {}", expression.data_type()); + let set = DataType::structured([ + ("col1", DataType::boolean_values([true, false])), + ]); + println!( + "expression super image = {}", + expression.super_image(&set).unwrap() + ); + assert_eq!( + expression.super_image(&set).unwrap(), + DataType::text_values(["true".to_string(), "false".to_string()]) + ); + + println!("\ntext => boolean"); + let expression = Expr::cast_as_boolean( + Expr::col("col1".to_string()) + ); + println!("expression = {}", expression); + println!("expression domain = {}", expression.domain()); + println!("expression co domain = {}", expression.co_domain()); + println!("expression data type = {}", expression.data_type()); + let set = DataType::structured([ + ("col1", DataType::text_values(["n".to_string(), "fa".to_string(), "off".to_string()])), + ]); + println!( + "expression super image = {}", + expression.super_image(&set).unwrap() + ); + assert_eq!( + expression.super_image(&set).unwrap(), + DataType::boolean_value(false) + ); + } } diff --git a/src/expr/sql.rs b/src/expr/sql.rs index 30271157..b0dc1ae3 100644 --- a/src/expr/sql.rs +++ b/src/expr/sql.rs @@ -224,9 +224,36 @@ impl<'a> expr::Visitor<'a, ast::Expr> for FromExprVisitor { data_type: DataType::text().into(), format: None, }, - expr::function::Function::CastAsFloat => todo!(), - expr::function::Function::CastAsInteger => todo!(), - expr::function::Function::CastAsDateTime => todo!(), + expr::function::Function::CastAsFloat => ast::Expr::Cast { + expr: arguments[0].clone().into(), + data_type: DataType::float().into(), + format: None, + }, + expr::function::Function::CastAsInteger => ast::Expr::Cast { + expr: arguments[0].clone().into(), + data_type: DataType::integer().into(), + format: None, + }, + expr::function::Function::CastAsBoolean => ast::Expr::Cast { + expr: arguments[0].clone().into(), + data_type: DataType::boolean().into(), + format: None, + }, + expr::function::Function::CastAsDateTime => ast::Expr::Cast { + expr: arguments[0].clone().into(), + data_type: DataType::date_time().into(), + format: None, + }, + expr::function::Function::CastAsDate => ast::Expr::Cast { + expr: arguments[0].clone().into(), + data_type: DataType::date().into(), + format: None, + }, + expr::function::Function::CastAsTime => ast::Expr::Cast { + expr: arguments[0].clone().into(), + data_type: DataType::time().into(), + format: None, + }, } } // TODO implement this properly @@ -528,4 +555,38 @@ mod tests { println!("ast::expr = {gen_expr}"); assert_eq!(gen_expr, parse_expr("substr(a, 0, 5)").unwrap()); } + #[test] + fn test_cast() { + let str_expr = "cast(a as varchar)"; + let ast_expr: ast::Expr = parse_expr(str_expr).unwrap(); + let expr = Expr::try_from(&ast_expr).unwrap(); + println!("expr = {}", expr); + let gen_expr = ast::Expr::from(&expr); + println!("ast::expr = {gen_expr}"); + assert_eq!(ast_expr, gen_expr); + + let str_expr = "cast(a as bigint)"; + let ast_expr: ast::Expr = parse_expr(str_expr).unwrap(); + let expr = Expr::try_from(&ast_expr).unwrap(); + println!("expr = {}", expr); + let gen_expr = ast::Expr::from(&expr); + println!("ast::expr = {gen_expr}"); + assert_eq!(ast_expr, gen_expr); + + let str_expr = "cast(a as boolean)"; + let ast_expr: ast::Expr = parse_expr(str_expr).unwrap(); + let expr = Expr::try_from(&ast_expr).unwrap(); + println!("expr = {}", expr); + let gen_expr = ast::Expr::from(&expr); + println!("ast::expr = {gen_expr}"); + assert_eq!(ast_expr, gen_expr); + + let str_expr = "cast(a as float)"; + let ast_expr: ast::Expr = parse_expr(str_expr).unwrap(); + let expr = Expr::try_from(&ast_expr).unwrap(); + println!("expr = {}", expr); + let gen_expr = ast::Expr::from(&expr); + println!("ast::expr = {gen_expr}"); + assert_eq!(ast_expr, gen_expr); + } } diff --git a/src/sql/expr.rs b/src/sql/expr.rs index 98360dd0..4a390660 100644 --- a/src/sql/expr.rs +++ b/src/sql/expr.rs @@ -4,7 +4,7 @@ use super::{Error, Result}; use crate::{ - builder::{With, WithContext, WithoutContext}, + builder::{WithContext, WithoutContext}, expr::{identifier::Identifier, Expr, Value}, hierarchy::{Hierarchy, Path}, visitor::{self, Acceptor, Dependencies, Visited}, @@ -268,6 +268,13 @@ pub trait Visitor<'a, T: Clone> { fn in_list(&self, expr: T, list: Vec) -> T; fn trim(&self, expr: T, trim_where: &Option, trim_what: Option) -> T; fn substring(&self, expr: T, substring_from: Option, substring_for: Option) -> T; + fn cast_as_text(&self, expr: T) -> T; + fn cast_as_float(&self, expr: T) -> T; + fn cast_as_integer(&self, expr: T) -> T; + fn cast_as_boolean(&self, expr: T) -> T; + fn cast_as_date_time(&self, expr: T) -> T; + fn cast_as_date(&self, expr: T) -> T; + fn cast_as_time(&self, expr: T) -> T; } // For the visitor to be more convenient, we create a few auxiliary objects @@ -373,7 +380,80 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, ast::Expr, T> for V { expr, data_type, format: _, - } => todo!(), + } => match data_type { + //Text + ast::DataType::Character(_) + | ast::DataType::Char(_) + | ast::DataType::CharacterVarying(_) + |ast::DataType::CharVarying(_) + | ast::DataType::Varchar(_) + | ast::DataType::Nvarchar(_) + | ast::DataType::Uuid + | ast::DataType::CharacterLargeObject(_) + | ast::DataType::CharLargeObject(_) + | ast::DataType::Clob(_) + | ast::DataType::Text + | ast::DataType::String(_) => self.cast_as_text(dependencies.get(expr).clone()), + // Integer + //Bytes + ast::DataType::Binary(_) + | ast::DataType::Varbinary(_) + | ast::DataType::Blob(_) + | ast::DataType::Bytes(_) => todo!(), + //Float + ast::DataType::Numeric(_) + | ast::DataType::Decimal(_) + | ast::DataType::BigNumeric(_) + | ast::DataType::BigDecimal(_) + | ast::DataType::Dec(_) + | ast::DataType::Float(_) + | ast::DataType::Float4 + | ast::DataType::Float64 + | ast::DataType::Real + | ast::DataType::Float8 + | ast::DataType::Double + | ast::DataType::DoublePrecision => self.cast_as_float(dependencies.get(expr).clone()), + // Integer + ast::DataType::TinyInt(_) + | ast::DataType::UnsignedTinyInt(_) + | ast::DataType::Int2(_) + | ast::DataType::UnsignedInt2(_) + | ast::DataType::SmallInt(_) + | ast::DataType::UnsignedSmallInt(_) + | ast::DataType::MediumInt(_) + | ast::DataType::UnsignedMediumInt(_) + | ast::DataType::Int(_) + | ast::DataType::Int4(_) + | ast::DataType::Int64 + | ast::DataType::Integer(_) + | ast::DataType::UnsignedInt(_) + | ast::DataType::UnsignedInt4(_) + | ast::DataType::UnsignedInteger(_) + | ast::DataType::BigInt(_) + | ast::DataType::UnsignedBigInt(_) + | ast::DataType::Int8(_) + | ast::DataType::UnsignedInt8(_) => self.cast_as_integer(dependencies.get(expr).clone()), + // Boolean + ast::DataType::Bool + | ast::DataType::Boolean => self.cast_as_boolean(dependencies.get(expr).clone()), + // Date + ast::DataType::Date => self.cast_as_date(dependencies.get(expr).clone()), + // Time + ast::DataType::Time(_, _) => self.cast_as_time(dependencies.get(expr).clone()), + // DateTime + ast::DataType::Datetime(_) + | ast::DataType::Timestamp(_, _) => self.cast_as_date_time(dependencies.get(expr).clone()), + + ast::DataType::Interval => todo!(), + ast::DataType::JSON => todo!(), + ast::DataType::Regclass => todo!(), + ast::DataType::Bytea => todo!(), + ast::DataType::Custom(_, _) => todo!(), + ast::DataType::Array(_) => todo!(), + ast::DataType::Enum(_) => todo!(), + ast::DataType::Set(_) => todo!(), + ast::DataType::Struct(_) => todo!(), + }, ast::Expr::TryCast { expr, data_type, @@ -640,6 +720,34 @@ impl<'a> Visitor<'a, String> for DisplayVisitor { .unwrap_or("".to_string()), ) } + + fn cast_as_text(&self, expr: String) -> String { + format!("CAST ({} AS TEXT)", expr) + } + + fn cast_as_float(&self, expr: String) -> String { + format!("CAST ({} AS FLOAT)", expr) + } + + fn cast_as_integer(&self, expr: String) -> String { + format!("CAST ({} AS INTEGER)", expr) + } + + fn cast_as_boolean(&self, expr: String) -> String { + format!("CAST ({} AS BOOLEAN)", expr) + } + + fn cast_as_date_time(&self, expr: String) -> String { + format!("CAST ({} AS DATETIME)", expr) + } + + fn cast_as_date(&self, expr: String) -> String { + format!("CAST ({} AS DATE)", expr) + } + + fn cast_as_time(&self, expr: String) -> String { + format!("CAST ({} AS TIME)", expr) + } } /// A simple ast::Expr -> Expr conversion Visitor @@ -918,6 +1026,34 @@ impl<'a> Visitor<'a, Result> for TryIntoExprVisitor<'a> { }) .unwrap_or(Ok(Expr::substr(expr.clone()?, substring_from.clone()?))) } + + fn cast_as_text(&self, expr: Result) -> Result { + Ok(Expr::cast_as_text(expr.clone()?)) + } + + fn cast_as_float(&self, expr: Result) -> Result { + Ok(Expr::cast_as_float(expr.clone()?)) + } + + fn cast_as_integer(&self, expr: Result) -> Result { + Ok(Expr::cast_as_integer(expr.clone()?)) + } + + fn cast_as_boolean(&self, expr: Result) -> Result { + Ok(Expr::cast_as_boolean(expr.clone()?)) + } + + fn cast_as_date_time(&self, expr: Result) -> Result { + Ok(Expr::cast_as_date_time(expr.clone()?)) + } + + fn cast_as_date(&self, expr: Result) -> Result { + Ok(Expr::cast_as_date(expr.clone()?)) + } + + fn cast_as_time(&self, expr: Result) -> Result { + Ok(Expr::cast_as_time(expr.clone()?)) + } } /// Based on the TryIntoExprVisitor implement the TryFrom trait diff --git a/src/sql/mod.rs b/src/sql/mod.rs index c5869080..2242198a 100644 --- a/src/sql/mod.rs +++ b/src/sql/mod.rs @@ -139,4 +139,25 @@ mod tests { relation.display_dot(); } } + + #[test] + fn test_cast_queries() { + let mut database = postgresql::test_database(); + + for query in [ + "SELECT CAST(a AS text) FROM table_1", // float => text + "SELECT CAST(b AS text) FROM table_1", // integer => text + "SELECT CAST(c AS text) FROM table_1", // date => text + "SELECT CAST(z AS text) FROM table_2", // text => text + "SELECT CAST(x AS float) FROM table_2", // integer => float + "SELECT CAST('true' AS boolean) FROM table_2", // integer => float + ] { + let res1 = database.query(query).unwrap(); + let relation = Relation::try_from(parse(query).unwrap().with(&database.relations())).unwrap(); + let relation_query: &str = &ast::Query::from(&relation).to_string(); + println!("{query} => {relation_query}"); + let res2 = database.query(relation_query).unwrap(); + assert_eq!(res1, res2); + } + } }