diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index f2d6c30d7cb7b..4165e7a77ec8a 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -242,7 +242,8 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M AttributeData? generatedComAttribute = null; foreach (var attr in symbol.ContainingType.GetAttributes()) { - if (generatedComAttribute is not null && attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute) + if (generatedComAttribute is null + && attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute) { generatedComAttribute = attr; } @@ -256,8 +257,23 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M generatorDiagnostics.ReportConfigurationNotSupported(lcidConversionAttr, nameof(TypeNames.LCIDConversionAttribute)); } + var generatedComInterfaceAttributeData = new InteropAttributeCompilationData(); + if (generatedComAttribute is not null) + { + var args = generatedComAttribute.NamedArguments.ToImmutableDictionary(); + generatedComInterfaceAttributeData = generatedComInterfaceAttributeData.WithValuesFromNamedArguments(args); + } // Create the stub. - var signatureContext = SignatureContext.Create(symbol, DefaultMarshallingInfoParser.Create(environment, generatorDiagnostics, symbol, new InteropAttributeCompilationData(), generatedComAttribute), environment, typeof(VtableIndexStubGenerator).Assembly); + var signatureContext = SignatureContext.Create( + symbol, + DefaultMarshallingInfoParser.Create( + environment, + generatorDiagnostics, + symbol, + generatedComInterfaceAttributeData, + generatedComAttribute), + environment, + typeof(VtableIndexStubGenerator).Assembly); if (!symbol.MethodImplementationFlags.HasFlag(MethodImplAttributes.PreserveSig)) { diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingInfoParser.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingInfoParser.cs index 410166893a514..7629c68e9713c 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingInfoParser.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingInfoParser.cs @@ -73,7 +73,7 @@ public interface IMarshallingInfoAttributeParser } /// - /// A provider of marshalling info based only on the managed type any any previously parsed use-site attribute information + /// A provider of marshalling info based only on the managed type and any previously parsed use-site attribute information /// public interface ITypeBasedMarshallingInfoProvider { diff --git a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs index 5f82435829d89..c218273a37848 100644 --- a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs +++ b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs @@ -370,6 +370,8 @@ public GeneratedComClassAttribute() { } public partial class GeneratedComInterfaceAttribute : System.Attribute { public GeneratedComInterfaceAttribute() { } + public StringMarshalling StringMarshalling { get { throw null; } set { } } + public Type? StringMarshallingCustomType { get { throw null; } set { } } } [System.CLSCompliantAttribute(false)] public partial interface IComExposedClass diff --git a/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/GeneratedComInterfaceAttribute.cs b/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/GeneratedComInterfaceAttribute.cs index 7d81e174d9666..cbe1a1508ba96 100644 --- a/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/GeneratedComInterfaceAttribute.cs +++ b/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/GeneratedComInterfaceAttribute.cs @@ -6,5 +6,26 @@ namespace System.Runtime.InteropServices.Marshalling [AttributeUsage(AttributeTargets.Interface)] public class GeneratedComInterfaceAttribute : Attribute { + /// + /// Gets or sets how to marshal string arguments to all methods on the interface. + /// If the attributed interface inherits from another interface with , + /// it must have the same values for and . + /// + /// + /// If this field is set to a value other than , + /// must not be specified. + /// + public StringMarshalling StringMarshalling { get; set; } + + /// + /// Gets or sets the used to control how string arguments are marshalled for all methods on the interface. + /// If the attributed interface inherits from another interface with , + /// it must have the same values for and . + /// + /// + /// If this field is specified, must not be specified + /// or must be set to . + /// + public Type? StringMarshallingCustomType { get; set; } } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/StringMarshallingTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/StringMarshallingTests.cs new file mode 100644 index 0000000000000..175296047f845 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/StringMarshallingTests.cs @@ -0,0 +1,111 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using SharedTypes.ComInterfaces; +using Xunit; + +namespace ComInterfaceGenerator.Tests +{ + public unsafe partial class StringMarshallingTests + { + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "new_utf8_marshalling")] + public static partial void* NewIUtf8Marshalling(); + + [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "new_utf16_marshalling")] + public static partial void* NewIUtf16Marshalling(); + + [GeneratedComClass] + internal partial class Utf8MarshalledClass : IUTF8Marshalling + { + string _data = "Hello, World!"; + + public string GetString() => _data; + public void SetString(string value) => _data = value; + } + + [GeneratedComClass] + internal partial class Utf16MarshalledClass : IUTF16Marshalling + { + string _data = "Hello, World!"; + + public string GetString() => _data; + public void SetString(string value) => _data = value; + } + + [GeneratedComClass] + internal partial class CustomUtf16MarshalledClass : ICustomStringMarshallingUtf16 + { + string _data = "Hello, World!"; + + public string GetString() => _data; + public void SetString(string value) => _data = value; + } + + [Fact] + public void ValidateStringMarshallingRCW() + { + var cw = new StrategyBasedComWrappers(); + var utf8 = NewIUtf8Marshalling(); + IUTF8Marshalling obj8 = (IUTF8Marshalling)cw.GetOrCreateObjectForComInstance((nint)utf8, CreateObjectFlags.None); + string value = obj8.GetString(); + Assert.Equal("Hello, World!", value); + obj8.SetString("TestString"); + value = obj8.GetString(); + Assert.Equal("TestString", value); + + var utf16 = NewIUtf16Marshalling(); + IUTF16Marshalling obj16 = (IUTF16Marshalling)cw.GetOrCreateObjectForComInstance((nint)utf16, CreateObjectFlags.None); + Assert.Equal("Hello, World!", obj16.GetString()); + obj16.SetString("TestString"); + Assert.Equal("TestString", obj16.GetString()); + + var utf16custom = NewIUtf16Marshalling(); + ICustomStringMarshallingUtf16 objCustom = (ICustomStringMarshallingUtf16)cw.GetOrCreateObjectForComInstance((nint)utf16custom, CreateObjectFlags.None); + Assert.Equal("Hello, World!", objCustom.GetString()); + objCustom.SetString("TestString"); + Assert.Equal("TestString", objCustom.GetString()); + } + + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/85795", TargetFrameworkMonikers.Any)] + public void RcwToCcw() + { + var cw = new StrategyBasedComWrappers(); + + var utf8 = new Utf8MarshalledClass(); + var utf8ComInstance = cw.GetOrCreateComInterfaceForObject(utf8, CreateComInterfaceFlags.None); + var utf8ComObject = (IUTF8Marshalling)cw.GetOrCreateObjectForComInstance(utf8ComInstance, CreateObjectFlags.None); + Assert.Equal(utf8.GetString(), utf8ComObject.GetString()); + utf8.SetString("Set from CLR object"); + Assert.Equal(utf8.GetString(), utf8ComObject.GetString()); + utf8ComObject.SetString("Set from COM object"); + Assert.Equal(utf8.GetString(), utf8ComObject.GetString()); + + var utf16 = new Utf16MarshalledClass(); + var utf16ComInstance = cw.GetOrCreateComInterfaceForObject(utf16, CreateComInterfaceFlags.None); + var utf16ComObject = (IUTF16Marshalling)cw.GetOrCreateObjectForComInstance(utf16ComInstance, CreateObjectFlags.None); + Assert.Equal(utf16.GetString(), utf16ComObject.GetString()); + utf16.SetString("Set from CLR object"); + Assert.Equal(utf16.GetString(), utf16ComObject.GetString()); + utf16ComObject.SetString("Set from COM object"); + Assert.Equal(utf16.GetString(), utf16ComObject.GetString()); + + var customUtf16 = new CustomUtf16MarshalledClass(); + var customUtf16ComInstance = cw.GetOrCreateComInterfaceForObject(customUtf16, CreateComInterfaceFlags.None); + var customUtf16ComObject = (ICustomStringMarshallingUtf16)cw.GetOrCreateObjectForComInstance(customUtf16ComInstance, CreateObjectFlags.None); + Assert.Equal(customUtf16.GetString(), customUtf16ComObject.GetString()); + customUtf16.SetString("Set from CLR object"); + Assert.Equal(customUtf16.GetString(), customUtf16ComObject.GetString()); + customUtf16ComObject.SetString("Set from COM object"); + Assert.Equal(customUtf16.GetString(), customUtf16ComObject.GetString()); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaceGenerator/StringMarshalling.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaceGenerator/StringMarshalling.cs new file mode 100644 index 0000000000000..9059f53c4cc8a --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaceGenerator/StringMarshalling.cs @@ -0,0 +1,201 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; +using System.Text; +using System.Threading.Tasks; +using SharedTypes.ComInterfaces; +using static System.Runtime.InteropServices.ComWrappers; + +namespace NativeExports.ComInterfaceGenerator +{ + public unsafe class StringMarshalling + { + [UnmanagedCallersOnly(EntryPoint = "new_utf8_marshalling")] + public static void* CreateUtf8ComObject() + { + MyComWrapper cw = new(); + var myObject = new Utf8Implementation(); + nint ptr = cw.GetOrCreateComInterfaceForObject(myObject, CreateComInterfaceFlags.None); + + return (void*)ptr; + } + + [UnmanagedCallersOnly(EntryPoint = "new_utf16_marshalling")] + public static void* CreateUtf16ComObject() + { + MyComWrapper cw = new(); + var myObject = new Utf16Implementation(); + nint ptr = cw.GetOrCreateComInterfaceForObject(myObject, CreateComInterfaceFlags.None); + + return (void*)ptr; + } + + class MyComWrapper : ComWrappers + { + static void* _s_comInterface1VTable = null; + static void* _s_comInterface2VTable = null; + static void* S_Utf8VTable + { + get + { + if (_s_comInterface1VTable != null) + return _s_comInterface1VTable; + void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(GetAndSetInt), sizeof(void*) * 5); + GetIUnknownImpl(out var fpQueryInterface, out var fpAddReference, out var fpRelease); + vtable[0] = (void*)fpQueryInterface; + vtable[1] = (void*)fpAddReference; + vtable[2] = (void*)fpRelease; + vtable[3] = (delegate* unmanaged)&Utf8Implementation.ABI.GetStringUtf8; + vtable[4] = (delegate* unmanaged)&Utf8Implementation.ABI.SetStringUtf8; + _s_comInterface1VTable = vtable; + return _s_comInterface1VTable; + } + } + static void* S_Utf16VTable + { + get + { + if (_s_comInterface2VTable != null) + return _s_comInterface2VTable; + void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(GetAndSetInt), sizeof(void*) * 5); + GetIUnknownImpl(out var fpQueryInterface, out var fpAddReference, out var fpRelease); + vtable[0] = (void*)fpQueryInterface; + vtable[1] = (void*)fpAddReference; + vtable[2] = (void*)fpRelease; + vtable[3] = (delegate* unmanaged)&Utf16Implementation.ABI.GetStringUtf16; + vtable[4] = (delegate* unmanaged)&Utf16Implementation.ABI.SetStringUtf16; + _s_comInterface2VTable = vtable; + return _s_comInterface2VTable; + } + } + + protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) + { + if (obj is IUTF8Marshalling) + { + ComInterfaceEntry* comInterfaceEntry = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(Utf8Implementation), sizeof(ComInterfaceEntry)); + comInterfaceEntry->IID = new Guid(IUTF8Marshalling._guid); + comInterfaceEntry->Vtable = (nint)S_Utf8VTable; + count = 1; + return comInterfaceEntry; + } + else if (obj is IUTF16Marshalling) + { + ComInterfaceEntry* comInterfaceEntry = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(Utf16Implementation), sizeof(ComInterfaceEntry)); + comInterfaceEntry->IID = new Guid(IUTF16Marshalling._guid); + comInterfaceEntry->Vtable = (nint)S_Utf16VTable; + count = 1; + return comInterfaceEntry; + } + count = 0; + return null; + } + + protected override object? CreateObject(nint externalComObject, CreateObjectFlags flags) => throw new NotImplementedException(); + protected override void ReleaseObjects(IEnumerable objects) => throw new NotImplementedException(); + } + + class Utf8Implementation : IUTF8Marshalling + { + string _data = "Hello, World!"; + + string IUTF8Marshalling.GetString() + { + return _data; + } + void IUTF8Marshalling.SetString(string x) + { + _data = x; + } + + // Provides function pointers in the COM format to use in COM VTables + public static class ABI + { + [UnmanagedCallersOnly] + public static int GetStringUtf8(void* @this, byte** value) + { + try + { + string currValue = ComInterfaceDispatch.GetInstance((ComInterfaceDispatch*)@this).GetString(); + *value = Utf8StringMarshaller.ConvertToUnmanaged(currValue); + return 0; + } + catch (Exception e) + { + return e.HResult; + } + } + + [UnmanagedCallersOnly] + public static int SetStringUtf8(void* @this, byte* newValue) + { + try + { + string value = Utf8StringMarshaller.ConvertToManaged(newValue); + ComInterfaceDispatch.GetInstance((ComInterfaceDispatch*)@this).SetString(value); + return 0; + } + catch (Exception e) + { + return e.HResult; + } + } + } + } + + class Utf16Implementation : IUTF16Marshalling + { + string _data = "Hello, World!"; + + string IUTF16Marshalling.GetString() + { + return _data; + } + void IUTF16Marshalling.SetString(string x) + { + _data = x; + } + + // Provides function pointers in the COM format to use in COM VTables + public static class ABI + { + [UnmanagedCallersOnly] + public static int GetStringUtf16(void* @this, ushort** value) + { + try + { + string currValue = ComInterfaceDispatch.GetInstance((ComInterfaceDispatch*)@this).GetString(); + *value = Utf16StringMarshaller.ConvertToUnmanaged(currValue); + return 0; + } + catch (Exception e) + { + return e.HResult; + } + } + + [UnmanagedCallersOnly] + public static int SetStringUtf16(void* @this, ushort* newValue) + { + try + { + string value = Utf16StringMarshaller.ConvertToManaged(newValue); + ComInterfaceDispatch.GetInstance((ComInterfaceDispatch*)@this).SetString(value); + return 0; + } + catch (Exception e) + { + return e.HResult; + } + } + } + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/ICustomStringMarshallingUtf16.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/ICustomStringMarshallingUtf16.cs new file mode 100644 index 0000000000000..d792b62d6d5a9 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/ICustomStringMarshallingUtf16.cs @@ -0,0 +1,20 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; + +namespace SharedTypes.ComInterfaces +{ + [Guid(_guid)] + [GeneratedComInterface(StringMarshalling = StringMarshalling.Custom, StringMarshallingCustomType = typeof(Utf16StringMarshaller))] + internal partial interface ICustomStringMarshallingUtf16 + { + public string GetString(); + + public void SetString(string value); + + public const string _guid = "E11D5F3E-DD57-41A6-A59E-7D110551A760"; + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IUTF16Marshalling.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IUTF16Marshalling.cs new file mode 100644 index 0000000000000..2ef5534aa6b23 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IUTF16Marshalling.cs @@ -0,0 +1,20 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; + +namespace SharedTypes.ComInterfaces +{ + [Guid(_guid)] + [GeneratedComInterface(StringMarshalling = StringMarshalling.Utf16)] + internal partial interface IUTF16Marshalling + { + public string GetString(); + + public void SetString(string value); + + public const string _guid = "E11D5F3E-DD57-41A6-A59E-7D110551A760"; + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IUTF8Marshalling.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IUTF8Marshalling.cs new file mode 100644 index 0000000000000..2689425abd506 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IUTF8Marshalling.cs @@ -0,0 +1,20 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; + +namespace SharedTypes.ComInterfaces +{ + [Guid(_guid)] + [GeneratedComInterface(StringMarshalling = StringMarshalling.Utf8)] + internal partial interface IUTF8Marshalling + { + public string GetString(); + + public void SetString(string value); + + public const string _guid = "E11D5F3E-DD57-41A6-A59E-7D110551A760"; + } +}