diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs new file mode 100644 index 0000000000000..5085cbc8090c4 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs @@ -0,0 +1,231 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Collections.Immutable; +using System.IO; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; +using Microsoft.CodeAnalysis.CSharp; + +namespace Microsoft.Interop +{ + [Generator] + public class ComClassGenerator : IIncrementalGenerator + { + private sealed record ComClassInfo(string ClassName, ContainingSyntaxContext ContainingSyntaxContext, ContainingSyntax ClassSyntax, SequenceEqualImmutableArray ImplementedInterfacesNames); + public void Initialize(IncrementalGeneratorInitializationContext context) + { + // Get all types with the [GeneratedComClassAttribute] attribute. + var attributedClasses = context.SyntaxProvider + .ForAttributeWithMetadataName( + TypeNames.GeneratedComClassAttribute, + static (node, ct) => node is ClassDeclarationSyntax, + static (context, ct) => + { + var type = (INamedTypeSymbol)context.TargetSymbol; + var syntax = (ClassDeclarationSyntax)context.TargetNode; + ImmutableArray.Builder names = ImmutableArray.CreateBuilder(); + foreach (INamedTypeSymbol iface in type.AllInterfaces) + { + if (iface.GetAttributes().Any(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute)) + { + names.Add(iface.ToDisplayString()); + } + } + return new ComClassInfo( + type.ToDisplayString(), + new ContainingSyntaxContext(syntax), + new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList), + new(names.ToImmutable())); + }); + + var className = attributedClasses.Select(static (info, ct) => info.ClassName); + + var classInfoType = attributedClasses + .Select(static (info, ct) => new { info.ClassName, info.ImplementedInterfacesNames }) + .Select(static (info, ct) => GenerateClassInfoType(info.ImplementedInterfacesNames.Array).NormalizeWhitespace()); + + var attribute = attributedClasses + .Select(static (info, ct) => new { info.ContainingSyntaxContext, info.ClassSyntax }) + .Select(static (info, ct) => GenerateClassInfoAttributeOnUserType(info.ContainingSyntaxContext, info.ClassSyntax).NormalizeWhitespace()); + + context.RegisterSourceOutput(className.Zip(classInfoType).Zip(attribute), static (context, classInfo) => + { + var ((className, classInfoType), attribute) = classInfo; + StringWriter writer = new(); + writer.WriteLine(classInfoType.ToFullString()); + writer.WriteLine(); + writer.WriteLine(attribute); + context.AddSource(className, writer.ToString()); + }); + } + + private const string ClassInfoTypeName = "ComClassInformation"; + + private static readonly AttributeSyntax s_comExposedClassAttributeTemplate = + Attribute( + GenericName(TypeNames.ComExposedClassAttribute) + .AddTypeArgumentListArguments( + IdentifierName(ClassInfoTypeName))); + private static MemberDeclarationSyntax GenerateClassInfoAttributeOnUserType(ContainingSyntaxContext containingSyntaxContext, ContainingSyntax classSyntax) => + containingSyntaxContext.WrapMemberInContainingSyntaxWithUnsafeModifier( + TypeDeclaration(classSyntax.TypeKind, classSyntax.Identifier) + .WithModifiers(classSyntax.Modifiers) + .WithTypeParameterList(classSyntax.TypeParameters) + .AddAttributeLists(AttributeList(SingletonSeparatedList(s_comExposedClassAttributeTemplate)))); + private static ClassDeclarationSyntax GenerateClassInfoType(ImmutableArray implementedInterfaces) + { + const string vtablesField = "s_vtables"; + const string vtablesLocal = "vtables"; + const string detailsTempLocal = "details"; + const string countIdentifier = "count"; + var typeDeclaration = ClassDeclaration(ClassInfoTypeName) + .AddModifiers( + Token(SyntaxKind.FileKeyword), + Token(SyntaxKind.SealedKeyword), + Token(SyntaxKind.UnsafeKeyword)) + .AddBaseListTypes(SimpleBaseType(ParseTypeName(TypeNames.IComExposedClass))) + .AddMembers( + FieldDeclaration( + VariableDeclaration( + PointerType( + ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)), + SingletonSeparatedList(VariableDeclarator(vtablesField)))) + .AddModifiers( + Token(SyntaxKind.PrivateKeyword), + Token(SyntaxKind.StaticKeyword), + Token(SyntaxKind.VolatileKeyword))); + List vtableInitializationBlock = new() + { + // ComInterfaceEntry* vtables = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(), sizeof(ComInterfaceEntry) * ); + LocalDeclarationStatement( + VariableDeclaration( + PointerType( + ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)), + SingletonSeparatedList( + VariableDeclarator(vtablesLocal) + .WithInitializer(EqualsValueClause( + CastExpression( + PointerType( + ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)), + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + ParseTypeName(TypeNames.System_Runtime_CompilerServices_RuntimeHelpers), + IdentifierName("AllocateTypeAssociatedMemory"))) + .AddArgumentListArguments( + Argument(TypeOfExpression(IdentifierName(ClassInfoTypeName))), + Argument( + BinaryExpression( + SyntaxKind.MultiplyExpression, + SizeOfExpression(ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)), + LiteralExpression( + SyntaxKind.NumericLiteralExpression, + Literal(implementedInterfaces.Length))))))))))), + // IIUnknownDerivedDetails details; + LocalDeclarationStatement( + VariableDeclaration( + ParseTypeName(TypeNames.IIUnknownDerivedDetails), + SingletonSeparatedList( + VariableDeclarator(detailsTempLocal)))) + }; + for (int i = 0; i < implementedInterfaces.Length; i++) + { + string ifaceName = implementedInterfaces[i]; + + // details = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof().TypeHandle); + vtableInitializationBlock.Add( + ExpressionStatement( + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(detailsTempLocal), + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + ParseTypeName(TypeNames.StrategyBasedComWrappers), + IdentifierName("DefaultIUnknownInterfaceDetailsStrategy")), + IdentifierName("GetIUnknownDerivedDetails")), + ArgumentList( + SingletonSeparatedList( + Argument( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + TypeOfExpression(ParseName(ifaceName)), + IdentifierName("TypeHandle"))))))))); + // vtable[i] = new() { IID = details.Iid, Vtable = details.ManagedVirtualMethodTable }; + vtableInitializationBlock.Add( + ExpressionStatement( + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + ElementAccessExpression( + IdentifierName(vtablesLocal), + BracketedArgumentList( + SingletonSeparatedList( + Argument( + LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(i)))))), + ImplicitObjectCreationExpression( + ArgumentList(), + InitializerExpression(SyntaxKind.ObjectInitializerExpression, + SeparatedList( + new ExpressionSyntax[] + { + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName("IID"), + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(detailsTempLocal), + IdentifierName("Iid"))), + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName("Vtable"), + CastExpression( + IdentifierName("nint"), + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(detailsTempLocal), + IdentifierName("ManagedVirtualMethodTable")))) + })))))); + } + + // s_vtable = vtable; + vtableInitializationBlock.Add( + ExpressionStatement( + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(vtablesField), + IdentifierName(vtablesLocal)))); + + BlockSyntax getComInterfaceEntriesMethodBody = Block( + // count = ; + ExpressionStatement( + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(countIdentifier), + LiteralExpression(SyntaxKind.NumericLiteralExpression, + Literal(implementedInterfaces.Length)))), + // if (s_vtable == null) + // { initializer block } + IfStatement( + BinaryExpression(SyntaxKind.EqualsExpression, + IdentifierName(vtablesField), + LiteralExpression(SyntaxKind.NullLiteralExpression)), + Block(vtableInitializationBlock)), + // return s_vtable; + ReturnStatement(IdentifierName(vtablesField))); + + typeDeclaration = typeDeclaration.AddMembers( + // public static unsafe ComWrappers.ComInterfaceDispatch* GetComInterfaceEntries(out int count) + // { body } + MethodDeclaration( + PointerType( + ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)), + "GetComInterfaceEntries") + .AddParameterListParameters( + Parameter(Identifier(countIdentifier)) + .WithType(PredefinedType(Token(SyntaxKind.IntKeyword))) + .AddModifiers(Token(SyntaxKind.OutKeyword))) + .WithBody(getComInterfaceEntriesMethodBody) + .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword))); + + return typeDeclaration; + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs index f773acbb3eeb3..17704038a19bc 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs @@ -45,7 +45,7 @@ public static class StepNames public void Initialize(IncrementalGeneratorInitializationContext context) { - // Get all methods with the [GeneratedComInterface] attribute. + // Get all types with the [GeneratedComInterface] attribute. var attributedInterfaces = context.SyntaxProvider .ForAttributeWithMetadataName( TypeNames.GeneratedComInterfaceAttribute, @@ -62,7 +62,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) return new { data.Syntax, data.Symbol, Diagnostic = diagnostic }; }); - // Split the methods we want to generate and the ones we don't into two separate groups. + // Split the types we want to generate and the ones we don't into two separate groups. var interfacesToGenerate = interfacesWithDiagnostics.Where(static data => data.Diagnostic is null); var invalidTypeDiagnostics = interfacesWithDiagnostics.Where(static data => data.Diagnostic is not null); @@ -726,7 +726,7 @@ private static ClassDeclarationSyntax GenerateInterfaceInformation(ComInterfaceC .WithExpressionBody( ArrowExpressionClause( ConditionalExpression( - BinaryExpression(SyntaxKind.EqualsExpression, + BinaryExpression(SyntaxKind.NotEqualsExpression, IdentifierName(vtableFieldName), LiteralExpression(SyntaxKind.NullLiteralExpression)), IdentifierName(vtableFieldName), diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs index 870368e4a5f98..fb0dd80ca7f1a 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs @@ -31,6 +31,18 @@ internal static class IncrementalValuesProviderExtensions }); } + /// + /// Format the syntax nodes in the given provider such that we will not re-normalize if the input nodes have not changed. + /// + /// A syntax node kind. + /// The input nodes + /// A provider of the formatted syntax nodes. + /// + /// Normalizing whitespace is very expensive, so if a generator will have cases where the input information into the step + /// that creates may change but the results of will say the same, + /// using this method to format the code in a separate step will reduce the amount of work the generator repeats when the + /// output code will not change. + /// public static IncrementalValuesProvider SelectNormalized(this IncrementalValuesProvider provider) where TNode : SyntaxNode { diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs index c74ca0b8a1d62..fa271c817da5e 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs @@ -117,13 +117,22 @@ public static string MarshalEx(InteropGenerationOptions options) public const string System_Runtime_InteropServices_ComWrappers_ComInterfaceDispatch = "System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch"; + public const string System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry = "System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry"; + + public const string StrategyBasedComWrappers = "System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers"; + public const string IIUnknownInterfaceType = "System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceType"; public const string IUnknownDerivedAttribute = "System.Runtime.InteropServices.Marshalling.IUnknownDerivedAttribute"; + public const string IIUnknownDerivedDetails = "System.Runtime.InteropServices.Marshalling.IIUnknownDerivedDetails"; public const string ComWrappersUnwrapper = "System.Runtime.InteropServices.Marshalling.ComWrappersUnwrapper"; public const string UnmanagedObjectUnwrapperAttribute = "System.Runtime.InteropServices.Marshalling.UnmanagedObjectUnwrapperAttribute`1"; public const string IUnmanagedObjectUnwrapper = "System.Runtime.InteropServices.Marshalling.IUnmanagedObjectUnwrapper"; public const string UnmanagedObjectUnwrapper = "System.Runtime.InteropServices.Marshalling.UnmanagedObjectUnwrapper"; + + public const string GeneratedComClassAttribute = "System.Runtime.InteropServices.Marshalling.GeneratedComClassAttribute"; + public const string ComExposedClassAttribute = "System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute"; + public const string IComExposedClass = "System.Runtime.InteropServices.Marshalling.IComExposedClass"; } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/ComExposedClassAttribute.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/ComExposedClassAttribute.cs new file mode 100644 index 0000000000000..5fb0d0719a698 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/ComExposedClassAttribute.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace System.Runtime.InteropServices.Marshalling +{ + /// + /// An attribute to mark this class as a type whose instances should be exposed to COM. + /// + /// The type that provides information about how to expose the attributed type to COM. + [AttributeUsage(AttributeTargets.Class, Inherited = false)] + public sealed class ComExposedClassAttribute : Attribute, IComExposedDetails + where T : IComExposedClass + { + /// + public unsafe ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count) => T.GetComInterfaceEntries(out count); + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/ComObject.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/ComObject.cs index c66d4a72c7b97..2fecd065244d7 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/ComObject.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/ComObject.cs @@ -90,7 +90,7 @@ private bool LookUpVTableInfo(RuntimeTypeHandle handle, out IIUnknownCacheStrate qiHResult = 0; if (!CacheStrategy.TryGetTableInfo(handle, out result)) { - IUnknownDerivedDetails? details = InterfaceDetailsStrategy.GetIUnknownDerivedDetails(handle); + IIUnknownDerivedDetails? details = InterfaceDetailsStrategy.GetIUnknownDerivedDetails(handle); if (details is null) { return false; diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/DefaultCaching.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/DefaultCaching.cs index 07152e03414a9..5b86d12c40254 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/DefaultCaching.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/DefaultCaching.cs @@ -11,7 +11,7 @@ internal sealed unsafe class DefaultCaching : IIUnknownCacheStrategy // [TODO] Implement some smart/thread-safe caching private readonly Dictionary _cache = new(); - IIUnknownCacheStrategy.TableInfo IIUnknownCacheStrategy.ConstructTableInfo(RuntimeTypeHandle handle, IUnknownDerivedDetails details, void* ptr) + IIUnknownCacheStrategy.TableInfo IIUnknownCacheStrategy.ConstructTableInfo(RuntimeTypeHandle handle, IIUnknownDerivedDetails details, void* ptr) { var obj = (void***)ptr; return new IIUnknownCacheStrategy.TableInfo() diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/DefaultIUnknownInterfaceDetailsStrategy.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/DefaultIUnknownInterfaceDetailsStrategy.cs index b1c3ff0f2afe5..33f8bc1d5ec39 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/DefaultIUnknownInterfaceDetailsStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/DefaultIUnknownInterfaceDetailsStrategy.cs @@ -7,9 +7,14 @@ internal sealed class DefaultIUnknownInterfaceDetailsStrategy : IIUnknownInterfa { public static readonly IIUnknownInterfaceDetailsStrategy Instance = new DefaultIUnknownInterfaceDetailsStrategy(); - public IUnknownDerivedDetails? GetIUnknownDerivedDetails(RuntimeTypeHandle type) + public IComExposedDetails? GetComExposedTypeDetails(RuntimeTypeHandle type) { - return IUnknownDerivedDetails.GetFromAttribute(type); + return IComExposedDetails.GetFromAttribute(type); + } + + public IIUnknownDerivedDetails? GetIUnknownDerivedDetails(RuntimeTypeHandle type) + { + return IIUnknownDerivedDetails.GetFromAttribute(type); } } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/GeneratedComClassAttribute.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/GeneratedComClassAttribute.cs new file mode 100644 index 0000000000000..13fe839a57645 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/GeneratedComClassAttribute.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace System.Runtime.InteropServices.Marshalling +{ + [AttributeUsage(AttributeTargets.Class)] + public sealed class GeneratedComClassAttribute : Attribute + { + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/GeneratedComInterfaceAttribute.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/GeneratedComInterfaceAttribute.cs index 09c81e4151222..7d81e174d9666 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/GeneratedComInterfaceAttribute.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/GeneratedComInterfaceAttribute.cs @@ -3,8 +3,6 @@ namespace System.Runtime.InteropServices.Marshalling { - public interface IComObjectWrapper { } - [AttributeUsage(AttributeTargets.Interface)] public class GeneratedComInterfaceAttribute : Attribute { diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/GeneratedComWrappersBase.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/GeneratedComWrappersBase.cs deleted file mode 100644 index 7f3f420b1ce8c..0000000000000 --- a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/GeneratedComWrappersBase.cs +++ /dev/null @@ -1,51 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -// This type is for the COM source generator and implements part of the COM-specific interactions. -// This API need to be exposed to implement the COM source generator in one form or another. - -using System.Collections; - -namespace System.Runtime.InteropServices.Marshalling -{ - public abstract class GeneratedComWrappersBase : ComWrappers - { - protected virtual IIUnknownInterfaceDetailsStrategy CreateInterfaceDetailsStrategy() => DefaultIUnknownInterfaceDetailsStrategy.Instance; - - protected virtual IIUnknownStrategy CreateIUnknownStrategy() => FreeThreadedStrategy.Instance; - - protected virtual IIUnknownCacheStrategy CreateCacheStrategy() => new DefaultCaching(); - - protected override sealed unsafe object CreateObject(nint externalComObject, CreateObjectFlags flags) - { - if (flags.HasFlag(CreateObjectFlags.TrackerObject) - || flags.HasFlag(CreateObjectFlags.Aggregation)) - { - throw new NotSupportedException(); - } - - var rcw = new ComObject(CreateInterfaceDetailsStrategy(), CreateIUnknownStrategy(), CreateCacheStrategy(), (void*)externalComObject); - if (flags.HasFlag(CreateObjectFlags.UniqueInstance)) - { - // Set value on MyComObject to enable the FinalRelease option. - // This could also be achieved through an internal factory - // function on ComObject type. - } - return rcw; - } - - protected override sealed void ReleaseObjects(IEnumerable objects) - { - throw new NotImplementedException(); - } - - public ComObject GetOrCreateUniqueObjectForComInstance(nint comInstance, CreateObjectFlags flags) - { - if (flags.HasFlag(CreateObjectFlags.Unwrap)) - { - throw new ArgumentException("Cannot create a unique object if unwrapping a ComWrappers-based COM object is requested.", nameof(flags)); - } - return (ComObject)GetOrCreateObjectForComInstance(comInstance, flags | CreateObjectFlags.UniqueInstance); - } - } -} diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IComExposedClass.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IComExposedClass.cs new file mode 100644 index 0000000000000..070f4d77c98cb --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IComExposedClass.cs @@ -0,0 +1,18 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Runtime.InteropServices.Marshalling +{ + /// + /// Type level information for managed class types exposed to COM. + /// + public unsafe interface IComExposedClass + { + /// + /// Get the COM interface information to provide to a instance to expose this type to COM. + /// + /// The number of COM interfaces this type implements. + /// The interface entry information for the interfaces the type implements. + public static abstract ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count); + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IComExposedDetails.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IComExposedDetails.cs new file mode 100644 index 0000000000000..4c7c419879506 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IComExposedDetails.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Reflection; + +namespace System.Runtime.InteropServices.Marshalling +{ + /// + /// Details about a managed class type exposed to COM. + /// + public unsafe interface IComExposedDetails + { + /// + /// Get the COM interface information to provide to a instance to expose this type to COM. + /// + /// The number of COM interfaces this type implements. + /// The interface entry information for the interfaces the type implements. + ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count); + + internal static IComExposedDetails? GetFromAttribute(RuntimeTypeHandle handle) + { + var type = Type.GetTypeFromHandle(handle); + if (type is null) + { + return null; + } + return (IComExposedDetails?)type.GetCustomAttribute(typeof(ComExposedClassAttribute<>)); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IIUnknownCacheStrategy.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IIUnknownCacheStrategy.cs index d043344b340d5..0784e6eea8856 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IIUnknownCacheStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IIUnknownCacheStrategy.cs @@ -25,7 +25,7 @@ public readonly struct TableInfo /// Pointer to the instance to query /// A instance /// True if success, otherwise false. - TableInfo ConstructTableInfo(RuntimeTypeHandle handle, IUnknownDerivedDetails interfaceDetails, void* ptr); + TableInfo ConstructTableInfo(RuntimeTypeHandle handle, IIUnknownDerivedDetails interfaceDetails, void* ptr); /// /// Get associated . diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IUnknownDerivedDetails.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IIUnknownDerivedDetails.cs similarity index 82% rename from src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IUnknownDerivedDetails.cs rename to src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IIUnknownDerivedDetails.cs index 0c6a9e210208f..0b36127bb6859 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IUnknownDerivedDetails.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IIUnknownDerivedDetails.cs @@ -11,7 +11,7 @@ namespace System.Runtime.InteropServices.Marshalling /// /// Details for the IUnknown derived interface. /// - public interface IUnknownDerivedDetails + public interface IIUnknownDerivedDetails { /// /// Interface ID. @@ -28,14 +28,14 @@ public interface IUnknownDerivedDetails /// public unsafe void** ManagedVirtualMethodTable { get; } - internal static IUnknownDerivedDetails? GetFromAttribute(RuntimeTypeHandle handle) + internal static IIUnknownDerivedDetails? GetFromAttribute(RuntimeTypeHandle handle) { var type = Type.GetTypeFromHandle(handle); if (type is null) { return null; } - return (IUnknownDerivedDetails?)type.GetCustomAttribute(typeof(IUnknownDerivedAttribute<,>)); + return (IIUnknownDerivedDetails?)type.GetCustomAttribute(typeof(IUnknownDerivedAttribute<,>)); } } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IIUnknownInterfaceDetailsStrategy.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IIUnknownInterfaceDetailsStrategy.cs index cd9f84931448b..44921171edf71 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IIUnknownInterfaceDetailsStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IIUnknownInterfaceDetailsStrategy.cs @@ -16,6 +16,13 @@ public interface IIUnknownInterfaceDetailsStrategy /// /// RuntimeTypeHandle instance /// Details if type is known. - IUnknownDerivedDetails? GetIUnknownDerivedDetails(RuntimeTypeHandle type); + IIUnknownDerivedDetails? GetIUnknownDerivedDetails(RuntimeTypeHandle type); + + /// + /// Given a get the details about the type that are exposed to COM. + /// + /// RuntimeTypeHandle instance + /// Details if type is known. + IComExposedDetails? GetComExposedTypeDetails(RuntimeTypeHandle type); } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IIUnknownInterfaceType.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IIUnknownInterfaceType.cs index a600826201d70..73b41fcde9a82 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IIUnknownInterfaceType.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IIUnknownInterfaceType.cs @@ -6,9 +6,19 @@ namespace System.Runtime.InteropServices.Marshalling { + /// + /// Type level information for an IUnknown-derived interface. + /// public unsafe interface IIUnknownInterfaceType { + /// + /// The Interface ID (IID) for the interface. + /// public abstract static Guid Iid { get; } + + /// + /// A pointer to the virtual method table to enable unmanaged callers to call a managed implementation of the interface. + /// public abstract static void** ManagedVirtualMethodTable { get; } } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IUnknownDerivedAttribute.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IUnknownDerivedAttribute.cs index 8624e00212665..02f8d78903c25 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IUnknownDerivedAttribute.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IUnknownDerivedAttribute.cs @@ -6,8 +6,13 @@ namespace System.Runtime.InteropServices.Marshalling { - [AttributeUsage(AttributeTargets.Interface)] - public class IUnknownDerivedAttribute : Attribute, IUnknownDerivedDetails + /// + /// An attribute to mark this interface as a managed representation of an IUnknown-derived interface. + /// + /// The type that provides type-level information about the interface. + /// The type to use for calling from managed callers to unmanaged implementations of the interface. + [AttributeUsage(AttributeTargets.Interface, Inherited = false)] + public class IUnknownDerivedAttribute : Attribute, IIUnknownDerivedDetails where T : IIUnknownInterfaceType { public IUnknownDerivedAttribute() diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/StrategyBasedComWrappers.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/StrategyBasedComWrappers.cs index e6e3442abc69f..e06f50b3d476b 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/StrategyBasedComWrappers.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/StrategyBasedComWrappers.cs @@ -1,14 +1,12 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -// Types that are only needed for the VTable source generator or to provide abstract concepts that the COM generator would use under the hood. -// These are types that we can exclude from the API proposals and either inline into the generated code, provide as file-scoped types, or not provide publicly (indicated by comments on each type). - using System.Collections; +using System.Reflection; namespace System.Runtime.InteropServices.Marshalling { - public abstract class StrategyBasedComWrappers : InteropServices.ComWrappers + public class StrategyBasedComWrappers : ComWrappers { public static IIUnknownInterfaceDetailsStrategy DefaultIUnknownInterfaceDetailsStrategy { get; } = Marshalling.DefaultIUnknownInterfaceDetailsStrategy.Instance; @@ -22,6 +20,16 @@ public abstract class StrategyBasedComWrappers : InteropServices.ComWrappers protected virtual IIUnknownCacheStrategy CreateCacheStrategy() => CreateDefaultCacheStrategy(); + protected override sealed unsafe ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) + { + if (obj.GetType().GetCustomAttribute(typeof(ComExposedClassAttribute<>)) is IComExposedDetails details) + { + return details.GetComInterfaceEntries(out count); + } + count = 0; + return null; + } + protected override sealed unsafe object CreateObject(nint externalComObject, CreateObjectFlags flags) { if (flags.HasFlag(CreateObjectFlags.TrackerObject) diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComClassTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComClassTests.cs new file mode 100644 index 0000000000000..54818174c056a --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComClassTests.cs @@ -0,0 +1,91 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; +using Xunit; + +namespace ComInterfaceGenerator.Tests +{ + unsafe partial class NativeExportsNE + { + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "set_com_object_data")] + public static partial void SetComObjectData(void* obj, int data); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "get_com_object_data")] + public static partial int GetComObjectData(void* obj); + } + + [GeneratedComClass] + partial class ManagedObjectExposedToCom : IComInterface1 + { + public int Data { get; set; } + int IComInterface1.GetData() => Data; + void IComInterface1.SetData(int n) => Data = n; + } + + [GeneratedComClass] + partial class DerivedComObject : ManagedObjectExposedToCom + { + } + + public unsafe class GeneratedComClassTests + { + [Fact] + public void ComInstanceProvidesInterfaceForDirectlyImplementedComInterface() + { + ManagedObjectExposedToCom obj = new(); + StrategyBasedComWrappers wrappers = new(); + nint ptr = wrappers.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.None); + Assert.NotEqual(0, ptr); + var iid = typeof(IComInterface1).GUID; + Assert.Equal(0, Marshal.QueryInterface(ptr, ref iid, out nint iComInterface)); + Assert.NotEqual(0, iComInterface); + Marshal.Release(iComInterface); + Marshal.Release(ptr); + } + + [Fact] + public void ComInstanceProvidesInterfaceForIndirectlyImplementedComInterface() + { + DerivedComObject obj = new(); + StrategyBasedComWrappers wrappers = new(); + nint ptr = wrappers.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.None); + Assert.NotEqual(0, ptr); + var iid = typeof(IComInterface1).GUID; + Assert.Equal(0, Marshal.QueryInterface(ptr, ref iid, out nint iComInterface)); + Assert.NotEqual(0, iComInterface); + Marshal.Release(iComInterface); + Marshal.Release(ptr); + } + + [Fact] + public void CallsToComInterfaceWriteChangesToManagedObject() + { + ManagedObjectExposedToCom obj = new(); + StrategyBasedComWrappers wrappers = new(); + void* ptr = (void*)wrappers.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.None); + Assert.NotEqual(0, (nint)ptr); + obj.Data = 3; + Assert.Equal(3, obj.Data); + NativeExportsNE.SetComObjectData(ptr, 42); + Assert.Equal(42, obj.Data); + Marshal.Release((nint)ptr); + } + + [Fact] + public void CallsToComInterfaceReadChangesFromManagedObject() + { + ManagedObjectExposedToCom obj = new(); + StrategyBasedComWrappers wrappers = new(); + void* ptr = (void*)wrappers.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.None); + Assert.NotEqual(0, (nint)ptr); + obj.Data = 3; + Assert.Equal(3, obj.Data); + obj.Data = 12; + Assert.Equal(obj.Data, NativeExportsNE.GetComObjectData(ptr)); + Marshal.Release((nint)ptr); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IComInterface1.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IComInterface1.cs new file mode 100644 index 0000000000000..bb464c0147c95 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IComInterface1.cs @@ -0,0 +1,18 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; + +namespace ComInterfaceGenerator.Tests +{ + [GeneratedComInterface] + [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] + [Guid("2c3f9903-b586-46b1-881b-adfce9af47b1")] + public partial interface IComInterface1 + { + int GetData(); + void SetData(int n); + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwTests.cs index 0841e6e879ee5..4f41e80511f5b 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwTests.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; using System.Collections; using System.Diagnostics; using System.Linq; @@ -14,23 +13,9 @@ namespace ComInterfaceGenerator.Tests; -[GeneratedComInterface] -[InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] -[Guid("2c3f9903-b586-46b1-881b-adfce9af47b1")] -public partial interface IComInterface1 +internal unsafe partial class NativeExportsNE { - int GetData(); - void SetData(int n); -} - -internal sealed unsafe partial class MyGeneratedComWrappers : StrategyBasedComWrappers -{ - protected sealed override unsafe ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) => throw new UnreachableException("Not creating CCWs yet"); -} - -public static unsafe partial class Native -{ - [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "get_com_object")] + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "get_com_object")] public static partial void* NewNativeObject(); } @@ -40,8 +25,8 @@ public class RcwTests [Fact] public unsafe void CallRcwFromGeneratedComInterface() { - var ptr = Native.NewNativeObject(); // new_native_object - var cw = new MyGeneratedComWrappers(); + var ptr = NativeExportsNE.NewNativeObject(); // new_native_object + var cw = new StrategyBasedComWrappers(); var obj = cw.GetOrCreateObjectForComInstance((nint)ptr, CreateObjectFlags.None); var intObj = (IComInterface1)obj; diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs new file mode 100644 index 0000000000000..2bb8d5e187ea3 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs @@ -0,0 +1,104 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.CodeAnalysis; +using Microsoft.Interop.UnitTests; +using Xunit; + +namespace ComInterfaceGenerator.Unit.Tests +{ + public class ComClassGeneratorOutputShape + { + [Fact] + public async Task SingleComClass() + { + string source = """ + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + [GeneratedComInterface] + partial interface INativeAPI + { + } + + [GeneratedComClass] + partial class C : INativeAPI {} + """; + Compilation comp = await TestUtils.CreateCompilation(source); + TestUtils.AssertPreSourceGeneratorCompilation(comp); + + var newComp = TestUtils.RunGenerators(comp, out _, new Microsoft.Interop.ComClassGenerator()); + TestUtils.AssertPostSourceGeneratorCompilation(newComp); + // We'll create one syntax tree for the new interface. + Assert.Equal(comp.SyntaxTrees.Count() + 1, newComp.SyntaxTrees.Count()); + + VerifyShape(newComp, "C"); + } + + [Fact] + public async Task MultipleComClasses() + { + string source = $$""" + using System.Runtime.InteropServices; + using System.Runtime.InteropServices.Marshalling; + + [GeneratedComInterface] + partial interface I + { + } + [GeneratedComInterface] + partial interface J + { + } + + [GeneratedComClass] + partial class C : I, J + { + } + + [GeneratedComClass] + partial class D : I, J + { + } + + [GeneratedComClass] + partial class E : C + { + } + """; + Compilation comp = await TestUtils.CreateCompilation(source); + TestUtils.AssertPreSourceGeneratorCompilation(comp); + + var newComp = TestUtils.RunGenerators(comp, out _, new Microsoft.Interop.ComClassGenerator()); + TestUtils.AssertPostSourceGeneratorCompilation(newComp); + // We'll create one syntax tree per user-defined interface. + Assert.Equal(comp.SyntaxTrees.Count() + 3, newComp.SyntaxTrees.Count()); + + VerifyShape(newComp, "C"); + VerifyShape(newComp, "D"); + VerifyShape(newComp, "E"); + } + private static void VerifyShape(Compilation comp, string userDefinedClassMetadataName) + { + INamedTypeSymbol? userDefinedClass = comp.Assembly.GetTypeByMetadataName(userDefinedClassMetadataName); + Assert.NotNull(userDefinedClass); + + INamedTypeSymbol? comExposedClassAttribute = comp.GetTypeByMetadataName("System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute`1"); + + Assert.NotNull(comExposedClassAttribute); + + AttributeData iUnknownDerivedAttribute = Assert.Single( + userDefinedClass.GetAttributes(), + attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass?.OriginalDefinition, comExposedClassAttribute)); + + Assert.Collection(Assert.IsAssignableFrom(iUnknownDerivedAttribute.AttributeClass).TypeArguments, + infoType => + { + Assert.True(Assert.IsAssignableFrom(infoType).IsFileLocal); + }); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/GeneratedComInterfaceAnalyzerTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/GeneratedComInterfaceAnalyzerTests.cs index e4291e6c1f858..a77166a211d7c 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/GeneratedComInterfaceAnalyzerTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/GeneratedComInterfaceAnalyzerTests.cs @@ -34,12 +34,6 @@ interface IFoo { void Bar() {} } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync(_usings + snippet); } @@ -54,12 +48,6 @@ interface IFoo { void Bar() {} } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync(_usings + snippet); } @@ -75,13 +63,6 @@ interface IFoo { void Bar() {} } - - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync(_usings + snippet); } @@ -97,12 +78,6 @@ interface IFoo { void Bar() {} } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync(_usings + snippet); } @@ -118,12 +93,6 @@ interface IFoo { void Bar() {} } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync(_usings + snippet); } @@ -139,12 +108,6 @@ interface IFoo { void Bar() {} } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync(_usings + snippet); } @@ -160,12 +123,6 @@ interface IFoo { void Bar() {} } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync(_usings + snippet); } @@ -181,12 +138,6 @@ interface IFoo { void Bar() {} } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync(_usings + snippet); } @@ -206,12 +157,6 @@ interface IFoo { void Bar() {} } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync(_usings + snippet); } @@ -232,12 +177,6 @@ interface IFoo { void Bar() {} } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync(_usings + snippet); } @@ -254,12 +193,6 @@ interface IFoo { void Bar() {} } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync(_usings + snippet); } @@ -276,12 +209,6 @@ interface IFoo { void Bar() {} } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync( _usings + snippet, @@ -302,12 +229,6 @@ interface IFoo { void Bar() {} } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync( _usings + snippet, @@ -328,12 +249,6 @@ interface IFoo { void Bar() {} } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync( _usings + snippet, @@ -354,12 +269,6 @@ interface IFoo { void Bar() {} } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync( _usings + snippet, @@ -380,12 +289,6 @@ interface IFoo { void Bar() {} } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync( _usings + snippet, @@ -406,12 +309,6 @@ interface IFoo { void Bar() {} } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync( _usings + snippet, @@ -438,12 +335,6 @@ void Bar() {} [GeneratedComInterface] partial interface IFoo { } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync(_usings + snippet); } @@ -462,12 +353,6 @@ void Bar() {} [GeneratedComInterface] partial interface IFoo { } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync(_usings + snippet); } @@ -486,12 +371,6 @@ void Bar() {} [GeneratedComInterface] partial interface IFoo { } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync( _usings + snippet, @@ -514,12 +393,6 @@ void Bar() {} [GeneratedComInterface] partial interface IFoo { } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync( _usings + snippet, @@ -542,12 +415,6 @@ void Bar() {} [GeneratedComInterface] partial interface IFoo { } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync( _usings + snippet, @@ -570,12 +437,6 @@ void Bar() {} [GeneratedComInterface] partial interface IFoo { } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync( _usings + snippet, @@ -598,12 +459,6 @@ void Bar() {} [GeneratedComInterface] partial interface IFoo { } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync( _usings + snippet, @@ -626,12 +481,6 @@ void Bar() {} [GeneratedComInterface] partial interface IFoo { } - - public unsafe partial class MyComWrappers : GeneratedComWrappersBase - { - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;} - } - """; await VerifyCS.VerifyAnalyzerAsync( _usings + snippet, diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaces.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaces.cs index 52db26bd0818a..651f02ba66cc1 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaces.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaces.cs @@ -15,7 +15,7 @@ namespace NativeExports; -public static unsafe class ComInterfaceGeneratorExports +public static unsafe class ComInterfaces { interface IComInterface1 { @@ -30,42 +30,70 @@ interface IComInterface1 [UnmanagedCallersOnly(EntryPoint = "get_com_object")] public static void* CreateComObject() { - MyComWrapper cw = new(); var myObject = new MyObject(); - nint ptr = cw.GetOrCreateComInterfaceForObject(myObject, CreateComInterfaceFlags.None); + nint ptr = ComWrappersInstance.GetOrCreateComInterfaceForObject(myObject, CreateComInterfaceFlags.None); return (void*)ptr; } + [UnmanagedCallersOnly(EntryPoint = "set_com_object_data")] + public static void SetComObjectData(void* ptr, int value) + { + IComInterface1 obj = (IComInterface1)ComWrappersInstance.GetOrCreateObjectForComInstance((nint)ptr, CreateObjectFlags.None); + obj.SetData(value); + } + + [UnmanagedCallersOnly(EntryPoint = "get_com_object_data")] + public static int GetComObjectData(void* ptr) + { + IComInterface1 obj = (IComInterface1)ComWrappersInstance.GetOrCreateObjectForComInstance((nint)ptr, CreateObjectFlags.None); + return obj.GetData(); + } + + private static readonly ComWrappers ComWrappersInstance = new MyComWrapper(); + class MyComWrapper : System.Runtime.InteropServices.ComWrappers { - static void* _s_comInterface1VTable = null; - static void* s_comInterface1VTable + static volatile void* s_comInterface1VTable = null; + static void* IComInterface1VTable { get { - if (MyComWrapper._s_comInterface1VTable != null) - return _s_comInterface1VTable; - void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComInterfaceGeneratorExports), sizeof(void*) * 5); + if (s_comInterface1VTable != null) + return s_comInterface1VTable; + void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComInterfaces), sizeof(void*) * 5); GetIUnknownImpl(out var fpQueryInterface, out var fpAddReference, out var fpRelease); vtable[0] = (void*)fpQueryInterface; vtable[1] = (void*)fpAddReference; vtable[2] = (void*)fpRelease; vtable[3] = (delegate* unmanaged)&MyObject.ABI.GetData; vtable[4] = (delegate* unmanaged)&MyObject.ABI.SetData; - _s_comInterface1VTable = vtable; - return _s_comInterface1VTable; + s_comInterface1VTable = vtable; + return s_comInterface1VTable; } } - protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) + + static volatile ComInterfaceEntry* s_myObjectComInterfaceEntries = null; + static ComInterfaceEntry* MyObjectComInterfaceEntries { - if (obj is MyObject) + get { + if (s_myObjectComInterfaceEntries != null) + return s_myObjectComInterfaceEntries; + ComInterfaceEntry* comInterfaceEntry = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(MyObject), sizeof(ComInterfaceEntry)); comInterfaceEntry->IID = IComInterface1.IID; - comInterfaceEntry->Vtable = (nint)s_comInterface1VTable; + comInterfaceEntry->Vtable = (nint)IComInterface1VTable; + s_myObjectComInterfaceEntries = comInterfaceEntry; + return s_myObjectComInterfaceEntries; + } + } + protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) + { + if (obj is MyObject) + { count = 1; - return comInterfaceEntry; + return MyObjectComInterfaceEntries; } count = 0; return null; @@ -77,14 +105,14 @@ protected override object CreateObject(nint ptr, CreateObjectFlags flags) { return null; } - return new IComInterface1Impl(ptr); + return new IComInterface1Impl(IComInterfaceImpl); } protected override void ReleaseObjects(IEnumerable objects) { } } // Wrapper for calling CCWs from the ComInterfaceGenerator - class IComInterface1Impl : IComInterface1 + sealed class IComInterface1Impl : IComInterface1 { nint _ptr; @@ -93,6 +121,11 @@ public IComInterface1Impl(nint @this) _ptr = @this; } + ~IComInterface1Impl() + { + int refCount = Marshal.Release(_ptr); + } + int GetData(nint inst) { int value;