Skip to content

Commit

Permalink
Merge pull request #156 from SeaQL/transaction
Browse files Browse the repository at this point in the history
Transaction Support
  • Loading branch information
tyt2y3 authored Sep 18, 2021
2 parents 3093cd2 + b30a103 commit 242d16d
Show file tree
Hide file tree
Showing 26 changed files with 648 additions and 148 deletions.
16 changes: 8 additions & 8 deletions examples/rocket_example/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async fn create(conn: Connection<Db>, post_form: Form<post::Model>) -> Flash<Red
text: Set(form.text.to_owned()),
..Default::default()
}
.save(&conn)
.save(&*conn)
.await
.expect("could not insert post");

Expand All @@ -52,7 +52,7 @@ async fn create(conn: Connection<Db>, post_form: Form<post::Model>) -> Flash<Red
#[post("/<id>", data = "<post_form>")]
async fn update(conn: Connection<Db>, id: i32, post_form: Form<post::Model>) -> Flash<Redirect> {
let post: post::ActiveModel = Post::find_by_id(id)
.one(&conn)
.one(&*conn)
.await
.unwrap()
.unwrap()
Expand All @@ -65,7 +65,7 @@ async fn update(conn: Connection<Db>, id: i32, post_form: Form<post::Model>) ->
title: Set(form.title.to_owned()),
text: Set(form.text.to_owned()),
}
.save(&conn)
.save(&*conn)
.await
.expect("could not edit post");

Expand All @@ -83,7 +83,7 @@ async fn list(
let posts_per_page = posts_per_page.unwrap_or(DEFAULT_POSTS_PER_PAGE);
let paginator = Post::find()
.order_by_asc(post::Column::Id)
.paginate(&conn, posts_per_page);
.paginate(&*conn, posts_per_page);
let num_pages = paginator.num_pages().await.ok().unwrap();

let posts = paginator
Expand All @@ -108,7 +108,7 @@ async fn list(
#[get("/<id>")]
async fn edit(conn: Connection<Db>, id: i32) -> Template {
let post: Option<post::Model> = Post::find_by_id(id)
.one(&conn)
.one(&*conn)
.await
.expect("could not find post");

Expand All @@ -123,20 +123,20 @@ async fn edit(conn: Connection<Db>, id: i32) -> Template {
#[delete("/<id>")]
async fn delete(conn: Connection<Db>, id: i32) -> Flash<Redirect> {
let post: post::ActiveModel = Post::find_by_id(id)
.one(&conn)
.one(&*conn)
.await
.unwrap()
.unwrap()
.into();

post.delete(&conn).await.unwrap();
post.delete(&*conn).await.unwrap();

Flash::success(Redirect::to("/"), "Post successfully deleted.")
}

#[delete("/")]
async fn destroy(conn: Connection<Db>) -> Result<()> {
Post::delete_many().exec(&conn).await.unwrap();
Post::delete_many().exec(&*conn).await.unwrap();
Ok(())
}

Expand Down
2 changes: 1 addition & 1 deletion examples/rocket_example/src/setup.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use sea_orm::sea_query::{ColumnDef, TableCreateStatement};
use sea_orm::{error::*, sea_query, DbConn, ExecResult};
use sea_orm::{query::*, error::*, sea_query, DbConn, ExecResult};

async fn create_table(db: &DbConn, stmt: &TableCreateStatement) -> Result<ExecResult, DbErr> {
let builder = db.get_database_backend();
Expand Down
53 changes: 40 additions & 13 deletions src/database/connection.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{error::*, ExecResult, QueryResult, Statement, StatementBuilder};
use std::{pin::Pin, future::Future};
use crate::{DatabaseTransaction, ConnectionTrait, ExecResult, QueryResult, Statement, StatementBuilder, TransactionError, error::*};
use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder};

#[cfg_attr(not(feature = "mock"), derive(Clone))]
Expand Down Expand Up @@ -51,8 +52,9 @@ impl std::fmt::Debug for DatabaseConnection {
}
}

impl DatabaseConnection {
pub fn get_database_backend(&self) -> DbBackend {
#[async_trait::async_trait]
impl ConnectionTrait for DatabaseConnection {
fn get_database_backend(&self) -> DbBackend {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql,
Expand All @@ -66,7 +68,7 @@ impl DatabaseConnection {
}
}

pub async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await,
Expand All @@ -80,7 +82,7 @@ impl DatabaseConnection {
}
}

pub async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await,
Expand All @@ -94,7 +96,7 @@ impl DatabaseConnection {
}
}

pub async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await,
Expand All @@ -108,21 +110,46 @@ impl DatabaseConnection {
}
}

/// Execute the function inside a transaction.
/// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(&'c DatabaseTransaction<'_>) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>> + Send + Sync,
T: Send,
E: std::error::Error + Send,
{
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.transaction(_callback).await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.transaction(_callback).await,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.transaction(_callback).await,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(_) => unimplemented!(), //TODO: support transaction in mock connection
DatabaseConnection::Disconnected => panic!("Disconnected"),
}
}

#[cfg(feature = "mock")]
fn is_mock_connection(&self) -> bool {
match self {
DatabaseConnection::MockDatabaseConnection(_) => true,
_ => false,
}
}
}

#[cfg(feature = "mock")]
impl DatabaseConnection {
pub fn as_mock_connection(&self) -> &crate::MockDatabaseConnection {
match self {
DatabaseConnection::MockDatabaseConnection(mock_conn) => mock_conn,
_ => panic!("not mock connection"),
}
}

#[cfg(not(feature = "mock"))]
pub fn as_mock_connection(&self) -> Option<bool> {
None
}

#[cfg(feature = "mock")]
pub fn into_transaction_log(self) -> Vec<crate::Transaction> {
pub fn into_transaction_log(&self) -> Vec<crate::Transaction> {
let mut mocker = self.as_mock_connection().get_mocker_mutex().lock().unwrap();
mocker.drain_transaction_log()
}
Expand Down
25 changes: 25 additions & 0 deletions src/database/db_connection.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use std::{pin::Pin, future::Future};
use crate::{DatabaseTransaction, DbBackend, DbErr, ExecResult, QueryResult, Statement, TransactionError};

#[async_trait::async_trait]
pub trait ConnectionTrait: Sync {
fn get_database_backend(&self) -> DbBackend;

async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr>;

async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr>;

async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr>;

/// Execute the function inside a transaction.
/// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(&'c DatabaseTransaction<'_>) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>> + Send + Sync,
T: Send,
E: std::error::Error + Send;

fn is_mock_connection(&self) -> bool {
false
}
}
Loading

0 comments on commit 242d16d

Please sign in to comment.