Skip to content

Commit

Permalink
Enhance handling of resolving the marshaller type in presence of gene…
Browse files Browse the repository at this point in the history
…rics and reuse this logic from the suppressor.

Add a linear collection test.

PR feedback.
  • Loading branch information
jkoritzinsky committed Jul 14, 2022
1 parent c474b9c commit 2ceab41
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,20 @@ private static void SuppressMarkMethodsAsStaticDiagnosticIfNeeded(SuppressionAna
SemanticModel model = context.GetSemanticModel(diagnostic.Location.SourceTree);
ISymbol diagnosedSymbol = model.GetDeclaredSymbol(diagnostic.Location.SourceTree.GetRoot(context.CancellationToken).FindNode(diagnostic.Location.SourceSpan), context.CancellationToken);

if (diagnosedSymbol.Kind == SymbolKind.Method)
if (diagnosedSymbol.Kind != SymbolKind.Method)
{
if (FindContainingEntryPointTypeAndManagedType(diagnosedSymbol.ContainingType) is (INamedTypeSymbol entryPointMarshallerType, INamedTypeSymbol managedType))
return;
}

if (FindContainingEntryPointTypeAndManagedType(diagnosedSymbol.ContainingType) is (INamedTypeSymbol entryPointMarshallerType, INamedTypeSymbol managedType))
{
bool isLinearCollectionMarshaller = ManualTypeMarshallingHelper.IsLinearCollectionEntryPoint(entryPointMarshallerType);
(MarshallerShape _, StatefulMarshallerShapeHelper.MarshallerMethods methods) = StatefulMarshallerShapeHelper.GetShapeForType(diagnosedSymbol.ContainingType, managedType, isLinearCollectionMarshaller, context.Compilation);
if (methods.IsShapeMethod((IMethodSymbol)diagnosedSymbol))
{
bool isLinearCollectionMarshaller = entryPointMarshallerType.GetAttributes().Any(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.ContiguousCollectionMarshallerAttribute);
(MarshallerShape _, StatefulMarshallerShapeHelper.MarshallerMethods methods) = StatefulMarshallerShapeHelper.GetShapeForType(diagnosedSymbol.ContainingType, managedType, isLinearCollectionMarshaller, context.Compilation);
if (methods.IsShapeMethod((IMethodSymbol)diagnosedSymbol))
{
// If we are a method of the shape on the stateful marshaller shape, then we need to be our current shape.
// So, suppress the diagnostic to make this method static, as that would break the shape.
context.ReportSuppression(Suppression.Create(MarkMethodsAsStaticSuppression, diagnostic));
}
// If we are a method of the shape on the stateful marshaller shape, then we need to be our current shape.
// So, suppress the diagnostic to make this method static, as that would break the shape.
context.ReportSuppression(Suppression.Create(MarkMethodsAsStaticSuppression, diagnostic));
}
}
}
Expand All @@ -60,7 +62,8 @@ private static (INamedTypeSymbol EntryPointType, INamedTypeSymbol ManagedType)?
&& attr.AttributeConstructor is not null
&& !attr.ConstructorArguments[0].IsNull
&& attr.ConstructorArguments[2].Value is INamedTypeSymbol marshallerTypeInAttribute
&& SymbolEqualityComparer.Default.Equals(marshallerTypeInAttribute, marshallerType));
&& ManualTypeMarshallingHelper.TryResolveMarshallerType(containingType, marshallerTypeInAttribute, _ => { }, out ITypeSymbol? constructedMarshallerType)
&& SymbolEqualityComparer.Default.Equals(constructedMarshallerType, marshallerType));
if (attrData is not null)
{
return (containingType, (INamedTypeSymbol)attrData.ConstructorArguments[0].Value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
using System.Runtime.CompilerServices;
Expand Down Expand Up @@ -62,6 +63,8 @@ public static class MarshalUsingProperties
public const string ConstantElementCount = nameof(ConstantElementCount);
}

private static void IgnoreDiagnostic(Diagnostic diagnostic) { }

public static bool IsLinearCollectionEntryPoint(INamedTypeSymbol entryPointType)
{
return entryPointType.IsGenericType
Expand Down Expand Up @@ -149,28 +152,9 @@ private static bool TryGetMarshallersFromEntryType(
continue;

ITypeSymbol marshallerType = marshallerTypeOnAttr;
if (isLinearCollectionMarshalling && marshallerTypeOnAttr is INamedTypeSymbol namedMarshallerType)
if (!TryResolveMarshallerType(entryPointType, marshallerType, IgnoreDiagnostic, out marshallerType))
{
// Update the marshaller type with resolved type arguments based on the entry point type
// We expect the entry point to already have its type arguments updated based on the managed type
Stack<string> nestedTypeNames = new Stack<string>();
INamedTypeSymbol currentType = namedMarshallerType;
while (currentType is not null)
{
if (currentType.IsConstructedFromEqualTypes(entryPointType))
break;

nestedTypeNames.Push(currentType.Name);
currentType = currentType.ContainingType;
}

currentType = entryPointType;
foreach (string name in nestedTypeNames)
{
currentType = currentType.GetTypeMembers(name).First();
}

marshallerType = currentType;
continue;
}

// TODO: Report invalid shape for mode
Expand Down Expand Up @@ -198,6 +182,72 @@ private static bool TryGetMarshallersFromEntryType(
return true;
}

/// <summary>
/// Resolve the (possibly unbound generic) marshaller type to a fully constructed type based on the entry point type's generic parameters.
/// </summary>
/// <param name="entryPointType">The entry point type</param>
/// <param name="attributeMarshallerType">The marshaller type from the CustomMarshallerAttribute</param>
/// <returns>A fully constructed marshaller type</returns>
public static bool TryResolveMarshallerType(INamedTypeSymbol entryPointType, ITypeSymbol? attributeMarshallerType, Action<Diagnostic> reportDiagnostic, [NotNullWhen(true)] out ITypeSymbol? marshallerType)
{
if (attributeMarshallerType is null)
{
marshallerType = null;
return false;
}

if (attributeMarshallerType is not INamedTypeSymbol namedMarshallerType)
{
marshallerType = attributeMarshallerType;
return true;
}

// Update the marshaller type with resolved type arguments based on the entry point type
// We expect the entry point to already have its type arguments updated based on the managed type
Stack<INamedTypeSymbol> nestedTypes = new();
INamedTypeSymbol currentType = namedMarshallerType;
int totalArity = 0;
while (currentType is not null)
{
nestedTypes.Push(currentType);
totalArity += currentType.Arity;
currentType = currentType.ContainingType;
}

if (totalArity != entryPointType.Arity)
{
//TODO: Report diagnostic
marshallerType = null;
return false;
}

int currentArityOffset = 0;
currentType = null;
while (nestedTypes.Count > 0)
{
if (currentType is null)
{
currentType = nestedTypes.Pop();
}
else
{
INamedTypeSymbol originalType = nestedTypes.Pop();
currentType = currentType.GetTypeMembers(originalType.Name, originalType.Arity).First();
}

if (currentType.TypeParameters.Length > 0)
{
currentType = currentType.ConstructedFrom.Construct(
ImmutableArray.CreateRange(entryPointType.TypeArguments, currentArityOffset, currentType.TypeParameters.Length, x => x),
ImmutableArray.CreateRange(entryPointType.TypeArgumentNullableAnnotations, currentArityOffset, currentType.TypeParameters.Length, x => x));
currentArityOffset += currentType.TypeParameters.Length;
}
}

marshallerType = currentType;
return true;
}

/// <summary>
/// Resolve a non-<see cref="INamedTypeSymbol"/> <paramref name="managedType"/> to the correct
/// managed type if <paramref name="entryType"/> is generic and <paramref name="managedType"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
using Microsoft.Interop.Analyzers;
using Xunit;

namespace LibraryImportGenerator.Unit.Tests
namespace LibraryImportGenerator.UnitTests
{
[ActiveIssue("https://github.com/dotnet/runtime/issues/60650", TestRuntimes.Mono)]
public class ShapeBreakingDiagnosticSuppressorTests
Expand All @@ -40,27 +40,15 @@ public struct ManagedToUnmanagedIn
{
public static int BufferSize { get; } = 1;
public void {|#0:FromManaged|}(S s)
{
}
public void {|#0:FromManaged|}(S s) {}
public void {|#1:FromManaged|}(S s, Span<byte> buffer)
{
}
public void {|#1:FromManaged|}(S s, Span<byte> buffer){}
public ManagedToUnmanagedIn {|#2:ToUnmanaged|}()
{
return default;
}
public ManagedToUnmanagedIn {|#2:ToUnmanaged|}() => default;
public void {|#3:FromUnmanaged|}(ManagedToUnmanagedIn unmanaged)
{
}
public void {|#3:FromUnmanaged|}(ManagedToUnmanagedIn unmanaged) {}
public S {|#4:ToManaged|}()
{
return default;
}
public S {|#4:ToManaged|}() => default;
public void {|#5:Free|}() {}
Expand All @@ -80,6 +68,56 @@ public struct ManagedToUnmanagedIn
SuppressedDiagnostic(ShapeBreakingDiagnosticSuppressor.MarkMethodsAsStaticSuppression, DiagnosticSeverity.Info).WithLocation(7));
}

[Fact]
public async Task StatefulLinearCollectionMarshallerMethodsThatDoNotUseInstanceState_SuppressesDiagnostic()
{
await VerifySuppressorAsync("""
using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices.Marshalling;
struct S
{
public bool b;
};
[CustomMarshaller(typeof(S), MarshalMode.ManagedToUnmanagedIn, typeof(Marshaller<>.ManagedToUnmanagedIn))]
[ContiguousCollectionMarshaller]
static class Marshaller<TNative>
{
public struct ManagedToUnmanagedIn
{
public void {|#0:FromManaged|}(S s) {}
public void {|#1:FromManaged|}(S s, Span<byte> buffer){}
public ManagedToUnmanagedIn {|#2:ToUnmanaged|}() => default;
public void {|#3:FromUnmanaged|}(ManagedToUnmanagedIn unmanaged) {}
public S {|#4:ToManaged|}() => default;
public ReadOnlySpan<int> {|#5:GetManagedValuesSource|}() => default;
public Span<TNative> {|#6:GetUnmanagedValuesDestination|}() => default;
public ReadOnlySpan<TNative> {|#7:GetUnmanagedValuesSource|}(int numElements) => default;
public Span<int> {|#8:GetManagedValuesDestination|}(int numElements) => default;
}
}
""",
SuppressedDiagnostic(ShapeBreakingDiagnosticSuppressor.MarkMethodsAsStaticSuppression, DiagnosticSeverity.Info).WithLocation(0),
SuppressedDiagnostic(ShapeBreakingDiagnosticSuppressor.MarkMethodsAsStaticSuppression, DiagnosticSeverity.Info).WithLocation(1),
SuppressedDiagnostic(ShapeBreakingDiagnosticSuppressor.MarkMethodsAsStaticSuppression, DiagnosticSeverity.Info).WithLocation(2),
SuppressedDiagnostic(ShapeBreakingDiagnosticSuppressor.MarkMethodsAsStaticSuppression, DiagnosticSeverity.Info).WithLocation(3),
SuppressedDiagnostic(ShapeBreakingDiagnosticSuppressor.MarkMethodsAsStaticSuppression, DiagnosticSeverity.Info).WithLocation(4),
SuppressedDiagnostic(ShapeBreakingDiagnosticSuppressor.MarkMethodsAsStaticSuppression, DiagnosticSeverity.Info).WithLocation(5),
SuppressedDiagnostic(ShapeBreakingDiagnosticSuppressor.MarkMethodsAsStaticSuppression, DiagnosticSeverity.Info).WithLocation(6),
SuppressedDiagnostic(ShapeBreakingDiagnosticSuppressor.MarkMethodsAsStaticSuppression, DiagnosticSeverity.Info).WithLocation(7),
SuppressedDiagnostic(ShapeBreakingDiagnosticSuppressor.MarkMethodsAsStaticSuppression, DiagnosticSeverity.Info).WithLocation(8));
}

[Fact]
public async Task MethodWithShapeMatchingNameButDifferingSignature_DoesNotSuppressDiagnostic()
{
Expand Down Expand Up @@ -153,7 +191,7 @@ private static async Task VerifySuppressorAsync(string source, params Diagnostic
await test.RunAsync(CancellationToken.None);
}

class Test : CSharpCodeFixVerifier<EmptyDiagnosticAnalyzer, EmptyCodeFixProvider>.Test
private class Test : CSharpCodeFixVerifier<EmptyDiagnosticAnalyzer, EmptyCodeFixProvider>.Test
{
public Test()
{
Expand Down

0 comments on commit 2ceab41

Please sign in to comment.