Skip to content

Commit

Permalink
time: Fix race condition in timer drop
Browse files Browse the repository at this point in the history
Dropping a timer on the millisecond that it was scheduled for, when it was on
the pending list, could result in a panic previously, as we did not record the
pending-list state in cached_when.

Hopefully fixes: ZcashFoundation/zebra#1452
  • Loading branch information
Bryan Donlan committed Dec 7, 2020
1 parent 0707f4c commit 62be632
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 2 deletions.
15 changes: 15 additions & 0 deletions tokio/src/time/driver/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,17 @@ impl TimerShared {
true_when
}

/// Sets the cached time-of-expiration value.
///
/// SAFETY: Must be called with the driver lock held, and when this entry is
/// not in any timer wheel lists.
pub(super) unsafe fn set_cached_when(&self, when: u64) {
self.driver_state
.0
.cached_when
.store(when, Ordering::Relaxed);
}

/// Returns the true time-of-expiration value, with relaxed memory ordering.
pub(super) fn true_when(&self) -> u64 {
self.state.when().expect("Timer already fired")
Expand Down Expand Up @@ -620,6 +631,10 @@ impl TimerHandle {
unsafe { self.inner.as_ref().sync_when() }
}

pub(super) unsafe fn set_cached_when(&self, when: u64) {
unsafe { self.inner.as_ref().set_cached_when(when) }
}

pub(super) unsafe fn is_pending(&self) -> bool {
unsafe { self.inner.as_ref().state.is_pending() }
}
Expand Down
7 changes: 5 additions & 2 deletions tokio/src/time/driver/wheel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ impl Wheel {
/// Remove `item` from the timing wheel.
pub(crate) unsafe fn remove(&mut self, item: NonNull<TimerShared>) {
unsafe {
if !item.as_ref().might_be_registered() {
let when = item.as_ref().cached_when();
if when == u64::max_value() {
self.pending.remove(item);
} else {
let when = item.as_ref().cached_when();
let level = self.level_for(when);

self.levels[level].remove_entry(item);
Expand Down Expand Up @@ -244,6 +244,9 @@ impl Wheel {
match unsafe { item.mark_pending(expiration.deadline) } {
Ok(()) => {
// Item was expired
unsafe {
item.set_cached_when(u64::max_value());
}
self.pending.push_front(item);
}
Err(expiration_tick) => {
Expand Down
70 changes: 70 additions & 0 deletions tokio/tests/time_sleep.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,73 @@ async fn no_out_of_bounds_close_to_max() {
fn ms(n: u64) -> Duration {
Duration::from_millis(n)
}

#[tokio::test]
async fn drop_after_reschedule_at_new_scheduled_time() {
use futures::poll;

tokio::time::pause();

let start = tokio::time::Instant::now();

let mut a = tokio::time::sleep(Duration::from_millis(5));
let mut b = tokio::time::sleep(Duration::from_millis(5));
let mut c = tokio::time::sleep(Duration::from_millis(10));

let _ = poll!(&mut a);
let _ = poll!(&mut b);
let _ = poll!(&mut c);

b.reset(start + Duration::from_millis(10));
a.await;

drop(b);
}

#[tokio::test]
async fn drop_from_wake() {
use std::future::Future;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::task::Context;

let paniced = Arc::new(AtomicBool::new(false));
let list: Arc<Mutex<Vec<tokio::time::Sleep>>> = Arc::new(Mutex::new(Vec::new()));

let arc_wake = Arc::new(DropWaker(paniced.clone(), list.clone()));
let arc_wake = futures::task::waker(arc_wake);

tokio::time::pause();

let mut lock = list.lock().unwrap();

for _ in 0..100 {
let mut timer = tokio::time::sleep(Duration::from_millis(10));

let _ = std::pin::Pin::new(&mut timer).poll(&mut Context::from_waker(&arc_wake));

lock.push(timer);
}

drop(lock);

tokio::time::sleep(Duration::from_millis(11)).await;

assert!(
!paniced.load(Ordering::SeqCst),
"paniced when dropping timers"
);

#[derive(Clone)]
struct DropWaker(Arc<AtomicBool>, Arc<Mutex<Vec<tokio::time::Sleep>>>);

impl futures::task::ArcWake for DropWaker {
fn wake_by_ref(arc_self: &Arc<Self>) {
if let Err(_) = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
*arc_self.1.lock().expect("panic in lock") = Vec::new()
})) {
arc_self.0.store(true, Ordering::SeqCst);
}
}
}
}

0 comments on commit 62be632

Please sign in to comment.