Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rwlock downgrade #128219

Merged
merged 10 commits into from
Nov 18, 2024
74 changes: 70 additions & 4 deletions library/std/src/sync/rwlock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ mod tests;
use crate::cell::UnsafeCell;
use crate::fmt;
use crate::marker::PhantomData;
use crate::mem::ManuallyDrop;
use crate::mem::{ManuallyDrop, forget};
use crate::ops::{Deref, DerefMut};
use crate::ptr::NonNull;
use crate::sync::{LockResult, TryLockError, TryLockResult, poison};
use crate::sync::{LockResult, PoisonError, TryLockError, TryLockResult, poison};
use crate::sys::sync as sys;

/// A reader-writer lock
Expand Down Expand Up @@ -574,8 +574,12 @@ impl<T> From<T> for RwLock<T> {

impl<'rwlock, T: ?Sized> RwLockReadGuard<'rwlock, T> {
/// Creates a new instance of `RwLockReadGuard<T>` from a `RwLock<T>`.
// SAFETY: if and only if `lock.inner.read()` (or `lock.inner.try_read()`) has been
// successfully called from the same thread before instantiating this object.
///
/// # Safety
///
/// This function is safe if and only if the same thread has successfully and safely called
/// `lock.inner.read()`, `lock.inner.try_read()`, or `lock.inner.downgrade()` before
/// instantiating this object.
unsafe fn new(lock: &'rwlock RwLock<T>) -> LockResult<RwLockReadGuard<'rwlock, T>> {
poison::map_result(lock.poison.borrow(), |()| RwLockReadGuard {
data: unsafe { NonNull::new_unchecked(lock.data.get()) },
Expand Down Expand Up @@ -957,6 +961,68 @@ impl<'a, T: ?Sized> RwLockWriteGuard<'a, T> {
None => Err(orig),
}
}

/// Downgrades a write-locked `RwLockWriteGuard` into a read-locked [`RwLockReadGuard`].
///
/// This method will atomically change the state of the [`RwLock`] from exclusive mode into
/// shared mode. This means that it is impossible for a writing thread to get in between a
/// thread calling `downgrade` and the same thread reading whatever it wrote while it had the
/// [`RwLock`] in write mode.
///
/// Note that since we have the `RwLockWriteGuard`, we know that the [`RwLock`] is already
/// locked for writing, so this method cannot fail.
///
/// # Example
///
/// ```
/// #![feature(rwlock_downgrade)]
/// use std::sync::{Arc, RwLock, RwLockWriteGuard};
///
/// // The inner value starts as 0.
/// let rw = Arc::new(RwLock::new(0));
///
/// // Put the lock in write mode.
/// let mut main_write_guard = rw.write().unwrap();
///
/// let evil = rw.clone();
/// let handle = std::thread::spawn(move || {
/// // This will not return until the main thread drops the `main_read_guard`.
/// let mut evil_guard = evil.write().unwrap();
///
/// assert_eq!(*evil_guard, 1);
/// *evil_guard = 2;
/// });
///
/// // After spawning the writer thread, set the inner value to 1.
/// *main_write_guard = 1;
///
/// // Atomically downgrade the write guard into a read guard.
/// let main_read_guard = RwLockWriteGuard::downgrade(main_write_guard);
///
/// // Since `downgrade` is atomic, the writer thread cannot have set the inner value to 2.
/// assert_eq!(*main_read_guard, 1, "`downgrade` was not atomic");
///
/// // Clean up everything now
/// drop(main_read_guard);
/// handle.join().unwrap();
///
/// let final_check = rw.read().unwrap();
/// assert_eq!(*final_check, 2);
/// ```
#[unstable(feature = "rwlock_downgrade", issue = "128203")]
pub fn downgrade(s: Self) -> RwLockReadGuard<'a, T> {
let lock = s.lock;

// We don't want to call the destructor since that calls `write_unlock`.
forget(s);

// SAFETY: We take ownership of a write guard, so we must already have the `RwLock` in write
// mode, satisfying the `downgrade` contract.
unsafe { lock.inner.downgrade() };

// SAFETY: We have just successfully called `downgrade`, so we fulfill the safety contract.
unsafe { RwLockReadGuard::new(lock).unwrap_or_else(PoisonError::into_inner) }
}
}

impl<'a, T: ?Sized> MappedRwLockWriteGuard<'a, T> {
Expand Down
105 changes: 105 additions & 0 deletions library/std/src/sync/rwlock/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -501,3 +501,108 @@ fn panic_while_mapping_write_unlocked_poison() {

drop(lock);
}

#[test]
fn test_downgrade_basic() {
let r = RwLock::new(());

let write_guard = r.write().unwrap();
let _read_guard = RwLockWriteGuard::downgrade(write_guard);
}

#[test]
fn test_downgrade_observe() {
// Taken from the test `test_rwlock_downgrade` from:
// https://github.com/Amanieu/parking_lot/blob/master/src/rwlock.rs

const W: usize = 20;
const N: usize = 100;

// This test spawns `W` writer threads, where each will increment a counter `N` times, ensuring
// that the value they wrote has not changed after downgrading.

let rw = Arc::new(RwLock::new(0));

// Spawn the writers that will do `W * N` operations and checks.
let handles: Vec<_> = (0..W)
.map(|_| {
let rw = rw.clone();
thread::spawn(move || {
for _ in 0..N {
// Increment the counter.
let mut write_guard = rw.write().unwrap();
*write_guard += 1;
let cur_val = *write_guard;

// Downgrade the lock to read mode, where the value protected cannot be modified.
let read_guard = RwLockWriteGuard::downgrade(write_guard);
assert_eq!(cur_val, *read_guard);
}
})
})
.collect();

for handle in handles {
handle.join().unwrap();
}

assert_eq!(*rw.read().unwrap(), W * N);
}

#[test]
fn test_downgrade_atomic() {
const NEW_VALUE: i32 = -1;

// This test checks that `downgrade` is atomic, meaning as soon as a write lock has been
// downgraded, the lock must be in read mode and no other threads can take the write lock to
// modify the protected value.

// `W` is the number of evil writer threads.
const W: usize = 20;
let rwlock = Arc::new(RwLock::new(0));

// Spawns many evil writer threads that will try and write to the locked value before the
// initial writer (who has the exclusive lock) can read after it downgrades.
// If the `RwLock` behaves correctly, then the initial writer should read the value it wrote
// itself as no other thread should be able to mutate the protected value.

// Put the lock in write mode, causing all future threads trying to access this go to sleep.
let mut main_write_guard = rwlock.write().unwrap();

// Spawn all of the evil writer threads. They will each increment the protected value by 1.
let handles: Vec<_> = (0..W)
.map(|_| {
let rwlock = rwlock.clone();
thread::spawn(move || {
// Will go to sleep since the main thread initially has the write lock.
let mut evil_guard = rwlock.write().unwrap();
*evil_guard += 1;
})
})
.collect();

// Wait for a good amount of time so that evil threads go to sleep.
// Note: this is not strictly necessary...
let eternity = crate::time::Duration::from_millis(42);
thread::sleep(eternity);

// Once everyone is asleep, set the value to `NEW_VALUE`.
*main_write_guard = NEW_VALUE;

// Atomically downgrade the write guard into a read guard.
let main_read_guard = RwLockWriteGuard::downgrade(main_write_guard);

// If the above is not atomic, then it would be possible for an evil thread to get in front of
// this read and change the value to be non-negative.
assert_eq!(*main_read_guard, NEW_VALUE, "`downgrade` was not atomic");

// Drop the main read guard and allow the evil writer threads to start incrementing.
drop(main_read_guard);

for handle in handles {
handle.join().unwrap();
}

let final_check = rwlock.read().unwrap();
assert_eq!(*final_check, W as i32 + NEW_VALUE);
}
52 changes: 47 additions & 5 deletions library/std/src/sys/sync/rwlock/futex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub struct RwLock {
const READ_LOCKED: Primitive = 1;
const MASK: Primitive = (1 << 30) - 1;
const WRITE_LOCKED: Primitive = MASK;
const DOWNGRADE: Primitive = READ_LOCKED.wrapping_sub(WRITE_LOCKED); // READ_LOCKED - WRITE_LOCKED
const MAX_READERS: Primitive = MASK - 1;
const READERS_WAITING: Primitive = 1 << 30;
const WRITERS_WAITING: Primitive = 1 << 31;
Expand Down Expand Up @@ -53,6 +54,24 @@ fn is_read_lockable(state: Primitive) -> bool {
state & MASK < MAX_READERS && !has_readers_waiting(state) && !has_writers_waiting(state)
}

#[inline]
fn is_read_lockable_after_wakeup(state: Primitive) -> bool {
// We make a special case for checking if we can read-lock _after_ a reader thread that went to
// sleep has been woken up by a call to `downgrade`.
//
// `downgrade` will wake up all readers and place the lock in read mode. Thus, there should be
// no readers waiting and the lock should be read-locked (not write-locked or unlocked).
//
// Note that we do not check if any writers are waiting. This is because a call to `downgrade`
// implies that the caller wants other readers to read the value protected by the lock. If we
// did not allow readers to acquire the lock before writers after a `downgrade`, then only the
// original writer would be able to read the value, thus defeating the purpose of `downgrade`.
state & MASK < MAX_READERS
&& !has_readers_waiting(state)
&& !is_write_locked(state)
&& !is_unlocked(state)
}

#[inline]
fn has_reached_max_readers(state: Primitive) -> bool {
state & MASK == MAX_READERS
Expand Down Expand Up @@ -84,6 +103,9 @@ impl RwLock {
}
}

/// # Safety
///
/// The `RwLock` must be read-locked (N readers) in order to call this.
#[inline]
pub unsafe fn read_unlock(&self) {
let state = self.state.fetch_sub(READ_LOCKED, Release) - READ_LOCKED;
Expand All @@ -100,11 +122,13 @@ impl RwLock {

#[cold]
fn read_contended(&self) {
let mut has_slept = false;
let mut state = self.spin_read();

loop {
// If we can lock it, lock it.
if is_read_lockable(state) {
// If we have just been woken up, first check for a `downgrade` call.
// Otherwise, if we can read-lock it, lock it.
if (has_slept && is_read_lockable_after_wakeup(state)) || is_read_lockable(state) {
match self.state.compare_exchange_weak(state, state + READ_LOCKED, Acquire, Relaxed)
{
Ok(_) => return, // Locked!
Expand All @@ -116,9 +140,7 @@ impl RwLock {
}

// Check for overflow.
if has_reached_max_readers(state) {
panic!("too many active read locks on RwLock");
}
assert!(!has_reached_max_readers(state), "too many active read locks on RwLock");

// Make sure the readers waiting bit is set before we go to sleep.
if !has_readers_waiting(state) {
Expand All @@ -132,6 +154,7 @@ impl RwLock {

// Wait for the state to change.
futex_wait(&self.state, state | READERS_WAITING, None);
has_slept = true;

// Spin again after waking up.
state = self.spin_read();
Expand All @@ -152,6 +175,9 @@ impl RwLock {
}
}

/// # Safety
///
/// The `RwLock` must be write-locked (single writer) in order to call this.
#[inline]
pub unsafe fn write_unlock(&self) {
let state = self.state.fetch_sub(WRITE_LOCKED, Release) - WRITE_LOCKED;
Expand All @@ -163,6 +189,22 @@ impl RwLock {
}
}

/// # Safety
///
/// The `RwLock` must be write-locked (single writer) in order to call this.
#[inline]
pub unsafe fn downgrade(&self) {
// Removes all write bits and adds a single read bit.
let state = self.state.fetch_add(DOWNGRADE, Release);
debug_assert!(is_write_locked(state), "RwLock must be write locked to call `downgrade`");

if has_readers_waiting(state) {
// Since we had the exclusive lock, nobody else can unset this bit.
self.state.fetch_sub(READERS_WAITING, Relaxed);
futex_wake_all(&self.state);
}
}

#[cold]
fn write_contended(&self) {
let mut state = self.spin_write();
Expand Down
5 changes: 5 additions & 0 deletions library/std/src/sys/sync/rwlock/no_threads.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,9 @@ impl RwLock {
pub unsafe fn write_unlock(&self) {
assert_eq!(self.mode.replace(0), -1);
}

#[inline]
pub unsafe fn downgrade(&self) {
assert_eq!(self.mode.replace(1), -1);
}
}
Loading
Loading