diff --git a/tokio/tests/rt_metrics.rs b/tokio/tests/rt_metrics.rs index dac8860b345..7adee28f917 100644 --- a/tokio/tests/rt_metrics.rs +++ b/tokio/tests/rt_metrics.rs @@ -2,7 +2,8 @@ #![warn(rust_2018_idioms)] #![cfg(all(feature = "full", not(target_os = "wasi"), target_has_atomic = "64"))] -use std::sync::{Arc, Barrier}; +use std::sync::mpsc; +use std::time::Duration; use tokio::runtime::Runtime; #[test] @@ -68,36 +69,51 @@ fn global_queue_depth_multi_thread() { let rt = threaded(); let metrics = rt.metrics(); - let barrier1 = Arc::new(Barrier::new(3)); - let barrier2 = Arc::new(Barrier::new(3)); - - // Spawn a task per runtime worker to block it. - for _ in 0..2 { - let barrier1 = barrier1.clone(); - let barrier2 = barrier2.clone(); - rt.spawn(async move { - barrier1.wait(); - barrier2.wait(); - }); - } - - barrier1.wait(); + for _ in 0..10 { + if let Ok(_blocking_tasks) = try_block_threaded(&rt) { + for i in 0..10 { + assert_eq!(i, metrics.global_queue_depth()); + rt.spawn(async {}); + } - let mut fail: Option = None; - for i in 0..10 { - let depth = metrics.global_queue_depth(); - if i != depth { - fail = Some(format!("{i} is not equal to {depth}")); - break; + return; } - rt.spawn(async {}); } - barrier2.wait(); + panic!("exhausted every try to block the runtime"); +} - if let Some(fail) = fail { - panic!("{fail}"); +fn try_block_threaded(rt: &Runtime) -> Result>, mpsc::RecvTimeoutError> { + let (tx, rx) = mpsc::channel(); + + let blocking_tasks = (0..rt.metrics().num_workers()) + .map(|_| { + let tx = tx.clone(); + let (task, barrier) = mpsc::channel(); + + // Spawn a task per runtime worker to block it. + rt.spawn(async move { + tx.send(()).unwrap(); + barrier.recv().ok(); + }); + + task + }) + .collect(); + + // Make sure the previously spawned tasks are blocking the runtime by + // receiving a message from each blocking task. + // + // If this times out we were unsuccessful in blocking the runtime and hit + // a deadlock instead (which might happen and is expected behaviour). + for _ in 0..rt.metrics().num_workers() { + rx.recv_timeout(Duration::from_secs(1))?; } + + // Return senders of the mpsc channels used for blocking the runtime as a + // surrogate handle for the tasks. Sending a message or dropping the senders + // will unblock the runtime. + Ok(blocking_tasks) } fn current_thread() -> Runtime {