diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f99f6547..f6b86897 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,7 +34,7 @@ jobs: - 5432:5432 mssql: - image: mcr.microsoft.com/mssql/server:2019-latest + image: mcr.microsoft.com/mssql/server:2019-CU28-ubuntu-20.04 env: # Set the SA password SA_PASSWORD: "Strong@Passw0rd" @@ -43,14 +43,27 @@ jobs: ports: - 1433:1433 options: >- - --health-cmd "/opt/mssql-tools/bin/sqlcmd -S localhost -U SA -P 'Strong@Passw0rd' -Q 'SELECT 1'" + --health-cmd "/opt/mssql-tools18/bin/sqlcmd -C -S localhost -U SA -P 'Strong@Passw0rd' -Q 'SELECT 1'" --health-interval 10s --health-timeout 5s --health-retries 5 - + mysql: + image: mysql:8.0 + env: + # The MySQL docker container requires these environment variables to be set + # so we can create and migrate the test database. + # See: https://hub.docker.com/_/mysql + MYSQL_DATABASE: qrlew_mysql_test + MYSQL_ROOT_PASSWORD: qrlew_test + ports: + # Opens port 3306 on service container and host + # https://docs.github.com/en/actions/using-containerized-services/about-service-containers + - 3306:3306 + # Before continuing, verify the mysql container is reachable from the ubuntu host + options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 steps: - uses: actions/checkout@v3 - name: Build - run: cargo build --features mssql,bigquery --verbose + run: cargo build --features mssql,bigquery,mysql --verbose - name: Run tests - run: cargo test --features mssql,bigquery --verbose + run: cargo test --features mssql,bigquery,mysql --verbose diff --git a/CHANGELOG.md b/CHANGELOG.md index 80b82e97..f98dd79f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.9.24] - 2024-09-27 +### Fixed +- mssql and bigquery translator +### Added +- mysql, databricks, hive, redshift translators +- mysql io connection for testing +- tool to get tables prefix + ## [0.9.23] - 2024-07-9 ### Fixed - fixing noise multiplier of the gaussian dp event which should be independent from the sensitivity. diff --git a/Cargo.toml b/Cargo.toml index 815ab729..bd64d8af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] authors = ["Nicolas Grislain "] name = "qrlew" -version = "0.9.23" +version = "0.9.24" edition = "2021" description = "Sarus Qrlew Engine" documentation = "https://docs.rs/qrlew" @@ -42,10 +42,15 @@ wiremock = { version = "0.6", optional = true } tempfile = { version = "3.6.0", optional = true } yup-oauth2 = { version = "9.0", optional = true } +# mysql dependencies +mysql = { version = "25.0.1", optional = true } +r2d2_mysql = { version = "25.0.0", optional = true } + [features] # Use SQLite for tests and examples sqlite = ["dep:rusqlite"] mssql = ["dep:sqlx", "dep:tokio"] +mysql = ["dep:mysql", "dep:r2d2_mysql"] bigquery = ["dep:gcp-bigquery-client", "dep:wiremock", "dep:tempfile", "dep:yup-oauth2", "dep:tokio"] # Tests checked_injections = [] diff --git a/src/dialect_translation/bigquery.rs b/src/dialect_translation/bigquery.rs index 48e28bef..4a6e43f0 100644 --- a/src/dialect_translation/bigquery.rs +++ b/src/dialect_translation/bigquery.rs @@ -55,6 +55,14 @@ impl RelationToQueryTranslator for BigQueryTranslator { kind: ast::CastKind::Cast, } } + fn cast_as_float(&self, expr: ast::Expr) -> ast::Expr { + ast::Expr::Cast { + expr: Box::new(expr), + data_type: ast::DataType::Float64, + format: None, + kind: ast::CastKind::Cast, + } + } fn substr(&self, exprs: Vec) -> ast::Expr { assert!(exprs.len() == 2); function_builder("SUBSTR", exprs, false) @@ -89,6 +97,51 @@ impl RelationToQueryTranslator for BigQueryTranslator { }) .collect() } + /// It converts EXTRACT(epoch FROM column) into + /// UNIX_SECONDS(CAST(col AS TIMESTAMP)) + fn extract_epoch(&self, expr: ast::Expr) -> ast::Expr { + let cast = ast::Expr::Cast { + expr: Box::new(expr), + data_type: ast::DataType::Timestamp(None, ast::TimezoneInfo::None), + format: None, + kind: ast::CastKind::Cast, + }; + function_builder("UNIX_SECONDS", vec![cast], false) + } + fn set_operation( + &self, + with: Vec, + operator: ast::SetOperator, + quantifier: ast::SetQuantifier, + left: ast::Select, + right: ast::Select, + ) -> ast::Query { + // UNION in big query must use a quantifier that can be either + // ALL or Distinct. + let translated_quantifier = match quantifier { + ast::SetQuantifier::All => ast::SetQuantifier::All, + _ => ast::SetQuantifier::Distinct, + }; + ast::Query { + with: (!with.is_empty()).then_some(ast::With { + recursive: false, + cte_tables: with, + }), + body: Box::new(ast::SetExpr::SetOperation { + op: operator, + set_quantifier: translated_quantifier, + left: Box::new(ast::SetExpr::Select(Box::new(left))), + right: Box::new(ast::SetExpr::Select(Box::new(right))), + }), + order_by: vec![], + limit: None, + offset: None, + fetch: None, + locks: vec![], + limit_by: vec![], + for_clause: None, + } + } } impl QueryToRelationTranslator for BigQueryTranslator { diff --git a/src/dialect_translation/databricks.rs b/src/dialect_translation/databricks.rs new file mode 100644 index 00000000..d1bcd6d0 --- /dev/null +++ b/src/dialect_translation/databricks.rs @@ -0,0 +1,61 @@ +use super::{function_builder, QueryToRelationTranslator, RelationToQueryTranslator}; +use sqlparser::{ast, dialect::DatabricksDialect}; + +use crate::expr::{self}; + +#[derive(Clone, Copy)] +pub struct DatabricksTranslator; + +impl RelationToQueryTranslator for DatabricksTranslator { + fn identifier(&self, value: &expr::Identifier) -> Vec { + value + .iter() + .map(|r| ast::Ident::with_quote('`', r)) + .collect() + } + + fn first(&self, expr: ast::Expr) -> ast::Expr { + expr + } + + fn var(&self, expr: ast::Expr) -> ast::Expr { + function_builder("VARIANCE", vec![expr], false) + } + + fn cast_as_text(&self, expr: ast::Expr) -> ast::Expr { + ast::Expr::Cast { + expr: Box::new(expr), + data_type: ast::DataType::String(None), + format: None, + kind: ast::CastKind::Cast, + } + } + fn cast_as_float(&self, expr: ast::Expr) -> ast::Expr { + function_builder("FLOAT", vec![expr], false) + } + /// It converts EXTRACT(epoch FROM column) into + /// UNIX_TIMESTAMP(col) + fn extract_epoch(&self, expr: ast::Expr) -> ast::Expr { + function_builder("UNIX_TIMESTAMP", vec![expr], false) + } + + fn format_float_value(&self, value: f64) -> ast::Expr { + let max_precision = 37; + let formatted = if value.abs() < 1e-10 || value.abs() > 1e10 { + // If the value is too small or too large, switch to scientific notation + format!("{:.precision$e}", value, precision = max_precision) + } else { + // Otherwise, use the default float formatting with the specified precision + format!("{}", value) + }; + ast::Expr::Value(ast::Value::Number(formatted, false)) + } +} + +impl QueryToRelationTranslator for DatabricksTranslator { + type D = DatabricksDialect; + + fn dialect(&self) -> Self::D { + DatabricksDialect {} + } +} diff --git a/src/dialect_translation/hive.rs b/src/dialect_translation/hive.rs index 8b137891..31444de5 100644 --- a/src/dialect_translation/hive.rs +++ b/src/dialect_translation/hive.rs @@ -1 +1,155 @@ +use super::{function_builder, QueryToRelationTranslator, RelationToQueryTranslator}; +use crate::{ + expr::{self}, + relation::{Join, Variant as _}, +}; +use sqlparser::{ast, dialect::HiveDialect}; +#[derive(Clone, Copy)] +pub struct HiveTranslator; + +// Using the same translations as in bigquery since it should be similar. +// HiveTranslator is not well tested at the moment. +impl RelationToQueryTranslator for HiveTranslator { + fn identifier(&self, value: &expr::Identifier) -> Vec { + value + .iter() + .map(|r| ast::Ident::with_quote('`', r)) + .collect() + } + + fn cte(&self, name: ast::Ident, _columns: Vec, query: ast::Query) -> ast::Cte { + ast::Cte { + alias: ast::TableAlias { + name, + columns: vec![], + }, + query: Box::new(query), + from: None, + materialized: None, + } + } + fn first(&self, expr: ast::Expr) -> ast::Expr { + expr + } + + fn mean(&self, expr: ast::Expr) -> ast::Expr { + function_builder("AVG", vec![expr], false) + } + + fn var(&self, expr: ast::Expr) -> ast::Expr { + function_builder("VARIANCE", vec![expr], false) + } + + fn std(&self, expr: ast::Expr) -> ast::Expr { + function_builder("STDDEV", vec![expr], false) + } + /// Converting LOG to LOG10 + fn log(&self, expr: ast::Expr) -> ast::Expr { + function_builder("LOG10", vec![expr], false) + } + fn cast_as_text(&self, expr: ast::Expr) -> ast::Expr { + ast::Expr::Cast { + expr: Box::new(expr), + data_type: ast::DataType::String(None), + format: None, + kind: ast::CastKind::Cast, + } + } + fn cast_as_float(&self, expr: ast::Expr) -> ast::Expr { + ast::Expr::Cast { + expr: Box::new(expr), + data_type: ast::DataType::Float64, + format: None, + kind: ast::CastKind::Cast, + } + } + fn substr(&self, exprs: Vec) -> ast::Expr { + assert!(exprs.len() == 2); + function_builder("SUBSTR", exprs, false) + } + fn substr_with_size(&self, exprs: Vec) -> ast::Expr { + assert!(exprs.len() == 3); + function_builder("SUBSTR", exprs, false) + } + /// Converting MD5(X) to TO_HEX(MD5(X)) + fn md5(&self, expr: ast::Expr) -> ast::Expr { + let md5_function = function_builder("MD5", vec![expr], false); + function_builder("TO_HEX", vec![md5_function], false) + } + fn random(&self) -> ast::Expr { + function_builder("RAND", vec![], false) + } + fn join_projection(&self, join: &Join) -> Vec { + join.left() + .schema() + .iter() + .map(|f| self.expr(&expr::Expr::qcol(Join::left_name(), f.name()))) + .chain( + join.right() + .schema() + .iter() + .map(|f| self.expr(&expr::Expr::qcol(Join::right_name(), f.name()))), + ) + .zip(join.schema().iter()) + .map(|(expr, field)| ast::SelectItem::ExprWithAlias { + expr, + alias: field.name().into(), + }) + .collect() + } + /// It converts EXTRACT(epoch FROM column) into + /// UNIX_SECONDS(CAST(col AS TIMESTAMP)) + fn extract_epoch(&self, expr: ast::Expr) -> ast::Expr { + let cast = ast::Expr::Cast { + expr: Box::new(expr), + data_type: ast::DataType::Timestamp(None, ast::TimezoneInfo::None), + format: None, + kind: ast::CastKind::Cast, + }; + function_builder("UNIX_SECONDS", vec![cast], false) + } + + fn set_operation( + &self, + with: Vec, + operator: ast::SetOperator, + quantifier: ast::SetQuantifier, + left: ast::Select, + right: ast::Select, + ) -> ast::Query { + // UNION in big query must use a quantifier that can be either + // ALL or Distinct. + let translated_quantifier = match quantifier { + ast::SetQuantifier::All => ast::SetQuantifier::All, + _ => ast::SetQuantifier::Distinct, + }; + ast::Query { + with: (!with.is_empty()).then_some(ast::With { + recursive: false, + cte_tables: with, + }), + body: Box::new(ast::SetExpr::SetOperation { + op: operator, + set_quantifier: translated_quantifier, + left: Box::new(ast::SetExpr::Select(Box::new(left))), + right: Box::new(ast::SetExpr::Select(Box::new(right))), + }), + order_by: vec![], + limit: None, + offset: None, + fetch: None, + locks: vec![], + limit_by: vec![], + for_clause: None, + } + } +} + +impl QueryToRelationTranslator for HiveTranslator { + type D = HiveDialect; + + fn dialect(&self) -> Self::D { + HiveDialect {} + } +} diff --git a/src/dialect_translation/mod.rs b/src/dialect_translation/mod.rs index a2d23043..54298c6a 100644 --- a/src/dialect_translation/mod.rs +++ b/src/dialect_translation/mod.rs @@ -23,10 +23,12 @@ use crate::{ use paste::paste; pub mod bigquery; +pub mod databricks; pub mod hive; pub mod mssql; pub mod mysql; pub mod postgresql; +pub mod redshiftsql; pub mod sqlite; // TODO: Add translatio errors @@ -268,6 +270,36 @@ macro_rules! relation_to_query_translator_trait_constructor { }) } + /// Build a set operation + fn set_operation( + &self, + with: Vec, + operator: ast::SetOperator, + quantifier: ast::SetQuantifier, + left: ast::Select, + right: ast::Select, + ) -> ast::Query { + ast::Query { + with: (!with.is_empty()).then_some(ast::With { + recursive: false, + cte_tables: with, + }), + body: Box::new(ast::SetExpr::SetOperation { + op: operator, + set_quantifier: quantifier, + left: Box::new(ast::SetExpr::Select(Box::new(left))), + right: Box::new(ast::SetExpr::Select(Box::new(right))), + }), + order_by: vec![], + limit: None, + offset: None, + fetch: None, + locks: vec![], + limit_by: vec![], + for_clause: None, + } + } + fn cte( &self, name: ast::Ident, @@ -358,9 +390,7 @@ macro_rules! relation_to_query_translator_trait_constructor { ast::Expr::Value(ast::Value::Number(format!("{}", **i), false)) } expr::Value::Enum(_) => todo!(), - expr::Value::Float(f) => { - ast::Expr::Value(ast::Value::Number(format!("{}", **f), false)) - } + expr::Value::Float(f) => self.format_float_value(**f), expr::Value::Text(t) => { ast::Expr::Value(ast::Value::SingleQuotedString(format!("{}", **t))) } @@ -388,6 +418,10 @@ macro_rules! relation_to_query_translator_trait_constructor { } } + fn format_float_value(&self, value: f64) -> ast::Expr { + ast::Expr::Value(ast::Value::Number(format!("{}", value), false)) + } + fn function( &self, function: &expr::function::Function, @@ -735,10 +769,24 @@ pub trait QueryToRelationTranslator { fn dialect(&self) -> Self::D; + // It checks that the identifier is conform to the dialect and returns and + // returns a result with expr::Identifier. + fn try_identifier(&self, ident: &ast::Ident) -> Result { + if let Some(quote_style) = ident.quote_style { + assert!(self.dialect().is_delimited_identifier_start(quote_style)); + } + Ok(expr::Identifier::from(ident)) + } + // It converts ast Expressions to sarus expressions fn try_expr(&self, expr: &ast::Expr, context: &Hierarchy) -> Result { match expr { ast::Expr::Function(func) => self.try_function(func, context), + ast::Expr::Identifier(ident) => { + // checking the identifier + let _ = self.try_identifier(ident)?; + expr::Expr::try_from(expr.with(context)) + } _ => expr::Expr::try_from(expr.with(context)), } } diff --git a/src/dialect_translation/mssql.rs b/src/dialect_translation/mssql.rs index 5d4731ef..206514ab 100644 --- a/src/dialect_translation/mssql.rs +++ b/src/dialect_translation/mssql.rs @@ -52,6 +52,10 @@ impl RelationToQueryTranslator for MsSqlTranslator { function_builder("RAND", vec![check_sum], false) } + fn char_length(&self, expr: ast::Expr) -> ast::Expr { + function_builder("LEN", vec![expr], false) + } + /// Converting MD5(X) to CONVERT(VARCHAR(MAX), HASHBYTES('MD5', X), 2) fn md5(&self, expr: ast::Expr) -> ast::Expr { // Construct HASHBYTES('MD5', X) @@ -131,18 +135,15 @@ impl RelationToQueryTranslator for MsSqlTranslator { function_builder("DATEDIFF", vec![second, unix, expr], false) } - // used during onboarding in order to have datetime with a proper format. - // This is not needed when we will remove the cast in string of the datetime - // during the onboarding - // CAST(col AS VARCHAR/TEXT) -> CONVERT(VARCHAR, col, 126) - - // TODO: some functions are not supported yet. - // EXTRACT(epoch FROM column) -> DATEDIFF(SECOND, '19700101', column) - // Concat(a, b) has to take at least 2 args, it can take empty string as well. - // onboarding, charset query: SELECT DISTINCT REGEXP_SPLIT_TO_TABLE(anon_2.name ,'') AS "regexp_split" ... - // onboarding, sampling, remove WHERE RAND(). - // onboarding CAST(col AS Boolean) -> CAST(col AS BIT) - // onboarding Literal True/Fale -> 1/0. + fn concat(&self, exprs: Vec) -> ast::Expr { + let literal = ast::Expr::Value(ast::Value::SingleQuotedString("".to_string())); + let expanded_exprs: Vec<_> = exprs + .iter() + .cloned() + .chain(std::iter::once(literal)) + .collect(); + function_builder("CONCAT", expanded_exprs, false) + } /// MSSQL queries don't support LIMIT but TOP in the SELECT statement instated fn query( @@ -156,11 +157,36 @@ impl RelationToQueryTranslator for MsSqlTranslator { limit: Option, offset: Option, ) -> ast::Query { - let top = limit.map(|e| ast::Top { + // A TOP can not be used + // in the same query or sub-query as a OFFSET. + let top = limit.filter(|_| offset.is_none()).map(|e| ast::Top { with_ties: false, percent: false, quantity: Some(ast::TopQuantity::Expr(e)), }); + + // ORDER BY clause is invalid in views, inline functions, derived tables, + // subqueries, and common table expressions, unless TOP, OFFSET or + // FOR XML is also specified. + let new_offset = match (order_by.is_empty(), &offset, &top) { + (false, None, None) => Some(ast::Offset { + value: ast::Expr::Value(ast::Value::Number("0".to_string(), false)), + rows: ast::OffsetRows::Rows, + }), + _ => offset.map(|o| ast::Offset { + value: o.value, + rows: ast::OffsetRows::Rows, + }), + }; + + let translated_projection: Vec = projection + .iter() + .map(case_from_boolean_select_item) + .collect(); + let translated_selection: Option = selection + .and_then(none_from_where_random) + .and_then(boolean_expr_from_identifier); + ast::Query { with: (!with.is_empty()).then_some(ast::With { recursive: false, @@ -169,11 +195,11 @@ impl RelationToQueryTranslator for MsSqlTranslator { body: Box::new(ast::SetExpr::Select(Box::new(ast::Select { distinct: None, top, - projection, + projection: translated_projection, into: None, from: vec![from], lateral_views: vec![], - selection, + selection: translated_selection, group_by, cluster_by: vec![], distribute_by: vec![], @@ -187,7 +213,7 @@ impl RelationToQueryTranslator for MsSqlTranslator { }))), order_by, limit: None, - offset: offset, + offset: new_offset, fetch: None, locks: vec![], limit_by: vec![], @@ -247,6 +273,18 @@ impl RelationToQueryTranslator for MsSqlTranslator { options: None, } } + + fn format_float_value(&self, value: f64) -> ast::Expr { + let max_precision = 37; + let formatted = if value.abs() < 1e-10 || value.abs() > 1e10 { + // If the value is too small or too large, switch to scientific notation + format!("{:.precision$e}", value, precision = max_precision) + } else { + // Otherwise, use the default float formatting with the specified precision + format!("{}", value) + }; + ast::Expr::Value(ast::Value::Number(formatted, false)) + } } impl QueryToRelationTranslator for MsSqlTranslator { @@ -369,6 +407,128 @@ fn extract_hashbyte_expression_if_valid(func_arg: &ast::FunctionArg) -> Option ast::SelectItem { + match select_item { + ast::SelectItem::ExprWithAlias { expr, alias } => ast::SelectItem::ExprWithAlias { + expr: case_from_boolean_expr(expr), + alias: alias.clone(), + }, + ast::SelectItem::UnnamedExpr(expr) => { + ast::SelectItem::UnnamedExpr(case_from_boolean_expr(expr)) + } + _ => select_item.clone(), + } +} + +fn case_from_boolean_expr(expr: &ast::Expr) -> ast::Expr { + match expr { + ast::Expr::UnaryOp { op, expr } => case_from_not_unary_op(op, expr), + ast::Expr::BinaryOp { op, left, right } => case_from_bool_binary_op(op, left, right), + _ => expr.clone(), + } +} + +fn case_from_not_unary_op(op: &ast::UnaryOperator, expr: &Box) -> ast::Expr { + match op { + ast::UnaryOperator::Not => { + // NOT( some_bool_expr ) -> CASE WHEN some_bool_expr THEN 0 ELSE 1 + let when_expr = vec![expr.as_ref().clone()]; + let then_expr = vec![ast::Expr::Value(ast::Value::Number("0".to_string(), false))]; + let else_expr = Box::new(ast::Expr::Value(ast::Value::Number("1".to_string(), false))); + ast::Expr::Case { + operand: None, + conditions: when_expr, + results: then_expr, + else_result: Some(else_expr), + } + } + _ => ast::Expr::UnaryOp { + op: op.clone(), + expr: expr.clone(), + }, + } +} + +// converting any boolean binray operation into CASE WHEN expr THEN 1 ELSE 0 +fn case_from_bool_binary_op( + op: &ast::BinaryOperator, + left: &Box, + right: &Box, +) -> ast::Expr { + let expr = ast::Expr::BinaryOp { + left: left.clone(), + op: op.clone(), + right: right.clone(), + }; + let when_expr = vec![expr.clone()]; + let true_expr = vec![ast::Expr::Value(ast::Value::Number("1".to_string(), false))]; + let false_expr = Box::new(ast::Expr::Value(ast::Value::Number("0".to_string(), false))); + + match op { + ast::BinaryOperator::Gt + | ast::BinaryOperator::Lt + | ast::BinaryOperator::GtEq + | ast::BinaryOperator::LtEq + | ast::BinaryOperator::Eq + | ast::BinaryOperator::NotEq + | ast::BinaryOperator::And + | ast::BinaryOperator::Or + | ast::BinaryOperator::Xor + | ast::BinaryOperator::BitwiseOr + | ast::BinaryOperator::BitwiseAnd + | ast::BinaryOperator::BitwiseXor => ast::Expr::Case { + operand: None, + conditions: when_expr, + results: true_expr, + else_result: Some(false_expr), + }, + _ => expr, + } +} + +// WHERE expretion modifications: + +/// Often sampling queries uses WHERE RAND(CHECKSUM(NEWID())) < x but in mssql +/// this doesn't associate a random value for each row. +/// Use ruther this approach to sample: +/// https://www.sqlservercentral.com/forums/topic/whats-the-best-way-to-get-a-sample-set-of-a-big-table-without-primary-key#post-1948778 +/// Careful!! If RAND function is found the WHERE will be set to None. +fn none_from_where_random(expr: ast::Expr) -> Option { + if has_rand_func(&expr) { + None + } else { + Some(expr) + } +} + +// It checks recursively if the Expr is RAND function. +fn has_rand_func(expr: &ast::Expr) -> bool { + match expr { + ast::Expr::Function(func) => { + let ast::Function { name, .. } = func; + let rand_func_name = ast::ObjectName(vec![ast::Ident::from("RAND")]); + &rand_func_name == name + } + ast::Expr::BinaryOp { left, .. } => has_rand_func(left.as_ref()), + ast::Expr::Nested(expr) => has_rand_func(expr.as_ref()), + _ => false, + } +} + +// In Mssql WHERE col is not accepted. +// This function converts WHERE col -> WHERE col=1 +fn boolean_expr_from_identifier(expr: ast::Expr) -> Option { + match expr { + ast::Expr::Identifier(_) => Some(ast::Expr::BinaryOp { + left: Box::new(expr), + op: ast::BinaryOperator::Eq, + right: Box::new(ast::Expr::Value(ast::Value::Number("1".to_string(), false))), + }), + _ => Some(expr), + } +} + // method to override DataType -> ast::DataType fn translate_data_type(dtype: DataType) -> ast::DataType { match dtype { @@ -391,6 +551,7 @@ mod tests { builder::{Ready, With}, data_type::DataType, dialect_translation::RelationWithTranslator, + display::Dot, expr::Expr, io::{mssql, Database as _}, namer, @@ -399,6 +560,21 @@ mod tests { }; use std::sync::Arc; + #[test] + fn test_coalesce() { + let mut database = mssql::test_database(); + let relations = database.relations(); + + let query = "SELECT COALESCE(a) FROM table_1 LIMIT 30"; + + let relation = Relation::try_from(With::with(&parse(query).unwrap(), &relations)).unwrap(); + relation.display_dot().unwrap(); + let rel_with_traslator = RelationWithTranslator(&relation, MsSqlTranslator); + let translated_query = &ast::Query::from(rel_with_traslator).to_string()[..]; + println!("{}", translated_query); + let _ = database.query(translated_query).unwrap(); + } + #[test] fn test_limit() { let mut database = mssql::test_database(); @@ -415,6 +591,38 @@ mod tests { let _ = database.query(translated_query).unwrap(); } + #[test] + fn test_not() { + let mut database = mssql::test_database(); + let relations = database.relations(); + + let query = "SELECT NOT (a IS NULL) AS col FROM table_1"; + + let relation = Relation::try_from(With::with(&parse(query).unwrap(), &relations)).unwrap(); + relation.display_dot().unwrap(); + + let rel_with_traslator = RelationWithTranslator(&relation, MsSqlTranslator); + let translated_query = ast::Query::from(rel_with_traslator); + println!("{}", translated_query); + let _ = database.query(&translated_query.to_string()[..]).unwrap(); + } + + #[test] + fn test_where_rand() { + let mut database = mssql::test_database(); + let relations = database.relations(); + + let query = "SELECT * FROM table_2 WHERE RANDOM()) < (0.5)"; + + let relation = Relation::try_from(With::with(&parse(query).unwrap(), &relations)).unwrap(); + relation.display_dot().unwrap(); + + let rel_with_traslator = RelationWithTranslator(&relation, MsSqlTranslator); + let translated_query = ast::Query::from(rel_with_traslator); + println!("{}", translated_query); + let _ = database.query(&translated_query.to_string()).unwrap(); + } + #[test] fn test_cast() { let mut database = mssql::test_database(); diff --git a/src/dialect_translation/mysql.rs b/src/dialect_translation/mysql.rs index 8b137891..f0ee586d 100644 --- a/src/dialect_translation/mysql.rs +++ b/src/dialect_translation/mysql.rs @@ -1 +1,472 @@ +use super::{function_builder, QueryToRelationTranslator, RelationToQueryTranslator, Result}; +use crate::{ + data_type::DataTyped as _, + expr::{self}, + hierarchy::Hierarchy, + relation::{Table, Variant as _}, + DataType, WithoutContext as _, +}; +use sqlparser::{ast, dialect::MySqlDialect}; +#[derive(Clone, Copy)] +pub struct MySqlTranslator; + +impl RelationToQueryTranslator for MySqlTranslator { + fn first(&self, expr: ast::Expr) -> ast::Expr { + expr + } + + fn mean(&self, expr: ast::Expr) -> ast::Expr { + function_builder("AVG", vec![expr], false) + } + + fn var(&self, expr: ast::Expr) -> ast::Expr { + function_builder("VARIANCE", vec![expr], false) + } + + fn std(&self, expr: ast::Expr) -> ast::Expr { + function_builder("STDDEV", vec![expr], false) + } + + fn identifier(&self, value: &expr::Identifier) -> Vec { + value + .iter() + .map(|r| ast::Ident::with_quote('`', r)) + .collect() + } + fn insert(&self, prefix: &str, table: &Table) -> ast::Statement { + ast::Statement::Insert(ast::Insert { + or: None, + into: true, + table_name: ast::ObjectName(self.identifier(&(table.path().clone().into()))), + table_alias: None, + columns: table + .schema() + .iter() + .map(|f| self.identifier(&(f.name().into()))[0].clone()) + .collect(), + overwrite: false, + source: Some(Box::new(ast::Query { + with: None, + body: Box::new(ast::SetExpr::Values(ast::Values { + explicit_row: false, + rows: vec![(1..=table.schema().len()) + .map(|_| ast::Expr::Value(ast::Value::Placeholder(format!("{prefix}")))) + .collect()], + })), + order_by: vec![], + limit: None, + limit_by: vec![], + offset: None, + fetch: None, + locks: vec![], + for_clause: None, + })), + partitioned: None, + after_columns: vec![], + table: false, + on: None, + returning: None, + ignore: false, + replace_into: false, + priority: None, + insert_alias: None, + }) + } + fn create(&self, table: &Table) -> ast::Statement { + ast::Statement::CreateTable { + or_replace: false, + temporary: false, + external: false, + global: None, + if_not_exists: false, + transient: false, + name: ast::ObjectName(self.identifier(&(table.path().clone().into()))), + columns: table + .schema() + .iter() + .map(|f| ast::ColumnDef { + name: self.identifier(&(f.name().into()))[0].clone(), + // Need to override some convertions + data_type: { translate_data_type(f.data_type()) }, + collation: None, + options: if let DataType::Optional(_) = f.data_type() { + vec![] + } else { + vec![ast::ColumnOptionDef { + name: None, + option: ast::ColumnOption::NotNull, + }] + }, + }) + .collect(), + constraints: vec![], + hive_distribution: ast::HiveDistributionStyle::NONE, + hive_formats: None, + table_properties: vec![], + with_options: vec![], + file_format: None, + location: None, + query: None, + without_rowid: false, + like: None, + clone: None, + engine: None, + default_charset: None, + collation: None, + on_commit: None, + on_cluster: None, + order_by: None, + strict: false, + comment: None, + auto_increment_offset: None, + partition_by: None, + cluster_by: None, + options: None, + } + } + fn random(&self) -> ast::Expr { + function_builder("RAND", vec![], false) + } + /// Converting LOG to LOG10 + fn log(&self, expr: ast::Expr) -> ast::Expr { + function_builder("LOG10", vec![expr], false) + } + fn cast_as_text(&self, expr: ast::Expr) -> ast::Expr { + ast::Expr::Cast { + expr: Box::new(expr), + data_type: ast::DataType::Char(None), + format: None, + kind: ast::CastKind::Cast, + } + } + fn extract_epoch(&self, expr: ast::Expr) -> ast::Expr { + function_builder("UNIX_TIMESTAMP", vec![expr], false) + } + /// For mysql CAST(expr AS INTEGER) should be converted to + /// CAST(expr AS SIGNED [INTEGER]) which produces a BigInt value. + /// CONVERT can be also used as CONVERT(expr, SIGNED) + /// however ast::DataType doesn't support SIGNED [INTEGER]. + /// We fix it by creating a function CONVERT(expr, SIGNED). + fn cast_as_integer(&self, expr: ast::Expr) -> ast::Expr { + let signed = ast::Expr::Identifier(ast::Ident { + value: "SIGNED".to_string(), + quote_style: None, + }); + function_builder("CONVERT", vec![expr, signed], false) + } + + // encode(source, 'escape') -> source + // encode(source, 'hex') -> hex(source) + // encode(source, 'base64') -> to_base64(source) + fn encode(&self, exprs: Vec) -> ast::Expr { + assert_eq!(exprs.len(), 2); + let source = exprs[0].clone(); + match &exprs[1] { + ast::Expr::Value(ast::Value::SingleQuotedString(s)) if s == &"hex".to_string() => { + function_builder("HEX", vec![source], false) + } + ast::Expr::Value(ast::Value::SingleQuotedString(s)) if s == &"base64".to_string() => { + function_builder("TO_BASE64", vec![source], false) + } + _ => source, + } + } + + // decode(source, 'hex') -> CONVERT(unhex(source) USING utf8mb4) + // decode(source, 'escape') -> CONVERT(source USING utf8mb4) + // decode(source, 'base64') -> CONVERT(from_base64(source) USING utf8mb4) + fn decode(&self, exprs: Vec) -> ast::Expr { + assert_eq!(exprs.len(), 2); + let source = exprs[0].clone(); + let binary_expr = match &exprs[1] { + ast::Expr::Value(ast::Value::SingleQuotedString(s)) if s == &"hex".to_string() => { + function_builder("UNHEX", vec![source], false) + } + ast::Expr::Value(ast::Value::SingleQuotedString(s)) if s == &"base64".to_string() => { + function_builder("FROM_BASE64", vec![source], false) + } + _ => source, + }; + let char_enc = ast::ObjectName(vec![ast::Ident { + value: "utf8mb4".to_string(), + quote_style: None, + }]); + ast::Expr::Convert { + expr: Box::new(binary_expr), + data_type: None, + charset: Some(char_enc), + target_before_value: false, + styles: vec![], + } + } +} + +impl QueryToRelationTranslator for MySqlTranslator { + type D = MySqlDialect; + + fn dialect(&self) -> Self::D { + MySqlDialect {} + } + + fn try_function( + &self, + func: &ast::Function, + context: &Hierarchy, + ) -> Result { + let function_name: &str = &func.name.0.iter().next().unwrap().value.to_lowercase()[..]; + let converted = self.try_function_args(func.args.clone(), context)?; + + match function_name { + "log" => self.try_ln(func, context), + "log10" => self.try_log(func, context), + "convert" => self.try_md5(func, context), + "unhex" => try_encode_decode(converted, EncodeDecodeFormat::Hex), + "from_base64" => try_encode_decode(converted, EncodeDecodeFormat::Base64), + "hex" => try_encode(converted, EncodeDecodeFormat::Hex), + "to_base64" => try_encode(converted, EncodeDecodeFormat::Base64), + _ => { + let expr = ast::Expr::Function(func.clone()); + expr::Expr::try_from(expr.with(context)) + } + } + } +} + +// method to override DataType -> ast::DataType +fn translate_data_type(dtype: DataType) -> ast::DataType { + match dtype { + DataType::Text(_) => ast::DataType::Varchar(Some(ast::CharacterLength::IntegerLength { + length: 255, + unit: None, + })), + //DataType::Boolean(_) => Boolean should be displayed as BIT for MSSQL, + // SQLParser doesn't support the BIT DataType (mssql equivalent of bool) + DataType::Optional(o) => translate_data_type(o.data_type().clone()), + _ => dtype.into(), + } +} +enum EncodeDecodeFormat { + Hex, + Base64, +} + +// unhex(source) -> encode(decode(source, 'hex'), 'escape') +// from_base64(source) -> encode(decode(source, 'base_64'), 'escape') +fn try_encode_decode(exprs: Vec, format: EncodeDecodeFormat) -> Result { + assert_eq!(exprs.len(), 1); + let format = match format { + EncodeDecodeFormat::Hex => expr::Expr::val("hex".to_string()), + EncodeDecodeFormat::Base64 => expr::Expr::val("base64".to_string()), + }; + let decode = expr::Expr::decode(exprs[0].clone(), format); + let escape = expr::Expr::val("escape".to_string()); + Ok(expr::Expr::encode(decode, escape)) +} + +// hex(source) -> encode(source, 'hex') +// to_base64(source) -> encode(source, 'base64') +fn try_encode(exprs: Vec, format: EncodeDecodeFormat) -> Result { + assert_eq!(exprs.len(), 1); + let format = match format { + EncodeDecodeFormat::Hex => expr::Expr::val("hex".to_string()), + EncodeDecodeFormat::Base64 => expr::Expr::val("base64".to_string()), + }; + Ok(expr::Expr::encode(exprs[0].clone(), format)) +} + +#[cfg(test)] +#[cfg(feature = "mysql")] +mod tests { + use itertools::Itertools as _; + + use super::*; + use crate::{ + dialect_translation::{postgresql::PostgreSqlTranslator, RelationWithTranslator}, + io::{mysql, postgresql, Database as _}, + relation::Relation, + sql::{self, relation::QueryWithRelations}, + }; + + fn try_from_mssql_query( + mysql_query: &str, + relations: Hierarchy>, + ) -> Relation { + let parsed_query = + sql::relation::parse_with_dialect(mysql_query, MySqlTranslator.dialect()).unwrap(); + // let parsed_query = parse(mysql_query).unwrap(); + let query_with_translator = ( + QueryWithRelations::new(&parsed_query, &relations), + MySqlTranslator, + ); + Relation::try_from(query_with_translator).unwrap() + } + + #[test] + fn test_unhex() { + let mut mysql_database = mysql::test_database(); + let mut psql_database = postgresql::test_database(); + let relations = mysql_database.relations(); + + let initial_mysql_query = "SELECT unhex('50726976617465') FROM table_2 LIMIT 1"; + let rel = try_from_mssql_query(initial_mysql_query, relations); + + let rel_with_traslator = RelationWithTranslator(&rel, PostgreSqlTranslator); + let psql_query = &ast::Query::from(rel_with_traslator).to_string()[..]; + + let rel_with_traslator = RelationWithTranslator(&rel, MySqlTranslator); + let mysql_query = &ast::Query::from(rel_with_traslator).to_string()[..]; + + println!("{}", initial_mysql_query); + println!("{}", mysql_query); + println!("{}", psql_query); + let res_initial_mysql = mysql_database + .query(initial_mysql_query) + .unwrap() + .iter() + .map(ToString::to_string) + .join("\n"); + + let res_mysql = mysql_database + .query(mysql_query) + .unwrap() + .iter() + .map(ToString::to_string) + .join("\n"); + let res_psql = psql_database + .query(psql_query) + .unwrap() + .iter() + .map(ToString::to_string) + .join("\n"); + assert_eq!(res_mysql, "(Private)".to_string()); + assert_eq!(res_mysql, res_psql); + assert_eq!(res_mysql, res_initial_mysql) + } + + #[test] + fn test_hex() { + let mut mysql_database = mysql::test_database(); + let mut psql_database = postgresql::test_database(); + let relations = mysql_database.relations(); + + let initial_mysql_query = "SELECT hex('Private') FROM table_2 LIMIT 1"; + let rel = try_from_mssql_query(initial_mysql_query, relations); + + let rel_with_traslator = RelationWithTranslator(&rel, PostgreSqlTranslator); + let psql_query = &ast::Query::from(rel_with_traslator).to_string()[..]; + + let rel_with_traslator = RelationWithTranslator(&rel, MySqlTranslator); + let mysql_query = &ast::Query::from(rel_with_traslator).to_string()[..]; + + println!("{}", initial_mysql_query); + println!("{}", mysql_query); + println!("{}", psql_query); + let res_initial_mysql = mysql_database + .query(initial_mysql_query) + .unwrap() + .iter() + .map(ToString::to_string) + .join("\n"); + + let res_mysql = mysql_database + .query(mysql_query) + .unwrap() + .iter() + .map(ToString::to_string) + .join("\n"); + let res_psql = psql_database + .query(psql_query) + .unwrap() + .iter() + .map(ToString::to_string) + .join("\n"); + assert_eq!(res_mysql, "(50726976617465)".to_string()); + assert_eq!(res_mysql, res_psql); + assert_eq!(res_mysql, res_initial_mysql) + } + + #[test] + fn test_from_base64() { + let mut mysql_database = mysql::test_database(); + let mut psql_database = postgresql::test_database(); + let relations = mysql_database.relations(); + + let initial_mysql_query = "SELECT from_base64('YWJj') FROM table_2 LIMIT 1"; + let rel = try_from_mssql_query(initial_mysql_query, relations); + + let rel_with_traslator = RelationWithTranslator(&rel, PostgreSqlTranslator); + let psql_query = &ast::Query::from(rel_with_traslator).to_string()[..]; + + let rel_with_traslator = RelationWithTranslator(&rel, MySqlTranslator); + let mysql_query = &ast::Query::from(rel_with_traslator).to_string()[..]; + + println!("{}", initial_mysql_query); + println!("{}", mysql_query); + println!("{}", psql_query); + let res_initial_mysql = mysql_database + .query(initial_mysql_query) + .unwrap() + .iter() + .map(ToString::to_string) + .join("\n"); + + let res_mysql = mysql_database + .query(mysql_query) + .unwrap() + .iter() + .map(ToString::to_string) + .join("\n"); + let res_psql = psql_database + .query(psql_query) + .unwrap() + .iter() + .map(ToString::to_string) + .join("\n"); + assert_eq!(res_mysql, res_psql); + assert_eq!(res_mysql, res_initial_mysql); + assert_eq!(res_mysql, "(abc)".to_string()); + } + + #[test] + fn test_to_base64() { + let mut mysql_database = mysql::test_database(); + let mut psql_database = postgresql::test_database(); + let relations = mysql_database.relations(); + + let initial_mysql_query = "SELECT TO_BASE64('abc') FROM table_2 LIMIT 1"; + let rel = try_from_mssql_query(initial_mysql_query, relations); + + let rel_with_traslator = RelationWithTranslator(&rel, PostgreSqlTranslator); + let psql_query = &ast::Query::from(rel_with_traslator).to_string()[..]; + + let rel_with_traslator = RelationWithTranslator(&rel, MySqlTranslator); + let mysql_query = &ast::Query::from(rel_with_traslator).to_string()[..]; + + println!("{}", initial_mysql_query); + println!("{}", mysql_query); + println!("{}", psql_query); + let res_initial_mysql = mysql_database + .query(initial_mysql_query) + .unwrap() + .iter() + .map(ToString::to_string) + .join("\n"); + + let res_mysql = mysql_database + .query(mysql_query) + .unwrap() + .iter() + .map(ToString::to_string) + .join("\n"); + let res_psql = psql_database + .query(psql_query) + .unwrap() + .iter() + .map(ToString::to_string) + .join("\n"); + assert_eq!(res_mysql, res_psql); + assert_eq!(res_mysql, res_initial_mysql); + assert_eq!(res_mysql, "(YWJj)".to_string()); + } +} diff --git a/src/dialect_translation/postgresql.rs b/src/dialect_translation/postgresql.rs index 1f8e806d..0bd2c0d1 100644 --- a/src/dialect_translation/postgresql.rs +++ b/src/dialect_translation/postgresql.rs @@ -1,5 +1,15 @@ +use crate::{ + expr::{self, Identifier}, + hierarchy::Hierarchy, + WithoutContext as _, +}; + use super::{function_builder, QueryToRelationTranslator, RelationToQueryTranslator}; -use sqlparser::{ast, dialect::PostgreSqlDialect}; +use crate::sql::{Error, Result}; +use sqlparser::{ + ast, + dialect::{Dialect, PostgreSqlDialect}, +}; #[derive(Clone, Copy)] pub struct PostgreSqlTranslator; @@ -80,6 +90,42 @@ impl QueryToRelationTranslator for PostgreSqlTranslator { fn dialect(&self) -> Self::D { PostgreSqlDialect {} } + + fn try_identifier(&self, ident: &ast::Ident) -> Result { + if let Some(quote_style) = ident.quote_style { + let identifier_quote_style = self.dialect().identifier_quote_style(""); + if identifier_quote_style != Some(quote_style) { + return Err(Error::Other(format!( + "Wrong quoting of {} Identifier", + ident + ))); + }; + } + Ok(expr::Identifier::from(ident)) + } + + // Fail if non postgres functions + fn try_function( + &self, + func: &ast::Function, + context: &Hierarchy, + ) -> Result { + let function_name: &str = &func.name.0.iter().next().unwrap().value.to_lowercase()[..]; + + match function_name { + "rand" | "unhex" | "from_hex" | "choose" | "newid" | "dayname" | "date_format" + | "quarter" | "datetime_diff" | "date" | "from_unixtime" | "unix_timestamp" => { + Err(Error::ParsingError(format!( + "`{}` is not a postgres function", + function_name.to_uppercase() + ))) + } + _ => { + let expr = ast::Expr::Function(func.clone()); + expr::Expr::try_from(expr.with(context)) + } + } + } } #[cfg(test)] diff --git a/src/dialect_translation/redshiftsql.rs b/src/dialect_translation/redshiftsql.rs new file mode 100644 index 00000000..9249be1d --- /dev/null +++ b/src/dialect_translation/redshiftsql.rs @@ -0,0 +1,84 @@ +use super::{function_builder, QueryToRelationTranslator, RelationToQueryTranslator}; +use sqlparser::{ast, dialect::RedshiftSqlDialect}; + +#[derive(Clone, Copy)] +pub struct RedshiftSqlTranslator; + +// Copied from postgres since it is very similar. +impl RelationToQueryTranslator for RedshiftSqlTranslator { + fn first(&self, expr: ast::Expr) -> ast::Expr { + expr + } + + fn mean(&self, expr: ast::Expr) -> ast::Expr { + function_builder("AVG", vec![expr], false) + } + + fn var(&self, expr: ast::Expr) -> ast::Expr { + function_builder("VARIANCE", vec![expr], false) + } + + fn std(&self, expr: ast::Expr) -> ast::Expr { + function_builder("STDDEV", vec![expr], false) + } + + fn trunc(&self, exprs: Vec) -> ast::Expr { + // TRUNC in postgres has a problem: + // In TRUNC(double_precision_number, precision) if precision is specified it fails + // If it is not specified it passes considering precision = 0. + // SELECT TRUNC(CAST (0.12 AS DOUBLE PRECISION), 0) fails + // SELECT TRUNC(CAST (0.12 AS DOUBLE PRECISION)) passes. + // Here we check precision, if it is 0 we remove it (such that the precision is implicit). + let func_args_list = ast::FunctionArgumentList { + duplicate_treatment: None, + args: exprs + .into_iter() + .filter_map(|e| { + (e != ast::Expr::Value(ast::Value::Number("0".to_string(), false))) + .then_some(ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e))) + }) + .collect(), + clauses: vec![], + }; + ast::Expr::Function(ast::Function { + name: ast::ObjectName(vec![ast::Ident::from("TRUNC")]), + args: ast::FunctionArguments::List(func_args_list), + over: None, + filter: None, + null_treatment: None, + within_group: vec![], + }) + } + + fn round(&self, exprs: Vec) -> ast::Expr { + // Same as TRUNC + // what if I wanted to do round(0, 0) + let func_args_list = ast::FunctionArgumentList { + duplicate_treatment: None, + args: exprs + .into_iter() + .filter_map(|e| { + (e != ast::Expr::Value(ast::Value::Number("0".to_string(), false))) + .then_some(ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e))) + }) + .collect(), + clauses: vec![], + }; + ast::Expr::Function(ast::Function { + name: ast::ObjectName(vec![ast::Ident::from("ROUND")]), + args: ast::FunctionArguments::List(func_args_list), + over: None, + filter: None, + null_treatment: None, + within_group: vec![], + }) + } +} + +impl QueryToRelationTranslator for RedshiftSqlTranslator { + type D = RedshiftSqlDialect; + + fn dialect(&self) -> Self::D { + RedshiftSqlDialect {} + } +} diff --git a/src/io/mod.rs b/src/io/mod.rs index 66708b45..dd3d5aac 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -11,6 +11,8 @@ pub mod bigquery; #[cfg(feature = "mssql")] pub mod mssql; +#[cfg(feature = "mysql")] +pub mod mysql; pub mod postgresql; #[cfg(feature = "sqlite")] pub mod sqlite; diff --git a/src/io/mssql.rs b/src/io/mssql.rs index c804e09d..39f6358e 100644 --- a/src/io/mssql.rs +++ b/src/io/mssql.rs @@ -171,6 +171,13 @@ impl Database { Schema::empty() .with(("a", DataType::float_interval(0., 10.))) .with(("b", DataType::optional(DataType::float_interval(-1., 1.)))) + .with(( + "c", + DataType::text_values([ + "12/04/2001 13:03:43".into(), + "12/05/2011 13:03:43".into(), + ]), + )) .with(("d", DataType::integer_interval(0, 10))), ) .build(), @@ -182,7 +189,7 @@ impl Database { Schema::empty() .with(("x", DataType::integer_interval(0, 100))) .with(("y", DataType::optional(DataType::text()))) - .with(("z", DataType::text_values(["Foo".into(), "Bar".into()]))), //Can't push these? why? + .with(("z", DataType::text_values(["Foo".into(), "Bar".into()]))), ) .build(), TableBuilder::new() diff --git a/src/io/mysql.rs b/src/io/mysql.rs new file mode 100644 index 00000000..a231794a --- /dev/null +++ b/src/io/mysql.rs @@ -0,0 +1,398 @@ +//! An object creating a docker container and releasing it after use +//! + +use super::{Database as DatabaseTrait, Error, Result, DATA_GENERATION_SEED}; +use crate::{ + data_type::{ + self, + generator::Generator, + value::{self, Value}, + DataTyped, + }, + dialect_translation::mysql::MySqlTranslator, + namer, + relation::{Table, Variant as _}, +}; +use std::{ + env, fmt, ops::Deref as _, process::Command, str::FromStr, string::FromUtf8Error, sync::Mutex, + thread, time, +}; + +use chrono::{Datelike, Timelike as _}; +use colored::Colorize; +use mysql::{prelude::*, OptsBuilder, Value as MySqlValue}; +use r2d2::Pool; +use r2d2_mysql::MySqlConnectionManager; +use rand::{rngs::StdRng, SeedableRng}; + +const DB: &str = "qrlew_mysql_test"; +const PORT: usize = 3306; +const USER: &str = "root"; +const PASSWORD: &str = "qrlew_test"; + +/// Converts mysql errors to io errors +impl From for Error { + fn from(err: mysql::Error) -> Self { + Error::Other(err.to_string()) + } +} + +impl From for Error { + fn from(value: FromUtf8Error) -> Self { + Error::Other(value.to_string()) + } +} + +pub struct Database { + name: String, + tables: Vec, + pool: Pool, + drop: bool, +} + +/// Only one pool +pub static MYSQL_POOL: Mutex>> = Mutex::new(None); +/// Only one thread starts a container +pub static MYSQL_CONTAINER: Mutex = Mutex::new(false); + +impl Database { + fn port() -> usize { + match env::var("MYSQL_PORT") { + Ok(port) => usize::from_str(&port).unwrap_or(PORT), + Err(_) => PORT, + } + } + + fn user() -> String { + env::var("MYSQL_USER").unwrap_or_else(|_| USER.into()) + } + + fn password() -> String { + env::var("MYSQL_PASSWORD").unwrap_or_else(|_| PASSWORD.into()) + } + + /// Try to build a pool from an existing DB + /// A MySQL instance must exist + /// `docker run --name qrlew-test -p 3306:3306 -e MYSQL_ROOT_PASSWORD=qrlew_test -d mysql` + fn build_pool_from_existing() -> Result> { + let opts = OptsBuilder::new() + .ip_or_hostname(Some("localhost")) + .tcp_port(Database::port() as u16) + .user(Some(&Database::user())) + .pass(Some(&Database::password())) + .db_name(Some(DB)); + let manager = MySqlConnectionManager::new(OptsBuilder::from(opts)); + Ok(r2d2::Pool::builder().max_size(10).build(manager)?) + } + + /// Try to build a pool from a DB in a container + fn build_pool_from_container(name: String) -> Result> { + let mut mysql_container = MYSQL_CONTAINER.lock().unwrap(); + if !*mysql_container { + // A new container will be started + *mysql_container = true; + // Other threads will wait for this to be ready + let name = namer::new_name(name); + let port = PORT + namer::new_id("mysql-port"); + + // Test the connection and launch a test instance if necessary + if !Command::new("docker") + .arg("start") + .arg(&name) + .status()? + .success() + { + log::debug!("Starting the DB"); + // If the container does not exist, start a new container + // Run: `docker run --name test-db -e MYSQL_ROOT_PASSWORD=test -d mysql` + let output = Command::new("docker") + .arg("run") + .arg("--name") + .arg(&name) + .arg("-d") + .arg("--rm") + .arg("-e") + .arg(format!("MYSQL_ROOT_PASSWORD={PASSWORD}")) + .arg("-p") + .arg(format!("{}:3306", port)) + .arg("mysql") + .output()?; + log::info!("{:?}", output); + log::info!("Waiting for the DB to start"); + // Wait for the DB to be ready + loop { + let output = Command::new("docker") + .arg("exec") + .arg(&name) + .arg("mysqladmin") + .arg("--user=root") + .arg(format!("--password={}", PASSWORD)) + .arg("ping") + .output()?; + if output.status.success() + && String::from_utf8_lossy(&output.stdout).contains("mysqld is alive") + { + break; + } + thread::sleep(time::Duration::from_millis(200)); + log::info!("Waiting..."); + } + log::info!("{}", "DB ready".red()); + } + let opts = OptsBuilder::new() + .ip_or_hostname(Some("localhost")) + .tcp_port(port as u16) + .user(Some(&Database::user())) + .pass(Some(&Database::password())) + .db_name(Some(DB)); + let manager = MySqlConnectionManager::new(OptsBuilder::from(opts)); + let pool = r2d2::Pool::builder().max_size(10).build(manager)?; + + // Ensure database exists + let mut conn = pool.get()?; + conn.query_drop(format!("CREATE DATABASE IF NOT EXISTS `{}`;", DB))?; + Ok(pool) + } else { + Database::build_pool_from_existing() + } + } +} + +impl fmt::Debug for Database { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Database") + .field("name", &self.name) + .field("tables", &self.tables) + .finish() + } +} + +impl DatabaseTrait for Database { + fn new(name: String, tables: Vec
) -> Result { + let mut mysql_pool = MYSQL_POOL.lock().unwrap(); + if mysql_pool.is_none() { + *mysql_pool = Some( + Database::build_pool_from_existing() + .or_else(|_| Database::build_pool_from_container(name.clone()))?, + ); + } + let pool = mysql_pool.as_ref().unwrap().clone(); + let mut conn = pool.get()?; + conn.query_drop(format!("CREATE DATABASE IF NOT EXISTS `{}`", DB))?; + conn.query_drop(format!("USE `{}`", DB))?; + let table_names: Vec = conn.query("SHOW TABLES")?; + if table_names.is_empty() { + Database { + name, + tables: vec![], + pool, + drop: false, + } + .with_tables(tables) + } else { + Ok(Database { + name, + tables, + pool, + drop: false, + }) + } + } + + fn name(&self) -> &str { + &self.name + } + + fn tables(&self) -> &[Table] { + &self.tables + } + + fn tables_mut(&mut self) -> &mut Vec
{ + &mut self.tables + } + + fn create_table(&mut self, table: &Table) -> Result { + let mut conn = self.pool.get()?; + let query = table.create(MySqlTranslator).to_string(); + conn.query_drop(&query)?; + Ok(0) + } + + fn insert_data(&mut self, table: &Table) -> Result<()> { + let mut rng = StdRng::seed_from_u64(DATA_GENERATION_SEED); + let size = Database::MAX_SIZE.min(table.size().generate(&mut rng) as usize); + let mut conn = self.pool.get()?; + let query = table.insert("?", MySqlTranslator).to_string(); + for _ in 0..size { + let structured: value::Struct = + table.schema().data_type().generate(&mut rng).try_into()?; + let values: Result> = structured + .into_iter() + .map(|(_, v)| MySqlValue::try_from((**v).clone())) + .collect(); + let values = values?; + conn.exec_drop(&query, values)?; + } + Ok(()) + } + + fn query(&mut self, query: &str) -> Result> { + let mut conn = self.pool.get()?; + let result: Vec = conn.query(query)?; + let rows: Result> = result + .into_iter() + .map(|row| { + let values: Result> = row + .unwrap() + .into_iter() + .map(|v| Value::try_from(v)) + .collect(); + Ok(value::List::from_iter(values?)) + }) + .collect(); + Ok(rows?) + } +} + +impl Drop for Database { + fn drop(&mut self) { + if self.drop { + Command::new("docker") + .arg("rm") + .arg("--force") + .arg(self.name()) + .status() + .expect("Deleted container"); + } + } +} + +impl TryFrom for MySqlValue { + type Error = Error; + + fn try_from(value: Value) -> Result { + match value { + Value::Boolean(b) => Ok(MySqlValue::from(b.deref())), + Value::Integer(i) => Ok(MySqlValue::from(i.deref())), + Value::Float(f) => Ok(MySqlValue::from(f.deref())), + Value::Text(t) => Ok(MySqlValue::from(t.deref())), + Value::Optional(o) => o + .as_ref() + .map(|v| MySqlValue::try_from((**v).clone())) + .transpose() + .map(|o| o.unwrap_or(MySqlValue::NULL)), + Value::Date(d) => Ok(MySqlValue::Date( + d.year() as u16, + d.month() as u8, + d.day() as u8, + 0, + 0, + 0, + 0, + )), + Value::Time(t) => Ok(MySqlValue::Time( + false, + 0, + t.hour() as u8, + t.minute() as u8, + t.second() as u8, + 0, + )), + Value::DateTime(dt) => { + let dt_naive = dt.deref(); + let date = dt_naive.date(); + let time = dt_naive.time(); + Ok(MySqlValue::Date( + date.year() as u16, + date.month() as u8, + date.day() as u8, + time.hour() as u8, + time.minute() as u8, + time.second() as u8, + 0, + )) + } + Value::Id(i) => Ok(MySqlValue::from(i.deref())), + _ => Err(Error::other(value)), + } + } +} + +impl TryFrom for Value { + type Error = Error; + + fn try_from(value: MySqlValue) -> Result { + match value { + MySqlValue::NULL => Ok(Value::Optional(data_type::value::Optional::new(None))), + MySqlValue::Int(i) => Ok(Value::Integer(i.into())), + MySqlValue::UInt(u) => Ok(Value::Integer((u as i64).into())), + MySqlValue::Float(f) => Ok(Value::Float((f as f64).into())), + MySqlValue::Double(d) => Ok(Value::Float(d.into())), + MySqlValue::Bytes(bytes) => { + let s = String::from_utf8(bytes)?; + Ok(Value::Text(s.into())) + } + MySqlValue::Date(year, month, day, hour, min, sec, _) => { + let dt = chrono::NaiveDate::from_ymd_opt(year as i32, month as u32, day as u32) + .ok_or_else(|| Error::other("Invalid date"))?; + if hour == 0 && min == 0 && sec == 0 { + Ok(Value::Date(dt.into())) + } else { + let time = chrono::NaiveTime::from_hms_opt(hour as u32, min as u32, sec as u32) + .ok_or_else(|| Error::other("Invalid time"))?; + Ok(Value::DateTime(chrono::NaiveDateTime::new(dt, time).into())) + } + } + MySqlValue::Time(neg, days, hours, mins, secs, _) => { + let total_secs = (((((days * 24) + u32::from(hours)) * 60 + u32::from(mins)) * 60 + + u32::from(secs)) as i64) + * if neg { -1 } else { 1 }; + let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt( + total_secs.abs() as u32, + 0, + ) + .ok_or_else(|| Error::other("Invalid time"))?; + Ok(Value::Time(time.into())) + } + } + } +} + +pub fn test_database() -> Database { + Database::new(DB.into(), Database::test_tables()).expect("Database") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn database_display() -> Result<()> { + let mut database = test_database(); + for query in [ + "SELECT count(a), 1+sum(a), d FROM table_1 GROUP BY d", + "SELECT AVG(x) as a FROM table_2", + "SELECT 1+count(y) as a, sum(1+x) as b FROM table_2", + "SELECT * FROM (SELECT * FROM table_1) as cte", + "SELECT * FROM table_2", + ] { + println!("\n{query}"); + for row in database.query(query)? { + println!("{}", row); + } + } + Ok(()) + } + + #[test] + fn database_test() -> Result<()> { + let mut database = test_database(); + println!("Pool {}", database.pool.max_size()); + assert!(!database.eq("SELECT * FROM table_1", "SELECT * FROM table_2")); + assert!(database.eq( + "SELECT * FROM table_1", + "SELECT * FROM (SELECT * FROM table_1) as cte" + )); + Ok(()) + } +} diff --git a/src/relation/sql.rs b/src/relation/sql.rs index 0dcea561..55bf60f9 100644 --- a/src/relation/sql.rs +++ b/src/relation/sql.rs @@ -91,35 +91,6 @@ fn select_from_query(query: ast::Query) -> ast::Select { } } -/// Build a set operation -fn set_operation( - with: Vec, - operator: ast::SetOperator, - quantifier: ast::SetQuantifier, - left: ast::Select, - right: ast::Select, -) -> ast::Query { - ast::Query { - with: (!with.is_empty()).then_some(ast::With { - recursive: false, - cte_tables: with, - }), - body: Box::new(ast::SetExpr::SetOperation { - op: operator, - set_quantifier: quantifier, - left: Box::new(ast::SetExpr::Select(Box::new(left))), - right: Box::new(ast::SetExpr::Select(Box::new(right))), - }), - order_by: vec![], - limit: None, - offset: None, - fetch: None, - locks: vec![], - limit_by: vec![], - for_clause: None, - } -} - impl<'a, T: RelationToQueryTranslator> Visitor<'a, ast::Query> for FromRelationVisitor { fn table(&self, table: &'a Table) -> ast::Query { self.translator.query( @@ -341,7 +312,7 @@ impl<'a, T: RelationToQueryTranslator> Visitor<'a, ast::Query> for FromRelationV .iter() .map(|field| self.translator.identifier(&(field.name().into()))[0].clone()) .collect(), - set_operation( + self.translator.set_operation( vec![], set.operator.clone().into(), set.quantifier.clone().into(), diff --git a/src/sql/mod.rs b/src/sql/mod.rs index 66fa48a4..8e1a01c4 100644 --- a/src/sql/mod.rs +++ b/src/sql/mod.rs @@ -97,11 +97,17 @@ impl From for Error { } } +impl From for Error { + fn from(err: crate::expr::Error) -> Self { + Error::Other(err.to_string()) + } +} + pub type Result = result::Result; // Import a few functions pub use expr::{parse_expr, parse_expr_with_dialect}; -pub use relation::{parse, parse_with_dialect}; +pub use relation::{parse, parse_with_dialect, tables_prefix}; #[cfg(test)] mod tests { diff --git a/src/sql/relation.rs b/src/sql/relation.rs index edbedc61..ccc41c65 100644 --- a/src/sql/relation.rs +++ b/src/sql/relation.rs @@ -914,6 +914,58 @@ impl<'a, T: QueryToRelationTranslator + Copy + Clone> TryFrom<(QueryWithRelation } } +/// Implement conversion to Identifier +impl<'a, T: QueryToRelationTranslator + Copy + Clone> TryFrom<(&'a ast::ObjectName, T)> + for Identifier +{ + type Error = Error; + + fn try_from(value: (&'a ast::ObjectName, T)) -> result::Result { + let (object_name, translator) = value; + let checked_identifiers: Result> = object_name + .0 + .iter() + .map(|i| { + let binding = translator.try_identifier(i)?; + let head = binding.head()?; + Ok(head.to_string()) + }) + .collect(); + let checked_identifiers = checked_identifiers?; + Ok(checked_identifiers + .iter() + .fold(Identifier::empty(), |acc, x| acc.with(x.to_string()))) + } +} + +struct FullyQualifiedTableNames(Vec); + +impl<'a, T: QueryToRelationTranslator + Copy + Clone> TryFrom<(&'a str, T)> + for FullyQualifiedTableNames +{ + type Error = Error; + + fn try_from(value: (&str, T)) -> result::Result { + let (query, translator) = value; + let parsed_query = parse_with_dialect(query, translator.dialect())?; + // Visit the query to get query names + let query_names = parsed_query.accept(IntoQueryNamesVisitor); + // get only base tables names + let tables_names: Result> = query_names + .iter() + .filter_map(|((_, object_name), pointed_query)| { + if pointed_query.is_none() { + Some(object_name) + } else { + None + } + }) + .map(|object_name| Identifier::try_from((object_name, translator))) + .collect(); + Ok(FullyQualifiedTableNames(tables_names?)) + } +} + /// It creates a new hierarchy with Identifier for which the last part of their /// path is not ambiguous. The new hierarchy will contain one-element paths fn last(columns: &Hierarchy) -> Hierarchy { @@ -953,20 +1005,95 @@ pub fn parse(query: &str) -> Result { parse_with_dialect(query, GenericDialect) } +/// Given a query and a translator it provides the prefix of base tables in the +/// query. It fails if tables identifiers are wrongly quoted. +pub fn tables_prefix( + query: &str, + translator: T, +) -> Result> { + let tables = FullyQualifiedTableNames::try_from((query, translator))?; + tables.0.iter().map(|f| Ok(f.head()?.to_string())).collect() +} + #[cfg(test)] mod tests { - use std::sync::Arc; - - use colored::Colorize; - use super::*; use crate::{ builder::Ready, data_type::{DataType, DataTyped, Variant}, + dialect_translation::{bigquery::BigQueryTranslator, mssql::MsSqlTranslator}, display::Dot, io::{postgresql, Database}, relation::schema::Schema, }; + use colored::Colorize; + use std::sync::Arc; + + #[test] + fn test_tables_prefix() { + let query_str = "SELECT * FROM my_db.my_sch.my_tab"; + let tables = tables_prefix(query_str, PostgreSqlTranslator).unwrap(); + assert_eq!(tables, vec!["my_db".to_string()]); + + let query_str = r#"SELECT * FROM "my_db"."my_sch"."my_tab""#; + let tables = tables_prefix(query_str, PostgreSqlTranslator).unwrap(); + assert_eq!(tables, vec!["my_db".to_string()]); + + let query_str = r#"SELECT * FROM `my_db`.`my_sch`.`my_tab`"#; + let tables = tables_prefix(query_str, PostgreSqlTranslator); + + assert!(tables.is_err()); + + let query_str = r#"SELECT * FROM `my_db`.`my_sch`.`my_tab`"#; + let tables = tables_prefix(query_str, MsSqlTranslator); + assert!(tables.is_err()); + + let query_str = r#"SELECT * FROM `my_db`.`my_sch`.`my_tab`"#; + let tables = tables_prefix(query_str, BigQueryTranslator).unwrap(); + assert_eq!(tables, vec!["my_db".to_string()]); + + let query_str = r#"SELECT * FROM "my_db"."my_sch"."my_tab""#; + let tables = tables_prefix(query_str, MsSqlTranslator).unwrap(); + assert_eq!(tables, vec!["my_db".to_string()]); + + let query_str = r#"SELECT * FROM [my_db].[my_sch].[my_tab]"#; + let tables = tables_prefix(query_str, MsSqlTranslator).unwrap(); + assert_eq!(tables, vec!["my_db".to_string()]); + + let query_str = "SELECT * FROM a.b.c.d.e.f"; + let tables = tables_prefix(query_str, PostgreSqlTranslator).unwrap(); + assert_eq!(tables, vec!["a".to_string()]); + + let query_str = "SELECT * FROM (SELECT * FROM my_db.my_sch.my_tab) AS t1"; + let tables = tables_prefix(query_str, PostgreSqlTranslator).unwrap(); + assert_eq!(tables, vec!["my_db".to_string()]); + + let query_str = r#" + SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM my_db.my_sch.my_tab) AS my_tab1) AS my_tab2) AS my_tab3) AS my_tab4 + "#; + let tables = tables_prefix(query_str, PostgreSqlTranslator).unwrap(); + assert_eq!(tables, vec!["my_db".to_string()]); + + let query_str = r#" + WITH my_tab1 AS (SELECT * FROM my_db.my_sch.my_tab), + my_tab2 AS (SELECT * FROM my_tab1), + my_tab3 AS (SELECT * FROM my_tab2), + my_tab4 AS (SELECT * FROM my_tab3) + SELECT * FROM my_tab4 + "#; + let tables = tables_prefix(query_str, PostgreSqlTranslator).unwrap(); + assert_eq!(tables, vec!["my_db".to_string()]); + + let query_str = r#" + WITH my_tab1 AS (SELECT * FROM my_db.my_sch.my_tab), + my_tab2 AS (SELECT * FROM my_tab1), + my_tab3 AS (SELECT * FROM my_tab2), + my_tab4 AS (SELECT * FROM my_other_db.my_other_sch.my_other_tab) + SELECT * FROM my_tab3 JOIN my_tab4 USING(id) + "#; + let tables = tables_prefix(query_str, PostgreSqlTranslator).unwrap(); + assert_eq!(tables, vec!["my_db".to_string(), "my_other_db".to_string()]); + } #[test] fn test_map_from_query() { diff --git a/tests/integration.rs b/tests/integration.rs index 6263fb85..35c91adb 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -107,11 +107,7 @@ const QUERIES: &[&str] = &[ SELECT * FROM t1 INNER JOIN t2 ON t1.d = t2.x INNER JOIN table_2 ON t1.d=table_2.x ORDER BY t1.a, t2.x, t2.y, t2.z LIMIT 17", "WITH t1 AS (SELECT a,d FROM table_1), t2 AS (SELECT * FROM table_2) - SELECT * FROM t1 INNER JOIN t2 ON t1.d = t2.x INNER JOIN table_2 ON t1.d=table_2.x ORDER BY t1.a, t2.x, t2.y, t2.z OFFSET 5", - "WITH t1 AS (SELECT a,d FROM table_1), - t2 AS (SELECT * FROM table_2) SELECT * FROM t1 INNER JOIN t2 ON t1.d = t2.x INNER JOIN table_2 ON t1.d=table_2.x ORDER BY t1.a, t2.x, t2.y, t2.z LIMIT 17 OFFSET 5", - "SELECT CASE a WHEN 5 THEN 0 ELSE a END FROM table_1", "SELECT CASE WHEN a < 5 THEN 0 WHEN a < 3 THEN 3 ELSE a END FROM table_1", "SELECT CASE WHEN a < 5 THEN 0 WHEN a < 3 THEN 3 END FROM table_1", @@ -136,25 +132,28 @@ const QUERIES: &[&str] = &[ "SELECT d, COUNT(*) AS my_count FROM table_1 GROUP BY d ORDER BY my_count", "SELECT d, COUNT(*) AS my_count FROM table_1 GROUP BY d ORDER BY my_count", // DISTINCT - "SELECT DISTINCT COUNT(*) FROM table_1 GROUP BY d", - "SELECT DISTINCT c, d FROM table_1", + "SELECT DISTINCT COUNT(*) FROM table_1 GROUP BY d", // fails with sqlite + "SELECT DISTINCT c, d FROM table_1", // fails with sqlite "SELECT c, COUNT(DISTINCT d) AS count_d, SUM(DISTINCT d) AS sum_d FROM table_1 GROUP BY c ORDER BY c", "SELECT SUM(DISTINCT a) AS s1 FROM table_1 GROUP BY c HAVING COUNT(*) > 5;", // using joins "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7) SELECT * FROM t1 INNER JOIN t2 USING(a)", "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7) SELECT * FROM t1 LEFT JOIN t2 USING(a)", "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7) SELECT * FROM t1 RIGHT JOIN t2 USING(a)", - "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7) SELECT * FROM t1 FULL JOIN t2 USING(a)", // natural joins - "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL INNER JOIN t2", - "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL LEFT JOIN t2", - "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL RIGHT JOIN t2", - "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL FULL JOIN t2", - "SELECT a, SUM(a) FROM table_1 GROUP BY a" + "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL INNER JOIN t2", // fails with sqlite + "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL LEFT JOIN t2", // fails with sqlite + "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL RIGHT JOIN t2", // fails with sqlite + "SELECT a, SUM(a) FROM table_1 GROUP BY a", + "SELECT SUBSTRING(z FROM 1 FOR 2) AS m, COUNT(*) AS my_count FROM table_2 GROUP BY z;", // fails with sqlite ]; #[cfg(feature = "sqlite")] -const SQLITE_QUERIES: &[&str] = &["SELECT AVG(b) as n, count(b) as d FROM table_1"]; +const SQLITE_QUERIES: &[&str] = &[ + "SELECT AVG(b) as n, count(b) as d FROM table_1", + "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7) SELECT * FROM t1 FULL JOIN t2 USING(a)", + "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL FULL JOIN t2", +]; #[cfg(feature = "sqlite")] #[test] @@ -165,7 +164,7 @@ fn test_on_sqlite() { println!("schema {} = {}", tab, tab.schema()); } for &query in SQLITE_QUERIES.iter().chain(QUERIES) { - assert!(test_rewritten_eq(&mut database, query)); + assert!(test_rewritten_eq(&mut database, query)) } } // This should work: https://www.db-fiddle.com/f/ouKSHjkEk29zWY5PN2YmjZ/10 @@ -177,11 +176,15 @@ const POSTGRESQL_QUERIES: &[&str] = &[ "SELECT CONCAT(x,y,z) FROM table_2 LIMIT 11", "SELECT CHAR_LENGTH(z) AS char_length FROM table_2 LIMIT 1", "SELECT POSITION('o' IN z) AS char_length FROM table_2 LIMIT 5", - "SELECT SUBSTRING(z FROM 1 FOR 2) AS m, COUNT(*) AS my_count FROM table_2 GROUP BY z;", "SELECT z AS age1, SUM(x) AS s1 FROM table_2 WHERE z IS NOT NULL GROUP BY z;", "SELECT COUNT(*) AS c1 FROM table_2 WHERE y ILIKE '%ab%';", "SELECT z, CASE WHEN z IS Null THEN 'Null' ELSE 'NotNull' END AS case_age, COUNT(*) AS c1 FROM table_2 GROUP BY z;", r#"SELECT "Id", NORMAL_COL, "Na.Me" FROM "MY SPECIAL TABLE""#, + "WITH t1 AS (SELECT a,d FROM table_1), + t2 AS (SELECT * FROM table_2) + SELECT * FROM t1 INNER JOIN t2 ON t1.d = t2.x INNER JOIN table_2 ON t1.d=table_2.x ORDER BY t1.a, t2.x, t2.y, t2.z OFFSET 5", + "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7) SELECT * FROM t1 FULL JOIN t2 USING(a)", + "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL FULL JOIN t2", // This fails consistency tests due to numeric errors. It could be fixed with Round // but in psql round(arg, precision) fails if arg is a double precision type // "SELECT @@ -210,83 +213,16 @@ fn test_on_postgresql() { } #[cfg(feature = "mssql")] -const MSSQL_QUERIES: &[&str] = &[ - "SELECT RANDOM(), * FROM table_2", - "SELECT AVG(x) as a FROM table_2", - "SELECT 1+count(y) as a, sum(1+x) as b FROM table_2", - "SELECT 1+SUM(a), count(b) FROM table_1", - // Some WHERE - "SELECT 1+SUM(a), count(b) FROM table_1 WHERE a>4", - "SELECT SUM(a), count(b) FROM table_1 WHERE a>4", - // Some GROUP BY - "SELECT 1+SUM(a), count(b) FROM table_1 GROUP BY d", - "SELECT count(b) FROM table_1 GROUP BY CEIL(d)", - "SELECT CEIL(d) AS d_ceiled, count(b) FROM table_1 GROUP BY CEIL(d)", - // "SELECT CEIL(d) AS d_ceiled, count(b) FROM table_1 GROUP BY d_ceiled", - // Some WHERE and GROUP BY - "SELECT 1+SUM(a), count(b) FROM table_1 WHERE d>4 GROUP BY d", - "SELECT 1+SUM(a), count(b), d FROM table_1 GROUP BY d", - "SELECT sum(a) FROM table_1 JOIN table_2 ON table_1.d = table_2.x", - "WITH t1 AS (SELECT a,d FROM table_1), - t2 AS (SELECT * FROM table_2) - SELECT sum(a) FROM t1 JOIN t2 ON t1.d = t2.x", - "WITH t1 AS (SELECT a,d FROM table_1 WHERE a>4), - t2 AS (SELECT * FROM table_2) - SELECT max(a), sum(d) FROM t1 INNER JOIN t2 ON t1.d = t2.x CROSS JOIN table_2", - // The ORDER BY clause is invalid in views, inline functions, derived tables, subqueries, and common table expressions, unless TOP, OFFSET or FOR XML is also specified. - // "WITH t1 AS (SELECT a,d FROM table_1), - // t2 AS (SELECT * FROM table_2) - // SELECT * FROM t1 INNER JOIN t2 ON t1.d = t2.x INNER JOIN table_2 ON t1.d=table_2.x ORDER BY t1.a, t2.x, t2.y, t2.z", - // Test LIMIT - // Test LIMIT - "WITH t1 AS (SELECT a,d FROM table_1), - t2 AS (SELECT * FROM table_2) - SELECT * FROM t1 INNER JOIN t2 ON t1.d = t2.x INNER JOIN table_2 ON t1.d=table_2.x ORDER BY t1.a, t2.x, t2.y, t2.z LIMIT 17", - "SELECT CASE a WHEN 5 THEN 0 ELSE a END FROM table_1", - "SELECT CASE WHEN a < 5 THEN 0 WHEN a < 3 THEN 3 ELSE a END FROM table_1", - "SELECT CASE WHEN a < 5 THEN 0 WHEN a < 3 THEN 3 END FROM table_1", - // Test UNION - "SELECT 1*a FROM table_1 UNION SELECT 1*x FROM table_2", - // Test no UNION with CTEs - "WITH t1 AS (SELECT a,d FROM table_1), - t2 AS (SELECT x,y FROM table_2) - SELECT * FROM t1", - // Test UNION with CTEs - "WITH t1 AS (SELECT 1*a, 1*d FROM table_1), - t2 AS (SELECT 0.1*x as a, 2*x as b FROM table_2) - SELECT * FROM t1 UNION SELECT * FROM t2", +const PSQL_QUERIES_FOR_MSSQL_DB: &[&str] = &[ // Some joins - "SELECT * FROM order_table LEFT JOIN item_table on id=order_id WHERE price>10", - "SELECT SUBSTRING(z FROM 1 FOR 2) AS m, COUNT(*) AS my_count FROM table_2 GROUP BY z;", "SELECT z AS age1, SUM(x) AS s1 FROM table_2 WHERE z IS NOT NULL GROUP BY z;", "SELECT z, CASE WHEN z IS Null THEN 0 ELSE 1 END AS case_age, COUNT(*) AS c1 FROM table_2 GROUP BY z;", "SELECT z, CASE WHEN z IS Null THEN CAST('A' AS VARCHAR(10)) ELSE CAST('B' AS VARCHAR(10)) END AS case_age, COUNT(*) AS c1 FROM table_2 GROUP BY z;", - // Some string functions - //"SELECT UPPER(z) FROM table_2 LIMIT 5", - //"SELECT LOWER(z) FROM table_2 LIMIT 5", - // ORDER BY - // The ORDER BY clause is invalid in views, inline functions, derived tables, subqueries, and common table expressions, unless TOP, OFFSET or FOR XML is also specified. - //"SELECT d, COUNT(*) AS my_count FROM table_1 GROUP BY d ORDER BY d", - //"SELECT d, COUNT(*) AS my_count FROM table_1 GROUP BY d ORDER BY d DESC", - //"SELECT d, COUNT(*) AS my_count FROM table_1 GROUP BY d ORDER BY my_count", - //"SELECT d, COUNT(*) AS my_count FROM table_1 GROUP BY d ORDER BY my_count", - // DISTINCT - // Some bug somewhere. Error not informative: panicked at src/expr/mod.rs:1029:18: Option::unwrap()` on a `None` value - //"SELECT DISTINCT COUNT(*) FROM table_1 GROUP BY d", - //"SELECT DISTINCT c, d FROM table_1", - //"SELECT c, COUNT(DISTINCT d) AS count_d, SUM(DISTINCT d) AS sum_d FROM table_1 GROUP BY c ORDER BY c", - //"SELECT SUM(DISTINCT a) AS s1 FROM table_1 GROUP BY c HAVING COUNT(*) > 5;", - // using joins - //"WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7) SELECT * FROM t1 INNER JOIN t2 USING(a)", - // "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7) SELECT * FROM t1 LEFT JOIN t2 USING(a)", - // "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7) SELECT * FROM t1 RIGHT JOIN t2 USING(a)", - // "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7) SELECT * FROM t1 FULL JOIN t2 USING(a)", - // // natural joins - // "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL INNER JOIN t2", - // "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL LEFT JOIN t2", - // "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL RIGHT JOIN t2", - // "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL FULL JOIN t2", - r#"SELECT "Id", NORMAL_COL, "Na.Me" FROM "MY SPECIAL TABLE""#, + "WITH t1 AS (SELECT a,d FROM table_1), + t2 AS (SELECT * FROM table_2) + SELECT * FROM t1 INNER JOIN t2 ON t1.d = t2.x INNER JOIN table_2 ON t1.d=table_2.x ORDER BY t1.a, t2.x, t2.y, t2.z OFFSET 5", + "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7) SELECT * FROM t1 FULL JOIN t2 USING(a)", + "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL FULL JOIN t2", ]; #[cfg(feature = "mssql")] @@ -302,12 +238,29 @@ fn test_on_mssql() { println!("schema {} = {}", tab, tab.schema()); } // TODO We should pass the QUERIES list too - for &query in MSSQL_QUERIES.iter() { + for &query in QUERIES.iter().chain(PSQL_QUERIES_FOR_MSSQL_DB) { println!("TESTING QUERY: {}", query); test_execute(&mut database, query, MsSqlTranslator); } } +#[cfg(feature = "bigquery")] +const PSQL_QUERIES_FOR_BIGQUERY_DB: &[&str] = &[ + "SELECT AVG(b) as n, count(b) as d FROM table_1", + "SELECT MD5(z) FROM table_2 LIMIT 10", + "SELECT CONCAT(x,y,z) FROM table_2 LIMIT 11", + "SELECT CHAR_LENGTH(z) AS char_length FROM table_2 LIMIT 1", + "SELECT z AS age1, SUM(x) AS s1 FROM table_2 WHERE z IS NOT NULL GROUP BY z;", + "SELECT COUNT(*) AS c1 FROM table_2 WHERE y LIKE '%Ba%';", + "SELECT z, CASE WHEN z IS Null THEN 'Null' ELSE 'NotNull' END AS case_age, COUNT(*) AS c1 FROM table_2 GROUP BY z;", + "SELECT RANDOM(), * FROM table_2", + "SELECT z, CASE WHEN z IS Null THEN CAST('A' AS VARCHAR(10)) ELSE CAST('B' AS VARCHAR(10)) END AS case_age, COUNT(*) AS c1 FROM table_2 GROUP BY z;", + "SELECT z, CASE WHEN z IS Null THEN 0 ELSE 1 END AS case_age, COUNT(*) AS c1 FROM table_2 GROUP BY z;", + r#"SELECT "Id", NORMAL_COL, "Na.Me" FROM MY_SPECIAL_TABLE"#, + "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7) SELECT * FROM t1 FULL JOIN t2 USING(a)", + "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL FULL JOIN t2", +]; + #[cfg(feature = "bigquery")] #[test] fn test_on_bigquery() { @@ -320,94 +273,37 @@ fn test_on_bigquery() { for tab in database.tables() { println!("schema {} = {}", tab, tab.schema()); } - let queries_for_bq = [ - "SELECT AVG(b) as n, count(b) as d FROM table_1", - "SELECT MD5(z) FROM table_2 LIMIT 10", - "SELECT CONCAT(x,y,z) FROM table_2 LIMIT 11", - "SELECT CHAR_LENGTH(z) AS char_length FROM table_2 LIMIT 1", - //"SELECT POSITION('o' IN z) AS char_length FROM table_2 LIMIT 5", - "SELECT SUBSTRING(z FROM 1 FOR 2) AS m, COUNT(*) AS my_count FROM table_2 GROUP BY z;", - "SELECT z AS age1, SUM(x) AS s1 FROM table_2 WHERE z IS NOT NULL GROUP BY z;", - "SELECT COUNT(*) AS c1 FROM table_2 WHERE y LIKE '%Ba%';", - "SELECT z, CASE WHEN z IS Null THEN 'Null' ELSE 'NotNull' END AS case_age, COUNT(*) AS c1 FROM table_2 GROUP BY z;", - "SELECT RANDOM(), * FROM table_2", - "SELECT AVG(x) as a FROM table_2", - "SELECT 1+count(y) as a, sum(1+x) as b FROM table_2", - "SELECT 1+SUM(a), count(b) FROM table_1", - // Some WHERE - "SELECT 1+SUM(a), count(b) FROM table_1 WHERE a>4", - "SELECT SUM(a), count(b) FROM table_1 WHERE a>4", - // Some GROUP BY - "SELECT 1+SUM(a), count(b) FROM table_1 GROUP BY d", - "SELECT count(b) FROM table_1 GROUP BY CEIL(d)", - "SELECT CEIL(d) AS d_ceiled, count(b) FROM table_1 GROUP BY CEIL(d)", - // "SELECT CEIL(d) AS d_ceiled, count(b) FROM table_1 GROUP BY d_ceiled", - // Some WHERE and GROUP BY - "SELECT 1+SUM(a), count(b) FROM table_1 WHERE d>4 GROUP BY d", - "SELECT 1+SUM(a), count(b), d FROM table_1 GROUP BY d", - "SELECT sum(a) FROM table_1 JOIN table_2 ON table_1.d = table_2.x", - "WITH t1 AS (SELECT a,d FROM table_1), - t2 AS (SELECT * FROM table_2) - SELECT sum(a) FROM t1 JOIN t2 ON t1.d = t2.x", - "WITH t1 AS (SELECT a,d FROM table_1 WHERE a>4), - t2 AS (SELECT * FROM table_2) - SELECT max(a), sum(d) FROM t1 INNER JOIN t2 ON t1.d = t2.x CROSS JOIN table_2", - "WITH t1 AS (SELECT a,d FROM table_1), - t2 AS (SELECT * FROM table_2) - SELECT * FROM t1 INNER JOIN t2 ON t1.d = t2.x INNER JOIN table_2 ON t1.d=table_2.x ORDER BY t1.a, t2.x, t2.y, t2.z", - // Test LIMIT - "WITH t1 AS (SELECT a,d FROM table_1), - t2 AS (SELECT * FROM table_2) - SELECT * FROM t1 INNER JOIN t2 ON t1.d = t2.x INNER JOIN table_2 ON t1.d=table_2.x ORDER BY t1.a, t2.x, t2.y, t2.z LIMIT 17", - "SELECT CASE a WHEN 5 THEN 0 ELSE a END FROM table_1", - "SELECT CASE WHEN a < 5 THEN 0 WHEN a < 3 THEN 3 ELSE a END FROM table_1", - "SELECT CASE WHEN a < 5 THEN 0 WHEN a < 3 THEN 3 END FROM table_1", - // Test UNION - // "SELECT 1*a FROM table_1 UNION SELECT 1*x FROM table_2", - // Test no UNION with CTEs - "WITH t1 AS (SELECT a,d FROM table_1), - t2 AS (SELECT x,y FROM table_2) - SELECT * FROM t1", - // Test UNION with CTEs - // "WITH t1 AS (SELECT 1*a, 1*d FROM table_1), - // t2 AS (SELECT 0.1*x as a, 2*x as b FROM table_2) - // SELECT * FROM t1 UNION SELECT * FROM t2", - // Some joins - "SELECT * FROM order_table LEFT JOIN item_table on id=order_id WHERE price>10", - "SELECT SUBSTRING(z FROM 1 FOR 2) AS m, COUNT(*) AS my_count FROM table_2 GROUP BY z;", - "SELECT z AS age1, SUM(x) AS s1 FROM table_2 WHERE z IS NOT NULL GROUP BY z;", - "SELECT z, CASE WHEN z IS Null THEN 0 ELSE 1 END AS case_age, COUNT(*) AS c1 FROM table_2 GROUP BY z;", - "SELECT z, CASE WHEN z IS Null THEN CAST('A' AS VARCHAR(10)) ELSE CAST('B' AS VARCHAR(10)) END AS case_age, COUNT(*) AS c1 FROM table_2 GROUP BY z;", - "SELECT UPPER(z) FROM table_2 LIMIT 5", - "SELECT LOWER(z) FROM table_2 LIMIT 5", - // ORDER BY - "SELECT d, COUNT(*) AS my_count FROM table_1 GROUP BY d ORDER BY d", - "SELECT d, COUNT(*) AS my_count FROM table_1 GROUP BY d ORDER BY d DESC", - "SELECT d, COUNT(*) AS my_count FROM table_1 GROUP BY d ORDER BY my_count", - "SELECT d, COUNT(*) AS my_count FROM table_1 GROUP BY d ORDER BY my_count", - // DISTINCT - "SELECT DISTINCT COUNT(*) FROM table_1 GROUP BY d", - "SELECT DISTINCT c, d FROM table_1", - "SELECT c, COUNT(DISTINCT d) AS count_d, SUM(DISTINCT d) AS sum_d FROM table_1 GROUP BY c ORDER BY c", - "SELECT SUM(DISTINCT a) AS s1 FROM table_1 GROUP BY c HAVING COUNT(*) > 5;", - // using joins - "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7) SELECT * FROM t1 INNER JOIN t2 USING(a)", - "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7) SELECT * FROM t1 LEFT JOIN t2 USING(a)", - "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7) SELECT * FROM t1 RIGHT JOIN t2 USING(a)", - "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7) SELECT * FROM t1 FULL JOIN t2 USING(a)", - // natural joins - "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL INNER JOIN t2", - "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL LEFT JOIN t2", - "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL RIGHT JOIN t2", - "WITH t1 AS (SELECT a, b, c FROM table_1 WHERE a > 5), t2 AS (SELECT a, d, c FROM table_1 WHERE a < 7 LIMIT 10) SELECT * FROM t1 NATURAL FULL JOIN t2", - r#"SELECT "Id", NORMAL_COL, "Na.Me" FROM MY_SPECIAL_TABLE"#, - ]; - for &query in queries_for_bq.iter() { + + for &query in QUERIES.iter().chain(PSQL_QUERIES_FOR_BIGQUERY_DB) { println!("TESTING QUERY: {}", query); test_execute(&mut database, query, BigQueryTranslator); } } +#[cfg(feature = "mysql")] +const PSQL_QUERIES_FOR_MYSQL_DB: &[&str] = &[ + "SELECT CAST(d AS INTEGER) FROM table_1", + "SELECT EXTRACT(EPOCH FROM c) FROM table_1", + "SELECT CAST(d AS TEXT) FROM table_1", +]; + +#[cfg(feature = "mysql")] +#[test] +fn test_on_mysql() { + use qrlew::{dialect_translation::mysql::MySqlTranslator, io::mysql}; + + let mut database = mysql::test_database(); + println!("database {} = {}", database.name(), database.relations()); + for tab in database.tables() { + println!("schema {} = {}", tab, tab.schema()); + } + + for &query in PSQL_QUERIES_FOR_MYSQL_DB.iter().chain(QUERIES) { + println!("TESTING QUERY: {}", query); + test_execute(&mut database, query, MySqlTranslator); + } +} + #[test] fn test_distinct_aggregates() { let mut database = postgresql::test_database();