diff --git a/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs index 72a3882a32..8995e69c51 100644 --- a/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs +++ b/uniffi_bindgen/src/bindings/kotlin/gen_kotlin/mod.rs @@ -196,6 +196,7 @@ pub struct KotlinWrapper<'a> { ci: &'a ComponentInterface, type_helper_code: String, type_imports: BTreeSet, + has_async_fns: bool, } impl<'a> KotlinWrapper<'a> { @@ -208,6 +209,7 @@ impl<'a> KotlinWrapper<'a> { ci, type_helper_code, type_imports, + has_async_fns: ci.has_async_fns(), } } @@ -216,6 +218,10 @@ impl<'a> KotlinWrapper<'a> { .iter_types() .map(|t| KotlinCodeOracle.find(t)) .filter_map(|ct| ct.initialization_fn()) + .chain( + self.has_async_fns + .then(|| "uniffiRustFutureContinuationCallback.register".into()), + ) .collect() } @@ -301,7 +307,9 @@ impl KotlinCodeOracle { FfiType::ForeignExecutorHandle => "USize".to_string(), FfiType::ForeignExecutorCallback => "UniFfiForeignExecutorCallback".to_string(), FfiType::RustFutureHandle => "Pointer".to_string(), - FfiType::RustFutureContinuation => "UniFffiRustFutureContinutationType".to_string(), + FfiType::RustFutureContinuationCallback => { + "UniFffiRustFutureContinuationCallbackType".to_string() + } FfiType::RustFutureContinuationData => "USize".to_string(), } } diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt b/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt index 42126a43e4..2a4405f377 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/Async.kt @@ -6,10 +6,14 @@ internal const val UNIFFI_RUST_FUTURE_POLL_MAYBE_READY = 1.toShort() internal val uniffiContinuationHandleMap = UniFfiHandleMap>() // FFI type for Rust future continuations -internal object uniffiRustFutureContinuation: UniFffiRustFutureContinutationType { +internal object uniffiRustFutureContinuationCallback: UniFffiRustFutureContinuationCallbackType { override fun callback(continuationHandle: USize, pollResult: Short) { uniffiContinuationHandleMap.remove(continuationHandle)?.resume(pollResult) } + + internal fun register(lib: _UniFFILib) { + lib.{{ ci.ffi_rust_future_continuation_callback_set().name() }}(this) + } } internal suspend fun uniffiRustCallAsync( @@ -23,7 +27,6 @@ internal suspend fun uniffiRustCallAsync( val pollResult = suspendCancellableCoroutine { continuation -> _UniFFILib.INSTANCE.{{ ci.ffi_rust_future_poll().name() }}( rustFuture, - uniffiRustFutureContinuation, uniffiContinuationHandleMap.insert(continuation) ) } @@ -36,3 +39,4 @@ internal suspend fun uniffiRustCallAsync( _UniFFILib.INSTANCE.{{ ci.ffi_rust_future_free().name() }}(rustFuture) } } + diff --git a/uniffi_bindgen/src/bindings/kotlin/templates/Helpers.kt b/uniffi_bindgen/src/bindings/kotlin/templates/Helpers.kt index 4a9605ef73..382a5f7413 100644 --- a/uniffi_bindgen/src/bindings/kotlin/templates/Helpers.kt +++ b/uniffi_bindgen/src/bindings/kotlin/templates/Helpers.kt @@ -156,6 +156,6 @@ internal class UniFfiHandleMap { } // FFI type for Rust future continuations -internal interface UniFffiRustFutureContinutationType : com.sun.jna.Callback { +internal interface UniFffiRustFutureContinuationCallbackType : com.sun.jna.Callback { fun callback(continuationHandle: USize, pollResult: Short); } diff --git a/uniffi_bindgen/src/bindings/python/gen_python/mod.rs b/uniffi_bindgen/src/bindings/python/gen_python/mod.rs index be4eefc6d9..0fd99cab02 100644 --- a/uniffi_bindgen/src/bindings/python/gen_python/mod.rs +++ b/uniffi_bindgen/src/bindings/python/gen_python/mod.rs @@ -312,7 +312,7 @@ impl PythonCodeOracle { FfiType::ForeignExecutorHandle => "ctypes.c_size_t".to_string(), FfiType::ForeignExecutorCallback => "_UNIFFI_FOREIGN_EXECUTOR_CALLBACK_T".to_string(), FfiType::RustFutureHandle => "ctypes.c_void_p".to_string(), - FfiType::RustFutureContinuation => "_UNIFFI_FUTURE_CONTINUATION_T".to_string(), + FfiType::RustFutureContinuationCallback => "_UNIFFI_FUTURE_CONTINUATION_T".to_string(), FfiType::RustFutureContinuationData => "ctypes.c_size_t".to_string(), } } diff --git a/uniffi_bindgen/src/bindings/python/templates/Async.py b/uniffi_bindgen/src/bindings/python/templates/Async.py index 5518e9b340..06e1a54605 100644 --- a/uniffi_bindgen/src/bindings/python/templates/Async.py +++ b/uniffi_bindgen/src/bindings/python/templates/Async.py @@ -2,13 +2,13 @@ _UNIFFI_RUST_FUTURE_POLL_READY = 0 _UNIFFI_RUST_FUTURE_POLL_MAYBE_READY = 1 -# Stores futures for _uniffi_continuation_func +# Stores futures for _uniffi_continuation_callback _UniffiContinuationPointerManager = _UniffiPointerManager() # Continuation callback for async functions # lift the return value or error and resolve the future, causing the async function to resume. @_UNIFFI_FUTURE_CONTINUATION_T -def _uniffi_continuation_func(future_ptr, poll_code): +def _uniffi_continuation_callback(future_ptr, poll_code): (eventloop, future) = _UniffiContinuationPointerManager.release_pointer(future_ptr) eventloop.call_soon_threadsafe(_uniffi_set_future_result, future, poll_code) @@ -25,7 +25,6 @@ async def _uniffi_rust_call_async(rust_future, ffi_complete, lift_func, error_ff future = eventloop.create_future() _UniffiLib.{{ ci.ffi_rust_future_poll().name() }}( rust_future, - _uniffi_continuation_func, _UniffiContinuationPointerManager.new_pointer((eventloop, future)), ) poll_code = await future @@ -37,3 +36,5 @@ async def _uniffi_rust_call_async(rust_future, ffi_complete, lift_func, error_ff ) finally: _UniffiLib.{{ ci.ffi_rust_future_free().name() }}(rust_future) + +_UniffiLib.{{ ci.ffi_rust_future_continuation_callback_set().name() }}(_uniffi_continuation_callback) diff --git a/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs b/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs index 22055fc6fe..1f7260d00b 100644 --- a/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs +++ b/uniffi_bindgen/src/bindings/ruby/gen_ruby/mod.rs @@ -163,7 +163,7 @@ mod filters { unimplemented!("Foreign executors are not implemented") } FfiType::RustFutureHandle - | FfiType::RustFutureContinuation + | FfiType::RustFutureContinuationCallback | FfiType::RustFutureContinuationData => { unimplemented!("Async functions are not implemented") } diff --git a/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs b/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs index 2e2d938c9d..0855bcb282 100644 --- a/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs +++ b/uniffi_bindgen/src/bindings/swift/gen_swift/mod.rs @@ -342,6 +342,7 @@ pub struct SwiftWrapper<'a> { ci: &'a ComponentInterface, type_helper_code: String, type_imports: BTreeSet, + has_async_fns: bool, } impl<'a> SwiftWrapper<'a> { pub fn new(config: Config, ci: &'a ComponentInterface) -> Self { @@ -353,6 +354,7 @@ impl<'a> SwiftWrapper<'a> { ci, type_helper_code, type_imports, + has_async_fns: ci.has_async_fns(), } } @@ -365,6 +367,10 @@ impl<'a> SwiftWrapper<'a> { .iter_types() .map(|t| SwiftCodeOracle.find(t)) .filter_map(|ct| ct.initialization_fn()) + .chain( + self.has_async_fns + .then(|| "uniffiInitContinuationCallback".into()), + ) .collect() } } @@ -463,7 +469,7 @@ impl SwiftCodeOracle { FfiType::ForeignCallback => "ForeignCallback".into(), FfiType::ForeignExecutorHandle => "Int".into(), FfiType::ForeignExecutorCallback => "ForeignExecutorCallback".into(), - FfiType::RustFutureContinuation => "UniFfiRustFutureContinuation".into(), + FfiType::RustFutureContinuationCallback => "UniFfiRustFutureContinuation".into(), FfiType::RustFutureHandle | FfiType::RustFutureContinuationData => { "UnsafeMutableRawPointer".into() } @@ -475,7 +481,7 @@ impl SwiftCodeOracle { FfiType::ForeignCallback | FfiType::ForeignExecutorCallback | FfiType::RustFutureHandle - | FfiType::RustFutureContinuation + | FfiType::RustFutureContinuationCallback | FfiType::RustFutureContinuationData => { format!("{} _Nonnull", self.ffi_type_label_raw(ffi_type)) } @@ -560,7 +566,9 @@ pub mod filters { FfiType::ForeignCallback => "ForeignCallback _Nonnull".into(), FfiType::ForeignExecutorCallback => "UniFfiForeignExecutorCallback _Nonnull".into(), FfiType::ForeignExecutorHandle => "size_t".into(), - FfiType::RustFutureContinuation => "UniFfiRustFutureContinuation _Nonnull".into(), + FfiType::RustFutureContinuationCallback => { + "UniFfiRustFutureContinuation _Nonnull".into() + } FfiType::RustFutureHandle | FfiType::RustFutureContinuationData => { "void* _Nonnull".into() } diff --git a/uniffi_bindgen/src/bindings/swift/templates/Async.swift b/uniffi_bindgen/src/bindings/swift/templates/Async.swift index 5409218a7f..0a2ed6dee9 100644 --- a/uniffi_bindgen/src/bindings/swift/templates/Async.swift +++ b/uniffi_bindgen/src/bindings/swift/templates/Async.swift @@ -20,7 +20,6 @@ internal func uniffiRustCallAsync( pollResult = await withUnsafeContinuation { {{ ci.ffi_rust_future_poll().name() }}( rustFuture, - uniffiFutureContinuation, ContinuationHolder($0).toOpaque() ) } @@ -37,7 +36,7 @@ internal func uniffiRustCallAsync( // Callback handlers for an async calls. These are invoked by Rust when the future is ready. They // lift the return value or error and resume the suspended function. -fileprivate func uniffiFutureContinuation(ptr: UnsafeMutableRawPointer, pollResult: Int8) { +fileprivate func uniffiFutureContinuationCallback(ptr: UnsafeMutableRawPointer, pollResult: Int8) { ContinuationHolder.fromOpaque(ptr).resume(pollResult) } @@ -62,3 +61,7 @@ class ContinuationHolder { return Unmanaged.fromOpaque(ptr).takeRetainedValue() } } + +fileprivate func uniffiInitContinuationCallback() { + {{ ci.ffi_rust_future_continuation_callback_set().name() }}(uniffiFutureContinuationCallback) +} diff --git a/uniffi_bindgen/src/interface/ffi.rs b/uniffi_bindgen/src/interface/ffi.rs index dea7a6995b..d18aaf8262 100644 --- a/uniffi_bindgen/src/interface/ffi.rs +++ b/uniffi_bindgen/src/interface/ffi.rs @@ -57,7 +57,7 @@ pub enum FfiType { /// Pointer to a Rust future RustFutureHandle, /// Continuation function for a Rust future - RustFutureContinuation, + RustFutureContinuationCallback, RustFutureContinuationData, // TODO: you can imagine a richer structural typesystem here, e.g. `Ref` or something. // We don't need that yet and it's possible we never will, so it isn't here for now. diff --git a/uniffi_bindgen/src/interface/mod.rs b/uniffi_bindgen/src/interface/mod.rs index ee80f3ead8..ee0de67172 100644 --- a/uniffi_bindgen/src/interface/mod.rs +++ b/uniffi_bindgen/src/interface/mod.rs @@ -428,6 +428,24 @@ impl ComponentInterface { } } + /// Builtin FFI function to set the Rust Future continuation callback + pub fn ffi_rust_future_continuation_callback_set(&self) -> FfiFunction { + FfiFunction { + name: format!( + "ffi_{}_rust_future_continuation_callback_set", + self.ffi_namespace() + ), + arguments: vec![FfiArgument { + name: "callback".to_owned(), + type_: FfiType::RustFutureContinuationCallback, + }], + return_type: None, + is_async: false, + has_rust_call_status_arg: false, + is_object_free_function: false, + } + } + /// Builtin FFI function to poll a Rust future. pub fn ffi_rust_future_poll(&self) -> FfiFunction { FfiFunction { @@ -438,11 +456,6 @@ impl ComponentInterface { name: "handle".to_owned(), type_: FfiType::RustFutureHandle, }, - // Continuation to call when the future can make progress - FfiArgument { - name: "continuation".into(), - type_: FfiType::RustFutureContinuation, - }, // Data to pass to the continuation FfiArgument { name: "uniffi_callback".to_owned(), @@ -612,6 +625,7 @@ impl ComponentInterface { /// List all FFI functions definitions for async functionality. pub fn iter_futures_ffi_function_definitons(&self) -> impl Iterator { [ + self.ffi_rust_future_continuation_callback_set(), self.ffi_rust_future_poll(), self.ffi_rust_future_cancel(), self.ffi_rust_future_free(), diff --git a/uniffi_bindgen/src/scaffolding/mod.rs b/uniffi_bindgen/src/scaffolding/mod.rs index 5264023925..8271d70efe 100644 --- a/uniffi_bindgen/src/scaffolding/mod.rs +++ b/uniffi_bindgen/src/scaffolding/mod.rs @@ -85,7 +85,7 @@ mod filters { FfiType::ForeignBytes => "::uniffi::ForeignBytes".into(), FfiType::ForeignCallback => "::uniffi::ForeignCallback".into(), FfiType::RustFutureHandle => "::uniffi::RustFutureHandle".into(), - FfiType::RustFutureContinuation => "::uniffi::RustFutureContinuation".into(), + FfiType::RustFutureContinuationCallback => "::uniffi::RustFutureContinuation".into(), FfiType::RustFutureContinuationData => "*const ()".into(), FfiType::ForeignExecutorHandle => "::uniffi::ForeignExecutorHandle".into(), FfiType::ForeignExecutorCallback => "::uniffi::ForeignExecutorCallback".into(), diff --git a/uniffi_core/src/ffi/foreigncallbacks.rs b/uniffi_core/src/ffi/foreigncallbacks.rs index ac2463cd8e..ffdf3aa958 100644 --- a/uniffi_core/src/ffi/foreigncallbacks.rs +++ b/uniffi_core/src/ffi/foreigncallbacks.rs @@ -10,7 +10,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; -use crate::{ForeignExecutorHandle, RustBuffer, RustTaskCallback}; +use crate::{ForeignExecutorHandle, RustBuffer, RustFuturePoll, RustTaskCallback}; /// ForeignCallback is the Rust representation of a foreign language function. /// It is the basis for all callbacks interfaces. It is registered exactly once per callback interface, @@ -56,12 +56,21 @@ pub type ForeignExecutorCallback = extern "C" fn( task_data: *const (), ) -> i8; +/// Foreign callback that's passed to [rust_future_poll] +/// +/// The Rust side of things calls this when the foreign side should call [rust_future_poll] again +/// to continue progress on the future. +pub type RustFutureContinuationCallback = extern "C" fn(callback_data: *const (), RustFuturePoll); + /// Store a [ForeignCallback] pointer pub(crate) struct ForeignCallbackCell(AtomicUsize); /// Store a [ForeignExecutorCallback] pointer pub(crate) struct ForeignExecutorCallbackCell(AtomicUsize); +/// Store a [RustFutureContinuationCallback] pointer +pub(crate) struct RustFutureContinuationCallbackCell(AtomicUsize); + /// Macro to define foreign callback types as well as the callback cell. macro_rules! impl_foreign_callback_cell { ($callback_type:ident, $cell_type:ident) => { @@ -101,3 +110,7 @@ macro_rules! impl_foreign_callback_cell { impl_foreign_callback_cell!(ForeignCallback, ForeignCallbackCell); impl_foreign_callback_cell!(ForeignExecutorCallback, ForeignExecutorCallbackCell); +impl_foreign_callback_cell!( + RustFutureContinuationCallback, + RustFutureContinuationCallbackCell +); diff --git a/uniffi_core/src/ffi/rustfuture.rs b/uniffi_core/src/ffi/rustfuture.rs index 336c985534..c7cfd90485 100644 --- a/uniffi_core/src/ffi/rustfuture.rs +++ b/uniffi_core/src/ffi/rustfuture.rs @@ -10,6 +10,8 @@ //! //! We implement async foreign functions using a simplified version of the Future API: //! +//! 0. At startup, register a [RustFutureContinuationCallback] by calling +//! rust_future_continuation_callback_set. //! 1. Call the scaffolding function to get a [RustFutureHandle] //! 2a. In a loop: //! - Call [rust_future_poll] @@ -73,18 +75,18 @@ //! [`Waker`]: https://doc.rust-lang.org/std/task/struct.Waker.html //! [`RawWaker`]: https://doc.rust-lang.org/std/task/struct.RawWaker.html -use crate::{rust_call_with_out_status, FfiConverter, FfiDefault, RustCallStatus}; +use crate::{ + rust_call_with_out_status, FfiConverter, FfiDefault, RustCallStatus, + RustFutureContinuationCallback, RustFutureContinuationCallbackCell, +}; use std::{ - cell::UnsafeCell, future::Future, marker::PhantomData, + mem, ops::Deref, panic, pin::Pin, - sync::{ - atomic::{AtomicU8, Ordering}, - Arc, Mutex, - }, + sync::{Arc, Mutex}, task::{Context, Poll, Wake}, }; @@ -102,11 +104,18 @@ pub enum RustFuturePoll { #[repr(transparent)] pub struct RustFutureHandle(*const ()); -/// Foreign callback that's passed to [rust_future_poll] -/// -/// The Rust side of things calls this when the foreign side should call [rust_future_poll] and -/// continue progress on the future. -pub type RustFutureContinuation = extern "C" fn(callback_data: *const (), status: RustFuturePoll); +/// Stores the global continuation callback +static RUST_FUTURE_CONTINUATION_CALLBACK_CELL: RustFutureContinuationCallbackCell = + RustFutureContinuationCallbackCell::new(); + +/// Set the global RustFutureContinuationCallback. +pub fn rust_future_continuation_callback_set(callback: RustFutureContinuationCallback) { + RUST_FUTURE_CONTINUATION_CALLBACK_CELL.set(callback); +} + +fn call_continuation(data: *const (), poll_code: RustFuturePoll) { + RUST_FUTURE_CONTINUATION_CALLBACK_CELL.get()(data, poll_code) +} // === Public FFI API === @@ -146,13 +155,9 @@ where /// # Safety /// /// The [RustFutureHandle] must not previously have been passed to [rust_future_free] -pub unsafe fn rust_future_poll( - handle: RustFutureHandle, - continuation: RustFutureContinuation, - data: *const (), -) { +pub unsafe fn rust_future_poll(handle: RustFutureHandle, data: *const ()) { let future = &*(handle.0 as *mut Arc); - future.clone().ffi_poll(continuation, data) + future.clone().ffi_poll(data) } /// Cancel a Rust future @@ -186,7 +191,7 @@ pub unsafe fn rust_future_complete( ) -> T { let future = &*(handle.0 as *mut Arc); let mut return_value = T::ffi_default(); - let out_return = std::mem::transmute::<&mut T, &mut ()>(&mut return_value); + let out_return = mem::transmute::<&mut T, &mut ()>(&mut return_value); future.ffi_complete(out_return, out_status); return_value } @@ -202,135 +207,68 @@ pub unsafe fn rust_future_free(handle: RustFutureHandle) { future.ffi_free() } -/// Thread-safe storage for a RustFutureContinuation +/// Thread-safe storage for [RustFutureContinuationCallback] data /// -/// The basic guarantee is that all continuations passed to [Self::store] are called exactly once -/// (assuming that [Self::try_call_continuation] is called after the last store). This enables us to -/// uphold the [rust_future_poll] guarantee. +/// The basic guarantee is that all data pointers passed in are passed out exactly once to the +/// foreign continuation callback. This enables us to uphold the [rust_future_poll] guarantee. /// -/// AtomicContinuationCell uses atomic trickery to make all operations thread-safe but non-blocking. -struct AtomicContinuationCell { - state: AtomicU8, - stored: UnsafeCell>, +/// [ContinuationDataCell] also tracks cancellation, which is closely tied to continuation data. +enum ContinuationDataCell { + Empty, + Cancelled, + Set(*const ()), } -impl AtomicContinuationCell { - /// Lock bit - const STATE_LOCK: u8 = 1 << 0; - /// Bit signalling that we should call the continuation - const STATE_NEEDS_CALL: u8 = 1 << 1; - /// Bit signalling that the RustFuture has been cancelled - const STATE_CANCELLED: u8 = 1 << 2; - +impl ContinuationDataCell { fn new() -> Self { - Self { - state: AtomicU8::new(0), - stored: UnsafeCell::new(None), - } - } - - /// Try to take a lock, optionally setting the other bits - fn try_lock(&self, extra_bits: u8) -> bool { - let prev_state = self - .state - .fetch_or(Self::STATE_LOCK | extra_bits, Ordering::Acquire); - (prev_state & Self::STATE_LOCK) == 0 - } - - /// Release a lock, calling any stored continuation - fn unlock_and_call(&self) { - self.call_continuation_unchecked(); - self.state.fetch_and( - !(Self::STATE_LOCK | Self::STATE_NEEDS_CALL), - Ordering::Release, - ); + Self::Empty } - /// Release a lock with the intention of keeping a stored continuation - /// - /// However, if another thread set the STATE_NEEDS_CALL or STATE_READY bit, then instead call - /// the stored continuation for them. - fn unlock_and_store(&self, new_continuation: RustFutureContinuation, data: *const ()) { - // Set the continuation - let stored = unsafe { &mut *self.stored.get() }; - if stored.is_some() { - log::error!("AtomicContinuationCell::unlock_and_store: continuation already set"); - self.call_continuation_unchecked(); + /// Store new continuation data + fn store(&mut self, data: *const ()) { + // If we're cancelled, then call the continuation immediately rather than storing it + if matches!(self, Self::Cancelled) { + call_continuation(data, RustFuturePoll::Ready); + return; } - *stored = Some((new_continuation, data)); - - match self - .state - .compare_exchange(Self::STATE_LOCK, 0, Ordering::Release, Ordering::Relaxed) - { - // Success! - Ok(_) => (), - Err(_) => { - // Another thread set the STATE_NEEDS_CALL or STATE_READY bit, so we should call the - // continuation for them. - self.call_continuation_unchecked(); - // We can now unlock unconditionally - self.state.fetch_and( - !(Self::STATE_LOCK | Self::STATE_NEEDS_CALL), - Ordering::Release, + + match mem::replace(self, Self::Set(data)) { + Self::Empty => (), + Self::Cancelled => unreachable!(), + Self::Set(old_data) => { + log::error!( + "store: observed Self::Set state, is poll() being called from multiple threads at once?" ); + call_continuation(old_data, RustFuturePoll::Ready); } } } - // Take the data out of self.continuation. If it was set, then call the continuation. - // - // Only call this if you have the lock - fn call_continuation_unchecked(&self) { - let stored = unsafe { &mut *self.stored.get() }; - if let Some((continuation, data)) = stored.take() { - continuation(data, self.poll_code()); + fn send(&mut self) { + if matches!(self, Self::Cancelled) { + return; } - } - fn try_call_continuation(&self, cancelled: bool) { - let extra_bits = if cancelled { - Self::STATE_NEEDS_CALL | Self::STATE_CANCELLED - } else { - Self::STATE_NEEDS_CALL - }; - if self.try_lock(extra_bits) { - self.unlock_and_call(); + if let Self::Set(old_data) = mem::replace(self, Self::Empty) { + call_continuation(old_data, RustFuturePoll::MaybeReady); } } - fn store(&self, continuation: RustFutureContinuation, data: *const ()) { - if self.try_lock(0) { - self.unlock_and_store(continuation, data); - } else { - // Failed to acquire the lock - // - If the other thread was calling `try_call_continuation`, that means they locked us out - // just before we could store the continuation. - // - If the other thread was calling `store`, then something weird happened and - // there's already a stored continuation. - // - // In either case, call the continuation now. - continuation(data, self.poll_code()); - } - } - - fn poll_code(&self) -> RustFuturePoll { - if self.state.load(Ordering::Relaxed) & Self::STATE_CANCELLED == 0 { - RustFuturePoll::MaybeReady - } else { - RustFuturePoll::Ready + fn cancel(&mut self) { + if let Self::Set(old_data) = mem::replace(self, Self::Cancelled) { + call_continuation(old_data, RustFuturePoll::Ready); } } fn is_cancelled(&self) -> bool { - self.state.load(Ordering::Relaxed) & Self::STATE_CANCELLED != 0 + matches!(self, Self::Cancelled) } } -// AtomicContinuationCell is Send + Sync as long the previous code is working correctly. +// ContinuationDataCell is Send + Sync as long we handle the *const () pointer correctly -unsafe impl Send for AtomicContinuationCell {} -unsafe impl Sync for AtomicContinuationCell {} +unsafe impl Send for ContinuationDataCell {} +unsafe impl Sync for ContinuationDataCell {} /// Wraps the actual future we're polling struct WrappedFuture @@ -444,7 +382,7 @@ where // This Mutex should never block if our code is working correctly, since there should not be // multiple threads calling [Self::poll] and/or [Self::complete] at the same time. future: Mutex>, - continuation: AtomicContinuationCell, + continuation_data: Mutex, // UT is used as the generic parameter for FfiConverter. // Let's model this with PhantomData as a function that inputs a UT value. _phantom: PhantomData ()>, @@ -460,30 +398,34 @@ where fn new(future: F, _tag: UT) -> Arc { Arc::new(Self { future: Mutex::new(WrappedFuture::new(future)), - continuation: AtomicContinuationCell::new(), + continuation_data: Mutex::new(ContinuationDataCell::new()), _phantom: PhantomData, }) } - fn poll(self: Arc, new_continuation: RustFutureContinuation, data: *const ()) { - let ready = self.continuation.is_cancelled() || { + fn poll(self: Arc, data: *const ()) { + let ready = self.is_cancelled() || { let mut locked = self.future.lock().unwrap(); let waker: std::task::Waker = Arc::clone(&self).into(); locked.poll(&mut Context::from_waker(&waker)) }; if ready { - new_continuation(data, RustFuturePoll::Ready); + call_continuation(data, RustFuturePoll::Ready) } else { - self.continuation.store(new_continuation, data); + self.continuation_data.lock().unwrap().store(data); } } + fn is_cancelled(&self) -> bool { + self.continuation_data.lock().unwrap().is_cancelled() + } + fn wake(&self) { - self.continuation.try_call_continuation(false) + self.continuation_data.lock().unwrap().send(); } fn cancel(&self) { - self.continuation.try_call_continuation(true); + self.continuation_data.lock().unwrap().cancel(); } fn complete(&self, return_value: &mut T::ReturnType, call_status: &mut RustCallStatus) { @@ -494,8 +436,8 @@ where } fn free(self: Arc) { - // Call any leftover continuation callbacks now - self.continuation.try_call_continuation(true); + // Call cancel() to send any leftover data to the continuation callback + self.continuation_data.lock().unwrap().cancel(); // Ensure we drop our inner future, releasing all held references self.future.lock().unwrap().free(); } @@ -523,7 +465,7 @@ where /// unnamable. #[doc(hidden)] trait RustFutureFfi { - fn ffi_poll(self: Arc, continuation: RustFutureContinuation, data: *const ()); + fn ffi_poll(self: Arc, data: *const ()); fn ffi_cancel(&self); unsafe fn ffi_complete(&self, out_return: &mut (), call_status: &mut RustCallStatus); fn ffi_free(self: Arc); @@ -536,8 +478,8 @@ where T: FfiConverter + Send + 'static, UT: Send + 'static, { - fn ffi_poll(self: Arc, continuation: RustFutureContinuation, data: *const ()) { - self.poll(continuation, data) + fn ffi_poll(self: Arc, data: *const ()) { + self.poll(data) } fn ffi_cancel(&self) { @@ -624,10 +566,14 @@ mod tests { fn poll(rust_future: &Arc) -> OnceCell { let cell = OnceCell::new(); let cell_ptr = &cell as *const OnceCell as *const (); - rust_future.clone().ffi_poll(poll_continuation, cell_ptr); + rust_future.clone().ffi_poll(cell_ptr); cell } + fn setup_continuation_callback() { + RUST_FUTURE_CONTINUATION_CALLBACK_CELL.set(poll_continuation); + } + extern "C" fn poll_continuation(data: *const (), code: RustFuturePoll) { let cell = unsafe { &*(data as *const OnceCell) }; cell.set(code).expect("Error setting OnceCell"); @@ -647,6 +593,7 @@ mod tests { #[test] fn test_success() { + setup_continuation_callback(); let (sender, rust_future) = channel(); // Test polling the rust future before it's ready @@ -676,6 +623,7 @@ mod tests { #[test] fn test_error() { + setup_continuation_callback(); let (sender, rust_future) = channel(); let continuation_result = poll(&rust_future); @@ -703,6 +651,7 @@ mod tests { // reference to the RustFuture #[test] fn test_cancel() { + setup_continuation_callback(); let (_sender, rust_future) = channel(); let continuation_result = poll(&rust_future); @@ -723,6 +672,7 @@ mod tests { // reference to the RustFuture #[test] fn test_release_future() { + setup_continuation_callback(); let (sender, rust_future) = channel(); // Create a weak reference to the channel to use to check if rust_future has dropped its // future. @@ -745,6 +695,7 @@ mod tests { // This shouldn't happen in practice, but it seems like good defensive programming #[test] fn test_complete_with_stored_continuation() { + setup_continuation_callback(); let (_sender, rust_future) = channel(); let continuation_result = poll(&rust_future); diff --git a/uniffi_macros/src/setup_scaffolding.rs b/uniffi_macros/src/setup_scaffolding.rs index 24d972f712..313331cc84 100644 --- a/uniffi_macros/src/setup_scaffolding.rs +++ b/uniffi_macros/src/setup_scaffolding.rs @@ -22,6 +22,8 @@ pub fn setup_scaffolding(namespace: String) -> Result { let reexport_hack_ident = format_ident!("{module_path}_uniffi_reexport_hack"); let ffi_foreign_executor_callback_set_ident = format_ident!("ffi_{module_path}_foreign_executor_callback_set"); + let ffi_rust_future_continuation_callback_set = + format_ident!("ffi_{module_path}_rust_future_continuation_callback_set"); let ffi_rust_future_poll = format_ident!("ffi_{module_path}_rust_future_poll"); let ffi_rust_future_cancel = format_ident!("ffi_{module_path}_rust_future_cancel"); let ffi_rust_future_free = format_ident!("ffi_{module_path}_rust_future_free"); @@ -100,12 +102,15 @@ pub fn setup_scaffolding(namespace: String) -> Result { #[allow(clippy::missing_safety_doc, missing_docs)] #[doc(hidden)] #[no_mangle] - pub unsafe extern "C" fn #ffi_rust_future_poll( - handle: ::uniffi::RustFutureHandle, - continuation: ::uniffi::RustFutureContinuation, - data: *const () - ) { - ::uniffi::ffi::rust_future_poll(handle, continuation, data); + pub unsafe extern "C" fn #ffi_rust_future_continuation_callback_set(callback: ::uniffi::RustFutureContinuationCallback) { + ::uniffi::ffi::rust_future_continuation_callback_set(callback); + } + + #[allow(clippy::missing_safety_doc, missing_docs)] + #[doc(hidden)] + #[no_mangle] + pub unsafe extern "C" fn #ffi_rust_future_poll(handle: ::uniffi::RustFutureHandle, data: *const ()) { + ::uniffi::ffi::rust_future_poll(handle, data); } #[allow(clippy::missing_safety_doc, missing_docs)]