From 786ec9df806ca6348c53579041c93067975bf940 Mon Sep 17 00:00:00 2001 From: Aaron Robinson <arobins@microsoft.com> Date: Tue, 21 Jun 2022 08:32:23 -0700 Subject: [PATCH] Switch to PREEMPT mode when calling CoTaskMem* APIs (#71031) * Switch to PREEMPT mode when calling CoTaskMem* APIs * Add failfast for COM calls from managed code. Co-authored-by: Jan Kotas <jkotas@microsoft.com> --- src/coreclr/vm/comcallablewrapper.cpp | 7 ++ src/coreclr/vm/comtoclrcall.cpp | 12 +- src/coreclr/vm/ilmarshalers.cpp | 11 +- src/coreclr/vm/olevariant.cpp | 162 ++++++++++++++------------ 4 files changed, 115 insertions(+), 77 deletions(-) diff --git a/src/coreclr/vm/comcallablewrapper.cpp b/src/coreclr/vm/comcallablewrapper.cpp index 9b91cf1a9299d..bd3c48ed443fe 100644 --- a/src/coreclr/vm/comcallablewrapper.cpp +++ b/src/coreclr/vm/comcallablewrapper.cpp @@ -426,6 +426,13 @@ extern "C" PCODE ComPreStubWorker(ComPrestubMethodFrame *pPFrame, UINT64 *pError } else { + if (pThread->PreemptiveGCDisabled()) + { + EEPOLICY_HANDLE_FATAL_ERROR_WITH_MESSAGE( + COR_E_EXECUTIONENGINE, + W("Invalid Program: attempted to call a COM method from managed code.")); + } + // Transition to cooperative GC mode before we start setting up the stub. GCX_COOP(); diff --git a/src/coreclr/vm/comtoclrcall.cpp b/src/coreclr/vm/comtoclrcall.cpp index 7eae976a068ed..7f669709d3197 100644 --- a/src/coreclr/vm/comtoclrcall.cpp +++ b/src/coreclr/vm/comtoclrcall.cpp @@ -488,8 +488,7 @@ extern "C" UINT64 __stdcall COMToCLRWorker(Thread *pThread, ComMethodFrame* pFra // The following code is a transcription of the code that is generated by CreateGenericComCallStub. The // idea is that we needn't really do this work either in static assembly code nor in dynamically // generated code since the benefit/cost ratio is low. There are some minor differences in the below - // code, compared to x86. First, the reentrancy and loader lock checks are optionally compiled into the - // stub on x86, depending on whether or not the corresponding MDAs are active at stub-generation time. + // code, compared to x86. // We must check each time at runtime here because we're using static code. // HRESULT hr = S_OK; @@ -505,8 +504,15 @@ extern "C" UINT64 __stdcall COMToCLRWorker(Thread *pThread, ComMethodFrame* pFra } } + if (pThread->PreemptiveGCDisabled()) + { + EEPOLICY_HANDLE_FATAL_ERROR_WITH_MESSAGE( + COR_E_EXECUTIONENGINE, + W("Invalid Program: attempted to call a COM method from managed code.")); + } + // Attempt to switch GC modes. Note that this is performed manually just like in the x86 stub because - // we have additional checks for shutdown races, MDAs, and thread abort that are performed only when + // we have additional checks for thread abort that are performed only when // g_TrapReturningThreads is set. pThread->m_fPreemptiveGCDisabled.StoreWithoutBarrier(1); if (g_TrapReturningThreads.LoadWithoutBarrier()) diff --git a/src/coreclr/vm/ilmarshalers.cpp b/src/coreclr/vm/ilmarshalers.cpp index 182a434c1e6b9..96914f90d3157 100644 --- a/src/coreclr/vm/ilmarshalers.cpp +++ b/src/coreclr/vm/ilmarshalers.cpp @@ -4348,7 +4348,11 @@ FCIMPL3(void, MngdNativeArrayMarshaler::ConvertSpaceToNative, MngdNativeArrayMar if ( (!ClrSafeInt<SIZE_T>::multiply(cElements, cbElement, cbArray)) || cbArray > MAX_SIZE_FOR_INTEROP) COMPlusThrow(kArgumentException, IDS_EE_STRUCTARRAYTOOLARGE); - *pNativeHome = CoTaskMemAlloc(cbArray); + { + GCX_PREEMP(); + *pNativeHome = CoTaskMemAlloc(cbArray); + } + if (*pNativeHome == NULL) ThrowOutOfMemory(); @@ -4461,7 +4465,10 @@ FCIMPL4(void, MngdNativeArrayMarshaler::ClearNative, MngdNativeArrayMarshaler* p if (*pNativeHome != NULL) { DoClearNativeContents(pThis, pManagedHome, pNativeHome, cElements); - CoTaskMemFree(*pNativeHome); + { + GCX_PREEMP(); + CoTaskMemFree(*pNativeHome); + } } HELPER_METHOD_FRAME_END(); diff --git a/src/coreclr/vm/olevariant.cpp b/src/coreclr/vm/olevariant.cpp index 223c6901c57ee..e4d9174dc3edc 100644 --- a/src/coreclr/vm/olevariant.cpp +++ b/src/coreclr/vm/olevariant.cpp @@ -2198,47 +2198,55 @@ void OleVariant::MarshalLPWSTRRArrayComToOle(BASEARRAYREF *pComArray, void *oleA } CONTRACTL_END; - ASSERT_PROTECTED(pComArray); - LPWSTR *pOle = (LPWSTR *) oleArray; LPWSTR *pOleEnd = pOle + cElements; - STRINGREF *pCom = (STRINGREF *) (*pComArray)->GetDataPtr(); - - while (pOle < pOleEnd) + struct + { + BASEARRAYREF pCom; + STRINGREF stringRef; + } gc; + gc.pCom = *pComArray; + gc.stringRef = NULL; + GCPROTECT_BEGIN(gc) { - // - // We aren't calling anything which might cause a GC, so don't worry about - // the array moving here. - // - - STRINGREF stringRef = *pCom++; - LPWSTR lpwstr; - if (stringRef == NULL) - { - lpwstr = NULL; - } - else + int i = 0; + while (pOle < pOleEnd) { - // Retrieve the length of the string. - int Length = stringRef->GetStringLength(); - int allocLength = (Length + 1) * sizeof(WCHAR); - if (allocLength < Length) - ThrowOutOfMemory(); - - // Allocate the string using CoTaskMemAlloc. - lpwstr = (LPWSTR)CoTaskMemAlloc(allocLength); - if (lpwstr == NULL) - ThrowOutOfMemory(); - - // Copy the COM+ string into the newly allocated LPWSTR. - memcpyNoGCRefs(lpwstr, stringRef->GetBuffer(), (Length + 1) * sizeof(WCHAR)); - lpwstr[Length] = 0; - } + gc.stringRef = *((STRINGREF*)gc.pCom->GetDataPtr() + i); - *pOle++ = lpwstr; + LPWSTR lpwstr; + if (gc.stringRef == NULL) + { + lpwstr = NULL; + } + else + { + // Retrieve the length of the string. + int Length = gc.stringRef->GetStringLength(); + int allocLength = (Length + 1) * sizeof(WCHAR); + if (allocLength < Length) + ThrowOutOfMemory(); + + // Allocate the string using CoTaskMemAlloc. + { + GCX_PREEMP(); + lpwstr = (LPWSTR)CoTaskMemAlloc(allocLength); + } + if (lpwstr == NULL) + ThrowOutOfMemory(); + + // Copy the COM+ string into the newly allocated LPWSTR. + memcpyNoGCRefs(lpwstr, gc.stringRef->GetBuffer(), allocLength); + lpwstr[Length] = W('\0'); + } + + *pOle++ = lpwstr; + i++; + } } + GCPROTECT_END(); } void OleVariant::ClearLPWSTRArray(void *oleArray, SIZE_T cElements, MethodTable *pInterfaceMT, PCODE pManagedMarshalerCode) @@ -2246,12 +2254,13 @@ void OleVariant::ClearLPWSTRArray(void *oleArray, SIZE_T cElements, MethodTable CONTRACTL { NOTHROW; - GC_NOTRIGGER; + GC_TRIGGERS; MODE_ANY; PRECONDITION(CheckPointer(oleArray)); } CONTRACTL_END; + GCX_PREEMP(); LPWSTR *pOle = (LPWSTR *) oleArray; LPWSTR *pOleEnd = pOle + cElements; @@ -2334,48 +2343,56 @@ void OleVariant::MarshalLPSTRRArrayComToOle(BASEARRAYREF *pComArray, void *oleAr } CONTRACTL_END; - ASSERT_PROTECTED(pComArray); - LPSTR *pOle = (LPSTR *) oleArray; LPSTR *pOleEnd = pOle + cElements; - STRINGREF *pCom = (STRINGREF *) (*pComArray)->GetDataPtr(); - - while (pOle < pOleEnd) + struct { - // - // We aren't calling anything which might cause a GC, so don't worry about - // the array moving here. - // - STRINGREF stringRef = *pCom++; - - CoTaskMemHolder<CHAR> lpstr(NULL); - if (stringRef == NULL) - { - lpstr = NULL; - } - else + BASEARRAYREF pCom; + STRINGREF stringRef; + } gc; + gc.pCom = *pComArray; + gc.stringRef = NULL; + GCPROTECT_BEGIN(gc) + { + int i = 0; + while (pOle < pOleEnd) { - // Retrieve the length of the string. - int Length = stringRef->GetStringLength(); - int allocLength = Length * GetMaxDBCSCharByteSize() + 1; - if (allocLength < Length) - ThrowOutOfMemory(); - - // Allocate the string using CoTaskMemAlloc. - lpstr = (LPSTR)CoTaskMemAlloc(allocLength); - if (lpstr == NULL) - ThrowOutOfMemory(); - - // Convert the unicode string to an ansi string. - int bytesWritten = InternalWideToAnsi(stringRef->GetBuffer(), Length, lpstr, allocLength, fBestFitMapping, fThrowOnUnmappableChar); - _ASSERTE(bytesWritten >= 0 && bytesWritten < allocLength); - lpstr[bytesWritten] = 0; - } + gc.stringRef = *((STRINGREF*)gc.pCom->GetDataPtr() + i); - *pOle++ = lpstr; - lpstr.SuppressRelease(); + CoTaskMemHolder<CHAR> lpstr(NULL); + if (gc.stringRef == NULL) + { + lpstr = NULL; + } + else + { + // Retrieve the length of the string. + int Length = gc.stringRef->GetStringLength(); + int allocLength = Length * GetMaxDBCSCharByteSize() + 1; + if (allocLength < Length) + ThrowOutOfMemory(); + + // Allocate the string using CoTaskMemAlloc. + { + GCX_PREEMP(); + lpstr = (LPSTR)CoTaskMemAlloc(allocLength); + } + if (lpstr == NULL) + ThrowOutOfMemory(); + + // Convert the unicode string to an ansi string. + int bytesWritten = InternalWideToAnsi(gc.stringRef->GetBuffer(), Length, lpstr, allocLength, fBestFitMapping, fThrowOnUnmappableChar); + _ASSERTE(bytesWritten >= 0 && bytesWritten < allocLength); + lpstr[bytesWritten] = '\0'; + } + + *pOle++ = lpstr; + i++; + lpstr.SuppressRelease(); + } } + GCPROTECT_END(); } void OleVariant::ClearLPSTRArray(void *oleArray, SIZE_T cElements, MethodTable *pInterfaceMT, PCODE pManagedMarshalerCode) @@ -2383,12 +2400,13 @@ void OleVariant::ClearLPSTRArray(void *oleArray, SIZE_T cElements, MethodTable * CONTRACTL { NOTHROW; - GC_NOTRIGGER; + GC_TRIGGERS; MODE_ANY; PRECONDITION(CheckPointer(oleArray)); } CONTRACTL_END; + GCX_PREEMP(); LPSTR *pOle = (LPSTR *) oleArray; LPSTR *pOleEnd = pOle + cElements; @@ -4640,8 +4658,8 @@ void OleVariant::MarshalArrayRefForSafeArray(SAFEARRAY *pSafeArray, if (!CorTypeInfo::IsPrimitiveType(th.GetInternalCorElementType())) { - _ASSERTE(!strcmp(th.AsMethodTable()->GetDebugClassName(), - "System.Currency")); + _ASSERTE(!strcmp(th.AsMethodTable()->GetDebugClassName(), "System.Currency") + || !strcmp(th.AsMethodTable()->GetDebugClassName(), "System.Decimal")); } } #endif