diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/StringMarshallingTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/StringMarshallingTests.cs index 175296047f845..c359be79f2967 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/StringMarshallingTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/StringMarshallingTests.cs @@ -22,6 +22,9 @@ public unsafe partial class StringMarshallingTests [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "new_utf16_marshalling")] public static partial void* NewIUtf16Marshalling(); + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "new_string_marshalling_override")] + public static partial void* NewStringMarshallingOverride(); + [GeneratedComClass] internal partial class Utf8MarshalledClass : IUTF8Marshalling { @@ -107,5 +110,26 @@ public void RcwToCcw() customUtf16ComObject.SetString("Set from COM object"); Assert.Equal(customUtf16.GetString(), customUtf16ComObject.GetString()); } + + [Fact] + public void MarshalAsAndMarshalUsingOverrideStringMarshalling() + { + var ptr = NewStringMarshallingOverride(); + var cw = new StrategyBasedComWrappers(); + var obj = cw.GetOrCreateObjectForComInstance((nint)ptr, CreateObjectFlags.None); + var stringMarshallingOverride = (IStringMarshallingOverride)obj; + Assert.Equal("Your string: MyUtf8String", stringMarshallingOverride.StringMarshallingUtf8("MyUtf8String")); + Assert.Equal("Your string: MyLPWStrString", stringMarshallingOverride.MarshalAsLPWString("MyLPWStrString")); + Assert.Equal("Your string: MyUtf16String", stringMarshallingOverride.MarshalUsingUtf16("MyUtf16String")); + + // Make sure the shadowing methods generated for the derived interface also follow the rules + var stringMarshallingOverrideDerived = (IStringMarshallingOverrideDerived)obj; + Assert.Equal("Your string: MyUtf8String", stringMarshallingOverrideDerived.StringMarshallingUtf8("MyUtf8String")); + Assert.Equal("Your string: MyLPWStrString", stringMarshallingOverrideDerived.MarshalAsLPWString("MyLPWStrString")); + Assert.Equal("Your string: MyUtf16String", stringMarshallingOverrideDerived.MarshalUsingUtf16("MyUtf16String")); + Assert.Equal("Your string 2: MyUtf8String", stringMarshallingOverrideDerived.StringMarshallingUtf8_2("MyUtf8String")); + Assert.Equal("Your string 2: MyLPWStrString", stringMarshallingOverrideDerived.MarshalAsLPWString_2("MyLPWStrString")); + Assert.Equal("Your string 2: MyUtf16String", stringMarshallingOverrideDerived.MarshalUsingUtf16_2("MyUtf16String")); + } } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaceGenerator/StringMarshallingOverride.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaceGenerator/StringMarshallingOverride.cs new file mode 100644 index 0000000000000..df01a1045700b --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaceGenerator/StringMarshallingOverride.cs @@ -0,0 +1,215 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; +using System.Text; +using System.Threading.Tasks; +using SharedTypes.ComInterfaces; +using static System.Runtime.InteropServices.ComWrappers; + +namespace NativeExports.ComInterfaceGenerator +{ + public unsafe partial class StringMarshallingOverride + { + [UnmanagedCallersOnly(EntryPoint = "new_string_marshalling_override")] + public static void* CreateStringMarshallingOverrideObject() + { + MyComWrapper cw = new(); + var myObject = new Implementation(); + nint ptr = cw.GetOrCreateComInterfaceForObject(myObject, CreateComInterfaceFlags.None); + return (void*)ptr; + } + + class MyComWrapper : ComWrappers + { + static void* _s_comInterfaceVTable = null; + static void* S_VTable + { + get + { + if (_s_comInterfaceVTable != null) + return _s_comInterfaceVTable; + void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(GetAndSetInt), sizeof(void*) * 6); + GetIUnknownImpl(out var fpQueryInterface, out var fpAddReference, out var fpRelease); + vtable[0] = (void*)fpQueryInterface; + vtable[1] = (void*)fpAddReference; + vtable[2] = (void*)fpRelease; + vtable[3] = (delegate* unmanaged)&Implementation.ABI.StringMarshallingUtf8; + vtable[4] = (delegate* unmanaged)&Implementation.ABI.MarshalAsLPWStr; + vtable[5] = (delegate* unmanaged)&Implementation.ABI.MarshalUsingUtf16; + _s_comInterfaceVTable = vtable; + return _s_comInterfaceVTable; + } + } + + static void* _s_derivedVTable = null; + static void* S_DerivedVTable + { + get + { + if (_s_comInterfaceVTable != null) + return _s_comInterfaceVTable; + void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(GetAndSetInt), sizeof(void*) * 9); + GetIUnknownImpl(out var fpQueryInterface, out var fpAddReference, out var fpRelease); + vtable[0] = (void*)fpQueryInterface; + vtable[1] = (void*)fpAddReference; + vtable[2] = (void*)fpRelease; + vtable[3] = (delegate* unmanaged)&Implementation.ABI.StringMarshallingUtf8; + vtable[4] = (delegate* unmanaged)&Implementation.ABI.MarshalAsLPWStr; + vtable[5] = (delegate* unmanaged)&Implementation.ABI.MarshalUsingUtf16; + vtable[6] = (delegate* unmanaged)&Implementation.ABI.StringMarshallingUtf8_2; + vtable[7] = (delegate* unmanaged)&Implementation.ABI.MarshalAsLPWStr_2; + vtable[8] = (delegate* unmanaged)&Implementation.ABI.MarshalUsingUtf16_2; + _s_comInterfaceVTable = vtable; + return _s_comInterfaceVTable; + } + } + + protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) + { + if (obj is IStringMarshallingOverrideDerived) + { + ComInterfaceEntry* comInterfaceEntry = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(Implementation), sizeof(ComInterfaceEntry) * 2); + comInterfaceEntry[0].IID = new Guid(IStringMarshallingOverrideDerived._guid); + comInterfaceEntry[0].Vtable = (nint)S_DerivedVTable; + comInterfaceEntry[1].IID = new Guid(IStringMarshallingOverride._guid); + comInterfaceEntry[1].Vtable = (nint)S_VTable; + count = 2; + return comInterfaceEntry; + } + if (obj is IStringMarshallingOverride) + { + ComInterfaceEntry* comInterfaceEntry = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(Implementation), sizeof(ComInterfaceEntry)); + comInterfaceEntry->IID = new Guid(IStringMarshallingOverride._guid); + comInterfaceEntry->Vtable = (nint)S_VTable; + count = 1; + return comInterfaceEntry; + } + count = 0; + return null; + } + + protected override object? CreateObject(nint externalComObject, CreateObjectFlags flags) => throw new NotImplementedException(); + protected override void ReleaseObjects(IEnumerable objects) => throw new NotImplementedException(); + } + + partial class Implementation : IStringMarshallingOverride, IStringMarshallingOverrideDerived + { + string _data = "Your string: "; + string IStringMarshallingOverride.StringMarshallingUtf8(string input) => _data + input; + string IStringMarshallingOverride.MarshalAsLPWString(string input) => _data + input; + string IStringMarshallingOverride.MarshalUsingUtf16(string input) => _data + input; + + string _data2 = "Your string 2: "; + string IStringMarshallingOverrideDerived.StringMarshallingUtf8_2(string input) => _data2 + input; + string IStringMarshallingOverrideDerived.MarshalAsLPWString_2(string input) => _data2 + input; + string IStringMarshallingOverrideDerived.MarshalUsingUtf16_2(string input) => _data2 + input; + + // Provides function pointers in the COM format to use in COM VTables + public static class ABI + { + [UnmanagedCallersOnly] + public static int StringMarshallingUtf8(void* @this, byte* input, byte** output) + { + try + { + string inputStr = Utf8StringMarshaller.ConvertToManaged(input); + string currValue = ComInterfaceDispatch.GetInstance((ComInterfaceDispatch*)@this).StringMarshallingUtf8(inputStr); + *output = Utf8StringMarshaller.ConvertToUnmanaged(currValue); + return 0; + } + catch (Exception e) + { + return e.HResult; + } + } + + [UnmanagedCallersOnly] + public static int MarshalAsLPWStr(void* @this, ushort* input, ushort** output) + { + try + { + string inputStr = Utf16StringMarshaller.ConvertToManaged(input); + string currValue = ComInterfaceDispatch.GetInstance((ComInterfaceDispatch*)@this).MarshalAsLPWString(inputStr); + *output = Utf16StringMarshaller.ConvertToUnmanaged(currValue); + return 0; + } + catch (Exception e) + { + return e.HResult; + } + } + + [UnmanagedCallersOnly] + public static int MarshalUsingUtf16(void* @this, ushort* input, ushort** output) + { + try + { + string inputStr = Utf16StringMarshaller.ConvertToManaged(input); + string currValue = ComInterfaceDispatch.GetInstance((ComInterfaceDispatch*)@this).MarshalUsingUtf16(inputStr); + *output = Utf16StringMarshaller.ConvertToUnmanaged(currValue); + return 0; + } + catch (Exception e) + { + return e.HResult; + } + } + + [UnmanagedCallersOnly] + public static int StringMarshallingUtf8_2(void* @this, byte* input, byte** output) + { + try + { + string inputStr = Utf8StringMarshaller.ConvertToManaged(input); + string currValue = ComInterfaceDispatch.GetInstance((ComInterfaceDispatch*)@this).StringMarshallingUtf8_2(inputStr); + *output = Utf8StringMarshaller.ConvertToUnmanaged(currValue); + return 0; + } + catch (Exception e) + { + return e.HResult; + } + } + + [UnmanagedCallersOnly] + public static int MarshalAsLPWStr_2(void* @this, ushort* input, ushort** output) + { + try + { + string inputStr = Utf16StringMarshaller.ConvertToManaged(input); + string currValue = ComInterfaceDispatch.GetInstance((ComInterfaceDispatch*)@this).MarshalAsLPWString_2(inputStr); + *output = Utf16StringMarshaller.ConvertToUnmanaged(currValue); + return 0; + } + catch (Exception e) + { + return e.HResult; + } + } + + [UnmanagedCallersOnly] + public static int MarshalUsingUtf16_2(void* @this, ushort* input, ushort** output) + { + try + { + string inputStr = Utf16StringMarshaller.ConvertToManaged(input); + string currValue = ComInterfaceDispatch.GetInstance((ComInterfaceDispatch*)@this).MarshalUsingUtf16_2(inputStr); + *output = Utf16StringMarshaller.ConvertToUnmanaged(currValue); + return 0; + } + catch (Exception e) + { + return e.HResult; + } + } + } + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStringMarshallingOverride.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStringMarshallingOverride.cs new file mode 100644 index 0000000000000..cd7c8f620dc93 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStringMarshallingOverride.cs @@ -0,0 +1,27 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; +using System.Text; +using System.Threading.Tasks; + +namespace SharedTypes.ComInterfaces +{ + [GeneratedComInterface(StringMarshalling = System.Runtime.InteropServices.StringMarshalling.Utf8)] + [Guid(_guid)] + internal partial interface IStringMarshallingOverride + { + public const string _guid = "5146B7DB-0588-469B-B8E5-B38090A2FC15"; + string StringMarshallingUtf8(string input); + + [return: MarshalAs(UnmanagedType.LPWStr)] + string MarshalAsLPWString([MarshalAs(UnmanagedType.LPWStr)] string input); + + [return: MarshalUsing(typeof(Utf16StringMarshaller))] + string MarshalUsingUtf16([MarshalUsing(typeof(Utf16StringMarshaller))] string input); + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStringMarshallingOverrideDerived.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStringMarshallingOverrideDerived.cs new file mode 100644 index 0000000000000..079db461a66bd --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStringMarshallingOverrideDerived.cs @@ -0,0 +1,27 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices.Marshalling; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading.Tasks; + +namespace SharedTypes.ComInterfaces +{ + [GeneratedComInterface(StringMarshalling = StringMarshalling.Utf8)] + [Guid(_guid)] + internal partial interface IStringMarshallingOverrideDerived : IStringMarshallingOverride + { + public new const string _guid = "3AFFE3FD-D11E-4195-8250-0C73321977A0"; + string StringMarshallingUtf8_2(string input); + + [return: MarshalAs(UnmanagedType.LPWStr)] + string MarshalAsLPWString_2([MarshalAs(UnmanagedType.LPWStr)] string input); + + [return: MarshalUsing(typeof(Utf16StringMarshaller))] + string MarshalUsingUtf16_2([MarshalUsing(typeof(Utf16StringMarshaller))] string input); + } +}