Skip to content

Commit

Permalink
fix(core): avoid excess wakeups in try_stream!()
Browse files Browse the repository at this point in the history
fixes #2834
  • Loading branch information
abonander committed Nov 2, 2023
1 parent 82fadce commit 87f671d
Showing 1 changed file with 86 additions and 25 deletions.
111 changes: 86 additions & 25 deletions sqlx-core/src/ext/async_stream.rs
Original file line number Diff line number Diff line change
@@ -1,70 +1,131 @@
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};

use futures_channel::mpsc;
use futures_core::future::BoxFuture;
use futures_core::stream::Stream;
use futures_util::{pin_mut, FutureExt, SinkExt};
use futures_util::FutureExt;

use crate::error::Error;

pub struct TryAsyncStream<'a, T> {
receiver: mpsc::Receiver<Result<T, Error>>,
future: BoxFuture<'a, Result<(), Error>>,
yielder: Yielder<T>,
future: BoxFuture<'a, ()>,
}

impl<'a, T> TryAsyncStream<'a, T> {
pub fn new<F, Fut>(f: F) -> Self
where
F: FnOnce(mpsc::Sender<Result<T, Error>>) -> Fut + Send,
F: FnOnce(Yielder<T>) -> Fut + Send,
Fut: 'a + Future<Output = Result<(), Error>> + Send,
T: 'a + Send,
{
let (mut sender, receiver) = mpsc::channel(0);
let yielder = Yielder::new();

let future = f(sender.clone());
let future = f(yielder.duplicate());
let future = async move {
if let Err(error) = future.await {
let _ = sender.send(Err(error)).await;
yielder.error(error).await;

Check failure on line 29 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / Check (async-std, rustls)

`()` is not a future

Check failure on line 29 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / Unit Test (tokio, rustls)

`()` is not a future

Check failure on line 29 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / CLI Unit Test

`()` is not a future

Check failure on line 29 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / Build SQLx CLI

`()` is not a future

Check failure on line 29 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / Unit Test (tokio, native-tls)

`()` is not a future

Check failure on line 29 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / CLI Binaries (ubuntu-latest)

`()` is not a future
}

Ok(())
()
}
.fuse()
.boxed();

Self { future, receiver }
Self { future, yielder }
}
}

impl<'a, T> Stream for TryAsyncStream<'a, T> {
type Item = Result<T, Error>;
pub struct Yielder<T> {
value: Arc<Mutex<Option<Result<T, Error>>>>,
}

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let future = &mut self.future;
pin_mut!(future);
impl<T> Yielder<T> {
fn new() -> Self {
Yielder {
value: Arc::new(Mutex::new(None)),
}
}

fn duplicate(&self) -> Self {
Yielder {
value: self.value.clone(),
}
}

pub async fn r#yield(&self, val: T) {
let replaced = self
.value
.lock()
.expect("BUG: panicked while holding a lock")
.replace(Ok(val));

debug_assert!(
replaced.is_none(),
"BUG: previously yielded value not taken"
);

let mut yielded = false;

// Allows control flow to escape the generating future and return to the stream.
futures_util::future::poll_fn(|_cx| {
if !yielded {
yielded = true;
Poll::Pending
} else {
Poll::Ready(())
}
})
.await
}

// the future is fused so its safe to call forever
// the future advances our "stream"
// the future should be polled in tandem with the stream receiver
let _ = future.poll(cx);
fn error(&self, err: Error) {
let replaced = self
.value
.lock()
.expect("BUG: panicked while holding a lock")
.replace(Err(err));

let receiver = &mut self.receiver;
pin_mut!(receiver);
debug_assert!(
replaced.is_none(),
"BUG: previously yielded value not taken"
);
}

fn take(&self) -> Option<Result<T, Error>> {
self.value
.lock()
.expect("BUG: panicked while holding a lock")
.take()
}
}

// then we check to see if we have anything to return
receiver.poll_next(cx)
impl<'a, T> Stream for TryAsyncStream<'a, T> {
type Item = Result<T, Error>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.future.poll_unpin(cx) {
Poll::Ready(()) => {
// Future returned without yielding another value.
Poll::Ready(None)
}
Poll::Pending => self
.yielder
.take()
.map_or(Poll::Pending, |val| Poll::Ready(Some(val))),
}
}
}

#[macro_export]
macro_rules! try_stream {
($($block:tt)*) => {
crate::ext::async_stream::TryAsyncStream::new(move |mut sender| async move {
crate::ext::async_stream::TryAsyncStream::new(move |yielder| async move {
macro_rules! r#yield {
($v:expr) => {{
let _ = futures_util::sink::SinkExt::send(&mut sender, Ok($v)).await;
yielder.r#yield(v).await;

Check failure on line 128 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / Check (async-std, rustls)

cannot find value `v` in this scope

Check failure on line 128 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / Check (async-std, rustls)

cannot find value `v` in this scope

Check failure on line 128 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / Check (async-std, native-tls)

cannot find value `v` in this scope

Check failure on line 128 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / Check (async-std, native-tls)

cannot find value `v` in this scope

Check failure on line 128 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / Unit Test (tokio, rustls)

cannot find value `v` in this scope

Check failure on line 128 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / Unit Test (tokio, rustls)

cannot find value `v` in this scope

Check failure on line 128 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / Check (tokio, native-tls)

cannot find value `v` in this scope

Check failure on line 128 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / Check (tokio, native-tls)

cannot find value `v` in this scope

Check failure on line 128 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / Check (async-std, none)

cannot find value `v` in this scope

Check failure on line 128 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / Check (async-std, none)

cannot find value `v` in this scope

Check failure on line 128 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / CLI Unit Test

cannot find value `v` in this scope

Check failure on line 128 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / CLI Unit Test

cannot find value `v` in this scope

Check failure on line 128 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / Build SQLx CLI

cannot find value `v` in this scope

Check failure on line 128 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / Build SQLx CLI

cannot find value `v` in this scope

Check failure on line 128 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / Unit Test (tokio, native-tls)

cannot find value `v` in this scope

Check failure on line 128 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / Unit Test (tokio, native-tls)

cannot find value `v` in this scope

Check failure on line 128 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / Unit Test (tokio, none)

cannot find value `v` in this scope

Check failure on line 128 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / Unit Test (tokio, none)

cannot find value `v` in this scope

Check failure on line 128 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / CLI Binaries (ubuntu-latest)

cannot find value `v` in this scope

Check failure on line 128 in sqlx-core/src/ext/async_stream.rs

View workflow job for this annotation

GitHub Actions / CLI Binaries (ubuntu-latest)

cannot find value `v` in this scope
}}
}

Expand Down

0 comments on commit 87f671d

Please sign in to comment.