diff --git a/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems b/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems index 1386cf0e..be490d25 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems +++ b/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems @@ -54,10 +54,10 @@ - + @@ -70,6 +70,8 @@ + + @@ -77,6 +79,7 @@ + diff --git a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/INotifyPropertyChangedGenerator.cs b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/INotifyPropertyChangedGenerator.cs index 90d6a244..93c2bb18 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/INotifyPropertyChangedGenerator.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/INotifyPropertyChangedGenerator.cs @@ -4,9 +4,10 @@ using System.Collections.Immutable; using System.Linq; -using CommunityToolkit.Mvvm.SourceGenerators.Diagnostics; using CommunityToolkit.Mvvm.SourceGenerators.Extensions; +using CommunityToolkit.Mvvm.SourceGenerators.Helpers; using CommunityToolkit.Mvvm.SourceGenerators.Input.Models; +using CommunityToolkit.Mvvm.SourceGenerators.Models; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; using static CommunityToolkit.Mvvm.SourceGenerators.Diagnostics.DiagnosticDescriptors; @@ -28,16 +29,16 @@ public INotifyPropertyChangedGenerator() } /// - protected override INotifyPropertyChangedInfo? ValidateTargetTypeAndGetInfo(INamedTypeSymbol typeSymbol, AttributeData attributeData, Compilation compilation, out ImmutableArray diagnostics) + private protected override INotifyPropertyChangedInfo? ValidateTargetTypeAndGetInfo(INamedTypeSymbol typeSymbol, AttributeData attributeData, Compilation compilation, out ImmutableArray diagnostics) { - ImmutableArray.Builder builder = ImmutableArray.CreateBuilder(); + diagnostics = ImmutableArray.Empty; INotifyPropertyChangedInfo? info = null; // Check if the type already implements INotifyPropertyChanged if (typeSymbol.AllInterfaces.Any(i => i.HasFullyQualifiedName("global::System.ComponentModel.INotifyPropertyChanged"))) { - builder.Add(DuplicateINotifyPropertyChangedInterfaceForINotifyPropertyChangedAttributeError, typeSymbol, typeSymbol); + diagnostics = ImmutableArray.Create(DiagnosticInfo.Create(DuplicateINotifyPropertyChangedInterfaceForINotifyPropertyChangedAttributeError, typeSymbol, typeSymbol)); goto End; } @@ -46,7 +47,7 @@ public INotifyPropertyChangedGenerator() if (typeSymbol.HasOrInheritsAttributeWithFullyQualifiedName("global::CommunityToolkit.Mvvm.ComponentModel.ObservableObjectAttribute") || typeSymbol.InheritsAttributeWithFullyQualifiedName("global::CommunityToolkit.Mvvm.ComponentModel.INotifyPropertyChangedAttribute")) { - builder.Add(InvalidAttributeCombinationForINotifyPropertyChangedAttributeError, typeSymbol, typeSymbol); + diagnostics = ImmutableArray.Create(DiagnosticInfo.Create(InvalidAttributeCombinationForINotifyPropertyChangedAttributeError, typeSymbol, typeSymbol)); goto End; } @@ -56,8 +57,6 @@ public INotifyPropertyChangedGenerator() info = new INotifyPropertyChangedInfo(includeAdditionalHelperMethods); End: - diagnostics = builder.ToImmutable(); - return info; } @@ -67,9 +66,17 @@ protected override ImmutableArray FilterDeclaredMembers // If requested, only include the event and the basic methods to raise it, but not the additional helpers if (!info.IncludeAdditionalHelperMethods) { - return memberDeclarations.Where(static member => member - is EventFieldDeclarationSyntax - or MethodDeclarationSyntax { Identifier.ValueText: "OnPropertyChanged" }).ToImmutableArray(); + using ImmutableArrayBuilder selectedMembers = ImmutableArrayBuilder.Rent(); + + foreach (MemberDeclarationSyntax memberDeclaration in memberDeclarations) + { + if (memberDeclaration is EventFieldDeclarationSyntax or MethodDeclarationSyntax { Identifier.ValueText: "OnPropertyChanged" }) + { + selectedMembers.Add(memberDeclaration); + } + } + + return selectedMembers.ToImmutable(); } return memberDeclarations; diff --git a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/Models/AttributeInfo.cs b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/Models/AttributeInfo.cs index e9a64379..ae84ffa1 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/Models/AttributeInfo.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/Models/AttributeInfo.cs @@ -32,23 +32,24 @@ public static AttributeInfo From(AttributeData attributeData) { string typeName = attributeData.AttributeClass!.GetFullyQualifiedName(); + using ImmutableArrayBuilder constructorArguments = ImmutableArrayBuilder.Rent(); + using ImmutableArrayBuilder<(string, TypedConstantInfo)> namedArguments = ImmutableArrayBuilder<(string, TypedConstantInfo)>.Rent(); + // Get the constructor arguments - ImmutableArray constructorArguments = - attributeData.ConstructorArguments - .Select(TypedConstantInfo.From) - .ToImmutableArray(); + foreach (TypedConstant typedConstant in attributeData.ConstructorArguments) + { + constructorArguments.Add(TypedConstantInfo.From(typedConstant)); + } // Get the named arguments - ImmutableArray<(string, TypedConstantInfo)>.Builder namedArguments = ImmutableArray.CreateBuilder<(string, TypedConstantInfo)>(); - - foreach (KeyValuePair arg in attributeData.NamedArguments) + foreach (KeyValuePair namedConstant in attributeData.NamedArguments) { - namedArguments.Add((arg.Key, TypedConstantInfo.From(arg.Value))); + namedArguments.Add((namedConstant.Key, TypedConstantInfo.From(namedConstant.Value))); } return new( typeName, - constructorArguments, + constructorArguments.ToImmutable(), namedArguments.ToImmutable()); } @@ -64,8 +65,8 @@ public static AttributeInfo From(INamedTypeSymbol typeSymbol, SemanticModel sema { string typeName = typeSymbol.GetFullyQualifiedName(); - ImmutableArray.Builder constructorArguments = ImmutableArray.CreateBuilder(); - ImmutableArray<(string, TypedConstantInfo)>.Builder namedArguments = ImmutableArray.CreateBuilder<(string, TypedConstantInfo)>(); + using ImmutableArrayBuilder constructorArguments = ImmutableArrayBuilder.Rent(); + using ImmutableArrayBuilder<(string, TypedConstantInfo)> namedArguments = ImmutableArrayBuilder<(string, TypedConstantInfo)>.Rent(); foreach (AttributeArgumentSyntax argument in arguments) { diff --git a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/Models/TypedConstantInfo.Factory.cs b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/Models/TypedConstantInfo.Factory.cs index 8e58a6e0..825b075e 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/Models/TypedConstantInfo.Factory.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/Models/TypedConstantInfo.Factory.cs @@ -6,6 +6,7 @@ using System.Collections.Immutable; using System.Linq; using System.Threading; +using CommunityToolkit.Mvvm.SourceGenerators.Helpers; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Operations; @@ -127,7 +128,7 @@ public static TypedConstantInfo From( return new Array(elementTypeName, ImmutableArray.Empty); } - ImmutableArray.Builder items = ImmutableArray.CreateBuilder(initializerExpression.Expressions.Count); + using ImmutableArrayBuilder items = ImmutableArrayBuilder.Rent(); // Enumerate all array elements and extract serialized info for them foreach (ExpressionSyntax initializationExpression in initializerExpression.Expressions) @@ -140,7 +141,7 @@ public static TypedConstantInfo From( items.Add(From(initializationOperation, semanticModel, initializationExpression, token)); } - return new Array(elementTypeName, items.MoveToImmutable()); + return new Array(elementTypeName, items.ToImmutable()); } throw new ArgumentException("Invalid attribute argument value"); diff --git a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/Models/ValidationInfo.cs b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/Models/ValidationInfo.cs index 734e6b68..650a9a89 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/Models/ValidationInfo.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/Models/ValidationInfo.cs @@ -22,10 +22,16 @@ internal sealed record ValidationInfo( string TypeName, ImmutableArray PropertyNames) { + /// + public bool Equals(ValidationInfo? obj) => Comparer.Default.Equals(this, obj); + + /// + public override int GetHashCode() => Comparer.Default.GetHashCode(this); + /// /// An implementation for . /// - public sealed class Comparer : Comparer + private sealed class Comparer : Comparer { /// protected override void AddToHashCode(ref HashCode hashCode, ValidationInfo obj) diff --git a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableObjectGenerator.cs b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableObjectGenerator.cs index 2c45778d..b4d796fa 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableObjectGenerator.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableObjectGenerator.cs @@ -4,8 +4,8 @@ using System.Collections.Immutable; using System.Linq; -using CommunityToolkit.Mvvm.SourceGenerators.Diagnostics; using CommunityToolkit.Mvvm.SourceGenerators.Extensions; +using CommunityToolkit.Mvvm.SourceGenerators.Models; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; using static CommunityToolkit.Mvvm.SourceGenerators.Diagnostics.DiagnosticDescriptors; @@ -27,14 +27,14 @@ public ObservableObjectGenerator() } /// - protected override object? ValidateTargetTypeAndGetInfo(INamedTypeSymbol typeSymbol, AttributeData attributeData, Compilation compilation, out ImmutableArray diagnostics) + private protected override object? ValidateTargetTypeAndGetInfo(INamedTypeSymbol typeSymbol, AttributeData attributeData, Compilation compilation, out ImmutableArray diagnostics) { - ImmutableArray.Builder builder = ImmutableArray.CreateBuilder(); + diagnostics = ImmutableArray.Empty; // Check if the type already implements INotifyPropertyChanged... if (typeSymbol.AllInterfaces.Any(i => i.HasFullyQualifiedName("global::System.ComponentModel.INotifyPropertyChanged"))) { - builder.Add(DuplicateINotifyPropertyChangedInterfaceForObservableObjectAttributeError, typeSymbol, typeSymbol); + diagnostics = ImmutableArray.Create(DiagnosticInfo.Create(DuplicateINotifyPropertyChangedInterfaceForObservableObjectAttributeError, typeSymbol, typeSymbol)); goto End; } @@ -42,7 +42,7 @@ public ObservableObjectGenerator() // ...or INotifyPropertyChanging if (typeSymbol.AllInterfaces.Any(i => i.HasFullyQualifiedName("global::System.ComponentModel.INotifyPropertyChanging"))) { - builder.Add(DuplicateINotifyPropertyChangingInterfaceForObservableObjectAttributeError, typeSymbol, typeSymbol); + diagnostics = ImmutableArray.Create(DiagnosticInfo.Create(DuplicateINotifyPropertyChangingInterfaceForObservableObjectAttributeError, typeSymbol, typeSymbol)); goto End; } @@ -51,14 +51,12 @@ public ObservableObjectGenerator() if (typeSymbol.InheritsAttributeWithFullyQualifiedName("global::CommunityToolkit.Mvvm.ComponentModel.ObservableObjectAttribute") || typeSymbol.HasOrInheritsAttributeWithFullyQualifiedName("global::CommunityToolkit.Mvvm.ComponentModel.INotifyPropertyChangedAttribute")) { - builder.Add(InvalidAttributeCombinationForObservableObjectAttributeError, typeSymbol, typeSymbol); + diagnostics = ImmutableArray.Create(DiagnosticInfo.Create(InvalidAttributeCombinationForObservableObjectAttributeError, typeSymbol, typeSymbol)); goto End; } End: - diagnostics = builder.ToImmutable(); - return null; } diff --git a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.Execute.cs b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.Execute.cs index 21a77780..6209609d 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.Execute.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.Execute.cs @@ -9,8 +9,9 @@ using System.Linq; using System.Threading; using CommunityToolkit.Mvvm.SourceGenerators.ComponentModel.Models; -using CommunityToolkit.Mvvm.SourceGenerators.Diagnostics; using CommunityToolkit.Mvvm.SourceGenerators.Extensions; +using CommunityToolkit.Mvvm.SourceGenerators.Helpers; +using CommunityToolkit.Mvvm.SourceGenerators.Models; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -43,9 +44,9 @@ public static bool TryGetInfo( SemanticModel semanticModel, CancellationToken token, [NotNullWhen(true)] out PropertyInfo? propertyInfo, - out ImmutableArray diagnostics) + out ImmutableArray diagnostics) { - ImmutableArray.Builder builder = ImmutableArray.CreateBuilder(); + using ImmutableArrayBuilder builder = ImmutableArrayBuilder.Rent(); // Validate the target type if (!IsTargetTypeValid(fieldSymbol, out bool shouldInvokeOnPropertyChanging)) @@ -100,10 +101,11 @@ public static bool TryGetInfo( return false; } - ImmutableArray.Builder propertyChangedNames = ImmutableArray.CreateBuilder(); - ImmutableArray.Builder propertyChangingNames = ImmutableArray.CreateBuilder(); - ImmutableArray.Builder notifiedCommandNames = ImmutableArray.CreateBuilder(); - ImmutableArray.Builder forwardedAttributes = ImmutableArray.CreateBuilder(); + using ImmutableArrayBuilder propertyChangedNames = ImmutableArrayBuilder.Rent(); + using ImmutableArrayBuilder propertyChangingNames = ImmutableArrayBuilder.Rent(); + using ImmutableArrayBuilder notifiedCommandNames = ImmutableArrayBuilder.Rent(); + using ImmutableArrayBuilder forwardedAttributes = ImmutableArrayBuilder.Rent(); + bool notifyRecipients = false; bool notifyDataErrorInfo = false; bool hasOrInheritsClassLevelNotifyPropertyChangedRecipients = false; @@ -137,14 +139,14 @@ public static bool TryGetInfo( foreach (AttributeData attributeData in fieldSymbol.GetAttributes()) { // Gather dependent property and command names - if (TryGatherDependentPropertyChangedNames(fieldSymbol, attributeData, propertyChangedNames, builder) || - TryGatherDependentCommandNames(fieldSymbol, attributeData, notifiedCommandNames, builder)) + if (TryGatherDependentPropertyChangedNames(fieldSymbol, attributeData, in propertyChangedNames, in builder) || + TryGatherDependentCommandNames(fieldSymbol, attributeData, in notifiedCommandNames, in builder)) { continue; } // Check whether the property should also notify recipients - if (TryGetIsNotifyingRecipients(fieldSymbol, attributeData, builder, hasOrInheritsClassLevelNotifyPropertyChangedRecipients, out isBroadcastTargetValid)) + if (TryGetIsNotifyingRecipients(fieldSymbol, attributeData, in builder, hasOrInheritsClassLevelNotifyPropertyChangedRecipients, out isBroadcastTargetValid)) { notifyRecipients = isBroadcastTargetValid; @@ -152,7 +154,7 @@ public static bool TryGetInfo( } // Check whether the property should also be validated - if (TryGetNotifyDataErrorInfo(fieldSymbol, attributeData, builder, hasOrInheritsClassLevelNotifyDataErrorInfo, out isValidationTargetValid)) + if (TryGetNotifyDataErrorInfo(fieldSymbol, attributeData, in builder, hasOrInheritsClassLevelNotifyDataErrorInfo, out isValidationTargetValid)) { notifyDataErrorInfo = isValidationTargetValid; @@ -266,9 +268,7 @@ public static bool TryGetInfo( /// The input instance to process. /// Whether or not property changing events should also be raised. /// Whether or not the containing type for is valid. - private static bool IsTargetTypeValid( - IFieldSymbol fieldSymbol, - out bool shouldInvokeOnPropertyChanging) + private static bool IsTargetTypeValid(IFieldSymbol fieldSymbol, out bool shouldInvokeOnPropertyChanging) { // The [ObservableProperty] attribute can only be used in types that are known to expose the necessary OnPropertyChanged and OnPropertyChanging methods. // That means that the containing type for the field needs to match one of the following conditions: @@ -318,8 +318,8 @@ private static bool IsGeneratedPropertyInvalid(string propertyName, ITypeSymbol private static bool TryGatherDependentPropertyChangedNames( IFieldSymbol fieldSymbol, AttributeData attributeData, - ImmutableArray.Builder propertyChangedNames, - ImmutableArray.Builder diagnostics) + in ImmutableArrayBuilder propertyChangedNames, + in ImmutableArrayBuilder diagnostics) { // Validates a property name using existing properties bool IsPropertyNameValid(string propertyName) @@ -383,8 +383,8 @@ bool IsPropertyNameValidWithGeneratedMembers(string propertyName) private static bool TryGatherDependentCommandNames( IFieldSymbol fieldSymbol, AttributeData attributeData, - ImmutableArray.Builder notifiedCommandNames, - ImmutableArray.Builder diagnostics) + in ImmutableArrayBuilder notifiedCommandNames, + in ImmutableArrayBuilder diagnostics) { // Validates a command name using existing properties bool IsCommandNameValid(string commandName, out bool shouldLookForGeneratedMembersToo) @@ -505,7 +505,7 @@ private static bool TryGetIsNotifyingRecipients(IFieldSymbol fieldSymbol, out bo private static bool TryGetIsNotifyingRecipients( IFieldSymbol fieldSymbol, AttributeData attributeData, - ImmutableArray.Builder diagnostics, + in ImmutableArrayBuilder diagnostics, bool hasOrInheritsClassLevelNotifyPropertyChangedRecipients, out bool isBroadcastTargetValid) { @@ -608,7 +608,7 @@ private static bool TryGetNotifyDataErrorInfo(IFieldSymbol fieldSymbol, out bool private static bool TryGetNotifyDataErrorInfo( IFieldSymbol fieldSymbol, AttributeData attributeData, - ImmutableArray.Builder diagnostics, + in ImmutableArrayBuilder diagnostics, bool hasOrInheritsClassLevelNotifyDataErrorInfo, out bool isValidationTargetValid) { @@ -701,7 +701,7 @@ private static bool TryGetNotifyDataErrorInfo( /// The generated instance for . public static MemberDeclarationSyntax GetPropertySyntax(PropertyInfo propertyInfo) { - ImmutableArray.Builder setterStatements = ImmutableArray.CreateBuilder(); + using ImmutableArrayBuilder setterStatements = ImmutableArrayBuilder.Rent(); // Get the property type syntax TypeSyntax propertyType = IdentifierName(propertyInfo.TypeNameWithNullabilityAnnotations); @@ -848,7 +848,7 @@ public static MemberDeclarationSyntax GetPropertySyntax(PropertyInfo propertyInf .AddArgumentListArguments( Argument(fieldExpression), Argument(IdentifierName("value")))), - Block(setterStatements)); + Block(setterStatements.ToArray())); // Prepare the forwarded attributes, if any ImmutableArray forwardedAttributes = diff --git a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.cs b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.cs index bbc95542..10e45fe4 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.cs @@ -43,7 +43,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // Get the hierarchy info for the target symbol, and try to gather the property info HierarchyInfo hierarchy = HierarchyInfo.From(fieldSymbol.ContainingType); - _ = Execute.TryGetInfo(fieldDeclaration, fieldSymbol, context.SemanticModel, token, out PropertyInfo? propertyInfo, out ImmutableArray diagnostics); + _ = Execute.TryGetInfo(fieldDeclaration, fieldSymbol, context.SemanticModel, token, out PropertyInfo? propertyInfo, out ImmutableArray diagnostics); return (Hierarchy: hierarchy, new Result(propertyInfo, diagnostics)); }) @@ -53,16 +53,14 @@ public void Initialize(IncrementalGeneratorInitializationContext context) context.ReportDiagnostics(propertyInfoWithErrors.Select(static (item, _) => item.Info.Errors)); // Get the filtered sequence to enable caching - IncrementalValuesProvider<(HierarchyInfo Hierarchy, PropertyInfo Info)> propertyInfo = + IncrementalValuesProvider<(HierarchyInfo Hierarchy, Result Info)> propertyInfo = propertyInfoWithErrors - .Select(static (item, _) => (item.Hierarchy, Info: item.Info.Value)) - .Where(static item => item.Info is not null)! - .WithComparers(HierarchyInfo.Comparer.Default, PropertyInfo.Comparer.Default); + .Where(static item => item.Info.Value is not null)!; // Split and group by containing type IncrementalValuesProvider<(HierarchyInfo Hierarchy, ImmutableArray Properties)> groupedPropertyInfo = propertyInfo - .GroupBy(HierarchyInfo.Comparer.Default) + .GroupBy(HierarchyInfo.Comparer.Default, static item => item.Value) .WithComparers(HierarchyInfo.Comparer.Default, PropertyInfo.Comparer.Default.ForImmutableArray()); // Generate the requested properties and methods @@ -84,7 +82,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // Gather all property changing names IncrementalValueProvider> propertyChangingNames = propertyInfo - .SelectMany(static (item, _) => item.Info.PropertyChangingNames) + .SelectMany(static (item, _) => item.Info.Value.PropertyChangingNames) .Collect() .Select(static (item, _) => item.Distinct().ToImmutableArray()) .WithComparer(EqualityComparer.Default.ForImmutableArray()); @@ -103,7 +101,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // Gather all property changed names IncrementalValueProvider> propertyChangedNames = propertyInfo - .SelectMany(static (item, _) => item.Info.PropertyChangedNames) + .SelectMany(static (item, _) => item.Info.Value.PropertyChangedNames) .Collect() .Select(static (item, _) => item.Distinct().ToImmutableArray()) .WithComparer(EqualityComparer.Default.ForImmutableArray()); diff --git a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableRecipientGenerator.cs b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableRecipientGenerator.cs index d2098938..c2736051 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableRecipientGenerator.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableRecipientGenerator.cs @@ -5,8 +5,10 @@ using System.Collections.Immutable; using System.Linq; using CommunityToolkit.Mvvm.SourceGenerators.ComponentModel.Models; -using CommunityToolkit.Mvvm.SourceGenerators.Diagnostics; using CommunityToolkit.Mvvm.SourceGenerators.Extensions; +using CommunityToolkit.Mvvm.SourceGenerators.Helpers; +using CommunityToolkit.Mvvm.SourceGenerators.Input.Models; +using CommunityToolkit.Mvvm.SourceGenerators.Models; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -30,16 +32,16 @@ public ObservableRecipientGenerator() } /// - protected override ObservableRecipientInfo? ValidateTargetTypeAndGetInfo(INamedTypeSymbol typeSymbol, AttributeData attributeData, Compilation compilation, out ImmutableArray diagnostics) + private protected override ObservableRecipientInfo? ValidateTargetTypeAndGetInfo(INamedTypeSymbol typeSymbol, AttributeData attributeData, Compilation compilation, out ImmutableArray diagnostics) { - ImmutableArray.Builder builder = ImmutableArray.CreateBuilder(); + diagnostics = ImmutableArray.Empty; ObservableRecipientInfo? info = null; // Check if the type already inherits from ObservableRecipient if (typeSymbol.InheritsFromFullyQualifiedName("global::CommunityToolkit.Mvvm.ComponentModel.ObservableRecipient")) { - builder.Add(DuplicateObservableRecipientError, typeSymbol, typeSymbol); + diagnostics = ImmutableArray.Create(DiagnosticInfo.Create(DuplicateObservableRecipientError, typeSymbol, typeSymbol)); goto End; } @@ -47,7 +49,7 @@ public ObservableRecipientGenerator() // Check if the type already inherits [ObservableRecipient] if (typeSymbol.InheritsAttributeWithFullyQualifiedName("global::CommunityToolkit.Mvvm.ComponentModel.ObservableRecipientAttribute")) { - builder.Add(InvalidAttributeCombinationForObservableRecipientAttributeError, typeSymbol, typeSymbol); + diagnostics = ImmutableArray.Create(DiagnosticInfo.Create(InvalidAttributeCombinationForObservableRecipientAttributeError, typeSymbol, typeSymbol)); goto End; } @@ -60,7 +62,7 @@ public ObservableRecipientGenerator() a.AttributeClass?.HasFullyQualifiedName("global::CommunityToolkit.Mvvm.ComponentModel.INotifyPropertyChangedAttribute") == true && !a.HasNamedArgument("IncludeAdditionalHelperMethods", false))) { - builder.Add(MissingBaseObservableObjectFunctionalityError, typeSymbol, typeSymbol); + diagnostics = ImmutableArray.Create(DiagnosticInfo.Create(MissingBaseObservableObjectFunctionalityError, typeSymbol, typeSymbol)); goto End; } @@ -84,15 +86,13 @@ public ObservableRecipientGenerator() hasOnDeactivatedMethod); End: - diagnostics = builder.ToImmutable(); - return info; } /// protected override ImmutableArray FilterDeclaredMembers(ObservableRecipientInfo info, ImmutableArray memberDeclarations) { - ImmutableArray.Builder builder = ImmutableArray.CreateBuilder(); + using ImmutableArrayBuilder builder = ImmutableArrayBuilder.Rent(); // If the target type has no constructors, generate constructors as well if (!info.HasExplicitConstructors) diff --git a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.Execute.cs b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.Execute.cs index 52d91a39..bca9391d 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.Execute.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.Execute.cs @@ -6,6 +6,7 @@ using System.Collections.Immutable; using System.Linq; using CommunityToolkit.Mvvm.SourceGenerators.Extensions; +using CommunityToolkit.Mvvm.SourceGenerators.Helpers; using CommunityToolkit.Mvvm.SourceGenerators.Input.Models; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; @@ -39,7 +40,7 @@ public static bool IsObservableValidator(INamedTypeSymbol typeSymbol) /// The resulting instance for . public static ValidationInfo GetInfo(INamedTypeSymbol typeSymbol) { - ImmutableArray.Builder propertyNames = ImmutableArray.CreateBuilder(); + using ImmutableArrayBuilder propertyNames = ImmutableArrayBuilder.Rent(); foreach (ISymbol memberSymbol in typeSymbol.GetMembers()) { @@ -92,7 +93,7 @@ public static ValidationInfo GetInfo(INamedTypeSymbol typeSymbol) /// A instance for the current type being inspected. public static RecipientInfo GetInfo(INamedTypeSymbol typeSymbol, ImmutableArray interfaceSymbols) { - ImmutableArray.Builder names = ImmutableArray.CreateBuilder(interfaceSymbols.Length); + using ImmutableArrayBuilder names = ImmutableArrayBuilder.Rent(); foreach (INamedTypeSymbol interfaceSymbol in interfaceSymbols) { @@ -102,7 +103,7 @@ public static RecipientInfo GetInfo(INamedTypeSymbol typeSymbol, ImmutableArray< return new( typeSymbol.GetFullMetadataNameForFileName(), typeSymbol.GetFullyQualifiedName(), - names.MoveToImmutable()); + names.ToImmutable()); } /// @@ -112,8 +113,7 @@ public static RecipientInfo GetInfo(INamedTypeSymbol typeSymbol, ImmutableArray< /// The head instance with the type attributes. public static CompilationUnitSyntax GetSyntax(bool isDynamicallyAccessedMembersAttributeAvailable) { - int numberOfAttributes = 5 + (isDynamicallyAccessedMembersAttributeAvailable ? 1 : 0); - ImmutableArray.Builder attributes = ImmutableArray.CreateBuilder(numberOfAttributes); + using ImmutableArrayBuilder attributes = ImmutableArrayBuilder.Rent(); // Prepare the base attributes with are always present: // @@ -171,7 +171,7 @@ public static CompilationUnitSyntax GetSyntax(bool isDynamicallyAccessedMembersA Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.PartialKeyword)) - .AddAttributeLists(attributes.MoveToImmutable().ToArray()))) + .AddAttributeLists(attributes.ToArray()))) .NormalizeWhitespace(); } @@ -261,7 +261,7 @@ public static CompilationUnitSyntax GetSyntax(ValidationInfo validationInfo) /// The sequence of instances to validate declared properties. private static ImmutableArray EnumerateValidationStatements(ValidationInfo validationInfo) { - ImmutableArray.Builder statements = ImmutableArray.CreateBuilder(validationInfo.PropertyNames.Length); + using ImmutableArrayBuilder statements = ImmutableArrayBuilder.Rent(); // This loop produces a sequence of statements as follows: // @@ -294,7 +294,7 @@ private static ImmutableArray EnumerateValidationStatements(Val IdentifierName(propertyName)))))))); } - return statements.MoveToImmutable(); + return statements.ToImmutable(); } } } diff --git a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/TransitiveMembersGenerator.cs b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/TransitiveMembersGenerator.cs index dd385094..5703941b 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/TransitiveMembersGenerator.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/ComponentModel/TransitiveMembersGenerator.cs @@ -85,7 +85,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) INamedTypeSymbol typeSymbol = (INamedTypeSymbol)context.TargetSymbol; // Gather all generation info, and any diagnostics - TInfo? info = ValidateTargetTypeAndGetInfo(typeSymbol, context.Attributes[0], context.SemanticModel.Compilation, out ImmutableArray diagnostics); + TInfo? info = ValidateTargetTypeAndGetInfo(typeSymbol, context.Attributes[0], context.SemanticModel.Compilation, out ImmutableArray diagnostics); // If there are any diagnostics, there's no need to compute the hierarchy info at all, just return them if (diagnostics.Length > 0) @@ -129,7 +129,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) /// The resulting diagnostics, if any. /// The extracted info for the current type, if possible. /// If is empty, the returned info will always be ignored and no sources will be produced. - protected abstract TInfo? ValidateTargetTypeAndGetInfo(INamedTypeSymbol typeSymbol, AttributeData attributeData, Compilation compilation, out ImmutableArray diagnostics); + private protected abstract TInfo? ValidateTargetTypeAndGetInfo(INamedTypeSymbol typeSymbol, AttributeData attributeData, Compilation compilation, out ImmutableArray diagnostics); /// /// Filters the nodes to generate from the input parsed tree. diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Diagnostics/DiagnosticExtensions.cs b/CommunityToolkit.Mvvm.SourceGenerators/Diagnostics/DiagnosticExtensions.cs deleted file mode 100644 index 8da472e5..00000000 --- a/CommunityToolkit.Mvvm.SourceGenerators/Diagnostics/DiagnosticExtensions.cs +++ /dev/null @@ -1,46 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -// This file is ported and adapted from ComputeSharp (Sergio0694/ComputeSharp), -// more info in ThirdPartyNotices.txt in the root of the project. - -using System.Collections.Immutable; -using System.Linq; -using Microsoft.CodeAnalysis; - -namespace CommunityToolkit.Mvvm.SourceGenerators.Diagnostics; - -/// -/// Extension methods specifically for creating diagnostics. -/// -internal static class DiagnosticExtensions -{ - /// - /// Adds a new diagnostics to the target builder. - /// - /// The collection of produced instances. - /// The input for the diagnostics to create. - /// The source to attach the diagnostics to. - /// The optional arguments for the formatted message to include. - public static void Add( - this ImmutableArray.Builder diagnostics, - DiagnosticDescriptor descriptor, - ISymbol symbol, - params object[] args) - { - diagnostics.Add(descriptor.CreateDiagnostic(symbol, args)); - } - - /// - /// Creates a new instance with the specified parameters. - /// - /// The input for the diagnostics to create. - /// The source to attach the diagnostics to. - /// The optional arguments for the formatted message to include. - /// The resulting instance. - public static Diagnostic CreateDiagnostic(this DiagnosticDescriptor descriptor, ISymbol symbol, params object[] args) - { - return Diagnostic.Create(descriptor, symbol.Locations.FirstOrDefault(), args); - } -} diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Extensions/DiagnosticsExtensions.cs b/CommunityToolkit.Mvvm.SourceGenerators/Extensions/DiagnosticsExtensions.cs new file mode 100644 index 00000000..7d9fc098 --- /dev/null +++ b/CommunityToolkit.Mvvm.SourceGenerators/Extensions/DiagnosticsExtensions.cs @@ -0,0 +1,80 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +// This file is ported and adapted from ComputeSharp (Sergio0694/ComputeSharp), +// more info in ThirdPartyNotices.txt in the root of the project. + +using System.Collections.Immutable; +using CommunityToolkit.Mvvm.SourceGenerators.Helpers; +using CommunityToolkit.Mvvm.SourceGenerators.Models; +using Microsoft.CodeAnalysis; + +namespace CommunityToolkit.Mvvm.SourceGenerators.Extensions; + +/// +/// Extension methods for , specifically for reporting diagnostics. +/// +internal static class DiagnosticsExtensions +{ + /// + /// Adds a new diagnostics to the target builder. + /// + /// The collection of produced instances. + /// The input for the diagnostics to create. + /// The source to attach the diagnostics to. + /// The optional arguments for the formatted message to include. + public static void Add( + this in ImmutableArrayBuilder diagnostics, + DiagnosticDescriptor descriptor, + ISymbol symbol, + params object[] args) + { + diagnostics.Add(DiagnosticInfo.Create(descriptor, symbol, args)); + } + + /// + /// Adds a new diagnostics to the target builder. + /// + /// The collection of produced instances. + /// The input for the diagnostics to create. + /// The source to attach the diagnostics to. + /// The optional arguments for the formatted message to include. + public static void Add( + this in ImmutableArrayBuilder diagnostics, + DiagnosticDescriptor descriptor, + SyntaxNode node, + params object[] args) + { + diagnostics.Add(DiagnosticInfo.Create(descriptor, node, args)); + } + + /// + /// Registers an output node into an to output diagnostics. + /// + /// The input instance. + /// The input sequence of diagnostics. + public static void ReportDiagnostics(this IncrementalGeneratorInitializationContext context, IncrementalValuesProvider diagnostics) + { + context.RegisterSourceOutput(diagnostics, static (context, diagnostic) => + { + context.ReportDiagnostic(diagnostic.ToDiagnostic()); + }); + } + + /// + /// Registers an output node into an to output diagnostics. + /// + /// The input instance. + /// The input sequence of diagnostics. + public static void ReportDiagnostics(this IncrementalGeneratorInitializationContext context, IncrementalValuesProvider> diagnostics) + { + context.RegisterSourceOutput(diagnostics, static (context, diagnostics) => + { + foreach (DiagnosticInfo diagnostic in diagnostics) + { + context.ReportDiagnostic(diagnostic.ToDiagnostic()); + } + }); + } +} diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Extensions/IncrementalValuesProviderExtensions.cs b/CommunityToolkit.Mvvm.SourceGenerators/Extensions/IncrementalValuesProviderExtensions.cs index 85eff06e..aea9a11f 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/Extensions/IncrementalValuesProviderExtensions.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/Extensions/IncrementalValuesProviderExtensions.cs @@ -22,33 +22,36 @@ internal static class IncrementalValuesProviderExtensions /// /// The type of left items in each tuple. /// The type of right items in each tuple. + /// The type of resulting projected elements. /// The input instance. /// A comparer. + /// A projection function to convert gathered elements. /// An with the grouped results. - public static IncrementalValuesProvider<(TLeft Left, ImmutableArray Right)> GroupBy( + public static IncrementalValuesProvider<(TLeft Left, ImmutableArray Right)> GroupBy( this IncrementalValuesProvider<(TLeft Left, TRight Right)> source, - IEqualityComparer comparer) + IEqualityComparer comparer, + Func projection) { return source.Collect().SelectMany((item, _) => { - Dictionary.Builder> map = new(comparer); + Dictionary.Builder> map = new(comparer); foreach ((TLeft hierarchy, TRight info) in item) { - if (!map.TryGetValue(hierarchy, out ImmutableArray.Builder builder)) + if (!map.TryGetValue(hierarchy, out ImmutableArray.Builder builder)) { - builder = ImmutableArray.CreateBuilder(); + builder = ImmutableArray.CreateBuilder(); map.Add(hierarchy, builder); } - builder.Add(info); + builder.Add(projection(info)); } - ImmutableArray<(TLeft Hierarchy, ImmutableArray Properties)>.Builder result = - ImmutableArray.CreateBuilder<(TLeft, ImmutableArray)>(); + ImmutableArray<(TLeft Hierarchy, ImmutableArray Elements)>.Builder result = + ImmutableArray.CreateBuilder<(TLeft, ImmutableArray)>(); - foreach (KeyValuePair.Builder> entry in map) + foreach (KeyValuePair.Builder> entry in map) { result.Add((entry.Key, entry.Value.ToImmutable())); } diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Helpers/ImmutableArrayBuilder{T}.cs b/CommunityToolkit.Mvvm.SourceGenerators/Helpers/ImmutableArrayBuilder{T}.cs new file mode 100644 index 00000000..206187b6 --- /dev/null +++ b/CommunityToolkit.Mvvm.SourceGenerators/Helpers/ImmutableArrayBuilder{T}.cs @@ -0,0 +1,89 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Immutable; + +namespace CommunityToolkit.Mvvm.SourceGenerators.Helpers; + +/// +/// A helper type to build instances with pooled buffers. +/// +/// The type of items to create arrays for. +internal ref struct ImmutableArrayBuilder +{ + /// + /// The shared instance to share objects. + /// + private static readonly ObjectPool.Builder> sharedObjectPool = new(ImmutableArray.CreateBuilder); + + /// + /// The owner instance. + /// + private readonly ObjectPool.Builder> objectPool; + + /// + /// The rented instance to use. + /// + private ImmutableArray.Builder? builder; + + /// + /// Rents a new pooled instance through a new value. + /// + /// A to interact with the underlying instance. + public static ImmutableArrayBuilder Rent() + { + return new(sharedObjectPool, sharedObjectPool.Allocate()); + } + + /// + /// Creates a new object with the specified parameters. + /// + /// + /// + private ImmutableArrayBuilder(ObjectPool.Builder> objectPool, ImmutableArray.Builder builder) + { + this.objectPool = objectPool; + this.builder = builder; + } + + /// + public readonly int Count + { + get => this.builder!.Count; + } + + /// + public readonly void Add(T item) + { + this.builder!.Add(item); + } + + /// + public readonly ImmutableArray ToImmutable() + { + return this.builder!.ToImmutable(); + } + + /// + public readonly T[] ToArray() + { + return this.builder!.ToArray(); + } + + /// + public void Dispose() + { + ImmutableArray.Builder? builder = this.builder; + + this.builder = null; + + if (builder is not null) + { + builder.Clear(); + + this.objectPool.Free(builder); + } + } +} \ No newline at end of file diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Helpers/ObjectPool{T}.cs b/CommunityToolkit.Mvvm.SourceGenerators/Helpers/ObjectPool{T}.cs new file mode 100644 index 00000000..73e26d42 --- /dev/null +++ b/CommunityToolkit.Mvvm.SourceGenerators/Helpers/ObjectPool{T}.cs @@ -0,0 +1,163 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +// Ported from Roslyn, see: https://github.com/dotnet/roslyn/blob/main/src/Dependencies/PooledObjects/ObjectPool%601.cs. + +using System; +using System.Runtime.CompilerServices; +using System.Threading; + +namespace CommunityToolkit.Mvvm.SourceGenerators.Helpers; + +/// +/// +/// Generic implementation of object pooling pattern with predefined pool size limit. The main purpose +/// is that limited number of frequently used objects can be kept in the pool for further recycling. +/// +/// +/// Notes: +/// +/// +/// It is not the goal to keep all returned objects. Pool is not meant for storage. If there +/// is no space in the pool, extra returned objects will be dropped. +/// +/// +/// It is implied that if object was obtained from a pool, the caller will return it back in +/// a relatively short time. Keeping checked out objects for long durations is ok, but +/// reduces usefulness of pooling. Just new up your own. +/// +/// +/// +/// +/// Not returning objects to the pool in not detrimental to the pool's work, but is a bad practice. +/// Rationale: if there is no intent for reusing the object, do not use pool - just use "new". +/// +/// +/// The type of objects to pool. +internal sealed class ObjectPool + where T : class +{ + /// + /// The factory is stored for the lifetime of the pool. We will call this only when pool needs to + /// expand. compared to "new T()", Func gives more flexibility to implementers and faster than "new T()". + /// + private readonly Func factory; + + /// + /// The array of cached items. + /// + private readonly Element[] items; + + /// + /// Storage for the pool objects. The first item is stored in a dedicated field + /// because we expect to be able to satisfy most requests from it. + /// + private T? firstItem; + + /// + /// Creates a new instance with the specified parameters. + /// + /// The input factory to produce items. + public ObjectPool(Func factory) + : this(factory, Environment.ProcessorCount * 2) + { + } + + /// + /// Creates a new instance with the specified parameters. + /// + /// The input factory to produce items. + /// The pool size to use. + public ObjectPool(Func factory, int size) + { + this.factory = factory; + this.items = new Element[size - 1]; + } + + /// + /// Produces a instance. + /// + /// The returned item to use. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public T Allocate() + { + T? item = this.firstItem; + + if (item is null || item != Interlocked.CompareExchange(ref this.firstItem, null, item)) + { + item = AllocateSlow(); + } + + return item; + } + + /// + /// Returns a given instance to the pool. + /// + /// The instance to return. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Free(T obj) + { + if (this.firstItem is null) + { + this.firstItem = obj; + } + else + { + FreeSlow(obj); + } + } + + /// + /// Allocates a new item. + /// + /// The returned item to use. + [MethodImpl(MethodImplOptions.NoInlining)] + private T AllocateSlow() + { + foreach (ref Element element in this.items.AsSpan()) + { + T? instance = element.Value; + + if (instance is not null) + { + if (instance == Interlocked.CompareExchange(ref element.Value, null, instance)) + { + return instance; + } + } + } + + return this.factory(); + } + + /// + /// Frees a given item. + /// + /// The item to return to the pool. + [MethodImpl(MethodImplOptions.NoInlining)] + private void FreeSlow(T obj) + { + foreach (ref Element element in this.items.AsSpan()) + { + if (element.Value is null) + { + element.Value = obj; + + break; + } + } + } + + /// + /// A container for a produced item (using a wrapper to avoid covariance checks). + /// + private struct Element + { + /// + /// The value held at the current element. + /// + internal T? Value; + } +} \ No newline at end of file diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Input/Models/CommandInfo.cs b/CommunityToolkit.Mvvm.SourceGenerators/Input/Models/CommandInfo.cs index f9240249..5bec857b 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/Input/Models/CommandInfo.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/Input/Models/CommandInfo.cs @@ -42,10 +42,16 @@ internal sealed record CommandInfo( bool FlowExceptionsToTaskScheduler, bool IncludeCancelCommand) { + /// + public bool Equals(CommandInfo? obj) => Comparer.Default.Equals(this, obj); + + /// + public override int GetHashCode() => Comparer.Default.GetHashCode(this); + /// /// An implementation for . /// - public sealed class Comparer : Comparer + private sealed class Comparer : Comparer { /// protected override void AddToHashCode(ref HashCode hashCode, CommandInfo obj) diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs b/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs index 3fc4f510..8c15335b 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs @@ -7,9 +7,10 @@ using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Linq; -using CommunityToolkit.Mvvm.SourceGenerators.Diagnostics; using CommunityToolkit.Mvvm.SourceGenerators.Extensions; +using CommunityToolkit.Mvvm.SourceGenerators.Helpers; using CommunityToolkit.Mvvm.SourceGenerators.Input.Models; +using CommunityToolkit.Mvvm.SourceGenerators.Models; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -27,15 +28,20 @@ partial class RelayCommandGenerator internal static class Execute { /// - /// Processes a given target method. + /// Processes a given annotated methods and produces command info, if possible. /// /// The input instance to process. /// The instance the method was annotated with. + /// The resulting instance, if successfully generated. /// The resulting diagnostics from the processing operation. - /// The resulting instance for , if available. - public static CommandInfo? GetInfo(IMethodSymbol methodSymbol, AttributeData attributeData, out ImmutableArray diagnostics) + /// Whether a instance could be generated successfully. + public static bool TryGetInfo( + IMethodSymbol methodSymbol, + AttributeData attributeData, + [NotNullWhen(true)] out CommandInfo? commandInfo, + out ImmutableArray diagnostics) { - ImmutableArray.Builder builder = ImmutableArray.CreateBuilder(); + using ImmutableArrayBuilder builder = ImmutableArrayBuilder.Rent(); // Validate the method definition is unique if (!IsCommandDefinitionUnique(methodSymbol, builder)) @@ -49,7 +55,7 @@ internal static class Execute // Get the command type symbols if (!TryMapCommandTypesFromMethod( methodSymbol, - builder, + in builder, out string? commandInterfaceType, out string? commandClassType, out string? delegateType, @@ -66,7 +72,7 @@ internal static class Execute methodSymbol, attributeData, commandClassType, - builder, + in builder, out bool allowConcurrentExecutions)) { goto Failure; @@ -77,7 +83,7 @@ internal static class Execute methodSymbol, attributeData, commandClassType, - builder, + in builder, out bool flowExceptionsToTaskScheduler)) { goto Failure; @@ -88,7 +94,7 @@ internal static class Execute methodSymbol, attributeData, commandTypeArguments, - builder, + in builder, out string? canExecuteMemberName, out CanExecuteExpressionType? canExecuteExpressionType)) { @@ -101,15 +107,13 @@ internal static class Execute attributeData, commandClassType, supportsCancellation, - builder, + in builder, out bool generateCancelCommand)) { goto Failure; } - diagnostics = builder.ToImmutable(); - - return new( + commandInfo = new CommandInfo( methodSymbol.Name, fieldName, propertyName, @@ -124,10 +128,15 @@ internal static class Execute flowExceptionsToTaskScheduler, generateCancelCommand); + diagnostics = builder.ToImmutable(); + + return true; + Failure: + commandInfo = null; diagnostics = builder.ToImmutable(); - return null; + return false; } /// @@ -170,7 +179,7 @@ public static ImmutableArray GetSyntax(CommandInfo comm .WithOpenBracketToken(Token(TriviaList(Comment($"/// The backing field for .")), SyntaxKind.OpenBracketToken, TriviaList()))); // Prepares the argument to pass the underlying method to invoke - ImmutableArray.Builder commandCreationArguments = ImmutableArray.CreateBuilder(); + using ImmutableArrayBuilder commandCreationArguments = ImmutableArrayBuilder.Rent(); // The first argument is the execute method, which is always present commandCreationArguments.Add( @@ -354,7 +363,7 @@ public static ImmutableArray GetSyntax(CommandInfo comm /// The input instance to process. /// The current collection of gathered diagnostics. /// Whether or not was unique within its containing type. - private static bool IsCommandDefinitionUnique(IMethodSymbol methodSymbol, ImmutableArray.Builder diagnostics) + private static bool IsCommandDefinitionUnique(IMethodSymbol methodSymbol, in ImmutableArrayBuilder diagnostics) { // If a duplicate is present in any of the base types, always emit a diagnostic for the current method. // That is, there is no need to check the order: we assume the priority is top-down in the type hierarchy. @@ -450,7 +459,7 @@ public static (string FieldName, string PropertyName) GetGeneratedFieldAndProper /// Whether or not was valid and the requested types have been set. private static bool TryMapCommandTypesFromMethod( IMethodSymbol methodSymbol, - ImmutableArray.Builder diagnostics, + in ImmutableArrayBuilder diagnostics, [NotNullWhen(true)] out string? commandInterfaceType, [NotNullWhen(true)] out string? commandClassType, [NotNullWhen(true)] out string? delegateType, @@ -580,7 +589,7 @@ private static bool TryGetAllowConcurrentExecutionsSwitch( IMethodSymbol methodSymbol, AttributeData attributeData, string commandClassType, - ImmutableArray.Builder diagnostics, + in ImmutableArrayBuilder diagnostics, out bool allowConcurrentExecutions) { // Try to get the custom switch for concurrent executions (the default is false) @@ -618,7 +627,7 @@ private static bool TryGetFlowExceptionsToTaskSchedulerSwitch( IMethodSymbol methodSymbol, AttributeData attributeData, string commandClassType, - ImmutableArray.Builder diagnostics, + in ImmutableArrayBuilder diagnostics, out bool flowExceptionsToTaskScheduler) { // Try to get the custom switch for task scheduler exception flow (the default is false) @@ -657,7 +666,7 @@ private static bool TryGetIncludeCancelCommandSwitch( AttributeData attributeData, string commandClassType, bool supportsCancellation, - ImmutableArray.Builder diagnostics, + in ImmutableArrayBuilder diagnostics, out bool generateCancelCommand) { // Try to get the custom switch for cancel command generation (the default is false) @@ -697,7 +706,7 @@ private static bool TryGetCanExecuteExpressionType( IMethodSymbol methodSymbol, AttributeData attributeData, ImmutableArray commandTypeArguments, - ImmutableArray.Builder diagnostics, + in ImmutableArrayBuilder diagnostics, out string? canExecuteMemberName, out CanExecuteExpressionType? canExecuteExpressionType) { diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs b/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs index 25331793..78bdb229 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs @@ -38,31 +38,30 @@ public void Initialize(IncrementalGeneratorInitializationContext context) IMethodSymbol methodSymbol = (IMethodSymbol)context.TargetSymbol; - // Produce the incremental models - HierarchyInfo hierarchy = HierarchyInfo.From(methodSymbol.ContainingType); - CommandInfo? commandInfo = Execute.GetInfo(methodSymbol, context.Attributes[0], out ImmutableArray diagnostics); + // Get the hierarchy info for the target symbol, and try to gather the command info + HierarchyInfo? hierarchy = HierarchyInfo.From(methodSymbol.ContainingType); + + _ = Execute.TryGetInfo(methodSymbol, context.Attributes[0], out CommandInfo? commandInfo, out ImmutableArray diagnostics); return (Hierarchy: hierarchy, new Result(commandInfo, diagnostics)); }) - .Where(static item => item.Hierarchy is not null); + .Where(static item => item.Hierarchy is not null)!; // Output the diagnostics context.ReportDiagnostics(commandInfoWithErrors.Select(static (item, _) => item.Info.Errors)); // Get the filtered sequence to enable caching - IncrementalValuesProvider<(HierarchyInfo Hierarchy, CommandInfo Info)> commandInfo = + IncrementalValuesProvider<(HierarchyInfo Hierarchy, Result Info)> commandInfo = commandInfoWithErrors - .Where(static item => item.Info.Value is not null) - .Select(static (item, _) => (item.Hierarchy, item.Info.Value!)) - .WithComparers(HierarchyInfo.Comparer.Default, CommandInfo.Comparer.Default); + .Where(static item => item.Info.Value is not null)!; // Generate the commands context.RegisterSourceOutput(commandInfo, static (context, item) => { - ImmutableArray memberDeclarations = Execute.GetSyntax(item.Info); + ImmutableArray memberDeclarations = Execute.GetSyntax(item.Info.Value); CompilationUnitSyntax compilationUnit = item.Hierarchy.GetCompilationUnit(memberDeclarations); - context.AddSource($"{item.Hierarchy.FilenameHint}.{item.Info.MethodName}.g.cs", compilationUnit.GetText(Encoding.UTF8)); + context.AddSource($"{item.Hierarchy.FilenameHint}.{item.Info.Value.MethodName}.g.cs", compilationUnit.GetText(Encoding.UTF8)); }); } } diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.Execute.cs b/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.Execute.cs index a9c13cab..a58a55c8 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.Execute.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.Execute.cs @@ -5,6 +5,7 @@ using System.Collections.Immutable; using System.Linq; using CommunityToolkit.Mvvm.SourceGenerators.Extensions; +using CommunityToolkit.Mvvm.SourceGenerators.Helpers; using CommunityToolkit.Mvvm.SourceGenerators.Input.Models; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; @@ -28,7 +29,7 @@ private static class Execute /// An array of interface type symbols. public static ImmutableArray GetInterfaces(INamedTypeSymbol typeSymbol) { - ImmutableArray.Builder iRecipientInterfaces = ImmutableArray.CreateBuilder(); + using ImmutableArrayBuilder iRecipientInterfaces = ImmutableArrayBuilder.Rent(); foreach (INamedTypeSymbol interfaceSymbol in typeSymbol.AllInterfaces) { @@ -50,7 +51,7 @@ public static ImmutableArray GetInterfaces(INamedTypeSymbol ty /// A instance for the current type being inspected. public static RecipientInfo GetInfo(INamedTypeSymbol typeSymbol, ImmutableArray interfaceSymbols) { - ImmutableArray.Builder names = ImmutableArray.CreateBuilder(interfaceSymbols.Length); + using ImmutableArrayBuilder names = ImmutableArrayBuilder.Rent(); foreach (INamedTypeSymbol interfaceSymbol in interfaceSymbols) { @@ -60,7 +61,7 @@ public static RecipientInfo GetInfo(INamedTypeSymbol typeSymbol, ImmutableArray< return new( typeSymbol.GetFullMetadataNameForFileName(), typeSymbol.GetFullyQualifiedName(), - names.MoveToImmutable()); + names.ToImmutable()); } /// @@ -70,8 +71,7 @@ public static RecipientInfo GetInfo(INamedTypeSymbol typeSymbol, ImmutableArray< /// The head instance with the type attributes. public static CompilationUnitSyntax GetSyntax(bool isDynamicallyAccessedMembersAttributeAvailable) { - int numberOfAttributes = 5 + (isDynamicallyAccessedMembersAttributeAvailable ? 1 : 0); - ImmutableArray.Builder attributes = ImmutableArray.CreateBuilder(numberOfAttributes); + using ImmutableArrayBuilder attributes = ImmutableArrayBuilder.Rent(); // Prepare the base attributes with are always present: // @@ -129,7 +129,7 @@ public static CompilationUnitSyntax GetSyntax(bool isDynamicallyAccessedMembersA Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.PartialKeyword)) - .AddAttributeLists(attributes.MoveToImmutable().ToArray()))) + .AddAttributeLists(attributes.ToArray()))) .NormalizeWhitespace(); } @@ -284,7 +284,7 @@ public static CompilationUnitSyntax GetSyntax(RecipientInfo recipientInfo) /// The sequence of instances to register message handlers. private static ImmutableArray EnumerateRegistrationStatements(RecipientInfo recipientInfo) { - ImmutableArray.Builder statements = ImmutableArray.CreateBuilder(recipientInfo.MessageTypes.Length); + using ImmutableArrayBuilder statements = ImmutableArrayBuilder.Rent(); // This loop produces a sequence of statements as follows: // @@ -305,7 +305,7 @@ private static ImmutableArray EnumerateRegistrationStatements(R .AddArgumentListArguments(Argument(IdentifierName("recipient"))))); } - return statements.MoveToImmutable(); + return statements.ToImmutable(); } /// @@ -315,7 +315,7 @@ private static ImmutableArray EnumerateRegistrationStatements(R /// The sequence of instances to register message handlers. private static ImmutableArray EnumerateRegistrationStatementsWithTokens(RecipientInfo recipientInfo) { - ImmutableArray.Builder statements = ImmutableArray.CreateBuilder(recipientInfo.MessageTypes.Length); + using ImmutableArrayBuilder statements = ImmutableArrayBuilder.Rent(); // This loop produces a sequence of statements as follows: // @@ -336,7 +336,7 @@ private static ImmutableArray EnumerateRegistrationStatementsWi .AddArgumentListArguments(Argument(IdentifierName("recipient")), Argument(IdentifierName("token"))))); } - return statements.MoveToImmutable(); + return statements.ToImmutable(); } } } diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs b/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs index be66b42c..b076c77d 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs @@ -61,8 +61,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) return Execute.GetInfo(typeSymbol, interfaceSymbols); }) - .Where(static item => item is not null)! - .WithComparer(RecipientInfo.Comparer.Default); + .Where(static item => item is not null)!; // Check whether the header file is needed IncrementalValueProvider isHeaderFileNeeded = diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Messaging/Models/RecipientInfo.cs b/CommunityToolkit.Mvvm.SourceGenerators/Messaging/Models/RecipientInfo.cs index b9d1f068..6dc6bfa8 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/Messaging/Models/RecipientInfo.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/Messaging/Models/RecipientInfo.cs @@ -22,10 +22,16 @@ internal sealed record RecipientInfo( string TypeName, ImmutableArray MessageTypes) { + /// + public bool Equals(RecipientInfo? obj) => Comparer.Default.Equals(this, obj); + + /// + public override int GetHashCode() => Comparer.Default.GetHashCode(this); + /// /// An implementation for . /// - public sealed class Comparer : Comparer + private sealed class Comparer : Comparer { /// protected override void AddToHashCode(ref HashCode hashCode, RecipientInfo obj) diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Models/DiagnosticInfo.cs b/CommunityToolkit.Mvvm.SourceGenerators/Models/DiagnosticInfo.cs new file mode 100644 index 00000000..a324ada6 --- /dev/null +++ b/CommunityToolkit.Mvvm.SourceGenerators/Models/DiagnosticInfo.cs @@ -0,0 +1,104 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +// This file is ported and adapted from ComputeSharp (Sergio0694/ComputeSharp), +// more info in ThirdPartyNotices.txt in the root of the project. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using CommunityToolkit.Mvvm.SourceGenerators.Extensions; +using CommunityToolkit.Mvvm.SourceGenerators.Helpers; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Text; + +namespace CommunityToolkit.Mvvm.SourceGenerators.Models; + +/// +/// A model for a serializeable diagnostic info. +/// +/// The wrapped instance. +/// The tree to use as location for the diagnostic, if available. +/// The span to use as location for the diagnostic. +/// The diagnostic arguments. +internal sealed record DiagnosticInfo( + DiagnosticDescriptor Descriptor, + SyntaxTree? SyntaxTree, + TextSpan TextSpan, + ImmutableArray Arguments) +{ + /// + public bool Equals(DiagnosticInfo? obj) => Comparer.Default.Equals(this, obj); + + /// + public override int GetHashCode() => Comparer.Default.GetHashCode(this); + + /// + /// Creates a new instance with the state from this model. + /// + /// A new instance with the state from this model. + public Diagnostic ToDiagnostic() + { + if (SyntaxTree is not null) + { + return Diagnostic.Create(Descriptor, Location.Create(SyntaxTree, TextSpan), Arguments.ToArray()); + } + + return Diagnostic.Create(Descriptor, null, Arguments.ToArray()); + } + + /// + /// Creates a new instance with the specified parameters. + /// + /// The input for the diagnostics to create. + /// The source to attach the diagnostics to. + /// The optional arguments for the formatted message to include. + /// A new instance with the specified parameters. + public static DiagnosticInfo Create(DiagnosticDescriptor descriptor, ISymbol symbol, params object[] args) + { + Location location = symbol.Locations.First(); + + return new(descriptor, location.SourceTree, location.SourceSpan, args.Select(static arg => arg.ToString()).ToImmutableArray()); + } + + /// + /// Creates a new instance with the specified parameters. + /// + /// The input for the diagnostics to create. + /// The source to attach the diagnostics to. + /// The optional arguments for the formatted message to include. + /// A new instance with the specified parameters. + public static DiagnosticInfo Create(DiagnosticDescriptor descriptor, SyntaxNode node, params object[] args) + { + Location location = node.GetLocation(); + + return new(descriptor, location.SourceTree, location.SourceSpan, args.Select(static arg => arg.ToString()).ToImmutableArray()); + } + + /// + /// An implementation for . + /// + private sealed class Comparer : Comparer + { + /// + protected override void AddToHashCode(ref HashCode hashCode, DiagnosticInfo obj) + { + hashCode.Add(obj.Descriptor); + hashCode.Add(obj.SyntaxTree); + hashCode.Add(obj.TextSpan); + hashCode.AddRange(obj.Arguments); + } + + /// + protected override bool AreEqual(DiagnosticInfo x, DiagnosticInfo y) + { + return + x.Descriptor.Equals(y.Descriptor) && + x.SyntaxTree == y.SyntaxTree && + x.TextSpan.Equals(y.TextSpan) && + x.Arguments.SequenceEqual(y.Arguments); + } + } +} diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Models/HierarchyInfo.cs b/CommunityToolkit.Mvvm.SourceGenerators/Models/HierarchyInfo.cs index 93924ce7..b2382507 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/Models/HierarchyInfo.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/Models/HierarchyInfo.cs @@ -32,7 +32,7 @@ internal sealed partial record HierarchyInfo(string FilenameHint, string Metadat /// A instance describing . public static HierarchyInfo From(INamedTypeSymbol typeSymbol) { - ImmutableArray.Builder hierarchy = ImmutableArray.CreateBuilder(); + using ImmutableArrayBuilder hierarchy = ImmutableArrayBuilder.Rent(); for (INamedTypeSymbol? parent = typeSymbol; parent is not null; diff --git a/CommunityToolkit.Mvvm.SourceGenerators/Models/Result.cs b/CommunityToolkit.Mvvm.SourceGenerators/Models/Result.cs index 83ca3d1b..ed192954 100644 --- a/CommunityToolkit.Mvvm.SourceGenerators/Models/Result.cs +++ b/CommunityToolkit.Mvvm.SourceGenerators/Models/Result.cs @@ -5,8 +5,12 @@ // This file is ported and adapted from ComputeSharp (Sergio0694/ComputeSharp), // more info in ThirdPartyNotices.txt in the root of the project. +using System; +using System.Collections.Generic; using System.Collections.Immutable; -using Microsoft.CodeAnalysis; +using System.Linq; +using CommunityToolkit.Mvvm.SourceGenerators.Extensions; +using CommunityToolkit.Mvvm.SourceGenerators.Helpers; namespace CommunityToolkit.Mvvm.SourceGenerators.Models; @@ -16,4 +20,33 @@ namespace CommunityToolkit.Mvvm.SourceGenerators.Models; /// The type of the wrapped value. /// The wrapped value for the current result. /// The associated diagnostic errors, if any. -internal sealed record Result(TValue Value, ImmutableArray Errors); +internal sealed record Result(TValue Value, ImmutableArray Errors) + where TValue : IEquatable? +{ + /// + public bool Equals(Result? obj) => Comparer.Default.Equals(this, obj); + + /// + public override int GetHashCode() => Comparer.Default.GetHashCode(this); + + /// + /// An implementation for . + /// + private sealed class Comparer : Comparer, Comparer> + { + /// + protected override void AddToHashCode(ref HashCode hashCode, Result obj) + { + hashCode.Add(obj.Value); + hashCode.AddRange(obj.Errors); + } + + /// + protected override bool AreEqual(Result x, Result y) + { + return + EqualityComparer.Default.Equals(x.Value, y.Value) && + x.Errors.SequenceEqual(y.Errors); + } + } +}