Skip to content

Commit

Permalink
time: Fix race condition in timer drop (#3229)
Browse files Browse the repository at this point in the history
Dropping a timer on the millisecond that it was scheduled for, when it was on
the pending list, could result in a panic previously, as we did not record the
pending-list state in cached_when.

Hopefully fixes: ZcashFoundation/zebra#1452
  • Loading branch information
bdonlan authored Dec 9, 2020
1 parent fc7a4b3 commit 9706ca9
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 9 deletions.
24 changes: 17 additions & 7 deletions tokio/src/time/driver/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,17 @@ impl TimerShared {
true_when
}

/// Sets the cached time-of-expiration value.
///
/// SAFETY: Must be called with the driver lock held, and when this entry is
/// not in any timer wheel lists.
unsafe fn set_cached_when(&self, when: u64) {
self.driver_state
.0
.cached_when
.store(when, Ordering::Relaxed);
}

/// Returns the true time-of-expiration value, with relaxed memory ordering.
pub(super) fn true_when(&self) -> u64 {
self.state.when().expect("Timer already fired")
Expand Down Expand Up @@ -643,14 +654,13 @@ impl TimerHandle {
/// After returning Ok, the entry must be added to the pending list.
pub(super) unsafe fn mark_pending(&self, not_after: u64) -> Result<(), u64> {
match self.inner.as_ref().state.mark_pending(not_after) {
Ok(()) => Ok(()),
Ok(()) => {
// mark this as being on the pending queue in cached_when
self.inner.as_ref().set_cached_when(u64::max_value());
Ok(())
}
Err(tick) => {
self.inner
.as_ref()
.driver_state
.0
.cached_when
.store(tick, Ordering::Relaxed);
self.inner.as_ref().set_cached_when(tick);
Err(tick)
}
}
Expand Down
4 changes: 2 additions & 2 deletions tokio/src/time/driver/wheel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ impl Wheel {
/// Remove `item` from the timing wheel.
pub(crate) unsafe fn remove(&mut self, item: NonNull<TimerShared>) {
unsafe {
if !item.as_ref().might_be_registered() {
let when = item.as_ref().cached_when();
if when == u64::max_value() {
self.pending.remove(item);
} else {
let when = item.as_ref().cached_when();
let level = self.level_for(when);

self.levels[level].remove_entry(item);
Expand Down
72 changes: 72 additions & 0 deletions tokio/tests/time_sleep.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,75 @@ async fn no_out_of_bounds_close_to_max() {
fn ms(n: u64) -> Duration {
Duration::from_millis(n)
}

#[tokio::test]
async fn drop_after_reschedule_at_new_scheduled_time() {
use futures::poll;

tokio::time::pause();

let start = tokio::time::Instant::now();

let mut a = tokio::time::sleep(Duration::from_millis(5));
let mut b = tokio::time::sleep(Duration::from_millis(5));
let mut c = tokio::time::sleep(Duration::from_millis(10));

let _ = poll!(&mut a);
let _ = poll!(&mut b);
let _ = poll!(&mut c);

b.reset(start + Duration::from_millis(10));
a.await;

drop(b);
}

#[tokio::test]
async fn drop_from_wake() {
use std::future::Future;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::task::Context;

let panicked = Arc::new(AtomicBool::new(false));
let list: Arc<Mutex<Vec<tokio::time::Sleep>>> = Arc::new(Mutex::new(Vec::new()));

let arc_wake = Arc::new(DropWaker(panicked.clone(), list.clone()));
let arc_wake = futures::task::waker(arc_wake);

tokio::time::pause();

let mut lock = list.lock().unwrap();

for _ in 0..100 {
let mut timer = tokio::time::sleep(Duration::from_millis(10));

let _ = std::pin::Pin::new(&mut timer).poll(&mut Context::from_waker(&arc_wake));

lock.push(timer);
}

drop(lock);

tokio::time::sleep(Duration::from_millis(11)).await;

assert!(
!panicked.load(Ordering::SeqCst),
"paniced when dropping timers"
);

#[derive(Clone)]
struct DropWaker(Arc<AtomicBool>, Arc<Mutex<Vec<tokio::time::Sleep>>>);

impl futures::task::ArcWake for DropWaker {
fn wake_by_ref(arc_self: &Arc<Self>) {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
*arc_self.1.lock().expect("panic in lock") = Vec::new()
}));

if result.is_err() {
arc_self.0.store(true, Ordering::SeqCst);
}
}
}
}

0 comments on commit 9706ca9

Please sign in to comment.