Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add analyzer/codefix for usage of interpolated strings in raw query methods #30835

Merged
merged 23 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/EFCore.Analyzers/AnalyzerReleases.Shipped.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,11 @@ EF1001 | Usage | Warning | InternalUsageDiagnosticAnalyzer
### Removed Rules
Rule ID | Category | Severity | Notes
--------|----------|----------|--------------------
EF1000 | Security | Disabled | RawSqlStringInjectionDiagnosticAnalyzer, [Documentation](https://docs.microsoft.com/ef/core/querying/raw-sql)
EF1000 | Security | Disabled | RawSqlStringInjectionDiagnosticAnalyzer, [Documentation](https://docs.microsoft.com/ef/core/querying/raw-sql)

## Release 8.0.0

### New Rules
Rule ID | Category | Severity | Notes
--------|----------|----------|-------
EF1002 | Security | Warning | InterpolatedStringUsageInRawQueriesDiagnosticAnalyzer, [Documentation](https://learn.microsoft.com/en-us/ef/core/querying/sql-queries#passing-parameters)
21 changes: 21 additions & 0 deletions src/EFCore.Analyzers/CompilationExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.CodeAnalysis;

namespace Microsoft.EntityFrameworkCore;

internal static class CompilationExtensions
{
public static INamedTypeSymbol? DbSetType(this Compilation compilation)
=> compilation.GetTypeByMetadataName("Microsoft.EntityFrameworkCore.DbSet`1");

public static INamedTypeSymbol? DbContextType(this Compilation compilation)
=> compilation.GetTypeByMetadataName("Microsoft.EntityFrameworkCore.DbContext");

public static INamedTypeSymbol? RelationalQueryableExtensionsType(this Compilation compilation)
=> compilation.GetTypeByMetadataName("Microsoft.EntityFrameworkCore.RelationalQueryableExtensions");

public static INamedTypeSymbol? RelationalDatabaseFacadeExtensionsType(this Compilation compilation)
=> compilation.GetTypeByMetadataName("Microsoft.EntityFrameworkCore.RelationalDatabaseFacadeExtensions");
}
1 change: 1 addition & 0 deletions src/EFCore.Analyzers/EFCore.Analyzers.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="$(MicrosoftCodeAnalysisVersion)" PrivateAssets="all" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="$(MicrosoftCodeAnalysisVersion)" PrivateAssets="all" />
<PackageReference Update="NETStandard.Library" PrivateAssets="all" />
</ItemGroup>

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Immutable;
using System.Composition;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace Microsoft.EntityFrameworkCore;

[ExportCodeFixProvider(LanguageNames.CSharp, Name = nameof(InterpolatedStringUsageInRawQueriesCodeFixProvider))]
[Shared]
public sealed class InterpolatedStringUsageInRawQueriesCodeFixProvider : CodeFixProvider
{
public override ImmutableArray<string> FixableDiagnosticIds
=> ImmutableArray.Create(InterpolatedStringUsageInRawQueriesDiagnosticAnalyzer.Id);

public override FixAllProvider GetFixAllProvider()
=> WellKnownFixAllProviders.BatchFixer;

public override async Task RegisterCodeFixesAsync(CodeFixContext context)
{
var document = context.Document;
var span = context.Span;
var cancellationToken = context.CancellationToken;

// We report only 1 diagnostic per span, so this is ok
var diagnostic = context.Diagnostics.First();

var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);

if (root!.FindNode(diagnostic.Location.SourceSpan) is not SimpleNameSyntax simpleName)
{
Debug.Fail("Analyzer reported diagnostic not on a SimpleNameSyntax. This should never happen");
return;
}

var invocationSyntax = simpleName.FirstAncestorOrSelf<InvocationExpressionSyntax>();

if (invocationSyntax is null)
{
return;
}

var foundInterpolation = false;

// Not all reported by analyzer cases are fixable. If there is a mix of interpolated arguments and normal ones, e.g. `FromSqlRaw($"SELECT * FROM [Users] WHERE [Id] = {id}", id)`,
// then replacing `FromSqlRaw` to `FromSqlInterpolated` creates compiler error since there is no overload for this.
// We find such cases by walking through syntaxes of each argument and searching for first interpolated string. If there are arguments after it, we consider such case unfixable.
foreach (var argument in invocationSyntax.ArgumentList.Arguments)
{
if (argument.Expression is InterpolatedStringExpressionSyntax)
{
foundInterpolation = true;
continue;
}

if (!foundInterpolation)
{
continue;
}

return;
}

context.RegisterCodeFix(
CodeAction.Create(
AnalyzerStrings.InterpolatedStringUsageInRawQueriesCodeActionTitle,
_ => Task.FromResult(document.WithSyntaxRoot(root.ReplaceNode(simpleName, GetReplacementName(simpleName)))),
nameof(InterpolatedStringUsageInRawQueriesCodeFixProvider)),
diagnostic);
}

private static SimpleNameSyntax GetReplacementName(SimpleNameSyntax oldName)
{
var oldNameToken = oldName.Identifier;
var oldMethodName = oldNameToken.ValueText;

var replacementMethodName = InterpolatedStringUsageInRawQueriesDiagnosticAnalyzer.GetReplacementMethodName(oldMethodName);
Debug.Assert(replacementMethodName != oldMethodName, "At this point we must find correct replacement name");

var replacementToken = SyntaxFactory.Identifier(replacementMethodName).WithTriviaFrom(oldNameToken);
return oldName.WithIdentifier(replacementToken);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Operations;

namespace Microsoft.EntityFrameworkCore;

[DiagnosticAnalyzer(LanguageNames.CSharp)]
public sealed class InterpolatedStringUsageInRawQueriesDiagnosticAnalyzer : DiagnosticAnalyzer
{
public const string Id = "EF1002";

private static readonly DiagnosticDescriptor Descriptor
// HACK: Work around dotnet/roslyn-analyzers#5890 by not using target-typed new
= new DiagnosticDescriptor(
Id,
title: AnalyzerStrings.InterpolatedStringUsageInRawQueriesAnalyzerTitle,
messageFormat: AnalyzerStrings.InterpolatedStringUsageInRawQueriesMessageFormat,
category: "Security",
defaultSeverity: DiagnosticSeverity.Warning,
isEnabledByDefault: true);

public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics
=> ImmutableArray.Create(Descriptor);

public override void Initialize(AnalysisContext context)
{
context.EnableConcurrentExecution();
context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None);

context.RegisterOperationAction(AnalyzerInvocation, OperationKind.Invocation);
}

private void AnalyzerInvocation(OperationAnalysisContext context)
{
var invocation = (IInvocationOperation)context.Operation;
var targetMethod = invocation.TargetMethod;

var report = targetMethod.Name switch
{
"FromSqlRaw" => AnalyzeFromSqlRawInvocation(invocation),
"ExecuteSqlRaw" or "ExecuteSqlRawAsync" => AnalyzeExecuteSqlRawInvocation(invocation),
"SqlQueryRaw" => AnalyzeSqlQueryRawInvocation(invocation),
_ => false
};

if (report)
{
context.ReportDiagnostic(Diagnostic.Create(
Descriptor,
GetTargetLocation(invocation.Syntax),
DoctorKrolic marked this conversation as resolved.
Show resolved Hide resolved
targetMethod.Name,
GetReplacementMethodName(targetMethod.Name)));
}

static Location GetTargetLocation(SyntaxNode syntax)
{
if (syntax is not InvocationExpressionSyntax invocationExpression)
{
Debug.Fail("In theory should never happen");
return syntax.GetLocation();
}

var targetNode = invocationExpression.Expression;

while (targetNode is MemberAccessExpressionSyntax memberAccess)
{
targetNode = memberAccess.Name;
}

// Generic name case, e.g. `db.Database.SqlQueryRaw<int>(...)`.
// At this point `targetNode` is `SqlQueryRaw<int>`, but we need location of the actual identifier
if (targetNode is GenericNameSyntax genericName)
{
return genericName.Identifier.GetLocation();
}

// We should appear at name expression, representing method name token, e.g.:
// db.Users.[|FromSqlRaw|](...) or db.Database.[|ExecuteSqlRaw|](...)
return targetNode.GetLocation();
}
}

internal static string GetReplacementMethodName(string oldName) => oldName switch
{
"FromSqlRaw" => "FromSql",
"ExecuteSqlRaw" => "ExecuteSql",
"ExecuteSqlRawAsync" => "ExecuteSqlAsync",
"SqlQueryRaw" => "SqlQuery",
_ => oldName
};

private static bool AnalyzeFromSqlRawInvocation(IInvocationOperation invocation)
{
var targetMethod = invocation.TargetMethod;
Debug.Assert(targetMethod.Name == "FromSqlRaw");

var compilation = invocation.SemanticModel!.Compilation;
var correctFromSqlRaw = FromSqlRawMethod(compilation);

Debug.Assert(correctFromSqlRaw is not null, "Unable to find original `FromSqlRaw` method");

// Verify that the method is the one we analyze and its second argument, which corresponds to `string sql`, is an interpolated string
if (correctFromSqlRaw is null ||
!targetMethod.ConstructedFrom.Equals(correctFromSqlRaw, SymbolEqualityComparer.Default) ||
invocation.Arguments[1].Value is not IInterpolatedStringOperation interpolatedString)
{
return false;
}

// Report warning if interpolated string is not a constant and all its interpolations are not constants
return AnalyzeInterpolatedString(interpolatedString);
}

private static bool AnalyzeExecuteSqlRawInvocation(IInvocationOperation invocation)
{
var targetMethod = invocation.TargetMethod;
Debug.Assert(targetMethod.Name is "ExecuteSqlRaw" or "ExecuteSqlRawAsync");

var compilation = invocation.SemanticModel!.Compilation;

if (targetMethod.Name == "ExecuteSqlRaw")
{
var correctMethods = ExecuteSqlRawMethods(compilation);

Debug.Assert(correctMethods.Any(), "Unable to find any `ExecuteSqlRaw` methods");

if (!correctMethods.Contains(targetMethod.ConstructedFrom, SymbolEqualityComparer.Default))
{
return false;
}
}
else
{
var correctMethods = ExecuteSqlRawAsyncMethods(compilation);

Debug.Assert(correctMethods.Any(), "Unable to find any `ExecuteSqlRawAsync` methods");

if (!correctMethods.Contains(targetMethod.ConstructedFrom, SymbolEqualityComparer.Default))
{
return false;
}
}

// At this point assume that the method is correct since both `ExecuteSqlRaw` and `ExecuteSqlRawAsync` have multiple overloads.
// Checking for every possible one is too much work for almost no gain.
// So check whether the second argument, that corresponds to `string sql` parameter, is an interpolated string...
if (invocation.Arguments[1].Value is not IInterpolatedStringOperation interpolatedString)
{
return false;
}

// ...and report warning if interpolated string is not a constant and all its interpolations are not constants
return AnalyzeInterpolatedString(interpolatedString);
}

private static bool AnalyzeSqlQueryRawInvocation(IInvocationOperation invocation)
{
var targetMethod = invocation.TargetMethod;
Debug.Assert(targetMethod.Name == "SqlQueryRaw");

var compilation = invocation.SemanticModel!.Compilation;

var correctSqlQueryRaw = SqlQueryRawMethod(compilation);

Debug.Assert(correctSqlQueryRaw is not null, "Unable to find original `SqlQueryRaw` method");

// Verify that the method is the one we analyze and its second argument, which corresponds to `string sql`, is an interpolated string
if (correctSqlQueryRaw is null ||
!targetMethod.ConstructedFrom.Equals(correctSqlQueryRaw, SymbolEqualityComparer.Default) ||
invocation.Arguments[1].Value is not IInterpolatedStringOperation interpolatedString)
{
return false;
}

// Report warning if interpolated string is not a constant and all its interpolations are not constants
return AnalyzeInterpolatedString(interpolatedString);
}

private static bool AnalyzeInterpolatedString(IInterpolatedStringOperation interpolatedString)
{
if (interpolatedString.ConstantValue.HasValue)
{
return false;
}

foreach (var part in interpolatedString.Parts)
{
if (part is not IInterpolationOperation interpolation)
{
continue;
}

if (!interpolation.Expression.ConstantValue.HasValue)
{
// Found non-constant interpolation. Report it
return true;
}
}

return false;
}

private static IMethodSymbol? FromSqlRawMethod(Compilation compilation)
{
var type = compilation.RelationalQueryableExtensionsType();
return (IMethodSymbol?)type?.GetMembers("FromSqlRaw").FirstOrDefault(s => s is IMethodSymbol);
}

private static IEnumerable<IMethodSymbol> ExecuteSqlRawMethods(Compilation compilation)
{
var type = compilation.RelationalDatabaseFacadeExtensionsType();
return type?.GetMembers("ExecuteSqlRaw").Where(s => s is IMethodSymbol).Cast<IMethodSymbol>() ?? Array.Empty<IMethodSymbol>();
}

private static IEnumerable<IMethodSymbol> ExecuteSqlRawAsyncMethods(Compilation compilation)
{
var type = compilation.RelationalDatabaseFacadeExtensionsType();
return type?.GetMembers("ExecuteSqlRawAsync").Where(s => s is IMethodSymbol).Cast<IMethodSymbol>() ?? Array.Empty<IMethodSymbol>();
}

private static IMethodSymbol? SqlQueryRawMethod(Compilation compilation)
{
var type = compilation.RelationalDatabaseFacadeExtensionsType();
return (IMethodSymbol?)type?.GetMembers("SqlQueryRaw").FirstOrDefault(s => s is IMethodSymbol);
}
}
Loading