diff --git a/Cargo.toml b/Cargo.toml index 8344c3fa2..8796652b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,7 @@ tracing = { version = "0.1", default-features = false, features = ["attributes", rust_decimal = { version = "1", default-features = false, optional = true } bigdecimal = { version = "0.3", default-features = false, optional = true } sea-orm-macros = { version = "0.10.3", path = "sea-orm-macros", default-features = false, optional = true } -sea-query = { version = "0.28", features = ["thread-safe"] } +sea-query = { version = "0.28.3", features = ["thread-safe"] } sea-query-binder = { version = "0.3", default-features = false, optional = true } sea-strum = { version = "0.23", default-features = false, features = ["derive", "sea-orm"] } serde = { version = "1.0", default-features = false } diff --git a/sea-orm-codegen/Cargo.toml b/sea-orm-codegen/Cargo.toml index 40f5480d3..7357a9eea 100644 --- a/sea-orm-codegen/Cargo.toml +++ b/sea-orm-codegen/Cargo.toml @@ -17,7 +17,7 @@ name = "sea_orm_codegen" path = "src/lib.rs" [dependencies] -sea-query = { version = "0.28", default-features = false, features = ["thread-safe"] } +sea-query = { version = "0.28.3", default-features = false, features = ["thread-safe"] } syn = { version = "1", default-features = false } quote = { version = "1", default-features = false } heck = { version = "0.3", default-features = false } diff --git a/src/entity/active_model.rs b/src/entity/active_model.rs index 2a34a2863..c15e24f12 100644 --- a/src/entity/active_model.rs +++ b/src/entity/active_model.rs @@ -1423,8 +1423,8 @@ mod tests { vec![ Transaction::from_sql_and_values( DbBackend::Postgres, - r#"UPDATE "fruit" SET WHERE "fruit"."id" = $1 RETURNING "id", "name", "cake_id""#, - vec![1i32.into()], + r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit" WHERE "fruit"."id" = $1 LIMIT $2"#, + vec![1i32.into(), 1u64.into()], ), Transaction::from_sql_and_values( DbBackend::Postgres, diff --git a/src/executor/update.rs b/src/executor/update.rs index 9e37cf4d2..3a42129db 100644 --- a/src/executor/update.rs +++ b/src/executor/update.rs @@ -1,9 +1,8 @@ use crate::{ error::*, ActiveModelTrait, ColumnTrait, ConnectionTrait, EntityTrait, IntoActiveModel, - Iterable, PrimaryKeyTrait, SelectModel, SelectorRaw, Statement, UpdateMany, UpdateOne, + Iterable, PrimaryKeyTrait, SelectModel, SelectorRaw, UpdateMany, UpdateOne, }; use sea_query::{Expr, FromValueTuple, Query, UpdateStatement}; -use std::future::Future; /// Defines an update operation #[derive(Clone, Debug)] @@ -13,7 +12,7 @@ pub struct Updater { } /// The result of an update operation on an ActiveModel -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Default)] pub struct UpdateResult { /// The rows affected by the update operation pub rows_affected: u64, @@ -29,8 +28,9 @@ where ::Model: IntoActiveModel, C: ConnectionTrait, { - // so that self is dropped before entering await - exec_update_and_return_updated(self.query, self.model, db).await + Updater::new(self.query) + .exec_update_and_return_updated(self.model, db) + .await } } @@ -39,12 +39,11 @@ where E: EntityTrait, { /// Execute an update operation on multiple ActiveModels - pub fn exec(self, db: &'a C) -> impl Future> + '_ + pub async fn exec(self, db: &'a C) -> Result where C: ConnectionTrait, { - // so that self is dropped before entering await - exec_update_only(self.query, db) + Updater::new(self.query).exec(db).await } } @@ -64,24 +63,76 @@ impl Updater { } /// Execute an update operation - pub fn exec(self, db: &C) -> impl Future> + '_ + pub async fn exec(self, db: &C) -> Result where C: ConnectionTrait, { + if self.is_noop() { + return Ok(UpdateResult::default()); + } let builder = db.get_database_backend(); - exec_update(builder.build(&self.query), db, self.check_record_exists) + let statement = builder.build(&self.query); + let result = db.execute(statement).await?; + if self.check_record_exists && result.rows_affected() == 0 { + return Err(DbErr::RecordNotFound( + "None of the database rows are affected".to_owned(), + )); + } + Ok(UpdateResult { + rows_affected: result.rows_affected(), + }) } -} -async fn exec_update_only(query: UpdateStatement, db: &C) -> Result -where - C: ConnectionTrait, -{ - Updater::new(query).exec(db).await + async fn exec_update_and_return_updated( + mut self, + model: A, + db: &C, + ) -> Result<::Model, DbErr> + where + A: ActiveModelTrait, + C: ConnectionTrait, + { + type Entity = ::Entity; + type Model = as EntityTrait>::Model; + type Column = as EntityTrait>::Column; + + if self.is_noop() { + return find_updated_model_by_id(model, db).await; + } + + match db.support_returning() { + true => { + let returning = Query::returning() + .exprs(Column::::iter().map(|c| c.select_as(Expr::col(c)))); + self.query.returning(returning); + let db_backend = db.get_database_backend(); + let found: Option> = SelectorRaw::>>::from_statement( + db_backend.build(&self.query), + ) + .one(db) + .await?; + // If we got `None` then we are updating a row that does not exist. + match found { + Some(model) => Ok(model), + None => Err(DbErr::RecordNotFound( + "None of the database rows are affected".to_owned(), + )), + } + } + false => { + // If we updating a row that does not exist then an error will be thrown here. + self.check_record_exists().exec(db).await?; + find_updated_model_by_id(model, db).await + } + } + } + + fn is_noop(&self) -> bool { + self.query.get_values().is_empty() + } } -async fn exec_update_and_return_updated( - mut query: UpdateStatement, +async fn find_updated_model_by_id( model: A, db: &C, ) -> Result<::Model, DbErr> @@ -90,63 +141,20 @@ where C: ConnectionTrait, { type Entity = ::Entity; - type Model = as EntityTrait>::Model; - type Column = as EntityTrait>::Column; type ValueType = < as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::ValueType; - match db.support_returning() { - true => { - let returning = - Query::returning().exprs(Column::::iter().map(|c| c.select_as(Expr::col(c)))); - query.returning(returning); - let db_backend = db.get_database_backend(); - let found: Option> = - SelectorRaw::>>::from_statement(db_backend.build(&query)) - .one(db) - .await?; - // If we got `None` then we are updating a row that does not exist. - match found { - Some(model) => Ok(model), - None => Err(DbErr::RecordNotFound( - "None of the database rows are affected".to_owned(), - )), - } - } - false => { - // If we updating a row that does not exist then an error will be thrown here. - Updater::new(query).check_record_exists().exec(db).await?; - let primary_key_value = match model.get_primary_key_value() { - Some(val) => ValueType::::from_value_tuple(val), - None => return Err(DbErr::UpdateGetPrimaryKey), - }; - let found = Entity::::find_by_id(primary_key_value).one(db).await?; - // If we cannot select the updated row from db by the cached primary key - match found { - Some(model) => Ok(model), - None => Err(DbErr::RecordNotFound( - "Failed to find updated item".to_owned(), - )), - } - } - } -} -async fn exec_update( - statement: Statement, - db: &C, - check_record_exists: bool, -) -> Result -where - C: ConnectionTrait, -{ - let result = db.execute(statement).await?; - if check_record_exists && result.rows_affected() == 0 { - return Err(DbErr::RecordNotFound( - "None of the database rows are affected".to_owned(), - )); + let primary_key_value = match model.get_primary_key_value() { + Some(val) => ValueType::::from_value_tuple(val), + None => return Err(DbErr::UpdateGetPrimaryKey), + }; + let found = Entity::::find_by_id(primary_key_value).one(db).await?; + // If we cannot select the updated row from db by the cached primary key + match found { + Some(model) => Ok(model), + None => Err(DbErr::RecordNotFound( + "Failed to find updated item".to_owned(), + )), } - Ok(UpdateResult { - rows_affected: result.rows_affected(), - }) } #[cfg(test)] @@ -157,15 +165,20 @@ mod tests { #[smol_potat::test] async fn update_record_not_found_1() -> Result<(), DbErr> { + let updated_cake = cake::Model { + id: 1, + name: "Cheese Cake".to_owned(), + }; + let db = MockDatabase::new(DbBackend::Postgres) .append_query_results([ - vec![cake::Model { - id: 1, - name: "Cheese Cake".to_owned(), - }], + vec![updated_cake.clone()], vec![], vec![], vec![], + vec![updated_cake.clone()], + vec![updated_cake.clone()], + vec![updated_cake.clone()], ]) .append_exec_results([MockExecResult { last_insert_id: 0, @@ -181,7 +194,7 @@ mod tests { assert_eq!( cake::ActiveModel { name: Set("Cheese Cake".to_owned()), - ..model.into_active_model() + ..model.clone().into_active_model() } .update(&db) .await?, @@ -223,7 +236,7 @@ mod tests { assert_eq!( Update::one(cake::ActiveModel { name: Set("Cheese Cake".to_owned()), - ..model.into_active_model() + ..model.clone().into_active_model() }) .exec(&db) .await, @@ -241,6 +254,28 @@ mod tests { Ok(UpdateResult { rows_affected: 0 }) ); + assert_eq!( + updated_cake.clone().into_active_model().save(&db).await?, + updated_cake.clone().into_active_model() + ); + + assert_eq!( + updated_cake.clone().into_active_model().update(&db).await?, + updated_cake + ); + + assert_eq!( + cake::Entity::update(updated_cake.clone().into_active_model()) + .exec(&db) + .await?, + updated_cake + ); + + assert_eq!( + cake::Entity::update_many().exec(&db).await?.rows_affected, + 0 + ); + assert_eq!( db.into_transaction_log(), [ @@ -269,6 +304,21 @@ mod tests { r#"UPDATE "cake" SET "name" = $1 WHERE "cake"."id" = $2"#, ["Cheese Cake".into(), 2i32.into()] ), + Transaction::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "cake"."id" = $1 LIMIT $2"#, + [1.into(), 1u64.into()] + ), + Transaction::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "cake"."id" = $1 LIMIT $2"#, + [1.into(), 1u64.into()] + ), + Transaction::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "cake"."id", "cake"."name" FROM "cake" WHERE "cake"."id" = $1 LIMIT $2"#, + [1.into(), 1u64.into()] + ), ] );