Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Send bound to streams. #471

Merged
merged 2 commits into from
Mar 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/send_stream_example/.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
HOST=127.0.0.1
PORT=8000
DATABASE_URL="postgres://postgres:password@localhost/axum_exmaple"
23 changes: 23 additions & 0 deletions examples/send_stream_example/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[package]
name = "sea-orm-axum-example"
version = "0.1.0"
authors = ["Sebastian Pütz <seb.puetz@gmail.com>"]
edition = "2021"
publish = false

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[workspace]

[dependencies]
tokio = { version = "1.14", features = ["full"] }
anyhow = "1"
dotenv = "0.15"
futures-util = "0.3"
serde = "1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

[dependencies.sea-orm]
path = "../../" # remove this line in your own project
# version = "^0.5.0"
features = ["macros", "mock", "sqlx-all", "runtime-tokio-rustls", "debug-print"]
default-features = false
1 change: 1 addition & 0 deletions examples/send_stream_example/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Demonstrator for using streaming queries with `tokio::spawn` or in contexts that require `Send` futures.
29 changes: 29 additions & 0 deletions examples/send_stream_example/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
mod post;
mod setup;

use futures_util::StreamExt;
use post::Entity as Post;
use sea_orm::{prelude::*, Database};
use std::env;

#[tokio::main]
async fn main() -> anyhow::Result<()> {
env::set_var("RUST_LOG", "debug");
tracing_subscriber::fmt::init();

dotenv::dotenv().ok();
let db_url = env::var("DATABASE_URL").expect("DATABASE_URL is not set in .env file");
let db = Database::connect(db_url)
.await
.expect("Database connection failed");
let _ = setup::create_post_table(&db);
tokio::task::spawn(async move {
let mut stream = Post::find().stream(&db).await.unwrap();
while let Some(item) = stream.next().await {
let item = item?;
println!("got something: {}", item.text);
}
Ok::<(), anyhow::Error>(())
})
.await?
}
26 changes: 26 additions & 0 deletions examples/send_stream_example/src/post.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//! SeaORM Entity. Generated by sea-orm-codegen 0.3.2

use sea_orm::entity::prelude::*;
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)]
#[sea_orm(table_name = "posts")]
pub struct Model {
#[sea_orm(primary_key)]
#[serde(skip_deserializing)]
pub id: i32,
pub title: String,
#[sea_orm(column_type = "Text")]
pub text: String,
}

#[derive(Copy, Clone, Debug, EnumIter)]
pub enum Relation {}

impl RelationTrait for Relation {
fn def(&self) -> RelationDef {
panic!("No RelationDef")
}
}

impl ActiveModelBehavior for ActiveModel {}
33 changes: 33 additions & 0 deletions examples/send_stream_example/src/setup.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use sea_orm::sea_query::{ColumnDef, TableCreateStatement};
use sea_orm::{error::*, sea_query, ConnectionTrait, DbConn, ExecResult};

async fn create_table(db: &DbConn, stmt: &TableCreateStatement) -> Result<ExecResult, DbErr> {
let builder = db.get_database_backend();
db.execute(builder.build(stmt)).await
}

pub async fn create_post_table(db: &DbConn) -> Result<ExecResult, DbErr> {
let stmt = sea_query::Table::create()
.table(super::post::Entity)
.if_not_exists()
.col(
ColumnDef::new(super::post::Column::Id)
.integer()
.not_null()
.auto_increment()
.primary_key(),
)
.col(
ColumnDef::new(super::post::Column::Title)
.string()
.not_null(),
)
.col(
ColumnDef::new(super::post::Column::Text)
.string()
.not_null(),
)
.to_owned();

create_table(db, &stmt).await
}
4 changes: 2 additions & 2 deletions src/database/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{future::Future, pin::Pin};
#[async_trait::async_trait]
pub trait ConnectionTrait<'a>: Sync {
/// Create a stream for the [QueryResult]
type Stream: Stream<Item = Result<QueryResult, DbErr>>;
type Stream: Stream<Item = Result<QueryResult, DbErr>> + Send;

/// Fetch the database backend as specified in [DbBackend].
/// This depends on feature flags enabled.
Expand All @@ -28,7 +28,7 @@ pub trait ConnectionTrait<'a>: Sync {
fn stream(
&'a self,
stmt: Statement,
) -> Pin<Box<dyn Future<Output = Result<Self::Stream, DbErr>> + 'a>>;
) -> Pin<Box<dyn Future<Output = Result<Self::Stream, DbErr>> + 'a + Send>>;

/// Execute SQL `BEGIN` transaction.
/// Returns a Transaction that can be committed or rolled back
Expand Down
2 changes: 1 addition & 1 deletion src/database/db_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
fn stream(
&'a self,
stmt: Statement,
) -> Pin<Box<dyn Future<Output = Result<Self::Stream, DbErr>> + 'a>> {
) -> Pin<Box<dyn Future<Output = Result<Self::Stream, DbErr>> + 'a + Send>> {
Box::pin(async move {
Ok(match self {
#[cfg(feature = "sqlx-mysql")]
Expand Down
2 changes: 1 addition & 1 deletion src/database/stream/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub struct QueryStream {
metric_callback: Option<crate::metric::Callback>,
#[borrows(mut conn, stmt, metric_callback)]
#[not_covariant]
stream: Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + 'this>>,
stream: Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send + 'this>>,
}

#[cfg(feature = "sqlx-mysql")]
Expand Down
89 changes: 42 additions & 47 deletions src/database/stream/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub struct TransactionStream<'a> {
metric_callback: Option<crate::metric::Callback>,
#[borrows(mut conn, stmt, metric_callback)]
#[not_covariant]
stream: Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + 'this>>,
stream: Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + 'this + Send>>,
}

impl<'a> std::fmt::Debug for TransactionStream<'a> {
Expand All @@ -35,62 +35,57 @@ impl<'a> std::fmt::Debug for TransactionStream<'a> {

impl<'a> TransactionStream<'a> {
#[instrument(level = "trace", skip(metric_callback))]
pub(crate) async fn build(
pub(crate) fn build(
conn: MutexGuard<'a, InnerConnection>,
stmt: Statement,
metric_callback: Option<crate::metric::Callback>,
) -> TransactionStream<'a> {
TransactionStreamAsyncBuilder {
TransactionStreamBuilder {
stmt,
conn,
metric_callback,
stream_builder: |conn, stmt, _metric_callback| {
Box::pin(async move {
match conn.deref_mut() {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(c) => {
let query = crate::driver::sqlx_mysql::sqlx_query(stmt);
crate::metric::metric_ok!(_metric_callback, stmt, {
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
as Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>>>>
})
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(c) => {
let query = crate::driver::sqlx_postgres::sqlx_query(stmt);
crate::metric::metric_ok!(_metric_callback, stmt, {
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
as Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>>>>
})
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(c) => {
let query = crate::driver::sqlx_sqlite::sqlx_query(stmt);
crate::metric::metric_ok!(_metric_callback, stmt, {
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
as Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>>>>
})
}
#[cfg(feature = "mock")]
InnerConnection::Mock(c) => c.fetch(stmt),
}
})
stream_builder: |conn, stmt, _metric_callback| match conn.deref_mut() {
#[cfg(feature = "sqlx-mysql")]
InnerConnection::MySql(c) => {
let query = crate::driver::sqlx_mysql::sqlx_query(stmt);
crate::metric::metric_ok!(_metric_callback, stmt, {
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
as Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send>>
})
}
#[cfg(feature = "sqlx-postgres")]
InnerConnection::Postgres(c) => {
let query = crate::driver::sqlx_postgres::sqlx_query(stmt);
crate::metric::metric_ok!(_metric_callback, stmt, {
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
as Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send>>
})
}
#[cfg(feature = "sqlx-sqlite")]
InnerConnection::Sqlite(c) => {
let query = crate::driver::sqlx_sqlite::sqlx_query(stmt);
crate::metric::metric_ok!(_metric_callback, stmt, {
Box::pin(
c.fetch(query)
.map_ok(Into::into)
.map_err(crate::sqlx_error_to_query_err),
)
as Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send>>
})
}
#[cfg(feature = "mock")]
InnerConnection::Mock(c) => c.fetch(stmt),
},
}
.build()
.await
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/database/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,14 +354,14 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction {
fn stream(
&'a self,
stmt: Statement,
) -> Pin<Box<dyn Future<Output = Result<Self::Stream, DbErr>> + 'a>> {
) -> Pin<Box<dyn Future<Output = Result<Self::Stream, DbErr>> + 'a + Send>> {
Box::pin(async move {
let conn = self.conn.lock().await;
Ok(crate::TransactionStream::build(
self.conn.lock().await,
conn,
stmt,
self.metric_callback.clone(),
)
.await)
))
})
}

Expand Down
2 changes: 1 addition & 1 deletion src/driver/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ impl MockDatabaseConnection {
pub fn fetch(
&self,
statement: &Statement,
) -> Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>>>> {
) -> Pin<Box<dyn Stream<Item = Result<QueryResult, DbErr>> + Send>> {
match self.query_all(statement.clone()) {
Ok(v) => Box::pin(futures::stream::iter(v.into_iter().map(Ok))),
Err(e) => Box::pin(futures::stream::iter(Some(Err(e)).into_iter())),
Expand Down
20 changes: 11 additions & 9 deletions src/executor/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,9 @@ where
pub async fn stream<'a: 'b, 'b, C>(
self,
db: &'a C,
) -> Result<impl Stream<Item = Result<E::Model, DbErr>> + 'b, DbErr>
) -> Result<impl Stream<Item = Result<E::Model, DbErr>> + 'b + Send, DbErr>
where
C: ConnectionTrait<'a>,
C: ConnectionTrait<'a> + Send,
{
self.into_model().stream(db).await
}
Expand Down Expand Up @@ -329,7 +329,7 @@ where
db: &'a C,
) -> Result<impl Stream<Item = Result<(E::Model, Option<F::Model>), DbErr>> + 'b, DbErr>
where
C: ConnectionTrait<'a>,
C: ConnectionTrait<'a> + Send,
{
self.into_model().stream(db).await
}
Expand Down Expand Up @@ -373,9 +373,9 @@ where
pub async fn stream<'a: 'b, 'b, C>(
self,
db: &'a C,
) -> Result<impl Stream<Item = Result<(E::Model, Option<F::Model>), DbErr>> + 'b, DbErr>
) -> Result<impl Stream<Item = Result<(E::Model, Option<F::Model>), DbErr>> + 'b + Send, DbErr>
where
C: ConnectionTrait<'a>,
C: ConnectionTrait<'a> + Send,
{
self.into_model().stream(db).await
}
Expand Down Expand Up @@ -452,10 +452,11 @@ where
pub async fn stream<'a: 'b, 'b, C>(
self,
db: &'a C,
) -> Result<Pin<Box<dyn Stream<Item = Result<S::Item, DbErr>> + 'b>>, DbErr>
) -> Result<Pin<Box<dyn Stream<Item = Result<S::Item, DbErr>> + 'b + Send>>, DbErr>
where
C: ConnectionTrait<'a>,
C: ConnectionTrait<'a> + Send,
S: 'b,
S::Item: Send,
{
self.into_selector_raw(db).stream(db).await
}
Expand Down Expand Up @@ -737,10 +738,11 @@ where
pub async fn stream<'a: 'b, 'b, C>(
self,
db: &'a C,
) -> Result<Pin<Box<dyn Stream<Item = Result<S::Item, DbErr>> + 'b>>, DbErr>
) -> Result<Pin<Box<dyn Stream<Item = Result<S::Item, DbErr>> + 'b + Send>>, DbErr>
where
C: ConnectionTrait<'a>,
C: ConnectionTrait<'a> + Send,
S: 'b,
S::Item: Send,
{
let stream = db.stream(self.stmt).await?;
Ok(Box::pin(stream.and_then(|row| {
Expand Down