Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use actual thread local queues instead of using a RwLock #93

Merged
merged 22 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ concurrent-queue = "2.0.0"
fastrand = "2.0.0"
futures-lite = { version = "2.0.0", default-features = false }
slab = "0.4.4"
thread_local = { git = "https://github.com/james7132/thread_local-rs", branch = "fix-iter-ub" }
james7132 marked this conversation as resolved.
Show resolved Hide resolved
atomic-waker = "1.0"

[target.'cfg(target_family = "wasm")'.dependencies]
futures-lite = { version = "2.0.0", default-features = false, features = ["std"] }
Expand Down
121 changes: 78 additions & 43 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,16 @@ use std::marker::PhantomData;
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::rc::Rc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, RwLock, TryLockError};
use std::sync::{Arc, Mutex, TryLockError};
use std::task::{Poll, Waker};

use async_lock::OnceCell;
use async_task::{Builder, Runnable};
use atomic_waker::AtomicWaker;
use concurrent_queue::ConcurrentQueue;
use futures_lite::{future, prelude::*};
use slab::Slab;
use thread_local::ThreadLocal;

#[doc(no_inline)]
pub use async_task::Task;
Expand Down Expand Up @@ -266,8 +268,17 @@ impl<'a> Executor<'a> {
fn schedule(&self) -> impl Fn(Runnable) + Send + Sync + 'static {
let state = self.state().clone();

// TODO: If possible, push into the current local queue and notify the ticker.
move |runnable| {
move |mut runnable| {
// If possible, push into the current local queue and notify the ticker.
if let Some(local) = state.local_queue.get() {
james7132 marked this conversation as resolved.
Show resolved Hide resolved
runnable = if let Err(err) = local.queue.push(runnable) {
err.into_inner()
} else {
local.waker.wake();
return;
james7132 marked this conversation as resolved.
Show resolved Hide resolved
}
}
// If the local queue is full, fallback to pushing onto the global injector queue.
state.queue.push(runnable).unwrap();
state.notify();
james7132 marked this conversation as resolved.
Show resolved Hide resolved
}
Expand Down Expand Up @@ -510,7 +521,11 @@ struct State {
queue: ConcurrentQueue<Runnable>,

/// Local queues created by runners.
local_queues: RwLock<Vec<Arc<ConcurrentQueue<Runnable>>>>,
///
/// If possible, tasks are scheduled onto the local queue, and will only defer
/// to othe global queue when they're full, or the task is being scheduled from
james7132 marked this conversation as resolved.
Show resolved Hide resolved
/// a thread without a runner.
local_queue: ThreadLocal<LocalQueue>,

/// Set to `true` when a sleeping ticker is notified or no tickers are sleeping.
notified: AtomicBool,
Expand All @@ -527,7 +542,7 @@ impl State {
fn new() -> State {
State {
queue: ConcurrentQueue::unbounded(),
local_queues: RwLock::new(Vec::new()),
local_queue: ThreadLocal::new(),
notified: AtomicBool::new(true),
sleepers: Mutex::new(Sleepers {
count: 0,
Expand Down Expand Up @@ -654,6 +669,12 @@ impl Ticker<'_> {
///
/// Returns `false` if the ticker was already sleeping and unnotified.
fn sleep(&mut self, waker: &Waker) -> bool {
self.state
.local_queue
.get_or_default()
.waker
.register(waker);

let mut sleepers = self.state.sleepers.lock().unwrap();

match self.sleeping {
Expand Down Expand Up @@ -692,7 +713,14 @@ impl Ticker<'_> {

/// Waits for the next runnable task to run.
async fn runnable(&mut self) -> Runnable {
self.runnable_with(|| self.state.queue.pop().ok()).await
self.runnable_with(|| {
self.state
.local_queue
.get()
.and_then(|local| local.queue.pop().ok())
.or_else(|| self.state.queue.pop().ok())
})
.await
}

/// Waits for the next runnable task to run, given a function that searches for a task.
Expand Down Expand Up @@ -754,9 +782,6 @@ struct Runner<'a> {
/// Inner ticker.
ticker: Ticker<'a>,

/// The local queue.
local: Arc<ConcurrentQueue<Runnable>>,

/// Bumped every time a runnable task is found.
ticks: usize,
}
Expand All @@ -767,38 +792,34 @@ impl Runner<'_> {
let runner = Runner {
state,
ticker: Ticker::new(state),
local: Arc::new(ConcurrentQueue::bounded(512)),
ticks: 0,
};
state
.local_queues
.write()
.unwrap()
.push(runner.local.clone());
runner
}

/// Waits for the next runnable task to run.
async fn runnable(&mut self, rng: &mut fastrand::Rng) -> Runnable {
let local = self.state.local_queue.get_or_default();

let runnable = self
.ticker
.runnable_with(|| {
// Try the local queue.
if let Ok(r) = self.local.pop() {
if let Ok(r) = local.queue.pop() {
return Some(r);
}

// Try stealing from the global queue.
if let Ok(r) = self.state.queue.pop() {
steal(&self.state.queue, &self.local);
steal(&self.state.queue, &local.queue);
return Some(r);
}

// Try stealing from other runners.
let local_queues = self.state.local_queues.read().unwrap();
let local_queues = &self.state.local_queue;

// Pick a random starting point in the iterator list and rotate the list.
let n = local_queues.len();
let n = local_queues.iter().count();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a cold operation? It seems like this would take a while.

Copy link
Contributor Author

@james7132 james7132 Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is one part I'm not so sure about. Generally this shouldn't be under contention, since the cost to spin up new threads is going to be higher than it is to scan over the entire container, unless you have literally thousands of threads. It otherwise is just a scan through fairly small buckets.

We could use an atomic counter to track how many there are, but since you can't remove items from the ThreadLocal, there will be residual thread locals from currently unused threads (as thread IDs are reused), that may get out of sync.

let start = rng.usize(..n);
let iter = local_queues
.iter()
Expand All @@ -807,12 +828,12 @@ impl Runner<'_> {
.take(n);

// Remove this runner's local queue.
let iter = iter.filter(|local| !Arc::ptr_eq(local, &self.local));
let iter = iter.filter(|other| !core::ptr::eq(*other, local));

// Try stealing from each local queue in the list.
for local in iter {
steal(local, &self.local);
if let Ok(r) = self.local.pop() {
for other in iter {
steal(&other.queue, &local.queue);
if let Ok(r) = local.queue.pop() {
return Some(r);
}
}
Expand All @@ -826,7 +847,7 @@ impl Runner<'_> {

if self.ticks % 64 == 0 {
// Steal tasks from the global queue to ensure fair task scheduling.
steal(&self.state.queue, &self.local);
steal(&self.state.queue, &local.queue);
}

runnable
Expand All @@ -836,15 +857,13 @@ impl Runner<'_> {
impl Drop for Runner<'_> {
fn drop(&mut self) {
// Remove the local queue.
self.state
.local_queues
.write()
.unwrap()
.retain(|local| !Arc::ptr_eq(local, &self.local));

// Re-schedule remaining tasks in the local queue.
while let Ok(r) = self.local.pop() {
r.schedule();
if let Some(local) = self.state.local_queue.get() {
// Re-schedule remaining tasks in the local queue.
for r in local.queue.try_iter() {
// Explicitly reschedule the runnable back onto the global
// queue to avoid rescheduling onto the local one.
self.state.queue.push(r).unwrap();
}
}
}
}
Expand Down Expand Up @@ -904,18 +923,13 @@ fn debug_executor(executor: &Executor<'_>, name: &str, f: &mut fmt::Formatter<'_
}

/// Debug wrapper for the local runners.
struct LocalRunners<'a>(&'a RwLock<Vec<Arc<ConcurrentQueue<Runnable>>>>);
struct LocalRunners<'a>(&'a ThreadLocal<LocalQueue>);

impl fmt::Debug for LocalRunners<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0.try_read() {
Ok(lock) => f
.debug_list()
.entries(lock.iter().map(|queue| queue.len()))
.finish(),
Err(TryLockError::WouldBlock) => f.write_str("<locked>"),
Err(TryLockError::Poisoned(_)) => f.write_str("<poisoned>"),
}
f.debug_list()
.entries(self.0.iter().map(|local| local.queue.len()))
.finish()
}
}

Expand All @@ -935,11 +949,32 @@ fn debug_executor(executor: &Executor<'_>, name: &str, f: &mut fmt::Formatter<'_
f.debug_struct(name)
.field("active", &ActiveTasks(&state.active))
.field("global_tasks", &state.queue.len())
.field("local_runners", &LocalRunners(&state.local_queues))
.field("local_runners", &LocalRunners(&state.local_queue))
.field("sleepers", &SleepCount(&state.sleepers))
.finish()
}

/// A queue local to each thread.
///
/// It's Default implementation is used for initializing each
/// thread's queue via `ThreadLocal::get_or_default`.
///
/// The local queue *must* be flushed, and all pending runnables
/// rescheduled onto the global queue when a runner is dropped.
struct LocalQueue {
queue: ConcurrentQueue<Runnable>,
waker: AtomicWaker,
}

impl Default for LocalQueue {
fn default() -> Self {
Self {
queue: ConcurrentQueue::bounded(512),
waker: AtomicWaker::new(),
}
}
}

/// Runs a closure when dropped.
struct CallOnDrop<F: FnMut()>(F);

Expand Down
Loading