From f012ccaec3d36f6aa9de8a43e85858047d0f209d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=BCtz?= Date: Wed, 19 Jan 2022 20:43:40 +0100 Subject: [PATCH 1/2] Add Send bound to streams. --- Cargo.toml | 2 +- examples/send_stream_example/.env | 3 + examples/send_stream_example/Cargo.toml | 23 +++++ examples/send_stream_example/README.md | 1 + examples/send_stream_example/src/main.rs | 29 ++++++ examples/send_stream_example/src/post.rs | 26 ++++++ examples/send_stream_example/src/setup.rs | 33 +++++++ src/database/connection.rs | 4 +- src/database/db_connection.rs | 2 +- src/database/stream/query.rs | 2 +- src/database/stream/transaction.rs | 103 ++++++++++++---------- src/database/transaction.rs | 5 +- src/driver/mock.rs | 2 +- src/executor/select.rs | 20 +++-- 14 files changed, 192 insertions(+), 63 deletions(-) create mode 100644 examples/send_stream_example/.env create mode 100644 examples/send_stream_example/Cargo.toml create mode 100644 examples/send_stream_example/README.md create mode 100644 examples/send_stream_example/src/main.rs create mode 100644 examples/send_stream_example/src/post.rs create mode 100644 examples/send_stream_example/src/setup.rs diff --git a/Cargo.toml b/Cargo.toml index 0665ae23e..0863d5eb5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,7 @@ serde = { version = "^1.0", features = ["derive"] } serde_json = { version = "^1", optional = true } sqlx = { version = "^0.5", optional = true } uuid = { version = "0.8", features = ["serde", "v4"], optional = true } -ouroboros = "0.14" +ouroboros = { git = "https://github.com/sebpuetz/ouroboros", branch = "send-builder-functions" } url = "^2.2" once_cell = "1.8" diff --git a/examples/send_stream_example/.env b/examples/send_stream_example/.env new file mode 100644 index 000000000..fb7fcfb5b --- /dev/null +++ b/examples/send_stream_example/.env @@ -0,0 +1,3 @@ +HOST=127.0.0.1 +PORT=8000 +DATABASE_URL="postgres://postgres:password@localhost/axum_exmaple" \ No newline at end of file diff --git a/examples/send_stream_example/Cargo.toml b/examples/send_stream_example/Cargo.toml new file mode 100644 index 000000000..e20b627e7 --- /dev/null +++ b/examples/send_stream_example/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "sea-orm-axum-example" +version = "0.1.0" +authors = ["Sebastian Pütz "] +edition = "2021" +publish = false + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[workspace] + +[dependencies] +tokio = { version = "1.14", features = ["full"] } +anyhow = "1" +dotenv = "0.15" +futures-util = "0.3" +serde = "1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } + +[dependencies.sea-orm] +path = "../../" # remove this line in your own project +# version = "^0.5.0" +features = ["macros", "mock", "sqlx-all", "runtime-tokio-rustls", "debug-print"] +default-features = false diff --git a/examples/send_stream_example/README.md b/examples/send_stream_example/README.md new file mode 100644 index 000000000..df30b8be1 --- /dev/null +++ b/examples/send_stream_example/README.md @@ -0,0 +1 @@ +Demonstrator for using streaming queries with `tokio::spawn` or in contexts that require `Send` futures. \ No newline at end of file diff --git a/examples/send_stream_example/src/main.rs b/examples/send_stream_example/src/main.rs new file mode 100644 index 000000000..9d7f664dd --- /dev/null +++ b/examples/send_stream_example/src/main.rs @@ -0,0 +1,29 @@ +mod post; +mod setup; + +use futures_util::StreamExt; +use post::Entity as Post; +use sea_orm::{prelude::*, Database}; +use std::env; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + env::set_var("RUST_LOG", "debug"); + tracing_subscriber::fmt::init(); + + dotenv::dotenv().ok(); + let db_url = env::var("DATABASE_URL").expect("DATABASE_URL is not set in .env file"); + let db = Database::connect(db_url) + .await + .expect("Database connection failed"); + let _ = setup::create_post_table(&db); + tokio::task::spawn(async move { + let mut stream = Post::find().stream(&db).await.unwrap(); + while let Some(item) = stream.next().await { + let item = item?; + println!("got something: {}", item.text); + } + Ok::<(), anyhow::Error>(()) + }) + .await? +} diff --git a/examples/send_stream_example/src/post.rs b/examples/send_stream_example/src/post.rs new file mode 100644 index 000000000..3bb4d6a33 --- /dev/null +++ b/examples/send_stream_example/src/post.rs @@ -0,0 +1,26 @@ +//! SeaORM Entity. Generated by sea-orm-codegen 0.3.2 + +use sea_orm::entity::prelude::*; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)] +#[sea_orm(table_name = "posts")] +pub struct Model { + #[sea_orm(primary_key)] + #[serde(skip_deserializing)] + pub id: i32, + pub title: String, + #[sea_orm(column_type = "Text")] + pub text: String, +} + +#[derive(Copy, Clone, Debug, EnumIter)] +pub enum Relation {} + +impl RelationTrait for Relation { + fn def(&self) -> RelationDef { + panic!("No RelationDef") + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/examples/send_stream_example/src/setup.rs b/examples/send_stream_example/src/setup.rs new file mode 100644 index 000000000..04677af46 --- /dev/null +++ b/examples/send_stream_example/src/setup.rs @@ -0,0 +1,33 @@ +use sea_orm::sea_query::{ColumnDef, TableCreateStatement}; +use sea_orm::{error::*, sea_query, ConnectionTrait, DbConn, ExecResult}; + +async fn create_table(db: &DbConn, stmt: &TableCreateStatement) -> Result { + let builder = db.get_database_backend(); + db.execute(builder.build(stmt)).await +} + +pub async fn create_post_table(db: &DbConn) -> Result { + let stmt = sea_query::Table::create() + .table(super::post::Entity) + .if_not_exists() + .col( + ColumnDef::new(super::post::Column::Id) + .integer() + .not_null() + .auto_increment() + .primary_key(), + ) + .col( + ColumnDef::new(super::post::Column::Title) + .string() + .not_null(), + ) + .col( + ColumnDef::new(super::post::Column::Text) + .string() + .not_null(), + ) + .to_owned(); + + create_table(db, &stmt).await +} diff --git a/src/database/connection.rs b/src/database/connection.rs index 8c409999b..eb41fe18c 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -9,7 +9,7 @@ use std::{future::Future, pin::Pin}; #[async_trait::async_trait] pub trait ConnectionTrait<'a>: Sync { /// Create a stream for the [QueryResult] - type Stream: Stream>; + type Stream: Stream> + Send; /// Fetch the database backend as specified in [DbBackend]. /// This depends on feature flags enabled. @@ -28,7 +28,7 @@ pub trait ConnectionTrait<'a>: Sync { fn stream( &'a self, stmt: Statement, - ) -> Pin> + 'a>>; + ) -> Pin> + 'a + Send>>; /// Execute SQL `BEGIN` transaction. /// Returns a Transaction that can be committed or rolled back diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index 900167378..0183bcd8f 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -155,7 +155,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection { fn stream( &'a self, stmt: Statement, - ) -> Pin> + 'a>> { + ) -> Pin> + 'a + Send>> { Box::pin(async move { Ok(match self { #[cfg(feature = "sqlx-mysql")] diff --git a/src/database/stream/query.rs b/src/database/stream/query.rs index 98f1142a3..e0f606f4c 100644 --- a/src/database/stream/query.rs +++ b/src/database/stream/query.rs @@ -24,7 +24,7 @@ pub struct QueryStream { metric_callback: Option, #[borrows(mut conn, stmt, metric_callback)] #[not_covariant] - stream: Pin> + 'this>>, + stream: Pin> + Send + 'this>>, } #[cfg(feature = "sqlx-mysql")] diff --git a/src/database/stream/transaction.rs b/src/database/stream/transaction.rs index daa912ef5..d198f54f4 100644 --- a/src/database/stream/transaction.rs +++ b/src/database/stream/transaction.rs @@ -2,7 +2,7 @@ use std::{ops::DerefMut, pin::Pin, task::Poll}; -use futures::Stream; +use futures::{FutureExt, Stream}; #[cfg(feature = "sqlx-dep")] use futures::TryStreamExt; @@ -24,7 +24,7 @@ pub struct TransactionStream<'a> { metric_callback: Option, #[borrows(mut conn, stmt, metric_callback)] #[not_covariant] - stream: Pin> + 'this>>, + stream: Pin> + 'this + Send>>, } impl<'a> std::fmt::Debug for TransactionStream<'a> { @@ -33,61 +33,72 @@ impl<'a> std::fmt::Debug for TransactionStream<'a> { } } +type PinStream<'s> = Pin> + Send + 's>>; + impl<'a> TransactionStream<'a> { + fn stream_builder<'s>( + conn: &'s mut MutexGuard<'a, InnerConnection>, + stmt: &'s Statement, + _metric_callback: &'s Option, + ) -> Pin> + 's + Send>> { + async move { + match conn.deref_mut() { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(c) => { + let query = crate::driver::sqlx_mysql::sqlx_query(stmt); + crate::metric::metric_ok!(_metric_callback, stmt, { + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err), + ) + as Pin> + Send>> + }) + } + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(c) => { + let query = crate::driver::sqlx_postgres::sqlx_query(stmt); + crate::metric::metric_ok!(_metric_callback, stmt, { + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err), + ) + as Pin> + Send>> + }) + } + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(c) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); + crate::metric::metric_ok!(_metric_callback, stmt, { + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err), + ) + as Pin> + Send>> + }) + } + #[cfg(feature = "mock")] + InnerConnection::Mock(c) => c.fetch(stmt), + } + } + .boxed() + } + #[instrument(level = "trace", skip(metric_callback))] pub(crate) async fn build( conn: MutexGuard<'a, InnerConnection>, stmt: Statement, metric_callback: Option, ) -> TransactionStream<'a> { + let stream_builder = Self::stream_builder; + TransactionStreamAsyncBuilder { stmt, conn, metric_callback, - stream_builder: |conn, stmt, _metric_callback| { - Box::pin(async move { - match conn.deref_mut() { - #[cfg(feature = "sqlx-mysql")] - InnerConnection::MySql(c) => { - let query = crate::driver::sqlx_mysql::sqlx_query(stmt); - crate::metric::metric_ok!(_metric_callback, stmt, { - Box::pin( - c.fetch(query) - .map_ok(Into::into) - .map_err(crate::sqlx_error_to_query_err), - ) - as Pin>>> - }) - } - #[cfg(feature = "sqlx-postgres")] - InnerConnection::Postgres(c) => { - let query = crate::driver::sqlx_postgres::sqlx_query(stmt); - crate::metric::metric_ok!(_metric_callback, stmt, { - Box::pin( - c.fetch(query) - .map_ok(Into::into) - .map_err(crate::sqlx_error_to_query_err), - ) - as Pin>>> - }) - } - #[cfg(feature = "sqlx-sqlite")] - InnerConnection::Sqlite(c) => { - let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); - crate::metric::metric_ok!(_metric_callback, stmt, { - Box::pin( - c.fetch(query) - .map_ok(Into::into) - .map_err(crate::sqlx_error_to_query_err), - ) - as Pin>>> - }) - } - #[cfg(feature = "mock")] - InnerConnection::Mock(c) => c.fetch(stmt), - } - }) - }, + stream_builder, } .build() .await diff --git a/src/database/transaction.rs b/src/database/transaction.rs index e1f85536c..c8690cb04 100644 --- a/src/database/transaction.rs +++ b/src/database/transaction.rs @@ -354,10 +354,11 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { fn stream( &'a self, stmt: Statement, - ) -> Pin> + 'a>> { + ) -> Pin> + 'a + Send>> { Box::pin(async move { + let conn = self.conn.lock().await; Ok(crate::TransactionStream::build( - self.conn.lock().await, + conn, stmt, self.metric_callback.clone(), ) diff --git a/src/driver/mock.rs b/src/driver/mock.rs index cfdd22c57..8a163e91d 100644 --- a/src/driver/mock.rs +++ b/src/driver/mock.rs @@ -148,7 +148,7 @@ impl MockDatabaseConnection { pub fn fetch( &self, statement: &Statement, - ) -> Pin>>> { + ) -> Pin> + Send>> { match self.query_all(statement.clone()) { Ok(v) => Box::pin(futures::stream::iter(v.into_iter().map(Ok))), Err(e) => Box::pin(futures::stream::iter(Some(Err(e)).into_iter())), diff --git a/src/executor/select.rs b/src/executor/select.rs index 20067aceb..6a8bd3fa1 100644 --- a/src/executor/select.rs +++ b/src/executor/select.rs @@ -273,9 +273,9 @@ where pub async fn stream<'a: 'b, 'b, C>( self, db: &'a C, - ) -> Result> + 'b, DbErr> + ) -> Result> + 'b + Send, DbErr> where - C: ConnectionTrait<'a>, + C: ConnectionTrait<'a> + Send, { self.into_model().stream(db).await } @@ -329,7 +329,7 @@ where db: &'a C, ) -> Result), DbErr>> + 'b, DbErr> where - C: ConnectionTrait<'a>, + C: ConnectionTrait<'a> + Send, { self.into_model().stream(db).await } @@ -373,9 +373,9 @@ where pub async fn stream<'a: 'b, 'b, C>( self, db: &'a C, - ) -> Result), DbErr>> + 'b, DbErr> + ) -> Result), DbErr>> + 'b + Send, DbErr> where - C: ConnectionTrait<'a>, + C: ConnectionTrait<'a> + Send, { self.into_model().stream(db).await } @@ -452,10 +452,11 @@ where pub async fn stream<'a: 'b, 'b, C>( self, db: &'a C, - ) -> Result> + 'b>>, DbErr> + ) -> Result> + 'b + Send>>, DbErr> where - C: ConnectionTrait<'a>, + C: ConnectionTrait<'a> + Send, S: 'b, + S::Item: Send, { self.into_selector_raw(db).stream(db).await } @@ -737,10 +738,11 @@ where pub async fn stream<'a: 'b, 'b, C>( self, db: &'a C, - ) -> Result> + 'b>>, DbErr> + ) -> Result> + 'b + Send>>, DbErr> where - C: ConnectionTrait<'a>, + C: ConnectionTrait<'a> + Send, S: 'b, + S::Item: Send, { let stream = db.stream(self.stmt).await?; Ok(Box::pin(stream.and_then(|row| { From 949e3115f5d38eb8247dfa0bd714a0c49efc3cb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=BCtz?= Date: Wed, 19 Jan 2022 23:54:16 +0100 Subject: [PATCH 2/2] Make TransactionStream::build sync --- Cargo.toml | 2 +- src/database/stream/transaction.rs | 42 +++++++++--------------------- src/database/transaction.rs | 3 +-- 3 files changed, 15 insertions(+), 32 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0863d5eb5..0665ae23e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,7 @@ serde = { version = "^1.0", features = ["derive"] } serde_json = { version = "^1", optional = true } sqlx = { version = "^0.5", optional = true } uuid = { version = "0.8", features = ["serde", "v4"], optional = true } -ouroboros = { git = "https://github.com/sebpuetz/ouroboros", branch = "send-builder-functions" } +ouroboros = "0.14" url = "^2.2" once_cell = "1.8" diff --git a/src/database/stream/transaction.rs b/src/database/stream/transaction.rs index d198f54f4..ad1c81818 100644 --- a/src/database/stream/transaction.rs +++ b/src/database/stream/transaction.rs @@ -2,7 +2,7 @@ use std::{ops::DerefMut, pin::Pin, task::Poll}; -use futures::{FutureExt, Stream}; +use futures::Stream; #[cfg(feature = "sqlx-dep")] use futures::TryStreamExt; @@ -33,16 +33,18 @@ impl<'a> std::fmt::Debug for TransactionStream<'a> { } } -type PinStream<'s> = Pin> + Send + 's>>; - impl<'a> TransactionStream<'a> { - fn stream_builder<'s>( - conn: &'s mut MutexGuard<'a, InnerConnection>, - stmt: &'s Statement, - _metric_callback: &'s Option, - ) -> Pin> + 's + Send>> { - async move { - match conn.deref_mut() { + #[instrument(level = "trace", skip(metric_callback))] + pub(crate) fn build( + conn: MutexGuard<'a, InnerConnection>, + stmt: Statement, + metric_callback: Option, + ) -> TransactionStream<'a> { + TransactionStreamBuilder { + stmt, + conn, + metric_callback, + stream_builder: |conn, stmt, _metric_callback| match conn.deref_mut() { #[cfg(feature = "sqlx-mysql")] InnerConnection::MySql(c) => { let query = crate::driver::sqlx_mysql::sqlx_query(stmt); @@ -81,27 +83,9 @@ impl<'a> TransactionStream<'a> { } #[cfg(feature = "mock")] InnerConnection::Mock(c) => c.fetch(stmt), - } - } - .boxed() - } - - #[instrument(level = "trace", skip(metric_callback))] - pub(crate) async fn build( - conn: MutexGuard<'a, InnerConnection>, - stmt: Statement, - metric_callback: Option, - ) -> TransactionStream<'a> { - let stream_builder = Self::stream_builder; - - TransactionStreamAsyncBuilder { - stmt, - conn, - metric_callback, - stream_builder, + }, } .build() - .await } } diff --git a/src/database/transaction.rs b/src/database/transaction.rs index c8690cb04..80db1c376 100644 --- a/src/database/transaction.rs +++ b/src/database/transaction.rs @@ -361,8 +361,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction { conn, stmt, self.metric_callback.clone(), - ) - .await) + )) }) }