diff --git a/sqlx-core/src/ext/async_stream.rs b/sqlx-core/src/ext/async_stream.rs index c08c0d9c74..b6a84f5ca6 100644 --- a/sqlx-core/src/ext/async_stream.rs +++ b/sqlx-core/src/ext/async_stream.rs @@ -1,70 +1,131 @@ use std::future::Future; use std::pin::Pin; +use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; -use futures_channel::mpsc; use futures_core::future::BoxFuture; use futures_core::stream::Stream; -use futures_util::{pin_mut, FutureExt, SinkExt}; +use futures_util::FutureExt; use crate::error::Error; pub struct TryAsyncStream<'a, T> { - receiver: mpsc::Receiver>, - future: BoxFuture<'a, Result<(), Error>>, + yielder: Yielder, + future: BoxFuture<'a, ()>, } impl<'a, T> TryAsyncStream<'a, T> { pub fn new(f: F) -> Self where - F: FnOnce(mpsc::Sender>) -> Fut + Send, + F: FnOnce(Yielder) -> Fut + Send, Fut: 'a + Future> + Send, T: 'a + Send, { - let (mut sender, receiver) = mpsc::channel(0); + let yielder = Yielder::new(); - let future = f(sender.clone()); + let future = f(yielder.duplicate()); let future = async move { if let Err(error) = future.await { - let _ = sender.send(Err(error)).await; + yielder.error(error).await; } - Ok(()) + () } .fuse() .boxed(); - Self { future, receiver } + Self { future, yielder } } } -impl<'a, T> Stream for TryAsyncStream<'a, T> { - type Item = Result; +pub struct Yielder { + value: Arc>>>, +} - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let future = &mut self.future; - pin_mut!(future); +impl Yielder { + fn new() -> Self { + Yielder { + value: Arc::new(Mutex::new(None)), + } + } + + fn duplicate(&self) -> Self { + Yielder { + value: self.value.clone(), + } + } + + pub async fn r#yield(&self, val: T) { + let replaced = self + .value + .lock() + .expect("BUG: panicked while holding a lock") + .replace(Ok(val)); + + debug_assert!( + replaced.is_none(), + "BUG: previously yielded value not taken" + ); + + let mut yielded = false; + + // Allows control flow to escape the generating future and return to the stream. + futures_util::future::poll_fn(|_cx| { + if !yielded { + yielded = true; + Poll::Pending + } else { + Poll::Ready(()) + } + }) + .await + } - // the future is fused so its safe to call forever - // the future advances our "stream" - // the future should be polled in tandem with the stream receiver - let _ = future.poll(cx); + fn error(&self, err: Error) { + let replaced = self + .value + .lock() + .expect("BUG: panicked while holding a lock") + .replace(Err(err)); - let receiver = &mut self.receiver; - pin_mut!(receiver); + debug_assert!( + replaced.is_none(), + "BUG: previously yielded value not taken" + ); + } + + fn take(&self) -> Option> { + self.value + .lock() + .expect("BUG: panicked while holding a lock") + .take() + } +} - // then we check to see if we have anything to return - receiver.poll_next(cx) +impl<'a, T> Stream for TryAsyncStream<'a, T> { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.future.poll_unpin(cx) { + Poll::Ready(()) => { + // Future returned without yielding another value. + Poll::Ready(None) + } + Poll::Pending => self + .yielder + .take() + .map_or(Poll::Pending, |val| Poll::Ready(Some(val))), + } } } #[macro_export] macro_rules! try_stream { ($($block:tt)*) => { - crate::ext::async_stream::TryAsyncStream::new(move |mut sender| async move { + crate::ext::async_stream::TryAsyncStream::new(move |yielder| async move { macro_rules! r#yield { ($v:expr) => {{ - let _ = futures_util::sink::SinkExt::send(&mut sender, Ok($v)).await; + yielder.r#yield(v).await; }} }