Skip to content

Commit

Permalink
Add Cast operator
Browse files Browse the repository at this point in the history
  • Loading branch information
Ten0 committed May 14, 2024
1 parent 56a39c0 commit 074fb91
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 1 deletion.
152 changes: 152 additions & 0 deletions diesel/src/expression/cast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
//! SQL `CAST(expr AS sql_type)` expression support

use crate::expression::{AppearsOnTable, Expression, SelectableExpression, ValidGrouping};
use crate::query_source::aliasing::{AliasSource, FieldAliasMapper};
use crate::result::QueryResult;
use crate::{query_builder, query_source, sql_types};

use std::marker::PhantomData;

pub(crate) mod private {
use super::*;

#[derive(Debug, Clone, Copy, diesel::query_builder::QueryId, sql_types::DieselNumericOps)]
pub struct Cast<E, ST> {
pub(super) expr: E,
pub(super) sql_type: PhantomData<ST>,
}
}
pub(crate) use private::Cast;

impl<E, ST> Cast<E, ST> {
pub(crate) fn new(expr: E) -> Self {
Self {
expr,
sql_type: PhantomData,
}
}
}

impl<E, ST, GroupByClause> ValidGrouping<GroupByClause> for Cast<E, ST>
where
E: ValidGrouping<GroupByClause>,
{
type IsAggregate = E::IsAggregate;
}

impl<E, ST, QS> SelectableExpression<QS> for Cast<E, ST>
where
Cast<E, ST>: AppearsOnTable<QS>,
E: SelectableExpression<QS>,
{
}

impl<E, ST, QS> AppearsOnTable<QS> for Cast<E, ST>
where
Cast<E, ST>: Expression,
E: AppearsOnTable<QS>,
{
}

impl<E, ST> Expression for Cast<E, ST>
where
E: Expression,
ST: sql_types::SingleValue,
{
type SqlType = ST;
}
impl<E, ST, DB> query_builder::QueryFragment<DB> for Cast<E, ST>
where
E: query_builder::QueryFragment<DB>,
DB: diesel::backend::Backend,
ST: KnownCastSqlTypeName<DB>,
{
fn walk_ast<'b>(&'b self, mut out: query_builder::AstPass<'_, 'b, DB>) -> QueryResult<()> {
out.push_sql("CAST(");
self.expr.walk_ast(out.reborrow())?;
out.push_sql(" AS ");
out.push_sql(ST::sql_type_name());
out.push_sql(")");
Ok(())
}
}

/// We know what to write as `sql_type` in the `CAST(expr AS sql_type)` SQL for
/// `Self`
///
/// That is what is returned by `Self::sql_type_name()`
#[diagnostic::on_unimplemented(
note = "In order to use `CAST`, it is necessary that Diesel knows how to express the name \
of this type in the given backend.",
note = "This can be PRed into Diesel if the type is a standard SQL type."
)]
pub trait KnownCastSqlTypeName<DB> {
/// What to write as `sql_type` in the `CAST(expr AS sql_type)` SQL for
/// `Self`
fn sql_type_name() -> &'static str;
}
impl<ST, DB> KnownCastSqlTypeName<DB> for sql_types::Nullable<ST>
where
ST: KnownCastSqlTypeName<DB>,
{
fn sql_type_name() -> &'static str {
<ST as KnownCastSqlTypeName<DB>>::sql_type_name()
}
}

macro_rules! type_name {
($($backend: ty: $backend_feature: literal { $($type: ident => $val: literal,)+ })*) => {
$(
$(
#[cfg(feature = $backend_feature)]
impl KnownCastSqlTypeName<$backend> for sql_types::$type {
fn sql_type_name() -> &'static str {
$val
}
}
)*
)*
};
}
type_name! {
diesel::pg::Pg: "postgres_backend" {
Int4 => "int4",
Int8 => "int8",
Text => "text",
}
diesel::mysql::Mysql: "mysql_backend" {
Int4 => "integer",
Int8 => "integer",
Text => "text",
}
diesel::sqlite::Sqlite: "sqlite" {
Int4 => "integer",
Int8 => "bigint",
Text => "text",
}
}

impl<S, E, ST> FieldAliasMapper<S> for Cast<E, ST>
where
S: AliasSource,
E: FieldAliasMapper<S>,
{
type Out = Cast<<E as FieldAliasMapper<S>>::Out, ST>;
fn map(self, alias: &query_source::Alias<S>) -> Self::Out {
Cast {
expr: self.expr.map(alias),
sql_type: self.sql_type,
}
}
}

/// Marker trait: this SQL type (`Self`) can be casted to the target SQL type
/// (`ST`) using `CAST(expr AS target_sql_type)`
pub trait CastsTo<ST> {}
impl<ST1, ST2> CastsTo<sql_types::Nullable<ST2>> for sql_types::Nullable<ST1> where ST1: CastsTo<ST2>
{}

impl CastsTo<sql_types::Int8> for sql_types::Int4 {}
impl CastsTo<sql_types::Int4> for sql_types::Int8 {}
impl CastsTo<sql_types::Text> for sql_types::Int4 {}
impl CastsTo<sql_types::Text> for sql_types::Int8 {}
4 changes: 4 additions & 0 deletions diesel/src/expression/helper_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ pub type NotBetween<Lhs, Lower, Upper> = Grouped<
/// [`lhs.concat(rhs)`](crate::expression_methods::TextExpressionMethods::concat())
pub type Concat<Lhs, Rhs> = Grouped<super::operators::Concat<Lhs, AsExpr<Rhs, Lhs>>>;

/// The return type of
/// [`expr.cast_to<ST>()`](crate::expression_methods::AsExpression::cast_to())
pub type Cast<Expr, ST> = super::cast::Cast<Expr, ST>;

/// The return type of
/// [`expr.desc()`](crate::expression_methods::ExpressionMethods::desc())
pub type Desc<Expr> = super::operators::Desc<Expr>;
Expand Down
1 change: 1 addition & 0 deletions diesel/src/expression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub(crate) mod nullable;
#[macro_use]
pub(crate) mod operators;
mod case_when;
pub mod cast;
pub(crate) mod select_by;
mod sql_literal;
pub(crate) mod subselect;
Expand Down
43 changes: 42 additions & 1 deletion diesel/src/expression_methods/global_expression_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::dsl;
use crate::expression::array_comparison::{AsInExpression, In, NotIn};
use crate::expression::grouped::Grouped;
use crate::expression::operators::*;
use crate::expression::{assume_not_null, nullable, AsExpression, Expression};
use crate::expression::{assume_not_null, cast, nullable, AsExpression, Expression};
use crate::sql_types::{SingleValue, SqlType};

/// Methods present on all expressions, except tuples
Expand Down Expand Up @@ -437,6 +437,47 @@ pub trait ExpressionMethods: Expression + Sized {
))
}

/// Generates a `CAST(expr AS sql_type)` expression
///
/// It is necessary that the expression's SQL type can be casted to the
/// target SQL type (represented by the [`CastsTo`](cast::CastsTo) trait),
/// and that we know how the corresponding SQL type is named for the
/// specific backend (represented by the
/// [`KnownCastSqlTypeName`](cast::KnownCastSqlTypeName) trait).
///
/// # Example
///
/// ```rust
/// # include!("../doctest_setup.rs");
/// #
/// # fn main() {
/// # run_test().unwrap();
/// # }
/// #
/// # fn run_test() -> QueryResult<()> {
/// # use schema::animals::dsl::*;
/// # let connection = &mut establish_connection();
/// #
/// use diesel::sql_types;
///
/// let data = diesel::select(
/// 12_i32
/// .into_sql::<sql_types::Int4>()
/// .cast::<sql_types::Text>(),
/// )
/// .first::<String>(connection)?;
/// assert_eq!("12", data);
/// # Ok(())
/// # }
/// ```
fn cast<ST>(self) -> dsl::Cast<Self, ST>
where
ST: SingleValue,
Self::SqlType: cast::CastsTo<ST>,
{
cast::Cast::new(self)
}

/// Creates a SQL `DESC` expression, representing this expression in
/// descending order.
///
Expand Down

0 comments on commit 074fb91

Please sign in to comment.