Skip to content

Commit

Permalink
ColumnTrait::into_returning_expr
Browse files Browse the repository at this point in the history
  • Loading branch information
billy1624 committed Mar 7, 2024
1 parent 64b813a commit 31bf0c8
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 13 deletions.
9 changes: 8 additions & 1 deletion src/entity/column.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{EntityName, Iden, IdenStatic, IntoSimpleExpr, Iterable};
use crate::{DbBackend, EntityName, Iden, IdenStatic, IntoSimpleExpr, Iterable};
use sea_query::{
Alias, BinOper, DynIden, Expr, IntoIden, SeaRc, SelectStatement, SimpleExpr, Value,
};
Expand Down Expand Up @@ -247,6 +247,13 @@ pub trait ColumnTrait: IdenStatic + Iterable + FromStr {
Expr::expr(self.into_simple_expr())
}

/// Construct a returning [`Expr`].
fn into_returning_expr(self, db_backend: DbBackend) -> Expr {
match db_backend {
_ => Expr::col(self),
}
}

/// Cast column expression used in select statement.
/// It only cast database enum as text if it's an enum column.
fn select_as(&self, expr: Expr) -> SimpleExpr {
Expand Down
13 changes: 8 additions & 5 deletions src/executor/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,12 @@ where
// so that self is dropped before entering await
let mut query = self.query;
if db.support_returning() && <A::Entity as EntityTrait>::PrimaryKey::iter().count() > 0 {
let returning = Query::returning().exprs(
<A::Entity as EntityTrait>::PrimaryKey::iter()
.map(|c| c.into_column().select_as(Expr::col(c.into_column_ref()))),
);
let db_backend = db.get_database_backend();
let returning =
Query::returning().exprs(<A::Entity as EntityTrait>::PrimaryKey::iter().map(|c| {
c.into_column()
.select_as(c.into_column().into_returning_expr(db_backend))
}));
query.returning(returning);
}
Inserter::<A>::new(self.primary_key, query).exec(db)
Expand Down Expand Up @@ -275,7 +277,8 @@ where
let found = match db.support_returning() {
true => {
let returning = Query::returning().exprs(
<A::Entity as EntityTrait>::Column::iter().map(|c| c.select_as(Expr::col(c))),
<A::Entity as EntityTrait>::Column::iter()
.map(|c| c.select_as(c.into_returning_expr(db_backend))),
);
insert_statement.returning(returning);
SelectorRaw::<SelectModel<<A::Entity as EntityTrait>::Model>>::from_statement(
Expand Down
14 changes: 8 additions & 6 deletions src/executor/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,11 @@ impl Updater {

match db.support_returning() {
true => {
let returning = Query::returning()
.exprs(Column::<A>::iter().map(|c| c.select_as(Expr::col(c))));
self.query.returning(returning);
let db_backend = db.get_database_backend();
let returning = Query::returning().exprs(
Column::<A>::iter().map(|c| c.select_as(c.into_returning_expr(db_backend))),
);
self.query.returning(returning);
let found: Option<Model<A>> = SelectorRaw::<SelectModel<Model<A>>>::from_statement(
db_backend.build(&self.query),
)
Expand Down Expand Up @@ -148,10 +149,11 @@ impl Updater {

match db.support_returning() {
true => {
let returning =
Query::returning().exprs(E::Column::iter().map(|c| c.select_as(Expr::col(c))));
self.query.returning(returning);
let db_backend = db.get_database_backend();
let returning = Query::returning().exprs(
E::Column::iter().map(|c| c.select_as(c.into_returning_expr(db_backend))),
);
self.query.returning(returning);
let models: Vec<E::Model> = SelectorRaw::<SelectModel<E::Model>>::from_statement(
db_backend.build(&self.query),
)
Expand Down
4 changes: 3 additions & 1 deletion tests/returning_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ async fn main() -> Result<(), DbErr> {
])
.and_where(Column::Id.eq(1));

let returning = Query::returning().columns([Column::Id, Column::Name, Column::ProfitMargin]);
let columns = [Column::Id, Column::Name, Column::ProfitMargin];
let returning =
Query::returning().exprs(columns.into_iter().map(|c| c.into_returning_expr(builder)));

create_tables(db).await?;

Expand Down

0 comments on commit 31bf0c8

Please sign in to comment.