diff --git a/bagua-core-internal/src/lib.rs b/bagua-core-internal/src/lib.rs index 82d964f..9790612 100644 --- a/bagua-core-internal/src/lib.rs +++ b/bagua-core-internal/src/lib.rs @@ -16,10 +16,12 @@ use crate::telemetry::{SCHEDULED_THREAD_POOL, TELEMETRY}; use cpp::cpp; use datatypes::{BaguaBucket, BaguaTensor}; use events::BaguaEventChannel; +use flume::RecvTimeoutError; use hashbrown::{HashMap, HashSet}; use std::collections::VecDeque; use std::fmt::Debug; use std::sync::Arc; +use std::time::Duration; use thiserror::Error; cpp! {{ @@ -120,6 +122,7 @@ pub struct BaguaCommBackend { channels: Arc, managed_ptrs: HashSet, comm_worker: std::thread::JoinHandle<()>, + comm_monitor: std::thread::JoinHandle<()>, } impl BaguaCommBackend { @@ -168,6 +171,10 @@ impl BaguaCommBackend { let channels = Arc::new(BaguaCommOpChannels::new(schedule_channel_cap)); let channels_clone = channels.clone(); + let (monitor_op_start_channel_sender, monitor_op_start_channel_receiver) = + flume::unbounded(); + let (monitor_op_finish_channel_sender, monitor_op_finish_channel_receiver) = + flume::unbounded(); BaguaCommBackend { ordered_buckets: Default::default(), @@ -190,6 +197,7 @@ impl BaguaCommBackend { "worker received scheduled communication operation {:?}", comm_op ); + monitor_op_start_channel_sender.send(comm_op.bucket.clone()); for op in &comm_op.ops { op.execute_background_communication( comm_op.bucket.clone(), @@ -199,6 +207,18 @@ impl BaguaCommBackend { tracing::debug!("comm op executed: {:?}", comm_op); comm_op.event_channel.finish(); tracing::debug!("comm op marked finished: {:?}", comm_op); + monitor_op_finish_channel_sender.send(()); + } + }), + comm_monitor: std::thread::spawn(move || loop { + let op_bucket = monitor_op_start_channel_receiver + .recv() + .expect("monitor cannot receive next comm op bucket"); + match monitor_op_finish_channel_receiver.recv_timeout(Duration::from_secs(300)) { + Ok(_) => {} + Err(_) => { + panic!("{:?} comm op has not finished for 5 min, panic", op_bucket); + } } }), } diff --git a/bagua-core-py/src/lib.rs b/bagua-core-py/src/lib.rs index de8a6d5..1d9e25a 100644 --- a/bagua-core-py/src/lib.rs +++ b/bagua-core-py/src/lib.rs @@ -331,6 +331,14 @@ fn bagua_core(_py: Python, m: &PyModule) -> PyResult<()> { .init(); color_eyre::install().unwrap(); + // panic the whole process when thread panics + let orig_hook = std::panic::take_hook(); + std::panic::set_hook(Box::new(move |panic_info| { + // invoke the default handler and exit the process + orig_hook(panic_info); + std::process::exit(1); + })); + m.add_class::()?; m.add_class::()?; m.add_class::()?;