Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Special casing System.Guid for COM VARIANT marshalling #100377

Merged
merged 11 commits into from
Apr 5, 2024
33 changes: 25 additions & 8 deletions src/coreclr/vm/olevariant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2567,17 +2567,34 @@ void OleVariant::MarshalRecordVariantOleToCom(VARIANT *pOleVariant,
if (!pRecInfo)
COMPlusThrow(kArgumentException, IDS_EE_INVALID_OLE_VARIANT);

LPVOID pvRecord = V_RECORD(pOleVariant);
if (pvRecord == NULL)
{
pComVariant->SetObjRef(NULL);
return;
}

MethodTable* pValueClass = NULL;
{
GCX_PREEMP();
pValueClass = GetMethodTableForRecordInfo(pRecInfo);
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved
}

if (pValueClass == NULL)
{
// This value type should have been registered through
// a TLB. CoreCLR doesn't support dynamic type mapping.
COMPlusThrow(kArgumentException, IDS_EE_CANNOT_MAP_TO_MANAGED_VC);
}
_ASSERTE(pValueClass->IsBlittable());

OBJECTREF BoxedValueClass = NULL;
GCPROTECT_BEGIN(BoxedValueClass)
{
LPVOID pvRecord = V_RECORD(pOleVariant);
if (pvRecord)
{
// This value type should have been registered through
// a TLB. CoreCLR doesn't support dynamic type mapping.
COMPlusThrow(kArgumentException, IDS_EE_CANNOT_MAP_TO_MANAGED_VC);
}

// Now that we have a blittable value class, allocate an instance of the
// boxed value class and copy the contents of the record into it.
BoxedValueClass = AllocateObject(pValueClass);
memcpyNoGCRefs(BoxedValueClass->GetData(), (BYTE*)pvRecord, pValueClass->GetNativeSize());
pComVariant->SetObjRef(BoxedValueClass);
}
GCPROTECT_END();
Expand Down
94 changes: 94 additions & 0 deletions src/coreclr/vm/stdinterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,43 @@ HRESULT GetITypeLibForAssembly(_In_ Assembly *pAssembly, _Outptr_ ITypeLib **ppT
return S_OK;
} // HRESULT GetITypeLibForAssembly()

// .NET Framework's mscorlib TLB GUID.
static const GUID s_MscorlibGuid = { 0xBED7F4EA, 0x1A96, 0x11D2, { 0x8F, 0x08, 0x00, 0xA0, 0xC9, 0xA6, 0x18, 0x6D } };

// Hard-coded GUID for System.Guid.
static const GUID s_GuidForSystemGuid = { 0x9C5923E9, 0xDE52, 0x33EA, { 0x88, 0xDE, 0x7E, 0xBC, 0x86, 0x33, 0xB9, 0xCC } };

// There are types that are helpful to provide that facilitate porting from
// .NET Framework to .NET 8+. This function is used to acquire their ITypeInfo.
// This should be used narrowly. Types at a minimum should be blittable.
static bool TryDeferToMscorlib(MethodTable* pClass, ITypeInfo** ppTI)
{
CONTRACTL
{
THROWS;
GC_TRIGGERS;
MODE_PREEMPTIVE;
PRECONDITION(pClass != NULL);
PRECONDITION(pClass->IsBlittable());
PRECONDITION(ppTI != NULL);
}
CONTRACTL_END;

// Marshalling of System.Guid is a common scenario that impacts many teams porting
// code to .NET 8+. Try to load the .NET Framework's TLB to support this scenario.
if (pClass == CoreLibBinder::GetClass(CLASS__GUID))
{
SafeComHolder<ITypeLib> pMscorlibTypeLib = NULL;
if (SUCCEEDED(::LoadRegTypeLib(s_MscorlibGuid, 2, 4, 0, &pMscorlibTypeLib)))
{
if (SUCCEEDED(pMscorlibTypeLib->GetTypeInfoOfGuid(s_GuidForSystemGuid, ppTI)))
return true;
}
}

return false;
}

HRESULT GetITypeInfoForEEClass(MethodTable *pClass, ITypeInfo **ppTI, bool bClassInfo)
{
CONTRACTL
Expand All @@ -625,6 +662,7 @@ HRESULT GetITypeInfoForEEClass(MethodTable *pClass, ITypeInfo **ppTI, bool bClas
GUID clsid;
GUID ciid;
ComMethodTable *pComMT = NULL;
MethodTable* pOriginalClass = pClass;
HRESULT hr = S_OK;
SafeComHolder<ITypeLib> pITLB = NULL;
SafeComHolder<ITypeInfo> pTI = NULL;
Expand Down Expand Up @@ -770,12 +808,68 @@ HRESULT GetITypeInfoForEEClass(MethodTable *pClass, ITypeInfo **ppTI, bool bClas
{
if (!FAILED(hr))
hr = E_FAIL;

if (pOriginalClass->IsValueType() && pOriginalClass->IsBlittable())
{
if (TryDeferToMscorlib(pOriginalClass, ppTI))
hr = S_OK;
}
}

ReturnHR:
return hr;
} // HRESULT GetITypeInfoForEEClass()

// Only a narrow set of types are supported.
// See TryDeferToMscorlib() above.
MethodTable* GetMethodTableForRecordInfo(IRecordInfo* recInfo)
{
CONTRACTL
{
THROWS;
GC_TRIGGERS;
MODE_PREEMPTIVE;
PRECONDITION(recInfo != NULL);
}
CONTRACTL_END;

HRESULT hr;

// Verify the associated TypeLib attribute
SafeComHolder<ITypeInfo> typeInfo;
hr = recInfo->GetTypeInfo(&typeInfo);
if (FAILED(hr))
return NULL;

SafeComHolder<ITypeLib> typeLib;
UINT index;
hr = typeInfo->GetContainingTypeLib(&typeLib, &index);
if (FAILED(hr))
return NULL;

TLIBATTR* attrs;
hr = typeLib->GetLibAttr(&attrs);
if (FAILED(hr))
return NULL;

GUID libGuid = attrs->guid;
typeLib->ReleaseTLibAttr(attrs);
if (s_MscorlibGuid != libGuid)
return NULL;

// Verify the Guid of the associated type
GUID typeGuid;
hr = recInfo->GetGuid(&typeGuid);
if (FAILED(hr))
return NULL;

// Check for supported types.
if (s_GuidForSystemGuid == typeGuid)
return CoreLibBinder::GetClass(CLASS__GUID);

return NULL;
}

// Returns a NON-ADDREF'd ITypeInfo.
HRESULT GetITypeInfoForMT(ComMethodTable *pMT, ITypeInfo **ppTI)
{
Expand Down
3 changes: 3 additions & 0 deletions src/coreclr/vm/stdinterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,7 @@ IErrorInfo *GetSupportedErrorInfo(IUnknown *iface, REFIID riid);
// Helpers to get the ITypeInfo* for a type.
HRESULT GetITypeInfoForEEClass(MethodTable *pMT, ITypeInfo **ppTI, bool bClassInfo = false);

// Gets the MethodTable for the associated IRecordInfo.
MethodTable* GetMethodTableForRecordInfo(IRecordInfo* recInfo);

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,6 @@ public static unsafe void SetAsByrefVariantIndirect(ref this ComVariant variant,
variant.SetAsByrefVariant(ref value);
return;
case VarEnum.VT_RECORD:
// VT_RECORD's are weird in that regardless of is the VT_BYREF flag is set or not
// they have the same internal representation.
variant = ComVariant.CreateRaw(value.VarType | VarEnum.VT_BYREF, value.GetRawDataRef<Record>());
break;
case VarEnum.VT_DECIMAL:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,11 @@ public static unsafe ComVariant CreateRaw<T>(VarEnum vt, T rawValue)
(VarEnum.VT_UNKNOWN or VarEnum.VT_DISPATCH or VarEnum.VT_LPSTR or VarEnum.VT_BSTR or VarEnum.VT_LPWSTR or VarEnum.VT_SAFEARRAY
or VarEnum.VT_CLSID or VarEnum.VT_STREAM or VarEnum.VT_STREAMED_OBJECT or VarEnum.VT_STORAGE or VarEnum.VT_STORED_OBJECT or VarEnum.VT_CF or VT_VERSIONED_STREAM, _) when sizeof(T) == nint.Size => rawValue,
(VarEnum.VT_CY or VarEnum.VT_FILETIME, 8) => rawValue,
(VarEnum.VT_RECORD, _) when sizeof(T) == sizeof(Record) => rawValue,

// VT_RECORDs are weird in that regardless of whether the VT_BYREF flag is set or not
// they have the same internal representation.
(VarEnum.VT_RECORD or VarEnum.VT_RECORD | VarEnum.VT_BYREF, _) when sizeof(T) == sizeof(Record) => rawValue,

_ when vt.HasFlag(VarEnum.VT_BYREF) && sizeof(T) == nint.Size => rawValue,
_ when vt.HasFlag(VarEnum.VT_VECTOR) && sizeof(T) == sizeof(Vector<byte>) => rawValue,
_ when vt.HasFlag(VarEnum.VT_ARRAY) && sizeof(T) == nint.Size => rawValue,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,40 @@ public void GetNativeVariantForObject_String_Success(string obj)
}
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsBuiltInComEnabled))]
public unsafe void GetNativeVariantForObject_Guid_Success()
{
var guid = new Guid("0DD3E51B-3162-4D13-B906-030F402C5BA2");
var v = new Variant();
IntPtr pNative = Marshal.AllocHGlobal(Marshal.SizeOf(v));
try
{
if (PlatformDetection.IsWindowsNanoServer)
{
Assert.Throws<NotSupportedException>(() => Marshal.GetNativeVariantForObject(guid, pNative));
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved
}
else
{
Marshal.GetNativeVariantForObject(guid, pNative);

Variant result = Marshal.PtrToStructure<Variant>(pNative);
Assert.Equal(VarEnum.VT_RECORD, (VarEnum)result.vt);
Assert.NotEqual(nint.Zero, result.pRecInfo); // We should have an IRecordInfo instance.

var expectedBytes = new ReadOnlySpan<byte>(guid.ToByteArray());
var actualBytes = new ReadOnlySpan<byte>((void*)result.bstrVal, expectedBytes.Length);
Assert.Equal(expectedBytes, actualBytes);

object o = Marshal.GetObjectForNativeVariant(pNative);
Assert.Equal(guid, o);
}
}
finally
{
Marshal.FreeHGlobal(pNative);
}
}

[ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsBuiltInComEnabled))]
[InlineData(3.14)]
public unsafe void GetNativeVariantForObject_Double_Success(double obj)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,38 @@ public void GetObjectForNativeVariant_InvalidDate_ThrowsArgumentException(double
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsBuiltInComEnabled))]
public void GetObjectForNativeVariant_NoDataForRecord_ThrowsArgumentException()
public void GetObjectForNativeVariant_NoRecordInfo_ThrowsArgumentException()
{
Variant variant = CreateVariant(VT_RECORD, new UnionTypes { _record = new Record { _recordInfo = IntPtr.Zero } });
AssertExtensions.Throws<ArgumentException>(null, () => GetObjectForNativeVariant(variant));
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsBuiltInComEnabled))]
public void GetObjectForNativeVariant_NoRecordData_ReturnsNull()
{
var recordInfo = new RecordInfo();
IntPtr pRecordInfo = Marshal.GetComInterfaceForObject<RecordInfo, IRecordInfo>(recordInfo);
try
{
Variant variant = CreateVariant(VT_RECORD, new UnionTypes
{
_record = new Record
{
_record = IntPtr.Zero,
_recordInfo = pRecordInfo
}
});
Assert.Null(GetObjectForNativeVariant(variant));
}
finally
{
Marshal.Release(pRecordInfo);
}
}

public static IEnumerable<object[]> GetObjectForNativeVariant_NoSuchGuid_TestData()
{
yield return new object[] { typeof(object).GUID };
yield return new object[] { typeof(string).GUID };
yield return new object[] { Guid.Empty };
}
Expand Down
1 change: 1 addition & 0 deletions src/tests/Interop/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ if(CLR_CMAKE_TARGET_WIN32)
add_subdirectory(COM/NativeClients/DefaultInterfaces)
add_subdirectory(COM/NativeClients/Dispatch)
add_subdirectory(COM/NativeClients/Events)
add_subdirectory(COM/NativeClients/MiscTypes)
add_subdirectory(COM/ComWrappers/MockReferenceTrackerRuntime)
add_subdirectory(COM/ComWrappers/WeakReference)

Expand Down
11 changes: 11 additions & 0 deletions src/tests/Interop/COM/Dynamic/BasicTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public void Run()

String();
Date();
SpecialCasedValueTypes();
ComObject();
Null();

Expand Down Expand Up @@ -385,6 +386,16 @@ private void Date()
Variant<DateTime>(val, expected);
}

private void SpecialCasedValueTypes()
{
{
var val = Guid.NewGuid();
var expected = val;
// Pass as variant
Variant<Guid>(val, expected);
}
}

private void ComObject()
{
Type t = Type.GetTypeFromCLSID(Guid.Parse(ServerGuids.BasicTest));
Expand Down
18 changes: 18 additions & 0 deletions src/tests/Interop/COM/NETClients/MiscTypes/App.manifest
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
<?xml version="1.0" encoding="utf-8"?>
<assembly manifestVersion="1.0" xmlns="urn:schemas-microsoft-com:asm.v1">
<assemblyIdentity
type="win32"
name="NetClientMiscTypes"
version="1.0.0.0" />

<dependency>
<dependentAssembly>
<!-- RegFree COM -->
<assemblyIdentity
type="win32"
name="COMNativeServer.X"
version="1.0.0.0"/>
</dependentAssembly>
</dependency>

</assembly>
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<!-- Needed for CMakeProjectReference, GC.WaitForPendingFinalizers -->
<RequiresProcessIsolation>true</RequiresProcessIsolation>
<ApplicationManifest>App.manifest</ApplicationManifest>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>
<ItemGroup>
<Compile Include="Program.cs" />
<Compile Include="../../ServerContracts/Server.CoClasses.cs" />
<Compile Include="../../ServerContracts/Server.Contracts.cs" />
<Compile Include="../../ServerContracts/ServerGuids.cs" />
</ItemGroup>
<ItemGroup>
<CMakeProjectReference Include="../../NativeServer/CMakeLists.txt" />
<ProjectReference Include="$(TestLibraryProjectPath)" />
</ItemGroup>
</Project>
Loading
Loading