diff --git a/embassy-sync/src/mutex.rs b/embassy-sync/src/mutex.rs index 08f66e3747..6c2358f216 100644 --- a/embassy-sync/src/mutex.rs +++ b/embassy-sync/src/mutex.rs @@ -1,7 +1,7 @@ //! Async mutex. //! //! This module provides a mutex that can be used to synchronize data between asynchronous tasks. -use core::cell::{RefCell, UnsafeCell}; +use core::cell::{Cell, UnsafeCell}; use core::future::poll_fn; use core::ops::{Deref, DerefMut}; use core::task::Poll; @@ -17,8 +17,39 @@ use crate::waitqueue::WakerRegistration; pub struct TryLockError; struct State { - locked: bool, - waker: WakerRegistration, + locked: Cell, + waker: UnsafeCell, +} + +impl State { + const fn new() -> Self { + Self { + locked: Cell::new(false), + waker: UnsafeCell::new(WakerRegistration::new()), + } + } + + fn lock(&self, waker: &core::task::Waker) -> bool { + if self.locked.replace(true) { + unsafe { (&mut *self.waker.get()).register(waker) }; + false + } else { + true + } + } + + fn try_lock(&self) -> Result<(), TryLockError> { + if self.locked.replace(true) { + Err(TryLockError) + } else { + Ok(()) + } + } + + fn unlock(&self) { + unsafe { (&mut *self.waker.get()).wake() }; + self.locked.set(false); + } } /// Async mutex. @@ -41,7 +72,7 @@ where M: RawMutex, T: ?Sized, { - state: BlockingMutex>, + state: BlockingMutex, inner: UnsafeCell, } @@ -57,10 +88,7 @@ where pub const fn new(value: T) -> Self { Self { inner: UnsafeCell::new(value), - state: BlockingMutex::new(RefCell::new(State { - locked: false, - waker: WakerRegistration::new(), - })), + state: BlockingMutex::new(State::new()), } } } @@ -75,16 +103,7 @@ where /// This will wait for the mutex to be unlocked if it's already locked. pub async fn lock(&self) -> MutexGuard<'_, M, T> { poll_fn(|cx| { - let ready = self.state.lock(|s| { - let mut s = s.borrow_mut(); - if s.locked { - s.waker.register(cx.waker()); - false - } else { - s.locked = true; - true - } - }); + let ready = self.state.lock(|s| s.lock(cx.waker())); if ready { Poll::Ready(MutexGuard { mutex: self }) @@ -99,15 +118,7 @@ where /// /// If the mutex is already locked, this will return an error instead of waiting. pub fn try_lock(&self) -> Result, TryLockError> { - self.state.lock(|s| { - let mut s = s.borrow_mut(); - if s.locked { - Err(TryLockError) - } else { - s.locked = true; - Ok(()) - } - })?; + self.state.lock(|s| s.try_lock())?; Ok(MutexGuard { mutex: self }) } @@ -205,11 +216,7 @@ where T: ?Sized, { fn drop(&mut self) { - self.mutex.state.lock(|s| { - let mut s = unwrap!(s.try_borrow_mut()); - s.locked = false; - s.waker.wake(); - }) + self.mutex.state.lock(|s| s.unlock()) } } @@ -268,7 +275,7 @@ where M: RawMutex, T: ?Sized, { - state: &'a BlockingMutex>, + state: &'a BlockingMutex, value: *mut T, } @@ -319,11 +326,7 @@ where T: ?Sized, { fn drop(&mut self) { - self.state.lock(|s| { - let mut s = unwrap!(s.try_borrow_mut()); - s.locked = false; - s.waker.wake(); - }) + self.state.lock(|s| s.unlock()) } }