diff --git a/ChangeLog.md b/ChangeLog.md index 17087977aa..71eeee9e42 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Fix analyzer [RCS1090](https://josefpihrt.github.io/docs/roslynator/analyzers/RCS1090) ([PR](https://github.com/dotnet/roslynator/pull/1566)) +- Fix analyzer [RCS1124](https://josefpihrt.github.io/docs/roslynator/analyzers/RCS1124) ([PR](https://github.com/dotnet/roslynator/pull/1572)) - [CLI] Fix command `generate-doc` ([PR](https://github.com/dotnet/roslynator/pull/1568), [PR](https://github.com/dotnet/roslynator/pull/1570)) ### Change diff --git a/src/Analyzers.CodeFixes/CSharp/CodeFixes/LocalDeclarationStatementCodeFixProvider.cs b/src/Analyzers.CodeFixes/CSharp/CodeFixes/LocalDeclarationStatementCodeFixProvider.cs index 45db233aaf..ab1057011d 100644 --- a/src/Analyzers.CodeFixes/CSharp/CodeFixes/LocalDeclarationStatementCodeFixProvider.cs +++ b/src/Analyzers.CodeFixes/CSharp/CodeFixes/LocalDeclarationStatementCodeFixProvider.cs @@ -67,7 +67,7 @@ private static async Task RefactorAsync( SemanticModel semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false); - ExpressionSyntax value = GetExpressionToInline(localDeclaration, semanticModel, cancellationToken); + ExpressionSyntax value = GetExpressionToInline(localDeclaration, nextStatement, semanticModel, cancellationToken); StatementSyntax newStatement = GetStatementWithInlinedExpression(nextStatement, value); @@ -93,7 +93,11 @@ private static async Task RefactorAsync( return await document.ReplaceStatementsAsync(statementsInfo, newStatements, cancellationToken).ConfigureAwait(false); } - private static ExpressionSyntax GetExpressionToInline(LocalDeclarationStatementSyntax localDeclaration, SemanticModel semanticModel, CancellationToken cancellationToken) + private static ExpressionSyntax GetExpressionToInline( + LocalDeclarationStatementSyntax localDeclaration, + StatementSyntax statement, + SemanticModel semanticModel, + CancellationToken cancellationToken) { VariableDeclarationSyntax variableDeclaration = localDeclaration.Declaration; @@ -114,20 +118,41 @@ private static ExpressionSyntax GetExpressionToInline(LocalDeclarationStatementS { expression = expression.Parenthesize(); - ExpressionSyntax typeExpression = (variableDeclaration.Type.IsVar) - ? variableDeclaration.Variables[0].Initializer.Value - : variableDeclaration.Type; + TypeSyntax type = variableDeclaration.Type; + ITypeSymbol typeSymbol; - ITypeSymbol typeSymbol = semanticModel.GetTypeSymbol(typeExpression, cancellationToken); + if (type.IsVar) + { + typeSymbol = semanticModel.GetTypeSymbol(variableDeclaration.Variables[0].Initializer.Value, cancellationToken)!; + type = typeSymbol.ToTypeSyntax().WithSimplifierAnnotation(); + } + else + { + typeSymbol = semanticModel.GetTypeSymbol(type, cancellationToken)!; + } - if (typeSymbol.SupportsExplicitDeclaration()) + bool ShouldAddCast() { - TypeSyntax type = typeSymbol.ToMinimalTypeSyntax(semanticModel, localDeclaration.SpanStart); + if (!typeSymbol.SupportsExplicitDeclaration()) + return false; + + if (statement.IsKind(SyntaxKind.ReturnStatement)) + { + IMethodSymbol enclosingSymbol = semanticModel.GetEnclosingSymbol(variableDeclaration.Type.SpanStart, cancellationToken); + + if (enclosingSymbol is not null + && SymbolEqualityComparer.Default.Equals(typeSymbol, enclosingSymbol.ReturnType)) + { + return false; + } + } - expression = SyntaxFactory.CastExpression(type, expression).WithSimplifierAnnotation(); + return true; } - return expression; + return (ShouldAddCast()) + ? SyntaxFactory.CastExpression(type.WithoutTrivia(), expression).WithSimplifierAnnotation() + : expression; } } diff --git a/src/Tests/Analyzers.Tests/RCS1124InlineLocalVariableTests.cs b/src/Tests/Analyzers.Tests/RCS1124InlineLocalVariableTests.cs index 7e27c7659e..b3468a52e8 100644 --- a/src/Tests/Analyzers.Tests/RCS1124InlineLocalVariableTests.cs +++ b/src/Tests/Analyzers.Tests/RCS1124InlineLocalVariableTests.cs @@ -101,6 +101,33 @@ void M() "); } + [Fact, Trait(Traits.Analyzer, DiagnosticIdentifiers.InlineLocalVariable)] + public async Task Test_NullableReturnType_ReturnsNullable() + { + await VerifyDiagnosticAndFixAsync(@" +public struct S; + +public class C +{ + public static S? M() + { + [|S? i = new S();|] + return i; + } +} +", @" +public struct S; + +public class C +{ + public static S? M() + { + return new S(); + } +} +"); + } + [Fact, Trait(Traits.Analyzer, DiagnosticIdentifiers.InlineLocalVariable)] public async Task TestNoDiagnostic_YieldReturnIsNotLastStatement() {