From 9ddb41adfe43510d5572f144a3888bef8aa92964 Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Tue, 14 May 2024 18:35:31 +0200 Subject: [PATCH] Add Cast operator --- diesel/src/expression/cast.rs | 152 ++++++++++++++++++ diesel/src/expression/helper_types.rs | 4 + diesel/src/expression/mod.rs | 1 + .../global_expression_methods.rs | 43 ++++- 4 files changed, 199 insertions(+), 1 deletion(-) create mode 100644 diesel/src/expression/cast.rs diff --git a/diesel/src/expression/cast.rs b/diesel/src/expression/cast.rs new file mode 100644 index 000000000000..c638979fdfd6 --- /dev/null +++ b/diesel/src/expression/cast.rs @@ -0,0 +1,152 @@ +//! SQL `CAST(expr AS sql_type)` expression support + +use crate::expression::{AppearsOnTable, Expression, SelectableExpression, ValidGrouping}; +use crate::query_source::aliasing::{AliasSource, FieldAliasMapper}; +use crate::result::QueryResult; +use crate::{query_builder, query_source, sql_types}; + +use std::marker::PhantomData; + +pub(crate) mod private { + use super::*; + + #[derive(Debug, Clone, Copy, diesel::query_builder::QueryId, sql_types::DieselNumericOps)] + pub struct Cast { + pub(super) expr: E, + pub(super) sql_type: PhantomData, + } +} +pub(crate) use private::Cast; + +impl Cast { + pub(crate) fn new(expr: E) -> Self { + Self { + expr, + sql_type: PhantomData, + } + } +} + +impl ValidGrouping for Cast +where + E: ValidGrouping, +{ + type IsAggregate = E::IsAggregate; +} + +impl SelectableExpression for Cast +where + Cast: AppearsOnTable, + E: SelectableExpression, +{ +} + +impl AppearsOnTable for Cast +where + Cast: Expression, + E: AppearsOnTable, +{ +} + +impl Expression for Cast +where + E: Expression, + ST: sql_types::SingleValue, +{ + type SqlType = ST; +} +impl query_builder::QueryFragment for Cast +where + E: query_builder::QueryFragment, + DB: diesel::backend::Backend, + ST: KnownCastSqlTypeName, +{ + fn walk_ast<'b>(&'b self, mut out: query_builder::AstPass<'_, 'b, DB>) -> QueryResult<()> { + out.push_sql("CAST("); + self.expr.walk_ast(out.reborrow())?; + out.push_sql(" AS "); + out.push_sql(ST::sql_type_name()); + out.push_sql(")"); + Ok(()) + } +} + +/// We know what to write as `sql_type` in the `CAST(expr AS sql_type)` SQL for +/// `Self` +/// +/// That is what is returned by `Self::sql_type_name()` +#[diagnostic::on_unimplemented( + note = "In order to use `CAST`, it is necessary that Diesel knows how to express the name \ + of this type in the given backend.", + note = "This can be PRed into Diesel if the type is a standard SQL type." +)] +pub trait KnownCastSqlTypeName { + /// What to write as `sql_type` in the `CAST(expr AS sql_type)` SQL for + /// `Self` + fn sql_type_name() -> &'static str; +} +impl KnownCastSqlTypeName for sql_types::Nullable +where + ST: KnownCastSqlTypeName, +{ + fn sql_type_name() -> &'static str { + >::sql_type_name() + } +} + +macro_rules! type_name { + ($($backend: ty: $backend_feature: literal { $($type: ident => $val: literal,)+ })*) => { + $( + $( + #[cfg(feature = $backend_feature)] + impl KnownCastSqlTypeName<$backend> for sql_types::$type { + fn sql_type_name() -> &'static str { + $val + } + } + )* + )* + }; +} +type_name! { + diesel::pg::Pg: "postgres_backend" { + Int4 => "int4", + Int8 => "int8", + Text => "text", + } + diesel::mysql::Mysql: "mysql_backend" { + Int4 => "integer", + Int8 => "integer", + Text => "char", + } + diesel::sqlite::Sqlite: "sqlite" { + Int4 => "integer", + Int8 => "bigint", + Text => "text", + } +} + +impl FieldAliasMapper for Cast +where + S: AliasSource, + E: FieldAliasMapper, +{ + type Out = Cast<>::Out, ST>; + fn map(self, alias: &query_source::Alias) -> Self::Out { + Cast { + expr: self.expr.map(alias), + sql_type: self.sql_type, + } + } +} + +/// Marker trait: this SQL type (`Self`) can be casted to the target SQL type +/// (`ST`) using `CAST(expr AS target_sql_type)` +pub trait CastsTo {} +impl CastsTo> for sql_types::Nullable where ST1: CastsTo +{} + +impl CastsTo for sql_types::Int4 {} +impl CastsTo for sql_types::Int8 {} +impl CastsTo for sql_types::Int4 {} +impl CastsTo for sql_types::Int8 {} diff --git a/diesel/src/expression/helper_types.rs b/diesel/src/expression/helper_types.rs index 8c27e7bfd51f..4d91223e5b67 100644 --- a/diesel/src/expression/helper_types.rs +++ b/diesel/src/expression/helper_types.rs @@ -90,6 +90,10 @@ pub type NotBetween = Grouped< /// [`lhs.concat(rhs)`](crate::expression_methods::TextExpressionMethods::concat()) pub type Concat = Grouped>>; +/// The return type of +/// [`expr.cast()`](crate::expression_methods::ExpressionMethods::cast()) +pub type Cast = super::cast::Cast; + /// The return type of /// [`expr.desc()`](crate::expression_methods::ExpressionMethods::desc()) pub type Desc = super::operators::Desc; diff --git a/diesel/src/expression/mod.rs b/diesel/src/expression/mod.rs index 97787238d4bd..b23dca181da6 100644 --- a/diesel/src/expression/mod.rs +++ b/diesel/src/expression/mod.rs @@ -37,6 +37,7 @@ pub(crate) mod nullable; #[macro_use] pub(crate) mod operators; mod case_when; +pub mod cast; pub(crate) mod select_by; mod sql_literal; pub(crate) mod subselect; diff --git a/diesel/src/expression_methods/global_expression_methods.rs b/diesel/src/expression_methods/global_expression_methods.rs index 893867a0bf8f..f819774be5cb 100644 --- a/diesel/src/expression_methods/global_expression_methods.rs +++ b/diesel/src/expression_methods/global_expression_methods.rs @@ -2,7 +2,7 @@ use crate::dsl; use crate::expression::array_comparison::{AsInExpression, In, NotIn}; use crate::expression::grouped::Grouped; use crate::expression::operators::*; -use crate::expression::{assume_not_null, nullable, AsExpression, Expression}; +use crate::expression::{assume_not_null, cast, nullable, AsExpression, Expression}; use crate::sql_types::{SingleValue, SqlType}; /// Methods present on all expressions, except tuples @@ -437,6 +437,47 @@ pub trait ExpressionMethods: Expression + Sized { )) } + /// Generates a `CAST(expr AS sql_type)` expression + /// + /// It is necessary that the expression's SQL type can be casted to the + /// target SQL type (represented by the [`CastsTo`](cast::CastsTo) trait), + /// and that we know how the corresponding SQL type is named for the + /// specific backend (represented by the + /// [`KnownCastSqlTypeName`](cast::KnownCastSqlTypeName) trait). + /// + /// # Example + /// + /// ```rust + /// # include!("../doctest_setup.rs"); + /// # + /// # fn main() { + /// # run_test().unwrap(); + /// # } + /// # + /// # fn run_test() -> QueryResult<()> { + /// # use schema::animals::dsl::*; + /// # let connection = &mut establish_connection(); + /// # + /// use diesel::sql_types; + /// + /// let data = diesel::select( + /// 12_i32 + /// .into_sql::() + /// .cast::(), + /// ) + /// .first::(connection)?; + /// assert_eq!("12", data); + /// # Ok(()) + /// # } + /// ``` + fn cast(self) -> dsl::Cast + where + ST: SingleValue, + Self::SqlType: cast::CastsTo, + { + cast::Cast::new(self) + } + /// Creates a SQL `DESC` expression, representing this expression in /// descending order. ///