Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide support for exposing .NET classes to COM through source generation #83755

Merged
merged 7 commits into from
Mar 28, 2023
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Collections.Immutable;
using System.IO;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using Microsoft.CodeAnalysis.CSharp;

namespace Microsoft.Interop
{
[Generator]
public class ComClassGenerator : IIncrementalGenerator
{
private sealed record ComClassInfo(string ClassName, ContainingSyntaxContext ContainingSyntaxContext, ContainingSyntax ClassSyntax, SequenceEqualImmutableArray<string> ImplementedInterfacesNames);
public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Get all types with the [GeneratedComClassAttribute] attribute.
var attributedClasses = context.SyntaxProvider
.ForAttributeWithMetadataName(
TypeNames.GeneratedComClassAttribute,
static (node, ct) => node is ClassDeclarationSyntax,
static (context, ct) =>
{
var type = (INamedTypeSymbol)context.TargetSymbol;
var syntax = (ClassDeclarationSyntax)context.TargetNode;
ImmutableArray<string>.Builder names = ImmutableArray.CreateBuilder<string>();
foreach (INamedTypeSymbol iface in type.AllInterfaces)
{
if (iface.GetAttributes().Any(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute))
{
names.Add(iface.ToDisplayString());
}
}
return new ComClassInfo(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to warn / bail if there are no interfaces with GeneratedComInterface?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's still a valid scenario, but I could see us adding a warning for it. I'll file a follow-up issue for that.

type.ToDisplayString(),
new ContainingSyntaxContext(syntax),
new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList),
new(names.ToImmutable()));
});

var className = attributedClasses.Select(static (info, ct) => info.ClassName);

var classInfoType = attributedClasses
.Select(static (info, ct) => new { info.ClassName, info.ImplementedInterfacesNames })
.Select(static (info, ct) => GenerateClassInfoType(info.ImplementedInterfacesNames.Array).NormalizeWhitespace());

var attribute = attributedClasses
.Select(static (info, ct) => new { info.ContainingSyntaxContext, info.ClassSyntax })
.Select(static (info, ct) => GenerateClassInfoAttributeOnUserType(info.ContainingSyntaxContext, info.ClassSyntax).NormalizeWhitespace());

context.RegisterSourceOutput(className.Zip(classInfoType).Zip(attribute), static (context, classInfo) =>
{
var ((className, classInfoType), attribute) = classInfo;
StringWriter writer = new();
writer.WriteLine(classInfoType.ToFullString());
writer.WriteLine();
writer.WriteLine(attribute);
context.AddSource(className, writer.ToString());
});
}

private const string ClassInfoTypeName = "ComClassInformation";
AaronRobinsonMSFT marked this conversation as resolved.
Show resolved Hide resolved

private static readonly AttributeSyntax s_comExposedClassAttributeTemplate =
Attribute(
GenericName(TypeNames.ComExposedClassAttribute)
.AddTypeArgumentListArguments(
IdentifierName(ClassInfoTypeName)));
private static MemberDeclarationSyntax GenerateClassInfoAttributeOnUserType(ContainingSyntaxContext containingSyntaxContext, ContainingSyntax classSyntax) =>
containingSyntaxContext.WrapMemberInContainingSyntaxWithUnsafeModifier(
TypeDeclaration(classSyntax.TypeKind, classSyntax.Identifier)
.WithModifiers(classSyntax.Modifiers)
.WithTypeParameterList(classSyntax.TypeParameters)
.AddAttributeLists(AttributeList(SingletonSeparatedList(s_comExposedClassAttributeTemplate))));
private static ClassDeclarationSyntax GenerateClassInfoType(ImmutableArray<string> implementedInterfaces)
{
const string vtablesField = "s_vtables";
const string vtablesLocal = "vtables";
const string detailsTempLocal = "details";
const string countIdentifier = "count";
var typeDeclaration = ClassDeclaration(ClassInfoTypeName)
.AddModifiers(
Token(SyntaxKind.FileKeyword),
Token(SyntaxKind.SealedKeyword),
Token(SyntaxKind.UnsafeKeyword))
.AddBaseListTypes(SimpleBaseType(ParseTypeName(TypeNames.IComExposedClass)))
.AddMembers(
FieldDeclaration(
VariableDeclaration(
PointerType(
ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
SingletonSeparatedList(VariableDeclarator(vtablesField))))
.AddModifiers(
Token(SyntaxKind.PrivateKeyword),
Token(SyntaxKind.StaticKeyword),
Token(SyntaxKind.VolatileKeyword)));
List<StatementSyntax> vtableInitializationBlock = new()
{
// ComInterfaceEntry* vtables = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(<ClassInfoTypeName>), sizeof(ComInterfaceEntry) * <numInterfaces>);
LocalDeclarationStatement(
VariableDeclaration(
PointerType(
ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
SingletonSeparatedList(
VariableDeclarator(vtablesLocal)
.WithInitializer(EqualsValueClause(
CastExpression(
PointerType(
ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName(TypeNames.System_Runtime_CompilerServices_RuntimeHelpers),
IdentifierName("AllocateTypeAssociatedMemory")))
.AddArgumentListArguments(
Argument(TypeOfExpression(IdentifierName(ClassInfoTypeName))),
Argument(
BinaryExpression(
SyntaxKind.MultiplyExpression,
SizeOfExpression(ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
LiteralExpression(
SyntaxKind.NumericLiteralExpression,
Literal(implementedInterfaces.Length))))))))))),
// IIUnknownDerivedDetails details;
LocalDeclarationStatement(
VariableDeclaration(
ParseTypeName(TypeNames.IIUnknownDerivedDetails),
SingletonSeparatedList(
VariableDeclarator(detailsTempLocal))))
};
for (int i = 0; i < implementedInterfaces.Length; i++)
{
string ifaceName = implementedInterfaces[i];

// details = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(<ifaceName>).TypeHandle);
vtableInitializationBlock.Add(
ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(detailsTempLocal),
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName(TypeNames.StrategyBasedComWrappers),
IdentifierName("DefaultIUnknownInterfaceDetailsStrategy")),
IdentifierName("GetIUnknownDerivedDetails")),
ArgumentList(
SingletonSeparatedList(
Argument(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
TypeOfExpression(ParseName(ifaceName)),
IdentifierName("TypeHandle")))))))));
// vtable[i] = new() { IID = details.Iid, Vtable = details.ManagedVirtualMethodTable };
vtableInitializationBlock.Add(
ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
ElementAccessExpression(
IdentifierName(vtablesLocal),
BracketedArgumentList(
SingletonSeparatedList(
Argument(
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(i)))))),
ImplicitObjectCreationExpression(
ArgumentList(),
InitializerExpression(SyntaxKind.ObjectInitializerExpression,
SeparatedList(
new ExpressionSyntax[]
{
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName("IID"),
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(detailsTempLocal),
IdentifierName("Iid"))),
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName("Vtable"),
CastExpression(
IdentifierName("nint"),
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(detailsTempLocal),
IdentifierName("ManagedVirtualMethodTable"))))
}))))));
}

// s_vtable = vtable;
vtableInitializationBlock.Add(
ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(vtablesField),
IdentifierName(vtablesLocal))));

BlockSyntax getComInterfaceEntriesMethodBody = Block(
// count = <count>;
ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(countIdentifier),
LiteralExpression(SyntaxKind.NumericLiteralExpression,
Literal(implementedInterfaces.Length)))),
// if (s_vtable == null)
// { initializer block }
IfStatement(
BinaryExpression(SyntaxKind.EqualsExpression,
IdentifierName(vtablesField),
LiteralExpression(SyntaxKind.NullLiteralExpression)),
Block(vtableInitializationBlock)),
// return s_vtable;
ReturnStatement(IdentifierName(vtablesField)));

typeDeclaration = typeDeclaration.AddMembers(
// public static unsafe ComWrappers.ComInterfaceDispatch* GetComInterfaceEntries(out int count)
// { body }
MethodDeclaration(
PointerType(
ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
"GetComInterfaceEntries")
.AddParameterListParameters(
Parameter(Identifier(countIdentifier))
.WithType(PredefinedType(Token(SyntaxKind.IntKeyword)))
.AddModifiers(Token(SyntaxKind.OutKeyword)))
.WithBody(getComInterfaceEntriesMethodBody)
.AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)));

return typeDeclaration;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public static class StepNames

public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Get all methods with the [GeneratedComInterface] attribute.
// Get all types with the [GeneratedComInterface] attribute.
var attributedInterfaces = context.SyntaxProvider
.ForAttributeWithMetadataName(
TypeNames.GeneratedComInterfaceAttribute,
Expand All @@ -62,7 +62,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
return new { data.Syntax, data.Symbol, Diagnostic = diagnostic };
});

// Split the methods we want to generate and the ones we don't into two separate groups.
// Split the types we want to generate and the ones we don't into two separate groups.
var interfacesToGenerate = interfacesWithDiagnostics.Where(static data => data.Diagnostic is null);
var invalidTypeDiagnostics = interfacesWithDiagnostics.Where(static data => data.Diagnostic is not null);

Expand Down Expand Up @@ -726,7 +726,7 @@ private static ClassDeclarationSyntax GenerateInterfaceInformation(ComInterfaceC
.WithExpressionBody(
ArrowExpressionClause(
ConditionalExpression(
BinaryExpression(SyntaxKind.EqualsExpression,
BinaryExpression(SyntaxKind.NotEqualsExpression,
IdentifierName(vtableFieldName),
LiteralExpression(SyntaxKind.NullLiteralExpression)),
IdentifierName(vtableFieldName),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ internal static class IncrementalValuesProviderExtensions
});
}

/// <summary>
/// Format the syntax nodes in the given provider such that we will not re-normalize if the input nodes have not changed.
/// </summary>
/// <typeparam name="TNode">A syntax node kind.</typeparam>
/// <param name="provider">The input nodes</param>
/// <returns>A provider of the formatted syntax nodes.</returns>
/// <remarks>
/// Normalizing whitespace is very expensive, so if a generator will have cases where the input information into the step
/// that creates <paramref name="provider"/> may change but the results of <paramref name="provider"/> will say the same,
/// using this method to format the code in a separate step will reduce the amount of work the generator repeats when the
/// output code will not change.
/// </remarks>
public static IncrementalValuesProvider<TNode> SelectNormalized<TNode>(this IncrementalValuesProvider<TNode> provider)
where TNode : SyntaxNode
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,22 @@ public static string MarshalEx(InteropGenerationOptions options)

public const string System_Runtime_InteropServices_ComWrappers_ComInterfaceDispatch = "System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch";

public const string System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry = "System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry";

public const string StrategyBasedComWrappers = "System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers";

public const string IIUnknownInterfaceType = "System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceType";
public const string IUnknownDerivedAttribute = "System.Runtime.InteropServices.Marshalling.IUnknownDerivedAttribute";
public const string IIUnknownDerivedDetails = "System.Runtime.InteropServices.Marshalling.IIUnknownDerivedDetails";

public const string ComWrappersUnwrapper = "System.Runtime.InteropServices.Marshalling.ComWrappersUnwrapper";
public const string UnmanagedObjectUnwrapperAttribute = "System.Runtime.InteropServices.Marshalling.UnmanagedObjectUnwrapperAttribute`1";

public const string IUnmanagedObjectUnwrapper = "System.Runtime.InteropServices.Marshalling.IUnmanagedObjectUnwrapper";
public const string UnmanagedObjectUnwrapper = "System.Runtime.InteropServices.Marshalling.UnmanagedObjectUnwrapper";

public const string GeneratedComClassAttribute = "System.Runtime.InteropServices.Marshalling.GeneratedComClassAttribute";
public const string ComExposedClassAttribute = "System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute";
public const string IComExposedClass = "System.Runtime.InteropServices.Marshalling.IComExposedClass";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace System.Runtime.InteropServices.Marshalling
{
/// <summary>
/// An attribute to mark this class as a type whose instances should be exposed to COM.
/// </summary>
/// <typeparam name="T">The type that provides information about how to expose the attributed type to COM.</typeparam>
[AttributeUsage(AttributeTargets.Class, Inherited = false)]
public sealed class ComExposedClassAttribute<T> : Attribute, IComExposedDetails
where T : IComExposedClass
{
/// <inheritdoc />
public unsafe ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count) => T.GetComInterfaceEntries(out count);
}
}
Loading