diff --git a/TunnelVisionLabs.ReferenceAssemblyAnnotator/Program.cs b/TunnelVisionLabs.ReferenceAssemblyAnnotator/Program.cs index 2ed3ae1..ef08b6e 100644 --- a/TunnelVisionLabs.ReferenceAssemblyAnnotator/Program.cs +++ b/TunnelVisionLabs.ReferenceAssemblyAnnotator/Program.cs @@ -68,7 +68,7 @@ static void AddAttributeOfInterest(Dictionary attributes private static void AnnotateAssembly(SuppressibleLoggingHelper? log, AssemblyDefinition assemblyDefinition, AssemblyDefinition annotatedAssemblyDefinition, Dictionary attributesOfInterest) { - Annotate(assemblyDefinition, annotatedAssemblyDefinition, attributesOfInterest); + Annotate(assemblyDefinition.MainModule, assemblyDefinition, annotatedAssemblyDefinition, attributesOfInterest); if (assemblyDefinition.Modules.Count != 1) throw new NotSupportedException(); @@ -77,14 +77,14 @@ private static void AnnotateAssembly(SuppressibleLoggingHelper? log, AssemblyDef private static void AnnotateModule(SuppressibleLoggingHelper? log, ModuleDefinition moduleDefinition, ModuleDefinition annotatedModuleDefinition, Dictionary attributesOfInterest) { - Annotate(moduleDefinition, annotatedModuleDefinition, attributesOfInterest); + Annotate(moduleDefinition, moduleDefinition, annotatedModuleDefinition, attributesOfInterest); foreach (var type in moduleDefinition.GetAllTypes()) { - AnnotateType(log, type, annotatedModuleDefinition, attributesOfInterest); + AnnotateType(log, moduleDefinition, type, annotatedModuleDefinition, attributesOfInterest); } } - private static void AnnotateType(SuppressibleLoggingHelper? log, TypeDefinition typeDefinition, ModuleDefinition annotatedModuleDefinition, Dictionary attributesOfInterest) + private static void AnnotateType(SuppressibleLoggingHelper? log, ModuleDefinition module, TypeDefinition typeDefinition, ModuleDefinition annotatedModuleDefinition, Dictionary attributesOfInterest) { if (attributesOfInterest.ContainsKey(typeDefinition.FullName)) return; @@ -95,82 +95,92 @@ private static void AnnotateType(SuppressibleLoggingHelper? log, TypeDefinition return; } - Annotate(typeDefinition, annotatedTypeDefinition, attributesOfInterest); + Annotate(module, typeDefinition, annotatedTypeDefinition, attributesOfInterest); for (int i = 0; i < typeDefinition.Interfaces.Count; i++) { for (int j = 0; j < annotatedTypeDefinition.Interfaces.Count; j++) { if (EquivalenceComparers.TypeReference.Equals(typeDefinition.Interfaces[i].InterfaceType, annotatedTypeDefinition.Interfaces[j].InterfaceType)) { - Annotate(typeDefinition.Interfaces[i], annotatedTypeDefinition.Interfaces[j], attributesOfInterest); + Annotate(module, typeDefinition.Interfaces[i], annotatedTypeDefinition.Interfaces[j], attributesOfInterest); } } } for (int i = 0; i < typeDefinition.GenericParameters.Count; i++) { - Annotate(typeDefinition.GenericParameters[i], annotatedTypeDefinition.GenericParameters[i], attributesOfInterest); + Annotate(module, typeDefinition.GenericParameters[i], annotatedTypeDefinition.GenericParameters[i], attributesOfInterest); } foreach (var method in typeDefinition.Methods) { - AnnotateMethod(log, method, annotatedTypeDefinition, attributesOfInterest); + AnnotateMethod(log, module, method, annotatedTypeDefinition, attributesOfInterest); } foreach (var property in typeDefinition.Properties) { - AnnotateProperty(property, annotatedTypeDefinition, attributesOfInterest); + AnnotateProperty(module, property, annotatedTypeDefinition, attributesOfInterest); } foreach (var field in typeDefinition.Fields) { - AnnotateField(field, annotatedTypeDefinition, attributesOfInterest); + AnnotateField(module, field, annotatedTypeDefinition, attributesOfInterest); } } - private static void AnnotateMethod(SuppressibleLoggingHelper? log, MethodDefinition methodDefinition, TypeDefinition annotatedTypeDefinition, Dictionary attributesOfInterest) + private static void AnnotateMethod(SuppressibleLoggingHelper? log, ModuleDefinition module, MethodDefinition methodDefinition, TypeDefinition annotatedTypeDefinition, Dictionary attributesOfInterest) { var annotatedMethodDefinition = FindMatchingMethod(log, methodDefinition, annotatedTypeDefinition); if (annotatedMethodDefinition is null) return; - Annotate(methodDefinition, annotatedMethodDefinition, attributesOfInterest); - Annotate(methodDefinition.MethodReturnType, annotatedMethodDefinition.MethodReturnType, attributesOfInterest); + Annotate(module, methodDefinition, annotatedMethodDefinition, attributesOfInterest); + Annotate(module, methodDefinition.MethodReturnType, annotatedMethodDefinition.MethodReturnType, attributesOfInterest); for (int i = 0; i < methodDefinition.Parameters.Count; i++) { - Annotate(methodDefinition.Parameters[i], annotatedMethodDefinition.Parameters[i], attributesOfInterest); + Annotate(module, methodDefinition.Parameters[i], annotatedMethodDefinition.Parameters[i], attributesOfInterest); } for (int i = 0; i < methodDefinition.GenericParameters.Count; i++) { - Annotate(methodDefinition.GenericParameters[i], annotatedMethodDefinition.GenericParameters[i], attributesOfInterest); + Annotate(module, methodDefinition.GenericParameters[i], annotatedMethodDefinition.GenericParameters[i], attributesOfInterest); } } - private static void AnnotateProperty(PropertyDefinition propertyDefinition, TypeDefinition annotatedTypeDefinition, Dictionary attributesOfInterest) + private static void AnnotateProperty(ModuleDefinition module, PropertyDefinition propertyDefinition, TypeDefinition annotatedTypeDefinition, Dictionary attributesOfInterest) { var annotatedPropertyDefinition = FindMatchingProperty(propertyDefinition, annotatedTypeDefinition); if (annotatedPropertyDefinition is null) return; - Annotate(propertyDefinition, annotatedPropertyDefinition, attributesOfInterest); + Annotate(module, propertyDefinition, annotatedPropertyDefinition, attributesOfInterest); for (int i = 0; i < propertyDefinition.Parameters.Count; i++) { - Annotate(propertyDefinition.Parameters[i], annotatedPropertyDefinition.Parameters[i], attributesOfInterest); + Annotate(module, propertyDefinition.Parameters[i], annotatedPropertyDefinition.Parameters[i], attributesOfInterest); } } - private static void AnnotateField(FieldDefinition fieldDefinition, TypeDefinition annotatedTypeDefinition, Dictionary attributesOfInterest) + private static void AnnotateField(ModuleDefinition module, FieldDefinition fieldDefinition, TypeDefinition annotatedTypeDefinition, Dictionary attributesOfInterest) { var annotatedFieldDefinition = FindMatchingField(fieldDefinition, annotatedTypeDefinition); if (annotatedFieldDefinition is null) return; - Annotate(fieldDefinition, annotatedFieldDefinition, attributesOfInterest); + Annotate(module, fieldDefinition, annotatedFieldDefinition, attributesOfInterest); } - private static void Annotate(ICustomAttributeProvider provider, ICustomAttributeProvider annotatedProvider, Dictionary attributesOfInterest) + private static void Annotate(ModuleDefinition module, ICustomAttributeProvider provider, ICustomAttributeProvider annotatedProvider, Dictionary attributesOfInterest) { + // Start by removing any prior attributes that need to be filtered out + for (int i = 0; i < provider.CustomAttributes.Count; i++) + { + if (IsExcludedAnnotation(provider, annotatedProvider, provider.CustomAttributes[i])) + { + provider.CustomAttributes.RemoveAt(i); + i--; + } + } + foreach (var customAttribute in annotatedProvider.CustomAttributes) { if (!attributesOfInterest.TryGetValue(customAttribute.AttributeType.FullName, out var attributeTypeDefinition)) @@ -179,11 +189,14 @@ private static void Annotate(ICustomAttributeProvider provider, ICustomAttribute if (customAttribute.Fields.Count != 0 || customAttribute.Properties.Count != 0) continue; + if (IsExcludedAnnotation(provider, annotatedProvider, customAttribute)) + continue; + var constructor = attributeTypeDefinition.Methods.SingleOrDefault(method => IsMatchingConstructor(method, customAttribute)); if (constructor is null) continue; - var newCustomAttribute = new CustomAttribute(constructor); + var newCustomAttribute = new CustomAttribute(module.ImportReference(constructor)); for (int i = 0; i < customAttribute.ConstructorArguments.Count; i++) { newCustomAttribute.ConstructorArguments.Add(new CustomAttributeArgument(constructor.Parameters[i].ParameterType, customAttribute.ConstructorArguments[i].Value)); @@ -207,6 +220,25 @@ static bool IsMatchingConstructor(MethodDefinition constructor, CustomAttribute } } + private static bool IsExcludedAnnotation(ICustomAttributeProvider provider, ICustomAttributeProvider annotatedProvider, CustomAttribute customAttribute) + { + if (provider is ParameterDefinition parameter + && parameter.Method is MethodDefinition { Name: nameof(GetHashCode), Parameters: { Count: 1 } } method + && method.DeclaringType is { Namespace: "System.Collections.Generic", GenericParameters: { Count: 1 } } type) + { + if (type.Name == "IEqualityComparer`1" + || type.Name == "EqualityComparer`1") + { + // Remove DisallowNullAttribute from: + // 1. System.Collections.Generic.IEqualityComparer.GetHashCode([DisallowNull] T) + // 2. System.Collections.Generic.EqualityComparer.GetHashCode([DisallowNull] T) + return customAttribute.AttributeType.FullName == "System.Diagnostics.CodeAnalysis.DisallowNullAttribute"; + } + } + + return false; + } + private static TypeDefinition? FindMatchingType(SuppressibleLoggingHelper? log, TypeDefinition typeDefinition, ModuleDefinition annotatedModuleDefinition) { if (typeDefinition.IsNested)