diff --git a/crates/bevy_ecs/src/schedule/executor/multi_threaded.rs b/crates/bevy_ecs/src/schedule/executor/multi_threaded.rs index 0606e7fbdc799..a00d328c843aa 100644 --- a/crates/bevy_ecs/src/schedule/executor/multi_threaded.rs +++ b/crates/bevy_ecs/src/schedule/executor/multi_threaded.rs @@ -1,4 +1,7 @@ -use std::sync::Arc; +use std::{ + any::Any, + sync::{Arc, Mutex}, +}; use bevy_tasks::{ComputeTaskPool, Scope, TaskPool, ThreadExecutor}; use bevy_utils::default; @@ -63,12 +66,18 @@ struct SystemTaskMetadata { is_exclusive: bool, } +/// The result of running a system that is sent across a channel. +struct SystemResult { + system_index: usize, + success: bool, +} + /// Runs the schedule using a thread pool. Non-conflicting systems can run in parallel. pub struct MultiThreadedExecutor { /// Sends system completion events. - sender: Sender, + sender: Sender, /// Receives system completion events. - receiver: Receiver, + receiver: Receiver, /// Metadata for scheduling and running system tasks. system_task_metadata: Vec, /// Union of the accesses of all currently running systems. @@ -77,6 +86,8 @@ pub struct MultiThreadedExecutor { local_thread_running: bool, /// Returns `true` if an exclusive system is running. exclusive_running: bool, + /// The number of systems expected to run. + num_systems: usize, /// The number of systems that are running. num_running_systems: usize, /// The number of systems that have completed. @@ -99,6 +110,10 @@ pub struct MultiThreadedExecutor { unapplied_systems: FixedBitSet, /// Setting when true applies system buffers after all systems have run apply_final_buffers: bool, + /// When set, tells the executor that a thread has panicked. + panic_payload: Arc>>>, + /// When set, stops the executor from running any more systems. + stop_spawning: bool, } impl Default for MultiThreadedExecutor { @@ -148,8 +163,8 @@ impl SystemExecutor for MultiThreadedExecutor { fn run(&mut self, schedule: &mut SystemSchedule, world: &mut World) { // reset counts - let num_systems = schedule.systems.len(); - if num_systems == 0 { + self.num_systems = schedule.systems.len(); + if self.num_systems == 0 { return; } self.num_running_systems = 0; @@ -182,7 +197,7 @@ impl SystemExecutor for MultiThreadedExecutor { // the executor itself is a `Send` future so that it can run // alongside systems that claim the local thread let executor = async { - while self.num_completed_systems < num_systems { + while self.num_completed_systems < self.num_systems { // SAFETY: self.ready_systems does not contain running systems unsafe { self.spawn_system_tasks(scope, systems, &mut conditions, world); @@ -190,15 +205,14 @@ impl SystemExecutor for MultiThreadedExecutor { if self.num_running_systems > 0 { // wait for systems to complete - let index = - self.receiver.recv().await.expect( - "A system has panicked so the executor cannot continue.", - ); - - self.finish_system_and_signal_dependents(index); + if let Ok(result) = self.receiver.recv().await { + self.finish_system_and_handle_dependents(result); + } else { + panic!("Channel closed unexpectedly!"); + } - while let Ok(index) = self.receiver.try_recv() { - self.finish_system_and_signal_dependents(index); + while let Ok(result) = self.receiver.try_recv() { + self.finish_system_and_handle_dependents(result); } self.rebuild_active_access(); @@ -217,11 +231,21 @@ impl SystemExecutor for MultiThreadedExecutor { if self.apply_final_buffers { // Do one final apply buffers after all systems have completed // Commands should be applied while on the scope's thread, not the executor's thread - apply_system_buffers(&self.unapplied_systems, systems, world.get_mut()); + let res = apply_system_buffers(&self.unapplied_systems, systems, world.get_mut()); + if let Err(payload) = res { + let mut panic_payload = self.panic_payload.lock().unwrap(); + *panic_payload = Some(payload); + } self.unapplied_systems.clear(); debug_assert!(self.unapplied_systems.is_clear()); } + // check to see if there was a panic + let mut payload = self.panic_payload.lock().unwrap(); + if let Some(payload) = payload.take() { + std::panic::resume_unwind(payload); + } + debug_assert!(self.ready_systems.is_clear()); debug_assert!(self.running_systems.is_clear()); self.active_access.clear(); @@ -238,6 +262,7 @@ impl MultiThreadedExecutor { sender, receiver, system_task_metadata: Vec::new(), + num_systems: 0, num_running_systems: 0, num_completed_systems: 0, num_dependencies_remaining: Vec::new(), @@ -252,6 +277,8 @@ impl MultiThreadedExecutor { completed_systems: FixedBitSet::new(), unapplied_systems: FixedBitSet::new(), apply_final_buffers: true, + panic_payload: Arc::new(Mutex::new(None)), + stop_spawning: false, } } @@ -438,6 +465,7 @@ impl MultiThreadedExecutor { let system_span = info_span!("system", name = &*system.name()); let sender = self.sender.clone(); + let panic_payload = self.panic_payload.clone(); let task = async move { #[cfg(feature = "trace")] let system_guard = system_span.enter(); @@ -447,14 +475,20 @@ impl MultiThreadedExecutor { })); #[cfg(feature = "trace")] drop(system_guard); - if res.is_err() { - // close the channel to propagate the error to the - // multithreaded executor - sender.close(); - } else { - sender - .try_send(system_index) - .unwrap_or_else(|error| unreachable!("{}", error)); + // tell the executor that the system finished + sender + .try_send(SystemResult { + system_index, + success: res.is_ok(), + }) + .unwrap_or_else(|error| unreachable!("{}", error)); + if let Err(payload) = res { + eprintln!("Encountered a panic in system `{}`!", &*system.name()); + // set the payload to propagate the error + { + let mut panic_payload = panic_payload.lock().unwrap(); + *panic_payload = Some(payload); + } } }; @@ -491,6 +525,7 @@ impl MultiThreadedExecutor { let system_span = info_span!("system", name = &*system.name()); let sender = self.sender.clone(); + let panic_payload = self.panic_payload.clone(); if is_apply_system_buffers(system) { // TODO: avoid allocation let unapplied_systems = self.unapplied_systems.clone(); @@ -498,19 +533,20 @@ impl MultiThreadedExecutor { let task = async move { #[cfg(feature = "trace")] let system_guard = system_span.enter(); - let res = std::panic::catch_unwind(AssertUnwindSafe(|| { - apply_system_buffers(&unapplied_systems, systems, world); - })); + let res = apply_system_buffers(&unapplied_systems, systems, world); #[cfg(feature = "trace")] drop(system_guard); - if res.is_err() { - // close the channel to propagate the error to the - // multithreaded executor - sender.close(); - } else { - sender - .try_send(system_index) - .unwrap_or_else(|error| unreachable!("{}", error)); + // tell the executor that the system finished + sender + .try_send(SystemResult { + system_index, + success: res.is_ok(), + }) + .unwrap_or_else(|error| unreachable!("{}", error)); + if let Err(payload) = res { + // set the payload to propagate the error + let mut panic_payload = panic_payload.lock().unwrap(); + *panic_payload = Some(payload); } }; @@ -526,14 +562,21 @@ impl MultiThreadedExecutor { })); #[cfg(feature = "trace")] drop(system_guard); - if res.is_err() { - // close the channel to propagate the error to the - // multithreaded executor - sender.close(); - } else { - sender - .try_send(system_index) - .unwrap_or_else(|error| unreachable!("{}", error)); + // tell the executor that the system finished + sender + .try_send(SystemResult { + system_index, + success: res.is_ok(), + }) + .unwrap_or_else(|error| unreachable!("{}", error)); + if let Err(payload) = res { + eprintln!( + "Encountered a panic in exclusive system `{}`!", + &*system.name() + ); + // set the payload to propagate the error + let mut panic_payload = panic_payload.lock().unwrap(); + *panic_payload = Some(payload); } }; @@ -546,7 +589,12 @@ impl MultiThreadedExecutor { self.local_thread_running = true; } - fn finish_system_and_signal_dependents(&mut self, system_index: usize) { + fn finish_system_and_handle_dependents(&mut self, result: SystemResult) { + let SystemResult { + system_index, + success, + } = result; + if self.system_task_metadata[system_index].is_exclusive { self.exclusive_running = false; } @@ -561,7 +609,12 @@ impl MultiThreadedExecutor { self.running_systems.set(system_index, false); self.completed_systems.insert(system_index); self.unapplied_systems.insert(system_index); + self.signal_dependents(system_index); + + if !success { + self.stop_spawning_systems(); + } } fn skip_system_and_signal_dependents(&mut self, system_index: usize) { @@ -581,6 +634,13 @@ impl MultiThreadedExecutor { } } + fn stop_spawning_systems(&mut self) { + if !self.stop_spawning { + self.num_systems = self.num_completed_systems + self.num_running_systems; + self.stop_spawning = true; + } + } + fn rebuild_active_access(&mut self) { self.active_access.clear(); for index in self.running_systems.ones() { @@ -595,12 +655,22 @@ fn apply_system_buffers( unapplied_systems: &FixedBitSet, systems: &[SyncUnsafeCell], world: &mut World, -) { +) -> Result<(), Box> { for system_index in unapplied_systems.ones() { // SAFETY: none of these systems are running, no other references exist let system = unsafe { &mut *systems[system_index].get() }; - system.apply_buffers(world); + let res = std::panic::catch_unwind(AssertUnwindSafe(|| { + system.apply_buffers(world); + })); + if let Err(payload) = res { + eprintln!( + "Encountered a panic when applying buffers for system `{}`!", + &*system.name() + ); + return Err(payload); + } } + Ok(()) } fn evaluate_and_fold_conditions(conditions: &mut [BoxedCondition], world: &World) -> bool { diff --git a/crates/bevy_ecs/src/schedule/executor/simple.rs b/crates/bevy_ecs/src/schedule/executor/simple.rs index 2aaf1777ed033..b72f2dbdd694c 100644 --- a/crates/bevy_ecs/src/schedule/executor/simple.rs +++ b/crates/bevy_ecs/src/schedule/executor/simple.rs @@ -1,6 +1,7 @@ #[cfg(feature = "trace")] use bevy_utils::tracing::info_span; use fixedbitset::FixedBitSet; +use std::panic::AssertUnwindSafe; use crate::{ schedule::{BoxedCondition, ExecutorKind, SystemExecutor, SystemSchedule}, @@ -78,9 +79,15 @@ impl SystemExecutor for SimpleExecutor { let system = &mut schedule.systems[system_index]; #[cfg(feature = "trace")] let system_span = info_span!("system", name = &*name).entered(); - system.run((), world); + let res = std::panic::catch_unwind(AssertUnwindSafe(|| { + system.run((), world); + })); #[cfg(feature = "trace")] system_span.exit(); + if let Err(payload) = res { + eprintln!("Encountered a panic in system `{}`!", &*system.name()); + std::panic::resume_unwind(payload); + } system.apply_buffers(world); } diff --git a/crates/bevy_ecs/src/schedule/executor/single_threaded.rs b/crates/bevy_ecs/src/schedule/executor/single_threaded.rs index 61b4520684e68..0bd3c8dca6fc2 100644 --- a/crates/bevy_ecs/src/schedule/executor/single_threaded.rs +++ b/crates/bevy_ecs/src/schedule/executor/single_threaded.rs @@ -1,6 +1,7 @@ #[cfg(feature = "trace")] use bevy_utils::tracing::info_span; use fixedbitset::FixedBitSet; +use std::panic::AssertUnwindSafe; use crate::{ schedule::{ @@ -95,9 +96,15 @@ impl SystemExecutor for SingleThreadedExecutor { } else { #[cfg(feature = "trace")] let system_span = info_span!("system", name = &*name).entered(); - system.run((), world); + let res = std::panic::catch_unwind(AssertUnwindSafe(|| { + system.run((), world); + })); #[cfg(feature = "trace")] system_span.exit(); + if let Err(payload) = res { + eprintln!("Encountered a panic in system `{}`!", &*system.name()); + std::panic::resume_unwind(payload); + } self.unapplied_systems.insert(system_index); } } diff --git a/crates/bevy_tasks/src/task_pool.rs b/crates/bevy_tasks/src/task_pool.rs index 69bb2cea27892..70711efe08877 100644 --- a/crates/bevy_tasks/src/task_pool.rs +++ b/crates/bevy_tasks/src/task_pool.rs @@ -350,7 +350,9 @@ impl TaskPool { let scope_executor: &'env ThreadExecutor<'env> = unsafe { mem::transmute(scope_executor) }; let spawned: ConcurrentQueue> = ConcurrentQueue::unbounded(); // shadow the variable so that the owned value cannot be used for the rest of the function - let spawned: &'env ConcurrentQueue> = unsafe { mem::transmute(&spawned) }; + let spawned: &'env ConcurrentQueue< + FallibleTask>>, + > = unsafe { mem::transmute(&spawned) }; let scope = Scope { executor, @@ -373,7 +375,14 @@ impl TaskPool { let get_results = async { let mut results = Vec::with_capacity(spawned.len()); while let Ok(task) = spawned.pop() { - results.push(task.await.unwrap()); + if let Some(res) = task.await { + match res { + Ok(res) => results.push(res), + Err(payload) => std::panic::resume_unwind(payload), + } + } else { + panic!("Failed to catch panic!"); + } } results }; @@ -571,7 +580,7 @@ pub struct Scope<'scope, 'env: 'scope, T> { executor: &'scope async_executor::Executor<'scope>, external_executor: &'scope ThreadExecutor<'scope>, scope_executor: &'scope ThreadExecutor<'scope>, - spawned: &'scope ConcurrentQueue>, + spawned: &'scope ConcurrentQueue>>>, // make `Scope` invariant over 'scope and 'env scope: PhantomData<&'scope mut &'scope ()>, env: PhantomData<&'env mut &'env ()>, @@ -587,7 +596,10 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> { /// /// For more information, see [`TaskPool::scope`]. pub fn spawn + 'scope + Send>(&self, f: Fut) { - let task = self.executor.spawn(f).fallible(); + let task = self + .executor + .spawn(AssertUnwindSafe(f).catch_unwind()) + .fallible(); // ConcurrentQueue only errors when closed or full, but we never // close and use an unbounded queue, so it is safe to unwrap self.spawned.push(task).unwrap(); @@ -600,7 +612,10 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> { /// /// For more information, see [`TaskPool::scope`]. pub fn spawn_on_scope + 'scope + Send>(&self, f: Fut) { - let task = self.scope_executor.spawn(f).fallible(); + let task = self + .scope_executor + .spawn(AssertUnwindSafe(f).catch_unwind()) + .fallible(); // ConcurrentQueue only errors when closed or full, but we never // close and use an unbounded queue, so it is safe to unwrap self.spawned.push(task).unwrap(); @@ -614,7 +629,10 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> { /// /// For more information, see [`TaskPool::scope`]. pub fn spawn_on_external + 'scope + Send>(&self, f: Fut) { - let task = self.external_executor.spawn(f).fallible(); + let task = self + .external_executor + .spawn(AssertUnwindSafe(f).catch_unwind()) + .fallible(); // ConcurrentQueue only errors when closed or full, but we never // close and use an unbounded queue, so it is safe to unwrap self.spawned.push(task).unwrap();