Skip to content

Commit

Permalink
Switch to PREEMPT mode when calling CoTaskMem* APIs (#71031)
Browse files Browse the repository at this point in the history
* Switch to PREEMPT mode when calling CoTaskMem* APIs

* Add failfast for COM calls from managed code.

Co-authored-by: Jan Kotas <jkotas@microsoft.com>
  • Loading branch information
AaronRobinsonMSFT and jkotas authored Jun 21, 2022
1 parent 3affc68 commit 786ec9d
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 77 deletions.
7 changes: 7 additions & 0 deletions src/coreclr/vm/comcallablewrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,13 @@ extern "C" PCODE ComPreStubWorker(ComPrestubMethodFrame *pPFrame, UINT64 *pError
}
else
{
if (pThread->PreemptiveGCDisabled())
{
EEPOLICY_HANDLE_FATAL_ERROR_WITH_MESSAGE(
COR_E_EXECUTIONENGINE,
W("Invalid Program: attempted to call a COM method from managed code."));
}

// Transition to cooperative GC mode before we start setting up the stub.
GCX_COOP();

Expand Down
12 changes: 9 additions & 3 deletions src/coreclr/vm/comtoclrcall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,7 @@ extern "C" UINT64 __stdcall COMToCLRWorker(Thread *pThread, ComMethodFrame* pFra
// The following code is a transcription of the code that is generated by CreateGenericComCallStub. The
// idea is that we needn't really do this work either in static assembly code nor in dynamically
// generated code since the benefit/cost ratio is low. There are some minor differences in the below
// code, compared to x86. First, the reentrancy and loader lock checks are optionally compiled into the
// stub on x86, depending on whether or not the corresponding MDAs are active at stub-generation time.
// code, compared to x86.
// We must check each time at runtime here because we're using static code.
//
HRESULT hr = S_OK;
Expand All @@ -505,8 +504,15 @@ extern "C" UINT64 __stdcall COMToCLRWorker(Thread *pThread, ComMethodFrame* pFra
}
}

if (pThread->PreemptiveGCDisabled())
{
EEPOLICY_HANDLE_FATAL_ERROR_WITH_MESSAGE(
COR_E_EXECUTIONENGINE,
W("Invalid Program: attempted to call a COM method from managed code."));
}

// Attempt to switch GC modes. Note that this is performed manually just like in the x86 stub because
// we have additional checks for shutdown races, MDAs, and thread abort that are performed only when
// we have additional checks for thread abort that are performed only when
// g_TrapReturningThreads is set.
pThread->m_fPreemptiveGCDisabled.StoreWithoutBarrier(1);
if (g_TrapReturningThreads.LoadWithoutBarrier())
Expand Down
11 changes: 9 additions & 2 deletions src/coreclr/vm/ilmarshalers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4348,7 +4348,11 @@ FCIMPL3(void, MngdNativeArrayMarshaler::ConvertSpaceToNative, MngdNativeArrayMar
if ( (!ClrSafeInt<SIZE_T>::multiply(cElements, cbElement, cbArray)) || cbArray > MAX_SIZE_FOR_INTEROP)
COMPlusThrow(kArgumentException, IDS_EE_STRUCTARRAYTOOLARGE);

*pNativeHome = CoTaskMemAlloc(cbArray);
{
GCX_PREEMP();
*pNativeHome = CoTaskMemAlloc(cbArray);
}

if (*pNativeHome == NULL)
ThrowOutOfMemory();

Expand Down Expand Up @@ -4461,7 +4465,10 @@ FCIMPL4(void, MngdNativeArrayMarshaler::ClearNative, MngdNativeArrayMarshaler* p
if (*pNativeHome != NULL)
{
DoClearNativeContents(pThis, pManagedHome, pNativeHome, cElements);
CoTaskMemFree(*pNativeHome);
{
GCX_PREEMP();
CoTaskMemFree(*pNativeHome);
}
}

HELPER_METHOD_FRAME_END();
Expand Down
162 changes: 90 additions & 72 deletions src/coreclr/vm/olevariant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2198,60 +2198,69 @@ void OleVariant::MarshalLPWSTRRArrayComToOle(BASEARRAYREF *pComArray, void *oleA
}
CONTRACTL_END;

ASSERT_PROTECTED(pComArray);

LPWSTR *pOle = (LPWSTR *) oleArray;
LPWSTR *pOleEnd = pOle + cElements;

STRINGREF *pCom = (STRINGREF *) (*pComArray)->GetDataPtr();

while (pOle < pOleEnd)
struct
{
BASEARRAYREF pCom;
STRINGREF stringRef;
} gc;
gc.pCom = *pComArray;
gc.stringRef = NULL;
GCPROTECT_BEGIN(gc)
{
//
// We aren't calling anything which might cause a GC, so don't worry about
// the array moving here.
//

STRINGREF stringRef = *pCom++;

LPWSTR lpwstr;
if (stringRef == NULL)
{
lpwstr = NULL;
}
else
int i = 0;
while (pOle < pOleEnd)
{
// Retrieve the length of the string.
int Length = stringRef->GetStringLength();
int allocLength = (Length + 1) * sizeof(WCHAR);
if (allocLength < Length)
ThrowOutOfMemory();

// Allocate the string using CoTaskMemAlloc.
lpwstr = (LPWSTR)CoTaskMemAlloc(allocLength);
if (lpwstr == NULL)
ThrowOutOfMemory();

// Copy the COM+ string into the newly allocated LPWSTR.
memcpyNoGCRefs(lpwstr, stringRef->GetBuffer(), (Length + 1) * sizeof(WCHAR));
lpwstr[Length] = 0;
}
gc.stringRef = *((STRINGREF*)gc.pCom->GetDataPtr() + i);

*pOle++ = lpwstr;
LPWSTR lpwstr;
if (gc.stringRef == NULL)
{
lpwstr = NULL;
}
else
{
// Retrieve the length of the string.
int Length = gc.stringRef->GetStringLength();
int allocLength = (Length + 1) * sizeof(WCHAR);
if (allocLength < Length)
ThrowOutOfMemory();

// Allocate the string using CoTaskMemAlloc.
{
GCX_PREEMP();
lpwstr = (LPWSTR)CoTaskMemAlloc(allocLength);
}
if (lpwstr == NULL)
ThrowOutOfMemory();

// Copy the COM+ string into the newly allocated LPWSTR.
memcpyNoGCRefs(lpwstr, gc.stringRef->GetBuffer(), allocLength);
lpwstr[Length] = W('\0');
}

*pOle++ = lpwstr;
i++;
}
}
GCPROTECT_END();
}

void OleVariant::ClearLPWSTRArray(void *oleArray, SIZE_T cElements, MethodTable *pInterfaceMT, PCODE pManagedMarshalerCode)
{
CONTRACTL
{
NOTHROW;
GC_NOTRIGGER;
GC_TRIGGERS;
MODE_ANY;
PRECONDITION(CheckPointer(oleArray));
}
CONTRACTL_END;

GCX_PREEMP();
LPWSTR *pOle = (LPWSTR *) oleArray;
LPWSTR *pOleEnd = pOle + cElements;

Expand Down Expand Up @@ -2334,61 +2343,70 @@ void OleVariant::MarshalLPSTRRArrayComToOle(BASEARRAYREF *pComArray, void *oleAr
}
CONTRACTL_END;

ASSERT_PROTECTED(pComArray);

LPSTR *pOle = (LPSTR *) oleArray;
LPSTR *pOleEnd = pOle + cElements;

STRINGREF *pCom = (STRINGREF *) (*pComArray)->GetDataPtr();

while (pOle < pOleEnd)
struct
{
//
// We aren't calling anything which might cause a GC, so don't worry about
// the array moving here.
//
STRINGREF stringRef = *pCom++;

CoTaskMemHolder<CHAR> lpstr(NULL);
if (stringRef == NULL)
{
lpstr = NULL;
}
else
BASEARRAYREF pCom;
STRINGREF stringRef;
} gc;
gc.pCom = *pComArray;
gc.stringRef = NULL;
GCPROTECT_BEGIN(gc)
{
int i = 0;
while (pOle < pOleEnd)
{
// Retrieve the length of the string.
int Length = stringRef->GetStringLength();
int allocLength = Length * GetMaxDBCSCharByteSize() + 1;
if (allocLength < Length)
ThrowOutOfMemory();

// Allocate the string using CoTaskMemAlloc.
lpstr = (LPSTR)CoTaskMemAlloc(allocLength);
if (lpstr == NULL)
ThrowOutOfMemory();

// Convert the unicode string to an ansi string.
int bytesWritten = InternalWideToAnsi(stringRef->GetBuffer(), Length, lpstr, allocLength, fBestFitMapping, fThrowOnUnmappableChar);
_ASSERTE(bytesWritten >= 0 && bytesWritten < allocLength);
lpstr[bytesWritten] = 0;
}
gc.stringRef = *((STRINGREF*)gc.pCom->GetDataPtr() + i);

*pOle++ = lpstr;
lpstr.SuppressRelease();
CoTaskMemHolder<CHAR> lpstr(NULL);
if (gc.stringRef == NULL)
{
lpstr = NULL;
}
else
{
// Retrieve the length of the string.
int Length = gc.stringRef->GetStringLength();
int allocLength = Length * GetMaxDBCSCharByteSize() + 1;
if (allocLength < Length)
ThrowOutOfMemory();

// Allocate the string using CoTaskMemAlloc.
{
GCX_PREEMP();
lpstr = (LPSTR)CoTaskMemAlloc(allocLength);
}
if (lpstr == NULL)
ThrowOutOfMemory();

// Convert the unicode string to an ansi string.
int bytesWritten = InternalWideToAnsi(gc.stringRef->GetBuffer(), Length, lpstr, allocLength, fBestFitMapping, fThrowOnUnmappableChar);
_ASSERTE(bytesWritten >= 0 && bytesWritten < allocLength);
lpstr[bytesWritten] = '\0';
}

*pOle++ = lpstr;
i++;
lpstr.SuppressRelease();
}
}
GCPROTECT_END();
}

void OleVariant::ClearLPSTRArray(void *oleArray, SIZE_T cElements, MethodTable *pInterfaceMT, PCODE pManagedMarshalerCode)
{
CONTRACTL
{
NOTHROW;
GC_NOTRIGGER;
GC_TRIGGERS;
MODE_ANY;
PRECONDITION(CheckPointer(oleArray));
}
CONTRACTL_END;

GCX_PREEMP();
LPSTR *pOle = (LPSTR *) oleArray;
LPSTR *pOleEnd = pOle + cElements;

Expand Down Expand Up @@ -4640,8 +4658,8 @@ void OleVariant::MarshalArrayRefForSafeArray(SAFEARRAY *pSafeArray,

if (!CorTypeInfo::IsPrimitiveType(th.GetInternalCorElementType()))
{
_ASSERTE(!strcmp(th.AsMethodTable()->GetDebugClassName(),
"System.Currency"));
_ASSERTE(!strcmp(th.AsMethodTable()->GetDebugClassName(), "System.Currency")
|| !strcmp(th.AsMethodTable()->GetDebugClassName(), "System.Decimal"));
}
}
#endif
Expand Down

0 comments on commit 786ec9d

Please sign in to comment.