diff --git a/DllImportGenerator/DllImportGenerator/ArrayMarshallingCodeContext.cs b/DllImportGenerator/DllImportGenerator/ArrayMarshallingCodeContext.cs
index 705d5d7b418..30e262e03e2 100644
--- a/DllImportGenerator/DllImportGenerator/ArrayMarshallingCodeContext.cs
+++ b/DllImportGenerator/DllImportGenerator/ArrayMarshallingCodeContext.cs
@@ -12,8 +12,11 @@ namespace Microsoft.Interop
{
internal sealed class ArrayMarshallingCodeContext : StubCodeContext
{
+ public const string LocalManagedIdentifierSuffix = "_local";
+
private readonly string indexerIdentifier;
private readonly StubCodeContext parentContext;
+ private readonly bool appendLocalManagedIdentifierSuffix;
public override bool PinningSupported => false;
@@ -21,11 +24,28 @@ internal sealed class ArrayMarshallingCodeContext : StubCodeContext
public override bool CanUseAdditionalTemporaryState => false;
- public ArrayMarshallingCodeContext(Stage currentStage, string indexerIdentifier, StubCodeContext parentContext)
+ ///
+ /// Create a for marshalling elements of an array.
+ ///
+ /// The current marshalling stage.
+ /// The indexer in the loop to get the element to marshal from the array.
+ /// The parent context.
+ ///
+ /// For array marshalling, we sometimes cache the array in a local to avoid multithreading issues.
+ /// Set this to true to add the to the managed identifier when
+ /// marshalling the array elements to ensure that we use the local copy instead of the managed identifier
+ /// when marshalling elements.
+ ///
+ public ArrayMarshallingCodeContext(
+ Stage currentStage,
+ string indexerIdentifier,
+ StubCodeContext parentContext,
+ bool appendLocalManagedIdentifierSuffix)
{
CurrentStage = currentStage;
this.indexerIdentifier = indexerIdentifier;
this.parentContext = parentContext;
+ this.appendLocalManagedIdentifierSuffix = appendLocalManagedIdentifierSuffix;
}
///
@@ -36,6 +56,10 @@ public ArrayMarshallingCodeContext(Stage currentStage, string indexerIdentifier,
public override (string managed, string native) GetIdentifiers(TypePositionInfo info)
{
var (managed, native) = parentContext.GetIdentifiers(info);
+ if (appendLocalManagedIdentifierSuffix)
+ {
+ managed += LocalManagedIdentifierSuffix;
+ }
return ($"{managed}[{indexerIdentifier}]", $"{native}[{indexerIdentifier}]");
}
diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/BlittableArrayMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/BlittableArrayMarshaller.cs
index 31638278ac0..033aa7b7d71 100644
--- a/DllImportGenerator/DllImportGenerator/Marshalling/BlittableArrayMarshaller.cs
+++ b/DllImportGenerator/DllImportGenerator/Marshalling/BlittableArrayMarshaller.cs
@@ -204,13 +204,13 @@ protected override ExpressionSyntax GenerateAllocationExpression(TypePositionInf
protected override ExpressionSyntax GenerateByteLengthCalculationExpression(TypePositionInfo info, StubCodeContext context)
{
- // sizeof() * .Length
- return BinaryExpression(SyntaxKind.MultiplyExpression,
- SizeOfExpression(GetElementTypeSyntax(info)),
- MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
- IdentifierName(context.GetIdentifiers(info).managed),
- IdentifierName("Length")
- ));
+ // checked(sizeof() * .Length)
+ return CheckedExpression(SyntaxKind.CheckedExpression,
+ BinaryExpression(SyntaxKind.MultiplyExpression,
+ SizeOfExpression(GetElementTypeSyntax(info)),
+ MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName(context.GetIdentifiers(info).managed),
+ IdentifierName("Length"))));
}
protected override StatementSyntax GenerateStackallocOnlyValueMarshalling(TypePositionInfo info, StubCodeContext context, SyntaxToken byteLengthIdentifier, SyntaxToken stackAllocPtrIdentifier)
diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ConditionalStackallocMarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ConditionalStackallocMarshallingGenerator.cs
index 7250c4f8e8c..7bc2a943e45 100644
--- a/DllImportGenerator/DllImportGenerator/Marshalling/ConditionalStackallocMarshallingGenerator.cs
+++ b/DllImportGenerator/DllImportGenerator/Marshalling/ConditionalStackallocMarshallingGenerator.cs
@@ -79,9 +79,7 @@ protected IEnumerable GenerateConditionalAllocationSyntax(
IdentifierName(nativeIdentifier),
LiteralExpression(SyntaxKind.NullLiteralExpression)));
yield return IfStatement(
- BinaryExpression(SyntaxKind.NotEqualsExpression,
- IdentifierName(managedIdentifier),
- LiteralExpression(SyntaxKind.NullLiteralExpression)),
+ GenerateNullCheckExpression(info, context),
Block(statements));
yield break;
}
@@ -137,24 +135,15 @@ protected IEnumerable GenerateConditionalAllocationSyntax(
ElseClause(marshalOnStack));
yield return IfStatement(
- BinaryExpression(
- SyntaxKind.EqualsExpression,
- IdentifierName(managedIdentifier),
- LiteralExpression(SyntaxKind.NullLiteralExpression)),
- Block(
- ExpressionStatement(
- AssignmentExpression(
- SyntaxKind.SimpleAssignmentExpression,
- IdentifierName(nativeIdentifier),
- LiteralExpression(SyntaxKind.NullLiteralExpression)))),
- ElseClause(Block(byteLenAssignment, allocBlock)));
+ GenerateNullCheckExpression(info, context),
+ Block(byteLenAssignment, allocBlock));
}
protected StatementSyntax GenerateConditionalAllocationFreeSyntax(
TypePositionInfo info,
StubCodeContext context)
{
- (string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(info);
+ (string managedIdentifier, _) = context.GetIdentifiers(info);
string allocationMarkerIdentifier = GetAllocationMarkerIdentifier(managedIdentifier);
if (!UsesConditionalStackAlloc(info, context))
{
@@ -219,6 +208,22 @@ protected abstract StatementSyntax GenerateStackallocOnlyValueMarshalling(
protected abstract ExpressionSyntax GenerateFreeExpression(
TypePositionInfo info,
StubCodeContext context);
+
+ ///
+ /// Generate code to check if the managed value is not null.
+ ///
+ /// Object to marshal
+ /// Code generation context
+ /// An expression that checks if the managed value is not null.
+ protected virtual ExpressionSyntax GenerateNullCheckExpression(
+ TypePositionInfo info,
+ StubCodeContext context)
+ {
+ return BinaryExpression(
+ SyntaxKind.NotEqualsExpression,
+ IdentifierName(context.GetIdentifiers(info).managed),
+ LiteralExpression(SyntaxKind.NullLiteralExpression));
+ }
///
public abstract TypeSyntax AsNativeType(TypePositionInfo info);
diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/NonBlittableArrayMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/NonBlittableArrayMarshaller.cs
index 6fe2811efcb..d4f9beebe40 100644
--- a/DllImportGenerator/DllImportGenerator/Marshalling/NonBlittableArrayMarshaller.cs
+++ b/DllImportGenerator/DllImportGenerator/Marshalling/NonBlittableArrayMarshaller.cs
@@ -1,5 +1,5 @@
using System.Collections.Generic;
-
+using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
@@ -18,7 +18,7 @@ internal class NonBlittableArrayMarshaller : ConditionalStackallocMarshallingGen
private const string IndexerIdentifier = "__i";
- private IMarshallingGenerator _elementMarshaller;
+ private readonly IMarshallingGenerator _elementMarshaller;
private readonly ExpressionSyntax _numElementsExpr;
public NonBlittableArrayMarshaller(IMarshallingGenerator elementMarshaller, ExpressionSyntax numElementsExpr)
@@ -64,6 +64,8 @@ public override ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext
public override IEnumerable Generate(TypePositionInfo info, StubCodeContext context)
{
var (managedIdentifer, nativeIdentifier) = context.GetIdentifiers(info);
+ bool cacheManagedValue = ShouldCacheManagedValue(info, context);
+ string managedLocal = !cacheManagedValue ? managedIdentifer : managedIdentifer + ArrayMarshallingCodeContext.LocalManagedIdentifierSuffix;
switch (context.CurrentStage)
{
@@ -71,70 +73,91 @@ public override IEnumerable Generate(TypePositionInfo info, Stu
if (TryGenerateSetupSyntax(info, context, out StatementSyntax conditionalAllocSetup))
yield return conditionalAllocSetup;
+ if (cacheManagedValue)
+ {
+ yield return LocalDeclarationStatement(
+ VariableDeclaration(
+ info.ManagedType.AsTypeSyntax(),
+ SingletonSeparatedList(
+ VariableDeclarator(managedLocal)
+ .WithInitializer(EqualsValueClause(
+ IdentifierName(managedIdentifer))))));
+ }
break;
case StubCodeContext.Stage.Marshal:
if (info.RefKind != RefKind.Out)
{
foreach (var statement in GenerateConditionalAllocationSyntax(
- info,
- context,
- StackAllocBytesThreshold))
+ info,
+ context,
+ StackAllocBytesThreshold))
{
yield return statement;
}
// Iterate through the elements of the array to marshal them
- var arraySubContext = new ArrayMarshallingCodeContext(context.CurrentStage, IndexerIdentifier, context);
+ var arraySubContext = new ArrayMarshallingCodeContext(context.CurrentStage, IndexerIdentifier, context, appendLocalManagedIdentifierSuffix: cacheManagedValue);
yield return IfStatement(BinaryExpression(SyntaxKind.NotEqualsExpression,
- IdentifierName(managedIdentifer),
+ IdentifierName(managedLocal),
LiteralExpression(SyntaxKind.NullLiteralExpression)),
- MarshallerHelpers.GetForLoop(managedIdentifer, IndexerIdentifier)
+ MarshallerHelpers.GetForLoop(managedLocal, IndexerIdentifier)
.WithStatement(Block(
- List(_elementMarshaller.Generate(info with { ManagedType = GetElementTypeSymbol(info) }, arraySubContext)))));
+ List(_elementMarshaller.Generate(
+ info with { ManagedType = GetElementTypeSymbol(info) },
+ arraySubContext)))));
}
break;
case StubCodeContext.Stage.Unmarshal:
if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In))
{
- var arraySubContext = new ArrayMarshallingCodeContext(context.CurrentStage, IndexerIdentifier, context);
-
+ var arraySubContext = new ArrayMarshallingCodeContext(context.CurrentStage, IndexerIdentifier, context, appendLocalManagedIdentifierSuffix: cacheManagedValue);
+
yield return IfStatement(
BinaryExpression(SyntaxKind.NotEqualsExpression,
IdentifierName(nativeIdentifier),
LiteralExpression(SyntaxKind.NullLiteralExpression)),
Block(
- // = new [];
+ // = new [];
ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
- IdentifierName(managedIdentifer),
+ IdentifierName(managedLocal),
ArrayCreationExpression(
ArrayType(GetElementTypeSymbol(info).AsTypeSyntax(),
SingletonList(ArrayRankSpecifier(
SingletonSeparatedList(_numElementsExpr))))))),
// Iterate through the elements of the native array to unmarshal them
- MarshallerHelpers.GetForLoop(managedIdentifer, IndexerIdentifier)
+ MarshallerHelpers.GetForLoop(managedLocal, IndexerIdentifier)
.WithStatement(Block(
List(_elementMarshaller.Generate(
info with { ManagedType = GetElementTypeSymbol(info) },
arraySubContext))))),
ElseClause(
ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
- IdentifierName(managedIdentifer),
+ IdentifierName(managedLocal),
LiteralExpression(SyntaxKind.NullLiteralExpression)))));
+
+ if (cacheManagedValue)
+ {
+ yield return ExpressionStatement(
+ AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
+ IdentifierName(managedIdentifer),
+ IdentifierName(managedLocal))
+ );
+ }
}
break;
case StubCodeContext.Stage.Cleanup:
{
- var arraySubContext = new ArrayMarshallingCodeContext(context.CurrentStage, IndexerIdentifier, context);
+ var arraySubContext = new ArrayMarshallingCodeContext(context.CurrentStage, IndexerIdentifier, context, appendLocalManagedIdentifierSuffix: cacheManagedValue);
var elementCleanup = List(_elementMarshaller.Generate(info with { ManagedType = GetElementTypeSymbol(info) }, arraySubContext));
if (elementCleanup.Count != 0)
{
// Iterate through the elements of the native array to clean up any unmanaged resources.
yield return IfStatement(
BinaryExpression(SyntaxKind.NotEqualsExpression,
- IdentifierName(managedIdentifer),
+ IdentifierName(managedLocal),
LiteralExpression(SyntaxKind.NullLiteralExpression)),
- MarshallerHelpers.GetForLoop(managedIdentifer, IndexerIdentifier)
+ MarshallerHelpers.GetForLoop(managedLocal, IndexerIdentifier)
.WithStatement(Block(elementCleanup)));
}
yield return GenerateConditionalAllocationFreeSyntax(info, context);
@@ -143,6 +166,11 @@ public override IEnumerable Generate(TypePositionInfo info, Stu
}
}
+ private static bool ShouldCacheManagedValue(TypePositionInfo info, StubCodeContext context)
+ {
+ return info.IsByRef && context.CanUseAdditionalTemporaryState;
+ }
+
public override bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context)
{
return true;
@@ -163,13 +191,18 @@ protected override ExpressionSyntax GenerateAllocationExpression(TypePositionInf
protected override ExpressionSyntax GenerateByteLengthCalculationExpression(TypePositionInfo info, StubCodeContext context)
{
- // sizeof() * .Length
- return BinaryExpression(SyntaxKind.MultiplyExpression,
- SizeOfExpression(GetNativeElementTypeSyntax(info)),
- MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
- IdentifierName(context.GetIdentifiers(info).managed),
- IdentifierName("Length")
- ));
+ string managedIdentifier = context.GetIdentifiers(info).managed;
+ if (ShouldCacheManagedValue(info, context))
+ {
+ managedIdentifier += ArrayMarshallingCodeContext.LocalManagedIdentifierSuffix;
+ }
+ // checked(sizeof() * .Length)
+ return CheckedExpression(SyntaxKind.CheckedExpression,
+ BinaryExpression(SyntaxKind.MultiplyExpression,
+ SizeOfExpression(GetNativeElementTypeSyntax(info)),
+ MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName(managedIdentifier),
+ IdentifierName("Length"))));
}
protected override StatementSyntax GenerateStackallocOnlyValueMarshalling(TypePositionInfo info, StubCodeContext context, SyntaxToken byteLengthIdentifier, SyntaxToken stackAllocPtrIdentifier)
@@ -191,6 +224,19 @@ protected override ExpressionSyntax GenerateFreeExpression(TypePositionInfo info
ParseTypeName("System.IntPtr"),
IdentifierName(context.GetIdentifiers(info).native))))));
}
- }
+ protected override ExpressionSyntax GenerateNullCheckExpression(TypePositionInfo info, StubCodeContext context)
+ {
+ string managedIdentifier = context.GetIdentifiers(info).managed;
+ if (ShouldCacheManagedValue(info, context))
+ {
+ managedIdentifier += ArrayMarshallingCodeContext.LocalManagedIdentifierSuffix;
+ }
+
+ return BinaryExpression(
+ SyntaxKind.NotEqualsExpression,
+ IdentifierName(managedIdentifier),
+ LiteralExpression(SyntaxKind.NullLiteralExpression));
+ }
+ }
}