diff --git a/build/dependencies.props b/build/dependencies.props index 2f11bd156..f3bc3a678 100644 --- a/build/dependencies.props +++ b/build/dependencies.props @@ -8,6 +8,7 @@ 3.0.0 3.1.9 5.0.1 + 7.0.0-preview.7.22376.6 6.0.10 1.0.4 3.0.0 diff --git a/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs b/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs index 5ec8e80d9..cc91d53ee 100644 --- a/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs +++ b/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs @@ -17,10 +17,9 @@ internal interface ICallerClientResultsManager : IClientResultsManager /// /// /// - /// The InstanceId of target client the caller server knows when this method is called. If the target client is managed by the caller server, the caller server knows the InstanceId of target client and this parameter is not null. Otherwise, this parameter is null. /// /// - Task AddInvocation(string connectionId, string invocationId, string instanceId, CancellationToken cancellationToken); + Task AddInvocation(string connectionId, string invocationId, CancellationToken cancellationToken); void AddServiceMapping(ServiceMappingMessage serviceMappingMessage); @@ -29,5 +28,7 @@ internal interface ICallerClientResultsManager : IClientResultsManager bool TryCompleteResult(string connectionId, ClientCompletionMessage message); bool TryCompleteResult(string connectionId, ErrorCompletionMessage message); + + void RemoveInvocation(string invocationId); } } \ No newline at end of file diff --git a/src/Microsoft.Azure.SignalR.Common/ClientInvocation/IRoutedClientResultsManager.cs b/src/Microsoft.Azure.SignalR.Common/ClientInvocation/IRoutedClientResultsManager.cs index 1283c3242..e22f73c63 100644 --- a/src/Microsoft.Azure.SignalR.Common/ClientInvocation/IRoutedClientResultsManager.cs +++ b/src/Microsoft.Azure.SignalR.Common/ClientInvocation/IRoutedClientResultsManager.cs @@ -9,8 +9,6 @@ internal interface IRoutedClientResultsManager : IClientResultsManager { void AddInvocation(string connectionId, string invocationId, string callerServerId, CancellationToken cancellationToken); - bool ContainsInvocation(string invocationId); - void CleanupInvocationsByConnection(string connectionId); } } \ No newline at end of file diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs index ac7647817..721b1f0d1 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs @@ -295,7 +295,7 @@ protected Task OnServiceErrorAsync(ServiceErrorMessage serviceErrorMessage) return Task.CompletedTask; } - protected Task OnPingMessageAsync(PingMessage pingMessage) + protected virtual Task OnPingMessageAsync(PingMessage pingMessage) { if (RuntimeServicePingMessage.TryGetOffline(pingMessage, out var instanceId)) { @@ -550,7 +550,7 @@ private async Task ProcessIncomingAsync(ConnectionContext connection) } } - private Task DispatchMessageAsync(ServiceMessage message) + protected virtual Task DispatchMessageAsync(ServiceMessage message) { return message switch { diff --git a/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs b/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs index a1bd59978..4748392b1 100644 --- a/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs +++ b/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs @@ -31,7 +31,7 @@ public string GenerateInvocationId(string connectionId) return $"{connectionId}-{_clientResultManagerId}-{Interlocked.Increment(ref _lastInvocationId)}"; } - public Task AddInvocation(string connectionId, string invocationId, string instanceId, CancellationToken cancellationToken) + public Task AddInvocation(string connectionId, string invocationId, CancellationToken cancellationToken) { var tcs = new TaskCompletionSourceWithCancellation( cancellationToken, @@ -53,7 +53,7 @@ public Task AddInvocation(string connectionId, string invocationId, string { tcs.TrySetException(new Exception(completionMessage.Error)); } - }) { RouterInstanceId = instanceId } + }) ); Debug.Assert(result); @@ -66,18 +66,7 @@ public void AddServiceMapping(ServiceMappingMessage serviceMappingMessage) { if (_pendingInvocations.TryGetValue(serviceMappingMessage.InvocationId, out var invocation)) { - if (invocation.RouterInstanceId == null) - { - invocation.RouterInstanceId = serviceMappingMessage.InstanceId; - } - else - { - // do nothing - } - } - else - { - // do nothing + invocation.RouterInstanceId = serviceMappingMessage.InstanceId; } } @@ -174,6 +163,11 @@ public bool TryGetInvocationReturnType(string invocationId, out Type type) return false; } + public void RemoveInvocation(string invocationId) + { + _pendingInvocations.TryRemove(invocationId, out _); + } + // Unused, here to honor the IInvocationBinder interface but should never be called public IReadOnlyList GetParameterTypes(string methodName) => throw new NotImplementedException(); diff --git a/src/Microsoft.Azure.SignalR/ClientInvocation/RoutedClientResultsManager.cs b/src/Microsoft.Azure.SignalR/ClientInvocation/RoutedClientResultsManager.cs index 2d9990937..78c8b7b8c 100644 --- a/src/Microsoft.Azure.SignalR/ClientInvocation/RoutedClientResultsManager.cs +++ b/src/Microsoft.Azure.SignalR/ClientInvocation/RoutedClientResultsManager.cs @@ -40,11 +40,6 @@ public bool TryCompleteResult(string connectionId, CompletionMessage message) } } - public bool ContainsInvocation(string invocationId) - { - return _routedInvocations.TryGetValue(invocationId, out _); - } - public void CleanupInvocationsByConnection(string connectionId) { foreach (var (invocationId, invocation) in _routedInvocations) diff --git a/src/Microsoft.Azure.SignalR/DependencyInjectionExtensions.cs b/src/Microsoft.Azure.SignalR/DependencyInjectionExtensions.cs index ec7e29ffe..f35660863 100644 --- a/src/Microsoft.Azure.SignalR/DependencyInjectionExtensions.cs +++ b/src/Microsoft.Azure.SignalR/DependencyInjectionExtensions.cs @@ -88,6 +88,11 @@ private static ISignalRServerBuilder AddAzureSignalRCore(this ISignalRServerBuil .AddSingleton() .AddSingleton() .AddSingleton() +#if NET7_0_OR_GREATER + .AddSingleton() +#else + .AddSingleton() +#endif .AddSingleton(typeof(NegotiateHandler<>)); // If a custom router is added, do not add the default router diff --git a/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs b/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs index b1b7c989b..a4b7d7c5e 100644 --- a/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs +++ b/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs @@ -30,6 +30,7 @@ internal class ServiceHubDispatcher where THub : Hub private readonly IEndpointRouter _router; private readonly string _hubName; private readonly IServiceEventHandler _serviceEventHandler; + private readonly IClientInvocationManager _clientInvocationManager; protected readonly IServerNameProvider _nameProvider; @@ -45,6 +46,7 @@ public ServiceHubDispatcher( IServerNameProvider nameProvider, ServerLifetimeManager serverLifetimeManager, IClientConnectionFactory clientConnectionFactory, + IClientInvocationManager clientInvocationManager, IServiceEventHandler serviceEventHandler) { _serviceProtocol = serviceProtocol; @@ -62,6 +64,7 @@ public ServiceHubDispatcher( _nameProvider = nameProvider; _hubName = typeof(THub).Name; _serviceEventHandler = serviceEventHandler; + _clientInvocationManager = clientInvocationManager; serverLifetimeManager?.Register(ShutdownAsync); } @@ -150,7 +153,8 @@ internal virtual ServiceConnectionFactory GetServiceConnectionFactory( connectionDelegate, _clientConnectionFactory, _nameProvider, - _serviceEventHandler) + _serviceEventHandler, + _clientInvocationManager) { ConfigureContext = contextConfig, ShutdownMode = _options.GracefulShutdown.Mode diff --git a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs index 676326386..8ee58b737 100644 --- a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs +++ b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs @@ -3,10 +3,12 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR; +using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Azure.SignalR.Common; using Microsoft.Azure.SignalR.Protocol; using Microsoft.Extensions.Logging; @@ -18,8 +20,9 @@ internal class ServiceLifetimeManager : ServiceLifetimeManagerBase w { private const string MarkerNotConfiguredError = "'AddAzureSignalR(...)' was called without a matching call to 'IApplicationBuilder.UseAzureSignalR(...)'."; - + private readonly IClientInvocationManager _clientInvocationManager; private readonly IClientConnectionManager _clientConnectionManager; + private readonly string _callerId; public ServiceLifetimeManager( IServiceConnectionManager serviceConnectionManager, @@ -29,12 +32,15 @@ public ServiceLifetimeManager( AzureSignalRMarkerService marker, IOptions globalHubOptions, IOptions> hubOptions, - IBlazorDetector blazorDetector) + IBlazorDetector blazorDetector, + IServerNameProvider nameProvider, + IClientInvocationManager clientInvocationManager) : base( serviceConnectionManager, protocolResolver, globalHubOptions, - hubOptions, logger) + hubOptions, + logger) { // after core 3.0 UseAzureSignalR() is not required. #if NETSTANDARD2_0 @@ -43,12 +49,15 @@ public ServiceLifetimeManager( throw new InvalidOperationException(MarkerNotConfiguredError); } #endif - _clientConnectionManager = clientConnectionManager; - if (hubOptions.Value.SupportedProtocols != null && hubOptions.Value.SupportedProtocols.Any(x => x.Equals(Constants.Protocol.BlazorPack, StringComparison.OrdinalIgnoreCase))) { blazorDetector?.TrySetBlazor(typeof(THub).Name, true); } + + _callerId = nameProvider?.GetName() ?? throw new ArgumentNullException(nameof(nameProvider)); + + _clientInvocationManager = clientInvocationManager ?? throw new ArgumentNullException(nameof(clientInvocationManager)); + _clientConnectionManager = clientConnectionManager ?? throw new ArgumentNullException(nameof(clientConnectionManager)); } public override Task OnConnectedAsync(HubConnectionContext connection) @@ -103,6 +112,80 @@ public override async Task SendConnectionAsync(string connectionId, string metho } } +#if NET7_0_OR_GREATER + public override async Task InvokeConnectionAsync(string connectionId, string methodName, object[] args, CancellationToken cancellationToken = default) + { + if (IsInvalidArgument(connectionId)) + { + throw new ArgumentNullException(nameof(connectionId)); + } + + if (IsInvalidArgument(methodName)) + { + throw new ArgumentNullException(nameof(methodName)); + } + + var invocationId = _clientInvocationManager.Caller.GenerateInvocationId(connectionId); + var message = AppendMessageTracingId(new ClientInvocationMessage(invocationId, connectionId, _callerId, SerializeAllProtocols(methodName, args, invocationId))); + await WriteAsync(message); + var task = _clientInvocationManager.Caller.AddInvocation(connectionId, invocationId, cancellationToken); + + // Exception handling follows https://source.dot.net/#Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs,349 + try + { + return await task; + } + catch + { + _clientInvocationManager.Caller.RemoveInvocation(invocationId); + throw; + } + } + + public override async Task SetConnectionResultAsync(string connectionId, CompletionMessage result) + { + if (IsInvalidArgument(connectionId)) + { + throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(connectionId)); + } + if (_clientConnectionManager.ClientConnections.TryGetValue(connectionId, out var clientConnectionContext)) + { + // Determine which manager (Caller / Router) the `result` belongs to. + // `TryCompletionResult` returns false when the corresponding invocation is not existing. + IClientResultsManager clientResultsManager = null; + var payload = new ReadOnlyMemory(); + if (_clientInvocationManager.Caller.TryCompleteResult(connectionId, result)) + { + clientResultsManager = _clientInvocationManager.Caller; + // For caller server, the only purpose of sending ClientCompletionMessage is to inform service to cleanup the invocation, which means only InvocationId and ConnectionId make sense. To avoid serialization for useless payload, we keep payload as empty bytes. + } + if (_clientInvocationManager.Router.TryCompleteResult(connectionId, result)) + { + clientResultsManager = _clientInvocationManager.Router; + // For router server, it should send a ClientCompletionMessage with accurate payload content, which is necessary for the caller server. + payload = SerializeCompletionMessage(result, clientConnectionContext.Protocol); + } + + // Block unknown `results` which belongs to neither Caller nor Router + if (clientResultsManager != null) + { + var protocol = clientConnectionContext.Protocol; + var message = AppendMessageTracingId(new ClientCompletionMessage(result.InvocationId, connectionId, _callerId, protocol, payload)); + await WriteAsync(message); + } + } + } + + public override bool TryGetReturnType(string invocationId, [NotNullWhen(true)] out Type type) + { + if (_clientInvocationManager.Router.TryGetInvocationReturnType(invocationId, out type)) + { + return true; + } + return _clientInvocationManager.Caller.TryGetInvocationReturnType(invocationId, out type); + } +#endif + private MultiConnectionDataMessage CreateMessage(string connectionId, string methodName, object[] args, ClientConnectionContext serviceConnectionContext) { IDictionary> payloads; diff --git a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManagerBase.cs b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManagerBase.cs index 99677181f..d34e6676f 100644 --- a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManagerBase.cs +++ b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManagerBase.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using System.Diagnostics.CodeAnalysis; using Microsoft.AspNetCore.SignalR; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Azure.SignalR.Protocol; @@ -283,10 +284,18 @@ protected static bool IsInvalidArgument(IReadOnlyList list) return list == null; } - protected IDictionary> SerializeAllProtocols(string method, object[] args) + protected IDictionary> SerializeAllProtocols(string method, object[] args, string invocationId = null) { var payloads = new Dictionary>(); - var message = new InvocationMessage(method, args); + InvocationMessage message; + if (invocationId == null) + { + message = new InvocationMessage(method, args); + } + else + { + message = new InvocationMessage(invocationId, method, args); + } var serializedHubMessages = _messageSerializer.SerializeMessage(message); foreach (var serializedMessage in serializedHubMessages) { @@ -298,6 +307,9 @@ protected IDictionary> SerializeAllProtocols(string protected ReadOnlyMemory SerializeProtocol(string protocol, string method, object[] args) => _messageSerializer.SerializeMessage(protocol, new InvocationMessage(method, args)); + protected ReadOnlyMemory SerializeCompletionMessage(CompletionMessage message, string protocol) => + _messageSerializer.SerializeMessage(protocol, message); + protected virtual T AppendMessageTracingId(T message) where T : ServiceMessage, IMessageWithTracingId { return message.WithTracingId(); diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs index 8b0860ef8..88a36a09c 100644 --- a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs +++ b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs @@ -41,6 +41,8 @@ internal partial class ServiceConnection : ServiceConnectionBase private readonly ConnectionDelegate _connectionDelegate; + private readonly IClientInvocationManager _clientInvocationManager; + public Action ConfigureContext { get; set; } public ServiceConnection(IServiceProtocol serviceProtocol, @@ -54,6 +56,7 @@ public ServiceConnection(IServiceProtocol serviceProtocol, HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, IServiceEventHandler serviceEventHandler, + IClientInvocationManager clientInvocationManager, ServiceConnectionType connectionType = ServiceConnectionType.Default, GracefulShutdownMode mode = GracefulShutdownMode.Off, int closeTimeOutMilliseconds = DefaultCloseTimeoutMilliseconds @@ -64,6 +67,7 @@ public ServiceConnection(IServiceProtocol serviceProtocol, _connectionDelegate = connectionDelegate; _clientConnectionFactory = clientConnectionFactory; _closeTimeOutMilliseconds = closeTimeOutMilliseconds; + _clientInvocationManager = clientInvocationManager; } protected override Task CreateConnection(string target = null) @@ -188,7 +192,33 @@ protected override async Task OnClientMessageAsync(ConnectionDataMessage connect } } - private async Task ProcessClientConnectionAsync(ClientConnectionContext connection) + protected override Task DispatchMessageAsync(ServiceMessage message) + { + return message switch + { + PingMessage pingMessage => OnPingMessageAsync(pingMessage), + ClientInvocationMessage clientInvocationMessage => OnClientInvocationAsync(clientInvocationMessage), + ServiceMappingMessage serviceMappingMessage => OnServiceMappingAsync(serviceMappingMessage), + ClientCompletionMessage clientCompletionMessage => OnClientCompletionAsync(clientCompletionMessage), + ErrorCompletionMessage errorCompletionMessage => OnErrorCompletionAsync(errorCompletionMessage), + _ => base.DispatchMessageAsync(message) + }; + } + + protected override Task OnPingMessageAsync(PingMessage pingMessage) + { +#if NET7_0_OR_GREATER + if (RuntimeServicePingMessage.TryGetOffline(pingMessage, out var instanceId)) + { + _clientInvocationManager.Caller.CleanupInvocationsByInstance(instanceId); + // Router invocations will be cleanup by its `CleanupInvocationsByConnection`, which is called by `RemoveClientConnection`. + // In `base.OnPingMessageAsync`, `CleanupClientConnections(instanceId)` will finally execute `RemoveClientConnection` for each ConnectionId. + } +#endif + return base.OnPingMessageAsync(pingMessage); + } + + private async Task ProcessClientConnectionAsync(ClientConnectionContext connection) { try { @@ -434,7 +464,34 @@ private ClientConnectionContext RemoveClientConnection(string connectionId) { _connectionIds.TryRemove(connectionId, out _); _clientConnectionManager.TryRemoveClientConnection(connectionId, out var connection); +#if NET7_0_OR_GREATER + _clientInvocationManager.Router.CleanupInvocationsByConnection(connectionId); +#endif return connection; } + + private Task OnClientInvocationAsync(ClientInvocationMessage message) + { + _clientInvocationManager.Router.AddInvocation(message.ConnectionId, message.InvocationId, message.CallerServerId, default); + return Task.CompletedTask; + } + + private Task OnServiceMappingAsync(ServiceMappingMessage message) + { + _clientInvocationManager.Caller.AddServiceMapping(message); + return Task.CompletedTask; + } + + private Task OnClientCompletionAsync(ClientCompletionMessage clientCompletionMessage) + { + _clientInvocationManager.Caller.TryCompleteResult(clientCompletionMessage.ConnectionId, clientCompletionMessage); + return Task.CompletedTask; + } + + private Task OnErrorCompletionAsync(ErrorCompletionMessage errorCompletionMessage) + { + _clientInvocationManager.Caller.TryCompleteResult(errorCompletionMessage.ConnectionId, errorCompletionMessage); + return Task.CompletedTask; + } } } \ No newline at end of file diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnectionFactory.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnectionFactory.cs index d8b454d0b..2df05934f 100644 --- a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnectionFactory.cs +++ b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnectionFactory.cs @@ -16,6 +16,7 @@ internal class ServiceConnectionFactory : IServiceConnectionFactory private readonly IClientConnectionFactory _clientConnectionFactory; private readonly IServerNameProvider _nameProvider; private readonly IServiceEventHandler _serviceEventHandler; + private readonly IClientInvocationManager _clientInvocationManager; public GracefulShutdownMode ShutdownMode { get; set; } = GracefulShutdownMode.Off; @@ -29,7 +30,8 @@ public ServiceConnectionFactory( ConnectionDelegate connectionDelegate, IClientConnectionFactory clientConnectionFactory, IServerNameProvider nameProvider, - IServiceEventHandler serviceEventHandler) + IServiceEventHandler serviceEventHandler, + IClientInvocationManager clientInvocationManager) { _serviceProtocol = serviceProtocol; _clientConnectionManager = clientConnectionManager; @@ -39,6 +41,7 @@ public ServiceConnectionFactory( _clientConnectionFactory = clientConnectionFactory; _nameProvider = nameProvider; _serviceEventHandler = serviceEventHandler; + _clientInvocationManager = clientInvocationManager; } public virtual IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, ServiceConnectionType type) @@ -55,6 +58,7 @@ public virtual IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMe endpoint, serviceMessageHandler, _serviceEventHandler, + _clientInvocationManager, type, ShutdownMode ) diff --git a/test/Microsoft.Azure.SignalR.IntegrationTests/Infrastructure/MockServiceConnectionFactory.cs b/test/Microsoft.Azure.SignalR.IntegrationTests/Infrastructure/MockServiceConnectionFactory.cs index 6d3142361..550402cfc 100644 --- a/test/Microsoft.Azure.SignalR.IntegrationTests/Infrastructure/MockServiceConnectionFactory.cs +++ b/test/Microsoft.Azure.SignalR.IntegrationTests/Infrastructure/MockServiceConnectionFactory.cs @@ -19,6 +19,7 @@ public MockServiceConnectionFactory( ILoggerFactory loggerFactory, ConnectionDelegate connectionDelegate, IClientConnectionFactory clientConnectionFactory, + IClientInvocationManager clientInvocationManager, IServerNameProvider nameProvider) : base( serviceProtocol, @@ -28,7 +29,8 @@ public MockServiceConnectionFactory( connectionDelegate, clientConnectionFactory, nameProvider, - null) + null, + clientInvocationManager) { _mockService = mockService; } diff --git a/test/Microsoft.Azure.SignalR.IntegrationTests/Infrastructure/MockServiceHubDispatcher.cs b/test/Microsoft.Azure.SignalR.IntegrationTests/Infrastructure/MockServiceHubDispatcher.cs index b1cc1e4ab..7f53fde7b 100644 --- a/test/Microsoft.Azure.SignalR.IntegrationTests/Infrastructure/MockServiceHubDispatcher.cs +++ b/test/Microsoft.Azure.SignalR.IntegrationTests/Infrastructure/MockServiceHubDispatcher.cs @@ -19,12 +19,14 @@ internal class MockServiceHubDispatcher : ServiceHubDispatcher private IClientConnectionManager _clientConnectionManager; private IServiceProtocol _serviceProtocol; private IClientConnectionFactory _clientConnectionFactory; + private IClientInvocationManager _clientInvocationManager; public MockServiceHubDispatcher( IServiceProtocol serviceProtocol, IHubContext context, IServiceConnectionManager serviceConnectionManager, IClientConnectionManager clientConnectionManager, + IClientInvocationManager clientInvocationManager, IServiceEndpointManager serviceEndpointManager, IOptions options, ILoggerFactory loggerFactory, @@ -43,6 +45,7 @@ public MockServiceHubDispatcher( nameProvider, serverLifetimeManager, clientConnectionFactory, + clientInvocationManager, null) { MockService = new ConnectionTrackingMockService(); @@ -52,11 +55,12 @@ public MockServiceHubDispatcher( _clientConnectionManager = clientConnectionManager; _serviceProtocol = serviceProtocol; _clientConnectionFactory = clientConnectionFactory; + _clientInvocationManager = clientInvocationManager; } internal override ServiceConnectionFactory GetServiceConnectionFactory( ConnectionFactory connectionFactory, ConnectionDelegate connectionDelegate, Action contextConfig - ) => new MockServiceConnectionFactory(MockService, _serviceProtocol, _clientConnectionManager, connectionFactory, _loggerFactory, connectionDelegate, _clientConnectionFactory, _nameProvider); + ) => new MockServiceConnectionFactory(MockService, _serviceProtocol, _clientConnectionManager, connectionFactory, _loggerFactory, connectionDelegate, _clientConnectionFactory, _clientInvocationManager, _nameProvider); // this is the gateway for the tests to control the mock service side public IMockService MockService { diff --git a/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnection.cs b/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnection.cs index ab6c77923..b6d3376fa 100644 --- a/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnection.cs +++ b/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnection.cs @@ -31,7 +31,8 @@ internal class TestServiceConnection : ServiceConnectionBase public TestServiceConnection(ServiceConnectionStatus status = ServiceConnectionStatus.Connected, bool throws = false, ILogger logger = null, IServiceMessageHandler serviceMessageHandler = null, - IServiceEventHandler serviceEventHandler = null + IServiceEventHandler serviceEventHandler = null, + IClientInvocationManager clientInvocationManager = null ) : base( new ServiceProtocol(), "serverId", diff --git a/test/Microsoft.Azure.SignalR.Tests/ClientInvocation/ClientInvocationManagerTest.cs b/test/Microsoft.Azure.SignalR.Tests/ClientInvocation/ClientInvocationManagerTest.cs deleted file mode 100644 index 4a67d390c..000000000 --- a/test/Microsoft.Azure.SignalR.Tests/ClientInvocation/ClientInvocationManagerTest.cs +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -#if NET7_0_OR_GREATER -using System; -using System.Dynamic; -using System.Threading; -using Microsoft.AspNetCore.SignalR; -using Microsoft.AspNetCore.SignalR.Internal; -using Microsoft.AspNetCore.SignalR.Protocol; -using Microsoft.Azure.SignalR.Protocol; -using Microsoft.Extensions.Logging.Abstractions; -using Xunit; - -namespace Microsoft.Azure.SignalR.Common.Tests -{ - public class ClientInvocationManagerTest - { - private static readonly IHubProtocolResolver HubProtocolResolver = - new DefaultHubProtocolResolver(new IHubProtocol[] - { - new JsonHubProtocol(), - new MessagePackHubProtocol() - }, - NullLogger.Instance - ); - - [Fact] - /* - * Client 1 <--> --------- - * | Pod 1 | <--> Server A - * Client 2 <--> --------- - * - * Note: Client 1 and Client 2 are both managed by Server A - */ - public async void TestNormalCompleteWithoutRouterServer() - { - var connectionId = "Connection-0"; - var invocationResult = "invocation-success-result"; - var targetClientInstanceId = "Instance 1"; - ClientInvocationManager clientInvocationManager = new ClientInvocationManager(HubProtocolResolver); - var invocationId = clientInvocationManager.Caller.GenerateInvocationId(connectionId); - - CancellationToken cancellationToken = new CancellationToken(); - // Server A knows the InstanceId of Client 2, so `instaceId` in `AddInvocation` is `targetClientInstanceId` ("Instance 1") - var task = clientInvocationManager.Caller.AddInvocation(connectionId, invocationId, targetClientInstanceId, cancellationToken); - - var ret = clientInvocationManager.Caller.TryGetInvocationReturnType(invocationId, out Type T); - - Assert.True(ret); - Assert.Equal(typeof(string), T); - - var completionMessage = new CompletionMessage(invocationId, null, invocationResult, true); - ret = clientInvocationManager.Caller.TryCompleteResult(connectionId, completionMessage); - Assert.True(ret); - - await task; - Assert.Equal(invocationResult, task.Result); - } - - [Theory] - [InlineData("json")] - [InlineData("messagepack")] - /* --------- <--> Client 2 - * Server 1 <--> Pod 1 <--> | Pod 2 | - * --------- <--> Server 2 - * - * Note: Server 2 manages Client 2. - */ - public async void TestNormalCompleteWithRouterServer(string protocol) - { - var instanceIds = new string[] { "Instance-0", "Instance-1" }; - var serverIds = new string[] { "Server-0", "Server-1" }; - var connectionIds = new string[] { "Connection-0", "Connection-1" }; - var invocationResult = "invocation-success-result"; - var ciManagers = new ClientInvocationManager[] - { - new ClientInvocationManager(HubProtocolResolver), - new ClientInvocationManager(HubProtocolResolver), - }; - var invocationId = ciManagers[0].Caller.GenerateInvocationId(connectionIds[0]); - var completionMessage = new CompletionMessage(invocationId, null, invocationResult, true); - - CancellationToken cancellationToken = new CancellationToken(); - // Server 1 doesn't know the InstanceId of Client 2, so `instaceId` is null for `AddInvocation` - var task = ciManagers[0].Caller.AddInvocation(connectionIds[0], invocationId, null, cancellationToken); - ciManagers[0].Caller.AddServiceMapping(new ServiceMappingMessage(invocationId, connectionIds[1], instanceIds[1])); - ciManagers[1].Router.AddInvocation(connectionIds[1], invocationId, serverIds[0], new CancellationToken()); - - var ret = ciManagers[1].Router.TryCompleteResult(connectionIds[1], completionMessage); - Assert.True(ret); - - var payload = GetBytes(protocol, completionMessage); - var clientCompletionMessage = new ClientCompletionMessage(invocationId, connectionIds[0], serverIds[1], protocol, payload); - - ret = ciManagers[0].Caller.TryCompleteResult(clientCompletionMessage.ConnectionId, clientCompletionMessage); - Assert.True(ret); - - await task; - - Assert.Equal(invocationResult, task.Result); - } - - internal static ReadOnlyMemory GetBytes(string proto, HubMessage message) - { - IHubProtocol hubProtocol = proto == "json" ? new JsonHubProtocol() : new MessagePackHubProtocol(); - return hubProtocol.GetMessageBytes(message); - } - - } -} -#endif \ No newline at end of file diff --git a/test/Microsoft.Azure.SignalR.Tests/ClientInvocationManagerTests.cs b/test/Microsoft.Azure.SignalR.Tests/ClientInvocationManagerTests.cs new file mode 100644 index 000000000..3be639312 --- /dev/null +++ b/test/Microsoft.Azure.SignalR.Tests/ClientInvocationManagerTests.cs @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +#if NET7_0_OR_GREATER +using System; +using System.Collections.Generic; +using System.Dynamic; +using System.Linq; +using System.Threading; +using Microsoft.AspNetCore.SignalR; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.Azure.SignalR.Protocol; +using Microsoft.Extensions.Logging.Abstractions; +using Xunit; + +namespace Microsoft.Azure.SignalR +{ + public class ClientInvocationManagerTests + { + private static readonly IHubProtocolResolver HubProtocolResolver = + new DefaultHubProtocolResolver(new IHubProtocol[] + { + new JsonHubProtocol(), + new MessagePackHubProtocol() + }, + NullLogger.Instance + ); + + private static readonly List TestConnectionIds = new() { "conn0", "conn1" }; + private static readonly List TestInstanceIds = new() { "instance0", "instance1" }; + private static readonly List TestServerIds = new() { "server1", "server2" }; + + [Theory] + [InlineData(true)] + [InlineData(false)] + /* + * Client 1 <--> --------- + * | Pod 1 | <--> Server A + * Client 2 <--> --------- + * + * Note: Client 1 and Client 2 are both managed by Server A + */ + public async void TestCompleteWithoutRouterServer(bool isCompletionWithResult) + { + var connectionId = TestConnectionIds[0]; + var targetClientInstanceId = TestInstanceIds[0]; + var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver); + var invocationId = clientInvocationManager.Caller.GenerateInvocationId(connectionId); + var invocationResult = "invocation-correct-result"; + + CancellationToken cancellationToken = new CancellationToken(); + // Server A knows the InstanceId of Client 2, so `instaceId` in `AddInvocation` is `targetClientInstanceId` + var task = clientInvocationManager.Caller.AddInvocation(connectionId, invocationId, cancellationToken); + + var ret = clientInvocationManager.Caller.TryGetInvocationReturnType(invocationId, out var t); + + Assert.True(ret); + Assert.Equal(typeof(string), t); + + var completionMessage = isCompletionWithResult + ? CompletionMessage.WithResult(invocationId, invocationResult) + : CompletionMessage.WithError(invocationId, invocationResult); + + ret = clientInvocationManager.Caller.TryCompleteResult(connectionId, completionMessage); + Assert.True(ret); + + try + { + await task; + Assert.True(isCompletionWithResult); + Assert.Equal(invocationResult, task.Result); + } + catch (Exception e) + { + Assert.False(isCompletionWithResult); + Assert.Equal(invocationResult, e.Message); + } + } + + [Theory] + [InlineData("json", true)] + [InlineData("json", false)] + [InlineData("messagepack", true)] + [InlineData("messagepack", false)] + /* --------- <--> Client 2 + * Server 1 <--> Pod 1 <--> | Pod 2 | + * --------- <--> Server 2 + * + * Note: Server 2 manages Client 2. + */ + public async void TestCompleteWithRouterServer(string protocol, bool isCompletionWithResult) + { + var serverIds = new string[] { TestServerIds[0], TestServerIds[1] }; + var invocationResult = "invocation-correct-result"; + var ciManagers = new ClientInvocationManager[] + { + new ClientInvocationManager(HubProtocolResolver), + new ClientInvocationManager(HubProtocolResolver), + }; + var invocationId = ciManagers[0].Caller.GenerateInvocationId(TestConnectionIds[0]); + + CancellationToken cancellationToken = new CancellationToken(); + // Server 1 doesn't know the InstanceId of Client 2, so `instaceId` is null for `AddInvocation` + var task = ciManagers[0].Caller.AddInvocation(TestConnectionIds[0], invocationId, cancellationToken); + ciManagers[0].Caller.AddServiceMapping(new ServiceMappingMessage(invocationId, TestConnectionIds[1], TestInstanceIds[1])); + ciManagers[1].Router.AddInvocation(TestConnectionIds[1], invocationId, serverIds[0], new CancellationToken()); + + var completionMessage = isCompletionWithResult + ? CompletionMessage.WithResult(invocationId, invocationResult) + : CompletionMessage.WithError(invocationId, invocationResult); + + var ret = ciManagers[1].Router.TryCompleteResult(TestConnectionIds[1], completionMessage); + Assert.True(ret); + + var payload = GetBytes(protocol, completionMessage); + var clientCompletionMessage = new ClientCompletionMessage(invocationId, TestConnectionIds[0], serverIds[1], protocol, payload); + + ret = ciManagers[0].Caller.TryCompleteResult(clientCompletionMessage.ConnectionId, clientCompletionMessage); + Assert.True(ret); + + try + { + await task; + Assert.True(isCompletionWithResult); + Assert.Equal(invocationResult, task.Result); + } + catch (Exception e) + { + Assert.False(isCompletionWithResult); + Assert.Equal(invocationResult, e.Message); + } + } + + [Fact] + public void TestCallerManagerCancellation() + { + var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver); + var invocationId = clientInvocationManager.Caller.GenerateInvocationId(TestConnectionIds[0]); + var cts = new CancellationTokenSource(); + var task = clientInvocationManager.Caller.AddInvocation(TestConnectionIds[0], invocationId, cts.Token); + + // Check if the invocation is existing + Assert.True(clientInvocationManager.Caller.TryGetInvocationReturnType(invocationId, out _)); + // Cancel the invocation by CancellationToken + cts.Cancel(true); + // Check if the invocation task has the information + Assert.Equal("One or more errors occurred. (Canceled)", task.Exception.Message); + Assert.True(task.IsFaulted); + // Check if the invocation was removed + Assert.False(clientInvocationManager.Caller.TryGetInvocationReturnType(invocationId, out _)); + } + + internal static ReadOnlyMemory GetBytes(string proto, HubMessage message) + { + IHubProtocol hubProtocol = proto == "json" ? new JsonHubProtocol() : new MessagePackHubProtocol(); + return hubProtocol.GetMessageBytes(message); + } + + } +} +#endif \ No newline at end of file diff --git a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/DefaultClientInvocationManager.cs b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/DefaultClientInvocationManager.cs new file mode 100644 index 000000000..43685435d --- /dev/null +++ b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/DefaultClientInvocationManager.cs @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Microsoft.Azure.SignalR +{ + internal class DefaultClientInvocationManager : IClientInvocationManager + { + public ICallerClientResultsManager Caller { get; } + public IRoutedClientResultsManager Router { get; } + + public DefaultClientInvocationManager() + { + var hubProtocolResolver = new DefaultHubProtocolResolver( + new IHubProtocol[] { + new JsonHubProtocol(), + new MessagePackHubProtocol() + }, + NullLogger.Instance); + + Caller = new CallerClientResultsManager(hubProtocolResolver); + Router = new RoutedClientResultsManager(); + } + } +} diff --git a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/ServiceConnectionProxy.cs b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/ServiceConnectionProxy.cs index 34161183b..a39c5546e 100644 --- a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/ServiceConnectionProxy.cs +++ b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/ServiceConnectionProxy.cs @@ -11,6 +11,8 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Azure.SignalR.Protocol; using Microsoft.Azure.SignalR.Tests.Common; using Microsoft.Extensions.Logging.Abstractions; @@ -26,6 +28,8 @@ internal class ServiceConnectionProxy : IClientConnectionManager, IClientConnect public IClientConnectionManager ClientConnectionManager { get; } + public IClientInvocationManager ClientInvocationManager { get; } + public IServiceConnectionContainer ServiceConnectionContainer { get; } public IServiceMessageHandler ServiceMessageHandler { get; } @@ -55,6 +59,11 @@ public ServiceConnectionProxy( { ConnectionFactory = connectionFactoryCallback?.Invoke(ConnectionFactoryCallbackAsync) ?? new TestConnectionFactory(ConnectionFactoryCallbackAsync); ClientConnectionManager = new ClientConnectionManager(); + ClientInvocationManager = new ClientInvocationManager(new DefaultHubProtocolResolver(new IHubProtocol[] + { + new JsonHubProtocol(), + new MessagePackHubProtocol(), + }, NullLogger.Instance)); _clientPipeOptions = clientPipeOptions; ConnectionDelegateCallback = callback ?? OnConnectionAsync; @@ -81,6 +90,7 @@ public IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHan endpoint, serviceMessageHandler, null, + ClientInvocationManager, type); ServiceConnections.TryAdd(connectionId, connection); return connection; diff --git a/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs b/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs index b0db0d438..cd2a1c60e 100644 --- a/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs @@ -1553,14 +1553,15 @@ public async Task ServiceConnectionContainerScopeWithPingUpdateTest() new ServiceEndpoint(ConnectionString2, EndpointType.Primary, "2") ); var endpoints = sem.GetEndpoints("hub"); + var clientInvocationManager = new DefaultClientInvocationManager(); var connection1 = new ServiceConnection(protocol, ccm, connectionFactory1, loggerFactory, connectionDelegate, ccf, - "serverId", "server-conn-1", endpoints[0], endpoints[0].ConnectionContainer as IServiceMessageHandler, null, closeTimeOutMilliseconds: 500); + "serverId", "server-conn-1", endpoints[0], endpoints[0].ConnectionContainer as IServiceMessageHandler, null, clientInvocationManager, closeTimeOutMilliseconds: 500); var connection2 = new ServiceConnection(protocol, ccm, connectionFactory2, loggerFactory, connectionDelegate, ccf, - "serverId", "server-conn-2", endpoints[1], endpoints[1].ConnectionContainer as IServiceMessageHandler, null, closeTimeOutMilliseconds: 500); + "serverId", "server-conn-2", endpoints[1], endpoints[1].ConnectionContainer as IServiceMessageHandler, null, clientInvocationManager, closeTimeOutMilliseconds: 500); var connection22 = new ServiceConnection(protocol, ccm, connectionFactory22, loggerFactory, connectionDelegate, ccf, - "serverId", "server-conn-22", endpoints[1], endpoints[1].ConnectionContainer as IServiceMessageHandler, null, closeTimeOutMilliseconds: 500); + "serverId", "server-conn-22", endpoints[1], endpoints[1].ConnectionContainer as IServiceMessageHandler, null, clientInvocationManager, closeTimeOutMilliseconds: 500); var router = new TestEndpointRouter(); diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionFacts.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionFacts.cs index cfe2efa32..aca12b753 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionFacts.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionFacts.cs @@ -513,7 +513,7 @@ public async Task ServiceConnectionInitializationDeadlockTest() try { SynchronizationContext.SetSynchronizationContext(null); - var conn = new TestServiceConnection(); + var conn = new TestServiceConnection(clientInvocationManager: new DefaultClientInvocationManager()); var initTask = conn.StartAsync(); await conn.ConnectionInitializedTask; conn.Stop(); diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionTests.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionTests.cs index a064b4c12..8f5aa9099 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionTests.cs @@ -47,7 +47,7 @@ public async Task TestServiceConnectionWithNormalApplicationTask() builder.UseConnectionHandler(); ConnectionDelegate handler = builder.Build(); var connection = new ServiceConnection(protocol, ccm, connectionFactory, loggerFactory, handler, ccf, - "serverId", Guid.NewGuid().ToString("N"), null, null, null); + "serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager()); var connectionTask = connection.StartAsync(); @@ -103,7 +103,7 @@ public async Task TestServiceConnectionErrorCleansAllClients() builder.UseConnectionHandler(); ConnectionDelegate handler = builder.Build(); var connection = new ServiceConnection(protocol, ccm, connectionFactory, loggerFactory, handler, ccf, - "serverId", Guid.NewGuid().ToString("N"), null, null, null); + "serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager()); var connectionTask = connection.StartAsync(); @@ -159,7 +159,7 @@ public async Task TestServiceConnectionWithErrorApplicationTask() ConnectionDelegate handler = builder.Build(); var connection = new ServiceConnection(protocol, ccm, connectionFactory, loggerFactory, handler, ccf, - "serverId", Guid.NewGuid().ToString("N"), null, null, null); + "serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager()); var connectionTask = connection.StartAsync(); @@ -222,7 +222,7 @@ public async Task TestServiceConnectionWithEndlessApplicationTaskNeverEnds() ConnectionDelegate handler = builder.Build(); var connection = new ServiceConnection(protocol, ccm, connectionFactory, loggerFactory, handler, ccf, "serverId", Guid.NewGuid().ToString("N"), - null, null, null, closeTimeOutMilliseconds: 1); + null, null, null, new DefaultClientInvocationManager(), closeTimeOutMilliseconds: 1); var connectionTask = connection.StartAsync(); @@ -277,7 +277,7 @@ public async Task ClientConnectionOutgoingAbortCanEndLifeTime() builder.UseConnectionHandler(); ConnectionDelegate handler = builder.Build(); var connection = new ServiceConnection(protocol, ccm, connectionFactory, loggerFactory, handler, ccf, - "serverId", Guid.NewGuid().ToString("N"), null, null, null, + "serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager(), closeTimeOutMilliseconds: 500); var connectionTask = connection.StartAsync(); @@ -335,7 +335,7 @@ public async Task ClientConnectionContextAbortCanSendOutCloseMessage() ConnectionDelegate handler = builder.Build(); var connection = new ServiceConnection(protocol, ccm, connectionFactory, loggerFactory, handler, ccf, - "serverId", Guid.NewGuid().ToString("N"), null, null, null, closeTimeOutMilliseconds: 500); + "serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager(), closeTimeOutMilliseconds: 500); var connectionTask = connection.StartAsync(); @@ -398,7 +398,7 @@ public async Task ClientConnectionWithDiagnosticClientTagTest() ConnectionDelegate handler = builder.Build(); var connection = new ServiceConnection(protocol, ccm, connectionFactory, loggerFactory, handler, ccf, - "serverId", Guid.NewGuid().ToString("N"), null, null, null, closeTimeOutMilliseconds: 500); + "serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager(), closeTimeOutMilliseconds: 500); var connectionTask = connection.StartAsync(); @@ -456,7 +456,7 @@ public async Task ClientConnectionLastWillCanSendOut() builder.UseConnectionHandler(); ConnectionDelegate handler = builder.Build(); var connection = new ServiceConnection(protocol, ccm, connectionFactory, loggerFactory, handler, ccf, - "serverId", Guid.NewGuid().ToString("N"), null, null, null, closeTimeOutMilliseconds: 500); + "serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager(), closeTimeOutMilliseconds: 500); var connectionTask = connection.StartAsync(); diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceHubDispatcherTests.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceHubDispatcherTests.cs index 4bc52905a..070836061 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceHubDispatcherTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceHubDispatcherTests.cs @@ -40,6 +40,7 @@ public async void TestShutdown() null, null, null, + null, null ); diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs index fdb69f6b7..9d9001079 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs @@ -65,8 +65,9 @@ public async void ServiceLifetimeManagerTest(string functionName, Type type) { var serviceConnectionManager = new TestServiceConnectionManager(); var blazorDetector = new DefaultBlazorDetector(); + var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver); var serviceLifetimeManager = new ServiceLifetimeManager(serviceConnectionManager, - new ClientConnectionManager(), HubProtocolResolver, Logger, Marker, _globalHubOptions, _localHubOptions, blazorDetector); + new ClientConnectionManager(), HubProtocolResolver, Logger, Marker, _globalHubOptions, _localHubOptions, blazorDetector, new DefaultServerNameProvider(), clientInvocationManager); await InvokeMethod(serviceLifetimeManager, functionName); @@ -85,6 +86,7 @@ public async void ServiceLifetimeManagerGroupTest(string functionName, Type type { var serviceConnectionManager = new TestServiceConnectionManager(); var blazorDetector = new DefaultBlazorDetector(); + var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver); var serviceLifetimeManager = new ServiceLifetimeManager( serviceConnectionManager, new ClientConnectionManager(), @@ -93,7 +95,10 @@ public async void ServiceLifetimeManagerGroupTest(string functionName, Type type Marker, _globalHubOptions, _localHubOptions, - blazorDetector); + blazorDetector, + new DefaultServerNameProvider(), + clientInvocationManager + ); await InvokeMethod(serviceLifetimeManager, functionName); @@ -121,9 +126,10 @@ public async void ServiceLifetimeManagerIntegrationTest(string methodName, Type var serviceConnectionManager = new ServiceConnectionManager(); serviceConnectionManager.SetServiceConnection(proxy.ServiceConnectionContainer); + var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver); var serviceLifetimeManager = new ServiceLifetimeManager(serviceConnectionManager, - proxy.ClientConnectionManager, HubProtocolResolver, Logger, Marker, _globalHubOptions, _localHubOptions, blazorDetector); + proxy.ClientConnectionManager, HubProtocolResolver, Logger, Marker, _globalHubOptions, _localHubOptions, blazorDetector, new DefaultServerNameProvider(), clientInvocationManager); var serverTask = proxy.WaitForServerConnectionAsync(1); _ = proxy.StartAsync(); @@ -167,8 +173,9 @@ public async void ServiceLifetimeManagerIgnoreBlazorHubProtocolTest(string funct IOptions globalHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List() { "json", "messagepack", MockProtocol, "json" } }); IOptions> localHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List() { "json", "messagepack", MockProtocol } }); var serviceConnectionManager = new TestServiceConnectionManager(); + var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver); var serviceLifetimeManager = new ServiceLifetimeManager(serviceConnectionManager, - new ClientConnectionManager(), protocolResolver, Logger, Marker, globalHubOptions, localHubOptions, blazorDetector); + new ClientConnectionManager(), protocolResolver, Logger, Marker, globalHubOptions, localHubOptions, blazorDetector, new DefaultServerNameProvider(), clientInvocationManager); await InvokeMethod(serviceLifetimeManager, functionName); @@ -265,6 +272,9 @@ private HubLifetimeManager MockLifetimeManager(IServiceConnectionManage ); IOptions globalHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List() { MockProtocol } }); IOptions> localHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List() { MockProtocol } }); + + var clientInvocationManager = new ClientInvocationManager(protocolResolver); + return new ServiceLifetimeManager( serviceConnectionManager, clientConnectionManager, @@ -273,10 +283,152 @@ private HubLifetimeManager MockLifetimeManager(IServiceConnectionManage Marker, globalHubOptions, localHubOptions, - blazorDetector + blazorDetector, + new DefaultServerNameProvider(), + clientInvocationManager ); } +#if NET7_0_OR_GREATER + private static ServiceLifetimeManager GetTestClientInvocationServiceLifetimeManager( + ServiceConnectionBase serviceConnection, + IServiceConnectionManager serviceConnectionManager, + ClientConnectionManager clientConnectionManager, + ClientInvocationManager clientInvocationManager = null, + ClientConnectionContext clientConnectionContext = null, + string protocol = "json" + ) + { + // Add a client to ClientConnectionManager + if (clientConnectionContext != null) + { + clientConnectionContext.ServiceConnection = serviceConnection; + clientConnectionManager.TryAddClientConnection(clientConnectionContext); + } + + // Create ServiceLifetimeManager + return new ServiceLifetimeManager(serviceConnectionManager, + clientConnectionManager, HubProtocolResolver, Logger, Marker, _globalHubOptions, _localHubOptions, null, new DefaultServerNameProvider(), clientInvocationManager ?? new ClientInvocationManager(HubProtocolResolver)); + } + + private static ClientConnectionContext GetClientConnectionContextWithConnection(string connectionId = null, string protocol = null) + { + var connectMessage = new OpenConnectionMessage(connectionId, new Claim[] { }); + connectMessage.Protocol = protocol; + return new ClientConnectionContext(connectMessage); + } + + + [Theory] + [InlineData("json", true)] + [InlineData("json", false)] + [InlineData("messagepack", true)] + [InlineData("messagepack", false)] + public async void TestClientInvocationOneService(string protocol, bool isCompletionWithResult) + { + var serviceConnection = new TestServiceConnection(); + var serviceConnectionManager = new TestServiceConnectionManager(); + + var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver); + var clientConnectionContext = GetClientConnectionContextWithConnection(TestConnectionIds[1], protocol); + + var serviceLifetimeManager = GetTestClientInvocationServiceLifetimeManager(serviceConnection, serviceConnectionManager, new ClientConnectionManager(), clientInvocationManager, clientConnectionContext, protocol); + + var invocationResult = "invocation-correct-result"; + + // Invoke the client + var task = serviceLifetimeManager.InvokeConnectionAsync(TestConnectionIds[1], "InvokedMethod", Array.Empty(), default); + + // Check if the caller server sent a ClientInvocationMessage + Assert.IsType(serviceConnectionManager.ServiceMessage); + var invocation = (ClientInvocationMessage)serviceConnectionManager.ServiceMessage; + + // Check if the caller server added the invocation + Assert.True(clientInvocationManager.Caller.TryGetInvocationReturnType(invocation.InvocationId, out _)); + + // Complete the invocation by SerivceLifetimeManager + var completionMessage = isCompletionWithResult + ? CompletionMessage.WithResult(invocation.InvocationId, invocationResult) + : CompletionMessage.WithError(invocation.InvocationId, invocationResult); + + await serviceLifetimeManager.SetConnectionResultAsync(invocation.ConnectionId, completionMessage); + // Check if the caller server sent a ClientCompletionMessage + Assert.IsType(serviceConnectionManager.ServiceMessage); + + // Check if the invocation result is correct + try + { + await task; + Assert.True(isCompletionWithResult); + Assert.Equal(invocationResult, task.Result); + } + catch (Exception e) + { + Assert.False(isCompletionWithResult); + Assert.Equal(invocationResult, e.Message); + } + } + + [Theory] + [InlineData("json", true)] + [InlineData("json", false)] + [InlineData("messagepack", true)] + [InlineData("messagepack", false)] + public async void TestMultiClientInvocationsMultipleService(string protocol, bool isCompletionWithResult) + { + var clientConnectionContext = GetClientConnectionContextWithConnection(TestConnectionIds[1], protocol); + var clientConnectionManager = new ClientConnectionManager(); + + var serviceConnectionManager = new TestServiceConnectionManager(); + var clientInvocationManagers = new List() { + new ClientInvocationManager(HubProtocolResolver), + new ClientInvocationManager(HubProtocolResolver) + }; + + var serviceLifetimeManagers = new List>() { + GetTestClientInvocationServiceLifetimeManager( new TestServiceConnection(), serviceConnectionManager, clientConnectionManager, clientInvocationManagers[0], null, protocol), + GetTestClientInvocationServiceLifetimeManager( new TestServiceConnection(), serviceConnectionManager, clientConnectionManager, clientInvocationManagers[1], clientConnectionContext, protocol) + }; + + var invocationResult = "invocation-correct-result"; + + // Invoke a client + var task = serviceLifetimeManagers[0].InvokeConnectionAsync(TestConnectionIds[1], "InvokedMethod", Array.Empty()); + var invocation = (ClientInvocationMessage)serviceConnectionManager.ServiceMessage; + // Check if the invocation was added to caller server + Assert.True(clientInvocationManagers[0].Caller.TryGetInvocationReturnType(invocation.InvocationId, out _)); + + // Route server adds invocation + clientInvocationManagers[1].Router.AddInvocation(TestConnectionIds[1], invocation.InvocationId, "server-0", default); + // check if the invocation was adder to route server + Assert.True(clientInvocationManagers[1].Router.TryGetInvocationReturnType(invocation.InvocationId, out _)); + + // The route server receives CompletionMessage + var completionMessage = isCompletionWithResult + ? CompletionMessage.WithResult(invocation.InvocationId, invocationResult) + : CompletionMessage.WithError(invocation.InvocationId, invocationResult); + await serviceLifetimeManagers[1].SetConnectionResultAsync(invocation.ConnectionId, completionMessage); + + // Check if the router server sent ClientCompletionMessage + Assert.IsType(serviceConnectionManager.ServiceMessage); + var clientCompletionMessage = (ClientCompletionMessage)serviceConnectionManager.ServiceMessage; + + clientInvocationManagers[0].Caller.TryCompleteResult(clientCompletionMessage.ConnectionId, clientCompletionMessage); + + try + { + await task; + Assert.True(isCompletionWithResult); + Assert.Equal(invocationResult, task.Result); + } + catch (Exception e) + { + Assert.False(isCompletionWithResult); + Assert.Equal(invocationResult, e.Message); + } + } +#endif + private static async Task InvokeMethod(HubLifetimeManager serviceLifetimeManager, string methodName) { switch (methodName) diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs index 0d2aa5cf8..fd6b38ccb 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs @@ -39,7 +39,8 @@ public ServiceMessageTests(ITestOutputHelper output) : base(output) public async Task TestOpenConnectionMessageWithMigrateIn() { var clientConnectionFactory = new TestClientConnectionFactory(); - var connection = CreateServiceConnection(clientConnectionFactory: clientConnectionFactory); + var clientInvocationManager = new DefaultClientInvocationManager(); + var connection = CreateServiceConnection(clientConnectionFactory: clientConnectionFactory, clientInvocationManager: clientInvocationManager); _ = connection.StartAsync(); await connection.ConnectionInitializedTask.OrTimeout(); @@ -73,8 +74,9 @@ public async Task TestOpenConnectionMessageWithMigrateIn() public async Task TestCloseConnectionMessageWithMigrateOut() { var clientConnectionFactory = new TestClientConnectionFactory(); + var clientInvocationManager = new DefaultClientInvocationManager(); - var connection = CreateServiceConnection(clientConnectionFactory: clientConnectionFactory, handler: new TestConnectionHandler(3000, "foobar")); + var connection = CreateServiceConnection(clientConnectionFactory: clientConnectionFactory, handler: new TestConnectionHandler(3000, "foobar"), clientInvocationManager: clientInvocationManager); _ = connection.StartAsync(); await connection.ConnectionInitializedTask.OrTimeout(1000); @@ -113,8 +115,9 @@ public async Task TestCloseConnectionMessageWithMigrateOut() public async Task TestCloseConnectionMessage() { var clientConnectionFactory = new TestClientConnectionFactory(); + var clientInvocationManager = new DefaultClientInvocationManager(); - var connection = CreateServiceConnection(clientConnectionFactory: clientConnectionFactory, handler: new TestConnectionHandler(3000, "foobar")); + var connection = CreateServiceConnection(clientConnectionFactory: clientConnectionFactory, handler: new TestConnectionHandler(3000, "foobar"), clientInvocationManager: clientInvocationManager); _ = connection.StartAsync(); await connection.ConnectionInitializedTask.OrTimeout(1000); @@ -167,8 +170,9 @@ public async Task TestAccessKeyRequestMessage(Type keyType) var endpoint = MockServiceEndpoint(keyType.Name); Assert.IsAssignableFrom(keyType, endpoint.AccessKey); var hubServiceEndpoint = new HubServiceEndpoint("foo", null, endpoint); + var clientInvocationManager = new DefaultClientInvocationManager(); - var connection = CreateServiceConnection(hubServiceEndpoint: hubServiceEndpoint); + var connection = CreateServiceConnection(hubServiceEndpoint: hubServiceEndpoint, clientInvocationManager: clientInvocationManager); _ = connection.StartAsync(); await connection.ConnectionInitializedTask.OrTimeout(1000); @@ -191,8 +195,9 @@ public async Task TestAccessKeyResponseMessage(Type keyType) var endpoint = MockServiceEndpoint(keyType.Name); Assert.IsAssignableFrom(keyType, endpoint.AccessKey); var hubServiceEndpoint = new HubServiceEndpoint("foo", null, endpoint); + var clientInvocationManager = new DefaultClientInvocationManager(); - var connection = CreateServiceConnection(hubServiceEndpoint: hubServiceEndpoint); + var connection = CreateServiceConnection(hubServiceEndpoint: hubServiceEndpoint, clientInvocationManager: clientInvocationManager); _ = connection.StartAsync(); await connection.ConnectionInitializedTask.OrTimeout(1000); @@ -239,7 +244,9 @@ public async Task TestAccessKeyResponseMessageWithError(int minutesElapsed, int field.SetValue(key, DateTime.UtcNow - TimeSpan.FromMinutes(minutesElapsed)); } - var connection = CreateServiceConnection(loggerFactory: loggerFactory, hubServiceEndpoint: endpoint); + var clientInvocationManager = new DefaultClientInvocationManager(); + + var connection = CreateServiceConnection(loggerFactory: loggerFactory, hubServiceEndpoint: endpoint, clientInvocationManager: clientInvocationManager); var connectionTask = connection.StartAsync(); await connection.ConnectionInitializedTask.OrTimeout(1000); @@ -275,6 +282,7 @@ private static TestServiceConnection CreateServiceConnection(ConnectionHandler h IServiceMessageHandler messageHandler = null, IServiceEventHandler eventHandler = null, IClientConnectionFactory clientConnectionFactory = null, + IClientInvocationManager clientInvocationManager = null, HubServiceEndpoint hubServiceEndpoint = null, ILoggerFactory loggerFactory = null) { @@ -309,6 +317,7 @@ private static TestServiceConnection CreateServiceConnection(ConnectionHandler h hubServiceEndpoint ?? new TestHubServiceEndpoint(), messageHandler ?? new TestServiceMessageHandler(), eventHandler ?? new TestServiceEventHandler(), + clientInvocationManager, mode: mode ?? GracefulShutdownMode.Off ); } @@ -435,6 +444,7 @@ public TestServiceConnection(TestConnectionContainer container, HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, IServiceEventHandler serviceEventHandler, + IClientInvocationManager clientInvocationManager, ServiceConnectionType connectionType = ServiceConnectionType.Default, GracefulShutdownMode mode = GracefulShutdownMode.Off, int closeTimeOutMilliseconds = 10000) : base( @@ -449,6 +459,7 @@ public TestServiceConnection(TestConnectionContainer container, endpoint, serviceMessageHandler, serviceEventHandler, + clientInvocationManager, connectionType: connectionType, mode: mode, closeTimeOutMilliseconds: closeTimeOutMilliseconds) diff --git a/test/Microsoft.Azure.SignalR.Tests/TestServiceConnectionForCloseAsync.cs b/test/Microsoft.Azure.SignalR.Tests/TestServiceConnectionForCloseAsync.cs index e423f630f..234d5e609 100644 --- a/test/Microsoft.Azure.SignalR.Tests/TestServiceConnectionForCloseAsync.cs +++ b/test/Microsoft.Azure.SignalR.Tests/TestServiceConnectionForCloseAsync.cs @@ -6,7 +6,7 @@ namespace Microsoft.Azure.SignalR.Tests { internal class TestServiceConnectionForCloseAsync : TestServiceConnection { - public TestServiceConnectionForCloseAsync() : base(ServiceConnectionStatus.Connected, false) + public TestServiceConnectionForCloseAsync() : base(ServiceConnectionStatus.Connected, false, clientInvocationManager: new DefaultClientInvocationManager()) { }