Skip to content

Commit

Permalink
m: New impl
Browse files Browse the repository at this point in the history
Signed-off-by: John Nunley <dev@notgull.net>
  • Loading branch information
notgull committed May 14, 2024
1 parent a80088e commit 881056f
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 48 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ exclude = ["/.*"]
static = []

[dependencies]
ahash = "0.8.11"
async-task = "4.4.0"
concurrent-queue = "2.5.0"
fastrand = "2.0.0"
Expand Down
214 changes: 167 additions & 47 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,17 @@
)]
#![cfg_attr(docsrs, feature(doc_auto_cfg))]

use std::cell::{Cell, RefCell};
use std::cmp::Reverse;
use std::collections::VecDeque;
use std::fmt;
use std::marker::PhantomData;
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::rc::Rc;
use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, RwLock, TryLockError};
use std::task::{Poll, Waker};
use std::thread::{self, ThreadId};

use ahash::AHashMap;
use async_task::{Builder, Runnable};
use concurrent_queue::ConcurrentQueue;
use futures_lite::{future, prelude::*};
Expand Down Expand Up @@ -355,8 +356,8 @@ impl<'a> Executor<'a> {
.local_queues
.read()
.unwrap()
.get(&thread_id())
.and_then(|list| list.first())
.get(thread_id())
.and_then(|list| list.as_ref())
{
match local_queue.queue.push(runnable) {
Ok(()) => {
Expand Down Expand Up @@ -692,8 +693,9 @@ struct State {

/// Local queues created by runners.
///
/// These are keyed by the thread that the runner originated in.
local_queues: RwLock<AHashMap<ThreadId, Vec<Arc<LocalQueue>>>>,
/// These are keyed by the thread that the runner originated in. See the `thread_id` function
/// for more information.
local_queues: RwLock<Vec<Option<Arc<LocalQueue>>>>,

/// Set to `true` when a sleeping ticker is notified or no tickers are sleeping.
notified: AtomicBool,
Expand All @@ -710,7 +712,7 @@ impl State {
const fn new() -> State {
State {
queue: ConcurrentQueue::unbounded(),
local_queues: RwLock::new(AHashMap::new()),
local_queues: RwLock::new(Vec::new()),
notified: AtomicBool::new(true),
sleepers: Mutex::new(Sleepers {
count: 0,
Expand Down Expand Up @@ -1025,7 +1027,9 @@ struct Runner<'a> {
ticker: Ticker<'a>,

/// The ID of the thread we originated from.
origin_id: ThreadId,
///
/// This is `None` if we don't own the local runner for this thread.
origin_id: Option<usize>,

/// The local queue.
local: Arc<LocalQueue>,
Expand All @@ -1041,23 +1045,42 @@ impl Runner<'_> {
let runner_id = ID_GENERATOR.fetch_add(1, Ordering::SeqCst);

let origin_id = thread_id();
let runner = Runner {
let mut runner = Runner {
state,
ticker: Ticker::for_runner(state, runner_id),
local: Arc::new(LocalQueue {
queue: ConcurrentQueue::bounded(512),
runner_id,
}),
ticks: 0,
origin_id,
origin_id: Some(origin_id),
};
state

// If this is the highest thread ID this executor has seen, make more slots.
let mut local_queues = state.local_queues.write().unwrap();
if local_queues.len() <= origin_id {
local_queues.resize_with(origin_id + 1, || None);
}

// Try to reserve the thread-local slot.
match state
.local_queues
.write()
.unwrap()
.entry(origin_id)
.or_default()
.push(runner.local.clone());
.get_mut(origin_id)
.unwrap()
{
slot @ None => {
// We won the race, insert our queue.
*slot = Some(runner.local.clone());
}

Some(_) => {
// We lost the race, indicate we don't own this ID.
runner.origin_id = None;
}
}

runner
}

Expand Down Expand Up @@ -1085,8 +1108,8 @@ impl Runner<'_> {
let start = rng.usize(..n);
let iter = local_queues
.iter()
.flat_map(|(_, list)| list)
.chain(local_queues.iter().flat_map(|(_, list)| list))
.filter_map(|list| list.as_ref())
.chain(local_queues.iter().filter_map(|list| list.as_ref()))
.skip(start)
.take(n);

Expand Down Expand Up @@ -1120,13 +1143,15 @@ impl Runner<'_> {
impl Drop for Runner<'_> {
fn drop(&mut self) {
// Remove the local queue.
self.state
.local_queues
.write()
.unwrap()
.get_mut(&self.origin_id)
.unwrap()
.retain(|local| !Arc::ptr_eq(local, &self.local));
if let Some(origin_id) = self.origin_id {
*self
.state
.local_queues
.write()
.unwrap()
.get_mut(origin_id)
.unwrap() = None;
}

// Re-schedule remaining tasks in the local queue.
while let Ok(r) = self.local.queue.pop() {
Expand Down Expand Up @@ -1206,25 +1231,7 @@ fn debug_state(state: &State, name: &str, f: &mut fmt::Formatter<'_>) -> fmt::Re
}
}

/// Debug wrapper for the local runners.
struct LocalRunners<'a>(&'a RwLock<AHashMap<ThreadId, Vec<Arc<LocalQueue>>>>);

impl fmt::Debug for LocalRunners<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0.try_read() {
Ok(lock) => f
.debug_list()
.entries(
lock.iter()
.flat_map(|(_, list)| list)
.map(|queue| queue.queue.len()),
)
.finish(),
Err(TryLockError::WouldBlock) => f.write_str("<locked>"),
Err(TryLockError::Poisoned(_)) => f.write_str("<poisoned>"),
}
}
}
// TODO: Add wrapper for thread-local queues.

/// Debug wrapper for the sleepers.
struct SleepCount<'a>(&'a Mutex<Sleepers>);
Expand All @@ -1242,18 +1249,131 @@ fn debug_state(state: &State, name: &str, f: &mut fmt::Formatter<'_>) -> fmt::Re
f.debug_struct(name)
.field("active", &ActiveTasks(&state.active))
.field("global_tasks", &state.queue.len())
.field("local_runners", &LocalRunners(&state.local_queues))
.field("sleepers", &SleepCount(&state.sleepers))
.finish()
}

fn thread_id() -> ThreadId {
fn thread_id() -> usize {
// TODO: This strategy does not work for WASM, figure out a better way!

/// An allocator for thread IDs.
struct Allocator {
/// The next thread ID to yield.
free_id: usize,

/// The list of thread ID's that have been released.
///
/// This exists to defend against the case where a user spawns a million threads, then calls
/// this function, then drops all of those threads. In a few moments this strategy could take up
/// all of the available thread ID space. Therefore we try to reuse thread IDs after they've been
/// dropped.
///
/// We prefer lower thread IDs, as larger thread IDs require more memory in the const-time addressing
/// strategy we use for thread-specific storage.
///
/// This is only `None` at program startup, it's only `Option` for const initialization.
///
/// TODO(notgull): make an entry in the "useful features" table for this
released_ids: Option<VecDeque<Reverse<usize>>>,
}

impl Allocator {
/// Run a closure with the address allocator.
fn with<R>(f: impl FnOnce(&mut Allocator) -> R) -> R {
static ALLOCATOR: Mutex<Allocator> = Mutex::new(Allocator {
free_id: 0,
released_ids: None,
});

f(&mut ALLOCATOR.lock().unwrap_or_else(|x| x.into_inner()))
}

/// Get the queue for released IDs.
fn released_ids(&mut self) -> &mut VecDeque<Reverse<usize>> {
self.released_ids.get_or_insert_with(VecDeque::default)
}

/// Get the newest ID.
fn alloc(&mut self) -> usize {
// If we can, reuse an existing ID.
if let Some(Reverse(id)) = self.released_ids().pop_front() {
return id;
}

// Increment our ID counter.
let id = self.free_id;
self.free_id = self
.free_id
.checked_add(1)
.expect("took up all available thread-ID space");
id
}

/// Free an ID that was previously allocated.
fn free(&mut self, id: usize) {
self.released_ids().push_front(Reverse(id));
}
}

thread_local! {
static ID: ThreadId = thread::current().id();
/// The unique ID for this thread.
static THREAD_ID: Cell<Option<usize>> = const { Cell::new(None) };
}

thread_local! {
/// A destructor that frees this ID before the thread exits.
///
/// This is separate from `THREAD_ID` so that accessing it does not involve a thread-local
/// destructor.
static THREAD_GUARD: RefCell<Option<ThreadGuard>> = const { RefCell::new(None) };
}

struct ThreadGuard(usize);

impl Drop for ThreadGuard {
fn drop(&mut self) {
// DEADLOCK: Allocator is only ever held by this and the first call to "thread_id".
Allocator::with(|alloc| {
// De-allocate the ID.
alloc.free(self.0);
});
}
}

/// Fast path for getting the thread ID.
#[inline]
fn get() -> usize {
// Try to use the cached thread ID.
THREAD_ID.with(|thread_id| {
if let Some(thread_id) = thread_id.get() {
return thread_id;
}

// Use the slow path.
get_slow(thread_id)
})
}

/// Slow path for getting the thread ID.
#[cold]
fn get_slow(slot: &Cell<Option<usize>>) -> usize {
// Allocate a new thread ID.
let id = Allocator::with(|alloc| alloc.alloc());

// Store the thread ID.
let old = slot.replace(Some(id));
debug_assert!(old.is_none());

// Store the destructor,
THREAD_GUARD.with(|guard| {
*guard.borrow_mut() = Some(ThreadGuard(id));
});

// Return the ID.
id
}

ID.try_with(|id| *id)
.unwrap_or_else(|_| thread::current().id())
get()
}

/// Runs a closure when dropped.
Expand Down

0 comments on commit 881056f

Please sign in to comment.