diff --git a/crates/bevy_tasks/Cargo.toml b/crates/bevy_tasks/Cargo.toml index 5a8ea9ff6b0ef..a088d49c4ce93 100644 --- a/crates/bevy_tasks/Cargo.toml +++ b/crates/bevy_tasks/Cargo.toml @@ -20,6 +20,8 @@ concurrent-queue = { version = "2.0.0", optional = true } [target.'cfg(target_arch = "wasm32")'.dependencies] wasm-bindgen-futures = "0.4" +pin-project = "1" +futures-channel = "0.3" [dev-dependencies] web-time = { version = "1.1" } diff --git a/crates/bevy_tasks/src/lib.rs b/crates/bevy_tasks/src/lib.rs index 17cfb348ef2c5..3eb33c6603e70 100644 --- a/crates/bevy_tasks/src/lib.rs +++ b/crates/bevy_tasks/src/lib.rs @@ -8,7 +8,9 @@ mod slice; pub use slice::{ParallelSlice, ParallelSliceMut}; +#[cfg_attr(target_arch = "wasm32", path = "wasm_task.rs")] mod task; + pub use task::Task; #[cfg(all(not(target_arch = "wasm32"), feature = "multi_threaded"))] @@ -19,7 +21,7 @@ pub use task_pool::{Scope, TaskPool, TaskPoolBuilder}; #[cfg(any(target_arch = "wasm32", not(feature = "multi_threaded")))] mod single_threaded_task_pool; #[cfg(any(target_arch = "wasm32", not(feature = "multi_threaded")))] -pub use single_threaded_task_pool::{FakeTask, Scope, TaskPool, TaskPoolBuilder, ThreadExecutor}; +pub use single_threaded_task_pool::{Scope, TaskPool, TaskPoolBuilder, ThreadExecutor}; mod usages; #[cfg(not(target_arch = "wasm32"))] diff --git a/crates/bevy_tasks/src/single_threaded_task_pool.rs b/crates/bevy_tasks/src/single_threaded_task_pool.rs index de7a13891593d..cd6bbc79d63ad 100644 --- a/crates/bevy_tasks/src/single_threaded_task_pool.rs +++ b/crates/bevy_tasks/src/single_threaded_task_pool.rs @@ -1,6 +1,8 @@ use std::sync::Arc; use std::{cell::RefCell, future::Future, marker::PhantomData, mem, rc::Rc}; +use crate::Task; + thread_local! { static LOCAL_EXECUTOR: async_executor::LocalExecutor<'static> = const { async_executor::LocalExecutor::new() }; } @@ -145,34 +147,33 @@ impl TaskPool { .collect() } - /// Spawns a static future onto the thread pool. The returned Task is a future. It can also be - /// cancelled and "detached" allowing it to continue running without having to be polled by the + /// Spawns a static future onto the thread pool. The returned Task is a future, which can be polled + /// to retrieve the output of the original future. Dropping the task will attempt to cancel it. + /// It can also be "detached", allowing it to continue running without having to be polled by the /// end-user. /// /// If the provided future is non-`Send`, [`TaskPool::spawn_local`] should be used instead. - pub fn spawn(&self, future: impl Future + 'static) -> FakeTask + pub fn spawn(&self, future: impl Future + 'static) -> Task where T: 'static, { #[cfg(target_arch = "wasm32")] - wasm_bindgen_futures::spawn_local(async move { - future.await; - }); + return Task::wrap_future(future); #[cfg(not(target_arch = "wasm32"))] { LOCAL_EXECUTOR.with(|executor| { - let _task = executor.spawn(future); + let task = executor.spawn(future); // Loop until all tasks are done while executor.try_tick() {} - }); - } - FakeTask + Task::new(task) + }) + } } /// Spawns a static future on the JS event loop. This is exactly the same as [`TaskPool::spawn`]. - pub fn spawn_local(&self, future: impl Future + 'static) -> FakeTask + pub fn spawn_local(&self, future: impl Future + 'static) -> Task where T: 'static, { @@ -198,17 +199,6 @@ impl TaskPool { } } -/// An empty task used in single-threaded contexts. -/// -/// This does nothing and is therefore safe, and recommended, to ignore. -#[derive(Debug)] -pub struct FakeTask; - -impl FakeTask { - /// No op on the single threaded task pool - pub fn detach(self) {} -} - /// A `TaskPool` scope for running one or more non-`'static` futures. /// /// For more information, see [`TaskPool::scope`]. diff --git a/crates/bevy_tasks/src/wasm_task.rs b/crates/bevy_tasks/src/wasm_task.rs new file mode 100644 index 0000000000000..47c082516ad2b --- /dev/null +++ b/crates/bevy_tasks/src/wasm_task.rs @@ -0,0 +1,82 @@ +use std::{ + any::Any, + future::{Future, IntoFuture}, + panic::{AssertUnwindSafe, UnwindSafe}, + pin::Pin, + task::Poll, +}; + +use futures_channel::oneshot; + +/// Wraps an asynchronous task, a spawned future. +/// +/// Tasks are also futures themselves and yield the output of the spawned future. +#[derive(Debug)] +pub struct Task(oneshot::Receiver>); + +impl Task { + pub(crate) fn wrap_future(future: impl Future + 'static) -> Self { + let (sender, receiver) = oneshot::channel(); + wasm_bindgen_futures::spawn_local(async move { + // Catch any panics that occur when polling the future so they can + // be propagated back to the task handle. + let value = CatchUnwind(AssertUnwindSafe(future)).await; + let _ = sender.send(value); + }); + Self(receiver.into_future()) + } + + /// When building for Wasm, this method has no effect. + /// This is only included for feature parity with other platforms. + pub fn detach(self) {} + + /// Requests a task to be cancelled and returns a future that suspends until it completes. + /// Returns the output of the future if it has already completed. + /// + /// # Implementation + /// + /// When building for Wasm, it is not possible to cancel tasks, which means this is the same + /// as just awaiting the task. This method is only included for feature parity with other platforms. + pub async fn cancel(self) -> Option { + match self.0.await { + Ok(Ok(value)) => Some(value), + Err(_) => None, + Ok(Err(panic)) => { + // drop this to prevent the panic payload from resuming the panic on drop. + // this also leaks the box but I'm not sure how to avoid that + std::mem::forget(panic); + None + } + } + } +} + +impl Future for Task { + type Output = T; + fn poll( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + match Pin::new(&mut self.0).poll(cx) { + Poll::Ready(Ok(Ok(value))) => Poll::Ready(value), + // NOTE: Propagating the panic here sorta has parity with the async_executor behavior. + // For those tasks, polling them after a panic returns a `None` which gets `unwrap`ed, so + // using `resume_unwind` here is essentially keeping the same behavior while adding more information. + Poll::Ready(Ok(Err(panic))) => std::panic::resume_unwind(panic), + Poll::Ready(Err(_)) => panic!("Polled a task after it was cancelled"), + Poll::Pending => Poll::Pending, + } + } +} + +type Panic = Box; + +#[pin_project::pin_project] +struct CatchUnwind(#[pin] F); + +impl Future for CatchUnwind { + type Output = Result; + fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context) -> Poll { + std::panic::catch_unwind(AssertUnwindSafe(|| self.project().0.poll(cx)))?.map(Ok) + } +}