Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IAST] Safeguard Method Replace aspects with try/catch (#5841 -> v2) #5855

Merged
merged 1 commit into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
// <copyright file="ReplaceAspectAnalyzer.cs" company="Datadog">
// Unless explicitly stated otherwise all files in this repository are licensed under the Apache 2 License.
// This product includes software developed at Datadog (https://www.datadoghq.com/). Copyright 2017 Datadog, Inc.
// </copyright>

#nullable enable
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;

namespace Datadog.Trace.Tools.Analyzers.AspectAnalyzers;

/// <summary>
/// An analyzer that analyzers aspects that use [AspectMethodInsertAfter] and [AspectMethodInsertBefore]
/// for example, and checks that they are all wrapped in a try-catch block. These methods should never throw
/// so they should always have a try-catch block around them.
/// </summary>
[DiagnosticAnalyzer(LanguageNames.CSharp)]
public class ReplaceAspectAnalyzer : DiagnosticAnalyzer
{
/// <summary>
/// The diagnostic ID displayed in error messages
/// </summary>
public const string DiagnosticId = "DD0005";

/// <summary>
/// The severity of the diagnostic
/// </summary>
public const DiagnosticSeverity Severity = DiagnosticSeverity.Error;

#pragma warning disable RS2008 // Enable analyzer release tracking for the analyzer project
private static readonly DiagnosticDescriptor MissingTryCatchRule = new(
DiagnosticId,
title: "Aspect is in incorrect format",
messageFormat: "Aspect method bodies should contain a single expression to set the result variable, and then have a try-catch block, and then return the created variable",
category: "Reliability",
defaultSeverity: Severity,
isEnabledByDefault: true,
description: "[AspectCtorReplace] and [AspectMethodReplace] Aspects should guarantee safety if possible. Please execute the target method first, then wrap the remainder of the aspect in a try-catch block, and finally return the variable.");
#pragma warning restore RS2008

/// <inheritdoc />
public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics { get; } = ImmutableArray.Create(MissingTryCatchRule);

/// <inheritdoc />
public override void Initialize(AnalysisContext context)
{
context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None);
context.EnableConcurrentExecution();

// Consider registering other actions that act on syntax instead of or in addition to symbols
// See https://github.com/dotnet/roslyn/blob/master/docs/analyzers/Analyzer%20Actions%20Semantics.md for more information
context.RegisterSyntaxNodeAction(AnalyseMethod, SyntaxKind.MethodDeclaration);
}

private void AnalyseMethod(SyntaxNodeAnalysisContext context)
{
// assume that generated code is safe, so bail out for perf reasons
if (context.IsGeneratedCode || context.Node is not MethodDeclarationSyntax methodDeclaration)
{
return;
}

var attributes = methodDeclaration.AttributeLists;
if (!attributes.Any())
{
// no attributes, let's just bail
return;
}

var hasAspectAttribute = false;
foreach (var attributeList in attributes)
{
foreach (var attribute in attributeList.Attributes)
{
var name = attribute.Name.ToString();
if (name is "AspectCtorReplace" or "AspectMethodReplace"
or "AspectCtorReplaceAttribute" or "AspectMethodReplaceAttribute")
{
hasAspectAttribute = true;
break;
}
}
}

if (!hasAspectAttribute)
{
// not an aspect
return;
}

var bodyBlock = methodDeclaration.Body;
var isVoidMethod = methodDeclaration.ReturnType is PredefinedTypeSyntax { Keyword.Text: "void" };
int expectedStatements = isVoidMethod ? 2 : 3;

if (bodyBlock is null)
{
// If we don't have a bodyBlock, it's probably a lambda or expression bodied member
// These can't have try catch blocks, so we should bail out
var location = methodDeclaration.ExpressionBody?.GetLocation() ?? methodDeclaration.GetLocation();
context.ReportDiagnostic(Diagnostic.Create(MissingTryCatchRule, location));
return;
}

if (!bodyBlock.Statements.Any())
{
// ignore this case, for now, if there's nothing in there, it's safe, and we don't want to hassle users too soon
return;
}

if (bodyBlock.Statements.Count != expectedStatements)
{
// We require exactly a predefined amount of statements, so this must be an error
context.ReportDiagnostic(Diagnostic.Create(MissingTryCatchRule, bodyBlock.GetLocation()));
return;
}

// check the first statement
if (!isVoidMethod && bodyBlock.Statements[0] is not LocalDeclarationStatementSyntax)
{
// this is an error, and we can't go much further
context.ReportDiagnostic(Diagnostic.Create(MissingTryCatchRule, bodyBlock.GetLocation()));
return;
}

if (bodyBlock.Statements[1] is not TryStatementSyntax tryCatchStatement)
{
// oops, you should have a try block here
context.ReportDiagnostic(Diagnostic.Create(MissingTryCatchRule, bodyBlock.GetLocation()));
return;
}

CatchClauseSyntax? catchClause = null;
var hasFilter = false;
var isSystemException = false;
var isRethrowing = false;

foreach (var catchSyntax in tryCatchStatement.Catches)
{
catchClause = catchSyntax;
isSystemException = false;
isRethrowing = false;

// check that it's catching _everything_
hasFilter = catchClause.Filter is not null;
if (hasFilter)
{
// Skipping because we shouldn't be letting anything through
continue;
}

var exceptionTypeName = catchSyntax.Declaration?.Type is { } exceptionType
? context.SemanticModel.GetSymbolInfo(exceptionType).Symbol?.ToString()
: null;
isSystemException = exceptionTypeName is null or "System.Exception";
if (!isSystemException)
{
// skipping because it's not broad enough
continue;
}

// final requirement, must not be rethrowing
foreach (var statement in catchSyntax.Block.Statements)
{
if (statement is ThrowStatementSyntax)
{
isRethrowing = true;
break;
}
}

// if we get here, we know one of the loops is all good, so we can break
break;
}

if (catchClause is null || hasFilter || !isSystemException || isRethrowing)
{
// oops, no good
var location = catchClause?.GetLocation() ?? tryCatchStatement.GetLocation();
context.ReportDiagnostic(Diagnostic.Create(MissingTryCatchRule, location));
}

// final check, do we return the variable?
if (!isVoidMethod)
{
if (bodyBlock.Statements[2] is not ReturnStatementSyntax returnStatement)
{
context.ReportDiagnostic(Diagnostic.Create(MissingTryCatchRule, bodyBlock.GetLocation()));
return;
}

// should be returning the variable
if (returnStatement.Expression is not IdentifierNameSyntax identifierName)
{
context.ReportDiagnostic(Diagnostic.Create(MissingTryCatchRule, bodyBlock.GetLocation()));
return;
}

LocalDeclarationStatementSyntax localDeclaration = (LocalDeclarationStatementSyntax)bodyBlock.Statements[0];
if (!localDeclaration.Declaration.Variables.Any()
|| localDeclaration.Declaration.Variables[0] is not { } variable
|| variable.Identifier.ToString() != identifierName.Identifier.ToString())
{
// not returning the right thing
context.ReportDiagnostic(Diagnostic.Create(MissingTryCatchRule, bodyBlock.GetLocation()));
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// <copyright file="ReplaceAspectCodeFixProvider.cs" company="Datadog">
// Unless explicitly stated otherwise all files in this repository are licensed under the Apache 2 License.
// This product includes software developed at Datadog (https://www.datadoghq.com/). Copyright 2017 Datadog, Inc.
// </copyright>

#nullable enable
using System.Collections.Immutable;
using System.Composition;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Datadog.Trace.Tools.Analyzers.ThreadAbortAnalyzer;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Formatting;

namespace Datadog.Trace.Tools.Analyzers.AspectAnalyzers;

/// <summary>
/// A CodeFixProvider for the <see cref="ThreadAbortAnalyzer"/>
/// </summary>
[ExportCodeFixProvider(LanguageNames.CSharp, Name = nameof(ReplaceAspectCodeFixProvider))]
[Shared]
public class ReplaceAspectCodeFixProvider : CodeFixProvider
{
/// <inheritdoc />
public sealed override ImmutableArray<string> FixableDiagnosticIds
{
get => ImmutableArray.Create(ReplaceAspectAnalyzer.DiagnosticId);
}

/// <inheritdoc />
public sealed override FixAllProvider GetFixAllProvider()
{
// See https://github.com/dotnet/roslyn/blob/master/docs/analyzers/FixAllProvider.md for more information on Fix All Providers
return WellKnownFixAllProviders.BatchFixer;
}

/// <inheritdoc />
public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context)
{
var root = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false);

var diagnostic = context.Diagnostics.First();
var diagnosticSpan = diagnostic.Location.SourceSpan;

// Find the methodDeclaration identified by the diagnostic.
var methodDeclaration = root?.FindToken(diagnosticSpan.Start)
.Parent
?.AncestorsAndSelf()
.OfType<MethodDeclarationSyntax>()
.First();

if (methodDeclaration?.Body is { Statements.Count: >2 } body
&& body.Statements[0] is LocalDeclarationStatementSyntax localDeclaration
&& body.Statements[body.Statements.Count - 1] is ReturnStatementSyntax { Expression: IdentifierNameSyntax identifierName }
&& localDeclaration.Declaration.Variables.Count == 1
&& localDeclaration.Declaration.Variables[0] is { } variable
&& variable.Identifier.ToString() == identifierName.Identifier.ToString())
{
// Register a code action that will invoke the fix.
context.RegisterCodeFix(
CodeAction.Create(
title: "Wrap internals with exception handler",
createChangedDocument: c => AddTryCatch(context.Document, methodDeclaration, c),
equivalenceKey: nameof(ReplaceAspectCodeFixProvider)),
diagnostic);
}
}

private async Task<Document> AddTryCatch(Document document, MethodDeclarationSyntax methodDeclaration, CancellationToken cancellationToken)
{
// we know we're calling this with something we can fix,
// we just need to work out if we need to wrap the internals in a try-catch
// or add a catch statement
var body = methodDeclaration.Body!;
var localDeclaration = (LocalDeclarationStatementSyntax)body.Statements[0];
var returnSyntax = (ReturnStatementSyntax)body.Statements[body.Statements.Count - 1];
TryStatementSyntax tryCatch;

if (body.Statements.Count == 3 && body.Statements[1] is TryStatementSyntax tryStatementSyntax)
{
tryCatch = tryStatementSyntax;
}
else
{
var block = SyntaxFactory.Block(body.Statements.Skip(1).Take(body.Statements.Count - 2));
tryCatch = SyntaxFactory.TryStatement().WithBlock(block);
}

// Add the catch statement to the try-catch block
var parentType = methodDeclaration.AncestorsAndSelf()
.FirstOrDefault(x => x is TypeDeclarationSyntax or RecordDeclarationSyntax or StructDeclarationSyntax);
var typeName = parentType switch
{
StructDeclarationSyntax t => t.Identifier.Text,
RecordDeclarationSyntax t => t.Identifier.Text,
TypeDeclarationSyntax t => t.Identifier.Text,
_ => "UNKNOWN",
};

var methodName = methodDeclaration.Identifier.Text;

var catchDeclaration = SyntaxFactory.CatchDeclaration(SyntaxFactory.IdentifierName("Exception"), SyntaxFactory.Identifier("ex"));
var logExpression = SyntaxFactory.ExpressionStatement(
SyntaxFactory.ParseExpression($$"""IastModule.Log.Error(ex, $"Error invoking {nameof({{typeName}})}.{nameof({{methodName}})}")"""));

var catchSyntax = SyntaxFactory.CatchClause()
.WithDeclaration(catchDeclaration)
.WithBlock(SyntaxFactory.Block(logExpression));

var updatedTryCatch = tryCatch.AddCatches(catchSyntax);
var newBody = SyntaxFactory.Block(localDeclaration, updatedTryCatch, returnSyntax)
.WithAdditionalAnnotations(Formatter.Annotation);

var newMethodDeclaration = methodDeclaration.WithBody(newBody);

// replace the syntax and return updated document
var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
root = root!.ReplaceNode(methodDeclaration, newMethodDeclaration);
return document.WithSyntaxRoot(root);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -422,17 +422,15 @@ internal static partial class AspectDefinitions
" [AspectMethodReplace(\"System.String::Concat(System.Object,System.Object,System.Object,System.Object)\",\"\",[0],[False],[None],Default,[])] Concat(System.Object,System.Object,System.Object,System.Object)",
" [AspectMethodReplace(\"System.String::Concat(System.String[])\",\"\",[0],[False],[None],Default,[])] Concat(System.String[])",
" [AspectMethodReplace(\"System.String::Concat(System.Object[])\",\"\",[0],[False],[None],Default,[])] Concat(System.Object[])",
" [AspectMethodReplace(\"System.String::Concat(System.Collections.Generic.IEnumerable`1<System.String>)\",\"\",[0],[False],[None],Default,[])] Concat(System.Collections.IEnumerable)",
" [AspectMethodReplace(\"System.String::Concat(System.Collections.Generic.IEnumerable`1<!!0>)\",\"\",[0],[False],[None],Default,[])] Concat2(System.Collections.IEnumerable)",
" [AspectMethodReplace(\"System.String::Concat(System.Collections.Generic.IEnumerable`1<System.String>)\",\"\",[0],[False],[None],Default,[])] Concat(System.Collections.Generic.IEnumerable)",
" [AspectMethodReplace(\"System.String::Substring(System.Int32)\",\"\",[0],[False],[StringLiteral_0],Default,[])] Substring(System.String,System.Int32)",
" [AspectMethodReplace(\"System.String::Substring(System.Int32,System.Int32)\",\"\",[0],[False],[StringLiteral_0],Default,[])] Substring(System.String,System.Int32,System.Int32)",
" [AspectMethodReplace(\"System.String::ToCharArray()\",\"\",[0],[False],[StringLiteral_0],Default,[])] ToCharArray(System.String)",
" [AspectMethodReplace(\"System.String::ToCharArray(System.Int32,System.Int32)\",\"\",[0],[False],[StringLiteral_0],Default,[])] ToCharArray(System.String,System.Int32,System.Int32)",
" [AspectMethodReplace(\"System.String::Join(System.String,System.String[],System.Int32,System.Int32)\",\"\",[0],[False],[None],Default,[])] Join(System.String,System.String[],System.Int32,System.Int32)",
" [AspectMethodReplace(\"System.String::Join(System.String,System.Object[])\",\"\",[0],[False],[None],Default,[])] Join(System.String,System.Object[])",
" [AspectMethodReplace(\"System.String::Join(System.String,System.String[])\",\"\",[0],[False],[None],Default,[])] Join(System.String,System.String[])",
" [AspectMethodReplace(\"System.String::Join(System.String,System.Collections.Generic.IEnumerable`1<System.String>)\",\"\",[0],[False],[None],Default,[])] JoinString(System.String,System.Collections.IEnumerable)",
" [AspectMethodReplace(\"System.String::Join(System.String,System.Collections.Generic.IEnumerable`1<!!0>)\",\"\",[0],[False],[None],Default,[])] Join(System.String,System.Collections.IEnumerable)",
" [AspectMethodReplace(\"System.String::Join(System.String,System.Collections.Generic.IEnumerable`1<System.String>)\",\"\",[0],[False],[None],Default,[])] JoinString(System.String,System.Collections.Generic.IEnumerable)",
" [AspectMethodReplace(\"System.String::ToUpper()\",\"\",[0],[False],[StringLiteral_0],Default,[])] ToUpper(System.String)",
" [AspectMethodReplace(\"System.String::ToUpper(System.Globalization.CultureInfo)\",\"\",[0],[False],[StringLiteral_0],Default,[])] ToUpper(System.String,System.Globalization.CultureInfo)",
" [AspectMethodReplace(\"System.String::ToUpperInvariant()\",\"\",[0],[False],[StringLiteral_0],Default,[])] ToUpperInvariant(System.String)",
Expand Down
Loading
Loading