Skip to content

Commit

Permalink
Remove class constraint from Interlocked.{Compare}Exchange (#104558)
Browse files Browse the repository at this point in the history
* Remove class constraint from Interlocked.{Compare}Exchange

Today `Interlocked.CompareExchange<T>` and `Interlocked.Exchange<T>` support only reference type `T`s. Now that we have corresponding {Compare}Exchange methods that support types of size 1, 2, 4, and 8, we can remove the constraint and support any `T` that's either a reference type, a primitive type, or an enum type, making the generic overload more useful and avoiding consumers needing to choose less-than-ideal types just because of the need for atomicity with Interlocked.{Compare}Exchange.

---------

Co-authored-by: Michal Strehovský <MichalStrehovsky@users.noreply.github.com>
  • Loading branch information
stephentoub and MichalStrehovsky authored Jul 19, 2024
1 parent 6aa2862 commit 41e02e5
Show file tree
Hide file tree
Showing 62 changed files with 973 additions and 579 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,6 @@ public static long Exchange(ref long location1, long value)
[return: NotNullIfNotNull(nameof(location1))]
[MethodImpl(MethodImplOptions.InternalCall)]
private static extern object? ExchangeObject([NotNullIfNotNull(nameof(value))] ref object? location1, object? value);

// The below whole method reduces to a single call to Exchange(ref object, object) but
// the JIT thinks that it will generate more native code than it actually does.

/// <summary>Sets a variable of the specified type <typeparamref name="T"/> to a specified value and returns the original value, as an atomic operation.</summary>
/// <param name="location1">The variable to set to the specified value.</param>
/// <param name="value">The value to which the <paramref name="location1"/> parameter is set.</param>
/// <returns>The original value of <paramref name="location1"/>.</returns>
/// <exception cref="NullReferenceException">The address of location1 is a null pointer.</exception>
/// <typeparam name="T">The type to be used for <paramref name="location1"/> and <paramref name="value"/>. This type must be a reference type.</typeparam>
[Intrinsic]
[return: NotNullIfNotNull(nameof(location1))]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static T Exchange<T>([NotNullIfNotNull(nameof(value))] ref T location1, T value) where T : class? =>
Unsafe.As<T>(Exchange(ref Unsafe.As<T, object?>(ref location1), value));
#endregion

#region CompareExchange
Expand Down Expand Up @@ -183,29 +168,6 @@ public static long CompareExchange(ref long location1, long value, long comparan
[MethodImpl(MethodImplOptions.InternalCall)]
[return: NotNullIfNotNull(nameof(location1))]
private static extern object? CompareExchangeObject(ref object? location1, object? value, object? comparand);

// Note that getILIntrinsicImplementationForInterlocked() in vm\jitinterface.cpp replaces
// the body of the following method with the following IL:
// ldarg.0
// ldarg.1
// ldarg.2
// call System.Threading.Interlocked::CompareExchange(ref Object, Object, Object)
// ret
// The workaround is no longer strictly necessary now that we have Unsafe.As but it does
// have the advantage of being less sensitive to JIT's inliner decisions.

/// <summary>Compares two instances of the specified reference type <typeparamref name="T"/> for reference equality and, if they are equal, replaces the first one.</summary>
/// <param name="location1">The destination, whose value is compared by reference with <paramref name="comparand"/> and possibly replaced.</param>
/// <param name="value">The value that replaces the destination value if the comparison by reference results in equality.</param>
/// <param name="comparand">The object that is compared by reference to the value at <paramref name="location1"/>.</param>
/// <returns>The original value in <paramref name="location1"/>.</returns>
/// <exception cref="NullReferenceException">The address of <paramref name="location1"/> is a null pointer.</exception>
/// <typeparam name="T">The type to be used for <paramref name="location1"/>, <paramref name="value"/>, and <paramref name="comparand"/>. This type must be a reference type.</typeparam>
[Intrinsic]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
[return: NotNullIfNotNull(nameof(location1))]
public static T CompareExchange<T>(ref T location1, T value, T comparand) where T : class? =>
Unsafe.As<T>(CompareExchange(ref Unsafe.As<T, object?>(ref location1), value, comparand));
#endregion

#region Add
Expand Down
8 changes: 4 additions & 4 deletions src/coreclr/jit/importercalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4130,11 +4130,7 @@ GenTree* Compiler::impIntrinsic(CORINFO_CLASS_HANDLE clsHnd,
case NI_System_Threading_Interlocked_Exchange:
case NI_System_Threading_Interlocked_ExchangeAdd:
{
assert(callType != TYP_STRUCT);
assert(sig->numArgs == 2);

var_types retType = JITtype2varType(sig->retType);
assert((genTypeSize(retType) >= 4) || (ni == NI_System_Threading_Interlocked_Exchange));

if (genTypeSize(retType) > TARGET_POINTER_SIZE)
{
Expand All @@ -4159,6 +4155,10 @@ GenTree* Compiler::impIntrinsic(CORINFO_CLASS_HANDLE clsHnd,
break;
}

assert(callType != TYP_STRUCT);
assert(sig->numArgs == 2);
assert((genTypeSize(retType) >= 4) || (ni == NI_System_Threading_Interlocked_Exchange));

GenTree* op2 = impPopStack().val;
GenTree* op1 = impPopStack().val;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,6 @@ public static long CompareExchange(ref long location1, long value, long comparan
#endif
}

[Intrinsic]
[return: NotNullIfNotNull(nameof(location1))]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static T CompareExchange<T>(ref T location1, T value, T comparand) where T : class?
{
return Unsafe.As<T>(CompareExchange(ref Unsafe.As<T, object?>(ref location1), value, comparand));
}

[Intrinsic]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
[return: NotNullIfNotNull(nameof(location1))]
Expand Down Expand Up @@ -92,16 +84,6 @@ public static long Exchange(ref long location1, long value)
#endif
}

[Intrinsic]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
[return: NotNullIfNotNull(nameof(location1))]
public static T Exchange<T>([NotNullIfNotNull(nameof(value))] ref T location1, T value) where T : class?
{
if (Unsafe.IsNullRef(ref location1))
ThrowHelper.ThrowNullReferenceException();
return Unsafe.As<T>(RuntimeImports.InterlockedExchange(ref Unsafe.As<T, object?>(ref location1), value));
}

[Intrinsic]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
[return: NotNullIfNotNull(nameof(location1))]
Expand Down
12 changes: 6 additions & 6 deletions src/coreclr/tools/Common/TypeSystem/IL/NativeAotILProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,6 @@ private static MethodIL TryGetIntrinsicMethodIL(MethodDesc method)

switch (owningType.Name)
{
case "Interlocked":
{
if (owningType.Namespace == "System.Threading")
return InterlockedIntrinsics.EmitIL(method);
}
break;
case "Unsafe":
{
if (owningType.Namespace == "System.Runtime.CompilerServices")
Expand Down Expand Up @@ -108,6 +102,12 @@ private static MethodIL TryGetPerInstantiationIntrinsicMethodIL(MethodDesc metho

switch (owningType.Name)
{
case "Interlocked":
{
if (owningType.Namespace == "System.Threading")
return InterlockedIntrinsics.EmitIL(method);
}
break;
case "Activator":
{
TypeSystemContext context = owningType.Context;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public static MethodIL EmitIL(
MethodDesc method)
{
Debug.Assert(((MetadataType)method.OwningType).Name == "Interlocked");
Debug.Assert(!method.IsGenericMethodDefinition);

if (method.HasInstantiation && method.Name == "CompareExchange")
{
Expand All @@ -30,22 +31,45 @@ public static MethodIL EmitIL(
if (compilationModuleGroup.ContainsType(method.OwningType))
#endif // READYTORUN
{
TypeDesc objectType = method.Context.GetWellKnownType(WellKnownType.Object);
MethodDesc compareExchangeObject = method.OwningType.GetKnownMethod("CompareExchange",
new MethodSignature(
MethodSignatureFlags.Static,
genericParameterCount: 0,
returnType: objectType,
parameters: new TypeDesc[] { objectType.MakeByRefType(), objectType, objectType }));

ILEmitter emit = new ILEmitter();
ILCodeStream codeStream = emit.NewCodeStream();
codeStream.EmitLdArg(0);
codeStream.EmitLdArg(1);
codeStream.EmitLdArg(2);
codeStream.Emit(ILOpcode.call, emit.NewToken(compareExchangeObject));
codeStream.Emit(ILOpcode.ret);
return emit.Link(method);
// Rewrite the generic Interlocked.CompareExchange<T> to be a call to one of the non-generic overloads.
TypeDesc ceArgType = null;

TypeDesc tType = method.Instantiation[0];
if (!tType.IsValueType)
{
ceArgType = method.Context.GetWellKnownType(WellKnownType.Object);
}
else if ((tType.IsPrimitive || tType.IsEnum) && (tType.UnderlyingType.Category is not (TypeFlags.Single or TypeFlags.Double)))
{
int size = tType.GetElementSize().AsInt;
Debug.Assert(size is 1 or 2 or 4 or 8);
ceArgType = size switch
{
1 => method.Context.GetWellKnownType(WellKnownType.Byte),
2 => method.Context.GetWellKnownType(WellKnownType.UInt16),
4 => method.Context.GetWellKnownType(WellKnownType.Int32),
_ => method.Context.GetWellKnownType(WellKnownType.Int64),
};
}

if (ceArgType is not null)
{
MethodDesc compareExchangeNonGeneric = method.OwningType.GetKnownMethod("CompareExchange",
new MethodSignature(
MethodSignatureFlags.Static,
genericParameterCount: 0,
returnType: ceArgType,
parameters: [ceArgType.MakeByRefType(), ceArgType, ceArgType]));

ILEmitter emit = new ILEmitter();
ILCodeStream codeStream = emit.NewCodeStream();
codeStream.EmitLdArg(0);
codeStream.EmitLdArg(1);
codeStream.EmitLdArg(2);
codeStream.Emit(ILOpcode.call, emit.NewToken(compareExchangeNonGeneric));
codeStream.Emit(ILOpcode.ret);
return emit.Link(method);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,6 @@ private MethodIL TryGetIntrinsicMethodIL(MethodDesc method)
return UnsafeIntrinsics.EmitIL(method);
}

if (mdType.Name == "Interlocked" && mdType.Namespace == "System.Threading")
{
return InterlockedIntrinsics.EmitIL(_compilationModuleGroup, method);
}

return null;
}

Expand All @@ -114,6 +109,11 @@ private MethodIL TryGetPerInstantiationIntrinsicMethodIL(MethodDesc method)
return TryGetIntrinsicMethodILForActivator(method);
}

if (mdType.Name == "Interlocked" && mdType.Namespace == "System.Threading")
{
return InterlockedIntrinsics.EmitIL(_compilationModuleGroup, method);
}

return null;
}

Expand Down
4 changes: 4 additions & 0 deletions src/coreclr/vm/corelib.h
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,10 @@ DEFINE_METHOD(MEMORY_MARSHAL, GET_ARRAY_DATA_REFERENCE_MDARRAY, GetArrayDa
DEFINE_CLASS(INTERLOCKED, Threading, Interlocked)
DEFINE_METHOD(INTERLOCKED, COMPARE_EXCHANGE_T, CompareExchange, GM_RefT_T_T_RetT)
DEFINE_METHOD(INTERLOCKED, COMPARE_EXCHANGE_OBJECT,CompareExchange, SM_RefObject_Object_Object_RetObject)
DEFINE_METHOD(INTERLOCKED, COMPARE_EXCHANGE_BYTE, CompareExchange, SM_RefByte_Byte_Byte_RetByte)
DEFINE_METHOD(INTERLOCKED, COMPARE_EXCHANGE_USHRT, CompareExchange, SM_RefUShrt_UShrt_UShrt_RetUShrt)
DEFINE_METHOD(INTERLOCKED, COMPARE_EXCHANGE_INT, CompareExchange, SM_RefInt_Int_Int_RetInt)
DEFINE_METHOD(INTERLOCKED, COMPARE_EXCHANGE_LONG, CompareExchange, SM_RefLong_Long_Long_RetLong)

DEFINE_CLASS(RAW_DATA, CompilerServices, RawData)
DEFINE_FIELD(RAW_DATA, DATA, Data)
Expand Down
91 changes: 71 additions & 20 deletions src/coreclr/vm/jitinterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7154,28 +7154,79 @@ bool getILIntrinsicImplementationForInterlocked(MethodDesc * ftn,
if (ftn->GetMemberDef() != CoreLibBinder::GetMethod(METHOD__INTERLOCKED__COMPARE_EXCHANGE_T)->GetMemberDef())
return false;

// Get MethodDesc for non-generic System.Threading.Interlocked.CompareExchange()
MethodDesc* cmpxchgObject = CoreLibBinder::GetMethod(METHOD__INTERLOCKED__COMPARE_EXCHANGE_OBJECT);

// Setup up the body of the method
static BYTE il[] = {
CEE_LDARG_0,
CEE_LDARG_1,
CEE_LDARG_2,
CEE_CALL,0,0,0,0,
CEE_RET
};

// Get the token for non-generic System.Threading.Interlocked.CompareExchange(), and patch [target]
mdMethodDef cmpxchgObjectToken = cmpxchgObject->GetMemberDef();
il[4] = (BYTE)((int)cmpxchgObjectToken >> 0);
il[5] = (BYTE)((int)cmpxchgObjectToken >> 8);
il[6] = (BYTE)((int)cmpxchgObjectToken >> 16);
il[7] = (BYTE)((int)cmpxchgObjectToken >> 24);
// Determine the type of the generic T method parameter
_ASSERTE(ftn->HasMethodInstantiation());
_ASSERTE(ftn->GetNumGenericMethodArgs() == 1);
TypeHandle typeHandle = ftn->GetMethodInstantiation()[0];

// Setup up the body of the CompareExchange methods; the method token will be patched on first use.
static BYTE il[5][9] =
{
{ CEE_LDARG_0, CEE_LDARG_1, CEE_LDARG_2, CEE_CALL, 0, 0, 0, 0, CEE_RET }, // object
{ CEE_LDARG_0, CEE_LDARG_1, CEE_LDARG_2, CEE_CALL, 0, 0, 0, 0, CEE_RET }, // byte
{ CEE_LDARG_0, CEE_LDARG_1, CEE_LDARG_2, CEE_CALL, 0, 0, 0, 0, CEE_RET }, // ushort
{ CEE_LDARG_0, CEE_LDARG_1, CEE_LDARG_2, CEE_CALL, 0, 0, 0, 0, CEE_RET }, // int
{ CEE_LDARG_0, CEE_LDARG_1, CEE_LDARG_2, CEE_CALL, 0, 0, 0, 0, CEE_RET }, // long
};

// Based on the generic method parameter, determine which overload of CompareExchange
// to delegate to, or if we can't handle the type at all.
int ilIndex;
MethodDesc* cmpxchgMethod;
if (!typeHandle.IsValueType())
{
ilIndex = 0;
cmpxchgMethod = CoreLibBinder::GetMethod(METHOD__INTERLOCKED__COMPARE_EXCHANGE_OBJECT);
}
else
{
CorElementType elementType = typeHandle.GetVerifierCorElementType();
if (!CorTypeInfo::IsPrimitiveType(elementType) ||
elementType == ELEMENT_TYPE_R4 ||
elementType == ELEMENT_TYPE_R8)
{
return false;
}
else
{
switch (typeHandle.GetSize())
{
case 1:
ilIndex = 1;
cmpxchgMethod = CoreLibBinder::GetMethod(METHOD__INTERLOCKED__COMPARE_EXCHANGE_BYTE);
break;

case 2:
ilIndex = 2;
cmpxchgMethod = CoreLibBinder::GetMethod(METHOD__INTERLOCKED__COMPARE_EXCHANGE_USHRT);
break;

case 4:
ilIndex = 3;
cmpxchgMethod = CoreLibBinder::GetMethod(METHOD__INTERLOCKED__COMPARE_EXCHANGE_INT);
break;

case 8:
ilIndex = 4;
cmpxchgMethod = CoreLibBinder::GetMethod(METHOD__INTERLOCKED__COMPARE_EXCHANGE_LONG);
break;

default:
_ASSERT(!"Unexpected primitive type size");
return false;
}
}
}

mdMethodDef cmpxchgToken = cmpxchgMethod->GetMemberDef();
il[ilIndex][4] = (BYTE)((int)cmpxchgToken >> 0);
il[ilIndex][5] = (BYTE)((int)cmpxchgToken >> 8);
il[ilIndex][6] = (BYTE)((int)cmpxchgToken >> 16);
il[ilIndex][7] = (BYTE)((int)cmpxchgToken >> 24);

// Initialize methInfo
methInfo->ILCode = const_cast<BYTE*>(il);
methInfo->ILCodeSize = sizeof(il);
methInfo->ILCode = const_cast<BYTE*>(il[ilIndex]);
methInfo->ILCodeSize = sizeof(il[ilIndex]);
methInfo->maxStack = 3;
methInfo->EHcount = 0;
methInfo->options = (CorInfoOptions)0;
Expand Down
2 changes: 2 additions & 0 deletions src/coreclr/vm/metasig.h
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,8 @@ DEFINE_METASIG_T(SM(RefDec_RetVoid, r(g(DECIMAL)), v))

DEFINE_METASIG(GM(RefT_T_T_RetT, IMAGE_CEE_CS_CALLCONV_DEFAULT, 1, r(M(0)) M(0) M(0), M(0)))
DEFINE_METASIG(SM(RefObject_Object_Object_RetObject, r(j) j j, j))
DEFINE_METASIG(SM(RefByte_Byte_Byte_RetByte, r(b) b b, b))
DEFINE_METASIG(SM(RefUShrt_UShrt_UShrt_RetUShrt, r(H) H H, H))

DEFINE_METASIG_T(SM(RefCleanupWorkListElement_RetVoid, r(C(CLEANUP_WORK_LIST_ELEMENT)), v))
DEFINE_METASIG_T(SM(RefCleanupWorkListElement_SafeHandle_RetIntPtr, r(C(CLEANUP_WORK_LIST_ELEMENT)) C(SAFE_HANDLE), I))
Expand Down
12 changes: 6 additions & 6 deletions src/libraries/Common/src/System/Net/StreamBuffer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -292,15 +292,15 @@ private sealed class ResettableValueTaskSource : IValueTaskSource

private ManualResetValueTaskSourceCore<bool> _waitSource; // mutable struct, do not make this readonly
private CancellationTokenRegistration _waitSourceCancellation;
private int _hasWaiter;
private bool _hasWaiter;

ValueTaskSourceStatus IValueTaskSource.GetStatus(short token) => _waitSource.GetStatus(token);

void IValueTaskSource.OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) => _waitSource.OnCompleted(continuation, state, token, flags);

void IValueTaskSource.GetResult(short token)
{
Debug.Assert(_hasWaiter == 0);
Debug.Assert(!_hasWaiter);

// Clean up the registration. This will wait for any in-flight cancellation to complete.
_waitSourceCancellation.Dispose();
Expand All @@ -312,7 +312,7 @@ void IValueTaskSource.GetResult(short token)

public void SignalWaiter()
{
if (Interlocked.Exchange(ref _hasWaiter, 0) == 1)
if (Interlocked.Exchange(ref _hasWaiter, false))
{
_waitSource.SetResult(true);
}
Expand All @@ -322,21 +322,21 @@ private void CancelWaiter(CancellationToken cancellationToken)
{
Debug.Assert(cancellationToken.IsCancellationRequested);

if (Interlocked.Exchange(ref _hasWaiter, 0) == 1)
if (Interlocked.Exchange(ref _hasWaiter, false))
{
_waitSource.SetException(ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException(cancellationToken)));
}
}

public void Reset()
{
if (_hasWaiter != 0)
if (_hasWaiter)
{
throw new InvalidOperationException("Concurrent use is not supported");
}

_waitSource.Reset();
Volatile.Write(ref _hasWaiter, 1);
Volatile.Write(ref _hasWaiter, true);
}

public void Wait()
Expand Down
Loading

0 comments on commit 41e02e5

Please sign in to comment.