From 6f988728bb1b9156b78153461c8632e44cb21a21 Mon Sep 17 00:00:00 2001 From: Markus Westerlind Date: Fri, 29 Jan 2021 21:45:29 +0100 Subject: [PATCH] runtime: minimize the amount of duplicated code (#3416) --- tokio/src/runtime/queue.rs | 4 +- tokio/src/runtime/task/core.rs | 24 ++++-- tokio/src/runtime/task/harness.rs | 56 +++++++------- tokio/src/sync/oneshot.rs | 120 ++++++++++++++---------------- 4 files changed, 100 insertions(+), 104 deletions(-) diff --git a/tokio/src/runtime/queue.rs b/tokio/src/runtime/queue.rs index 0fcaad8d7d2..1c7bb230984 100644 --- a/tokio/src/runtime/queue.rs +++ b/tokio/src/runtime/queue.rs @@ -235,7 +235,7 @@ impl Local { // tasks and we are the only producer. self.inner.buffer[i_idx].with_mut(|ptr| unsafe { let ptr = (*ptr).as_ptr(); - (*ptr).header().queue_next.with_mut(|ptr| *ptr = Some(next)); + (*ptr).header().set_next(Some(next)) }); } @@ -610,7 +610,7 @@ fn get_next(header: NonNull) -> Option> { fn set_next(header: NonNull, val: Option>) { unsafe { - header.as_ref().queue_next.with_mut(|ptr| *ptr = val); + header.as_ref().set_next(val); } } diff --git a/tokio/src/runtime/task/core.rs b/tokio/src/runtime/task/core.rs index f922deaa67d..9f7ff55fede 100644 --- a/tokio/src/runtime/task/core.rs +++ b/tokio/src/runtime/task/core.rs @@ -249,10 +249,10 @@ impl CoreStage { /// /// The caller must ensure it is safe to mutate the `stage` field. pub(super) fn drop_future_or_output(&self) { - self.stage.with_mut(|ptr| { - // Safety: The caller ensures mutal exclusion to the field. - unsafe { *ptr = Stage::Consumed }; - }); + // Safety: the caller ensures mutual exclusion to the field. + unsafe { + self.set_stage(Stage::Consumed); + } } /// Store the task output @@ -261,10 +261,10 @@ impl CoreStage { /// /// The caller must ensure it is safe to mutate the `stage` field. pub(super) fn store_output(&self, output: super::Result) { - self.stage.with_mut(|ptr| { - // Safety: the caller ensures mutual exclusion to the field. - unsafe { *ptr = Stage::Finished(output) }; - }); + // Safety: the caller ensures mutual exclusion to the field. + unsafe { + self.set_stage(Stage::Finished(output)); + } } /// Take the task output @@ -283,6 +283,10 @@ impl CoreStage { } }) } + + unsafe fn set_stage(&self, stage: Stage) { + self.stage.with_mut(|ptr| *ptr = stage) + } } cfg_rt_multi_thread! { @@ -293,6 +297,10 @@ cfg_rt_multi_thread! { let task = unsafe { RawTask::from_raw(self.into()) }; task.shutdown(); } + + pub(crate) unsafe fn set_next(&self, next: Option>) { + self.queue_next.with_mut(|ptr| *ptr = next); + } } } diff --git a/tokio/src/runtime/task/harness.rs b/tokio/src/runtime/task/harness.rs index df1c8ac07bb..7d596e36e1a 100644 --- a/tokio/src/runtime/task/harness.rs +++ b/tokio/src/runtime/task/harness.rs @@ -403,44 +403,44 @@ fn poll_future( snapshot: Snapshot, cx: Context<'_>, ) -> PollFuture { - let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { - struct Guard<'a, T: Future> { - core: &'a CoreStage, - } + if snapshot.is_cancelled() { + PollFuture::Complete(Err(JoinError::cancelled()), snapshot.is_join_interested()) + } else { + let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { + struct Guard<'a, T: Future> { + core: &'a CoreStage, + } - impl Drop for Guard<'_, T> { - fn drop(&mut self) { - self.core.drop_future_or_output(); + impl Drop for Guard<'_, T> { + fn drop(&mut self) { + self.core.drop_future_or_output(); + } } - } - let guard = Guard { core }; + let guard = Guard { core }; - // If the task is cancelled, avoid polling it, instead signalling it - // is complete. - if snapshot.is_cancelled() { - Poll::Ready(Err(JoinError::cancelled())) - } else { let res = guard.core.poll(cx); // prevent the guard from dropping the future mem::forget(guard); - res.map(Ok) - } - })); - match res { - Ok(Poll::Pending) => match header.state.transition_to_idle() { - Ok(snapshot) => { - if snapshot.is_notified() { - PollFuture::Notified - } else { - PollFuture::None + res + })); + match res { + Ok(Poll::Pending) => match header.state.transition_to_idle() { + Ok(snapshot) => { + if snapshot.is_notified() { + PollFuture::Notified + } else { + PollFuture::None + } } + Err(_) => PollFuture::Complete(Err(cancel_task(core)), true), + }, + Ok(Poll::Ready(ok)) => PollFuture::Complete(Ok(ok), snapshot.is_join_interested()), + Err(err) => { + PollFuture::Complete(Err(JoinError::panic(err)), snapshot.is_join_interested()) } - Err(_) => PollFuture::Complete(Err(cancel_task(core)), true), - }, - Ok(Poll::Ready(ok)) => PollFuture::Complete(ok, snapshot.is_join_interested()), - Err(err) => PollFuture::Complete(Err(JoinError::panic(err)), snapshot.is_join_interested()), + } } } diff --git a/tokio/src/sync/oneshot.rs b/tokio/src/sync/oneshot.rs index ece9abaeb64..20d39dc1a1e 100644 --- a/tokio/src/sync/oneshot.rs +++ b/tokio/src/sync/oneshot.rs @@ -84,10 +84,42 @@ struct Inner { value: UnsafeCell>, /// The task to notify when the receiver drops without consuming the value. - tx_task: UnsafeCell>, + tx_task: Task, /// The task to notify when the value is sent. - rx_task: UnsafeCell>, + rx_task: Task, +} + +struct Task(UnsafeCell>); + +impl Task { + unsafe fn will_wake(&self, cx: &mut Context<'_>) -> bool { + self.with_task(|w| w.will_wake(cx.waker())) + } + + unsafe fn with_task(&self, f: F) -> R + where + F: FnOnce(&Waker) -> R, + { + self.0.with(|ptr| { + let waker: *const Waker = (&*ptr).as_ptr(); + f(&*waker) + }) + } + + unsafe fn drop_task(&self) { + self.0.with_mut(|ptr| { + let ptr: *mut Waker = (&mut *ptr).as_mut_ptr(); + ptr.drop_in_place(); + }); + } + + unsafe fn set_task(&self, cx: &mut Context<'_>) { + self.0.with_mut(|ptr| { + let ptr: *mut Waker = (&mut *ptr).as_mut_ptr(); + ptr.write(cx.waker().clone()); + }); + } } #[derive(Clone, Copy)] @@ -127,8 +159,8 @@ pub fn channel() -> (Sender, Receiver) { let inner = Arc::new(Inner { state: AtomicUsize::new(State::new().as_usize()), value: UnsafeCell::new(None), - tx_task: UnsafeCell::new(MaybeUninit::uninit()), - rx_task: UnsafeCell::new(MaybeUninit::uninit()), + tx_task: Task(UnsafeCell::new(MaybeUninit::uninit())), + rx_task: Task(UnsafeCell::new(MaybeUninit::uninit())), }); let tx = Sender { @@ -188,9 +220,9 @@ impl Sender { }); if !inner.complete() { - return Err(inner - .value - .with_mut(|ptr| unsafe { (*ptr).take() }.unwrap())); + unsafe { + return Err(inner.consume_value().unwrap()); + } } Ok(()) @@ -357,7 +389,7 @@ impl Sender { } if state.is_tx_task_set() { - let will_notify = unsafe { inner.with_tx_task(|w| w.will_wake(cx.waker())) }; + let will_notify = unsafe { inner.tx_task.will_wake(cx) }; if !will_notify { state = State::unset_tx_task(&inner.state); @@ -368,7 +400,7 @@ impl Sender { coop.made_progress(); return Ready(()); } else { - unsafe { inner.drop_tx_task() }; + unsafe { inner.tx_task.drop_task() }; } } } @@ -376,7 +408,7 @@ impl Sender { if !state.is_tx_task_set() { // Attempt to set the task unsafe { - inner.set_tx_task(cx); + inner.tx_task.set_task(cx); } // Update the state @@ -584,7 +616,7 @@ impl Inner { if prev.is_rx_task_set() { // TODO: Consume waker? unsafe { - self.with_rx_task(Waker::wake_by_ref); + self.rx_task.with_task(Waker::wake_by_ref); } } @@ -609,7 +641,7 @@ impl Inner { Ready(Err(RecvError(()))) } else { if state.is_rx_task_set() { - let will_notify = unsafe { self.with_rx_task(|w| w.will_wake(cx.waker())) }; + let will_notify = unsafe { self.rx_task.will_wake(cx) }; // Check if the task is still the same if !will_notify { @@ -625,7 +657,7 @@ impl Inner { None => Ready(Err(RecvError(()))), }; } else { - unsafe { self.drop_rx_task() }; + unsafe { self.rx_task.drop_task() }; } } } @@ -633,7 +665,7 @@ impl Inner { if !state.is_rx_task_set() { // Attempt to set the task unsafe { - self.set_rx_task(cx); + self.rx_task.set_task(cx); } // Update the state @@ -660,7 +692,7 @@ impl Inner { if prev.is_tx_task_set() && !prev.is_complete() { unsafe { - self.with_tx_task(Waker::wake_by_ref); + self.tx_task.with_task(Waker::wake_by_ref); } } } @@ -669,72 +701,28 @@ impl Inner { unsafe fn consume_value(&self) -> Option { self.value.with_mut(|ptr| (*ptr).take()) } - - unsafe fn with_rx_task(&self, f: F) -> R - where - F: FnOnce(&Waker) -> R, - { - self.rx_task.with(|ptr| { - let waker: *const Waker = (&*ptr).as_ptr(); - f(&*waker) - }) - } - - unsafe fn with_tx_task(&self, f: F) -> R - where - F: FnOnce(&Waker) -> R, - { - self.tx_task.with(|ptr| { - let waker: *const Waker = (&*ptr).as_ptr(); - f(&*waker) - }) - } - - unsafe fn drop_rx_task(&self) { - self.rx_task.with_mut(|ptr| { - let ptr: *mut Waker = (&mut *ptr).as_mut_ptr(); - ptr.drop_in_place(); - }); - } - - unsafe fn drop_tx_task(&self) { - self.tx_task.with_mut(|ptr| { - let ptr: *mut Waker = (&mut *ptr).as_mut_ptr(); - ptr.drop_in_place(); - }); - } - - unsafe fn set_rx_task(&self, cx: &mut Context<'_>) { - self.rx_task.with_mut(|ptr| { - let ptr: *mut Waker = (&mut *ptr).as_mut_ptr(); - ptr.write(cx.waker().clone()); - }); - } - - unsafe fn set_tx_task(&self, cx: &mut Context<'_>) { - self.tx_task.with_mut(|ptr| { - let ptr: *mut Waker = (&mut *ptr).as_mut_ptr(); - ptr.write(cx.waker().clone()); - }); - } } unsafe impl Send for Inner {} unsafe impl Sync for Inner {} +fn mut_load(this: &mut AtomicUsize) -> usize { + this.with_mut(|v| *v) +} + impl Drop for Inner { fn drop(&mut self) { - let state = State(self.state.with_mut(|v| *v)); + let state = State(mut_load(&mut self.state)); if state.is_rx_task_set() { unsafe { - self.drop_rx_task(); + self.rx_task.drop_task(); } } if state.is_tx_task_set() { unsafe { - self.drop_tx_task(); + self.tx_task.drop_task(); } } }