Skip to content

Commit

Permalink
Merge pull request rust-lang#4027 from RalfJung/no-clones
Browse files Browse the repository at this point in the history
sync support: dont implicitly clone inside the general sync machinery
  • Loading branch information
RalfJung authored Nov 11, 2024
2 parents 7137683 + ad16dc4 commit 7c7371a
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 41 deletions.
38 changes: 21 additions & 17 deletions src/tools/miri/src/concurrency/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,16 @@ impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
/// Helper for lazily initialized `alloc_extra.sync` data:
/// this forces an immediate init.
fn lazy_sync_init<T: 'static>(
&mut self,
/// Return a reference to the data in the machine state.
fn lazy_sync_init<'a, T: 'static>(
&'a mut self,
primitive: &MPlaceTy<'tcx>,
init_offset: Size,
data: T,
) -> InterpResult<'tcx> {
) -> InterpResult<'tcx, &'a T>
where
'tcx: 'a,
{
let this = self.eval_context_mut();

let (alloc, offset, _) = this.ptr_get_alloc_id(primitive.ptr(), 0)?;
Expand All @@ -239,7 +243,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
&init_field,
AtomicWriteOrd::Relaxed,
)?;
interp_ok(())
interp_ok(this.get_alloc_extra(alloc)?.get_sync::<T>(offset).unwrap())
}

/// Helper for lazily initialized `alloc_extra.sync` data:
Expand All @@ -248,15 +252,17 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
/// and stores that in `alloc_extra.sync`.
/// - Otherwise, calls `new_data` to initialize the primitive.
///
/// The return value is a *clone* of the stored data, so if you intend to mutate it
/// better wrap everything into an `Rc`.
fn lazy_sync_get_data<T: 'static + Clone>(
&mut self,
/// Return a reference to the data in the machine state.
fn lazy_sync_get_data<'a, T: 'static>(
&'a mut self,
primitive: &MPlaceTy<'tcx>,
init_offset: Size,
missing_data: impl FnOnce() -> InterpResult<'tcx, T>,
new_data: impl FnOnce(&mut MiriInterpCx<'tcx>) -> InterpResult<'tcx, T>,
) -> InterpResult<'tcx, T> {
) -> InterpResult<'tcx, &'a T>
where
'tcx: 'a,
{
let this = self.eval_context_mut();

// Check if this is already initialized. Needs to be atomic because we can race with another
Expand All @@ -280,17 +286,15 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
// or else it has been moved illegally.
let (alloc, offset, _) = this.ptr_get_alloc_id(primitive.ptr(), 0)?;
let (alloc_extra, _machine) = this.get_alloc_extra_mut(alloc)?;
if let Some(data) = alloc_extra.get_sync::<T>(offset) {
interp_ok(data.clone())
} else {
// Due to borrow checker reasons, we have to do the lookup twice.
if alloc_extra.get_sync::<T>(offset).is_none() {
let data = missing_data()?;
alloc_extra.sync.insert(offset, Box::new(data.clone()));
interp_ok(data)
alloc_extra.sync.insert(offset, Box::new(data));
}
interp_ok(alloc_extra.get_sync::<T>(offset).unwrap())
} else {
let data = new_data(this)?;
this.lazy_sync_init(primitive, init_offset, data.clone())?;
interp_ok(data)
this.lazy_sync_init(primitive, init_offset, data)
}
}

Expand Down Expand Up @@ -326,7 +330,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {

#[inline]
/// Get the id of the thread that currently owns this lock.
fn mutex_get_owner(&mut self, mutex_ref: &MutexRef) -> ThreadId {
fn mutex_get_owner(&self, mutex_ref: &MutexRef) -> ThreadId {
mutex_ref.0.borrow().owner.unwrap()
}

Expand Down
16 changes: 13 additions & 3 deletions src/tools/miri/src/shims/unix/macos/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ enum MacOsUnfairLock {

impl<'tcx> EvalContextExtPriv<'tcx> for crate::MiriInterpCx<'tcx> {}
trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> {
fn os_unfair_lock_get_data(
&mut self,
fn os_unfair_lock_get_data<'a>(
&'a mut self,
lock_ptr: &OpTy<'tcx>,
) -> InterpResult<'tcx, MacOsUnfairLock> {
) -> InterpResult<'tcx, &'a MacOsUnfairLock>
where
'tcx: 'a,
{
let this = self.eval_context_mut();
let lock = this.deref_pointer(lock_ptr)?;
this.lazy_sync_get_data(
Expand Down Expand Up @@ -68,6 +71,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
);
return interp_ok(());
};
let mutex_ref = mutex_ref.clone();

if this.mutex_is_locked(&mutex_ref) {
if this.mutex_get_owner(&mutex_ref) == this.active_thread() {
Expand Down Expand Up @@ -97,6 +101,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
this.write_scalar(Scalar::from_bool(false), dest)?;
return interp_ok(());
};
let mutex_ref = mutex_ref.clone();

if this.mutex_is_locked(&mutex_ref) {
// Contrary to the blocking lock function, this does not check for
Expand All @@ -119,6 +124,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
"attempted to unlock an os_unfair_lock not owned by the current thread".to_owned()
));
};
let mutex_ref = mutex_ref.clone();

// Now, unlock.
if this.mutex_unlock(&mutex_ref)?.is_none() {
Expand Down Expand Up @@ -147,6 +153,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
"called os_unfair_lock_assert_owner on an os_unfair_lock not owned by the current thread".to_owned()
));
};
let mutex_ref = mutex_ref.clone();

if !this.mutex_is_locked(&mutex_ref)
|| this.mutex_get_owner(&mutex_ref) != this.active_thread()
{
Expand All @@ -167,6 +175,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
// The lock is poisoned, who knows who owns it... we'll pretend: someone else.
return interp_ok(());
};
let mutex_ref = mutex_ref.clone();

if this.mutex_is_locked(&mutex_ref)
&& this.mutex_get_owner(&mutex_ref) == this.active_thread()
{
Expand Down
45 changes: 27 additions & 18 deletions src/tools/miri/src/shims/unix/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,10 @@ fn mutex_create<'tcx>(
fn mutex_get_data<'tcx, 'a>(
ecx: &'a mut MiriInterpCx<'tcx>,
mutex_ptr: &OpTy<'tcx>,
) -> InterpResult<'tcx, PthreadMutex> {
) -> InterpResult<'tcx, &'a PthreadMutex>
where
'tcx: 'a,
{
let mutex = ecx.deref_pointer(mutex_ptr)?;
ecx.lazy_sync_get_data(
&mutex,
Expand Down Expand Up @@ -259,10 +262,13 @@ fn rwlock_init_offset<'tcx>(ecx: &MiriInterpCx<'tcx>) -> InterpResult<'tcx, Size
interp_ok(offset)
}

fn rwlock_get_data<'tcx>(
ecx: &mut MiriInterpCx<'tcx>,
fn rwlock_get_data<'tcx, 'a>(
ecx: &'a mut MiriInterpCx<'tcx>,
rwlock_ptr: &OpTy<'tcx>,
) -> InterpResult<'tcx, PthreadRwLock> {
) -> InterpResult<'tcx, &'a PthreadRwLock>
where
'tcx: 'a,
{
let rwlock = ecx.deref_pointer(rwlock_ptr)?;
ecx.lazy_sync_get_data(
&rwlock,
Expand Down Expand Up @@ -389,10 +395,13 @@ fn cond_create<'tcx>(
interp_ok(data)
}

fn cond_get_data<'tcx>(
ecx: &mut MiriInterpCx<'tcx>,
fn cond_get_data<'tcx, 'a>(
ecx: &'a mut MiriInterpCx<'tcx>,
cond_ptr: &OpTy<'tcx>,
) -> InterpResult<'tcx, PthreadCondvar> {
) -> InterpResult<'tcx, &'a PthreadCondvar>
where
'tcx: 'a,
{
let cond = ecx.deref_pointer(cond_ptr)?;
ecx.lazy_sync_get_data(
&cond,
Expand Down Expand Up @@ -498,7 +507,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
) -> InterpResult<'tcx> {
let this = self.eval_context_mut();

let mutex = mutex_get_data(this, mutex_op)?;
let mutex = mutex_get_data(this, mutex_op)?.clone();

let ret = if this.mutex_is_locked(&mutex.mutex_ref) {
let owner_thread = this.mutex_get_owner(&mutex.mutex_ref);
Expand Down Expand Up @@ -535,7 +544,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
fn pthread_mutex_trylock(&mut self, mutex_op: &OpTy<'tcx>) -> InterpResult<'tcx, Scalar> {
let this = self.eval_context_mut();

let mutex = mutex_get_data(this, mutex_op)?;
let mutex = mutex_get_data(this, mutex_op)?.clone();

interp_ok(Scalar::from_i32(if this.mutex_is_locked(&mutex.mutex_ref) {
let owner_thread = this.mutex_get_owner(&mutex.mutex_ref);
Expand All @@ -561,7 +570,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
fn pthread_mutex_unlock(&mut self, mutex_op: &OpTy<'tcx>) -> InterpResult<'tcx, Scalar> {
let this = self.eval_context_mut();

let mutex = mutex_get_data(this, mutex_op)?;
let mutex = mutex_get_data(this, mutex_op)?.clone();

if let Some(_old_locked_count) = this.mutex_unlock(&mutex.mutex_ref)? {
// The mutex was locked by the current thread.
Expand Down Expand Up @@ -589,8 +598,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let this = self.eval_context_mut();

// Reading the field also has the side-effect that we detect double-`destroy`
// since we make the field unint below.
let mutex = mutex_get_data(this, mutex_op)?;
// since we make the field uninit below.
let mutex = mutex_get_data(this, mutex_op)?.clone();

if this.mutex_is_locked(&mutex.mutex_ref) {
throw_ub_format!("destroyed a locked mutex");
Expand Down Expand Up @@ -697,7 +706,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let this = self.eval_context_mut();

// Reading the field also has the side-effect that we detect double-`destroy`
// since we make the field unint below.
// since we make the field uninit below.
let id = rwlock_get_data(this, rwlock_op)?.id;

if this.rwlock_is_locked(id) {
Expand Down Expand Up @@ -822,8 +831,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
) -> InterpResult<'tcx> {
let this = self.eval_context_mut();

let data = cond_get_data(this, cond_op)?;
let mutex_ref = mutex_get_data(this, mutex_op)?.mutex_ref;
let data = *cond_get_data(this, cond_op)?;
let mutex_ref = mutex_get_data(this, mutex_op)?.mutex_ref.clone();

this.condvar_wait(
data.id,
Expand All @@ -846,8 +855,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
) -> InterpResult<'tcx> {
let this = self.eval_context_mut();

let data = cond_get_data(this, cond_op)?;
let mutex_ref = mutex_get_data(this, mutex_op)?.mutex_ref;
let data = *cond_get_data(this, cond_op)?;
let mutex_ref = mutex_get_data(this, mutex_op)?.mutex_ref.clone();

// Extract the timeout.
let duration = match this
Expand Down Expand Up @@ -884,7 +893,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let this = self.eval_context_mut();

// Reading the field also has the side-effect that we detect double-`destroy`
// since we make the field unint below.
// since we make the field uninit below.
let id = cond_get_data(this, cond_op)?.id;
if this.condvar_is_awaited(id) {
throw_ub_format!("destroying an awaited conditional variable");
Expand Down
9 changes: 6 additions & 3 deletions src/tools/miri/src/shims/windows/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> {
// Windows sync primitives are pointer sized.
// We only use the first 4 bytes for the id.

fn init_once_get_data(
&mut self,
fn init_once_get_data<'a>(
&'a mut self,
init_once_ptr: &OpTy<'tcx>,
) -> InterpResult<'tcx, WindowsInitOnce> {
) -> InterpResult<'tcx, &'a WindowsInitOnce>
where
'tcx: 'a,
{
let this = self.eval_context_mut();

let init_once = this.deref_pointer(init_once_ptr)?;
Expand Down

0 comments on commit 7c7371a

Please sign in to comment.