Skip to content

Commit

Permalink
Mutex: Only track a single locked flag
Browse files Browse the repository at this point in the history
  • Loading branch information
bugadani committed Dec 18, 2024
1 parent 2ba7faf commit 59bb8f2
Showing 1 changed file with 41 additions and 38 deletions.
79 changes: 41 additions & 38 deletions embassy-sync/src/mutex.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Async mutex.
//!
//! This module provides a mutex that can be used to synchronize data between asynchronous tasks.
use core::cell::{RefCell, UnsafeCell};
use core::cell::{Cell, UnsafeCell};
use core::future::poll_fn;
use core::ops::{Deref, DerefMut};
use core::task::Poll;
Expand All @@ -17,8 +17,39 @@ use crate::waitqueue::WakerRegistration;
pub struct TryLockError;

struct State {
locked: bool,
waker: WakerRegistration,
locked: Cell<bool>,
waker: UnsafeCell<WakerRegistration>,
}

impl State {
const fn new() -> Self {
Self {
locked: Cell::new(false),
waker: UnsafeCell::new(WakerRegistration::new()),
}
}

fn lock(&self, waker: &core::task::Waker) -> bool {
if self.locked.replace(true) {
unsafe { (&mut *self.waker.get()).register(waker) };
false
} else {
true
}
}

fn try_lock(&self) -> Result<(), TryLockError> {
if self.locked.replace(true) {
Err(TryLockError)
} else {
Ok(())
}
}

fn unlock(&self) {
unsafe { (&mut *self.waker.get()).wake() };
self.locked.set(false);
}
}

/// Async mutex.
Expand All @@ -41,7 +72,7 @@ where
M: RawMutex,
T: ?Sized,
{
state: BlockingMutex<M, RefCell<State>>,
state: BlockingMutex<M, State>,
inner: UnsafeCell<T>,
}

Expand All @@ -57,10 +88,7 @@ where
pub const fn new(value: T) -> Self {
Self {
inner: UnsafeCell::new(value),
state: BlockingMutex::new(RefCell::new(State {
locked: false,
waker: WakerRegistration::new(),
})),
state: BlockingMutex::new(State::new()),
}
}
}
Expand All @@ -75,16 +103,7 @@ where
/// This will wait for the mutex to be unlocked if it's already locked.
pub async fn lock(&self) -> MutexGuard<'_, M, T> {
poll_fn(|cx| {
let ready = self.state.lock(|s| {
let mut s = s.borrow_mut();
if s.locked {
s.waker.register(cx.waker());
false
} else {
s.locked = true;
true
}
});
let ready = self.state.lock(|s| s.lock(cx.waker()));

if ready {
Poll::Ready(MutexGuard { mutex: self })
Expand All @@ -99,15 +118,7 @@ where
///
/// If the mutex is already locked, this will return an error instead of waiting.
pub fn try_lock(&self) -> Result<MutexGuard<'_, M, T>, TryLockError> {
self.state.lock(|s| {
let mut s = s.borrow_mut();
if s.locked {
Err(TryLockError)
} else {
s.locked = true;
Ok(())
}
})?;
self.state.lock(|s| s.try_lock())?;

Ok(MutexGuard { mutex: self })
}
Expand Down Expand Up @@ -205,11 +216,7 @@ where
T: ?Sized,
{
fn drop(&mut self) {
self.mutex.state.lock(|s| {
let mut s = unwrap!(s.try_borrow_mut());
s.locked = false;
s.waker.wake();
})
self.mutex.state.lock(|s| s.unlock())
}
}

Expand Down Expand Up @@ -268,7 +275,7 @@ where
M: RawMutex,
T: ?Sized,
{
state: &'a BlockingMutex<M, RefCell<State>>,
state: &'a BlockingMutex<M, State>,
value: *mut T,
}

Expand Down Expand Up @@ -319,11 +326,7 @@ where
T: ?Sized,
{
fn drop(&mut self) {
self.state.lock(|s| {
let mut s = unwrap!(s.try_borrow_mut());
s.locked = false;
s.waker.wake();
})
self.state.lock(|s| s.unlock())
}
}

Expand Down

0 comments on commit 59bb8f2

Please sign in to comment.