From 8a4632824f836b67cc953d566029f7dcbd827788 Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Tue, 30 Aug 2022 09:39:24 -0700 Subject: [PATCH] feat(maitake): add `wait::Semaphore` (#301) A semaphore is a useful synchronization type for things like rate-limiting async tasks. It could also be used as a lower-level primitive to implement things like read-write locks[^1] and channels[^2]. This branch adds an asynchronous semaphore implementation in `maitake::wait`. The implementation is based on the implementation I wrote for Tokio in tokio-rs/tokio#2325, with some code simplified a bit as it was not necessary to maintain Tokio's required API surface. Closes #299 [^1]: a rwlock can be modeled by a semaphore with _n_ permits (where _n_ is the maximum number of concurrent readers); each reader must acquire a single permit, while a writer must acquire _n_ permits). [^2]: a bounded MPSC channel of capacity _n_ can be implemented using a semaphore with _n_ permits, where each producer must acquire a single permit to write, and every time a message is consumed, the reader releases a permit to the writers. --- maitake/src/wait.rs | 7 +- maitake/src/wait/semaphore.rs | 1041 +++++++++++++++++++++++++++ maitake/src/wait/semaphore/loom.rs | 165 +++++ maitake/src/wait/semaphore/tests.rs | 127 ++++ 4 files changed, 1339 insertions(+), 1 deletion(-) create mode 100644 maitake/src/wait/semaphore.rs create mode 100644 maitake/src/wait/semaphore/loom.rs create mode 100644 maitake/src/wait/semaphore/tests.rs diff --git a/maitake/src/wait.rs b/maitake/src/wait.rs index 6e274d73..b2dfad7e 100644 --- a/maitake/src/wait.rs +++ b/maitake/src/wait.rs @@ -1,21 +1,26 @@ //! Waiter cells and queues to allow tasks to wait for notifications. //! -//! This module implements three types of structure for waiting: +//! This module implements the following primitives for waiting: //! //! - [`WaitCell`], which stores a *single* waiting task //! - [`WaitQueue`], a queue of waiting tasks, which are woken in first-in, //! first-out order //! - [`WaitMap`], a set of waiting tasks associated with keys, in which a task //! can be woken by its key +//! - [`Semaphore`]: an asynchronous [counting semaphore], for limiting the +//! number of tasks which may run concurrently pub(crate) mod cell; pub mod map; pub mod queue; +pub mod semaphore; pub use self::cell::WaitCell; #[doc(inline)] pub use self::map::WaitMap; #[doc(inline)] pub use self::queue::WaitQueue; +#[doc(inline)] +pub use self::semaphore::Semaphore; use core::task::Poll; diff --git a/maitake/src/wait/semaphore.rs b/maitake/src/wait/semaphore.rs new file mode 100644 index 00000000..68d35d06 --- /dev/null +++ b/maitake/src/wait/semaphore.rs @@ -0,0 +1,1041 @@ +//! An asynchronous [counting semaphore]. +//! +//! A semaphore limits the number of tasks which may execute concurrently. See +//! the [`Semaphore`] type's documentation for details. +//! +//! [counting semaphore]: https://en.wikipedia.org/wiki/Semaphore_(programming) +use crate::{ + loom::{ + cell::UnsafeCell, + sync::{ + atomic::{AtomicUsize, Ordering::*}, + spin::{Mutex, MutexGuard}, + }, + }, + wait::{self, WaitResult}, +}; +use cordyceps::{ + list::{self, List}, + Linked, +}; +use core::{ + cmp, + future::Future, + marker::PhantomPinned, + pin::Pin, + ptr::{self, NonNull}, + task::{Context, Poll, Waker}, +}; +#[cfg(any(test, feature = "tracing-01", feature = "tracing-02"))] +use mycelium_util::fmt; +use mycelium_util::sync::CachePadded; +use pin_project::{pin_project, pinned_drop}; + +#[cfg(all(test, loom))] +mod loom; +#[cfg(all(test, not(loom)))] +mod tests; + +/// An asynchronous [counting semaphore]. +/// +/// A semaphore is a synchronization primitive that limits the number of tasks +/// that may run concurrently. It consists of a count of _permits_, which tasks +/// may [`acquire`] in order to execute in some context. When a task acquires a +/// permit from the semaphore, the count of permits held by the semaphore is +/// decreased. When no permits remain in the semaphore, any task that wishes to +/// acquire a permit must (asynchronously) wait until another task has released +/// a permit. +/// +/// The [`Permit`] type is a RAII guard representing one or more permits +/// acquired from a `Semaphore`. When a [`Permit`] is dropped, the permits it +/// represents are released back to the `Semaphore`, potentially allowing a +/// waiting task to acquire them. +/// +/// # Fairness +/// +/// This semaphore is _fair_: as permits become available, they are assigned to +/// waiting tasks in the order that those tasks requested permits (first-in, +/// first-out). This means that all tasks waiting to acquire permits will +/// eventually be allowed to progress, and a single task cannot starve the +/// semaphore of permits (provided that permits are eventually released). The +/// semaphore remains fair even when a call to `acquire` requests more than one +/// permit at a time. +/// +/// # Examples +/// +/// Using a semaphore to limit concurrency: +/// +/// ``` +/// # use std as alloc; +/// use maitake::{scheduler::Scheduler, wait::Semaphore}; +/// use alloc::sync::Arc; +/// +/// let scheduler = Scheduler::new(); +/// // Allow 4 tasks to run concurrently at a time. +/// let semaphore = Arc::new(Semaphore::new(4)); +/// +/// for _ in 0..8 { +/// // Clone the `Arc` around the semaphore. +/// let semaphore = semaphore.clone(); +/// scheduler.spawn(async move { +/// // Acquire a permit from the semaphore, returning a RAII guard that +/// // releases the permit back to the semaphore when dropped. +/// // +/// // If all 4 permits have been acquired, the calling task will yield, +/// // and it will be woken when another task releases a permit. +/// let _permit = semaphore +/// .acquire(1) +/// .await +/// .expect("semaphore will not be closed"); +/// +/// // do some work... +/// }); +/// } +/// +/// scheduler.tick(); +/// ``` +/// +/// A semaphore may also be used to cause a task to run once all of a set of +/// tasks have completed. If we want some task _B_ to run only after a fixed +/// number _n_ of tasks _A_ have run, we can have task _B_ try to acquire _n_ +/// permits from a semaphore with 0 permits, and have each task _A_ add one +/// permit to the semaphore when it completes. +/// +/// For example: +/// +/// ``` +/// # use std as alloc; +/// use maitake::{scheduler::Scheduler, wait::Semaphore}; +/// use alloc::sync::Arc; +/// +/// // How many tasks will we be waiting for the completion of? +/// const TASKS: usize = 4; +/// +/// let scheduler = Scheduler::new(); +/// +/// // Create the semaphore with 0 permits. +/// let semaphore = Arc::new(Semaphore::new(0)); +/// +/// // Spawn the "B" task that will wait for the 4 "A" tasks to complete. +/// scheduler.spawn({ +/// let semaphore = semaphore.clone(); +/// async move { +/// println!("Task B starting..."); +/// +/// // Since the semaphore is created with 0 permits, this will +/// // wait until all 4 "A" tasks have completed. +/// let _permit = semaphore +/// .acquire(TASKS) +/// .await +/// .expect("semaphore will not be closed"); +/// +/// // ... do some work ... +/// +/// println!("Task B done!"); +/// } +/// }); +/// +/// for i in 0..TASKS { +/// let semaphore = semaphore.clone(); +/// scheduler.spawn(async move { +/// println!("Task A {i} starting..."); +/// +/// // Add a single permit to the semaphore. Once all 4 tasks have +/// // completed, the semaphore will have the 4 permits required to +/// // wake the "B" task. +/// semaphore.add_permits(1); +/// +/// // ... do some work ... +/// +/// println!("Task A {i} done"); +/// }); +/// } +/// +/// scheduler.tick(); +/// ``` +/// +/// [counting semaphore]: https://en.wikipedia.org/wiki/Semaphore_(programming) +/// [`acquire`]: Semaphore::acquire +#[derive(Debug)] +pub struct Semaphore { + /// The number of permits in the semaphore (or [`usize::MAX] if the + /// semaphore is closed. + permits: CachePadded, + + /// The queue of tasks waiting to acquire permits. + /// + /// A spinlock (from `mycelium_util`) is used here, in order to support + /// `no_std` platforms; when running `loom` tests, a `loom` mutex is used + /// instead to simulate the spinlock, because loom doesn't play nice with + /// real spinlocks. + waiters: Mutex, +} + +/// A [RAII guard] representing one or more permits acquired from a +/// [`Semaphore`]. +/// +/// When the `Permit` is dropped, the permits it represents are released back to +/// the [`Semaphore`], potentially waking another task. +/// +/// This type is returned by the [`Semaphore::acquire`] and +/// [`Semaphore::try_acquire`] methods. +/// +/// [RAII guard]: https://rust-unofficial.github.io/patterns/patterns/behavioural/RAII.html +#[derive(Debug)] +#[must_use = "dropping a `Permit` releases the acquired permits back to the `Semaphore`"] +pub struct Permit<'sem> { + permits: usize, + semaphore: &'sem Semaphore, +} + +/// The future returned by the [`Semaphore::acquire`] method. +#[derive(Debug)] +#[pin_project(PinnedDrop)] +#[must_use = "futures do nothing unless `.await`ed or `poll`ed"] +pub struct Acquire<'sem> { + semaphore: &'sem Semaphore, + queued: bool, + permits: usize, + #[pin] + waiter: Waiter, +} + +/// Errors returned by [`Semaphore::try_acquire`]. + +#[derive(Debug, PartialEq, Eq)] +pub enum TryAcquireError { + /// The semaphore has been [closed], so additional permits cannot be + /// acquired. + /// + /// [closed]: Semaphore::close + Closed, + /// The semaphore does not currently have enough permits to satisfy the + /// request. + InsufficientPermits, +} + +/// The semaphore's queue of waiters. This is the portion of the semaphore's +/// state stored inside the lock. +#[derive(Debug)] +struct SemQueue { + /// The linked list of waiters. + /// + /// # Safety + /// + /// This is protected by a mutex; the mutex *must* be acquired when + /// manipulating the linked list, OR when manipulating waiter nodes that may + /// be linked into the list. If a node is known to not be linked, it is safe + /// to modify that node (such as by waking the stored [`Waker`]) without + /// holding the lock; otherwise, it may be modified through the list, so the + /// lock must be held when modifying the + /// node. + queue: List, + + /// Has the semaphore closed? + /// + /// This is tracked inside of the locked state to avoid a potential race + /// condition where the semaphore closes while trying to lock the wait queue. + closed: bool, +} + +#[derive(Debug)] +#[pin_project] +struct Waiter { + #[pin] + node: UnsafeCell, + + remaining_permits: RemainingPermits, +} + +/// The number of permits needed before this waiter can be woken. +/// +/// When this value reaches zero, the waiter has acquired all its needed +/// permits and can be woken. If this value is `usize::max`, then the waiter +/// has not yet been linked into the semaphore queue. +#[derive(Debug)] +struct RemainingPermits(AtomicUsize); + +#[derive(Debug)] +struct Node { + links: list::Links, + waker: Option, + + // This type is !Unpin due to the heuristic from: + // + _pin: PhantomPinned, +} + +// === impl Semaphore === + +impl Semaphore { + /// The maximum number of permits a `Semaphore` may contain. + pub const MAX_PERMITS: usize = usize::MAX - 1; + + const CLOSED: usize = usize::MAX; + + loom_const_fn! { + /// Returns a new `Semaphore` with `permits` permits available. + /// + /// # Panics + /// + /// If `permits` is less than [`MAX_PERMITS`] ([`usize::MAX`] - 1). + /// + /// [`MAX_PERMITS`]: Self::MAX_PERMITS + #[must_use] + pub fn new(permits: usize) -> Self { + assert!( + permits <= Self::MAX_PERMITS, + "a semaphore may not have more than Semaphore::MAX_PERMITS permits", + ); + Self { + permits: CachePadded::new(AtomicUsize::new(permits)), + waiters: Mutex::new(SemQueue { + queue: List::new(), + closed: false, + }), + } + } + } + + /// Returns the number of permits currently available in this semaphore, or + /// 0 if the semaphore is [closed]. + /// + /// [closed]: Semaphore::close + pub fn available_permits(&self) -> usize { + let permits = self.permits.load(Acquire); + if permits == Self::CLOSED { + return 0; + } + + permits + } + + /// Acquire `permits` permits from the `Semaphore`, waiting asynchronously + /// if there are insufficient permits currently available. + /// + /// # Returns + /// + /// - `Ok(`[`Permit`]`)` with the requested number of permits, if the + /// permits were acquired. + /// - `Err(`[`Closed`]`)` if the semaphore was [closed]. + /// + /// # Cancellation + /// + /// This method uses a queue to fairly distribute permits in the order they + /// were requested. If an [`Acquire`] future is dropped before it completes, + /// the task will lose its place in the queue. + /// + /// [`Closed`]: crate::wait::Closed + /// [closed]: Semaphore::close + pub fn acquire(&self, permits: usize) -> Acquire<'_> { + Acquire { + semaphore: self, + queued: false, + permits, + waiter: Waiter { + node: UnsafeCell::new(Node { + links: list::Links::new(), + waker: None, + _pin: PhantomPinned, + }), + remaining_permits: RemainingPermits(AtomicUsize::new(permits)), + }, + } + } + + /// Add `permits` new permits to the semaphore. + /// + /// This permanently increases the number of permits available in the + /// semaphore. The permit count can be permanently *decreased* by calling + /// [`acquire`] or [`try_acquire`], and [`forget`]ting the returned [`Permit`]. + /// + /// # Panics + /// + /// If adding `permits` permits would cause the permit count to overflow + /// [`MAX_PERMITS`] ([`usize::MAX`] - 1). + /// + /// [`acquire`]: Self::acquire + /// [`try_acquire`]: Self::try_acquire + /// [`forget`]: Permit::forget + /// [`MAX_PERMITS`]: Self::MAX_PERMITS + #[inline(always)] + pub fn add_permits(&self, permits: usize) { + if permits == 0 { + return; + } + + self.add_permits_locked(permits, self.waiters.lock()); + } + + /// Try to acquire `permits` permits from the `Semaphore`, without waiting + /// for additional permits to become available. + /// + /// # Returns + /// + /// - `Ok(`[`Permit`]`)` with the requested number of permits, if the + /// permits were acquired. + /// - `Err(`[`TryAcquireError::Closed`]`)` if the semaphore was [closed]. + /// - `Err(`[`TryAcquireError::InsufficientPermits`]`)` if the semaphore had + /// fewer than `permits` permits available. + /// + /// [`Closed`]: crate::wait::Closed + /// [closed]: Semaphore::close + pub fn try_acquire(&self, permits: usize) -> Result, TryAcquireError> { + trace!(permits, "Semaphore::try_acquire"); + self.try_acquire_inner(permits).map(|_| Permit { + permits, + semaphore: self, + }) + } + + /// Closes the semaphore. + /// + /// This wakes all tasks currently waiting on the semaphore, and prevents + /// new permits from being acquired. + pub fn close(&self) { + let mut waiters = self.waiters.lock(); + self.permits.store(Self::CLOSED, Release); + waiters.closed = true; + while let Some(waiter) = waiters.queue.pop_back() { + if let Some(waker) = Waiter::take_waker(waiter, &mut waiters.queue) { + waker.wake(); + } + } + } + + fn poll_acquire( + &self, + mut node: Pin<&mut Waiter>, + permits: usize, + queued: bool, + cx: &mut Context<'_>, + ) -> Poll> { + trace!( + waiter = ?fmt::ptr(node.as_mut()), + permits, + queued, + "Semaphore::poll_acquire" + ); + // the total number of permits we've acquired so far. + let mut acquired_permits = 0; + let waiter = node.as_mut().project(); + + // how many permits are currently needed? + let needed_permits = if queued { + waiter.remaining_permits.remaining() + } else { + permits + }; + + // okay, let's try to consume the requested number of permits from the + // semaphore. + let mut sem_curr = self.permits.load(Relaxed); + let mut lock = None; + let mut waiters = loop { + // semaphore has closed + if sem_curr == Self::CLOSED { + return wait::closed(); + } + + // the total number of permits currently available to this waiter + // are the number it has acquired so far plus all the permits + // in the semaphore. + let available_permits = sem_curr + acquired_permits; + let mut remaining = 0; + let mut sem_next = sem_curr; + let can_acquire = if available_permits >= needed_permits { + // there are enough permits available to satisfy this request. + + // the semaphore's next state will be the current number of + // permits less the amount we have to take from it to satisfy + // request. + sem_next -= needed_permits - acquired_permits; + needed_permits + } else { + // the number of permits available in the semaphore is less than + // number we want to acquire. take all the currently available + // permits. + sem_next = 0; + // how many permits do we still need to acquire? + remaining = (needed_permits - acquired_permits) - sem_curr; + sem_curr + }; + + if remaining > 0 && lock.is_none() { + // we weren't able to acquire enough permits on this poll, so + // the waiter will probably need to be queued, so we must lock + // the wait queue. + // + // this has to happen *before* the CAS that sets the new value + // of the semaphore's permits counter. if we subtracted the + // permits before acquiring the lock, additional permits might + // be added to the semaphore while we were waiting to lock the + // wait queue, and we would miss acquiring those permits. + // therefore, we lock the queue now. + lock = Some(self.waiters.lock()); + } + + if let Err(actual) = test_dbg!(self.permits.compare_exchange( + test_dbg!(sem_curr), + test_dbg!(sem_next), + AcqRel, + Acquire + )) { + // the semaphore was updated while we were trying to acquire + // permits. + sem_curr = actual; + continue; + } + + // okay, we took some permits from the semaphore. + acquired_permits += can_acquire; + // did we acquire all the permits we needed? + if test_dbg!(remaining) == 0 { + if !queued { + // the wasn't already in the queue, so we won't need to + // remove it --- we're done! + trace!( + waiter = ?fmt::ptr(node.as_mut()), + permits, + queued, + "Semaphore::poll_acquire -> all permits acquired; done" + ); + return Poll::Ready(Ok(())); + } else { + // we acquired all the permits we needed, but the waiter was + // already in the queue, so we need to dequeue it. we may + // have already acquired the lock on a previous CAS attempt + // that failed, but if not, grab it now. + break lock.unwrap_or_else(|| self.waiters.lock()); + } + } + + // we updated the semaphore, and will need to wait to acquire + // additional permits. + break lock.expect("we should have acquired the lock before trying to wait"); + }; + + if waiters.closed { + trace!( + waiter = ?fmt::ptr(node.as_mut()), + permits, + queued, + "Semaphore::poll_acquire -> semaphore closed" + ); + return wait::closed(); + } + + // add permits to the waiter, returning whether we added enough to wake + // it. + if waiter.remaining_permits.add(&mut acquired_permits) { + trace!( + waiter = ?fmt::ptr(node.as_mut()), + permits, + queued, + "Semaphore::poll_acquire -> remaining permits acquired; done" + ); + // if there are permits left over after waking the node, give the + // remaining permits back to the semaphore, potentially assigning + // them to the next waiter in the queue. + self.add_permits_locked(acquired_permits, waiters); + return Poll::Ready(Ok(())); + } + + debug_assert_eq!( + acquired_permits, 0, + "if we are enqueueing a waiter, we must have used all the acquired permits" + ); + + // we need to wait --- register the polling task's waker, and enqueue + // node. + let node_ptr = unsafe { NonNull::from(Pin::into_inner_unchecked(node)) }; + Waiter::with_node(node_ptr, &mut waiters.queue, |node| { + let will_wake = node + .waker + .as_ref() + .map_or(false, |waker| waker.will_wake(cx.waker())); + if !will_wake { + node.waker = Some(cx.waker().clone()) + } + }); + + // if the waiter is not already in the queue, add it now. + if !queued { + waiters.queue.push_front(node_ptr); + trace!( + waiter = ?node_ptr, + permits, + queued, + "Semaphore::poll_acquire -> enqueued" + ); + } + + Poll::Pending + } + + #[inline(never)] + fn add_permits_locked(&self, mut permits: usize, mut waiters: MutexGuard<'_, SemQueue>) { + trace!(permits, "Semaphore::add_permits"); + if waiters.closed { + trace!( + permits, + "Semaphore::add_permits -> already closed; doing nothing" + ); + return; + } + + let mut drained_queue = false; + while permits > 0 { + // peek the last waiter in the queue to add permits to it; we may not + // be popping it from the queue if there are not enough permits to + // wake that waiter. + match waiters.queue.back() { + Some(waiter) => { + // try to add enough permits to wake this waiter. if we + // can't, break --- we should be out of permits. + if !waiter.project_ref().remaining_permits.add(&mut permits) { + debug_assert_eq!(permits, 0); + break; + } + } + None => { + // we've emptied the queue. all done! + drained_queue = true; + break; + } + }; + + // okay, we added enough permits to wake this waiter. + let waiter = waiters + .queue + .pop_back() + .expect("if `back()` returned `Some`, `pop_back()` will also return `Some`"); + let waker = Waiter::take_waker(waiter, &mut waiters.queue); + trace!(?waiter, ?waker, permits, "Semaphore::add_permits -> waking"); + if let Some(waker) = waker { + // TODO(eliza): wake in batches outside the lock. + waker.wake(); + } + } + + if permits > 0 && drained_queue { + trace!( + permits, + "Semaphore::add_permits -> queue drained, assigning remaining permits to semaphore" + ); + // we drained the queue, but there are still permits left --- add + // them to the semaphore. + let prev = self.permits.fetch_add(permits, Release); + assert!( + prev + permits <= Self::MAX_PERMITS, + "semaphore overflow adding {permits} permits to {prev}; max permits: {}", + Self::MAX_PERMITS + ); + } + } + + /// Drop an `Acquire` future. + /// + /// This is factored out into a method on `Semaphore`, because the same code + /// is run when dropping an `Acquire` future or an `AcquireOwned` future. + fn drop_acquire(&self, waiter: Pin<&mut Waiter>, permits: usize, queued: bool) { + // If the future is completed, there is no node in the wait list, so we + // can skip acquiring the lock. + if !queued { + return; + } + + // This is where we ensure safety. The future is being dropped, + // which means we must ensure that the waiter entry is no longer stored + // in the linked list. + let mut waiters = self.waiters.lock(); + + let acquired_permits = permits - waiter.remaining_permits.remaining(); + + // Safety: we have locked the wait list. + unsafe { + // remove the entry from the list + let node = NonNull::from(Pin::into_inner_unchecked(waiter)); + waiters.queue.remove(node) + }; + + if acquired_permits > 0 { + self.add_permits_locked(acquired_permits, waiters); + } + } + + /// Try to acquire permits from the semaphore without waiting. + /// + /// This method is factored out because it's identical between the + /// `try_acquire` and `try_acquire_owned` methods, which behave identically + /// but return different permit types. + fn try_acquire_inner(&self, permits: usize) -> Result<(), TryAcquireError> { + let mut available = self.permits.load(Relaxed); + loop { + // are there enough permits to satisfy the request? + match available { + Self::CLOSED => { + trace!(permits, "Semaphore::try_acquire -> closed"); + return Err(TryAcquireError::Closed); + } + available if available < permits => { + trace!( + permits, + available, + "Semaphore::try_acquire -> insufficient permits" + ); + return Err(TryAcquireError::InsufficientPermits); + } + _ => {} + } + + let remaining = available - permits; + match self + .permits + .compare_exchange_weak(available, remaining, AcqRel, Acquire) + { + Ok(_) => { + trace!(permits, remaining, "Semaphore::try_acquire -> acquired"); + return Ok(()); + } + Err(actual) => available = actual, + } + } + } +} + +// === impl Acquire === + +impl<'sem> Future for Acquire<'sem> { + type Output = WaitResult>; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let poll = this + .semaphore + .poll_acquire(this.waiter, *this.permits, *this.queued, cx) + .map_ok(|_| Permit { + permits: *this.permits, + semaphore: this.semaphore, + }); + *this.queued = poll.is_pending(); + poll + } +} + +#[pinned_drop] +impl PinnedDrop for Acquire<'_> { + fn drop(self: Pin<&mut Self>) { + let this = self.project(); + trace!(?this.queued, "Acquire::drop"); + this.semaphore + .drop_acquire(this.waiter, *this.permits, *this.queued) + } +} + +// safety: the `Acquire` future is not automatically `Sync` because the `Waiter` +// node contains an `UnsafeCell`, which is not `Sync`. this impl is safe because +// the `Acquire` future will only access this `UnsafeCell` when mutably borrowed +// (when polling or dropping the future), so the future itself is safe to share +// immutably between threads. +unsafe impl Sync for Acquire<'_> {} + +// === impl Permit === + +impl Permit<'_> { + /// Forget this permit, dropping it *without* returning the number of + /// acquired permits to the semaphore. + /// + /// This permanently decreases the number of permits in the semaphore by + /// [`self.permits()`](Self::permits). + pub fn forget(mut self) { + self.permits = 0; + } + + /// Returns the count of semaphore permits owned by this `Permit`. + #[inline] + #[must_use] + pub fn permits(&self) -> usize { + self.permits + } +} + +impl Drop for Permit<'_> { + fn drop(&mut self) { + trace!(?self.permits, "Permit::drop"); + self.semaphore.add_permits(self.permits); + } +} + +// === Owned variants when `Arc` is available === + +feature! { + #![feature = "alloc"] + + use alloc::sync::Arc; + + /// Future returned from [`Semaphore::acquire_owned()`]. + /// + /// This is identical to the [`Acquire`] future, except that it takes an + /// [`Arc`] reference to the [`Semaphore`], allowing the returned future to + /// live for the `'static` lifetime, and returns an [`OwnedPermit`] (rather + /// than a [`Permit`]), which is also valid for the `'static` lifetime. + #[derive(Debug)] + #[pin_project(PinnedDrop)] + #[must_use = "futures do nothing unless `.await`ed or `poll`ed"] + pub struct AcquireOwned { + semaphore: Arc, + queued: bool, + permits: usize, + #[pin] + waiter: Waiter, + } + + /// An owned [RAII guard] representing one or more permits acquired from a + /// [`Semaphore`]. + /// + /// When the `OwnedPermit` is dropped, the permits it represents are + /// released back to the [`Semaphore`], potentially waking another task. + /// + /// This type is identical to the [`Permit`] type, except that it holds an + /// [`Arc`] clone of the [`Semaphore`], rather than borrowing it. This + /// allows the guard to be valid for the `'static` lifetime. + /// + /// This type is returned by the [`Semaphore::acquire_owned`] and + /// [`Semaphore::try_acquire_owned`] methods. + /// + /// [RAII guard]: https://rust-unofficial.github.io/patterns/patterns/behavioural/RAII.html + #[derive(Debug)] + #[must_use = "dropping an `OwnedPermit` releases the acquired permits back to the `Semaphore`"] + pub struct OwnedPermit { + permits: usize, + semaphore: Arc, + } + + impl Semaphore { + /// Acquire `permits` permits from the `Semaphore`, waiting asynchronously + /// if there are insufficient permits currently available, and returning + /// an [`OwnedPermit`]. + /// + /// This method behaves identically to [`acquire`], except that it + /// requires the `Semaphore` to be wrapped in an [`Arc`], and returns an + /// [`OwnedPermit`] which clones the [`Arc`] rather than borrowing the + /// semaphore. This allows the returned [`OwnedPermit`] to be valid for + /// the `'static` lifetime. + /// + /// # Returns + /// + /// - `Ok(`[`OwnedPermit`]`)` with the requested number of permits, if the + /// permits were acquired. + /// - `Err(`[`Closed`]`)` if the semaphore was [closed]. + /// + /// # Cancellation + /// + /// This method uses a queue to fairly distribute permits in the order they + /// were requested. If an [`AcquireOwned`] future is dropped before it + /// completes, the task will lose its place in the queue. + /// + /// [`acquire`]: Semaphore::acquire + /// [`Closed`]: crate::wait::Closed + /// [closed]: Semaphore::close + pub fn acquire_owned(self: &Arc, permits: usize) -> AcquireOwned { + AcquireOwned { + semaphore: self.clone(), + queued: false, + permits, + waiter: Waiter::new(permits), + } + } + + /// Try to acquire `permits` permits from the `Semaphore`, without waiting + /// for additional permits to become available, and returning an [`OwnedPermit`]. + /// + /// This method behaves identically to [`try_acquire`], except that it + /// requires the `Semaphore` to be wrapped in an [`Arc`], and returns an + /// [`OwnedPermit`] which clones the [`Arc`] rather than borrowing the + /// semaphore. This allows the returned [`OwnedPermit`] to be valid for + /// the `'static` lifetime. + /// + /// # Returns + /// + /// - `Ok(`[`OwnedPermit`]`)` with the requested number of permits, if the + /// permits were acquired. + /// - `Err(`[`TryAcquireError::Closed`]`)` if the semaphore was [closed]. + /// - `Err(`[`TryAcquireError::InsufficientPermits`]`)` if the semaphore + /// had fewer than `permits` permits available. + /// + /// + /// [`try_acquire`]: Semaphore::try_acquire + /// [`Closed`]: crate::wait::Closed + /// [closed]: Semaphore::close + pub fn try_acquire_owned(self: &Arc, permits: usize) -> Result { + trace!(permits, "Semaphore::try_acquire_owned"); + self.try_acquire_inner(permits).map(|_| OwnedPermit { + permits, + semaphore: self.clone(), + }) + } + } + + // === impl AcquireOwned === + + impl Future for AcquireOwned { + type Output = WaitResult; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let poll = this + .semaphore + .poll_acquire(this.waiter, *this.permits, *this.queued, cx) + .map_ok(|_| OwnedPermit { + permits: *this.permits, + // TODO(eliza): might be nice to not have to bump the + // refcount here... + semaphore: this.semaphore.clone(), + }); + *this.queued = poll.is_pending(); + poll + } + } + + #[pinned_drop] + impl PinnedDrop for AcquireOwned { + fn drop(mut self: Pin<&mut Self>) { + let this = self.project(); + trace!(?this.queued, "AcquireOwned::drop"); + this.semaphore + .drop_acquire(this.waiter, *this.permits, *this.queued) + } + } + + // safety: this is safe for the same reasons as the `Sync` impl for the + // `Acquire` future. + unsafe impl Sync for AcquireOwned {} + + // === impl OwnedPermit === + + impl OwnedPermit { + /// Forget this permit, dropping it *without* returning the number of + /// acquired permits to the semaphore. + /// + /// This permanently decreases the number of permits in the semaphore by + /// [`self.permits()`](Self::permits). + pub fn forget(mut self) { + self.permits = 0; + } + + /// Returns the count of semaphore permits owned by this `OwnedPermit`. + #[inline] + #[must_use] + pub fn permits(&self) -> usize { + self.permits + } + } + + impl Drop for OwnedPermit { + fn drop(&mut self) { + trace!(?self.permits, "OwnedPermit::drop"); + self.semaphore.add_permits(self.permits); + } + } + +} + +// === impl Waiter === + +impl Waiter { + fn new(permits: usize) -> Self { + Self { + node: UnsafeCell::new(Node { + links: list::Links::new(), + waker: None, + _pin: PhantomPinned, + }), + remaining_permits: RemainingPermits(AtomicUsize::new(permits)), + } + } + + #[inline(always)] + #[cfg_attr(loom, track_caller)] + fn take_waker(this: NonNull, list: &mut List) -> Option { + Self::with_node(this, list, |node| node.waker.take()) + } + + /// # Safety + /// + /// This is only safe to call while the list is locked. The dummy `_list` + /// parameter ensures this method is only called while holding the lock, so + /// this can be safe. + /// + /// Of course, that must be the *same* list that this waiter is a member of, + /// and currently, there is no way to ensure that... + #[inline(always)] + #[cfg_attr(loom, track_caller)] + fn with_node( + mut this: NonNull, + _list: &mut List, + f: impl FnOnce(&mut Node) -> T, + ) -> T { + unsafe { + // safety: this is only called while holding the lock on the queue, + // so it's safe to mutate the waiter. + this.as_mut().node.with_mut(|node| f(&mut *node)) + } + } +} + +unsafe impl Linked> for Waiter { + type Handle = NonNull; + + fn into_ptr(r: Self::Handle) -> NonNull { + r + } + + unsafe fn from_ptr(ptr: NonNull) -> Self::Handle { + ptr + } + + unsafe fn links(target: NonNull) -> NonNull> { + // Safety: using `ptr::addr_of!` avoids creating a temporary + // reference, which stacked borrows dislikes. + let node = ptr::addr_of!((*target.as_ptr()).node); + (*node).with_mut(|node| { + let links = ptr::addr_of_mut!((*node).links); + // Safety: since the `target` pointer is `NonNull`, we can assume + // that pointers to its members are also not null, making this use + // of `new_unchecked` fine. + NonNull::new_unchecked(links) + }) + } +} + +// === impl RemainingPermits === + +impl RemainingPermits { + /// Add an acquisition of permits to the waiter, returning whether or not + /// the waiter has acquired enough permits to be woken. + #[inline] + #[cfg_attr(loom, track_caller)] + fn add(&self, permits: &mut usize) -> bool { + let mut curr = self.0.load(Relaxed); + loop { + let taken = cmp::min(curr, *permits); + let remaining = curr - taken; + match self + .0 + .compare_exchange_weak(curr, remaining, AcqRel, Acquire) + { + // added the permits to the waiter! + Ok(_) => { + *permits -= taken; + return remaining == 0; + } + Err(actual) => curr = actual, + } + } + } + + #[inline] + fn remaining(&self) -> usize { + self.0.load(Acquire) + } +} diff --git a/maitake/src/wait/semaphore/loom.rs b/maitake/src/wait/semaphore/loom.rs new file mode 100644 index 00000000..8d1daf2f --- /dev/null +++ b/maitake/src/wait/semaphore/loom.rs @@ -0,0 +1,165 @@ +use super::*; +use crate::loom::{ + self, future, + sync::{ + atomic::{AtomicUsize, Ordering::SeqCst}, + Arc, + }, + thread, +}; + +#[test] +fn basically_works() { + const TASKS: usize = 2; + + async fn task((ref sem, ref count): &(Semaphore, AtomicUsize)) { + let permit = sem.acquire(1).await.unwrap(); + let actual = count.fetch_add(1, SeqCst); + assert!(actual <= TASKS - 1); + + let actual = count.fetch_sub(1, SeqCst); + assert!(actual <= TASKS); + drop(permit); + } + + loom::model(|| { + let sem = Arc::new((Semaphore::new(TASKS), AtomicUsize::new(0))); + let threads = (0..TASKS) + .map(|_| { + let sem = sem.clone(); + thread::spawn(move || { + future::block_on(task(&sem)); + }) + }) + .collect::>(); + + future::block_on(task(&sem)); + + for t in threads { + t.join().unwrap(); + } + }) +} + +#[test] +fn release_on_drop() { + loom::model(|| { + let sem = Arc::new(Semaphore::new(1)); + + let thread = thread::spawn({ + let sem = sem.clone(); + move || { + let _permit = future::block_on(sem.acquire(1)).unwrap(); + } + }); + + let permit = future::block_on(sem.acquire(1)).unwrap(); + drop(permit); + thread.join().unwrap(); + }) +} + +#[test] +fn close() { + loom::model(|| { + let sem = Arc::new(Semaphore::new(1)); + let threads: Vec<_> = (0..2) + .map(|_| { + thread::spawn({ + let sem = sem.clone(); + move || -> Result<(), ()> { + for _ in 0..2 { + let _permit = future::block_on(sem.acquire(1)).map_err(|_| ())?; + } + Ok(()) + } + }) + }) + .collect(); + + sem.close(); + + for thread in threads { + let _ = thread.join().unwrap(); + } + }) +} + +#[test] +fn concurrent_close() { + fn run(sem: Arc) -> impl FnOnce() -> Result<(), ()> { + move || { + let permit = future::block_on(sem.acquire(1)).map_err(|_| ())?; + drop(permit); + sem.close(); + Ok(()) + } + } + + loom::model(|| { + let sem = Arc::new(Semaphore::new(1)); + let threads: Vec<_> = (0..2).map(|_| thread::spawn(run(sem.clone()))).collect(); + let _ = run(sem)(); + + for thread in threads { + let _ = thread.join().unwrap(); + } + }) +} + +#[test] +fn concurrent_cancel() { + use futures_util::future::FutureExt; + fn run(sem: &Arc) -> impl FnOnce() { + let sem = sem.clone(); + move || { + future::block_on(async move { + // poll two `acquire` futures immediately and then cancel + // them, regardless of whether or not they complete. + let _permit1 = { + let acquire = sem.acquire(1); + acquire.now_or_never() + }; + let _permit2 = { + let acquire = sem.acquire(1); + acquire.now_or_never() + }; + }) + } + } + + loom::model(|| { + let sem = Arc::new(Semaphore::new(0)); + + let thread1 = thread::spawn(run(&sem)); + let thread2 = thread::spawn(run(&sem)); + let thread3 = thread::spawn(run(&sem)); + + thread1.join().unwrap(); + sem.add_permits(10); + thread2.join().unwrap(); + thread3.join().unwrap(); + }) +} + +#[test] +fn drop_permits_while_acquiring() { + loom::model(|| { + let sem = Arc::new(Semaphore::new(4)); + let permit1 = sem + .try_acquire(3) + .expect("semaphore has 4 permits, so we should acquire 3"); + let thread1 = thread::spawn({ + let sem = sem.clone(); + move || { + let _permit = future::block_on(sem.acquire(2)).unwrap(); + assert_eq!(sem.available_permits(), 2); + } + }); + + drop(permit1); + trace!("dropped permit 1"); + thread1.join().unwrap(); + assert_eq!(sem.available_permits(), 4); + }) +} diff --git a/maitake/src/wait/semaphore/tests.rs b/maitake/src/wait/semaphore/tests.rs new file mode 100644 index 00000000..ee9663f8 --- /dev/null +++ b/maitake/src/wait/semaphore/tests.rs @@ -0,0 +1,127 @@ +use super::*; + +fn assert_send_sync() {} + +#[test] +fn semaphore_is_send_and_sync() { + assert_send_sync::(); +} + +#[test] +fn permit_is_send_and_sync() { + assert_send_sync::>(); +} + +#[test] +fn acquire_is_send_and_sync() { + assert_send_sync::>(); +} + +#[cfg(feature = "alloc")] +mod alloc { + use super::*; + use crate::scheduler::Scheduler; + use ::alloc::sync::Arc; + use core::sync::atomic::AtomicBool; + + #[test] + fn owned_permit_is_send_and_sync() { + assert_send_sync::(); + } + + #[test] + fn acquire_owned_is_send_and_sync() { + assert_send_sync::(); + } + + #[test] + fn basic_concurrency_limit() { + const TASKS: usize = 8; + const CONCURRENCY_LIMIT: usize = 4; + crate::util::trace_init(); + + let scheduler = Scheduler::new(); + let semaphore = Arc::new(Semaphore::new(CONCURRENCY_LIMIT)); + let running = Arc::new(AtomicUsize::new(0)); + let completed = Arc::new(AtomicUsize::new(0)); + + for _ in 0..TASKS { + let semaphore = semaphore.clone(); + let running = running.clone(); + let completed = completed.clone(); + scheduler.spawn(async move { + let permit = semaphore + .acquire(1) + .await + .expect("semaphore will not be closed"); + assert!(test_dbg!(running.fetch_add(1, Relaxed)) < CONCURRENCY_LIMIT); + + crate::future::yield_now().await; + drop(permit); + + assert!(test_dbg!(running.fetch_sub(1, Relaxed)) <= CONCURRENCY_LIMIT); + completed.fetch_add(1, Relaxed); + }); + } + + while completed.load(Relaxed) < TASKS { + scheduler.tick(); + assert!(test_dbg!(running.load(Relaxed)) <= CONCURRENCY_LIMIT); + } + } + + #[test] + fn countdown() { + const TASKS: usize = 4; + + let scheduler = Scheduler::new(); + let semaphore = Arc::new(Semaphore::new(0)); + let a_done = Arc::new(AtomicUsize::new(0)); + let b_done = Arc::new(AtomicBool::new(false)); + + scheduler.spawn({ + let semaphore = semaphore.clone(); + let b_done = b_done.clone(); + let a_done = a_done.clone(); + async move { + tracing_02::info!("Task B starting..."); + + // Since the semaphore is created with 0 permits, this will + // wait until all 4 "A" tasks have completed. + let _permit = semaphore + .acquire(TASKS) + .await + .expect("semaphore will not be closed"); + assert_eq!(a_done.load(Relaxed), TASKS); + + // ... do some work ... + + tracing_02::info!("Task B done!"); + b_done.store(true, Relaxed); + } + }); + + for i in 0..TASKS { + let semaphore = semaphore.clone(); + let a_done = a_done.clone(); + scheduler.spawn(async move { + tracing_02::info!("Task A {i} starting..."); + + crate::future::yield_now().await; + + a_done.fetch_add(1, Relaxed); + semaphore.add_permits(1); + + // ... do some work ... + tracing_02::info!("Task A {i} done"); + }); + } + + while !b_done.load(Relaxed) { + scheduler.tick(); + } + + assert_eq!(a_done.load(Relaxed), TASKS); + assert!(b_done.load(Relaxed)); + } +}