diff --git a/diesel/src/expression/case_when.rs b/diesel/src/expression/case_when.rs new file mode 100644 index 000000000000..e5c4df80279a --- /dev/null +++ b/diesel/src/expression/case_when.rs @@ -0,0 +1,415 @@ +use crate::expression::grouped::Grouped; +use crate::expression::{helper_types, Expression}; +use crate::sql_types::{BoolOrNullableBool, SqlType}; +use diesel_derives::{DieselNumericOps, QueryId, ValidGrouping}; + +use super::{AsExpression, TypedExpressionType}; + +/// Creates a SQL `CASE WHEN ... END` expression +/// +/// # Example +/// +/// ``` +/// # include!("../doctest_setup.rs"); +/// # +/// # fn main() { +/// # use schema::users::dsl::*; +/// # let connection = &mut establish_connection(); +/// use diesel::dsl::case_when; +/// +/// let users_with_name: Vec<(i32, Option)> = users +/// .select((id, case_when(name.eq("Sean"), id))) +/// .load(connection) +/// .unwrap(); +/// +/// assert_eq!(&[(1, Some(1)), (2, None)], users_with_name.as_slice()); +/// # } +/// ``` +/// +/// # `ELSE` clause +/// ``` +/// # include!("../doctest_setup.rs"); +/// # +/// # fn main() { +/// # use schema::users::dsl::*; +/// # let connection = &mut establish_connection(); +/// use diesel::dsl::case_when; +/// +/// let users_with_name: Vec<(i32, i32)> = users +/// .select((id, case_when(name.eq("Sean"), id).otherwise(0))) +/// .load(connection) +/// .unwrap(); +/// +/// assert_eq!(&[(1, 1), (2, 0)], users_with_name.as_slice()); +/// # } +/// ``` +/// +/// Note that the SQL types of the `case_when` and `else` expressions should +/// be equal. This includes whether they are wrapped in +/// [`Nullable`](crate::sql_types::Nullable), so you may need to call +/// [`nullable`](crate::expression_methods::NullableExpressionMethods::nullable) +/// on one of them. +/// +/// # More `WHEN` branches +/// ``` +/// # include!("../doctest_setup.rs"); +/// # +/// # fn main() { +/// # use schema::users::dsl::*; +/// # let connection = &mut establish_connection(); +/// use diesel::dsl::case_when; +/// +/// let users_with_name: Vec<(i32, Option)> = users +/// .select((id, case_when(name.eq("Sean"), id).when(name.eq("Tess"), 2))) +/// .load(connection) +/// .unwrap(); +/// +/// assert_eq!(&[(1, Some(1)), (2, Some(2))], users_with_name.as_slice()); +/// # } +/// ``` +pub fn case_when(condition: C, if_true: T) -> helper_types::case_when +where + C: Expression, + ::SqlType: BoolOrNullableBool, + T: AsExpression, + ST: SqlType + TypedExpressionType, +{ + CaseWhen { + whens: CaseWhenConditionsLeaf { + when: Grouped(condition), + then: Grouped(if_true.as_expression()), + }, + else_expr: NoElseExpression, + } +} + +/// A SQL `CASE WHEN ... END` expression +#[derive(Debug, Clone, Copy, QueryId, DieselNumericOps, ValidGrouping)] +pub struct CaseWhen { + whens: Whens, + else_expr: E, +} + +impl CaseWhen { + /// Add an additional `WHEN ... THEN ...` branch to the `CASE` expression + /// + /// See the [`case_when`] documentation for more details. + pub fn when(self, condition: C, if_true: T) -> helper_types::When + where + Self: CaseWhenTypesExtractor, + C: Expression, + ::SqlType: BoolOrNullableBool, + T: AsExpression<::OutputExpressionSpecifiedSqlType>, + { + CaseWhen { + whens: CaseWhenConditionsIntermediateNode { + first_whens: self.whens, + last_when: CaseWhenConditionsLeaf { + when: Grouped(condition), + then: Grouped(if_true.as_expression()), + }, + }, + else_expr: self.else_expr, + } + } +} + +impl CaseWhen { + /// Sets the `ELSE` branch of the `CASE` expression + /// + /// It is named this way because `else` is a reserved keyword in Rust + /// + /// See the [`case_when`] documentation for more details. + pub fn otherwise(self, if_no_other_branch_matched: E) -> helper_types::Otherwise + where + Self: CaseWhenTypesExtractor, + E: AsExpression<::OutputExpressionSpecifiedSqlType>, + { + CaseWhen { + whens: self.whens, + else_expr: ElseExpression { + expr: Grouped(if_no_other_branch_matched.as_expression()), + }, + } + } +} + +pub(crate) use non_public_types::*; +mod non_public_types { + use super::CaseWhen; + + use diesel_derives::{QueryId, ValidGrouping}; + + use crate::expression::{ + AppearsOnTable, Expression, SelectableExpression, TypedExpressionType, + }; + use crate::query_builder::{AstPass, QueryFragment}; + use crate::query_source::aliasing; + use crate::sql_types::{BoolOrNullableBool, IntoNullable, SqlType}; + + #[derive(Debug, Clone, Copy, QueryId, ValidGrouping)] + pub struct CaseWhenConditionsLeaf { + pub(super) when: W, + pub(super) then: T, + } + + #[derive(Debug, Clone, Copy, QueryId, ValidGrouping)] + pub struct CaseWhenConditionsIntermediateNode { + pub(super) first_whens: Whens, + pub(super) last_when: CaseWhenConditionsLeaf, + } + + pub trait CaseWhenConditions { + type OutputExpressionSpecifiedSqlType: SqlType + TypedExpressionType; + } + impl CaseWhenConditions for CaseWhenConditionsLeaf + where + ::SqlType: SqlType + TypedExpressionType, + { + type OutputExpressionSpecifiedSqlType = T::SqlType; + } + // This intentionally doesn't re-check inner `Whens` here, because this trait is + // only used to allow expression SQL type inference for `.when` calls so we + // want to make it as lightweight as possible for fast compilation. Actual + // guarantees are provided by the other implementations below + impl CaseWhenConditions for CaseWhenConditionsIntermediateNode + where + ::SqlType: SqlType + TypedExpressionType, + { + type OutputExpressionSpecifiedSqlType = T::SqlType; + } + + #[derive(Debug, Clone, Copy, QueryId, ValidGrouping)] + pub struct NoElseExpression; + #[derive(Debug, Clone, Copy, QueryId, ValidGrouping)] + pub struct ElseExpression { + pub(super) expr: E, + } + + /// Largely internal trait used to define the [`When`] and [`Otherwise`] + /// type aliases + /// + /// It should typically not be needed in user code unless writing extremely + /// generic functions + pub trait CaseWhenTypesExtractor { + /// The + /// This may not be the actual output expression type: if there is no + /// `else` it will be made `Nullable` + type OutputExpressionSpecifiedSqlType: SqlType + TypedExpressionType; + type Whens; + type Else; + } + impl CaseWhenTypesExtractor for CaseWhen + where + Whens: CaseWhenConditions, + { + type OutputExpressionSpecifiedSqlType = Whens::OutputExpressionSpecifiedSqlType; + type Whens = Whens; + type Else = E; + } + + impl SelectableExpression for CaseWhen, NoElseExpression> + where + CaseWhen, NoElseExpression>: AppearsOnTable, + W: SelectableExpression, + T: SelectableExpression, + { + } + + impl SelectableExpression + for CaseWhen, ElseExpression> + where + CaseWhen, ElseExpression>: AppearsOnTable, + W: SelectableExpression, + T: SelectableExpression, + E: SelectableExpression, + { + } + + impl SelectableExpression + for CaseWhen, E> + where + Self: AppearsOnTable, + W: SelectableExpression, + T: SelectableExpression, + CaseWhen: SelectableExpression, + { + } + + impl AppearsOnTable for CaseWhen, NoElseExpression> + where + CaseWhen, NoElseExpression>: Expression, + W: AppearsOnTable, + T: AppearsOnTable, + { + } + + impl AppearsOnTable for CaseWhen, ElseExpression> + where + CaseWhen, ElseExpression>: Expression, + W: AppearsOnTable, + T: AppearsOnTable, + E: AppearsOnTable, + { + } + + impl AppearsOnTable + for CaseWhen, E> + where + Self: Expression, + W: AppearsOnTable, + T: AppearsOnTable, + CaseWhen: AppearsOnTable, + { + } + + impl Expression for CaseWhen, NoElseExpression> + where + W: Expression, + ::SqlType: BoolOrNullableBool, + T: Expression, + ::SqlType: IntoNullable, + <::SqlType as IntoNullable>::Nullable: SqlType + TypedExpressionType, + { + type SqlType = <::SqlType as IntoNullable>::Nullable; + } + impl Expression for CaseWhen, ElseExpression> + where + W: Expression, + ::SqlType: BoolOrNullableBool, + T: Expression, + { + type SqlType = T::SqlType; + } + impl Expression for CaseWhen, E> + where + CaseWhen, E>: Expression, + CaseWhen: Expression< + SqlType = , E> as Expression>::SqlType, + >, + { + type SqlType = , E> as Expression>::SqlType; + } + + impl QueryFragment for CaseWhen + where + DB: crate::backend::Backend, + Whens: QueryFragment, + E: QueryFragment, + { + fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> crate::QueryResult<()> { + out.push_sql("CASE"); + self.whens.walk_ast(out.reborrow())?; + self.else_expr.walk_ast(out.reborrow())?; + out.push_sql(" END"); + Ok(()) + } + } + + impl QueryFragment for CaseWhenConditionsLeaf + where + DB: crate::backend::Backend, + W: QueryFragment, + T: QueryFragment, + { + fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> crate::QueryResult<()> { + out.push_sql(" WHEN "); + self.when.walk_ast(out.reborrow())?; + out.push_sql(" THEN "); + self.then.walk_ast(out.reborrow())?; + Ok(()) + } + } + + impl QueryFragment for CaseWhenConditionsIntermediateNode + where + DB: crate::backend::Backend, + Whens: QueryFragment, + W: QueryFragment, + T: QueryFragment, + { + fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> crate::QueryResult<()> { + self.first_whens.walk_ast(out.reborrow())?; + self.last_when.walk_ast(out.reborrow())?; + Ok(()) + } + } + + impl QueryFragment for NoElseExpression + where + DB: crate::backend::Backend, + { + fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> crate::result::QueryResult<()> { + let _ = out; + Ok(()) + } + } + impl QueryFragment for ElseExpression + where + E: QueryFragment, + DB: crate::backend::Backend, + { + fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> crate::result::QueryResult<()> { + out.push_sql(" ELSE "); + self.expr.walk_ast(out.reborrow())?; + Ok(()) + } + } + + impl aliasing::FieldAliasMapper for CaseWhen + where + S: aliasing::AliasSource, + Conditions: aliasing::FieldAliasMapper, + E: aliasing::FieldAliasMapper, + { + type Out = CaseWhen< + >::Out, + >::Out, + >; + fn map(self, alias: &aliasing::Alias) -> Self::Out { + CaseWhen { + whens: self.whens.map(alias), + else_expr: self.else_expr.map(alias), + } + } + } + + impl aliasing::FieldAliasMapper for CaseWhenConditionsLeaf + where + S: aliasing::AliasSource, + W: aliasing::FieldAliasMapper, + T: aliasing::FieldAliasMapper, + { + type Out = CaseWhenConditionsLeaf< + >::Out, + >::Out, + >; + fn map(self, alias: &aliasing::Alias) -> Self::Out { + CaseWhenConditionsLeaf { + when: self.when.map(alias), + then: self.then.map(alias), + } + } + } + + impl aliasing::FieldAliasMapper + for CaseWhenConditionsIntermediateNode + where + S: aliasing::AliasSource, + W: aliasing::FieldAliasMapper, + T: aliasing::FieldAliasMapper, + Whens: aliasing::FieldAliasMapper, + { + type Out = CaseWhenConditionsIntermediateNode< + >::Out, + >::Out, + >::Out, + >; + fn map(self, alias: &aliasing::Alias) -> Self::Out { + CaseWhenConditionsIntermediateNode { + first_whens: self.first_whens.map(alias), + last_when: self.last_when.map(alias), + } + } + } +} diff --git a/diesel/src/expression/helper_types.rs b/diesel/src/expression/helper_types.rs index fee01a8904aa..5104b93a0c05 100644 --- a/diesel/src/expression/helper_types.rs +++ b/diesel/src/expression/helper_types.rs @@ -5,6 +5,7 @@ use super::array_comparison::{AsInExpression, In, NotIn}; use super::grouped::Grouped; use super::select_by::SelectBy; use super::{AsExpression, Expression}; +use crate::expression; use crate::expression_methods::PreferredBoolSqlType; use crate::sql_types; @@ -120,6 +121,27 @@ pub type Like = Grouped = Grouped>>>; +/// The return type of [`case_when()`](expression::case_when::case_when) +#[allow(non_camel_case_types)] +pub type case_when::SqlType> = expression::case_when::CaseWhen< + expression::case_when::CaseWhenConditionsLeaf, Grouped>>, + expression::case_when::NoElseExpression, +>; +/// The return type of [`case_when(...).when(...)`](expression::case_when::CaseWhen::when) +pub type When = expression::case_when::CaseWhen< + expression::case_when::CaseWhenConditionsIntermediateNode< + Grouped, + Grouped::OutputExpressionSpecifiedSqlType>>, + ::Whens, + >, + ::Else, +>; +/// The return type of [`case_when(...).otherwise(...)`](expression::case_when::CaseWhen::otherwise) +pub type Otherwise = expression::case_when::CaseWhen< + ::Whens, + expression::case_when::ElseExpression::OutputExpressionSpecifiedSqlType>>>, +>; + /// Represents the return type of [`.as_select()`](crate::prelude::SelectableHelper::as_select) pub type AsSelect = SelectBy; diff --git a/diesel/src/expression/mod.rs b/diesel/src/expression/mod.rs index 48b7fe942222..f54d5244a470 100644 --- a/diesel/src/expression/mod.rs +++ b/diesel/src/expression/mod.rs @@ -36,6 +36,7 @@ mod not; pub(crate) mod nullable; #[macro_use] pub(crate) mod operators; +mod case_when; pub(crate) mod select_by; mod sql_literal; pub(crate) mod subselect; @@ -51,6 +52,8 @@ pub use self::operators::Concat; pub(crate) mod dsl { use crate::dsl::SqlTypeOf; + #[doc(inline)] + pub use super::case_when::*; #[doc(inline)] pub use super::count::*; #[doc(inline)]