diff --git a/src/Refactorings/CSharp/Refactorings/DeconstructForeachVariableRefactoring.cs b/src/Refactorings/CSharp/Refactorings/DeconstructForeachVariableRefactoring.cs index e04da7437a..8a3bfdb8b1 100644 --- a/src/Refactorings/CSharp/Refactorings/DeconstructForeachVariableRefactoring.cs +++ b/src/Refactorings/CSharp/Refactorings/DeconstructForeachVariableRefactoring.cs @@ -1,6 +1,7 @@ // Copyright (c) Josef Pihrt and Contributors. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; using System.Threading; @@ -21,23 +22,35 @@ public static void ComputeRefactoring( { ITypeSymbol typeSymbol = semanticModel.GetTypeSymbol(forEachStatement.Type); - IMethodSymbol deconstructSymbol = typeSymbol.FindMember( - "Deconstruct", - symbol => - { - if (symbol.DeclaredAccessibility == Accessibility.Public) + IEnumerable parameters = null; + + if (typeSymbol.IsTupleType) + { + var tupleType = (INamedTypeSymbol)typeSymbol; + parameters = tupleType.TupleElements; + } + else + { + IMethodSymbol deconstructSymbol = typeSymbol.FindMember( + "Deconstruct", + symbol => { - ImmutableArray parameters = symbol.Parameters; + if (symbol.DeclaredAccessibility == Accessibility.Public) + { + ImmutableArray parameters = symbol.Parameters; - return parameters.Any() - && parameters.All(f => f.RefKind == RefKind.Out); - } + return parameters.Any() + && parameters.All(f => f.RefKind == RefKind.Out); + } - return false; - }); + return false; + }); - if (deconstructSymbol is null) - return; + if (deconstructSymbol is null) + return; + + parameters = deconstructSymbol.Parameters; + } ISymbol foreachSymbol = semanticModel.GetDeclaredSymbol(forEachStatement, context.CancellationToken); @@ -45,7 +58,7 @@ public static void ComputeRefactoring( return; var walker = new DeconstructForeachVariableWalker( - deconstructSymbol, + parameters, foreachSymbol, forEachStatement.Identifier.ValueText, semanticModel, @@ -53,73 +66,92 @@ public static void ComputeRefactoring( walker.Visit(forEachStatement.Statement); - if (!walker.Success) - return; - - context.RegisterRefactoring( - "Deconstruct foreach variable", - ct => RefactorAsync(context.Document, forEachStatement, deconstructSymbol, foreachSymbol, semanticModel, ct), - RefactoringDescriptors.DeconstructForeachVariable); + if (walker.Success) + { + context.RegisterRefactoring( + "Deconstruct foreach variable", + ct => RefactorAsync(context.Document, forEachStatement, parameters, foreachSymbol, semanticModel, ct), + RefactoringDescriptors.DeconstructForeachVariable); + } } private static async Task RefactorAsync( Document document, ForEachStatementSyntax forEachStatement, - IMethodSymbol deconstructSymbol, + IEnumerable deconstructSymbols, ISymbol identifierSymbol, SemanticModel semanticModel, CancellationToken cancellationToken) { + int position = forEachStatement.SpanStart; + ITypeSymbol elementType = semanticModel.GetForEachStatementInfo(forEachStatement).ElementType; + SyntaxNode enclosingSymbolSyntax = semanticModel.GetEnclosingSymbolSyntax(position, cancellationToken); + + ImmutableArray declaredSymbols = semanticModel.GetDeclaredSymbols(enclosingSymbolSyntax, excludeAnonymousTypeProperty: true, cancellationToken); + + ImmutableArray symbols = declaredSymbols + .Concat(semanticModel.LookupSymbols(position)) + .Distinct() + .Except(deconstructSymbols) + .ToImmutableArray(); + + Dictionary newNames = deconstructSymbols + .Select(parameter => + { + string name = StringUtility.FirstCharToLower(parameter.Name); + string newName = NameGenerator.Default.EnsureUniqueName(name, symbols); + + return (name: parameter.Name, newName); + }) + .ToDictionary(f => f.name, f => f.newName); + + var rewriter = new DeconstructForeachVariableRewriter(identifierSymbol, newNames, semanticModel, cancellationToken); + + var newStatement = (StatementSyntax)rewriter.Visit(forEachStatement.Statement); + DeclarationExpressionSyntax variableExpression = DeclarationExpression( CSharpFactory.VarType().WithTriviaFrom(forEachStatement.Type), ParenthesizedVariableDesignation( - deconstructSymbol.Parameters.Select(parameter => + deconstructSymbols.Select(parameter => { - return (VariableDesignationSyntax)SingleVariableDesignation( - Identifier( - SyntaxTriviaList.Empty, - parameter.Name, - SyntaxTriviaList.Empty)); + return SingleVariableDesignation( + Identifier(SyntaxTriviaList.Empty, newNames[parameter.Name], SyntaxTriviaList.Empty)); }) - .ToSeparatedSyntaxList()) + .ToSeparatedSyntaxList()) .WithTriviaFrom(forEachStatement.Identifier)) .WithFormatterAnnotation(); - var rewriter = new DeconstructForeachVariableRewriter(identifierSymbol, semanticModel, cancellationToken); - - var newStatement = (StatementSyntax)rewriter.Visit(forEachStatement.Statement); - - ForEachVariableStatementSyntax newForEachStatement = ForEachVariableStatement( + ForEachVariableStatementSyntax forEachVariableStatement = ForEachVariableStatement( forEachStatement.AttributeLists, forEachStatement.AwaitKeyword, forEachStatement.ForEachKeyword, forEachStatement.OpenParenToken, - variableExpression.WithFormatterAnnotation(), + variableExpression, forEachStatement.InKeyword, forEachStatement.Expression, forEachStatement.CloseParenToken, newStatement); - return await document.ReplaceNodeAsync(forEachStatement, newForEachStatement, cancellationToken).ConfigureAwait(false); + return await document.ReplaceNodeAsync(forEachStatement, forEachVariableStatement, cancellationToken).ConfigureAwait(false); } private class DeconstructForeachVariableWalker : CSharpSyntaxWalker { public DeconstructForeachVariableWalker( - IMethodSymbol deconstructMethod, + IEnumerable parameters, ISymbol identifierSymbol, string identifier, SemanticModel semanticModel, CancellationToken cancellationToken) { - DeconstructMethod = deconstructMethod; + Parameters = parameters; IdentifierSymbol = identifierSymbol; Identifier = identifier; SemanticModel = semanticModel; CancellationToken = cancellationToken; } - public IMethodSymbol DeconstructMethod { get; } + public IEnumerable Parameters { get; } public ISymbol IdentifierSymbol { get; } @@ -155,7 +187,7 @@ bool IsFixable(IdentifierNameSyntax node) var memberAccess = (MemberAccessExpressionSyntax)node.Parent; if (object.ReferenceEquals(memberAccess.Expression, node)) { - foreach (IParameterSymbol parameter in DeconstructMethod.Parameters) + foreach (ISymbol parameter in Parameters) { if (string.Equals(parameter.Name, memberAccess.Name.Identifier.ValueText, StringComparison.OrdinalIgnoreCase)) return true; @@ -172,16 +204,20 @@ private class DeconstructForeachVariableRewriter : CSharpSyntaxRewriter { public DeconstructForeachVariableRewriter( ISymbol identifierSymbol, + Dictionary names, SemanticModel semanticModel, CancellationToken cancellationToken) { IdentifierSymbol = identifierSymbol; + Names = names; SemanticModel = semanticModel; CancellationToken = cancellationToken; } public ISymbol IdentifierSymbol { get; } + public Dictionary Names { get; } + public SemanticModel SemanticModel { get; } public CancellationToken CancellationToken { get; } @@ -193,8 +229,12 @@ public override SyntaxNode VisitMemberAccessExpression(MemberAccessExpressionSyn && identifierName.Identifier.ValueText == IdentifierSymbol.Name && SymbolEqualityComparer.Default.Equals(SemanticModel.GetSymbol(identifierName, CancellationToken), IdentifierSymbol)) { - return IdentifierName(StringUtility.FirstCharToLower(node.Name.Identifier.ValueText)) - .WithTriviaFrom(identifierName); + string name = node.Name.Identifier.ValueText; + + if (!Names.TryGetValue(name, out string newName)) + newName = StringUtility.FirstCharToLower(name); + + return IdentifierName(newName).WithTriviaFrom(identifierName); } return base.VisitMemberAccessExpression(node); diff --git a/src/Tests/Refactorings.Tests/RR0217DeconstructForeachVariableTests.cs b/src/Tests/Refactorings.Tests/RR0217DeconstructForeachVariableTests.cs index 29feaa6e88..30f3b26080 100644 --- a/src/Tests/Refactorings.Tests/RR0217DeconstructForeachVariableTests.cs +++ b/src/Tests/Refactorings.Tests/RR0217DeconstructForeachVariableTests.cs @@ -11,7 +11,7 @@ public class RR0217DeconstructForeachVariableTests : AbstractCSharpRefactoringVe public override string RefactoringId { get; } = RefactoringIdentifiers.DeconstructForeachVariable; [Fact, Trait(Traits.Refactoring, RefactoringIdentifiers.DeconstructForeachVariable)] - public async Task Test_EmptyObjectInitializer() + public async Task Test_Dictionary() { await VerifyRefactoringAsync(@" using System.Collections.Generic; @@ -45,6 +45,110 @@ void M() } } } +", equivalenceKey: EquivalenceKey.Create(RefactoringId)); + } + + [Fact, Trait(Traits.Refactoring, RefactoringIdentifiers.DeconstructForeachVariable)] + public async Task Test_Dictionary_TopLevelStatement() + { + await VerifyRefactoringAsync(@" +using System.Collections.Generic; + +var dic = new Dictionary(); + +foreach ([||]var kvp in dic) +{ + var k = kvp.Key; + var v = kvp.Value.ToString(); +} +", @" +using System.Collections.Generic; + +var dic = new Dictionary(); + +foreach (var (key, value) in dic) +{ + var k = key; + var v = value.ToString(); +} +", equivalenceKey: EquivalenceKey.Create(RefactoringId)); + } + + [Fact, Trait(Traits.Refactoring, RefactoringIdentifiers.DeconstructForeachVariable)] + public async Task Test_Tuple() + { + await VerifyRefactoringAsync(@" +using System.Collections.Generic; + +class C +{ + void M() + { + var items = new List<(object, string)>(); + + foreach ([||]var item in items) + { + var k = item.Item1; + var v = item.Item2.ToString(); + } + } +} +", @" +using System.Collections.Generic; + +class C +{ + void M() + { + var items = new List<(object, string)>(); + + foreach (var (item1, item2) in items) + { + var k = item1; + var v = item2.ToString(); + } + } +} +", equivalenceKey: EquivalenceKey.Create(RefactoringId)); + } + + [Fact, Trait(Traits.Refactoring, RefactoringIdentifiers.DeconstructForeachVariable)] + public async Task Test_TupleWithNamedFields() + { + await VerifyRefactoringAsync(@" +using System.Collections.Generic; + +class C +{ + void M() + { + var p1 = false; + var items = new List<(object p1, string p2)>(); + + foreach ([||]var item in items) + { + var k = item.p1; + var v = item.p2.ToString(); + } + } +} +", @" +using System.Collections.Generic; + +class C +{ + void M() + { + var p1 = false; + var items = new List<(object p1, string p2)>(); + + foreach (var (p12, p2) in items) + { + var k = p12; + var v = p2.ToString(); + } + } +} ", equivalenceKey: EquivalenceKey.Create(RefactoringId)); } }