From c9f668448327d90312b62467e8bc5a93ea420b9f Mon Sep 17 00:00:00 2001 From: Martin Taillefer Date: Fri, 6 Jan 2023 14:03:33 -0800 Subject: [PATCH] Bunch of fixes for CA1859 * When using the ?? operator, the nullable annotation for the left-hand operand is now erased. This prevents the analyzer suggesting to use a replacement nullable type rather than its non-nullable variation. * We no longer suggest to upgrade the type of a local/field/parameter if the symbol is being used to invoke a method that's is an explicit implementation of an interface method. If the user would upgrade the type, the call to that method would no longer work. * Ensure that we never recommend upgrading the signature of a method that implements an interface method. * Don't recommend a method to be upgraded if the method is an implementation of a partial method definition. This is because there might be different implementations of the method with conflicting behavior. Note that if you use #if constructs, the diagnostic may still make recommendations that would break your code since the analyzer only knows about the select #if block. * Remove a field-specific optimization that was designed to speed up the analyzer since it actually broke analysis of fields, yielding bogus analysis results. --- .../UseConcreteTypeAnalyzer.Collector.cs | 49 +++--- .../Performance/UseConcreteTypeAnalyzer.cs | 95 +++++++---- .../Performance/UseConcreteTypeTests.cs | 155 +++++++++++++++++- 3 files changed, 247 insertions(+), 52 deletions(-) diff --git a/src/NetAnalyzers/Core/Microsoft.NetCore.Analyzers/Performance/UseConcreteTypeAnalyzer.Collector.cs b/src/NetAnalyzers/Core/Microsoft.NetCore.Analyzers/Performance/UseConcreteTypeAnalyzer.Collector.cs index c004413d91..d9248a13c5 100644 --- a/src/NetAnalyzers/Core/Microsoft.NetCore.Analyzers/Performance/UseConcreteTypeAnalyzer.Collector.cs +++ b/src/NetAnalyzers/Core/Microsoft.NetCore.Analyzers/Performance/UseConcreteTypeAnalyzer.Collector.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Threading; using Analyzer.Utilities.Extensions; +using Analyzer.Utilities.Lightup; using Analyzer.Utilities.PooledObjects; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.Operations; @@ -17,9 +18,9 @@ private sealed class Collector { private static readonly ObjectPool _pool = new(() => new Collector()); - public ConcurrentDictionary VirtualDispatchFields { get; } = new(SymbolEqualityComparer.Default); - public ConcurrentDictionary VirtualDispatchLocals { get; } = new(SymbolEqualityComparer.Default); - public ConcurrentDictionary VirtualDispatchParameters { get; } = new(SymbolEqualityComparer.Default); + public ConcurrentDictionary> VirtualDispatchFields { get; } = new(SymbolEqualityComparer.Default); + public ConcurrentDictionary> VirtualDispatchLocals { get; } = new(SymbolEqualityComparer.Default); + public ConcurrentDictionary> VirtualDispatchParameters { get; } = new(SymbolEqualityComparer.Default); public ConcurrentDictionary MethodsAssignedToDelegate { get; } = new(SymbolEqualityComparer.Default); public ConcurrentDictionary> FieldAssignments { get; } = new(SymbolEqualityComparer.Default); @@ -35,9 +36,9 @@ private Collector() private void Reset() { - VirtualDispatchFields.Clear(); - VirtualDispatchLocals.Clear(); - VirtualDispatchParameters.Clear(); + DrainDictionary(VirtualDispatchFields); + DrainDictionary(VirtualDispatchLocals); + DrainDictionary(VirtualDispatchParameters); MethodsAssignedToDelegate.Clear(); DrainDictionary(FieldAssignments); @@ -47,7 +48,8 @@ private void Reset() Void = null; - static void DrainDictionary(ConcurrentDictionary> d) + static void DrainDictionary(ConcurrentDictionary> d) + where U : notnull { foreach (var kvp in d) { @@ -97,7 +99,7 @@ public void HandleInvocation(IInvocationOperation op) var fieldRef = (IFieldReferenceOperation)instance; if (CanUpgrade(fieldRef.Field)) { - VirtualDispatchFields[fieldRef.Field] = true; + RecordVirtualDispatch(fieldRef.Field, op.TargetMethod); } break; @@ -106,14 +108,14 @@ public void HandleInvocation(IInvocationOperation op) case OperationKind.ParameterReference: { var parameterRef = (IParameterReferenceOperation)instance; - VirtualDispatchParameters[parameterRef.Parameter] = true; + RecordVirtualDispatch(parameterRef.Parameter, op.TargetMethod); break; } case OperationKind.LocalReference: { var localRef = (ILocalReferenceOperation)instance; - VirtualDispatchLocals[localRef.Local] = true; + RecordVirtualDispatch(localRef.Local, op.TargetMethod); break; } } @@ -306,7 +308,7 @@ private static bool CanUpgrade(IOperation target) /// Trivial reject for methods that can't be upgraded in order to avoid wasted work. /// private static bool CanUpgrade(IMethodSymbol methodSym) - => methodSym.DeclaredAccessibility == Accessibility.Private && methodSym.MethodKind == MethodKind.Ordinary; + => methodSym.DeclaredAccessibility == Accessibility.Private && methodSym.MethodKind == MethodKind.Ordinary && !methodSym.IsImplementationOfAnyInterfaceMember(); /// /// Trivial reject for fields that can't be upgraded in order to avoid wasted work. @@ -361,7 +363,16 @@ private void GetValueTypes(List values, IOperation op) case OperationKind.Coalesce: { var colOp = (ICoalesceOperation)op; + + var oldCount = values.Count; GetValueTypes(values, colOp.Value); + + if (values.Count > oldCount) + { + // erase any potential nullable annotations of the left-hand value since when the value is null, it doesn't get used + values[^1] = values[^1].WithNullableAnnotation(Analyzer.Utilities.Lightup.NullableAnnotation.NotAnnotated); + } + GetValueTypes(values, colOp.WhenNull); return; } @@ -373,19 +384,9 @@ private void GetValueTypes(List values, IOperation op) case OperationKind.PropertyReference: case OperationKind.MethodReference: case OperationKind.LocalReference: - { - if (op.Type != null) - { - values.Add(op.Type!); - } - - return; - } - case OperationKind.FieldReference: { - var fieldRefOp = (IFieldReferenceOperation)op; - if (CanUpgrade(fieldRefOp.Field)) + if (op.Type != null) { values.Add(op.Type!); } @@ -454,6 +455,10 @@ private void RecordAssignment(IOperation op, ITypeSymbol valueType) } } + private void RecordVirtualDispatch(IFieldSymbol field, IMethodSymbol target) => VirtualDispatchFields.GetOrAdd(field, _ => PooledConcurrentSet.GetInstance(SymbolEqualityComparer.Default)).Add(target); + private void RecordVirtualDispatch(ILocalSymbol local, IMethodSymbol target) => VirtualDispatchLocals.GetOrAdd(local, _ => PooledConcurrentSet.GetInstance(SymbolEqualityComparer.Default)).Add(target); + private void RecordVirtualDispatch(IParameterSymbol parameter, IMethodSymbol target) => VirtualDispatchParameters.GetOrAdd(parameter, _ => PooledConcurrentSet.GetInstance(SymbolEqualityComparer.Default)).Add(target); + private void RecordAssignment(IFieldSymbol field, ITypeSymbol valueType) => FieldAssignments.GetOrAdd(field, _ => PooledConcurrentSet.GetInstance(SymbolEqualityComparer.Default)).Add(valueType); private void RecordAssignment(ILocalSymbol local, ITypeSymbol valueType) => LocalAssignments.GetOrAdd(local, _ => PooledConcurrentSet.GetInstance(SymbolEqualityComparer.Default)).Add(valueType); private void RecordAssignment(IParameterSymbol parameter, ITypeSymbol valueType) => ParameterAssignments.GetOrAdd(parameter.OriginalDefinition, _ => PooledConcurrentSet.GetInstance(SymbolEqualityComparer.Default)).Add(valueType); diff --git a/src/NetAnalyzers/Core/Microsoft.NetCore.Analyzers/Performance/UseConcreteTypeAnalyzer.cs b/src/NetAnalyzers/Core/Microsoft.NetCore.Analyzers/Performance/UseConcreteTypeAnalyzer.cs index fafcc19ad8..3200ff6aaa 100644 --- a/src/NetAnalyzers/Core/Microsoft.NetCore.Analyzers/Performance/UseConcreteTypeAnalyzer.cs +++ b/src/NetAnalyzers/Core/Microsoft.NetCore.Analyzers/Performance/UseConcreteTypeAnalyzer.cs @@ -52,6 +52,8 @@ namespace Microsoft.NetCore.Analyzers.Performance /// * The method must be private. /// /// * The method must not have been assigned to a delegate. + /// + /// * The method must not be the implementation of a partial method definition. /// [DiagnosticAnalyzer(LanguageNames.CSharp, LanguageNames.VisualBasic)] public sealed partial class UseConcreteTypeAnalyzer : DiagnosticAnalyzer @@ -108,6 +110,7 @@ public override void Initialize(AnalysisContext context) { context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None); context.EnableConcurrentExecution(); + context.RegisterCompilationStartAction(context => { var voidType = context.Compilation.GetSpecialType(SpecialType.System_Void); @@ -141,33 +144,42 @@ public override void Initialize(AnalysisContext context) private static void Report(SymbolAnalysisContext context, Collector coll) { // for all eligible private fields that are used as the receiver for a virtual call - foreach (var field in coll.VirtualDispatchFields.Keys) + foreach (var pair in coll.VirtualDispatchFields) { + var field = pair.Key; + var methods = pair.Value; + if (coll.FieldAssignments.TryGetValue(field, out var assignments)) { - Report(field, field.Type, assignments, UseConcreteTypeForField); + Report(field, field.Type, assignments, methods, UseConcreteTypeForField); } } // for all eligible local variables that are used as the receiver for a virtual call - foreach (var local in coll.VirtualDispatchLocals.Keys) + foreach (var pair in coll.VirtualDispatchLocals) { + var local = pair.Key; + var methods = pair.Value; + if (coll.LocalAssignments.TryGetValue(local, out var assignments)) { - Report(local, local.Type, assignments, UseConcreteTypeForLocal); + Report(local, local.Type, assignments, methods, UseConcreteTypeForLocal); } } // for all eligible parameters that are used as the receiver for a virtual call - foreach (var parameter in coll.VirtualDispatchParameters.Keys) + foreach (var pair in coll.VirtualDispatchParameters) { + var parameter = pair.Key; + var methods = pair.Value; + if (coll.ParameterAssignments.TryGetValue(parameter, out var assignments)) { if (parameter.ContainingSymbol is IMethodSymbol method) { if (CanUpgrade(method)) { - Report(parameter, parameter.Type, assignments, UseConcreteTypeForParameter); + Report(parameter, parameter.Type, assignments, methods, UseConcreteTypeForParameter); } } } @@ -182,11 +194,11 @@ private static void Report(SymbolAnalysisContext context, Collector coll) // only report the method if it never assigned to a delegate if (CanUpgrade(method)) { - Report(method, method.ReturnType, returns, UseConcreteTypeForMethodReturn); + Report(method, method.ReturnType, returns, null, UseConcreteTypeForMethodReturn); } } - void Report(ISymbol sym, ITypeSymbol fromType, PooledConcurrentSet assignments, DiagnosticDescriptor desc) + void Report(ISymbol sym, ITypeSymbol fromType, PooledConcurrentSet assignments, PooledConcurrentSet? targets, DiagnosticDescriptor desc) { // a set of the values assigned to the given symbol using var types = PooledHashSet.GetInstance(assignments, SymbolEqualityComparer.Default); @@ -194,37 +206,62 @@ void Report(ISymbol sym, ITypeSymbol fromType, PooledConcurrentSet // 'void' is the magic value we use to represent null assignment var assignedNull = types.Remove(coll.Void!); - // We currently only handle the case where there is only a single consistent type of value assigned to the + // We currently only handle the case where there is a single consistent type of value assigned to the // symbol. If there are multiple different types, we could try to find the common base for these, but it doesn't // seem worth the complication. - if (types.Count == 1) + if (types.Count != 1) { - var toType = types.Single(); - if (assignedNull) - { - toType = toType.WithNullableAnnotation(Analyzer.Utilities.Lightup.NullableAnnotation.Annotated); - } + return; + } - if (!toType.DerivesFrom(fromType.OriginalDefinition)) - { - // can readily replace fromType by toType - return; - } + var toType = types.Single(); + if (assignedNull || fromType.NullableAnnotation() == Analyzer.Utilities.Lightup.NullableAnnotation.Annotated) + { + toType = toType.WithNullableAnnotation(Analyzer.Utilities.Lightup.NullableAnnotation.Annotated); + } - if (toType.TypeKind == TypeKind.Class - && !SymbolEqualityComparer.Default.Equals(fromType, toType) - && toType.SpecialType != SpecialType.System_Object - && toType.SpecialType != SpecialType.System_Delegate) + if (!toType.DerivesFrom(fromType.OriginalDefinition)) + { + // can readily replace fromType by toType + return; + } + + // if any of the methods that are invoked on toType are explicit implementations of interface methods, then we don't want + // to recommend upgrading the type otherwise it would break those call sites + if (targets != null) + { + foreach (var t in targets) { - var fromTypeName = GetTypeName(fromType); - var toTypeName = GetTypeName(toType); - var diagnostic = sym.CreateDiagnostic(desc, sym.Name, fromTypeName, toTypeName); - context.ReportDiagnostic(diagnostic); + foreach (var m in toType.GetMembers()) + { + if (m.IsImplementationOfAnyExplicitInterfaceMember()) + { + if (m.IsImplementationOfInterfaceMember(t)) + { + return; + } + } + } } } + + if (toType.TypeKind == TypeKind.Class + && !SymbolEqualityComparer.Default.Equals(fromType, toType) + && toType.SpecialType != SpecialType.System_Object + && toType.SpecialType != SpecialType.System_Delegate) + { + var fromTypeName = GetTypeName(fromType); + var toTypeName = GetTypeName(toType); + var diagnostic = sym.CreateDiagnostic(desc, sym.Name, fromTypeName, toTypeName); + context.ReportDiagnostic(diagnostic); + } } - bool CanUpgrade(IMethodSymbol methodSym) => !coll.MethodsAssignedToDelegate.ContainsKey(methodSym); + bool CanUpgrade(IMethodSymbol methodSym) + { + return !coll.MethodsAssignedToDelegate.ContainsKey(methodSym) + && methodSym.PartialDefinitionPart == null; + } static string GetTypeName(ITypeSymbol type) => type.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat); } diff --git a/src/NetAnalyzers/UnitTests/Microsoft.NetCore.Analyzers/Performance/UseConcreteTypeTests.cs b/src/NetAnalyzers/UnitTests/Microsoft.NetCore.Analyzers/Performance/UseConcreteTypeTests.cs index 3072076001..1aef5dea63 100644 --- a/src/NetAnalyzers/UnitTests/Microsoft.NetCore.Analyzers/Performance/UseConcreteTypeTests.cs +++ b/src/NetAnalyzers/UnitTests/Microsoft.NetCore.Analyzers/Performance/UseConcreteTypeTests.cs @@ -87,6 +87,97 @@ public static int Bar() await TestCSAsync(Source); } + [Fact] + public static async Task ShouldNotTrigger3() + { + const string Source = @" + #nullable enable + + using System; + using System.IO; + + namespace Example + { + internal static class C + { + private static Stream GetStream(int i) + { + if (i == 0) + { + return Stream.Null; + } + + return new MyStream(); + } + } + } + + public class MyStream : MemoryStream { } + "; + + await TestCSAsync(Source); + } + + [Fact] + public static async Task ShouldNotTrigger4() + { + const string Source = @" + #nullable enable + + using System; + using System.IO; + + namespace Example + { + internal partial class C + { + private partial Stream GetStream(int i); + } + + internal partial class C + { + private partial Stream GetStream(int i) + { + return new MyStream(); + } + } + } + + public class MyStream : MemoryStream { } + "; + + await TestCSAsync(Source); + } + + [Fact] + public static async Task ShouldNotTrigger5() + { + const string Source = @" + #nullable enable + + interface IFoo + { + int M(); + } + + internal class C : IFoo + { + int IFoo.M() => 42; + } + + internal class Use + { + static int Bar() + { + IFoo f = new C(); + return f.M(); + } + } + "; + + await TestCSAsync(Source); + } + [Fact] public static async Task ShouldTrigger1() { @@ -159,6 +250,68 @@ await TestCSAsync(Source, .WithArguments("foo", "Example.IFoo", "Example.Foo")); } + [Fact] + public static async Task ShouldTrigger3() + { + const string Source = @" + #nullable enable + + using System; + using System.IO; + + namespace Example + { + internal class C + { + private MemoryStream? _stream; + + private Stream {|#0:GetStream|}() + { + return _stream ?? Create(); + } + + private MemoryStream Create() => new MemoryStream(); + } + } + "; + + await TestCSAsync(Source, + VerifyCS.Diagnostic(UseConcreteTypeAnalyzer.UseConcreteTypeForMethodReturn) + .WithLocation(0) + .WithArguments("GetStream", "System.IO.Stream", "System.IO.MemoryStream")); + } + + [Fact] + public static async Task ShouldTrigger4() + { + const string Source = @" + #nullable enable + + using System; + using System.IO; + + namespace Example + { + internal class C + { + private MemoryStream? _stream; + + private Stream? {|#0:GetStream|}() + { + return _stream ?? Create(); + } + + private MemoryStream? Create() => new MemoryStream(); + } + } + "; + + await TestCSAsync(Source, + VerifyCS.Diagnostic(UseConcreteTypeAnalyzer.UseConcreteTypeForMethodReturn) + .WithLocation(0) + .WithArguments("GetStream", "System.IO.Stream?", "System.IO.MemoryStream?")); + } + [Fact] public static async Task Params() { @@ -261,7 +414,7 @@ private void Method2() await TestCSAsync(Source, VerifyCS.Diagnostic(UseConcreteTypeAnalyzer.UseConcreteTypeForMethodReturn) .WithLocation(0) - .WithArguments("Method1", "Example.IFoo?", "Example.Foo")); + .WithArguments("Method1", "Example.IFoo?", "Example.Foo?")); } [Fact]