Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only work-steal in the main loop #12

Open
wants to merge 1 commit into
base: rustc
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rayon-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ num_cpus = "1.2"
crossbeam-channel = "0.5.0"
crossbeam-deque = "0.8.1"
crossbeam-utils = "0.8.0"
smallvec = "1.11.0"

[dev-dependencies]
rand = "0.8"
Expand Down
24 changes: 21 additions & 3 deletions rayon-core/src/broadcast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::registry::{Registry, WorkerThread};
use crate::scope::ScopeLatch;
use std::fmt;
use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

mod test;
Expand Down Expand Up @@ -100,13 +101,22 @@ where
OP: Fn(BroadcastContext<'_>) -> R + Sync,
R: Send,
{
let current_thread = WorkerThread::current();
let current_thread_addr = current_thread as usize;
let started = &AtomicBool::new(false);
let f = move |injected: bool| {
debug_assert!(injected);

// Mark as started if we are on the thread that initiated the broadcast.
if current_thread_addr == WorkerThread::current() as usize {
started.store(true, Ordering::Relaxed);
}

BroadcastContext::with(&op)
};

let n_threads = registry.num_threads();
let current_thread = WorkerThread::current().as_ref();
let current_thread = current_thread.as_ref();
let tlv = crate::tlv::get();
let latch = ScopeLatch::with_count(n_threads, current_thread);
let jobs: Vec<_> = (0..n_threads)
Expand All @@ -116,8 +126,16 @@ where

registry.inject_broadcast(job_refs);

let current_thread_job_id = current_thread
.and_then(|worker| (registry.id() == worker.registry.id()).then(|| worker))
.map(|worker| jobs[worker.index].as_job_ref().id());

// Wait for all jobs to complete, then collect the results, maybe propagating a panic.
latch.wait(current_thread);
latch.wait(
current_thread,
|| started.load(Ordering::Relaxed),
|job| Some(job.id()) == current_thread_job_id,
);
jobs.into_iter().map(|job| job.into_result()).collect()
}

Expand All @@ -133,7 +151,7 @@ where
{
let job = ArcJob::new({
let registry = Arc::clone(registry);
move || {
move |_| {
registry.catch_unwind(|| BroadcastContext::with(&op));
registry.terminate(); // (*) permit registry to terminate now
}
Expand Down
2 changes: 2 additions & 0 deletions rayon-core/src/broadcast/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ fn spawn_broadcast_self() {
}

#[test]
#[ignore]
#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
fn broadcast_mutual() {
let count = AtomicUsize::new(0);
Expand Down Expand Up @@ -97,6 +98,7 @@ fn spawn_broadcast_mutual() {
}

#[test]
#[ignore]
#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
fn broadcast_mutual_sleepy() {
let count = AtomicUsize::new(0);
Expand Down
40 changes: 26 additions & 14 deletions rayon-core/src/job.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ pub(super) trait Job {
unsafe fn execute(this: *const ());
}

#[derive(PartialEq, Eq, Hash, Copy, Clone, Debug)]
pub(super) struct JobRefId {
pointer: usize,
}

/// Effectively a Job trait object. Each JobRef **must** be executed
/// exactly once, or else data may leak.
///
Expand Down Expand Up @@ -54,11 +59,11 @@ impl JobRef {
}
}

/// Returns an opaque handle that can be saved and compared,
/// without making `JobRef` itself `Copy + Eq`.
#[inline]
pub(super) fn id(&self) -> impl Eq {
(self.pointer, self.execute_fn)
pub(super) fn id(&self) -> JobRefId {
JobRefId {
pointer: self.pointer as usize,
}
}

#[inline]
Expand Down Expand Up @@ -102,8 +107,13 @@ where
JobRef::new(self)
}

pub(super) unsafe fn run_inline(self, stolen: bool) -> R {
self.func.into_inner().unwrap()(stolen)
pub(super) unsafe fn run_inline(&self, stolen: bool) {
let func = (*self.func.get()).take().unwrap();
(*self.result.get()) = match unwind::halt_unwinding(|| func(stolen)) {
Ok(x) => JobResult::Ok(x),
Err(x) => JobResult::Panic(x),
};
Latch::set(&self.latch);
}

pub(super) unsafe fn into_result(self) -> R {
Expand Down Expand Up @@ -136,15 +146,15 @@ where
/// (Probably `StackJob` should be refactored in a similar fashion.)
pub(super) struct HeapJob<BODY>
where
BODY: FnOnce() + Send,
BODY: FnOnce(JobRefId) + Send,
{
job: BODY,
tlv: Tlv,
}

impl<BODY> HeapJob<BODY>
where
BODY: FnOnce() + Send,
BODY: FnOnce(JobRefId) + Send,
{
pub(super) fn new(tlv: Tlv, job: BODY) -> Box<Self> {
Box::new(HeapJob { job, tlv })
Expand All @@ -168,27 +178,28 @@ where

impl<BODY> Job for HeapJob<BODY>
where
BODY: FnOnce() + Send,
BODY: FnOnce(JobRefId) + Send,
{
unsafe fn execute(this: *const ()) {
let pointer = this as usize;
let this = Box::from_raw(this as *mut Self);
tlv::set(this.tlv);
(this.job)();
(this.job)(JobRefId { pointer });
}
}

/// Represents a job stored in an `Arc` -- like `HeapJob`, but may
/// be turned into multiple `JobRef`s and called multiple times.
pub(super) struct ArcJob<BODY>
where
BODY: Fn() + Send + Sync,
BODY: Fn(JobRefId) + Send + Sync,
{
job: BODY,
}

impl<BODY> ArcJob<BODY>
where
BODY: Fn() + Send + Sync,
BODY: Fn(JobRefId) + Send + Sync,
{
pub(super) fn new(job: BODY) -> Arc<Self> {
Arc::new(ArcJob { job })
Expand All @@ -212,11 +223,12 @@ where

impl<BODY> Job for ArcJob<BODY>
where
BODY: Fn() + Send + Sync,
BODY: Fn(JobRefId) + Send + Sync,
{
unsafe fn execute(this: *const ()) {
let pointer = this as usize;
let this = Arc::from_raw(this as *mut Self);
(this.job)();
(this.job)(JobRefId { pointer });
}
}

Expand Down
82 changes: 28 additions & 54 deletions rayon-core/src/join/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::job::JobRef;
use crate::job::StackJob;
use crate::latch::SpinLatch;
use crate::registry::{self, WorkerThread};
use crate::tlv::{self, Tlv};
use crate::registry;
use crate::tlv;
use crate::unwind;
use std::any::Any;
use std::sync::atomic::{AtomicBool, Ordering};

use crate::FnContext;

Expand Down Expand Up @@ -135,68 +136,41 @@ where
// Create virtual wrapper for task b; this all has to be
// done here so that the stack frame can keep it all live
// long enough.
let job_b = StackJob::new(tlv, call_b(oper_b), SpinLatch::new(worker_thread));
let job_b_started = AtomicBool::new(false);
let job_b = StackJob::new(
tlv,
|migrated| {
job_b_started.store(true, Ordering::Relaxed);
call_b(oper_b)(migrated)
},
SpinLatch::new(worker_thread),
);
let job_b_ref = job_b.as_job_ref();
let job_b_id = job_b_ref.id();
worker_thread.push(job_b_ref);

// Execute task a; hopefully b gets stolen in the meantime.
let status_a = unwind::halt_unwinding(call_a(oper_a, injected));
let result_a = match status_a {
Ok(v) => v,
Err(err) => join_recover_from_panic(worker_thread, &job_b.latch, err, tlv),
};

// Now that task A has finished, try to pop job B from the
// local stack. It may already have been popped by job A; it
// may also have been stolen. There may also be some tasks
// pushed on top of it in the stack, and we will have to pop
// those off to get to it.
while !job_b.latch.probe() {
if let Some(job) = worker_thread.take_local_job() {
if job_b_id == job.id() {
// Found it! Let's run it.
//
// Note that this could panic, but it's ok if we unwind here.

// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
tlv::set(tlv);

let result_b = job_b.run_inline(injected);
return (result_a, result_b);
} else {
worker_thread.execute(job);
}
} else {
// Local deque is empty. Time to steal from other
// threads.
worker_thread.wait_until(&job_b.latch);
debug_assert!(job_b.latch.probe());
break;
}
}
// Wait for job B or execute it if it's in the local queue.
worker_thread.wait_for_jobs::<_, false>(
&job_b.latch,
|| job_b_started.load(Ordering::Relaxed),
|job| job.id() == job_b_id,
|job: JobRef| {
debug_assert_eq!(job.id(), job_b_id);
job_b.run_inline(injected);
},
);

// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
tlv::set(tlv);

let result_a = match status_a {
Ok(v) => v,
Err(err) => unwind::resume_unwinding(err),
};

(result_a, job_b.into_result())
})
}

/// If job A panics, we still cannot return until we are sure that job
/// B is complete. This is because it may contain references into the
/// enclosing stack frame(s).
#[cold] // cold path
unsafe fn join_recover_from_panic(
worker_thread: &WorkerThread,
job_b_latch: &SpinLatch<'_>,
err: Box<dyn Any + Send>,
tlv: Tlv,
) -> ! {
worker_thread.wait_until(job_b_latch);

// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
tlv::set(tlv);

unwind::resume_unwinding(err)
}
1 change: 1 addition & 0 deletions rayon-core/src/join/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ fn join_context_both() {
}

#[test]
#[ignore]
#[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
fn join_context_neither() {
// If we're already in a 1-thread pool, neither job should be stolen.
Expand Down
5 changes: 0 additions & 5 deletions rayon-core/src/latch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,6 @@ impl<'r> SpinLatch<'r> {
..SpinLatch::new(thread)
}
}

#[inline]
pub(super) fn probe(&self) -> bool {
self.core_latch.probe()
}
}

impl<'r> AsCoreLatch for SpinLatch<'r> {
Expand Down
Loading