Skip to content

Commit

Permalink
Ensure CA1835 preserves nullability (#4664)
Browse files Browse the repository at this point in the history
* Ensure CA1835 preserves nullability

* Remove spacing from VB file.

* Add extra nullability sub-cases in new unit tests.

* Remove extra space causing CI failure.

* Add nullability tests with CancellationToken

Co-authored-by: carlossanlop <carlossanlop@users.noreply.github.com>
  • Loading branch information
carlossanlop and carlossanlop authored Jan 20, 2021
1 parent c480509 commit 33582d0
Show file tree
Hide file tree
Showing 6 changed files with 367 additions and 93 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.NetCore.Analyzers.Runtime;

Expand All @@ -14,30 +14,44 @@ namespace Microsoft.NetCore.CSharp.Analyzers.Runtime
[ExportCodeFixProvider(LanguageNames.CSharp)]
public sealed class CSharpPreferStreamAsyncMemoryOverloadsFixer : PreferStreamAsyncMemoryOverloadsFixer
{
protected override IArgumentOperation? GetArgumentByPositionOrName(ImmutableArray<IArgumentOperation> args, int index, string name, out bool isNamed)
protected override SyntaxNode? GetArgumentByPositionOrName(IInvocationOperation invocation, int index, string name, out bool isNamed)
{
isNamed = false;

// The expected position is beyond the total arguments, so we don't expect to find the argument in the array
if (index >= args.Length)
if (index < invocation.Arguments.Length &&
invocation.Syntax is InvocationExpressionSyntax expression)
{
return null;
}
// If the argument in the specified index does not have a name, then it is in its expected position
else if (args[index].Syntax is ArgumentSyntax argNode && argNode.NameColon == null)
{
return args[index];
}
// Otherwise, find it by name
else
{
isNamed = true;
return args.FirstOrDefault(argOperation =>
var args = invocation.Arguments;
// If the argument in the specified index does not have a name, then it is in its expected position
if (args[index].Syntax is ArgumentSyntax argNode && argNode.NameColon == null)
{
return args[index].Syntax;
}
// The argument in the specified index does not have a name but is part of a nullable expression
else if (args[index].Syntax is IdentifierNameSyntax identifierNameNode &&
identifierNameNode.Identifier.Value.Equals(name) &&
identifierNameNode.Parent is PostfixUnaryExpressionSyntax nullableExpression)
{
return nullableExpression;
}
// Otherwise, find it by name
else
{
return argOperation.Syntax is ArgumentSyntax argNode &&
argNode.NameColon?.Name?.Identifier.ValueText == name;
});
IArgumentOperation? operation = args.FirstOrDefault(argOperation =>
{
return argOperation.Syntax is ArgumentSyntax argNode &&
argNode.NameColon?.Name?.Identifier.ValueText == name;
});

if (operation != null)
{
isNamed = true;
return operation.Syntax;
}
}
}

return null;
}

protected override bool IsSystemNamespaceImported(IReadOnlyList<SyntaxNode> importList)
Expand All @@ -54,18 +68,70 @@ protected override bool IsSystemNamespaceImported(IReadOnlyList<SyntaxNode> impo

protected override bool IsPassingZeroAndBufferLength(SemanticModel model, SyntaxNode bufferValueNode, SyntaxNode offsetValueNode, SyntaxNode countValueNode)
{
return
// First argument should be an identifier name node
bufferValueNode is IdentifierNameSyntax firstArgumentIdentifierName &&
// First argument should be an identifier name node
if (bufferValueNode is ArgumentSyntax arg1 &&
arg1.Expression is IdentifierNameSyntax firstArgumentIdentifierName)
{
// Second argument should be a literal expression node with a constant value of zero
model.GetConstantValue(offsetValueNode) is Optional<object> optionalValue && optionalValue.HasValue && optionalValue.Value is 0 &&
// Third argument should be a member access node...
countValueNode is MemberAccessExpressionSyntax thirdArgumentMemberAccessExpression &&
thirdArgumentMemberAccessExpression.Expression is IdentifierNameSyntax thirdArgumentIdentifierName &&
// whose identifier is that of the first argument...
firstArgumentIdentifierName.Identifier.ValueText == thirdArgumentIdentifierName.Identifier.ValueText &&
// and the member name is `Length`
thirdArgumentMemberAccessExpression.Name.Identifier.ValueText == WellKnownMemberNames.LengthPropertyName;
if (offsetValueNode is ArgumentSyntax arg2 &&
arg2.Expression is LiteralExpressionSyntax literal &&
literal.Token.Value is int value && value == 0)
{
// Third argument should be a member access node...
if (countValueNode is ArgumentSyntax arg3 &&
arg3.Expression is MemberAccessExpressionSyntax thirdArgumentMemberAccessExpression &&
thirdArgumentMemberAccessExpression.Expression is IdentifierNameSyntax thirdArgumentIdentifierName &&
// whose identifier is that of the first argument...
firstArgumentIdentifierName.Identifier.ValueText == thirdArgumentIdentifierName.Identifier.ValueText &&
// and the member name is `Length`
thirdArgumentMemberAccessExpression.Name.Identifier.ValueText == WellKnownMemberNames.LengthPropertyName)
{
return true;
}
}
}
return false;
}

protected override SyntaxNode GetNodeWithNullability(IInvocationOperation invocation)
{
if (invocation.Syntax is InvocationExpressionSyntax invocationExpression &&
invocationExpression.Expression is MemberAccessExpressionSyntax memberAccessExpression &&
memberAccessExpression.Expression is PostfixUnaryExpressionSyntax postfixUnaryExpression)
{
return postfixUnaryExpression;
}

return invocation.Instance.Syntax;
}

protected override SyntaxNode GetNamedArgument(SyntaxGenerator generator, SyntaxNode node, bool isNamed, string newName)
{
if (isNamed)
{
SyntaxNode actualNode = node;

if (node is ArgumentSyntax argument)
{
actualNode = argument.Expression;
}

return generator.Argument(name: newName, RefKind.None, actualNode);
}

return node;
}

protected override SyntaxNode GetNamedMemberInvocation(SyntaxGenerator generator, SyntaxNode node, string memberName)
{
SyntaxNode actualNode = node;

if (node is ArgumentSyntax argument)
{
actualNode = argument.Expression;
}

return generator.MemberAccessExpression(actualNode.WithoutTrivia(), memberName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace Microsoft.NetCore.Analyzers.Runtime
public abstract class PreferStreamAsyncMemoryOverloadsFixer : CodeFixProvider
{
// Checks if the argument in the specified index has a name. If it doesn't, returns that arguments. If it does, then looks for the argument using the specified name, and returns it, or null if not found.
protected abstract IArgumentOperation? GetArgumentByPositionOrName(ImmutableArray<IArgumentOperation> args, int index, string name, out bool isNamed);
protected abstract SyntaxNode? GetArgumentByPositionOrName(IInvocationOperation invocation, int index, string name, out bool isNamed);

// Verifies if a namespace has already been added to the usings/imports list.
protected abstract bool IsSystemNamespaceImported(IReadOnlyList<SyntaxNode> importList);
Expand All @@ -41,6 +41,15 @@ public abstract class PreferStreamAsyncMemoryOverloadsFixer : CodeFixProvider
// where `buffer` is the name of the variable passed as the 0th argument.
protected abstract bool IsPassingZeroAndBufferLength(SemanticModel model, SyntaxNode bufferValueNode, SyntaxNode offsetValueNode, SyntaxNode countValueNode);

// Ensures the invocation node is returned with nullability.
protected abstract SyntaxNode GetNodeWithNullability(IInvocationOperation invocation);

// Ensures the argument is retrieved with the name and nullability.
protected abstract SyntaxNode GetNamedArgument(SyntaxGenerator generator, SyntaxNode node, bool isNamed, string newName);

// Ensures the member invocation is retrieved with the name and nullability.
protected abstract SyntaxNode GetNamedMemberInvocation(SyntaxGenerator generator, SyntaxNode node, string memberName);

public sealed override ImmutableArray<string> FixableDiagnosticIds =>
ImmutableArray.Create(PreferStreamAsyncMemoryOverloads.RuleId);

Expand Down Expand Up @@ -71,35 +80,35 @@ public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context)
return;
}

IArgumentOperation? bufferOperation = GetArgumentByPositionOrName(invocation.Arguments, 0, "buffer", out bool isBufferNamed);
if (bufferOperation == null)
SyntaxNode? bufferNode = GetArgumentByPositionOrName(invocation, 0, "buffer", out bool isBufferNamed);
if (bufferNode == null)
{
return;
}

IArgumentOperation? offsetOperation = GetArgumentByPositionOrName(invocation.Arguments, 1, "offset", out bool isOffsetNamed);
if (offsetOperation == null)
SyntaxNode? offsetNode = GetArgumentByPositionOrName(invocation, 1, "offset", out bool isOffsetNamed);
if (offsetNode == null)
{
return;
}

IArgumentOperation? countOperation = GetArgumentByPositionOrName(invocation.Arguments, 2, "count", out bool isCountNamed);
if (countOperation == null)
SyntaxNode? countNode = GetArgumentByPositionOrName(invocation, 2, "count", out bool isCountNamed);
if (countNode == null)
{
return;
}

// No nullcheck for this, because there is an overload that may not contain it
IArgumentOperation? cancellationTokenOperation = GetArgumentByPositionOrName(invocation.Arguments, 3, "cancellationToken", out bool isCancellationTokenNamed);
SyntaxNode? cancellationTokenNode = GetArgumentByPositionOrName(invocation, 3, "cancellationToken", out bool isCancellationTokenNamed);

string title = MicrosoftNetCoreAnalyzersResources.PreferStreamAsyncMemoryOverloadsTitle;

Task<Document> createChangedDocument(CancellationToken _) => FixInvocation(model, doc, root,
invocation, invocation.TargetMethod.Name,
bufferOperation.Value.Syntax, isBufferNamed,
offsetOperation.Value.Syntax, isOffsetNamed,
countOperation.Value.Syntax, isCountNamed,
cancellationTokenOperation?.Value.Syntax, isCancellationTokenNamed);
bufferNode, isBufferNamed,
offsetNode, isOffsetNamed,
countNode, isCountNamed,
cancellationTokenNode, isCancellationTokenNamed);

context.RegisterCodeFix(
new MyCodeAction(
Expand All @@ -111,56 +120,55 @@ Task<Document> createChangedDocument(CancellationToken _) => FixInvocation(model

private Task<Document> FixInvocation(SemanticModel model, Document doc, SyntaxNode root,
IInvocationOperation invocation, string methodName,
SyntaxNode bufferValueNode, bool isBufferNamed,
SyntaxNode offsetValueNode, bool isOffsetNamed,
SyntaxNode countValueNode, bool isCountNamed,
SyntaxNode? cancellationTokenValueNode, bool isCancellationTokenNamed)
SyntaxNode bufferNode, bool isBufferNamed,
SyntaxNode offsetNode, bool isOffsetNamed,
SyntaxNode countNode, bool isCountNamed,
SyntaxNode? cancellationTokenNode, bool isCancellationTokenNamed)
{
SyntaxGenerator generator = SyntaxGenerator.GetGenerator(doc);

// The stream-derived instance
SyntaxNode streamInstanceNode = invocation.Instance.Syntax;
SyntaxNode streamInstanceNode = GetNodeWithNullability(invocation);

// Depending on the arguments being passed to Read/WriteAsync, it's the substitution we will make
SyntaxNode replacedInvocationNode;

if (IsPassingZeroAndBufferLength(model, bufferValueNode, offsetValueNode, countValueNode))
if (IsPassingZeroAndBufferLength(model, bufferNode, offsetNode, countNode))
{
// Remove 0 and buffer.length
replacedInvocationNode =
(isBufferNamed ? generator.Argument(name: "buffer", RefKind.None, bufferValueNode) : bufferValueNode)
.WithTriviaFrom(bufferValueNode);
GetNamedArgument(generator, bufferNode, isBufferNamed, "buffer")
.WithTriviaFrom(bufferNode);
}
else
{
// buffer.AsMemory(int start, int length)
// offset should become start
// count should become length
SyntaxNode namedStartNode = isOffsetNamed ? generator.Argument(name: "start", RefKind.None, offsetValueNode) : offsetValueNode;
SyntaxNode namedLengthNode = isCountNamed ? generator.Argument(name: "length", RefKind.None, countValueNode) : countValueNode;
SyntaxNode namedStartNode = GetNamedArgument(generator, offsetNode, isOffsetNamed, "start");
SyntaxNode namedLengthNode = GetNamedArgument(generator, countNode, isCountNamed, "length");

// Generate an invocation of the AsMemory() method from the byte array object, using the correct named arguments
SyntaxNode asMemoryExpressionNode = generator.MemberAccessExpression(bufferValueNode.WithoutTrivia(), memberName: "AsMemory");
SyntaxNode asMemoryExpressionNode = GetNamedMemberInvocation(generator, bufferNode, "AsMemory");
SyntaxNode asMemoryInvocationNode = generator.InvocationExpression(
asMemoryExpressionNode,
namedStartNode.WithTriviaFrom(offsetValueNode),
namedLengthNode.WithTriviaFrom(countValueNode));
namedStartNode.WithTriviaFrom(offsetNode),
namedLengthNode.WithTriviaFrom(countNode));

// Generate the new buffer argument, ensuring we include the buffer argument name if the user originally indicated one
replacedInvocationNode =
(isBufferNamed ? generator.Argument(name: "buffer", RefKind.None, asMemoryInvocationNode) : asMemoryInvocationNode)
.WithTriviaFrom(bufferValueNode);
replacedInvocationNode = GetNamedArgument(generator, asMemoryInvocationNode, isBufferNamed, "buffer")
.WithTriviaFrom(bufferNode);
}

// Create an async method call for the stream object with no arguments
SyntaxNode asyncMethodNode = generator.MemberAccessExpression(streamInstanceNode, methodName);

// Add the arguments to the async method call, with or without CancellationToken
SyntaxNode[] nodeArguments;
if (cancellationTokenValueNode != null)
if (cancellationTokenNode != null)
{
SyntaxNode namedCancellationTokenNode = isCancellationTokenNamed ? generator.Argument(name: "cancellationToken", RefKind.None, cancellationTokenValueNode) : cancellationTokenValueNode;
nodeArguments = new SyntaxNode[] { replacedInvocationNode, namedCancellationTokenNode.WithTriviaFrom(cancellationTokenValueNode) };
SyntaxNode namedCancellationTokenNode = GetNamedArgument(generator, cancellationTokenNode, isCancellationTokenNamed, "cancellationToken");
nodeArguments = new SyntaxNode[] { replacedInvocationNode, namedCancellationTokenNode.WithTriviaFrom(cancellationTokenNode) };
}
else
{
Expand Down
Loading

0 comments on commit 33582d0

Please sign in to comment.