diff --git a/src/TypedSignalR.Client/CodeAnalysis/SpecialSymbols.cs b/src/TypedSignalR.Client/CodeAnalysis/SpecialSymbols.cs index ac62535..6d60e18 100644 --- a/src/TypedSignalR.Client/CodeAnalysis/SpecialSymbols.cs +++ b/src/TypedSignalR.Client/CodeAnalysis/SpecialSymbols.cs @@ -1,3 +1,4 @@ +using System.Collections.Immutable; using Microsoft.CodeAnalysis; namespace TypedSignalR.Client.CodeAnalysis; @@ -9,9 +10,9 @@ public class SpecialSymbols public readonly INamedTypeSymbol CancellationTokenSymbol; public readonly INamedTypeSymbol AsyncEnumerableSymbol; public readonly INamedTypeSymbol ChannelReaderSymbol; - public readonly INamedTypeSymbol HubConnectionObserverSymbol; - public readonly IMethodSymbol CreateHubProxyMethodSymbol; - public readonly IMethodSymbol RegisterMethodSymbol; + public readonly ImmutableArray HubConnectionObserverSymbols; + public readonly ImmutableArray CreateHubProxyMethodSymbols; + public readonly ImmutableArray RegisterMethodSymbols; public SpecialSymbols( INamedTypeSymbol taskSymbol, @@ -19,17 +20,17 @@ public SpecialSymbols( INamedTypeSymbol cancellationTokenSymbol, INamedTypeSymbol asyncEnumerableSymbol, INamedTypeSymbol channelReaderSymbol, - INamedTypeSymbol hubConnectionObserverSymbol, - IMethodSymbol createHubProxyMethodSymbol, - IMethodSymbol registerMethodSymbol) + ImmutableArray hubConnectionObserverSymbols, + ImmutableArray createHubProxyMethodSymbols, + ImmutableArray registerMethodSymbols) { TaskSymbol = taskSymbol; GenericTaskSymbol = genericTaskSymbol; CancellationTokenSymbol = cancellationTokenSymbol; AsyncEnumerableSymbol = asyncEnumerableSymbol; ChannelReaderSymbol = channelReaderSymbol; - HubConnectionObserverSymbol = hubConnectionObserverSymbol; - CreateHubProxyMethodSymbol = createHubProxyMethodSymbol; - RegisterMethodSymbol = registerMethodSymbol; + HubConnectionObserverSymbols = hubConnectionObserverSymbols; + CreateHubProxyMethodSymbols = createHubProxyMethodSymbols; + RegisterMethodSymbols = registerMethodSymbols; } } diff --git a/src/TypedSignalR.Client/SourceGenerator.cs b/src/TypedSignalR.Client/SourceGenerator.cs index 283c23c..9e3f77f 100644 --- a/src/TypedSignalR.Client/SourceGenerator.cs +++ b/src/TypedSignalR.Client/SourceGenerator.cs @@ -129,9 +129,12 @@ private static ValidatedSourceSymbol ValidateCreateHubProxyMethodSymbol((SourceS return default; } - if (SymbolEqualityComparer.Default.Equals(extensionMethodSymbol, specialSymbols.CreateHubProxyMethodSymbol)) + foreach (var createHubProxyMethodSymbol in specialSymbols.CreateHubProxyMethodSymbols) { - return new ValidatedSourceSymbol(methodSymbol, location); + if (SymbolEqualityComparer.Default.Equals(extensionMethodSymbol, createHubProxyMethodSymbol)) + { + return new ValidatedSourceSymbol(methodSymbol, location); + } } return default; @@ -159,9 +162,12 @@ private static ValidatedSourceSymbol ValidateRegisterMethodSymbol((SourceSymbol, return default; } - if (SymbolEqualityComparer.Default.Equals(extensionMethodSymbol, specialSymbols.RegisterMethodSymbol)) + foreach (var registerMethodSymbol in specialSymbols.RegisterMethodSymbols) { - return new ValidatedSourceSymbol(methodSymbol, location); + if (SymbolEqualityComparer.Default.Equals(extensionMethodSymbol, registerMethodSymbol)) + { + return new ValidatedSourceSymbol(methodSymbol, location); + } } return default; @@ -233,41 +239,33 @@ private static SpecialSymbols GetSpecialSymbols(Compilation compilation) var cancellationTokenSymbol = compilation.GetTypeByMetadataName("System.Threading.CancellationToken"); var asyncEnumerableSymbol = compilation.GetTypeByMetadataName("System.Collections.Generic.IAsyncEnumerable`1"); var channelReaderSymbol = compilation.GetTypeByMetadataName("System.Threading.Channels.ChannelReader`1"); - var hubConnectionObserverSymbol = compilation.GetTypeByMetadataName("TypedSignalR.Client.IHubConnectionObserver"); - var memberSymbols = compilation.GetTypeByMetadataName("TypedSignalR.Client.HubConnectionExtensions")!.GetMembers(); + var hubConnectionObserverSymbol = compilation.GetTypesByMetadataName("TypedSignalR.Client.IHubConnectionObserver"); + var hubConnectionExtensions = compilation.GetTypesByMetadataName("TypedSignalR.Client.HubConnectionExtensions"); - IMethodSymbol? createHubProxyMethodSymbol = null; - IMethodSymbol? registerMethodSymbol = null; + ImmutableArray createHubProxyMethodSymbol = ImmutableArray.Empty; + ImmutableArray registerMethodSymbol = ImmutableArray.Empty; - foreach (var memberSymbol in memberSymbols) + foreach (var hubConnectionExtension in hubConnectionExtensions) { - if (memberSymbol is not IMethodSymbol methodSymbol) + foreach (var memberSymbol in hubConnectionExtension.GetMembers()) { - continue; - } + if (memberSymbol is not IMethodSymbol methodSymbol) + { + continue; + } - if (methodSymbol.Name is "CreateHubProxy") - { - if (methodSymbol.MethodKind is MethodKind.Ordinary) + if (methodSymbol.MethodKind is not MethodKind.Ordinary) { - createHubProxyMethodSymbol = methodSymbol; + continue; + } - if (registerMethodSymbol is not null) - { - break; - } + if (methodSymbol.Name is "CreateHubProxy") + { + createHubProxyMethodSymbol = createHubProxyMethodSymbol.Add(methodSymbol); } - } - else if (methodSymbol.Name is "Register") - { - if (methodSymbol.MethodKind is MethodKind.Ordinary) + else if (methodSymbol.Name is "Register") { - registerMethodSymbol = methodSymbol; - - if (createHubProxyMethodSymbol is not null) - { - break; - } + registerMethodSymbol = registerMethodSymbol.Add(methodSymbol); } } } @@ -323,9 +321,12 @@ private static IReadOnlyList ExtractReceiverTypesFromRegisterMetho ITypeSymbol receiverTypeSymbol = methodSymbol.TypeArguments[0]; - if (SymbolEqualityComparer.Default.Equals(receiverTypeSymbol, specialSymbols.HubConnectionObserverSymbol)) + foreach (var hubConnectionObserverSymbol in specialSymbols.HubConnectionObserverSymbols) { - continue; + if (SymbolEqualityComparer.Default.Equals(receiverTypeSymbol, hubConnectionObserverSymbol)) + { + continue; + } } var isValid = TypeValidator.ValidateReceiverTypeRule(context, receiverTypeSymbol, specialSymbols, location); diff --git a/src/TypedSignalR.Client/TypedSignalR.Client.csproj b/src/TypedSignalR.Client/TypedSignalR.Client.csproj index 00952d4..1faa581 100644 --- a/src/TypedSignalR.Client/TypedSignalR.Client.csproj +++ b/src/TypedSignalR.Client/TypedSignalR.Client.csproj @@ -31,7 +31,7 @@ - +