From 786ec9df806ca6348c53579041c93067975bf940 Mon Sep 17 00:00:00 2001
From: Aaron Robinson <arobins@microsoft.com>
Date: Tue, 21 Jun 2022 08:32:23 -0700
Subject: [PATCH] Switch to PREEMPT mode when calling CoTaskMem* APIs (#71031)

* Switch to PREEMPT mode when calling CoTaskMem* APIs

* Add failfast for COM calls from managed code.

Co-authored-by: Jan Kotas <jkotas@microsoft.com>
---
 src/coreclr/vm/comcallablewrapper.cpp |   7 ++
 src/coreclr/vm/comtoclrcall.cpp       |  12 +-
 src/coreclr/vm/ilmarshalers.cpp       |  11 +-
 src/coreclr/vm/olevariant.cpp         | 162 ++++++++++++++------------
 4 files changed, 115 insertions(+), 77 deletions(-)

diff --git a/src/coreclr/vm/comcallablewrapper.cpp b/src/coreclr/vm/comcallablewrapper.cpp
index 9b91cf1a9299d..bd3c48ed443fe 100644
--- a/src/coreclr/vm/comcallablewrapper.cpp
+++ b/src/coreclr/vm/comcallablewrapper.cpp
@@ -426,6 +426,13 @@ extern "C" PCODE ComPreStubWorker(ComPrestubMethodFrame *pPFrame, UINT64 *pError
     }
     else
     {
+        if (pThread->PreemptiveGCDisabled())
+        {
+            EEPOLICY_HANDLE_FATAL_ERROR_WITH_MESSAGE(
+                COR_E_EXECUTIONENGINE,
+                W("Invalid Program: attempted to call a COM method from managed code."));
+        }
+
         // Transition to cooperative GC mode before we start setting up the stub.
         GCX_COOP();
 
diff --git a/src/coreclr/vm/comtoclrcall.cpp b/src/coreclr/vm/comtoclrcall.cpp
index 7eae976a068ed..7f669709d3197 100644
--- a/src/coreclr/vm/comtoclrcall.cpp
+++ b/src/coreclr/vm/comtoclrcall.cpp
@@ -488,8 +488,7 @@ extern "C" UINT64 __stdcall COMToCLRWorker(Thread *pThread, ComMethodFrame* pFra
     // The following code is a transcription of the code that is generated by CreateGenericComCallStub.  The
     // idea is that we needn't really do this work either in static assembly code nor in dynamically
     // generated code since the benefit/cost ratio is low.  There are some minor differences in the below
-    // code, compared to x86.  First, the reentrancy and loader lock checks are optionally compiled into the
-    // stub on x86, depending on whether or not the corresponding MDAs are active at stub-generation time.
+    // code, compared to x86.
     // We must check each time at runtime here because we're using static code.
     //
     HRESULT hr = S_OK;
@@ -505,8 +504,15 @@ extern "C" UINT64 __stdcall COMToCLRWorker(Thread *pThread, ComMethodFrame* pFra
         }
     }
 
+    if (pThread->PreemptiveGCDisabled())
+    {
+        EEPOLICY_HANDLE_FATAL_ERROR_WITH_MESSAGE(
+            COR_E_EXECUTIONENGINE,
+            W("Invalid Program: attempted to call a COM method from managed code."));
+    }
+
     // Attempt to switch GC modes.  Note that this is performed manually just like in the x86 stub because
-    // we have additional checks for shutdown races, MDAs, and thread abort that are performed only when
+    // we have additional checks for thread abort that are performed only when
     // g_TrapReturningThreads is set.
     pThread->m_fPreemptiveGCDisabled.StoreWithoutBarrier(1);
     if (g_TrapReturningThreads.LoadWithoutBarrier())
diff --git a/src/coreclr/vm/ilmarshalers.cpp b/src/coreclr/vm/ilmarshalers.cpp
index 182a434c1e6b9..96914f90d3157 100644
--- a/src/coreclr/vm/ilmarshalers.cpp
+++ b/src/coreclr/vm/ilmarshalers.cpp
@@ -4348,7 +4348,11 @@ FCIMPL3(void, MngdNativeArrayMarshaler::ConvertSpaceToNative, MngdNativeArrayMar
         if ( (!ClrSafeInt<SIZE_T>::multiply(cElements, cbElement, cbArray)) || cbArray > MAX_SIZE_FOR_INTEROP)
             COMPlusThrow(kArgumentException, IDS_EE_STRUCTARRAYTOOLARGE);
 
-        *pNativeHome = CoTaskMemAlloc(cbArray);
+        {
+            GCX_PREEMP();
+            *pNativeHome = CoTaskMemAlloc(cbArray);
+        }
+
         if (*pNativeHome == NULL)
             ThrowOutOfMemory();
 
@@ -4461,7 +4465,10 @@ FCIMPL4(void, MngdNativeArrayMarshaler::ClearNative, MngdNativeArrayMarshaler* p
     if (*pNativeHome != NULL)
     {
         DoClearNativeContents(pThis, pManagedHome, pNativeHome, cElements);
-        CoTaskMemFree(*pNativeHome);
+        {
+            GCX_PREEMP();
+            CoTaskMemFree(*pNativeHome);
+        }
     }
 
     HELPER_METHOD_FRAME_END();
diff --git a/src/coreclr/vm/olevariant.cpp b/src/coreclr/vm/olevariant.cpp
index 223c6901c57ee..e4d9174dc3edc 100644
--- a/src/coreclr/vm/olevariant.cpp
+++ b/src/coreclr/vm/olevariant.cpp
@@ -2198,47 +2198,55 @@ void OleVariant::MarshalLPWSTRRArrayComToOle(BASEARRAYREF *pComArray, void *oleA
     }
     CONTRACTL_END;
 
-    ASSERT_PROTECTED(pComArray);
-
     LPWSTR *pOle = (LPWSTR *) oleArray;
     LPWSTR *pOleEnd = pOle + cElements;
 
-    STRINGREF *pCom = (STRINGREF *) (*pComArray)->GetDataPtr();
-
-    while (pOle < pOleEnd)
+    struct
+    {
+        BASEARRAYREF pCom;
+        STRINGREF stringRef;
+    } gc;
+    gc.pCom = *pComArray;
+    gc.stringRef = NULL;
+    GCPROTECT_BEGIN(gc)
     {
-        //
-        // We aren't calling anything which might cause a GC, so don't worry about
-        // the array moving here.
-        //
-
-        STRINGREF stringRef = *pCom++;
 
-        LPWSTR lpwstr;
-        if (stringRef == NULL)
-        {
-            lpwstr = NULL;
-        }
-        else
+        int i = 0;
+        while (pOle < pOleEnd)
         {
-            // Retrieve the length of the string.
-            int Length = stringRef->GetStringLength();
-            int allocLength = (Length + 1) * sizeof(WCHAR);
-            if (allocLength < Length)
-                ThrowOutOfMemory();
-
-            // Allocate the string using CoTaskMemAlloc.
-            lpwstr = (LPWSTR)CoTaskMemAlloc(allocLength);
-            if (lpwstr == NULL)
-                ThrowOutOfMemory();
-
-            // Copy the COM+ string into the newly allocated LPWSTR.
-            memcpyNoGCRefs(lpwstr, stringRef->GetBuffer(), (Length + 1) * sizeof(WCHAR));
-            lpwstr[Length] = 0;
-        }
+            gc.stringRef = *((STRINGREF*)gc.pCom->GetDataPtr() + i);
 
-        *pOle++ = lpwstr;
+            LPWSTR lpwstr;
+            if (gc.stringRef == NULL)
+            {
+                lpwstr = NULL;
+            }
+            else
+            {
+                // Retrieve the length of the string.
+                int Length = gc.stringRef->GetStringLength();
+                int allocLength = (Length + 1) * sizeof(WCHAR);
+                if (allocLength < Length)
+                    ThrowOutOfMemory();
+
+                // Allocate the string using CoTaskMemAlloc.
+                {
+                    GCX_PREEMP();
+                    lpwstr = (LPWSTR)CoTaskMemAlloc(allocLength);
+                }
+                if (lpwstr == NULL)
+                    ThrowOutOfMemory();
+
+                // Copy the COM+ string into the newly allocated LPWSTR.
+                memcpyNoGCRefs(lpwstr, gc.stringRef->GetBuffer(), allocLength);
+                lpwstr[Length] = W('\0');
+            }
+
+            *pOle++ = lpwstr;
+            i++;
+        }
     }
+    GCPROTECT_END();
 }
 
 void OleVariant::ClearLPWSTRArray(void *oleArray, SIZE_T cElements, MethodTable *pInterfaceMT, PCODE pManagedMarshalerCode)
@@ -2246,12 +2254,13 @@ void OleVariant::ClearLPWSTRArray(void *oleArray, SIZE_T cElements, MethodTable
     CONTRACTL
     {
         NOTHROW;
-        GC_NOTRIGGER;
+        GC_TRIGGERS;
         MODE_ANY;
         PRECONDITION(CheckPointer(oleArray));
     }
     CONTRACTL_END;
 
+    GCX_PREEMP();
     LPWSTR *pOle = (LPWSTR *) oleArray;
     LPWSTR *pOleEnd = pOle + cElements;
 
@@ -2334,48 +2343,56 @@ void OleVariant::MarshalLPSTRRArrayComToOle(BASEARRAYREF *pComArray, void *oleAr
     }
     CONTRACTL_END;
 
-    ASSERT_PROTECTED(pComArray);
-
     LPSTR *pOle = (LPSTR *) oleArray;
     LPSTR *pOleEnd = pOle + cElements;
 
-    STRINGREF *pCom = (STRINGREF *) (*pComArray)->GetDataPtr();
-
-    while (pOle < pOleEnd)
+    struct
     {
-        //
-        // We aren't calling anything which might cause a GC, so don't worry about
-        // the array moving here.
-        //
-        STRINGREF stringRef = *pCom++;
-
-        CoTaskMemHolder<CHAR> lpstr(NULL);
-        if (stringRef == NULL)
-        {
-            lpstr = NULL;
-        }
-        else
+        BASEARRAYREF pCom;
+        STRINGREF stringRef;
+    } gc;
+    gc.pCom = *pComArray;
+    gc.stringRef = NULL;
+    GCPROTECT_BEGIN(gc)
+    {
+        int i = 0;
+        while (pOle < pOleEnd)
         {
-            // Retrieve the length of the string.
-            int Length = stringRef->GetStringLength();
-            int allocLength = Length * GetMaxDBCSCharByteSize() + 1;
-            if (allocLength < Length)
-                ThrowOutOfMemory();
-
-            // Allocate the string using CoTaskMemAlloc.
-            lpstr = (LPSTR)CoTaskMemAlloc(allocLength);
-            if (lpstr == NULL)
-                ThrowOutOfMemory();
-
-            // Convert the unicode string to an ansi string.
-            int bytesWritten = InternalWideToAnsi(stringRef->GetBuffer(), Length, lpstr, allocLength, fBestFitMapping, fThrowOnUnmappableChar);
-            _ASSERTE(bytesWritten >= 0 && bytesWritten < allocLength);
-            lpstr[bytesWritten] = 0;
-        }
+            gc.stringRef = *((STRINGREF*)gc.pCom->GetDataPtr() + i);
 
-        *pOle++ = lpstr;
-        lpstr.SuppressRelease();
+            CoTaskMemHolder<CHAR> lpstr(NULL);
+            if (gc.stringRef == NULL)
+            {
+                lpstr = NULL;
+            }
+            else
+            {
+                // Retrieve the length of the string.
+                int Length = gc.stringRef->GetStringLength();
+                int allocLength = Length * GetMaxDBCSCharByteSize() + 1;
+                if (allocLength < Length)
+                    ThrowOutOfMemory();
+
+                // Allocate the string using CoTaskMemAlloc.
+                {
+                    GCX_PREEMP();
+                    lpstr = (LPSTR)CoTaskMemAlloc(allocLength);
+                }
+                if (lpstr == NULL)
+                    ThrowOutOfMemory();
+
+                // Convert the unicode string to an ansi string.
+                int bytesWritten = InternalWideToAnsi(gc.stringRef->GetBuffer(), Length, lpstr, allocLength, fBestFitMapping, fThrowOnUnmappableChar);
+                _ASSERTE(bytesWritten >= 0 && bytesWritten < allocLength);
+                lpstr[bytesWritten] = '\0';
+            }
+
+            *pOle++ = lpstr;
+            i++;
+            lpstr.SuppressRelease();
+        }
     }
+    GCPROTECT_END();
 }
 
 void OleVariant::ClearLPSTRArray(void *oleArray, SIZE_T cElements, MethodTable *pInterfaceMT, PCODE pManagedMarshalerCode)
@@ -2383,12 +2400,13 @@ void OleVariant::ClearLPSTRArray(void *oleArray, SIZE_T cElements, MethodTable *
     CONTRACTL
     {
         NOTHROW;
-        GC_NOTRIGGER;
+        GC_TRIGGERS;
         MODE_ANY;
         PRECONDITION(CheckPointer(oleArray));
     }
     CONTRACTL_END;
 
+    GCX_PREEMP();
     LPSTR *pOle = (LPSTR *) oleArray;
     LPSTR *pOleEnd = pOle + cElements;
 
@@ -4640,8 +4658,8 @@ void OleVariant::MarshalArrayRefForSafeArray(SAFEARRAY *pSafeArray,
 
             if (!CorTypeInfo::IsPrimitiveType(th.GetInternalCorElementType()))
             {
-                _ASSERTE(!strcmp(th.AsMethodTable()->GetDebugClassName(),
-                                 "System.Currency"));
+                _ASSERTE(!strcmp(th.AsMethodTable()->GetDebugClassName(), "System.Currency")
+                        || !strcmp(th.AsMethodTable()->GetDebugClassName(), "System.Decimal"));
             }
         }
 #endif