diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index ed0cfddc3827..7eca326386fc 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -15,8 +15,14 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + +use arrow_schema::TimeUnit; use regex::Regex; -use sqlparser::{ast, keywords::ALL_KEYWORDS}; +use sqlparser::{ + ast::{self, Ident, ObjectName, TimezoneInfo}, + keywords::ALL_KEYWORDS, +}; /// `Dialect` to use for Unparsing /// @@ -36,8 +42,8 @@ pub trait Dialect: Send + Sync { true } - // Does the dialect use TIMESTAMP to represent Date64 rather than DATETIME? - // E.g. Trino, Athena and Dremio does not have DATETIME data type + /// Does the dialect use TIMESTAMP to represent Date64 rather than DATETIME? + /// E.g. Trino, Athena and Dremio does not have DATETIME data type fn use_timestamp_for_date64(&self) -> bool { false } @@ -46,23 +52,50 @@ pub trait Dialect: Send + Sync { IntervalStyle::PostgresVerbose } - // Does the dialect use DOUBLE PRECISION to represent Float64 rather than DOUBLE? - // E.g. Postgres uses DOUBLE PRECISION instead of DOUBLE + /// Does the dialect use DOUBLE PRECISION to represent Float64 rather than DOUBLE? + /// E.g. Postgres uses DOUBLE PRECISION instead of DOUBLE fn float64_ast_dtype(&self) -> sqlparser::ast::DataType { sqlparser::ast::DataType::Double } - // The SQL type to use for Arrow Utf8 unparsing - // Most dialects use VARCHAR, but some, like MySQL, require CHAR + /// The SQL type to use for Arrow Utf8 unparsing + /// Most dialects use VARCHAR, but some, like MySQL, require CHAR fn utf8_cast_dtype(&self) -> ast::DataType { ast::DataType::Varchar(None) } - // The SQL type to use for Arrow LargeUtf8 unparsing - // Most dialects use TEXT, but some, like MySQL, require CHAR + /// The SQL type to use for Arrow LargeUtf8 unparsing + /// Most dialects use TEXT, but some, like MySQL, require CHAR fn large_utf8_cast_dtype(&self) -> ast::DataType { ast::DataType::Text } + + /// The date field extract style to use: `DateFieldExtractStyle` + fn date_field_extract_style(&self) -> DateFieldExtractStyle { + DateFieldExtractStyle::DatePart + } + + /// The SQL type to use for Arrow Int64 unparsing + /// Most dialects use BigInt, but some, like MySQL, require SIGNED + fn int64_cast_dtype(&self) -> ast::DataType { + ast::DataType::BigInt(None) + } + + /// The SQL type to use for Timestamp unparsing + /// Most dialects use Timestamp, but some, like MySQL, require Datetime + /// Some dialects like Dremio does not support WithTimeZone and requires always Timestamp + fn timestamp_cast_dtype( + &self, + _time_unit: &TimeUnit, + tz: &Option>, + ) -> ast::DataType { + let tz_info = match tz { + Some(_) => TimezoneInfo::WithTimeZone, + None => TimezoneInfo::None, + }; + + ast::DataType::Timestamp(None, tz_info) + } } /// `IntervalStyle` to use for unparsing @@ -80,6 +113,19 @@ pub enum IntervalStyle { MySQL, } +/// Datetime subfield extraction style for unparsing +/// +/// `` +/// Different DBMSs follow different standards; popular ones are: +/// date_part('YEAR', date '2001-02-16') +/// EXTRACT(YEAR from date '2001-02-16') +/// Some DBMSs, like Postgres, support both, whereas others like MySQL require EXTRACT. +#[derive(Clone, Copy, PartialEq)] +pub enum DateFieldExtractStyle { + DatePart, + Extract, +} + pub struct DefaultDialect {} impl Dialect for DefaultDialect { @@ -133,6 +179,22 @@ impl Dialect for MySqlDialect { fn large_utf8_cast_dtype(&self) -> ast::DataType { ast::DataType::Char(None) } + + fn date_field_extract_style(&self) -> DateFieldExtractStyle { + DateFieldExtractStyle::Extract + } + + fn int64_cast_dtype(&self) -> ast::DataType { + ast::DataType::Custom(ObjectName(vec![Ident::new("SIGNED")]), vec![]) + } + + fn timestamp_cast_dtype( + &self, + _time_unit: &TimeUnit, + _tz: &Option>, + ) -> ast::DataType { + ast::DataType::Datetime(None) + } } pub struct SqliteDialect {} @@ -151,6 +213,10 @@ pub struct CustomDialect { float64_ast_dtype: sqlparser::ast::DataType, utf8_cast_dtype: ast::DataType, large_utf8_cast_dtype: ast::DataType, + date_field_extract_style: DateFieldExtractStyle, + int64_cast_dtype: ast::DataType, + timestamp_cast_dtype: ast::DataType, + timestamp_tz_cast_dtype: ast::DataType, } impl Default for CustomDialect { @@ -163,6 +229,13 @@ impl Default for CustomDialect { float64_ast_dtype: sqlparser::ast::DataType::Double, utf8_cast_dtype: ast::DataType::Varchar(None), large_utf8_cast_dtype: ast::DataType::Text, + date_field_extract_style: DateFieldExtractStyle::DatePart, + int64_cast_dtype: ast::DataType::BigInt(None), + timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None), + timestamp_tz_cast_dtype: ast::DataType::Timestamp( + None, + TimezoneInfo::WithTimeZone, + ), } } } @@ -206,6 +279,26 @@ impl Dialect for CustomDialect { fn large_utf8_cast_dtype(&self) -> ast::DataType { self.large_utf8_cast_dtype.clone() } + + fn date_field_extract_style(&self) -> DateFieldExtractStyle { + self.date_field_extract_style + } + + fn int64_cast_dtype(&self) -> ast::DataType { + self.int64_cast_dtype.clone() + } + + fn timestamp_cast_dtype( + &self, + _time_unit: &TimeUnit, + tz: &Option>, + ) -> ast::DataType { + if tz.is_some() { + self.timestamp_tz_cast_dtype.clone() + } else { + self.timestamp_cast_dtype.clone() + } + } } /// `CustomDialectBuilder` to build `CustomDialect` using builder pattern @@ -230,6 +323,10 @@ pub struct CustomDialectBuilder { float64_ast_dtype: sqlparser::ast::DataType, utf8_cast_dtype: ast::DataType, large_utf8_cast_dtype: ast::DataType, + date_field_extract_style: DateFieldExtractStyle, + int64_cast_dtype: ast::DataType, + timestamp_cast_dtype: ast::DataType, + timestamp_tz_cast_dtype: ast::DataType, } impl Default for CustomDialectBuilder { @@ -248,6 +345,13 @@ impl CustomDialectBuilder { float64_ast_dtype: sqlparser::ast::DataType::Double, utf8_cast_dtype: ast::DataType::Varchar(None), large_utf8_cast_dtype: ast::DataType::Text, + date_field_extract_style: DateFieldExtractStyle::DatePart, + int64_cast_dtype: ast::DataType::BigInt(None), + timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None), + timestamp_tz_cast_dtype: ast::DataType::Timestamp( + None, + TimezoneInfo::WithTimeZone, + ), } } @@ -260,6 +364,10 @@ impl CustomDialectBuilder { float64_ast_dtype: self.float64_ast_dtype, utf8_cast_dtype: self.utf8_cast_dtype, large_utf8_cast_dtype: self.large_utf8_cast_dtype, + date_field_extract_style: self.date_field_extract_style, + int64_cast_dtype: self.int64_cast_dtype, + timestamp_cast_dtype: self.timestamp_cast_dtype, + timestamp_tz_cast_dtype: self.timestamp_tz_cast_dtype, } } @@ -293,6 +401,7 @@ impl CustomDialectBuilder { self } + /// Customize the dialect with a specific SQL type for Float64 casting: DOUBLE, DOUBLE PRECISION, etc. pub fn with_float64_ast_dtype( mut self, float64_ast_dtype: sqlparser::ast::DataType, @@ -301,11 +410,13 @@ impl CustomDialectBuilder { self } + /// Customize the dialect with a specific SQL type for Utf8 casting: VARCHAR, CHAR, etc. pub fn with_utf8_cast_dtype(mut self, utf8_cast_dtype: ast::DataType) -> Self { self.utf8_cast_dtype = utf8_cast_dtype; self } + /// Customize the dialect with a specific SQL type for LargeUtf8 casting: TEXT, CHAR, etc. pub fn with_large_utf8_cast_dtype( mut self, large_utf8_cast_dtype: ast::DataType, @@ -313,4 +424,30 @@ impl CustomDialectBuilder { self.large_utf8_cast_dtype = large_utf8_cast_dtype; self } + + /// Customize the dialect with a specific date field extract style listed in `DateFieldExtractStyle` + pub fn with_date_field_extract_style( + mut self, + date_field_extract_style: DateFieldExtractStyle, + ) -> Self { + self.date_field_extract_style = date_field_extract_style; + self + } + + /// Customize the dialect with a specific SQL type for Int64 casting: BigInt, SIGNED, etc. + pub fn with_int64_cast_dtype(mut self, int64_cast_dtype: ast::DataType) -> Self { + self.int64_cast_dtype = int64_cast_dtype; + self + } + + /// Customize the dialect with a specific SQL type for Timestamp casting: Timestamp, Datetime, etc. + pub fn with_timestamp_cast_dtype( + mut self, + timestamp_cast_dtype: ast::DataType, + timestamp_tz_cast_dtype: ast::DataType, + ) -> Self { + self.timestamp_cast_dtype = timestamp_cast_dtype; + self.timestamp_tz_cast_dtype = timestamp_tz_cast_dtype; + self + } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 2f7854c1a183..f4ea44f37d78 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -16,6 +16,13 @@ // under the License. use core::fmt; + +use datafusion_expr::ScalarUDF; +use sqlparser::ast::Value::SingleQuotedString; +use sqlparser::ast::{ + self, BinaryOperator, Expr as AstExpr, Function, FunctionArg, Ident, Interval, + TimezoneInfo, UnaryOperator, +}; use std::sync::Arc; use std::{fmt::Display, vec}; @@ -28,12 +35,6 @@ use arrow_array::types::{ }; use arrow_array::{Date32Array, Date64Array, PrimitiveArray}; use arrow_schema::DataType; -use sqlparser::ast::Value::SingleQuotedString; -use sqlparser::ast::{ - self, BinaryOperator, Expr as AstExpr, Function, FunctionArg, Ident, Interval, - TimezoneInfo, UnaryOperator, -}; - use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_err, Column, Result, ScalarValue, @@ -43,7 +44,7 @@ use datafusion_expr::{ Between, BinaryExpr, Case, Cast, Expr, GroupingSet, Like, Operator, TryCast, }; -use super::dialect::IntervalStyle; +use super::dialect::{DateFieldExtractStyle, IntervalStyle}; use super::Unparser; /// DataFusion's Exprs can represent either an `Expr` or an `OrderByExpr` @@ -149,6 +150,12 @@ impl Unparser<'_> { Expr::ScalarFunction(ScalarFunction { func, args }) => { let func_name = func.name(); + if let Some(expr) = + self.scalar_function_to_sql_overrides(func_name, func, args) + { + return Ok(expr); + } + let args = args .iter() .map(|e| { @@ -545,6 +552,38 @@ impl Unparser<'_> { } } + fn scalar_function_to_sql_overrides( + &self, + func_name: &str, + _func: &Arc, + args: &[Expr], + ) -> Option { + if func_name.to_lowercase() == "date_part" + && self.dialect.date_field_extract_style() == DateFieldExtractStyle::Extract + && args.len() == 2 + { + let date_expr = self.expr_to_sql(&args[1]).ok()?; + + if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &args[0] { + let field = match field.to_lowercase().as_str() { + "year" => ast::DateTimeField::Year, + "month" => ast::DateTimeField::Month, + "day" => ast::DateTimeField::Day, + "hour" => ast::DateTimeField::Hour, + "minute" => ast::DateTimeField::Minute, + "second" => ast::DateTimeField::Second, + _ => return None, + }; + + return Some(ast::Expr::Extract { + field, + expr: Box::new(date_expr), + }); + } + } + None + } + fn ast_type_for_date64_in_cast(&self) -> ast::DataType { if self.dialect.use_timestamp_for_date64() { ast::DataType::Timestamp(None, ast::TimezoneInfo::None) @@ -1105,6 +1144,131 @@ impl Unparser<'_> { } } + /// MySQL requires INTERVAL sql to be in the format: INTERVAL 1 YEAR + INTERVAL 1 MONTH + INTERVAL 1 DAY etc + /// `` + /// Interval sequence can't be wrapped in brackets - (INTERVAL 1 YEAR + INTERVAL 1 MONTH ...) so we need to generate + /// a single INTERVAL expression so it works correct for interval substraction cases + /// MySQL supports the DAY_MICROSECOND unit type (format is DAYS HOURS:MINUTES:SECONDS.MICROSECONDS), but it is not supported by sqlparser + /// so we calculate the best single interval to represent the provided duration + fn interval_to_mysql_expr( + &self, + months: i32, + days: i32, + microseconds: i64, + ) -> Result { + // MONTH only + if months != 0 && days == 0 && microseconds == 0 { + let interval = Interval { + value: Box::new(ast::Expr::Value(ast::Value::Number( + months.to_string(), + false, + ))), + leading_field: Some(ast::DateTimeField::Month), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + return Ok(ast::Expr::Interval(interval)); + } else if months != 0 { + return not_impl_err!("Unsupported Interval scalar with both Month and DayTime for IntervalStyle::MySQL"); + } + + // DAY only + if microseconds == 0 { + let interval = Interval { + value: Box::new(ast::Expr::Value(ast::Value::Number( + days.to_string(), + false, + ))), + leading_field: Some(ast::DateTimeField::Day), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + return Ok(ast::Expr::Interval(interval)); + } + + // calculate the best single interval to represent the provided days and microseconds + + let microseconds = microseconds + (days as i64 * 24 * 60 * 60 * 1_000_000); + + if microseconds % 1_000_000 != 0 { + let interval = Interval { + value: Box::new(ast::Expr::Value(ast::Value::Number( + microseconds.to_string(), + false, + ))), + leading_field: Some(ast::DateTimeField::Microsecond), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + return Ok(ast::Expr::Interval(interval)); + } + + let secs = microseconds / 1_000_000; + + if secs % 60 != 0 { + let interval = Interval { + value: Box::new(ast::Expr::Value(ast::Value::Number( + secs.to_string(), + false, + ))), + leading_field: Some(ast::DateTimeField::Second), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + return Ok(ast::Expr::Interval(interval)); + } + + let mins = secs / 60; + + if mins % 60 != 0 { + let interval = Interval { + value: Box::new(ast::Expr::Value(ast::Value::Number( + mins.to_string(), + false, + ))), + leading_field: Some(ast::DateTimeField::Minute), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + return Ok(ast::Expr::Interval(interval)); + } + + let hours = mins / 60; + + if hours % 24 != 0 { + let interval = Interval { + value: Box::new(ast::Expr::Value(ast::Value::Number( + hours.to_string(), + false, + ))), + leading_field: Some(ast::DateTimeField::Hour), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + return Ok(ast::Expr::Interval(interval)); + } + + let days = hours / 24; + + let interval = Interval { + value: Box::new(ast::Expr::Value(ast::Value::Number( + days.to_string(), + false, + ))), + leading_field: Some(ast::DateTimeField::Day), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + Ok(ast::Expr::Interval(interval)) + } + fn interval_scalar_to_sql(&self, v: &ScalarValue) -> Result { match self.dialect.interval_style() { IntervalStyle::PostgresVerbose => { @@ -1127,10 +1291,7 @@ impl Unparser<'_> { } // If the interval standard is SQLStandard, implement a simple unparse logic IntervalStyle::SQLStandard => match v { - ScalarValue::IntervalYearMonth(v) => { - let Some(v) = v else { - return Ok(ast::Expr::Value(ast::Value::Null)); - }; + ScalarValue::IntervalYearMonth(Some(v)) => { let interval = Interval { value: Box::new(ast::Expr::Value( ast::Value::SingleQuotedString(v.to_string()), @@ -1142,10 +1303,7 @@ impl Unparser<'_> { }; Ok(ast::Expr::Interval(interval)) } - ScalarValue::IntervalDayTime(v) => { - let Some(v) = v else { - return Ok(ast::Expr::Value(ast::Value::Null)); - }; + ScalarValue::IntervalDayTime(Some(v)) => { let days = v.days; let secs = v.milliseconds / 1_000; let mins = secs / 60; @@ -1168,11 +1326,7 @@ impl Unparser<'_> { }; Ok(ast::Expr::Interval(interval)) } - ScalarValue::IntervalMonthDayNano(v) => { - let Some(v) = v else { - return Ok(ast::Expr::Value(ast::Value::Null)); - }; - + ScalarValue::IntervalMonthDayNano(Some(v)) => { if v.months >= 0 && v.days == 0 && v.nanoseconds == 0 { let interval = Interval { value: Box::new(ast::Expr::Value( @@ -1184,10 +1338,7 @@ impl Unparser<'_> { fractional_seconds_precision: None, }; Ok(ast::Expr::Interval(interval)) - } else if v.months == 0 - && v.days >= 0 - && v.nanoseconds % 1_000_000 == 0 - { + } else if v.months == 0 && v.nanoseconds % 1_000_000 == 0 { let days = v.days; let secs = v.nanoseconds / 1_000_000_000; let mins = secs / 60; @@ -1214,11 +1365,29 @@ impl Unparser<'_> { not_impl_err!("Unsupported IntervalMonthDayNano scalar with both Month and DayTime for IntervalStyle::SQLStandard") } } - _ => Ok(ast::Expr::Value(ast::Value::Null)), + _ => not_impl_err!( + "Unsupported ScalarValue for Interval conversion: {v:?}" + ), + }, + IntervalStyle::MySQL => match v { + ScalarValue::IntervalYearMonth(Some(v)) => { + self.interval_to_mysql_expr(*v, 0, 0) + } + ScalarValue::IntervalDayTime(Some(v)) => { + self.interval_to_mysql_expr(0, v.days, v.milliseconds as i64 * 1_000) + } + ScalarValue::IntervalMonthDayNano(Some(v)) => { + if v.nanoseconds % 1_000 != 0 { + return not_impl_err!( + "Unsupported IntervalMonthDayNano scalar with nanoseconds precision for IntervalStyle::MySQL" + ); + } + self.interval_to_mysql_expr(v.months, v.days, v.nanoseconds / 1_000) + } + _ => not_impl_err!( + "Unsupported ScalarValue for Interval conversion: {v:?}" + ), }, - IntervalStyle::MySQL => { - not_impl_err!("Unsupported interval scalar for IntervalStyle::MySQL") - } } } @@ -1231,7 +1400,7 @@ impl Unparser<'_> { DataType::Int8 => Ok(ast::DataType::TinyInt(None)), DataType::Int16 => Ok(ast::DataType::SmallInt(None)), DataType::Int32 => Ok(ast::DataType::Integer(None)), - DataType::Int64 => Ok(ast::DataType::BigInt(None)), + DataType::Int64 => Ok(self.dialect.int64_cast_dtype()), DataType::UInt8 => Ok(ast::DataType::UnsignedTinyInt(None)), DataType::UInt16 => Ok(ast::DataType::UnsignedSmallInt(None)), DataType::UInt32 => Ok(ast::DataType::UnsignedInteger(None)), @@ -1241,13 +1410,8 @@ impl Unparser<'_> { } DataType::Float32 => Ok(ast::DataType::Float(None)), DataType::Float64 => Ok(self.dialect.float64_ast_dtype()), - DataType::Timestamp(_, tz) => { - let tz_info = match tz { - Some(_) => TimezoneInfo::WithTimeZone, - None => TimezoneInfo::None, - }; - - Ok(ast::DataType::Timestamp(None, tz_info)) + DataType::Timestamp(time_unit, tz) => { + Ok(self.dialect.timestamp_cast_dtype(time_unit, tz)) } DataType::Date32 => Ok(ast::DataType::Date), DataType::Date64 => Ok(self.ast_type_for_date64_in_cast()), @@ -1335,6 +1499,7 @@ mod tests { use arrow::datatypes::TimeUnit; use arrow::datatypes::{Field, Schema}; use arrow_schema::DataType::Int8; + use ast::ObjectName; use datafusion_common::TableReference; use datafusion_expr::{ case, col, cube, exists, grouping_set, interval_datetime_lit, @@ -1885,6 +2050,11 @@ mod tests { IntervalStyle::SQLStandard, "INTERVAL '1 12:0:0.000' DAY TO SECOND", ), + ( + interval_month_day_nano_lit("-1.5 DAY"), + IntervalStyle::SQLStandard, + "INTERVAL '-1 -12:0:0.000' DAY TO SECOND", + ), ( interval_month_day_nano_lit("1.51234 DAY"), IntervalStyle::SQLStandard, @@ -1949,6 +2119,46 @@ mod tests { IntervalStyle::PostgresVerbose, r#"INTERVAL '1 YEARS 7 MONS 0 DAYS 0 HOURS 0 MINS 0.00 SECS'"#, ), + ( + interval_year_month_lit("1 YEAR 1 MONTH"), + IntervalStyle::MySQL, + r#"INTERVAL 13 MONTH"#, + ), + ( + interval_month_day_nano_lit("1 YEAR -1 MONTH"), + IntervalStyle::MySQL, + r#"INTERVAL 11 MONTH"#, + ), + ( + interval_month_day_nano_lit("15 DAY"), + IntervalStyle::MySQL, + r#"INTERVAL 15 DAY"#, + ), + ( + interval_month_day_nano_lit("-40 HOURS"), + IntervalStyle::MySQL, + r#"INTERVAL -40 HOUR"#, + ), + ( + interval_datetime_lit("-1.5 DAY 1 HOUR"), + IntervalStyle::MySQL, + "INTERVAL -35 HOUR", + ), + ( + interval_datetime_lit("1000000 DAY 1.5 HOUR 10 MINUTE 20 SECOND"), + IntervalStyle::MySQL, + r#"INTERVAL 86400006020 SECOND"#, + ), + ( + interval_year_month_lit("0 DAY 0 HOUR"), + IntervalStyle::MySQL, + r#"INTERVAL 0 DAY"#, + ), + ( + interval_month_day_nano_lit("-1296000000 SECOND"), + IntervalStyle::MySQL, + r#"INTERVAL -15000 DAY"#, + ), ]; for (value, style, expected) in tests { @@ -1994,4 +2204,119 @@ mod tests { } Ok(()) } + + #[test] + fn custom_dialect_with_date_field_extract_style() -> Result<()> { + for (extract_style, unit, expected) in [ + ( + DateFieldExtractStyle::DatePart, + "YEAR", + "date_part('YEAR', x)", + ), + ( + DateFieldExtractStyle::Extract, + "YEAR", + "EXTRACT(YEAR FROM x)", + ), + ( + DateFieldExtractStyle::DatePart, + "MONTH", + "date_part('MONTH', x)", + ), + ( + DateFieldExtractStyle::Extract, + "MONTH", + "EXTRACT(MONTH FROM x)", + ), + ( + DateFieldExtractStyle::DatePart, + "DAY", + "date_part('DAY', x)", + ), + (DateFieldExtractStyle::Extract, "DAY", "EXTRACT(DAY FROM x)"), + ] { + let dialect = CustomDialectBuilder::new() + .with_date_field_extract_style(extract_style) + .build(); + + let unparser = Unparser::new(&dialect); + let expr = ScalarUDF::new_from_impl( + datafusion_functions::datetime::date_part::DatePartFunc::new(), + ) + .call(vec![Expr::Literal(ScalarValue::new_utf8(unit)), col("x")]); + + let ast = unparser.expr_to_sql(&expr)?; + let actual = format!("{}", ast); + + assert_eq!(actual, expected); + } + Ok(()) + } + + #[test] + fn custom_dialect_with_int64_cast_dtype() -> Result<()> { + let default_dialect = CustomDialectBuilder::new().build(); + let mysql_dialect = CustomDialectBuilder::new() + .with_int64_cast_dtype(ast::DataType::Custom( + ObjectName(vec![Ident::new("SIGNED")]), + vec![], + )) + .build(); + + for (dialect, identifier) in + [(default_dialect, "BIGINT"), (mysql_dialect, "SIGNED")] + { + let unparser = Unparser::new(&dialect); + let expr = Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Int64, + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = format!(r#"CAST(a AS {identifier})"#); + + assert_eq!(actual, expected); + } + Ok(()) + } + + #[test] + fn custom_dialect_with_teimstamp_cast_dtype() -> Result<()> { + let default_dialect = CustomDialectBuilder::new().build(); + let mysql_dialect = CustomDialectBuilder::new() + .with_timestamp_cast_dtype( + ast::DataType::Datetime(None), + ast::DataType::Datetime(None), + ) + .build(); + + let timestamp = DataType::Timestamp(TimeUnit::Nanosecond, None); + let timestamp_with_tz = + DataType::Timestamp(TimeUnit::Nanosecond, Some("+08:00".into())); + + for (dialect, data_type, identifier) in [ + (&default_dialect, ×tamp, "TIMESTAMP"), + ( + &default_dialect, + ×tamp_with_tz, + "TIMESTAMP WITH TIME ZONE", + ), + (&mysql_dialect, ×tamp, "DATETIME"), + (&mysql_dialect, ×tamp_with_tz, "DATETIME"), + ] { + let unparser = Unparser::new(dialect); + let expr = Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: data_type.clone(), + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = format!(r#"CAST(a AS {identifier})"#); + + assert_eq!(actual, expected); + } + Ok(()) + } }