diff --git a/Cargo.toml b/Cargo.toml index 5d3f7bb..6920042 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,6 @@ categories = ["asynchronous", "concurrency"] exclude = ["/.*"] [dependencies] -async-lock = "3.0.0" async-task = "4.4.0" concurrent-queue = "2.0.0" fastrand = "2.0.0" diff --git a/src/lib.rs b/src/lib.rs index 7d91a59..0c51fd5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,7 +25,12 @@ //! future::block_on(ex.run(task)); //! ``` -#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)] +#![warn( + missing_docs, + missing_debug_implementations, + rust_2018_idioms, + clippy::undocumented_unsafe_blocks +)] #![doc( html_favicon_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png" )] @@ -37,11 +42,10 @@ use std::fmt; use std::marker::PhantomData; use std::panic::{RefUnwindSafe, UnwindSafe}; use std::rc::Rc; -use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering}; use std::sync::{Arc, Mutex, RwLock, TryLockError}; use std::task::{Poll, Waker}; -use async_lock::OnceCell; use async_task::{Builder, Runnable}; use concurrent_queue::ConcurrentQueue; use futures_lite::{future, prelude::*}; @@ -76,13 +80,15 @@ pub use async_task::Task; /// ``` pub struct Executor<'a> { /// The executor state. - state: OnceCell>, + state: AtomicPtr, /// Makes the `'a` lifetime invariant. _marker: PhantomData>, } +// SAFETY: Executor stores no thread local state that can be accessed via other thread. unsafe impl Send for Executor<'_> {} +// SAFETY: Executor internally synchronizes all of it's operations internally. unsafe impl Sync for Executor<'_> {} impl UnwindSafe for Executor<'_> {} @@ -106,7 +112,7 @@ impl<'a> Executor<'a> { /// ``` pub const fn new() -> Executor<'a> { Executor { - state: OnceCell::new(), + state: AtomicPtr::new(std::ptr::null_mut()), _marker: PhantomData, } } @@ -231,7 +237,7 @@ impl<'a> Executor<'a> { // Remove the task from the set of active tasks when the future finishes. let entry = active.vacant_entry(); let index = entry.key(); - let state = self.state().clone(); + let state = self.state_as_arc(); let future = async move { let _guard = CallOnDrop(move || drop(state.active.lock().unwrap().try_remove(index))); future.await @@ -361,7 +367,7 @@ impl<'a> Executor<'a> { /// Returns a function that schedules a runnable task when it gets woken up. fn schedule(&self) -> impl Fn(Runnable) + Send + Sync + 'static { - let state = self.state().clone(); + let state = self.state_as_arc(); // TODO: If possible, push into the current local queue and notify the ticker. move |runnable| { @@ -370,34 +376,73 @@ impl<'a> Executor<'a> { } } - /// Returns a reference to the inner state. - fn state(&self) -> &Arc { - #[cfg(not(target_family = "wasm"))] - { - return self.state.get_or_init_blocking(|| Arc::new(State::new())); + /// Returns a pointer to the inner state. + #[inline] + fn state_ptr(&self) -> *const State { + #[cold] + fn alloc_state(atomic_ptr: &AtomicPtr) -> *mut State { + let state = Arc::new(State::new()); + // TODO: Switch this to use cast_mut once the MSRV can be bumped past 1.65 + let ptr = Arc::into_raw(state) as *mut State; + if let Err(actual) = atomic_ptr.compare_exchange( + std::ptr::null_mut(), + ptr, + Ordering::AcqRel, + Ordering::Acquire, + ) { + // SAFETY: This was just created from Arc::into_raw. + drop(unsafe { Arc::from_raw(ptr) }); + actual + } else { + ptr + } } - // Some projects use this on WASM for some reason. In this case get_or_init_blocking - // doesn't work. Just poll the future once and panic if there is contention. - #[cfg(target_family = "wasm")] - future::block_on(future::poll_once( - self.state.get_or_init(|| async { Arc::new(State::new()) }), - )) - .expect("encountered contention on WASM") + let mut ptr = self.state.load(Ordering::Acquire); + if ptr.is_null() { + ptr = alloc_state(&self.state); + } + ptr + } + + /// Returns a reference to the inner state. + #[inline] + fn state(&self) -> &State { + // SAFETY: So long as an Executor lives, it's state pointer will always be valid + // when accessed through state_ptr. + unsafe { &*self.state_ptr() } + } + + // Clones the inner state Arc + #[inline] + fn state_as_arc(&self) -> Arc { + // SAFETY: So long as an Executor lives, it's state pointer will always be a valid + // Arc when accessed through state_ptr. + let arc = unsafe { Arc::from_raw(self.state_ptr()) }; + let clone = arc.clone(); + std::mem::forget(arc); + clone } } impl Drop for Executor<'_> { fn drop(&mut self) { - if let Some(state) = self.state.get() { - let mut active = state.active.lock().unwrap_or_else(|e| e.into_inner()); - for w in active.drain() { - w.wake(); - } - drop(active); + let ptr = *self.state.get_mut(); + if ptr.is_null() { + return; + } + + // SAFETY: As ptr is not null, it was allocated via Arc::new and converted + // via Arc::into_raw in state_ptr. + let state = unsafe { Arc::from_raw(ptr) }; - while state.queue.pop().is_ok() {} + let mut active = state.active.lock().unwrap_or_else(|e| e.into_inner()); + for w in active.drain() { + w.wake(); } + drop(active); + + while state.queue.pop().is_ok() {} } } @@ -718,9 +763,7 @@ impl Sleepers { fn update(&mut self, id: usize, waker: &Waker) -> bool { for item in &mut self.wakers { if item.0 == id { - if !item.1.will_wake(waker) { - item.1.clone_from(waker); - } + item.1.clone_from(waker); return false; } } @@ -1006,21 +1049,24 @@ fn steal(src: &ConcurrentQueue, dest: &ConcurrentQueue) { /// Debug implementation for `Executor` and `LocalExecutor`. fn debug_executor(executor: &Executor<'_>, name: &str, f: &mut fmt::Formatter<'_>) -> fmt::Result { // Get a reference to the state. - let state = match executor.state.get() { - Some(state) => state, - None => { - // The executor has not been initialized. - struct Uninitialized; - - impl fmt::Debug for Uninitialized { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("") - } + let ptr = executor.state.load(Ordering::Acquire); + if ptr.is_null() { + // The executor has not been initialized. + struct Uninitialized; + + impl fmt::Debug for Uninitialized { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("") } - - return f.debug_tuple(name).field(&Uninitialized).finish(); } - }; + + return f.debug_tuple(name).field(&Uninitialized).finish(); + } + + // SAFETY: If the state pointer is not null, it must have been + // allocated properly by Arc::new and converted via Arc::into_raw + // in state_ptr. + let state = unsafe { &*ptr }; /// Debug wrapper for the number of active tasks. struct ActiveTasks<'a>(&'a Mutex>);