diff --git a/tokio/src/time/driver/entry.rs b/tokio/src/time/driver/entry.rs index e0926797fd5..ae99f297a11 100644 --- a/tokio/src/time/driver/entry.rs +++ b/tokio/src/time/driver/entry.rs @@ -437,6 +437,17 @@ impl TimerShared { true_when } + /// Sets the cached time-of-expiration value. + /// + /// SAFETY: Must be called with the driver lock held, and when this entry is + /// not in any timer wheel lists. + pub(super) unsafe fn set_cached_when(&self, when: u64) { + self.driver_state + .0 + .cached_when + .store(when, Ordering::Relaxed); + } + /// Returns the true time-of-expiration value, with relaxed memory ordering. pub(super) fn true_when(&self) -> u64 { self.state.when().expect("Timer already fired") @@ -620,6 +631,10 @@ impl TimerHandle { unsafe { self.inner.as_ref().sync_when() } } + pub(super) unsafe fn set_cached_when(&self, when: u64) { + unsafe { self.inner.as_ref().set_cached_when(when) } + } + pub(super) unsafe fn is_pending(&self) -> bool { unsafe { self.inner.as_ref().state.is_pending() } } diff --git a/tokio/src/time/driver/wheel/mod.rs b/tokio/src/time/driver/wheel/mod.rs index e9df87afabc..2156b24e962 100644 --- a/tokio/src/time/driver/wheel/mod.rs +++ b/tokio/src/time/driver/wheel/mod.rs @@ -118,10 +118,10 @@ impl Wheel { /// Remove `item` from the timing wheel. pub(crate) unsafe fn remove(&mut self, item: NonNull) { unsafe { - if !item.as_ref().might_be_registered() { + let when = item.as_ref().cached_when(); + if when == u64::max_value() { self.pending.remove(item); } else { - let when = item.as_ref().cached_when(); let level = self.level_for(when); self.levels[level].remove_entry(item); @@ -244,6 +244,9 @@ impl Wheel { match unsafe { item.mark_pending(expiration.deadline) } { Ok(()) => { // Item was expired + unsafe { + item.set_cached_when(u64::max_value()); + } self.pending.push_front(item); } Err(expiration_tick) => { diff --git a/tokio/tests/time_sleep.rs b/tokio/tests/time_sleep.rs index d110ec27a8d..0880ee33d3c 100644 --- a/tokio/tests/time_sleep.rs +++ b/tokio/tests/time_sleep.rs @@ -308,3 +308,73 @@ async fn no_out_of_bounds_close_to_max() { fn ms(n: u64) -> Duration { Duration::from_millis(n) } + +#[tokio::test] +async fn drop_after_reschedule_at_new_scheduled_time() { + use futures::poll; + + tokio::time::pause(); + + let start = tokio::time::Instant::now(); + + let mut a = tokio::time::sleep(Duration::from_millis(5)); + let mut b = tokio::time::sleep(Duration::from_millis(5)); + let mut c = tokio::time::sleep(Duration::from_millis(10)); + + let _ = poll!(&mut a); + let _ = poll!(&mut b); + let _ = poll!(&mut c); + + b.reset(start + Duration::from_millis(10)); + a.await; + + drop(b); +} + +#[tokio::test] +async fn drop_from_wake() { + use std::future::Future; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::{Arc, Mutex}; + use std::task::Context; + + let paniced = Arc::new(AtomicBool::new(false)); + let list: Arc>> = Arc::new(Mutex::new(Vec::new())); + + let arc_wake = Arc::new(DropWaker(paniced.clone(), list.clone())); + let arc_wake = futures::task::waker(arc_wake); + + tokio::time::pause(); + + let mut lock = list.lock().unwrap(); + + for _ in 0..100 { + let mut timer = tokio::time::sleep(Duration::from_millis(10)); + + let _ = std::pin::Pin::new(&mut timer).poll(&mut Context::from_waker(&arc_wake)); + + lock.push(timer); + } + + drop(lock); + + tokio::time::sleep(Duration::from_millis(11)).await; + + assert!( + !paniced.load(Ordering::SeqCst), + "paniced when dropping timers" + ); + + #[derive(Clone)] + struct DropWaker(Arc, Arc>>); + + impl futures::task::ArcWake for DropWaker { + fn wake_by_ref(arc_self: &Arc) { + if let Err(_) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + *arc_self.1.lock().expect("panic in lock") = Vec::new() + })) { + arc_self.0.store(true, Ordering::SeqCst); + } + } + } +}