Skip to content

Commit

Permalink
Reliability improvements for array marshalling (#384)
Browse files Browse the repository at this point in the history
  • Loading branch information
jkoritzinsky authored Dec 5, 2020
1 parent a7ccdd8 commit 5a78d17
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,40 @@ namespace Microsoft.Interop
{
internal sealed class ArrayMarshallingCodeContext : StubCodeContext
{
public const string LocalManagedIdentifierSuffix = "_local";

private readonly string indexerIdentifier;
private readonly StubCodeContext parentContext;
private readonly bool appendLocalManagedIdentifierSuffix;

public override bool PinningSupported => false;

public override bool StackSpaceUsable => false;

public override bool CanUseAdditionalTemporaryState => false;

public ArrayMarshallingCodeContext(Stage currentStage, string indexerIdentifier, StubCodeContext parentContext)
/// <summary>
/// Create a <see cref="StubCodeContext"/> for marshalling elements of an array.
/// </summary>
/// <param name="currentStage">The current marshalling stage.</param>
/// <param name="indexerIdentifier">The indexer in the loop to get the element to marshal from the array.</param>
/// <param name="parentContext">The parent context.</param>
/// <param name="appendLocalManagedIdentifierSuffix">
/// For array marshalling, we sometimes cache the array in a local to avoid multithreading issues.
/// Set this to <c>true</c> to add the <see cref="LocalManagedIdentifierSuffix"/> to the managed identifier when
/// marshalling the array elements to ensure that we use the local copy instead of the managed identifier
/// when marshalling elements.
/// </param>
public ArrayMarshallingCodeContext(
Stage currentStage,
string indexerIdentifier,
StubCodeContext parentContext,
bool appendLocalManagedIdentifierSuffix)
{
CurrentStage = currentStage;
this.indexerIdentifier = indexerIdentifier;
this.parentContext = parentContext;
this.appendLocalManagedIdentifierSuffix = appendLocalManagedIdentifierSuffix;
}

/// <summary>
Expand All @@ -36,6 +56,10 @@ public ArrayMarshallingCodeContext(Stage currentStage, string indexerIdentifier,
public override (string managed, string native) GetIdentifiers(TypePositionInfo info)
{
var (managed, native) = parentContext.GetIdentifiers(info);
if (appendLocalManagedIdentifierSuffix)
{
managed += LocalManagedIdentifierSuffix;
}
return ($"{managed}[{indexerIdentifier}]", $"{native}[{indexerIdentifier}]");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,13 @@ protected override ExpressionSyntax GenerateAllocationExpression(TypePositionInf

protected override ExpressionSyntax GenerateByteLengthCalculationExpression(TypePositionInfo info, StubCodeContext context)
{
// sizeof(<nativeElementType>) * <managedIdentifier>.Length
return BinaryExpression(SyntaxKind.MultiplyExpression,
SizeOfExpression(GetElementTypeSyntax(info)),
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(context.GetIdentifiers(info).managed),
IdentifierName("Length")
));
// checked(sizeof(<nativeElementType>) * <managedIdentifier>.Length)
return CheckedExpression(SyntaxKind.CheckedExpression,
BinaryExpression(SyntaxKind.MultiplyExpression,
SizeOfExpression(GetElementTypeSyntax(info)),
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(context.GetIdentifiers(info).managed),
IdentifierName("Length"))));
}

protected override StatementSyntax GenerateStackallocOnlyValueMarshalling(TypePositionInfo info, StubCodeContext context, SyntaxToken byteLengthIdentifier, SyntaxToken stackAllocPtrIdentifier)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ protected IEnumerable<StatementSyntax> GenerateConditionalAllocationSyntax(
IdentifierName(nativeIdentifier),
LiteralExpression(SyntaxKind.NullLiteralExpression)));
yield return IfStatement(
BinaryExpression(SyntaxKind.NotEqualsExpression,
IdentifierName(managedIdentifier),
LiteralExpression(SyntaxKind.NullLiteralExpression)),
GenerateNullCheckExpression(info, context),
Block(statements));
yield break;
}
Expand Down Expand Up @@ -137,24 +135,15 @@ protected IEnumerable<StatementSyntax> GenerateConditionalAllocationSyntax(
ElseClause(marshalOnStack));

yield return IfStatement(
BinaryExpression(
SyntaxKind.EqualsExpression,
IdentifierName(managedIdentifier),
LiteralExpression(SyntaxKind.NullLiteralExpression)),
Block(
ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(nativeIdentifier),
LiteralExpression(SyntaxKind.NullLiteralExpression)))),
ElseClause(Block(byteLenAssignment, allocBlock)));
GenerateNullCheckExpression(info, context),
Block(byteLenAssignment, allocBlock));
}

protected StatementSyntax GenerateConditionalAllocationFreeSyntax(
TypePositionInfo info,
StubCodeContext context)
{
(string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(info);
(string managedIdentifier, _) = context.GetIdentifiers(info);
string allocationMarkerIdentifier = GetAllocationMarkerIdentifier(managedIdentifier);
if (!UsesConditionalStackAlloc(info, context))
{
Expand Down Expand Up @@ -219,6 +208,22 @@ protected abstract StatementSyntax GenerateStackallocOnlyValueMarshalling(
protected abstract ExpressionSyntax GenerateFreeExpression(
TypePositionInfo info,
StubCodeContext context);

/// <summary>
/// Generate code to check if the managed value is not null.
/// </summary>
/// <param name="info">Object to marshal</param>
/// <param name="context">Code generation context</param>
/// <returns>An expression that checks if the managed value is not null.</returns>
protected virtual ExpressionSyntax GenerateNullCheckExpression(
TypePositionInfo info,
StubCodeContext context)
{
return BinaryExpression(
SyntaxKind.NotEqualsExpression,
IdentifierName(context.GetIdentifiers(info).managed),
LiteralExpression(SyntaxKind.NullLiteralExpression));
}

/// <inheritdoc/>
public abstract TypeSyntax AsNativeType(TypePositionInfo info);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using System.Collections.Generic;

using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
Expand All @@ -18,7 +18,7 @@ internal class NonBlittableArrayMarshaller : ConditionalStackallocMarshallingGen

private const string IndexerIdentifier = "__i";

private IMarshallingGenerator _elementMarshaller;
private readonly IMarshallingGenerator _elementMarshaller;
private readonly ExpressionSyntax _numElementsExpr;

public NonBlittableArrayMarshaller(IMarshallingGenerator elementMarshaller, ExpressionSyntax numElementsExpr)
Expand Down Expand Up @@ -64,77 +64,100 @@ public override ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext
public override IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeContext context)
{
var (managedIdentifer, nativeIdentifier) = context.GetIdentifiers(info);
bool cacheManagedValue = ShouldCacheManagedValue(info, context);
string managedLocal = !cacheManagedValue ? managedIdentifer : managedIdentifer + ArrayMarshallingCodeContext.LocalManagedIdentifierSuffix;

switch (context.CurrentStage)
{
case StubCodeContext.Stage.Setup:
if (TryGenerateSetupSyntax(info, context, out StatementSyntax conditionalAllocSetup))
yield return conditionalAllocSetup;

if (cacheManagedValue)
{
yield return LocalDeclarationStatement(
VariableDeclaration(
info.ManagedType.AsTypeSyntax(),
SingletonSeparatedList(
VariableDeclarator(managedLocal)
.WithInitializer(EqualsValueClause(
IdentifierName(managedIdentifer))))));
}
break;
case StubCodeContext.Stage.Marshal:
if (info.RefKind != RefKind.Out)
{
foreach (var statement in GenerateConditionalAllocationSyntax(
info,
context,
StackAllocBytesThreshold))
info,
context,
StackAllocBytesThreshold))
{
yield return statement;
}

// Iterate through the elements of the array to marshal them
var arraySubContext = new ArrayMarshallingCodeContext(context.CurrentStage, IndexerIdentifier, context);
var arraySubContext = new ArrayMarshallingCodeContext(context.CurrentStage, IndexerIdentifier, context, appendLocalManagedIdentifierSuffix: cacheManagedValue);
yield return IfStatement(BinaryExpression(SyntaxKind.NotEqualsExpression,
IdentifierName(managedIdentifer),
IdentifierName(managedLocal),
LiteralExpression(SyntaxKind.NullLiteralExpression)),
MarshallerHelpers.GetForLoop(managedIdentifer, IndexerIdentifier)
MarshallerHelpers.GetForLoop(managedLocal, IndexerIdentifier)
.WithStatement(Block(
List(_elementMarshaller.Generate(info with { ManagedType = GetElementTypeSymbol(info) }, arraySubContext)))));
List(_elementMarshaller.Generate(
info with { ManagedType = GetElementTypeSymbol(info) },
arraySubContext)))));
}
break;
case StubCodeContext.Stage.Unmarshal:
if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In))
{
var arraySubContext = new ArrayMarshallingCodeContext(context.CurrentStage, IndexerIdentifier, context);
var arraySubContext = new ArrayMarshallingCodeContext(context.CurrentStage, IndexerIdentifier, context, appendLocalManagedIdentifierSuffix: cacheManagedValue);

yield return IfStatement(
BinaryExpression(SyntaxKind.NotEqualsExpression,
IdentifierName(nativeIdentifier),
LiteralExpression(SyntaxKind.NullLiteralExpression)),
Block(
// <managedIdentifier> = new <managedElementType>[<numElementsExpression>];
// <managedLocal> = new <managedElementType>[<numElementsExpression>];
ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(managedIdentifer),
IdentifierName(managedLocal),
ArrayCreationExpression(
ArrayType(GetElementTypeSymbol(info).AsTypeSyntax(),
SingletonList(ArrayRankSpecifier(
SingletonSeparatedList(_numElementsExpr))))))),
// Iterate through the elements of the native array to unmarshal them
MarshallerHelpers.GetForLoop(managedIdentifer, IndexerIdentifier)
MarshallerHelpers.GetForLoop(managedLocal, IndexerIdentifier)
.WithStatement(Block(
List(_elementMarshaller.Generate(
info with { ManagedType = GetElementTypeSymbol(info) },
arraySubContext))))),
ElseClause(
ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(managedIdentifer),
IdentifierName(managedLocal),
LiteralExpression(SyntaxKind.NullLiteralExpression)))));

if (cacheManagedValue)
{
yield return ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(managedIdentifer),
IdentifierName(managedLocal))
);
}
}
break;
case StubCodeContext.Stage.Cleanup:
{
var arraySubContext = new ArrayMarshallingCodeContext(context.CurrentStage, IndexerIdentifier, context);
var arraySubContext = new ArrayMarshallingCodeContext(context.CurrentStage, IndexerIdentifier, context, appendLocalManagedIdentifierSuffix: cacheManagedValue);
var elementCleanup = List(_elementMarshaller.Generate(info with { ManagedType = GetElementTypeSymbol(info) }, arraySubContext));
if (elementCleanup.Count != 0)
{
// Iterate through the elements of the native array to clean up any unmanaged resources.
yield return IfStatement(
BinaryExpression(SyntaxKind.NotEqualsExpression,
IdentifierName(managedIdentifer),
IdentifierName(managedLocal),
LiteralExpression(SyntaxKind.NullLiteralExpression)),
MarshallerHelpers.GetForLoop(managedIdentifer, IndexerIdentifier)
MarshallerHelpers.GetForLoop(managedLocal, IndexerIdentifier)
.WithStatement(Block(elementCleanup)));
}
yield return GenerateConditionalAllocationFreeSyntax(info, context);
Expand All @@ -143,6 +166,11 @@ public override IEnumerable<StatementSyntax> Generate(TypePositionInfo info, Stu
}
}

private static bool ShouldCacheManagedValue(TypePositionInfo info, StubCodeContext context)
{
return info.IsByRef && context.CanUseAdditionalTemporaryState;
}

public override bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context)
{
return true;
Expand All @@ -163,13 +191,18 @@ protected override ExpressionSyntax GenerateAllocationExpression(TypePositionInf

protected override ExpressionSyntax GenerateByteLengthCalculationExpression(TypePositionInfo info, StubCodeContext context)
{
// sizeof(<nativeElementType>) * <managedIdentifier>.Length
return BinaryExpression(SyntaxKind.MultiplyExpression,
SizeOfExpression(GetNativeElementTypeSyntax(info)),
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(context.GetIdentifiers(info).managed),
IdentifierName("Length")
));
string managedIdentifier = context.GetIdentifiers(info).managed;
if (ShouldCacheManagedValue(info, context))
{
managedIdentifier += ArrayMarshallingCodeContext.LocalManagedIdentifierSuffix;
}
// checked(sizeof(<nativeElementType>) * <managedIdentifier>.Length)
return CheckedExpression(SyntaxKind.CheckedExpression,
BinaryExpression(SyntaxKind.MultiplyExpression,
SizeOfExpression(GetNativeElementTypeSyntax(info)),
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(managedIdentifier),
IdentifierName("Length"))));
}

protected override StatementSyntax GenerateStackallocOnlyValueMarshalling(TypePositionInfo info, StubCodeContext context, SyntaxToken byteLengthIdentifier, SyntaxToken stackAllocPtrIdentifier)
Expand All @@ -191,6 +224,19 @@ protected override ExpressionSyntax GenerateFreeExpression(TypePositionInfo info
ParseTypeName("System.IntPtr"),
IdentifierName(context.GetIdentifiers(info).native))))));
}
}

protected override ExpressionSyntax GenerateNullCheckExpression(TypePositionInfo info, StubCodeContext context)
{
string managedIdentifier = context.GetIdentifiers(info).managed;
if (ShouldCacheManagedValue(info, context))
{
managedIdentifier += ArrayMarshallingCodeContext.LocalManagedIdentifierSuffix;
}

return BinaryExpression(
SyntaxKind.NotEqualsExpression,
IdentifierName(managedIdentifier),
LiteralExpression(SyntaxKind.NullLiteralExpression));
}
}
}

0 comments on commit 5a78d17

Please sign in to comment.