Skip to content

Commit

Permalink
Prevent task from respawning while in the timer queue
Browse files Browse the repository at this point in the history
  • Loading branch information
bugadani committed Dec 10, 2024
1 parent 515666d commit b49e2a3
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 15 deletions.
14 changes: 12 additions & 2 deletions embassy-executor/src/raw/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub(crate) struct TaskHeader {
}

/// This is essentially a `&'static TaskStorage<F>` where the type of the future has been erased.
#[derive(Clone, Copy)]
#[derive(Clone, Copy, PartialEq)]
pub struct TaskRef {
ptr: NonNull<TaskHeader>,
}
Expand All @@ -72,6 +72,16 @@ impl TaskRef {
}
}

/// # Safety
///
/// The result of this function must only be compared
/// for equality, or stored, but not used.
pub const unsafe fn dangling() -> Self {
Self {
ptr: NonNull::dangling(),
}
}

pub(crate) fn header(self) -> &'static TaskHeader {
unsafe { self.ptr.as_ref() }
}
Expand All @@ -97,7 +107,7 @@ impl TaskRef {
/// This functions should only be called by the timer queue implementation, before
/// enqueueing the timer item.
#[cfg(feature = "integrated-timers")]
pub unsafe fn timer_enqueue(&self) -> bool {
pub unsafe fn timer_enqueue(&self) -> timer_queue::TimerEnqueueOperation {
self.header().state.timer_enqueue()
}

Expand Down
25 changes: 22 additions & 3 deletions embassy-executor/src/raw/state_atomics.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use core::sync::atomic::{AtomicU32, Ordering};

#[cfg(feature = "integrated-timers")]
use super::timer_queue::TimerEnqueueOperation;

/// Task is spawned (has a future)
pub(crate) const STATE_SPAWNED: u32 = 1 << 0;
/// Task is in the executor run queue
Expand Down Expand Up @@ -56,11 +59,27 @@ impl State {
state & STATE_SPAWNED != 0
}

/// Mark the task as timer-queued. Return whether it was newly queued (i.e. not queued before)
/// Mark the task as timer-queued. Return whether it can be enqueued.
#[cfg(feature = "integrated-timers")]
#[inline(always)]
pub fn timer_enqueue(&self) -> bool {
self.state.fetch_or(STATE_TIMER_QUEUED, Ordering::Relaxed) & STATE_TIMER_QUEUED == 0
pub fn timer_enqueue(&self) -> TimerEnqueueOperation {
if self
.state
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |state| {
// If not started, ignore it
if state & STATE_SPAWNED == 0 {
None
} else {
// Mark it as enqueued
Some(state | STATE_TIMER_QUEUED)
}
})
.is_ok()
{
TimerEnqueueOperation::Enqueue
} else {
TimerEnqueueOperation::Ignore
}
}

/// Unmark the task as timer-queued.
Expand Down
27 changes: 24 additions & 3 deletions embassy-executor/src/raw/state_atomics_arm.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
use core::arch::asm;
use core::sync::atomic::{compiler_fence, AtomicBool, AtomicU32, Ordering};

#[cfg(feature = "integrated-timers")]
use super::timer_queue::TimerEnqueueOperation;

// Must be kept in sync with the layout of `State`!
pub(crate) const STATE_SPAWNED: u32 = 1 << 0;
pub(crate) const STATE_RUN_QUEUED: u32 = 1 << 8;
#[cfg(feature = "integrated-timers")]
pub(crate) const STATE_TIMER_QUEUED: u32 = 1 << 16;

#[repr(C, align(4))]
pub(crate) struct State {
Expand Down Expand Up @@ -87,11 +92,27 @@ impl State {
r
}

/// Mark the task as timer-queued. Return whether it was newly queued (i.e. not queued before)
/// Mark the task as timer-queued. Return whether it can be enqueued.
#[cfg(feature = "integrated-timers")]
#[inline(always)]
pub fn timer_enqueue(&self) -> bool {
!self.timer_queued.swap(true, Ordering::Relaxed)
pub fn timer_enqueue(&self) -> TimerEnqueueOperation {
if self
.as_u32()
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |state| {
// If not started, ignore it
if state & STATE_SPAWNED == 0 {
None
} else {
// Mark it as enqueued
Some(state | STATE_TIMER_QUEUED)
}
})
.is_ok()
{
TimerEnqueueOperation::Enqueue
} else {
TimerEnqueueOperation::Ignore
}
}

/// Unmark the task as timer-queued.
Expand Down
18 changes: 13 additions & 5 deletions embassy-executor/src/raw/state_critical_section.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ use core::cell::Cell;

use critical_section::Mutex;

#[cfg(feature = "integrated-timers")]
use super::timer_queue::TimerEnqueueOperation;

/// Task is spawned (has a future)
pub(crate) const STATE_SPAWNED: u32 = 1 << 0;
/// Task is in the executor run queue
Expand Down Expand Up @@ -73,14 +76,19 @@ impl State {
})
}

/// Mark the task as timer-queued. Return whether it was newly queued (i.e. not queued before)
/// Mark the task as timer-queued. Return whether it can be enqueued.
#[cfg(feature = "integrated-timers")]
#[inline(always)]
pub fn timer_enqueue(&self) -> bool {
pub fn timer_enqueue(&self) -> TimerEnqueueOperation {
self.update(|s| {
let ok = *s & STATE_TIMER_QUEUED == 0;
*s |= STATE_TIMER_QUEUED;
ok
// FIXME: we need to split SPAWNED into two phases, to prevent enqueueing a task that is
// just being spawned, because its executor pointer may still be changing.
if *s & STATE_SPAWNED == STATE_SPAWNED {
*s |= STATE_TIMER_QUEUED;
TimerEnqueueOperation::Enqueue
} else {
TimerEnqueueOperation::Ignore
}
})
}

Expand Down
13 changes: 13 additions & 0 deletions embassy-executor/src/raw/timer_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ use super::TaskRef;
/// An item in the timer queue.
pub struct TimerQueueItem {
/// The next item in the queue.
///
/// If this field contains `Some`, the item is in the queue. The last item in the queue has a
/// value of `Some(dangling_pointer)`
pub next: Cell<Option<TaskRef>>,

/// The time at which this item expires.
Expand All @@ -23,3 +26,13 @@ impl TimerQueueItem {
}
}
}

/// The operation to perform after `timer_enqueue` is called.
#[derive(Debug, Copy, Clone, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum TimerEnqueueOperation {
/// Enqueue the task.
Enqueue,
/// Update the task's expiration time.
Ignore,
}
12 changes: 12 additions & 0 deletions embassy-time-queue-driver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@ extern "Rust" {

/// Schedule the given waker to be woken at `at`.
pub fn schedule_wake(at: u64, waker: &Waker) {
#[cfg(feature = "integrated-timers")]
{
// The very first thing we must do, before we even access the timer queue, is to
// mark the task a TIMER_QUEUED. This ensures that the task that is being scheduled
// can not be respawn while we are accessing the timer queue.
let task = embassy_executor::raw::task_from_waker(waker);
if unsafe { task.timer_enqueue() } == embassy_executor::raw::timer_queue::TimerEnqueueOperation::Ignore {
// We are not allowed to enqueue the task in the timer queue. This is because the
// task is not spawned, and so it makes no sense to schedule it.
return;
}
}
unsafe { _embassy_time_schedule_wake(at, waker) }
}

Expand Down
13 changes: 11 additions & 2 deletions embassy-time-queue-driver/src/queue_integrated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@ impl TimerQueue {
/// a new alarm for that time.
pub fn schedule_wake(&mut self, at: u64, p: TaskRef) -> bool {
let item = p.timer_queue_item();
if unsafe { p.timer_enqueue() } {
if item.next.get().is_none() {
// If not in the queue, add it and update.
let prev = self.head.replace(Some(p));
item.next.set(prev);
item.next.set(if prev.is_none() {
Some(unsafe { TaskRef::dangling() })
} else {
prev
});
item.expires_at.set(at);
true
} else if at <= item.expires_at.get() {
Expand Down Expand Up @@ -65,13 +69,18 @@ impl TimerQueue {
fn retain(&self, mut f: impl FnMut(TaskRef) -> bool) {
let mut prev = &self.head;
while let Some(p) = prev.get() {
if unsafe { p == TaskRef::dangling() } {
// prev was the last item, stop
break;
}
let item = p.timer_queue_item();
if f(p) {
// Skip to next
prev = &item.next;
} else {
// Remove it
prev.set(item.next.get());
item.next.set(None);
unsafe { p.timer_dequeue() };
}
}
Expand Down

0 comments on commit b49e2a3

Please sign in to comment.