diff --git a/DllImportGenerator/DllImportGenerator/ArrayMarshallingCodeContext.cs b/DllImportGenerator/DllImportGenerator/ArrayMarshallingCodeContext.cs index 705d5d7b418..30e262e03e2 100644 --- a/DllImportGenerator/DllImportGenerator/ArrayMarshallingCodeContext.cs +++ b/DllImportGenerator/DllImportGenerator/ArrayMarshallingCodeContext.cs @@ -12,8 +12,11 @@ namespace Microsoft.Interop { internal sealed class ArrayMarshallingCodeContext : StubCodeContext { + public const string LocalManagedIdentifierSuffix = "_local"; + private readonly string indexerIdentifier; private readonly StubCodeContext parentContext; + private readonly bool appendLocalManagedIdentifierSuffix; public override bool PinningSupported => false; @@ -21,11 +24,28 @@ internal sealed class ArrayMarshallingCodeContext : StubCodeContext public override bool CanUseAdditionalTemporaryState => false; - public ArrayMarshallingCodeContext(Stage currentStage, string indexerIdentifier, StubCodeContext parentContext) + /// + /// Create a for marshalling elements of an array. + /// + /// The current marshalling stage. + /// The indexer in the loop to get the element to marshal from the array. + /// The parent context. + /// + /// For array marshalling, we sometimes cache the array in a local to avoid multithreading issues. + /// Set this to true to add the to the managed identifier when + /// marshalling the array elements to ensure that we use the local copy instead of the managed identifier + /// when marshalling elements. + /// + public ArrayMarshallingCodeContext( + Stage currentStage, + string indexerIdentifier, + StubCodeContext parentContext, + bool appendLocalManagedIdentifierSuffix) { CurrentStage = currentStage; this.indexerIdentifier = indexerIdentifier; this.parentContext = parentContext; + this.appendLocalManagedIdentifierSuffix = appendLocalManagedIdentifierSuffix; } /// @@ -36,6 +56,10 @@ public ArrayMarshallingCodeContext(Stage currentStage, string indexerIdentifier, public override (string managed, string native) GetIdentifiers(TypePositionInfo info) { var (managed, native) = parentContext.GetIdentifiers(info); + if (appendLocalManagedIdentifierSuffix) + { + managed += LocalManagedIdentifierSuffix; + } return ($"{managed}[{indexerIdentifier}]", $"{native}[{indexerIdentifier}]"); } diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/BlittableArrayMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/BlittableArrayMarshaller.cs index 31638278ac0..033aa7b7d71 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/BlittableArrayMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/BlittableArrayMarshaller.cs @@ -204,13 +204,13 @@ protected override ExpressionSyntax GenerateAllocationExpression(TypePositionInf protected override ExpressionSyntax GenerateByteLengthCalculationExpression(TypePositionInfo info, StubCodeContext context) { - // sizeof() * .Length - return BinaryExpression(SyntaxKind.MultiplyExpression, - SizeOfExpression(GetElementTypeSyntax(info)), - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(context.GetIdentifiers(info).managed), - IdentifierName("Length") - )); + // checked(sizeof() * .Length) + return CheckedExpression(SyntaxKind.CheckedExpression, + BinaryExpression(SyntaxKind.MultiplyExpression, + SizeOfExpression(GetElementTypeSyntax(info)), + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(context.GetIdentifiers(info).managed), + IdentifierName("Length")))); } protected override StatementSyntax GenerateStackallocOnlyValueMarshalling(TypePositionInfo info, StubCodeContext context, SyntaxToken byteLengthIdentifier, SyntaxToken stackAllocPtrIdentifier) diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ConditionalStackallocMarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ConditionalStackallocMarshallingGenerator.cs index 7250c4f8e8c..7bc2a943e45 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ConditionalStackallocMarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ConditionalStackallocMarshallingGenerator.cs @@ -79,9 +79,7 @@ protected IEnumerable GenerateConditionalAllocationSyntax( IdentifierName(nativeIdentifier), LiteralExpression(SyntaxKind.NullLiteralExpression))); yield return IfStatement( - BinaryExpression(SyntaxKind.NotEqualsExpression, - IdentifierName(managedIdentifier), - LiteralExpression(SyntaxKind.NullLiteralExpression)), + GenerateNullCheckExpression(info, context), Block(statements)); yield break; } @@ -137,24 +135,15 @@ protected IEnumerable GenerateConditionalAllocationSyntax( ElseClause(marshalOnStack)); yield return IfStatement( - BinaryExpression( - SyntaxKind.EqualsExpression, - IdentifierName(managedIdentifier), - LiteralExpression(SyntaxKind.NullLiteralExpression)), - Block( - ExpressionStatement( - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - IdentifierName(nativeIdentifier), - LiteralExpression(SyntaxKind.NullLiteralExpression)))), - ElseClause(Block(byteLenAssignment, allocBlock))); + GenerateNullCheckExpression(info, context), + Block(byteLenAssignment, allocBlock)); } protected StatementSyntax GenerateConditionalAllocationFreeSyntax( TypePositionInfo info, StubCodeContext context) { - (string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(info); + (string managedIdentifier, _) = context.GetIdentifiers(info); string allocationMarkerIdentifier = GetAllocationMarkerIdentifier(managedIdentifier); if (!UsesConditionalStackAlloc(info, context)) { @@ -219,6 +208,22 @@ protected abstract StatementSyntax GenerateStackallocOnlyValueMarshalling( protected abstract ExpressionSyntax GenerateFreeExpression( TypePositionInfo info, StubCodeContext context); + + /// + /// Generate code to check if the managed value is not null. + /// + /// Object to marshal + /// Code generation context + /// An expression that checks if the managed value is not null. + protected virtual ExpressionSyntax GenerateNullCheckExpression( + TypePositionInfo info, + StubCodeContext context) + { + return BinaryExpression( + SyntaxKind.NotEqualsExpression, + IdentifierName(context.GetIdentifiers(info).managed), + LiteralExpression(SyntaxKind.NullLiteralExpression)); + } /// public abstract TypeSyntax AsNativeType(TypePositionInfo info); diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/NonBlittableArrayMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/NonBlittableArrayMarshaller.cs index 6fe2811efcb..d4f9beebe40 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/NonBlittableArrayMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/NonBlittableArrayMarshaller.cs @@ -1,5 +1,5 @@ using System.Collections.Generic; - +using System.Linq; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -18,7 +18,7 @@ internal class NonBlittableArrayMarshaller : ConditionalStackallocMarshallingGen private const string IndexerIdentifier = "__i"; - private IMarshallingGenerator _elementMarshaller; + private readonly IMarshallingGenerator _elementMarshaller; private readonly ExpressionSyntax _numElementsExpr; public NonBlittableArrayMarshaller(IMarshallingGenerator elementMarshaller, ExpressionSyntax numElementsExpr) @@ -64,6 +64,8 @@ public override ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext public override IEnumerable Generate(TypePositionInfo info, StubCodeContext context) { var (managedIdentifer, nativeIdentifier) = context.GetIdentifiers(info); + bool cacheManagedValue = ShouldCacheManagedValue(info, context); + string managedLocal = !cacheManagedValue ? managedIdentifer : managedIdentifer + ArrayMarshallingCodeContext.LocalManagedIdentifierSuffix; switch (context.CurrentStage) { @@ -71,70 +73,91 @@ public override IEnumerable Generate(TypePositionInfo info, Stu if (TryGenerateSetupSyntax(info, context, out StatementSyntax conditionalAllocSetup)) yield return conditionalAllocSetup; + if (cacheManagedValue) + { + yield return LocalDeclarationStatement( + VariableDeclaration( + info.ManagedType.AsTypeSyntax(), + SingletonSeparatedList( + VariableDeclarator(managedLocal) + .WithInitializer(EqualsValueClause( + IdentifierName(managedIdentifer)))))); + } break; case StubCodeContext.Stage.Marshal: if (info.RefKind != RefKind.Out) { foreach (var statement in GenerateConditionalAllocationSyntax( - info, - context, - StackAllocBytesThreshold)) + info, + context, + StackAllocBytesThreshold)) { yield return statement; } // Iterate through the elements of the array to marshal them - var arraySubContext = new ArrayMarshallingCodeContext(context.CurrentStage, IndexerIdentifier, context); + var arraySubContext = new ArrayMarshallingCodeContext(context.CurrentStage, IndexerIdentifier, context, appendLocalManagedIdentifierSuffix: cacheManagedValue); yield return IfStatement(BinaryExpression(SyntaxKind.NotEqualsExpression, - IdentifierName(managedIdentifer), + IdentifierName(managedLocal), LiteralExpression(SyntaxKind.NullLiteralExpression)), - MarshallerHelpers.GetForLoop(managedIdentifer, IndexerIdentifier) + MarshallerHelpers.GetForLoop(managedLocal, IndexerIdentifier) .WithStatement(Block( - List(_elementMarshaller.Generate(info with { ManagedType = GetElementTypeSymbol(info) }, arraySubContext))))); + List(_elementMarshaller.Generate( + info with { ManagedType = GetElementTypeSymbol(info) }, + arraySubContext))))); } break; case StubCodeContext.Stage.Unmarshal: if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In)) { - var arraySubContext = new ArrayMarshallingCodeContext(context.CurrentStage, IndexerIdentifier, context); - + var arraySubContext = new ArrayMarshallingCodeContext(context.CurrentStage, IndexerIdentifier, context, appendLocalManagedIdentifierSuffix: cacheManagedValue); + yield return IfStatement( BinaryExpression(SyntaxKind.NotEqualsExpression, IdentifierName(nativeIdentifier), LiteralExpression(SyntaxKind.NullLiteralExpression)), Block( - // = new []; + // = new []; ExpressionStatement( AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, - IdentifierName(managedIdentifer), + IdentifierName(managedLocal), ArrayCreationExpression( ArrayType(GetElementTypeSymbol(info).AsTypeSyntax(), SingletonList(ArrayRankSpecifier( SingletonSeparatedList(_numElementsExpr))))))), // Iterate through the elements of the native array to unmarshal them - MarshallerHelpers.GetForLoop(managedIdentifer, IndexerIdentifier) + MarshallerHelpers.GetForLoop(managedLocal, IndexerIdentifier) .WithStatement(Block( List(_elementMarshaller.Generate( info with { ManagedType = GetElementTypeSymbol(info) }, arraySubContext))))), ElseClause( ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, - IdentifierName(managedIdentifer), + IdentifierName(managedLocal), LiteralExpression(SyntaxKind.NullLiteralExpression))))); + + if (cacheManagedValue) + { + yield return ExpressionStatement( + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(managedIdentifer), + IdentifierName(managedLocal)) + ); + } } break; case StubCodeContext.Stage.Cleanup: { - var arraySubContext = new ArrayMarshallingCodeContext(context.CurrentStage, IndexerIdentifier, context); + var arraySubContext = new ArrayMarshallingCodeContext(context.CurrentStage, IndexerIdentifier, context, appendLocalManagedIdentifierSuffix: cacheManagedValue); var elementCleanup = List(_elementMarshaller.Generate(info with { ManagedType = GetElementTypeSymbol(info) }, arraySubContext)); if (elementCleanup.Count != 0) { // Iterate through the elements of the native array to clean up any unmanaged resources. yield return IfStatement( BinaryExpression(SyntaxKind.NotEqualsExpression, - IdentifierName(managedIdentifer), + IdentifierName(managedLocal), LiteralExpression(SyntaxKind.NullLiteralExpression)), - MarshallerHelpers.GetForLoop(managedIdentifer, IndexerIdentifier) + MarshallerHelpers.GetForLoop(managedLocal, IndexerIdentifier) .WithStatement(Block(elementCleanup))); } yield return GenerateConditionalAllocationFreeSyntax(info, context); @@ -143,6 +166,11 @@ public override IEnumerable Generate(TypePositionInfo info, Stu } } + private static bool ShouldCacheManagedValue(TypePositionInfo info, StubCodeContext context) + { + return info.IsByRef && context.CanUseAdditionalTemporaryState; + } + public override bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) { return true; @@ -163,13 +191,18 @@ protected override ExpressionSyntax GenerateAllocationExpression(TypePositionInf protected override ExpressionSyntax GenerateByteLengthCalculationExpression(TypePositionInfo info, StubCodeContext context) { - // sizeof() * .Length - return BinaryExpression(SyntaxKind.MultiplyExpression, - SizeOfExpression(GetNativeElementTypeSyntax(info)), - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(context.GetIdentifiers(info).managed), - IdentifierName("Length") - )); + string managedIdentifier = context.GetIdentifiers(info).managed; + if (ShouldCacheManagedValue(info, context)) + { + managedIdentifier += ArrayMarshallingCodeContext.LocalManagedIdentifierSuffix; + } + // checked(sizeof() * .Length) + return CheckedExpression(SyntaxKind.CheckedExpression, + BinaryExpression(SyntaxKind.MultiplyExpression, + SizeOfExpression(GetNativeElementTypeSyntax(info)), + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(managedIdentifier), + IdentifierName("Length")))); } protected override StatementSyntax GenerateStackallocOnlyValueMarshalling(TypePositionInfo info, StubCodeContext context, SyntaxToken byteLengthIdentifier, SyntaxToken stackAllocPtrIdentifier) @@ -191,6 +224,19 @@ protected override ExpressionSyntax GenerateFreeExpression(TypePositionInfo info ParseTypeName("System.IntPtr"), IdentifierName(context.GetIdentifiers(info).native)))))); } - } + protected override ExpressionSyntax GenerateNullCheckExpression(TypePositionInfo info, StubCodeContext context) + { + string managedIdentifier = context.GetIdentifiers(info).managed; + if (ShouldCacheManagedValue(info, context)) + { + managedIdentifier += ArrayMarshallingCodeContext.LocalManagedIdentifierSuffix; + } + + return BinaryExpression( + SyntaxKind.NotEqualsExpression, + IdentifierName(managedIdentifier), + LiteralExpression(SyntaxKind.NullLiteralExpression)); + } + } }