diff --git a/Cargo.toml b/Cargo.toml index 4025c43..6bee628 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,8 @@ exclude = ["/.*"] [features] # Adds support for executors optimized for use in static variables. static = [] +# Adds support for executors optimized for use in main(). +main_executor = ["event-listener", "async-io"] [dependencies] async-task = "4.4.0" @@ -24,6 +26,8 @@ concurrent-queue = "2.5.0" fastrand = "2.0.0" futures-lite = { version = "2.0.0", default-features = false } slab = "0.4.4" +event-listener = {version = "5.1.0", optional = true } +async-io = {version = "2.1.0", optional = true } [target.'cfg(target_family = "wasm")'.dependencies] futures-lite = { version = "2.0.0", default-features = false, features = ["std"] } @@ -38,6 +42,10 @@ fastrand = "2.0.0" futures-lite = "2.0.0" once_cell = "1.16.0" +[[example]] +name = "thread_pool" +required-features = ["main_executor"] + [[bench]] name = "executor" harness = false diff --git a/examples/thread_pool.rs b/examples/thread_pool.rs new file mode 100644 index 0000000..1a70046 --- /dev/null +++ b/examples/thread_pool.rs @@ -0,0 +1,18 @@ +//! An example of using with_thread_pool. + +use std::sync::Arc; + +use async_executor::{with_thread_pool, Executor}; + +async fn async_main(_ex: &Executor<'_>) -> Result<(), Box> { + println!("Hello, world!"); + Ok(()) +} + +fn main() -> Result<(), Box> { + // create executor + let ex = Arc::new(Executor::new()); + + // run executor on thread pool + with_thread_pool(&ex, || async_io::block_on(async_main(&ex))) +} diff --git a/src/lib.rs b/src/lib.rs index 2ec014a..9b5e52e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -55,12 +55,19 @@ use slab::Slab; #[cfg(feature = "static")] mod static_executors; +#[cfg(feature = "main_executor")] +mod main_executor; + #[doc(no_inline)] pub use async_task::{FallibleTask, Task}; + #[cfg(feature = "static")] #[cfg_attr(docsrs, doc(cfg(any(feature = "static"))))] pub use static_executors::*; +#[cfg(feature = "main_executor")] +pub use main_executor::*; + /// An async executor. /// /// # Examples diff --git a/src/main_executor.rs b/src/main_executor.rs new file mode 100644 index 0000000..b1183dd --- /dev/null +++ b/src/main_executor.rs @@ -0,0 +1,119 @@ +use std::rc::Rc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::thread; + +use event_listener::Event; + +use crate::{Executor, LocalExecutor}; + +/// Wait for the executor to stop. +pub(crate) struct WaitForStop { + /// Whether or not we need to stop. + stopped: AtomicBool, + + /// Wait for the stop. + events: Event, +} + +impl WaitForStop { + /// Create a new wait for stop. + #[inline] + pub(crate) fn new() -> Self { + Self { + stopped: AtomicBool::new(false), + events: Event::new(), + } + } + + /// Wait for the event to stop. + #[inline] + pub(crate) async fn wait(&self) { + loop { + if self.stopped.load(Ordering::Relaxed) { + return; + } + + event_listener::listener!(&self.events => listener); + + if self.stopped.load(Ordering::Acquire) { + return; + } + + listener.await; + } + } + + /// Stop the waiter. + #[inline] + pub(crate) fn stop(&self) { + self.stopped.store(true, Ordering::SeqCst); + self.events.notify_additional(usize::MAX); + } +} + +/// Something that can be set up as an executor. +pub trait MainExecutor: Sized { + /// Create this type and pass it into `main`. + fn with_main T>(f: F) -> T; +} + +impl MainExecutor for Arc> { + #[inline] + fn with_main T>(f: F) -> T { + let ex = Arc::new(Executor::new()); + with_thread_pool(&ex, || f(&ex)) + } +} + +impl MainExecutor for Executor<'_> { + #[inline] + fn with_main T>(f: F) -> T { + let ex = Executor::new(); + with_thread_pool(&ex, || f(&ex)) + } +} + +impl MainExecutor for Rc> { + #[inline] + fn with_main T>(f: F) -> T { + f(&Rc::new(LocalExecutor::new())) + } +} + +impl MainExecutor for LocalExecutor<'_> { + fn with_main T>(f: F) -> T { + f(&LocalExecutor::new()) + } +} + +/// Run a function that takes an `Executor` inside of a thread pool. +#[inline] +fn with_thread_pool(ex: &Executor<'_>, f: impl FnOnce() -> T) -> T { + let stopper = WaitForStop::new(); + + // Create a thread for each CPU. + thread::scope(|scope| { + let num_threads = thread::available_parallelism().map_or(1, |num| num.get()); + for i in 0..num_threads { + let ex = &ex; + let stopper = &stopper; + + thread::Builder::new() + .name(format!("smol-macros-{i}")) + .spawn_scoped(scope, || { + async_io::block_on(ex.run(stopper.wait())); + }) + .expect("failed to spawn thread"); + } + + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)); + + stopper.stop(); + + match result { + Ok(value) => value, + Err(err) => std::panic::resume_unwind(err), + } + }) +}