Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Cast operator #4024

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions diesel/src/expression/cast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
//! SQL `CAST(expr AS sql_type)` expression support

use crate::expression::{AppearsOnTable, Expression, SelectableExpression, ValidGrouping};
use crate::query_source::aliasing::{AliasSource, FieldAliasMapper};
use crate::result::QueryResult;
use crate::{query_builder, query_source, sql_types};

use std::marker::PhantomData;

pub(crate) mod private {
use super::*;

#[derive(Debug, Clone, Copy, diesel::query_builder::QueryId, sql_types::DieselNumericOps)]
pub struct Cast<E, ST> {
pub(super) expr: E,
pub(super) sql_type: PhantomData<ST>,
}
}
pub(crate) use private::Cast;

impl<E, ST> Cast<E, ST> {
pub(crate) fn new(expr: E) -> Self {
Self {
expr,
sql_type: PhantomData,
}
}
}

impl<E, ST, GroupByClause> ValidGrouping<GroupByClause> for Cast<E, ST>
where
E: ValidGrouping<GroupByClause>,
{
type IsAggregate = E::IsAggregate;
}

impl<E, ST, QS> SelectableExpression<QS> for Cast<E, ST>
where
Cast<E, ST>: AppearsOnTable<QS>,
E: SelectableExpression<QS>,
{
}

impl<E, ST, QS> AppearsOnTable<QS> for Cast<E, ST>
where
Cast<E, ST>: Expression,
E: AppearsOnTable<QS>,
{
}

impl<E, ST> Expression for Cast<E, ST>
where
E: Expression,
ST: sql_types::SingleValue,
{
type SqlType = ST;
}
impl<E, ST, DB> query_builder::QueryFragment<DB> for Cast<E, ST>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
impl<E, ST, DB> query_builder::QueryFragment<DB> for Cast<E, ST>
impl<E, ST, DB> query_builder::QueryFragment<DB> for Cast<E, ST>

where
E: query_builder::QueryFragment<DB>,
DB: diesel::backend::Backend,
ST: KnownCastSqlTypeName<DB>,
{
fn walk_ast<'b>(&'b self, mut out: query_builder::AstPass<'_, 'b, DB>) -> QueryResult<()> {
out.push_sql("CAST(");
self.expr.walk_ast(out.reborrow())?;
out.push_sql(" AS ");
out.push_sql(ST::sql_type_name());
out.push_sql(")");
Ok(())
}
}

/// We know what to write as `sql_type` in the `CAST(expr AS sql_type)` SQL for
/// `Self`
///
/// That is what is returned by `Self::sql_type_name()`
#[diagnostic::on_unimplemented(
note = "In order to use `CAST`, it is necessary that Diesel knows how to express the name \
of this type in the given backend.",
note = "This can be PRed into Diesel if the type is a standard SQL type."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We likely want to change this message to something more active like: If you run into this error message and believe that this cast should be supported open a PR that adds that trait implementation there (with link to the source or file path)

)]
pub trait KnownCastSqlTypeName<DB> {
/// What to write as `sql_type` in the `CAST(expr AS sql_type)` SQL for
/// `Self`
fn sql_type_name() -> &'static str;
}
impl<ST, DB> KnownCastSqlTypeName<DB> for sql_types::Nullable<ST>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
impl<ST, DB> KnownCastSqlTypeName<DB> for sql_types::Nullable<ST>
impl<ST, DB> KnownCastSqlTypeName<DB> for sql_types::Nullable<ST>

where
ST: KnownCastSqlTypeName<DB>,
{
fn sql_type_name() -> &'static str {
<ST as KnownCastSqlTypeName<DB>>::sql_type_name()
}
}

macro_rules! type_name {
($($backend: ty: $backend_feature: literal { $($type: ident => $val: literal,)+ })*) => {
$(
$(
#[cfg(feature = $backend_feature)]
impl KnownCastSqlTypeName<$backend> for sql_types::$type {
fn sql_type_name() -> &'static str {
$val
}
}
)*
)*
};
}
type_name! {
diesel::pg::Pg: "postgres_backend" {
Int4 => "int4",
Int8 => "int8",
Text => "text",
}
diesel::mysql::Mysql: "mysql_backend" {
Int4 => "integer",
Int8 => "integer",
Text => "char",
}
diesel::sqlite::Sqlite: "sqlite" {
Int4 => "integer",
Int8 => "bigint",
Text => "text",
}
}

impl<S, E, ST> FieldAliasMapper<S> for Cast<E, ST>
where
S: AliasSource,
E: FieldAliasMapper<S>,
{
type Out = Cast<<E as FieldAliasMapper<S>>::Out, ST>;
fn map(self, alias: &query_source::Alias<S>) -> Self::Out {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fn map(self, alias: &query_source::Alias<S>) -> Self::Out {
fn map(self, alias: &query_source::Alias<S>) -> Self::Out {

Cast {
expr: self.expr.map(alias),
sql_type: self.sql_type,
}
}
}

/// Marker trait: this SQL type (`Self`) can be casted to the target SQL type
/// (`ST`) using `CAST(expr AS target_sql_type)`
pub trait CastsTo<ST> {}
impl<ST1, ST2> CastsTo<sql_types::Nullable<ST2>> for sql_types::Nullable<ST1> where ST1: CastsTo<ST2>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
impl<ST1, ST2> CastsTo<sql_types::Nullable<ST2>> for sql_types::Nullable<ST1> where ST1: CastsTo<ST2>
impl<ST1, ST2> CastsTo<sql_types::Nullable<ST2>> for sql_types::Nullable<ST1> where ST1: CastsTo<ST2>

{}

impl CastsTo<sql_types::Int8> for sql_types::Int4 {}
impl CastsTo<sql_types::Int4> for sql_types::Int8 {}
impl CastsTo<sql_types::Text> for sql_types::Int4 {}
impl CastsTo<sql_types::Text> for sql_types::Int8 {}
4 changes: 4 additions & 0 deletions diesel/src/expression/helper_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ pub type NotBetween<Lhs, Lower, Upper> = Grouped<
/// [`lhs.concat(rhs)`](crate::expression_methods::TextExpressionMethods::concat())
pub type Concat<Lhs, Rhs> = Grouped<super::operators::Concat<Lhs, AsExpr<Rhs, Lhs>>>;

/// The return type of
/// [`expr.cast<ST>()`](crate::expression_methods::ExpressionMethods::cast())
pub type Cast<Expr, ST> = super::cast::Cast<Expr, ST>;

/// The return type of
/// [`expr.desc()`](crate::expression_methods::ExpressionMethods::desc())
pub type Desc<Expr> = super::operators::Desc<Expr>;
Expand Down
1 change: 1 addition & 0 deletions diesel/src/expression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub(crate) mod nullable;
#[macro_use]
pub(crate) mod operators;
mod case_when;
pub mod cast;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to keep the module private and just reexport the necessary items from the expression module.

pub(crate) mod select_by;
mod sql_literal;
pub(crate) mod subselect;
Expand Down
43 changes: 42 additions & 1 deletion diesel/src/expression_methods/global_expression_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::dsl;
use crate::expression::array_comparison::{AsInExpression, In, NotIn};
use crate::expression::grouped::Grouped;
use crate::expression::operators::*;
use crate::expression::{assume_not_null, nullable, AsExpression, Expression};
use crate::expression::{assume_not_null, cast, nullable, AsExpression, Expression};
use crate::sql_types::{SingleValue, SqlType};

/// Methods present on all expressions, except tuples
Expand Down Expand Up @@ -437,6 +437,47 @@ pub trait ExpressionMethods: Expression + Sized {
))
}

/// Generates a `CAST(expr AS sql_type)` expression
///
/// It is necessary that the expression's SQL type can be casted to the
/// target SQL type (represented by the [`CastsTo`](cast::CastsTo) trait),
/// and that we know how the corresponding SQL type is named for the
/// specific backend (represented by the
/// [`KnownCastSqlTypeName`](cast::KnownCastSqlTypeName) trait).
///
/// # Example
///
/// ```rust
/// # include!("../doctest_setup.rs");
/// #
/// # fn main() {
/// # run_test().unwrap();
/// # }
/// #
/// # fn run_test() -> QueryResult<()> {
/// # use schema::animals::dsl::*;
/// # let connection = &mut establish_connection();
/// #
/// use diesel::sql_types;
///
/// let data = diesel::select(
/// 12_i32
/// .into_sql::<sql_types::Int4>()
/// .cast::<sql_types::Text>(),
/// )
/// .first::<String>(connection)?;
/// assert_eq!("12", data);
/// # Ok(())
/// # }
/// ```
fn cast<ST>(self) -> dsl::Cast<Self, ST>
where
ST: SingleValue,
Self::SqlType: cast::CastsTo<ST>,
{
cast::Cast::new(self)
}

/// Creates a SQL `DESC` expression, representing this expression in
/// descending order.
///
Expand Down
Loading