Skip to content

Commit

Permalink
Provide support for exposing .NET classes to COM through source gener…
Browse files Browse the repository at this point in the history
…ation (#83755)
  • Loading branch information
jkoritzinsky authored Mar 28, 2023
1 parent 3a28f6e commit ce48579
Show file tree
Hide file tree
Showing 25 changed files with 655 additions and 257 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Collections.Immutable;
using System.IO;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using Microsoft.CodeAnalysis.CSharp;

namespace Microsoft.Interop
{
[Generator]
public class ComClassGenerator : IIncrementalGenerator
{
private sealed record ComClassInfo(string ClassName, ContainingSyntaxContext ContainingSyntaxContext, ContainingSyntax ClassSyntax, SequenceEqualImmutableArray<string> ImplementedInterfacesNames);
public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Get all types with the [GeneratedComClassAttribute] attribute.
var attributedClasses = context.SyntaxProvider
.ForAttributeWithMetadataName(
TypeNames.GeneratedComClassAttribute,
static (node, ct) => node is ClassDeclarationSyntax,
static (context, ct) =>
{
var type = (INamedTypeSymbol)context.TargetSymbol;
var syntax = (ClassDeclarationSyntax)context.TargetNode;
ImmutableArray<string>.Builder names = ImmutableArray.CreateBuilder<string>();
foreach (INamedTypeSymbol iface in type.AllInterfaces)
{
if (iface.GetAttributes().Any(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute))
{
names.Add(iface.ToDisplayString());
}
}
return new ComClassInfo(
type.ToDisplayString(),
new ContainingSyntaxContext(syntax),
new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList),
new(names.ToImmutable()));
});

var className = attributedClasses.Select(static (info, ct) => info.ClassName);

var classInfoType = attributedClasses
.Select(static (info, ct) => new { info.ClassName, info.ImplementedInterfacesNames })
.Select(static (info, ct) => GenerateClassInfoType(info.ImplementedInterfacesNames.Array).NormalizeWhitespace());

var attribute = attributedClasses
.Select(static (info, ct) => new { info.ContainingSyntaxContext, info.ClassSyntax })
.Select(static (info, ct) => GenerateClassInfoAttributeOnUserType(info.ContainingSyntaxContext, info.ClassSyntax).NormalizeWhitespace());

context.RegisterSourceOutput(className.Zip(classInfoType).Zip(attribute), static (context, classInfo) =>
{
var ((className, classInfoType), attribute) = classInfo;
StringWriter writer = new();
writer.WriteLine(classInfoType.ToFullString());
writer.WriteLine();
writer.WriteLine(attribute);
context.AddSource(className, writer.ToString());
});
}

private const string ClassInfoTypeName = "ComClassInformation";

private static readonly AttributeSyntax s_comExposedClassAttributeTemplate =
Attribute(
GenericName(TypeNames.ComExposedClassAttribute)
.AddTypeArgumentListArguments(
IdentifierName(ClassInfoTypeName)));
private static MemberDeclarationSyntax GenerateClassInfoAttributeOnUserType(ContainingSyntaxContext containingSyntaxContext, ContainingSyntax classSyntax) =>
containingSyntaxContext.WrapMemberInContainingSyntaxWithUnsafeModifier(
TypeDeclaration(classSyntax.TypeKind, classSyntax.Identifier)
.WithModifiers(classSyntax.Modifiers)
.WithTypeParameterList(classSyntax.TypeParameters)
.AddAttributeLists(AttributeList(SingletonSeparatedList(s_comExposedClassAttributeTemplate))));
private static ClassDeclarationSyntax GenerateClassInfoType(ImmutableArray<string> implementedInterfaces)
{
const string vtablesField = "s_vtables";
const string vtablesLocal = "vtables";
const string detailsTempLocal = "details";
const string countIdentifier = "count";
var typeDeclaration = ClassDeclaration(ClassInfoTypeName)
.AddModifiers(
Token(SyntaxKind.FileKeyword),
Token(SyntaxKind.SealedKeyword),
Token(SyntaxKind.UnsafeKeyword))
.AddBaseListTypes(SimpleBaseType(ParseTypeName(TypeNames.IComExposedClass)))
.AddMembers(
FieldDeclaration(
VariableDeclaration(
PointerType(
ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
SingletonSeparatedList(VariableDeclarator(vtablesField))))
.AddModifiers(
Token(SyntaxKind.PrivateKeyword),
Token(SyntaxKind.StaticKeyword),
Token(SyntaxKind.VolatileKeyword)));
List<StatementSyntax> vtableInitializationBlock = new()
{
// ComInterfaceEntry* vtables = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(<ClassInfoTypeName>), sizeof(ComInterfaceEntry) * <numInterfaces>);
LocalDeclarationStatement(
VariableDeclaration(
PointerType(
ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
SingletonSeparatedList(
VariableDeclarator(vtablesLocal)
.WithInitializer(EqualsValueClause(
CastExpression(
PointerType(
ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName(TypeNames.System_Runtime_CompilerServices_RuntimeHelpers),
IdentifierName("AllocateTypeAssociatedMemory")))
.AddArgumentListArguments(
Argument(TypeOfExpression(IdentifierName(ClassInfoTypeName))),
Argument(
BinaryExpression(
SyntaxKind.MultiplyExpression,
SizeOfExpression(ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
LiteralExpression(
SyntaxKind.NumericLiteralExpression,
Literal(implementedInterfaces.Length))))))))))),
// IIUnknownDerivedDetails details;
LocalDeclarationStatement(
VariableDeclaration(
ParseTypeName(TypeNames.IIUnknownDerivedDetails),
SingletonSeparatedList(
VariableDeclarator(detailsTempLocal))))
};
for (int i = 0; i < implementedInterfaces.Length; i++)
{
string ifaceName = implementedInterfaces[i];

// details = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(<ifaceName>).TypeHandle);
vtableInitializationBlock.Add(
ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(detailsTempLocal),
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName(TypeNames.StrategyBasedComWrappers),
IdentifierName("DefaultIUnknownInterfaceDetailsStrategy")),
IdentifierName("GetIUnknownDerivedDetails")),
ArgumentList(
SingletonSeparatedList(
Argument(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
TypeOfExpression(ParseName(ifaceName)),
IdentifierName("TypeHandle")))))))));
// vtable[i] = new() { IID = details.Iid, Vtable = details.ManagedVirtualMethodTable };
vtableInitializationBlock.Add(
ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
ElementAccessExpression(
IdentifierName(vtablesLocal),
BracketedArgumentList(
SingletonSeparatedList(
Argument(
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(i)))))),
ImplicitObjectCreationExpression(
ArgumentList(),
InitializerExpression(SyntaxKind.ObjectInitializerExpression,
SeparatedList(
new ExpressionSyntax[]
{
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName("IID"),
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(detailsTempLocal),
IdentifierName("Iid"))),
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName("Vtable"),
CastExpression(
IdentifierName("nint"),
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(detailsTempLocal),
IdentifierName("ManagedVirtualMethodTable"))))
}))))));
}

// s_vtable = vtable;
vtableInitializationBlock.Add(
ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(vtablesField),
IdentifierName(vtablesLocal))));

BlockSyntax getComInterfaceEntriesMethodBody = Block(
// count = <count>;
ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(countIdentifier),
LiteralExpression(SyntaxKind.NumericLiteralExpression,
Literal(implementedInterfaces.Length)))),
// if (s_vtable == null)
// { initializer block }
IfStatement(
BinaryExpression(SyntaxKind.EqualsExpression,
IdentifierName(vtablesField),
LiteralExpression(SyntaxKind.NullLiteralExpression)),
Block(vtableInitializationBlock)),
// return s_vtable;
ReturnStatement(IdentifierName(vtablesField)));

typeDeclaration = typeDeclaration.AddMembers(
// public static unsafe ComWrappers.ComInterfaceDispatch* GetComInterfaceEntries(out int count)
// { body }
MethodDeclaration(
PointerType(
ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
"GetComInterfaceEntries")
.AddParameterListParameters(
Parameter(Identifier(countIdentifier))
.WithType(PredefinedType(Token(SyntaxKind.IntKeyword)))
.AddModifiers(Token(SyntaxKind.OutKeyword)))
.WithBody(getComInterfaceEntriesMethodBody)
.AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)));

return typeDeclaration;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public static class StepNames

public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Get all methods with the [GeneratedComInterface] attribute.
// Get all types with the [GeneratedComInterface] attribute.
var attributedInterfaces = context.SyntaxProvider
.ForAttributeWithMetadataName(
TypeNames.GeneratedComInterfaceAttribute,
Expand All @@ -62,7 +62,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
return new { data.Syntax, data.Symbol, Diagnostic = diagnostic };
});

// Split the methods we want to generate and the ones we don't into two separate groups.
// Split the types we want to generate and the ones we don't into two separate groups.
var interfacesToGenerate = interfacesWithDiagnostics.Where(static data => data.Diagnostic is null);
var invalidTypeDiagnostics = interfacesWithDiagnostics.Where(static data => data.Diagnostic is not null);

Expand Down Expand Up @@ -726,7 +726,7 @@ private static ClassDeclarationSyntax GenerateInterfaceInformation(ComInterfaceC
.WithExpressionBody(
ArrowExpressionClause(
ConditionalExpression(
BinaryExpression(SyntaxKind.EqualsExpression,
BinaryExpression(SyntaxKind.NotEqualsExpression,
IdentifierName(vtableFieldName),
LiteralExpression(SyntaxKind.NullLiteralExpression)),
IdentifierName(vtableFieldName),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ internal static class IncrementalValuesProviderExtensions
});
}

/// <summary>
/// Format the syntax nodes in the given provider such that we will not re-normalize if the input nodes have not changed.
/// </summary>
/// <typeparam name="TNode">A syntax node kind.</typeparam>
/// <param name="provider">The input nodes</param>
/// <returns>A provider of the formatted syntax nodes.</returns>
/// <remarks>
/// Normalizing whitespace is very expensive, so if a generator will have cases where the input information into the step
/// that creates <paramref name="provider"/> may change but the results of <paramref name="provider"/> will say the same,
/// using this method to format the code in a separate step will reduce the amount of work the generator repeats when the
/// output code will not change.
/// </remarks>
public static IncrementalValuesProvider<TNode> SelectNormalized<TNode>(this IncrementalValuesProvider<TNode> provider)
where TNode : SyntaxNode
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,22 @@ public static string MarshalEx(InteropGenerationOptions options)

public const string System_Runtime_InteropServices_ComWrappers_ComInterfaceDispatch = "System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch";

public const string System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry = "System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry";

public const string StrategyBasedComWrappers = "System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers";

public const string IIUnknownInterfaceType = "System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceType";
public const string IUnknownDerivedAttribute = "System.Runtime.InteropServices.Marshalling.IUnknownDerivedAttribute";
public const string IIUnknownDerivedDetails = "System.Runtime.InteropServices.Marshalling.IIUnknownDerivedDetails";

public const string ComWrappersUnwrapper = "System.Runtime.InteropServices.Marshalling.ComWrappersUnwrapper";
public const string UnmanagedObjectUnwrapperAttribute = "System.Runtime.InteropServices.Marshalling.UnmanagedObjectUnwrapperAttribute`1";

public const string IUnmanagedObjectUnwrapper = "System.Runtime.InteropServices.Marshalling.IUnmanagedObjectUnwrapper";
public const string UnmanagedObjectUnwrapper = "System.Runtime.InteropServices.Marshalling.UnmanagedObjectUnwrapper";

public const string GeneratedComClassAttribute = "System.Runtime.InteropServices.Marshalling.GeneratedComClassAttribute";
public const string ComExposedClassAttribute = "System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute";
public const string IComExposedClass = "System.Runtime.InteropServices.Marshalling.IComExposedClass";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace System.Runtime.InteropServices.Marshalling
{
/// <summary>
/// An attribute to mark this class as a type whose instances should be exposed to COM.
/// </summary>
/// <typeparam name="T">The type that provides information about how to expose the attributed type to COM.</typeparam>
[AttributeUsage(AttributeTargets.Class, Inherited = false)]
public sealed class ComExposedClassAttribute<T> : Attribute, IComExposedDetails
where T : IComExposedClass
{
/// <inheritdoc />
public unsafe ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count) => T.GetComInterfaceEntries(out count);
}
}
Loading

0 comments on commit ce48579

Please sign in to comment.