diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs
index f2d6c30d7cb7b..4165e7a77ec8a 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs
@@ -242,7 +242,8 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
AttributeData? generatedComAttribute = null;
foreach (var attr in symbol.ContainingType.GetAttributes())
{
- if (generatedComAttribute is not null && attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute)
+ if (generatedComAttribute is null
+ && attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute)
{
generatedComAttribute = attr;
}
@@ -256,8 +257,23 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
generatorDiagnostics.ReportConfigurationNotSupported(lcidConversionAttr, nameof(TypeNames.LCIDConversionAttribute));
}
+ var generatedComInterfaceAttributeData = new InteropAttributeCompilationData();
+ if (generatedComAttribute is not null)
+ {
+ var args = generatedComAttribute.NamedArguments.ToImmutableDictionary();
+ generatedComInterfaceAttributeData = generatedComInterfaceAttributeData.WithValuesFromNamedArguments(args);
+ }
// Create the stub.
- var signatureContext = SignatureContext.Create(symbol, DefaultMarshallingInfoParser.Create(environment, generatorDiagnostics, symbol, new InteropAttributeCompilationData(), generatedComAttribute), environment, typeof(VtableIndexStubGenerator).Assembly);
+ var signatureContext = SignatureContext.Create(
+ symbol,
+ DefaultMarshallingInfoParser.Create(
+ environment,
+ generatorDiagnostics,
+ symbol,
+ generatedComInterfaceAttributeData,
+ generatedComAttribute),
+ environment,
+ typeof(VtableIndexStubGenerator).Assembly);
if (!symbol.MethodImplementationFlags.HasFlag(MethodImplAttributes.PreserveSig))
{
diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingInfoParser.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingInfoParser.cs
index 410166893a514..7629c68e9713c 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingInfoParser.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingInfoParser.cs
@@ -73,7 +73,7 @@ public interface IMarshallingInfoAttributeParser
}
///
- /// A provider of marshalling info based only on the managed type any any previously parsed use-site attribute information
+ /// A provider of marshalling info based only on the managed type and any previously parsed use-site attribute information
///
public interface ITypeBasedMarshallingInfoProvider
{
diff --git a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs
index 5f82435829d89..c218273a37848 100644
--- a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs
+++ b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs
@@ -370,6 +370,8 @@ public GeneratedComClassAttribute() { }
public partial class GeneratedComInterfaceAttribute : System.Attribute
{
public GeneratedComInterfaceAttribute() { }
+ public StringMarshalling StringMarshalling { get { throw null; } set { } }
+ public Type? StringMarshallingCustomType { get { throw null; } set { } }
}
[System.CLSCompliantAttribute(false)]
public partial interface IComExposedClass
diff --git a/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/GeneratedComInterfaceAttribute.cs b/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/GeneratedComInterfaceAttribute.cs
index 7d81e174d9666..cbe1a1508ba96 100644
--- a/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/GeneratedComInterfaceAttribute.cs
+++ b/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/GeneratedComInterfaceAttribute.cs
@@ -6,5 +6,26 @@ namespace System.Runtime.InteropServices.Marshalling
[AttributeUsage(AttributeTargets.Interface)]
public class GeneratedComInterfaceAttribute : Attribute
{
+ ///
+ /// Gets or sets how to marshal string arguments to all methods on the interface.
+ /// If the attributed interface inherits from another interface with ,
+ /// it must have the same values for and .
+ ///
+ ///
+ /// If this field is set to a value other than ,
+ /// must not be specified.
+ ///
+ public StringMarshalling StringMarshalling { get; set; }
+
+ ///
+ /// Gets or sets the used to control how string arguments are marshalled for all methods on the interface.
+ /// If the attributed interface inherits from another interface with ,
+ /// it must have the same values for and .
+ ///
+ ///
+ /// If this field is specified, must not be specified
+ /// or must be set to .
+ ///
+ public Type? StringMarshallingCustomType { get; set; }
}
}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/StringMarshallingTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/StringMarshallingTests.cs
new file mode 100644
index 0000000000000..175296047f845
--- /dev/null
+++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/StringMarshallingTests.cs
@@ -0,0 +1,111 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.InteropServices;
+using System.Runtime.InteropServices.Marshalling;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using SharedTypes.ComInterfaces;
+using Xunit;
+
+namespace ComInterfaceGenerator.Tests
+{
+ public unsafe partial class StringMarshallingTests
+ {
+ [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "new_utf8_marshalling")]
+ public static partial void* NewIUtf8Marshalling();
+
+ [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "new_utf16_marshalling")]
+ public static partial void* NewIUtf16Marshalling();
+
+ [GeneratedComClass]
+ internal partial class Utf8MarshalledClass : IUTF8Marshalling
+ {
+ string _data = "Hello, World!";
+
+ public string GetString() => _data;
+ public void SetString(string value) => _data = value;
+ }
+
+ [GeneratedComClass]
+ internal partial class Utf16MarshalledClass : IUTF16Marshalling
+ {
+ string _data = "Hello, World!";
+
+ public string GetString() => _data;
+ public void SetString(string value) => _data = value;
+ }
+
+ [GeneratedComClass]
+ internal partial class CustomUtf16MarshalledClass : ICustomStringMarshallingUtf16
+ {
+ string _data = "Hello, World!";
+
+ public string GetString() => _data;
+ public void SetString(string value) => _data = value;
+ }
+
+ [Fact]
+ public void ValidateStringMarshallingRCW()
+ {
+ var cw = new StrategyBasedComWrappers();
+ var utf8 = NewIUtf8Marshalling();
+ IUTF8Marshalling obj8 = (IUTF8Marshalling)cw.GetOrCreateObjectForComInstance((nint)utf8, CreateObjectFlags.None);
+ string value = obj8.GetString();
+ Assert.Equal("Hello, World!", value);
+ obj8.SetString("TestString");
+ value = obj8.GetString();
+ Assert.Equal("TestString", value);
+
+ var utf16 = NewIUtf16Marshalling();
+ IUTF16Marshalling obj16 = (IUTF16Marshalling)cw.GetOrCreateObjectForComInstance((nint)utf16, CreateObjectFlags.None);
+ Assert.Equal("Hello, World!", obj16.GetString());
+ obj16.SetString("TestString");
+ Assert.Equal("TestString", obj16.GetString());
+
+ var utf16custom = NewIUtf16Marshalling();
+ ICustomStringMarshallingUtf16 objCustom = (ICustomStringMarshallingUtf16)cw.GetOrCreateObjectForComInstance((nint)utf16custom, CreateObjectFlags.None);
+ Assert.Equal("Hello, World!", objCustom.GetString());
+ objCustom.SetString("TestString");
+ Assert.Equal("TestString", objCustom.GetString());
+ }
+
+ [Fact]
+ [ActiveIssue("https://github.com/dotnet/runtime/issues/85795", TargetFrameworkMonikers.Any)]
+ public void RcwToCcw()
+ {
+ var cw = new StrategyBasedComWrappers();
+
+ var utf8 = new Utf8MarshalledClass();
+ var utf8ComInstance = cw.GetOrCreateComInterfaceForObject(utf8, CreateComInterfaceFlags.None);
+ var utf8ComObject = (IUTF8Marshalling)cw.GetOrCreateObjectForComInstance(utf8ComInstance, CreateObjectFlags.None);
+ Assert.Equal(utf8.GetString(), utf8ComObject.GetString());
+ utf8.SetString("Set from CLR object");
+ Assert.Equal(utf8.GetString(), utf8ComObject.GetString());
+ utf8ComObject.SetString("Set from COM object");
+ Assert.Equal(utf8.GetString(), utf8ComObject.GetString());
+
+ var utf16 = new Utf16MarshalledClass();
+ var utf16ComInstance = cw.GetOrCreateComInterfaceForObject(utf16, CreateComInterfaceFlags.None);
+ var utf16ComObject = (IUTF16Marshalling)cw.GetOrCreateObjectForComInstance(utf16ComInstance, CreateObjectFlags.None);
+ Assert.Equal(utf16.GetString(), utf16ComObject.GetString());
+ utf16.SetString("Set from CLR object");
+ Assert.Equal(utf16.GetString(), utf16ComObject.GetString());
+ utf16ComObject.SetString("Set from COM object");
+ Assert.Equal(utf16.GetString(), utf16ComObject.GetString());
+
+ var customUtf16 = new CustomUtf16MarshalledClass();
+ var customUtf16ComInstance = cw.GetOrCreateComInterfaceForObject(customUtf16, CreateComInterfaceFlags.None);
+ var customUtf16ComObject = (ICustomStringMarshallingUtf16)cw.GetOrCreateObjectForComInstance(customUtf16ComInstance, CreateObjectFlags.None);
+ Assert.Equal(customUtf16.GetString(), customUtf16ComObject.GetString());
+ customUtf16.SetString("Set from CLR object");
+ Assert.Equal(customUtf16.GetString(), customUtf16ComObject.GetString());
+ customUtf16ComObject.SetString("Set from COM object");
+ Assert.Equal(customUtf16.GetString(), customUtf16ComObject.GetString());
+ }
+ }
+}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaceGenerator/StringMarshalling.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaceGenerator/StringMarshalling.cs
new file mode 100644
index 0000000000000..9059f53c4cc8a
--- /dev/null
+++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaceGenerator/StringMarshalling.cs
@@ -0,0 +1,201 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
+using System.Runtime.InteropServices.Marshalling;
+using System.Text;
+using System.Threading.Tasks;
+using SharedTypes.ComInterfaces;
+using static System.Runtime.InteropServices.ComWrappers;
+
+namespace NativeExports.ComInterfaceGenerator
+{
+ public unsafe class StringMarshalling
+ {
+ [UnmanagedCallersOnly(EntryPoint = "new_utf8_marshalling")]
+ public static void* CreateUtf8ComObject()
+ {
+ MyComWrapper cw = new();
+ var myObject = new Utf8Implementation();
+ nint ptr = cw.GetOrCreateComInterfaceForObject(myObject, CreateComInterfaceFlags.None);
+
+ return (void*)ptr;
+ }
+
+ [UnmanagedCallersOnly(EntryPoint = "new_utf16_marshalling")]
+ public static void* CreateUtf16ComObject()
+ {
+ MyComWrapper cw = new();
+ var myObject = new Utf16Implementation();
+ nint ptr = cw.GetOrCreateComInterfaceForObject(myObject, CreateComInterfaceFlags.None);
+
+ return (void*)ptr;
+ }
+
+ class MyComWrapper : ComWrappers
+ {
+ static void* _s_comInterface1VTable = null;
+ static void* _s_comInterface2VTable = null;
+ static void* S_Utf8VTable
+ {
+ get
+ {
+ if (_s_comInterface1VTable != null)
+ return _s_comInterface1VTable;
+ void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(GetAndSetInt), sizeof(void*) * 5);
+ GetIUnknownImpl(out var fpQueryInterface, out var fpAddReference, out var fpRelease);
+ vtable[0] = (void*)fpQueryInterface;
+ vtable[1] = (void*)fpAddReference;
+ vtable[2] = (void*)fpRelease;
+ vtable[3] = (delegate* unmanaged)&Utf8Implementation.ABI.GetStringUtf8;
+ vtable[4] = (delegate* unmanaged)&Utf8Implementation.ABI.SetStringUtf8;
+ _s_comInterface1VTable = vtable;
+ return _s_comInterface1VTable;
+ }
+ }
+ static void* S_Utf16VTable
+ {
+ get
+ {
+ if (_s_comInterface2VTable != null)
+ return _s_comInterface2VTable;
+ void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(GetAndSetInt), sizeof(void*) * 5);
+ GetIUnknownImpl(out var fpQueryInterface, out var fpAddReference, out var fpRelease);
+ vtable[0] = (void*)fpQueryInterface;
+ vtable[1] = (void*)fpAddReference;
+ vtable[2] = (void*)fpRelease;
+ vtable[3] = (delegate* unmanaged)&Utf16Implementation.ABI.GetStringUtf16;
+ vtable[4] = (delegate* unmanaged)&Utf16Implementation.ABI.SetStringUtf16;
+ _s_comInterface2VTable = vtable;
+ return _s_comInterface2VTable;
+ }
+ }
+
+ protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
+ {
+ if (obj is IUTF8Marshalling)
+ {
+ ComInterfaceEntry* comInterfaceEntry = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(Utf8Implementation), sizeof(ComInterfaceEntry));
+ comInterfaceEntry->IID = new Guid(IUTF8Marshalling._guid);
+ comInterfaceEntry->Vtable = (nint)S_Utf8VTable;
+ count = 1;
+ return comInterfaceEntry;
+ }
+ else if (obj is IUTF16Marshalling)
+ {
+ ComInterfaceEntry* comInterfaceEntry = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(Utf16Implementation), sizeof(ComInterfaceEntry));
+ comInterfaceEntry->IID = new Guid(IUTF16Marshalling._guid);
+ comInterfaceEntry->Vtable = (nint)S_Utf16VTable;
+ count = 1;
+ return comInterfaceEntry;
+ }
+ count = 0;
+ return null;
+ }
+
+ protected override object? CreateObject(nint externalComObject, CreateObjectFlags flags) => throw new NotImplementedException();
+ protected override void ReleaseObjects(IEnumerable objects) => throw new NotImplementedException();
+ }
+
+ class Utf8Implementation : IUTF8Marshalling
+ {
+ string _data = "Hello, World!";
+
+ string IUTF8Marshalling.GetString()
+ {
+ return _data;
+ }
+ void IUTF8Marshalling.SetString(string x)
+ {
+ _data = x;
+ }
+
+ // Provides function pointers in the COM format to use in COM VTables
+ public static class ABI
+ {
+ [UnmanagedCallersOnly]
+ public static int GetStringUtf8(void* @this, byte** value)
+ {
+ try
+ {
+ string currValue = ComInterfaceDispatch.GetInstance((ComInterfaceDispatch*)@this).GetString();
+ *value = Utf8StringMarshaller.ConvertToUnmanaged(currValue);
+ return 0;
+ }
+ catch (Exception e)
+ {
+ return e.HResult;
+ }
+ }
+
+ [UnmanagedCallersOnly]
+ public static int SetStringUtf8(void* @this, byte* newValue)
+ {
+ try
+ {
+ string value = Utf8StringMarshaller.ConvertToManaged(newValue);
+ ComInterfaceDispatch.GetInstance((ComInterfaceDispatch*)@this).SetString(value);
+ return 0;
+ }
+ catch (Exception e)
+ {
+ return e.HResult;
+ }
+ }
+ }
+ }
+
+ class Utf16Implementation : IUTF16Marshalling
+ {
+ string _data = "Hello, World!";
+
+ string IUTF16Marshalling.GetString()
+ {
+ return _data;
+ }
+ void IUTF16Marshalling.SetString(string x)
+ {
+ _data = x;
+ }
+
+ // Provides function pointers in the COM format to use in COM VTables
+ public static class ABI
+ {
+ [UnmanagedCallersOnly]
+ public static int GetStringUtf16(void* @this, ushort** value)
+ {
+ try
+ {
+ string currValue = ComInterfaceDispatch.GetInstance((ComInterfaceDispatch*)@this).GetString();
+ *value = Utf16StringMarshaller.ConvertToUnmanaged(currValue);
+ return 0;
+ }
+ catch (Exception e)
+ {
+ return e.HResult;
+ }
+ }
+
+ [UnmanagedCallersOnly]
+ public static int SetStringUtf16(void* @this, ushort* newValue)
+ {
+ try
+ {
+ string value = Utf16StringMarshaller.ConvertToManaged(newValue);
+ ComInterfaceDispatch.GetInstance((ComInterfaceDispatch*)@this).SetString(value);
+ return 0;
+ }
+ catch (Exception e)
+ {
+ return e.HResult;
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/ICustomStringMarshallingUtf16.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/ICustomStringMarshallingUtf16.cs
new file mode 100644
index 0000000000000..d792b62d6d5a9
--- /dev/null
+++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/ICustomStringMarshallingUtf16.cs
@@ -0,0 +1,20 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Runtime.InteropServices;
+using System.Runtime.InteropServices.Marshalling;
+
+namespace SharedTypes.ComInterfaces
+{
+ [Guid(_guid)]
+ [GeneratedComInterface(StringMarshalling = StringMarshalling.Custom, StringMarshallingCustomType = typeof(Utf16StringMarshaller))]
+ internal partial interface ICustomStringMarshallingUtf16
+ {
+ public string GetString();
+
+ public void SetString(string value);
+
+ public const string _guid = "E11D5F3E-DD57-41A6-A59E-7D110551A760";
+ }
+}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IUTF16Marshalling.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IUTF16Marshalling.cs
new file mode 100644
index 0000000000000..2ef5534aa6b23
--- /dev/null
+++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IUTF16Marshalling.cs
@@ -0,0 +1,20 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Runtime.InteropServices;
+using System.Runtime.InteropServices.Marshalling;
+
+namespace SharedTypes.ComInterfaces
+{
+ [Guid(_guid)]
+ [GeneratedComInterface(StringMarshalling = StringMarshalling.Utf16)]
+ internal partial interface IUTF16Marshalling
+ {
+ public string GetString();
+
+ public void SetString(string value);
+
+ public const string _guid = "E11D5F3E-DD57-41A6-A59E-7D110551A760";
+ }
+}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IUTF8Marshalling.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IUTF8Marshalling.cs
new file mode 100644
index 0000000000000..2689425abd506
--- /dev/null
+++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IUTF8Marshalling.cs
@@ -0,0 +1,20 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Runtime.InteropServices;
+using System.Runtime.InteropServices.Marshalling;
+
+namespace SharedTypes.ComInterfaces
+{
+ [Guid(_guid)]
+ [GeneratedComInterface(StringMarshalling = StringMarshalling.Utf8)]
+ internal partial interface IUTF8Marshalling
+ {
+ public string GetString();
+
+ public void SetString(string value);
+
+ public const string _guid = "E11D5F3E-DD57-41A6-A59E-7D110551A760";
+ }
+}