Skip to content

Commit

Permalink
Add Fixer for xUnit1051 using TestContext.Current.CancellationToken
Browse files Browse the repository at this point in the history
  • Loading branch information
campersau committed Aug 25, 2024
1 parent b107296 commit 25e4a56
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 6 deletions.
107 changes: 107 additions & 0 deletions src/xunit.analyzers.fixes/X1000/UseCancellationTokenFixer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
using System.Collections.Generic;
using System.Composition;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;

namespace Xunit.Analyzers.Fixes;

[ExportCodeFixProvider(LanguageNames.CSharp), Shared]
public class UseCancellationTokenFixer : BatchedCodeFixProvider
{
public const string Key_UseCancellationTokenArgument = "xUnit1051_UseCancellationTokenArgument";

public UseCancellationTokenFixer() :
base(Descriptors.X1051_UseCancellationToken.Id)
{ }

public override async Task RegisterCodeFixesAsync(CodeFixContext context)
{
var semanticModel = await context.Document.GetSemanticModelAsync(context.CancellationToken).ConfigureAwait(false);
if (semanticModel is null)
return;

var testContextType = TypeSymbolFactory.TestContext_V3(semanticModel.Compilation);
if (testContextType is null)
return;

var root = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false);
if (root is null)
return;

var diagnostic = context.Diagnostics.FirstOrDefault();
if (diagnostic is null)
return;

if (!diagnostic.Properties.TryGetValue(Constants.Properties.ParameterName, out var parameterName))
return;
if (parameterName is null)
return;

if (!diagnostic.Properties.TryGetValue(Constants.Properties.ParameterIndex, out var parameterIndexText))
return;
if (!int.TryParse(parameterIndexText, out var parameterIndex))
return;

if (root.FindNode(diagnostic.Location.SourceSpan) is not InvocationExpressionSyntax invocation)
return;

var arguments = new List<ArgumentSyntax>(invocation.ArgumentList.Arguments);

for (var argumentIndex = 0; argumentIndex < arguments.Count; argumentIndex++)
{
if (arguments[argumentIndex].NameColon?.Name.Identifier.Text == parameterName)
{
parameterIndex = argumentIndex;
break;
}
}

context.RegisterCodeFix(
XunitCodeAction.Create(
async ct =>
{
var editor = await DocumentEditor.CreateAsync(context.Document, ct).ConfigureAwait(false);
var testContextCancellationTokenExpression = (ExpressionSyntax)editor.Generator.MemberAccessExpression(
editor.Generator.MemberAccessExpression(
editor.Generator.TypeExpression(testContextType),
"Current"
),
"CancellationToken"
);
if (parameterIndex < arguments.Count)
{
arguments[parameterIndex] = arguments[parameterIndex].WithExpression(testContextCancellationTokenExpression);
}
else
{
var argument = Argument(testContextCancellationTokenExpression);
if (parameterIndex > arguments.Count || arguments.Any(arg => arg.NameColon is not null))
{
argument = argument.WithNameColon(NameColon(parameterName));
}
arguments.Add(argument);
}
editor.ReplaceNode(
invocation,
invocation
.WithArgumentList(ArgumentList(SeparatedList(arguments)))
);
return editor.GetChangedDocument();
},
Key_UseCancellationTokenArgument,
"{0} TestContext.Current.CancellationToken", parameterIndex < arguments.Count ? "Use" : "Add"
),
context.Diagnostics
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
using System.Threading.Tasks;
using Xunit;
using Xunit.Analyzers.Fixes;
using Verify = CSharpVerifier<Xunit.Analyzers.UseCancellationToken>;

public class UseCancellationTokenFixerTests
{
[Fact]
public async Task UseCancellationTokenArgument()
{
var before = /* lang=c#-test */ """
using System.Threading;
using System.Threading.Tasks;
using Xunit;

public class TestClass {
[Fact]
public void TestMethod()
{
[|FunctionWithOverload(42)|];
[|FunctionWithOverload(42, default(CancellationToken))|];

[|FunctionWithDefaults()|];
[|FunctionWithDefaults(42)|];
[|FunctionWithDefaults(cancellationToken: default(CancellationToken))|];
[|FunctionWithDefaults(42, cancellationToken: default(CancellationToken))|];
}

void FunctionWithOverload(int _) { }
void FunctionWithOverload(int _1, CancellationToken _2) { }

void FunctionWithDefaults(int _1 = 2112, CancellationToken cancellationToken = default(CancellationToken)) { }
}
""";
var after = /* lang=c#-test */ """
using System.Threading;
using System.Threading.Tasks;
using Xunit;

public class TestClass {
[Fact]
public void TestMethod()
{
FunctionWithOverload(42, TestContext.Current.CancellationToken);
FunctionWithOverload(42, TestContext.Current.CancellationToken);

FunctionWithDefaults(cancellationToken: TestContext.Current.CancellationToken);
FunctionWithDefaults(42, TestContext.Current.CancellationToken);
FunctionWithDefaults(cancellationToken: TestContext.Current.CancellationToken);
FunctionWithDefaults(42, cancellationToken: TestContext.Current.CancellationToken);
}

void FunctionWithOverload(int _) { }
void FunctionWithOverload(int _1, CancellationToken _2) { }

void FunctionWithDefaults(int _1 = 2112, CancellationToken cancellationToken = default(CancellationToken)) { }
}
""";

await Verify.VerifyCodeFixV3(before, after, UseCancellationTokenFixer.Key_UseCancellationTokenArgument);
}

[Fact]
public async Task UseCancellationTokenArgument_AliasTestContext()
{
var before = /* lang=c#-test */ """
using System.Threading;
using System.Threading.Tasks;
using MyContext = Xunit.TestContext;

public class TestClass {
[Xunit.Fact]
public void TestMethod()
{
[|Function()|];
}

void Function(CancellationToken token = default(CancellationToken)) { }
}
""";
var after = /* lang=c#-test */ """
using System.Threading;
using System.Threading.Tasks;
using MyContext = Xunit.TestContext;

public class TestClass {
[Xunit.Fact]
public void TestMethod()
{
Function(MyContext.Current.CancellationToken);
}

void Function(CancellationToken token = default(CancellationToken)) { }
}
""";

await Verify.VerifyCodeFixV3(before, after, UseCancellationTokenFixer.Key_UseCancellationTokenArgument);
}
}
1 change: 1 addition & 0 deletions src/xunit.analyzers/Utility/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ public static class Xunit
public const string LongLivedMarshalByRefObject_Execution_V2 = "Xunit.LongLivedMarshalByRefObject";
public const string LongLivedMarshalByRefObject_RunnerUtility = "Xunit.Sdk.LongLivedMarshalByRefObject";
public const string MemberDataAttribute = "Xunit.MemberDataAttribute";
public const string TestContext_V3 = "Xunit.TestContext";
public const string TheoryAttribute = "Xunit.TheoryAttribute";
public const string TheoryData = "Xunit.TheoryData";
public const string TheoryDataRow_V3 = "Xunit.TheoryDataRow";
Expand Down
3 changes: 3 additions & 0 deletions src/xunit.analyzers/Utility/TypeSymbolFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,9 @@ public static INamedTypeSymbol String(Compilation compilation) =>
public static INamedTypeSymbol? TaskOfT(Compilation compilation) =>
Guard.ArgumentNotNull(compilation).GetTypeByMetadataName("System.Threading.Tasks.Task`1");

public static INamedTypeSymbol? TestContext_V3(Compilation compilation) =>
Guard.ArgumentNotNull(compilation).GetTypeByMetadataName(Constants.Types.Xunit.TestContext_V3);

public static INamedTypeSymbol? TheoryAttribute(Compilation compilation) =>
Guard.ArgumentNotNull(compilation).GetTypeByMetadataName(Constants.Types.Xunit.TheoryAttribute);

Expand Down
22 changes: 16 additions & 6 deletions src/xunit.analyzers/X1000/UseCancellationToken.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using System.Collections.Immutable;
using System.Globalization;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
Expand Down Expand Up @@ -49,14 +51,14 @@ public override void AnalyzeCompilation(
{
// Default parameter value
if (argument.ArgumentKind == ArgumentKind.DefaultValue)
Report(context, invocationOperation.Syntax.GetLocation());
Report(context, invocationOperation.Syntax.GetLocation(), argument.Parameter!);
// Explicit parameter value
else if (argument.Syntax is ArgumentSyntax argumentSyntax)
{
var kind = argumentSyntax.Expression.Kind();
if (kind == SyntaxKind.DefaultExpression || kind == SyntaxKind.DefaultLiteralExpression)
Report(context, invocationOperation.Syntax.GetLocation());
Report(context, invocationOperation.Syntax.GetLocation(), argument.Parameter!);
}
}
// Look for an overload with the exact same parameter types + a CancellationToken
Expand All @@ -80,7 +82,7 @@ public override void AnalyzeCompilation(
if (match)
{
Report(context, invocationOperation.Syntax.GetLocation());
Report(context, invocationOperation.Syntax.GetLocation(), method.Parameters.Last());
return;
}
}
Expand All @@ -89,13 +91,21 @@ public override void AnalyzeCompilation(

static void Report(
OperationAnalysisContext context,
Location location) =>
context.ReportDiagnostic(
Location location,
IParameterSymbol parameter)
{
var builder = ImmutableDictionary.CreateBuilder<string, string?>();
builder[Constants.Properties.ParameterName] = parameter.Name;
builder[Constants.Properties.ParameterIndex] = parameter.Ordinal.ToString(CultureInfo.InvariantCulture);

context.ReportDiagnostic(
Diagnostic.Create(
Descriptors.X1051_UseCancellationToken,
location
location,
builder.ToImmutable()
)
);
}
}

protected override bool ShouldAnalyze(XunitContext xunitContext) =>
Expand Down

0 comments on commit 25e4a56

Please sign in to comment.