diff --git a/build/dependencies.props b/build/dependencies.props
index 2f11bd156..f3bc3a678 100644
--- a/build/dependencies.props
+++ b/build/dependencies.props
@@ -8,6 +8,7 @@
3.0.0
3.1.9
5.0.1
+ 7.0.0-preview.7.22376.6
6.0.10
1.0.4
3.0.0
diff --git a/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs b/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs
index 5ec8e80d9..cc91d53ee 100644
--- a/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs
+++ b/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs
@@ -17,10 +17,9 @@ internal interface ICallerClientResultsManager : IClientResultsManager
///
///
///
- /// The InstanceId of target client the caller server knows when this method is called. If the target client is managed by the caller server, the caller server knows the InstanceId of target client and this parameter is not null. Otherwise, this parameter is null.
///
///
- Task AddInvocation(string connectionId, string invocationId, string instanceId, CancellationToken cancellationToken);
+ Task AddInvocation(string connectionId, string invocationId, CancellationToken cancellationToken);
void AddServiceMapping(ServiceMappingMessage serviceMappingMessage);
@@ -29,5 +28,7 @@ internal interface ICallerClientResultsManager : IClientResultsManager
bool TryCompleteResult(string connectionId, ClientCompletionMessage message);
bool TryCompleteResult(string connectionId, ErrorCompletionMessage message);
+
+ void RemoveInvocation(string invocationId);
}
}
\ No newline at end of file
diff --git a/src/Microsoft.Azure.SignalR.Common/ClientInvocation/IRoutedClientResultsManager.cs b/src/Microsoft.Azure.SignalR.Common/ClientInvocation/IRoutedClientResultsManager.cs
index 1283c3242..e22f73c63 100644
--- a/src/Microsoft.Azure.SignalR.Common/ClientInvocation/IRoutedClientResultsManager.cs
+++ b/src/Microsoft.Azure.SignalR.Common/ClientInvocation/IRoutedClientResultsManager.cs
@@ -9,8 +9,6 @@ internal interface IRoutedClientResultsManager : IClientResultsManager
{
void AddInvocation(string connectionId, string invocationId, string callerServerId, CancellationToken cancellationToken);
- bool ContainsInvocation(string invocationId);
-
void CleanupInvocationsByConnection(string connectionId);
}
}
\ No newline at end of file
diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs
index ac7647817..721b1f0d1 100644
--- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs
+++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs
@@ -295,7 +295,7 @@ protected Task OnServiceErrorAsync(ServiceErrorMessage serviceErrorMessage)
return Task.CompletedTask;
}
- protected Task OnPingMessageAsync(PingMessage pingMessage)
+ protected virtual Task OnPingMessageAsync(PingMessage pingMessage)
{
if (RuntimeServicePingMessage.TryGetOffline(pingMessage, out var instanceId))
{
@@ -550,7 +550,7 @@ private async Task ProcessIncomingAsync(ConnectionContext connection)
}
}
- private Task DispatchMessageAsync(ServiceMessage message)
+ protected virtual Task DispatchMessageAsync(ServiceMessage message)
{
return message switch
{
diff --git a/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs b/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs
index a1bd59978..4748392b1 100644
--- a/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs
+++ b/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs
@@ -31,7 +31,7 @@ public string GenerateInvocationId(string connectionId)
return $"{connectionId}-{_clientResultManagerId}-{Interlocked.Increment(ref _lastInvocationId)}";
}
- public Task AddInvocation(string connectionId, string invocationId, string instanceId, CancellationToken cancellationToken)
+ public Task AddInvocation(string connectionId, string invocationId, CancellationToken cancellationToken)
{
var tcs = new TaskCompletionSourceWithCancellation(
cancellationToken,
@@ -53,7 +53,7 @@ public Task AddInvocation(string connectionId, string invocationId, string
{
tcs.TrySetException(new Exception(completionMessage.Error));
}
- }) { RouterInstanceId = instanceId }
+ })
);
Debug.Assert(result);
@@ -66,18 +66,7 @@ public void AddServiceMapping(ServiceMappingMessage serviceMappingMessage)
{
if (_pendingInvocations.TryGetValue(serviceMappingMessage.InvocationId, out var invocation))
{
- if (invocation.RouterInstanceId == null)
- {
- invocation.RouterInstanceId = serviceMappingMessage.InstanceId;
- }
- else
- {
- // do nothing
- }
- }
- else
- {
- // do nothing
+ invocation.RouterInstanceId = serviceMappingMessage.InstanceId;
}
}
@@ -174,6 +163,11 @@ public bool TryGetInvocationReturnType(string invocationId, out Type type)
return false;
}
+ public void RemoveInvocation(string invocationId)
+ {
+ _pendingInvocations.TryRemove(invocationId, out _);
+ }
+
// Unused, here to honor the IInvocationBinder interface but should never be called
public IReadOnlyList GetParameterTypes(string methodName) => throw new NotImplementedException();
diff --git a/src/Microsoft.Azure.SignalR/ClientInvocation/RoutedClientResultsManager.cs b/src/Microsoft.Azure.SignalR/ClientInvocation/RoutedClientResultsManager.cs
index 2d9990937..78c8b7b8c 100644
--- a/src/Microsoft.Azure.SignalR/ClientInvocation/RoutedClientResultsManager.cs
+++ b/src/Microsoft.Azure.SignalR/ClientInvocation/RoutedClientResultsManager.cs
@@ -40,11 +40,6 @@ public bool TryCompleteResult(string connectionId, CompletionMessage message)
}
}
- public bool ContainsInvocation(string invocationId)
- {
- return _routedInvocations.TryGetValue(invocationId, out _);
- }
-
public void CleanupInvocationsByConnection(string connectionId)
{
foreach (var (invocationId, invocation) in _routedInvocations)
diff --git a/src/Microsoft.Azure.SignalR/DependencyInjectionExtensions.cs b/src/Microsoft.Azure.SignalR/DependencyInjectionExtensions.cs
index ec7e29ffe..f35660863 100644
--- a/src/Microsoft.Azure.SignalR/DependencyInjectionExtensions.cs
+++ b/src/Microsoft.Azure.SignalR/DependencyInjectionExtensions.cs
@@ -88,6 +88,11 @@ private static ISignalRServerBuilder AddAzureSignalRCore(this ISignalRServerBuil
.AddSingleton()
.AddSingleton()
.AddSingleton()
+#if NET7_0_OR_GREATER
+ .AddSingleton()
+#else
+ .AddSingleton()
+#endif
.AddSingleton(typeof(NegotiateHandler<>));
// If a custom router is added, do not add the default router
diff --git a/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs b/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs
index b1b7c989b..a4b7d7c5e 100644
--- a/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs
+++ b/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs
@@ -30,6 +30,7 @@ internal class ServiceHubDispatcher where THub : Hub
private readonly IEndpointRouter _router;
private readonly string _hubName;
private readonly IServiceEventHandler _serviceEventHandler;
+ private readonly IClientInvocationManager _clientInvocationManager;
protected readonly IServerNameProvider _nameProvider;
@@ -45,6 +46,7 @@ public ServiceHubDispatcher(
IServerNameProvider nameProvider,
ServerLifetimeManager serverLifetimeManager,
IClientConnectionFactory clientConnectionFactory,
+ IClientInvocationManager clientInvocationManager,
IServiceEventHandler serviceEventHandler)
{
_serviceProtocol = serviceProtocol;
@@ -62,6 +64,7 @@ public ServiceHubDispatcher(
_nameProvider = nameProvider;
_hubName = typeof(THub).Name;
_serviceEventHandler = serviceEventHandler;
+ _clientInvocationManager = clientInvocationManager;
serverLifetimeManager?.Register(ShutdownAsync);
}
@@ -150,7 +153,8 @@ internal virtual ServiceConnectionFactory GetServiceConnectionFactory(
connectionDelegate,
_clientConnectionFactory,
_nameProvider,
- _serviceEventHandler)
+ _serviceEventHandler,
+ _clientInvocationManager)
{
ConfigureContext = contextConfig,
ShutdownMode = _options.GracefulShutdown.Mode
diff --git a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs
index 676326386..8ee58b737 100644
--- a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs
+++ b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs
@@ -3,10 +3,12 @@
using System;
using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR;
+using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Azure.SignalR.Common;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;
@@ -18,8 +20,9 @@ internal class ServiceLifetimeManager : ServiceLifetimeManagerBase w
{
private const string MarkerNotConfiguredError =
"'AddAzureSignalR(...)' was called without a matching call to 'IApplicationBuilder.UseAzureSignalR(...)'.";
-
+ private readonly IClientInvocationManager _clientInvocationManager;
private readonly IClientConnectionManager _clientConnectionManager;
+ private readonly string _callerId;
public ServiceLifetimeManager(
IServiceConnectionManager serviceConnectionManager,
@@ -29,12 +32,15 @@ public ServiceLifetimeManager(
AzureSignalRMarkerService marker,
IOptions globalHubOptions,
IOptions> hubOptions,
- IBlazorDetector blazorDetector)
+ IBlazorDetector blazorDetector,
+ IServerNameProvider nameProvider,
+ IClientInvocationManager clientInvocationManager)
: base(
serviceConnectionManager,
protocolResolver,
globalHubOptions,
- hubOptions, logger)
+ hubOptions,
+ logger)
{
// after core 3.0 UseAzureSignalR() is not required.
#if NETSTANDARD2_0
@@ -43,12 +49,15 @@ public ServiceLifetimeManager(
throw new InvalidOperationException(MarkerNotConfiguredError);
}
#endif
- _clientConnectionManager = clientConnectionManager;
-
if (hubOptions.Value.SupportedProtocols != null && hubOptions.Value.SupportedProtocols.Any(x => x.Equals(Constants.Protocol.BlazorPack, StringComparison.OrdinalIgnoreCase)))
{
blazorDetector?.TrySetBlazor(typeof(THub).Name, true);
}
+
+ _callerId = nameProvider?.GetName() ?? throw new ArgumentNullException(nameof(nameProvider));
+
+ _clientInvocationManager = clientInvocationManager ?? throw new ArgumentNullException(nameof(clientInvocationManager));
+ _clientConnectionManager = clientConnectionManager ?? throw new ArgumentNullException(nameof(clientConnectionManager));
}
public override Task OnConnectedAsync(HubConnectionContext connection)
@@ -103,6 +112,80 @@ public override async Task SendConnectionAsync(string connectionId, string metho
}
}
+#if NET7_0_OR_GREATER
+ public override async Task InvokeConnectionAsync(string connectionId, string methodName, object[] args, CancellationToken cancellationToken = default)
+ {
+ if (IsInvalidArgument(connectionId))
+ {
+ throw new ArgumentNullException(nameof(connectionId));
+ }
+
+ if (IsInvalidArgument(methodName))
+ {
+ throw new ArgumentNullException(nameof(methodName));
+ }
+
+ var invocationId = _clientInvocationManager.Caller.GenerateInvocationId(connectionId);
+ var message = AppendMessageTracingId(new ClientInvocationMessage(invocationId, connectionId, _callerId, SerializeAllProtocols(methodName, args, invocationId)));
+ await WriteAsync(message);
+ var task = _clientInvocationManager.Caller.AddInvocation(connectionId, invocationId, cancellationToken);
+
+ // Exception handling follows https://source.dot.net/#Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs,349
+ try
+ {
+ return await task;
+ }
+ catch
+ {
+ _clientInvocationManager.Caller.RemoveInvocation(invocationId);
+ throw;
+ }
+ }
+
+ public override async Task SetConnectionResultAsync(string connectionId, CompletionMessage result)
+ {
+ if (IsInvalidArgument(connectionId))
+ {
+ throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(connectionId));
+ }
+ if (_clientConnectionManager.ClientConnections.TryGetValue(connectionId, out var clientConnectionContext))
+ {
+ // Determine which manager (Caller / Router) the `result` belongs to.
+ // `TryCompletionResult` returns false when the corresponding invocation is not existing.
+ IClientResultsManager clientResultsManager = null;
+ var payload = new ReadOnlyMemory();
+ if (_clientInvocationManager.Caller.TryCompleteResult(connectionId, result))
+ {
+ clientResultsManager = _clientInvocationManager.Caller;
+ // For caller server, the only purpose of sending ClientCompletionMessage is to inform service to cleanup the invocation, which means only InvocationId and ConnectionId make sense. To avoid serialization for useless payload, we keep payload as empty bytes.
+ }
+ if (_clientInvocationManager.Router.TryCompleteResult(connectionId, result))
+ {
+ clientResultsManager = _clientInvocationManager.Router;
+ // For router server, it should send a ClientCompletionMessage with accurate payload content, which is necessary for the caller server.
+ payload = SerializeCompletionMessage(result, clientConnectionContext.Protocol);
+ }
+
+ // Block unknown `results` which belongs to neither Caller nor Router
+ if (clientResultsManager != null)
+ {
+ var protocol = clientConnectionContext.Protocol;
+ var message = AppendMessageTracingId(new ClientCompletionMessage(result.InvocationId, connectionId, _callerId, protocol, payload));
+ await WriteAsync(message);
+ }
+ }
+ }
+
+ public override bool TryGetReturnType(string invocationId, [NotNullWhen(true)] out Type type)
+ {
+ if (_clientInvocationManager.Router.TryGetInvocationReturnType(invocationId, out type))
+ {
+ return true;
+ }
+ return _clientInvocationManager.Caller.TryGetInvocationReturnType(invocationId, out type);
+ }
+#endif
+
private MultiConnectionDataMessage CreateMessage(string connectionId, string methodName, object[] args, ClientConnectionContext serviceConnectionContext)
{
IDictionary> payloads;
diff --git a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManagerBase.cs b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManagerBase.cs
index 99677181f..d34e6676f 100644
--- a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManagerBase.cs
+++ b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManagerBase.cs
@@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
+using System.Diagnostics.CodeAnalysis;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Azure.SignalR.Protocol;
@@ -283,10 +284,18 @@ protected static bool IsInvalidArgument(IReadOnlyList