diff --git a/docs/management-sdk-guide.md b/docs/management-sdk-guide.md
index 27f214fda..78f82c644 100644
--- a/docs/management-sdk-guide.md
+++ b/docs/management-sdk-guide.md
@@ -183,7 +183,7 @@ This SDK can communicates to Azure SignalR Service with two transport types:
| | Transient | Persistent |
| ------------------------------ | ------------------ | --------------------------------- |
| Default JSON library | `Newtonsoft.Json` | The same as Asp.Net Core SignalR:
`Newtonsoft.Json` for .NET Standard 2.0;
`System.Text.Json` for .NET Core App 3.1 and above |
-| MessaegPack clients support | since v1.21.0 | since v1.20.0 |
+| MessagePack clients support | since v1.21.0 | since v1.20.0 |
#### Json serialization
See [Customizing Json Serialization in Management SDK](./advanced-topics/json-object-serializer.md)
diff --git a/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnection.cs b/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnection.cs
index 09da21aa0..e297c429f 100644
--- a/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnection.cs
+++ b/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnection.cs
@@ -19,19 +19,22 @@ namespace Microsoft.Azure.SignalR.AspNet
{
internal partial class ServiceConnection : ServiceConnectionBase
{
+ private const string ReconnectMessage = "asrs:reconnect";
+
private static readonly Dictionary CustomHeader = new Dictionary
{{Constants.AsrsUserAgent, ProductInfo.GetProductInfo()}};
- private const string ReconnectMessage = "asrs:reconnect";
-
private static readonly TimeSpan CloseApplicationTimeout = TimeSpan.FromSeconds(5);
private readonly ConcurrentDictionary _clientConnections =
new ConcurrentDictionary(StringComparer.Ordinal);
private readonly IConnectionFactory _connectionFactory;
+
private readonly IClientConnectionManager _clientConnectionManager;
+ private readonly AckHandler _ackHandler;
+
public ServiceConnection(
string serverId,
string connectionId,
@@ -42,6 +45,7 @@ public ServiceConnection(
ILoggerFactory loggerFactory,
IServiceMessageHandler serviceMessageHandler,
IServiceEventHandler serviceEventHandler,
+ AckHandler ackHandler,
ServiceConnectionType connectionType = ServiceConnectionType.Default)
: base(
serviceProtocol,
@@ -55,6 +59,7 @@ public ServiceConnection(
{
_connectionFactory = connectionFactory;
_clientConnectionManager = clientConnectionManager;
+ _ackHandler = ackHandler;
}
protected override Task CreateConnection(string target = null)
@@ -147,6 +152,17 @@ protected virtual async Task CleanupConnectionsAsyncCore(string instanceId = nul
}
}
+ private static string GetString(ReadOnlySequence buffer)
+ {
+ if (buffer.IsSingleSegment)
+ {
+ MemoryMarshal.TryGetArray(buffer.First, out var segment);
+ return Encoding.UTF8.GetString(segment.Array, segment.Offset, segment.Count);
+ }
+
+ return Encoding.UTF8.GetString(buffer.ToArray());
+ }
+
private async Task ForwardMessageToApplication(string connectionId, ServiceMessage message)
{
if (_clientConnections.TryGetValue(connectionId, out var clientContext))
@@ -230,6 +246,7 @@ private async Task OnConnectedAsyncCore(ClientConnectionContext clientContext, O
catch (Exception e)
{
Log.ConnectedStartingFailed(Logger, connectionId, e);
+
// Should not wait for application task inside the application task
_ = PerformDisconnectCore(connectionId, false);
_ = SafeWriteAsync(new CloseConnectionMessage(connectionId, e.Message));
@@ -276,14 +293,18 @@ private async Task ProcessMessageAsync(ClientConnectionContext clientContext, Ca
case OpenConnectionMessage openConnectionMessage:
await OnConnectedAsyncCore(clientContext, openConnectionMessage);
break;
+
case CloseConnectionMessage closeConnectionMessage:
+
// should not wait for application task when inside the application task
// As the messages are in a queue, close message should be after all the other messages
await PerformDisconnectCore(closeConnectionMessage.ConnectionId, false);
return;
+
case ConnectionDataMessage connectionDataMessage:
ProcessOutgoingMessages(clientContext, connectionDataMessage);
break;
+
default:
break;
}
@@ -304,17 +325,6 @@ private async Task ProcessMessageAsync(ClientConnectionContext clientContext, Ca
}
}
- private static string GetString(ReadOnlySequence buffer)
- {
- if (buffer.IsSingleSegment)
- {
- MemoryMarshal.TryGetArray(buffer.First, out var segment);
- return Encoding.UTF8.GetString(segment.Array, segment.Offset, segment.Count);
- }
-
- return Encoding.UTF8.GetString(buffer.ToArray());
- }
-
private string GetInstanceId(IDictionary header)
{
if (header.TryGetValue(Constants.AsrsInstanceId, out var instanceId))
@@ -324,4 +334,4 @@ private string GetInstanceId(IDictionary header)
return null;
}
}
-}
\ No newline at end of file
+}
diff --git a/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnectionFactory.cs b/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnectionFactory.cs
index b943be39d..5d6ad4d60 100644
--- a/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnectionFactory.cs
+++ b/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnectionFactory.cs
@@ -29,7 +29,7 @@ public ServiceConnectionFactory(
_serviceEventHandler = serviceEventHandler;
}
- public IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, ServiceConnectionType type)
+ public IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, AckHandler ackHandler, ServiceConnectionType type)
{
return new ServiceConnection(
_nameProvider.GetName(),
@@ -41,6 +41,7 @@ public IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHan
_logger,
serviceMessageHandler,
_serviceEventHandler,
+ ackHandler,
type);
}
}
diff --git a/src/Microsoft.Azure.SignalR.Common/Constants.cs b/src/Microsoft.Azure.SignalR.Common/Constants.cs
index a691ff4be..790bb8633 100644
--- a/src/Microsoft.Azure.SignalR.Common/Constants.cs
+++ b/src/Microsoft.Azure.SignalR.Common/Constants.cs
@@ -113,5 +113,11 @@ public static class ErrorCodes
public const string InfoUserNotInGroup = "Info.User.NotInGroup";
public const string ErrorConnectionNotExisted = "Error.Connection.NotExisted";
}
+
+ public static class HttpClientNames
+ {
+ public const string Resilient = "Resilient";
+ public const string MessageResilient = "MessageResilient";
+ }
}
}
\ No newline at end of file
diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/AadAccessKey.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/AadAccessKey.cs
index 874961eed..411c39a73 100644
--- a/src/Microsoft.Azure.SignalR.Common/Endpoints/AadAccessKey.cs
+++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/AadAccessKey.cs
@@ -1,4 +1,7 @@
-using System;
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
using System.Collections.Generic;
using System.Net;
using System.Net.Http;
@@ -154,7 +157,14 @@ private async Task AuthorizeWithRetryAsync(CancellationToken ctoken = default)
catch (Exception e)
{
latest = e;
- await Task.Delay(AuthorizeRetryInterval);
+ try
+ {
+ await Task.Delay(AuthorizeRetryInterval, ctoken);
+ }
+ catch (OperationCanceledException)
+ {
+ break;
+ }
}
}
@@ -169,7 +179,6 @@ private async Task AuthorizeWithTokenAsync(string accessToken, CancellationToken
await new RestClient().SendAsync(
api,
HttpMethod.Get,
- "",
handleExpectedResponseAsync: HandleHttpResponseAsync,
cancellationToken: ctoken);
}
diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.cs
index 227d3850e..8bedc821e 100644
--- a/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.cs
+++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.cs
@@ -88,11 +88,11 @@ private async Task UpdateAccessKeyAsync(AadAccessKey key)
try
{
await key.UpdateAccessKeyAsync();
- Log.SucceedToAuthorizeAccessKey(logger, key.Endpoint.AbsoluteUri);
+ Log.SucceedToAuthorizeAccessKey(logger, key.AuthorizeUrl);
}
catch (Exception e)
{
- Log.FailedToAuthorizeAccessKey(logger, key.Endpoint.AbsoluteUri, e);
+ Log.FailedToAuthorizeAccessKey(logger, key.AuthorizeUrl, e);
}
}
diff --git a/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionFactory.cs b/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionFactory.cs
index f3ac55c14..e182663d7 100644
--- a/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionFactory.cs
+++ b/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionFactory.cs
@@ -2,6 +2,6 @@
{
internal interface IServiceConnectionFactory
{
- IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, ServiceConnectionType type);
+ IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, AckHandler ackHandler, ServiceConnectionType type);
}
}
diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs
index 64ef90b2b..c9a257275 100644
--- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs
+++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs
@@ -4,7 +4,6 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
-using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
@@ -17,18 +16,18 @@ namespace Microsoft.Azure.SignalR
internal abstract class ServiceConnectionContainerBase : IServiceConnectionContainer, IServiceMessageHandler
{
private const int CheckWindow = 5;
- private static readonly TimeSpan CheckTimeSpan = TimeSpan.FromMinutes(10);
// Give interval(5s) * 24 = 2min window for retry considering abnormal case.
private const int MaxRetryRemoveSeverConnection = 24;
+ private static readonly TimeSpan CheckTimeSpan = TimeSpan.FromMinutes(10);
+
private static readonly int MaxReconnectBackOffInternalInMilliseconds = 1000;
+
// Give (interval * 3 + 1) delay when check value expire.
private static readonly long DefaultServersPingTimeoutTicks = Stopwatch.Frequency * ((long)Constants.Periods.DefaultServersPingInterval.TotalSeconds * 3 + 1);
- private static readonly Tuple DefaultServersTagContext = new Tuple(string.Empty, 0);
- private static TimeSpan ReconnectInterval =>
- TimeSpan.FromMilliseconds(StaticRandom.Next(MaxReconnectBackOffInternalInMilliseconds));
+ private static readonly Tuple DefaultServersTagContext = new Tuple(string.Empty, 0);
private readonly BackOffPolicy _backOffPolicy = new BackOffPolicy();
@@ -36,8 +35,6 @@ internal abstract class ServiceConnectionContainerBase : IServiceConnectionConta
private readonly object _statusLock = new object();
- private (int count, DateTime? last) _inactiveInfo;
-
private readonly AckHandler _ackHandler;
private readonly CustomizedPingTimer _statusPing;
@@ -46,34 +43,21 @@ internal abstract class ServiceConnectionContainerBase : IServiceConnectionConta
private readonly ServiceDiagnosticLogsContext _serviceDiagnosticLogsContext = new ServiceDiagnosticLogsContext { EnableMessageLog = false };
+ private (int count, DateTime? last) _inactiveInfo;
+
private volatile List _serviceConnections;
private volatile ServiceConnectionStatus _status;
-
//
private volatile Tuple _serversTagContext = DefaultServersTagContext;
- private volatile bool _hasClients;
- private volatile bool _terminated = false;
- protected ILogger Logger { get; }
-
- protected List ServiceConnections
- {
- get { return _serviceConnections; }
- set { _serviceConnections = value; }
- }
-
- protected IServiceConnectionFactory ServiceConnectionFactory { get; }
-
- protected int FixedConnectionCount { get; }
+ private volatile bool _hasClients;
- protected virtual ServiceConnectionType InitialConnectionType { get; } = ServiceConnectionType.Default;
+ private volatile bool _terminated = false;
public HubServiceEndpoint Endpoint { get; }
- public event Action ConnectionStatusChanged;
-
public string ServersTag => _serversTagContext.Item1;
public bool HasClients => _hasClients;
@@ -99,6 +83,26 @@ private set
}
}
+ public Task ConnectionInitializedTask => Task.WhenAny(from connection in ServiceConnections
+ select connection.ConnectionInitializedTask);
+
+ protected ILogger Logger { get; }
+
+ protected List ServiceConnections
+ {
+ get => _serviceConnections;
+ set => _serviceConnections = value;
+ }
+
+ protected IServiceConnectionFactory ServiceConnectionFactory { get; }
+
+ protected int FixedConnectionCount { get; }
+
+ protected virtual ServiceConnectionType InitialConnectionType { get; } = ServiceConnectionType.Default;
+
+ private static TimeSpan ReconnectInterval =>
+ TimeSpan.FromMilliseconds(StaticRandom.Next(MaxReconnectBackOffInternalInMilliseconds));
+
protected ServiceConnectionContainerBase(IServiceConnectionFactory serviceConnectionFactory,
int minConnectionCount,
HubServiceEndpoint endpoint,
@@ -112,7 +116,7 @@ protected ServiceConnectionContainerBase(IServiceConnectionFactory serviceConnec
_ackHandler = ackHandler ?? new AckHandler();
// make sure it is after _endpoint is set
- // init initial connections
+ // init initial connections
List initial;
if (initialConnections == null)
{
@@ -140,7 +144,7 @@ protected ServiceConnectionContainerBase(IServiceConnectionFactory serviceConnec
ConnectionStatusChanged += OnStatusChanged;
_statusPing = new CustomizedPingTimer(Logger, Constants.CustomizedPingTimer.ServiceStatus, WriteServiceStatusPingAsync, Constants.Periods.DefaultStatusPingInterval, Constants.Periods.DefaultStatusPingInterval);
-
+
// when server connection count is specified to 0, the app server only handle negotiate requests
if (initial.Count > 0)
{
@@ -150,6 +154,8 @@ protected ServiceConnectionContainerBase(IServiceConnectionFactory serviceConnec
_serversPing = new CustomizedPingTimer(Logger, Constants.CustomizedPingTimer.Servers, WriteServersPingAsync, Constants.Periods.DefaultServersPingInterval, Constants.Periods.DefaultServersPingInterval);
}
+ public event Action ConnectionStatusChanged;
+
public async Task StartAsync()
{
using (new ServiceConnectionContainerScope(_serviceDiagnosticLogsContext))
@@ -165,27 +171,6 @@ public virtual Task StopAsync()
return Task.WhenAll(ServiceConnections.Select(c => c.StopAsync()));
}
- ///
- /// Start and manage the whole connection lifetime
- ///
- ///
- protected async Task StartCoreAsync(IServiceConnection connection, string target = null)
- {
- if (_terminated)
- {
- return;
- }
-
- try
- {
- await connection.StartAsync(target);
- }
- finally
- {
- await OnConnectionComplete(connection);
- }
- }
-
public virtual Task HandlePingAsync(PingMessage pingMessage)
{
if (RuntimeServicePingMessage.TryGetClientCount(pingMessage, out var clientCount))
@@ -226,9 +211,6 @@ public void HandleAck(AckMessage ackMessage)
_ackHandler.TriggerAck(ackMessage.AckId, (AckStatus)ackMessage.Status);
}
- public Task ConnectionInitializedTask => Task.WhenAny(from connection in ServiceConnections
- select connection.ConnectionInitializedTask);
-
public virtual Task WriteAsync(ServiceMessage serviceMessage)
{
return WriteToScopedOrRandomAvailableConnection(serviceMessage);
@@ -241,28 +223,17 @@ public async Task WriteAckableMessageAsync(ServiceMessage serviceMessage,
throw new ArgumentException($"{nameof(serviceMessage)} is not {nameof(IAckableMessage)}");
}
- var task = _ackHandler.CreateAck(out var id, cancellationToken);
+ var task = _ackHandler.CreateSingleAck(out var id, null, cancellationToken);
ackableMessage.AckId = id;
- // Sending regular messages completes as soon as the data leaves the outbound pipe,
- // whereas ackable ones complete upon full roundtrip of the message and the ack (or timeout).
+ // Sending regular messages completes as soon as the data leaves the outbound pipe,
+ // whereas ackable ones complete upon full roundtrip of the message and the ack (or timeout).
// Therefore sending them over different connections creates a possibility for processing them out of original order.
// By sending both message types over the same connection we ensure that they are sent (and processed) in their original order.
await WriteToScopedOrRandomAvailableConnection(serviceMessage);
var status = await task;
- switch (status)
- {
- case AckStatus.Ok:
- return true;
- case AckStatus.NotFound:
- return false;
- case AckStatus.Timeout:
- case AckStatus.InternalServerError:
- throw new TimeoutException($"Ack-able message {serviceMessage.GetType()}(ackId: {ackableMessage.AckId}) timed out.");
- default:
- throw new AzureSignalRException($"Ack-able message {serviceMessage.GetType()}(ackId: {ackableMessage.AckId}) gets error ack status {status}.");
- }
+ return AckHandler.HandleAckStatus(ackableMessage, status);
}
public virtual Task OfflineAsync(GracefulShutdownMode mode)
@@ -296,12 +267,65 @@ public void Dispose()
GC.SuppressFinalize(this);
}
+ internal static TimeSpan GetRetryDelay(int retryCount)
+ {
+ // retry count: 0, 1, 2, 3, 4, 5, 6, ...
+ // delay seconds: 1, 2, 4, 8, 16, 32, 60, ...
+ if (retryCount > 5)
+ {
+ return TimeSpan.FromMinutes(1) + ReconnectInterval;
+ }
+ return TimeSpan.FromSeconds(1 << retryCount) + ReconnectInterval;
+ }
+
+ internal bool GetServiceStatus(bool active, int checkWindow, TimeSpan checkTimeSpan)
+ {
+ if (active)
+ {
+ _inactiveInfo = (0, null);
+ return true;
+ }
+ else
+ {
+ var info = _inactiveInfo;
+ var last = info.last ?? DateTime.UtcNow;
+ var count = info.count;
+ count++;
+ _inactiveInfo = (count, last);
+
+ // Inactive it only when it checks over 5 times and elapsed for over 10 minutes
+ var inactive = count >= checkWindow && DateTime.UtcNow - last >= checkTimeSpan;
+ return !inactive;
+ }
+ }
+
+ ///
+ /// Start and manage the whole connection lifetime
+ ///
+ ///
+ protected async Task StartCoreAsync(IServiceConnection connection, string target = null)
+ {
+ if (_terminated)
+ {
+ return;
+ }
+
+ try
+ {
+ await connection.StartAsync(target);
+ }
+ finally
+ {
+ await OnConnectionComplete(connection);
+ }
+ }
+
///
/// Create a connection for a specific service connection type
///
protected IServiceConnection CreateServiceConnectionCore(ServiceConnectionType type)
{
- var connection = ServiceConnectionFactory.Create(Endpoint, this, type);
+ var connection = ServiceConnectionFactory.Create(Endpoint, this, _ackHandler, type);
connection.ConnectionStatusChanged += OnConnectionStatusChanged;
return connection;
@@ -324,6 +348,7 @@ protected virtual async Task OnConnectionComplete(IServiceConnection serviceConn
{
await RestartFixedServiceConnectionCoreAsync(index);
}
+
// the rest are "on demand" and are only created upon request
else
{
@@ -332,47 +357,6 @@ protected virtual async Task OnConnectionComplete(IServiceConnection serviceConn
}
}
- private async Task RestartFixedServiceConnectionCoreAsync(int index)
- {
- if (_terminated)
- {
- return;
- }
-
- Func> tryNewConnection = async () =>
- {
- var connection = CreateServiceConnectionCore(InitialConnectionType);
- ReplaceFixedConnection(index, connection);
-
- _ = StartCoreAsync(connection);
- await connection.ConnectionInitializedTask;
-
- return connection.Status == ServiceConnectionStatus.Connected;
- };
- await _backOffPolicy.CallProbeWithBackOffAsync(tryNewConnection, GetRetryDelay);
- }
-
- private void ReplaceFixedConnection(int index, IServiceConnection serviceConnection)
- {
- lock (_lock)
- {
- var newImmutableConnections = ServiceConnections.ToList();
- newImmutableConnections[index] = serviceConnection;
- ServiceConnections = newImmutableConnections;
- }
- }
-
- private void RemoveOnDemandConnection(IServiceConnection serviceConnection)
- {
- lock (_lock)
- {
- var newImmutableConnections = ServiceConnections.ToList();
- Debug.Assert(newImmutableConnections.IndexOf(serviceConnection) >= FixedConnectionCount);
- newImmutableConnections.Remove(serviceConnection);
- ServiceConnections = newImmutableConnections;
- }
- }
-
protected void AddOnDemandConnection(IServiceConnection serviceConnection)
{
lock (_lock)
@@ -426,36 +410,45 @@ protected async Task RemoveConnectionAsync(IServiceConnection c, GracefulShutdow
Log.TimeoutWaitingForFinAck(Logger, retry);
}
- internal bool GetServiceStatus(bool active, int checkWindow, TimeSpan checkTimeSpan)
+ private async Task RestartFixedServiceConnectionCoreAsync(int index)
{
- if (active)
+ if (_terminated)
{
- _inactiveInfo = (0, null);
- return true;
+ return;
}
- else
+
+ Func> tryNewConnection = async () =>
{
- var info = _inactiveInfo;
- var last = info.last ?? DateTime.UtcNow;
- var count = info.count;
- count++;
- _inactiveInfo = (count, last);
+ var connection = CreateServiceConnectionCore(InitialConnectionType);
+ ReplaceFixedConnection(index, connection);
- // Inactive it only when it checks over 5 times and elapsed for over 10 minutes
- var inactive = count >= checkWindow && DateTime.UtcNow - last >= checkTimeSpan;
- return !inactive;
+ _ = StartCoreAsync(connection);
+ await connection.ConnectionInitializedTask;
+
+ return connection.Status == ServiceConnectionStatus.Connected;
+ };
+ await _backOffPolicy.CallProbeWithBackOffAsync(tryNewConnection, GetRetryDelay);
+ }
+
+ private void ReplaceFixedConnection(int index, IServiceConnection serviceConnection)
+ {
+ lock (_lock)
+ {
+ var newImmutableConnections = ServiceConnections.ToList();
+ newImmutableConnections[index] = serviceConnection;
+ ServiceConnections = newImmutableConnections;
}
}
- internal static TimeSpan GetRetryDelay(int retryCount)
+ private void RemoveOnDemandConnection(IServiceConnection serviceConnection)
{
- // retry count: 0, 1, 2, 3, 4, 5, 6, ...
- // delay seconds: 1, 2, 4, 8, 16, 32, 60, ...
- if (retryCount > 5)
+ lock (_lock)
{
- return TimeSpan.FromMinutes(1) + ReconnectInterval;
+ var newImmutableConnections = ServiceConnections.ToList();
+ Debug.Assert(newImmutableConnections.IndexOf(serviceConnection) >= FixedConnectionCount);
+ newImmutableConnections.Remove(serviceConnection);
+ ServiceConnections = newImmutableConnections;
}
- return TimeSpan.FromSeconds(1 << retryCount) + ReconnectInterval;
}
private void OnStatusChanged(StatusChange obj)
@@ -597,12 +590,17 @@ private async Task SafeWriteAsync(ServiceMessage serviceMessage)
protected internal sealed class CustomizedPingTimer : IDisposable
{
private readonly object _lock = new object();
+
private readonly long _defaultPingTicks;
private readonly string _pingName;
+
private readonly Func _writePing;
+
private readonly TimeSpan _dueTime;
+
private readonly TimeSpan _intervalTime;
+
private readonly ILogger _logger;
// Considering parallel add endpoints to save time,
@@ -610,6 +608,7 @@ protected internal sealed class CustomizedPingTimer : IDisposable
private long _counter = 0;
private long _lastSendTimestamp = 0;
+
private TimerAwaitable _timer;
public CustomizedPingTimer(ILogger logger, string pingName, Func writePing, TimeSpan dueTime, TimeSpan intervalTime)
diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/AckHandler.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/AckHandler.cs
index e3fbf8c2b..2f0d75840 100644
--- a/src/Microsoft.Azure.SignalR.Common/Utilities/AckHandler.cs
+++ b/src/Microsoft.Azure.SignalR.Common/Utilities/AckHandler.cs
@@ -2,70 +2,140 @@
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;
+using Microsoft.Azure.SignalR.Common;
+using Microsoft.Azure.SignalR.Protocol;
+
+#nullable enable
namespace Microsoft.Azure.SignalR
{
internal sealed class AckHandler : IDisposable
{
- private readonly ConcurrentDictionary _acks = new ConcurrentDictionary();
+ private readonly ConcurrentDictionary _acks = new();
private readonly Timer _timer;
- private readonly TimeSpan _ackInterval;
- private readonly TimeSpan _ackTtl;
- private int _currentId = 0;
+ private readonly TimeSpan _defaultAckTimeout;
+ private volatile bool _disposed;
+
+ private int _nextId;
+ private int NextId() => Interlocked.Increment(ref _nextId);
+
+ public AckHandler(int ackIntervalInMilliseconds = 3000, int ackTtlInMilliseconds = 10000) : this(TimeSpan.FromMilliseconds(ackIntervalInMilliseconds), TimeSpan.FromMilliseconds(ackTtlInMilliseconds)) { }
- public AckHandler(int ackIntervalInMilliseconds = 3000, int ackTtlInMilliseconds = 10000)
+ internal AckHandler(TimeSpan ackInterval, TimeSpan defaultAckTimeout)
{
- _ackInterval = TimeSpan.FromMilliseconds(ackIntervalInMilliseconds);
- _ackTtl = TimeSpan.FromMilliseconds(ackTtlInMilliseconds);
+ _defaultAckTimeout = defaultAckTimeout;
+ _timer = new Timer(_ => CheckAcks(), null, ackInterval, ackInterval);
+ }
- bool restoreFlow = false;
- try
+ public Task CreateSingleAck(out int id, TimeSpan? ackTimeout = default, CancellationToken cancellationToken = default)
+ {
+ id = NextId();
+ if (_disposed)
{
- if (!ExecutionContext.IsFlowSuppressed())
- {
- ExecutionContext.SuppressFlow();
- restoreFlow = true;
- }
+ return Task.FromResult(AckStatus.Ok);
+ }
+ var info = (IAckInfo)_acks.GetOrAdd(id, _ => new SingleAckInfo(ackTimeout ?? _defaultAckTimeout));
+ if (info is MultiAckInfo)
+ {
+ throw new InvalidOperationException();
+ }
+ cancellationToken.Register(() => info.Cancel());
+ return info.Task;
+ }
- _timer = new Timer(state => ((AckHandler)state).CheckAcks(), state: this, dueTime: _ackInterval, period: _ackInterval);
+ public static bool HandleAckStatus(IAckableMessage message, AckStatus status)
+ {
+ return status switch
+ {
+ AckStatus.Ok => true,
+ AckStatus.NotFound => false,
+ AckStatus.Timeout or AckStatus.InternalServerError => throw new TimeoutException($"Ack-able message {message.GetType()}(ackId: {message.AckId}) timed out."),
+ _ => throw new AzureSignalRException($"Ack-able message {message.GetType()}(ackId: {message.AckId}) gets error ack status {status}."),
+ };
+ }
+
+ public Task CreateMultiAck(out int id, TimeSpan? ackTimeout = default)
+ {
+ id = NextId();
+ if (_disposed)
+ {
+ return Task.FromResult(AckStatus.Ok);
}
- finally
+ var info = (IAckInfo)_acks.GetOrAdd(id, _ => new MultiAckInfo(ackTimeout ?? _defaultAckTimeout));
+ if (info is SingleAckInfo)
{
- // Restore the current ExecutionContext
- if (restoreFlow)
- {
- ExecutionContext.RestoreFlow();
- }
+ throw new InvalidOperationException();
}
+ return info.Task;
}
- public Task CreateAck(out int id, CancellationToken cancellationToken = default)
+ public void TriggerAck(int id, AckStatus status = AckStatus.Ok)
{
- id = Interlocked.Increment(ref _currentId);
- var tcs = _acks.GetOrAdd(id, _ => new AckInfo(_ackTtl)).Tcs;
- cancellationToken.Register(() => tcs.TrySetCanceled());
- return tcs.Task;
+ if (_acks.TryGetValue(id, out var info))
+ {
+ switch (info)
+ {
+ case IAckInfo ackInfo:
+ if (ackInfo.Ack(status))
+ {
+ _acks.TryRemove(id, out _);
+ }
+ break;
+ default:
+ throw new InvalidCastException($"Expected: IAckInfo<{typeof(IAckInfo).Name}>, actual type: {info.GetType().Name}");
+ }
+ }
}
- public void TriggerAck(int id, AckStatus ackStatus)
+ public void SetExpectedCount(int id, int expectedCount)
{
- if (_acks.TryRemove(id, out var ack))
+ if (_disposed)
+ {
+ return;
+ }
+
+ if (_acks.TryGetValue(id, out var info))
{
- ack.Tcs.TrySetResult(ackStatus);
+ if (info is not IMultiAckInfo multiAckInfo)
+ {
+ throw new InvalidOperationException();
+ }
+ if (multiAckInfo.SetExpectedCount(expectedCount))
+ {
+ _acks.TryRemove(id, out _);
+ }
}
}
private void CheckAcks()
{
+ if (_disposed)
+ {
+ return;
+ }
+
var utcNow = DateTime.UtcNow;
- foreach (var pair in _acks)
+ foreach (var item in _acks)
{
- if (utcNow > pair.Value.Expired)
+ var id = item.Key;
+ var ack = item.Value;
+ if (utcNow > ack.TimeoutAt)
{
- if (_acks.TryRemove(pair.Key, out var ack))
+ if (_acks.TryRemove(id, out _))
{
- ack.Tcs.TrySetResult(AckStatus.Timeout);
+ if (ack is SingleAckInfo singleAckInfo)
+ {
+ singleAckInfo.Ack(AckStatus.Timeout);
+ }
+ else if (ack is MultiAckInfo multipleAckInfo)
+ {
+ multipleAckInfo.ForceAck(AckStatus.Timeout);
+ }
+ else
+ {
+ ack.Cancel();
+ }
}
}
}
@@ -73,28 +143,135 @@ private void CheckAcks()
public void Dispose()
{
- _timer?.Dispose();
+ _disposed = true;
+
+ _timer.Dispose();
- foreach (var pair in _acks)
+ while (!_acks.IsEmpty)
{
- if (_acks.TryRemove(pair.Key, out var ack))
+ foreach (var item in _acks)
{
- ack.Tcs.TrySetCanceled();
+ var id = item.Key;
+ var ack = item.Value;
+ if (_acks.TryRemove(id, out _))
+ {
+ ack.Cancel();
+ if (ack is IDisposable disposable)
+ {
+ disposable.Dispose();
+ }
+ }
}
}
}
- private class AckInfo
+ private interface IAckInfo
+ {
+ DateTime TimeoutAt { get; }
+ void Cancel();
+ }
+
+ private interface IAckInfo : IAckInfo
+ {
+ Task Task { get; }
+ bool Ack(T status);
+ }
+
+ public interface IMultiAckInfo
+ {
+ bool SetExpectedCount(int expectedCount);
+ }
+
+ private sealed class SingleAckInfo : IAckInfo
+ {
+ public readonly TaskCompletionSource _tcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
+
+ public DateTime TimeoutAt { get; }
+
+ public SingleAckInfo(TimeSpan timeout)
+ {
+ TimeoutAt = DateTime.UtcNow + timeout;
+ }
+
+ public bool Ack(AckStatus status = AckStatus.Ok) =>
+ _tcs.TrySetResult(status);
+
+ public Task Task => _tcs.Task;
+
+ public void Cancel() => _tcs.TrySetCanceled();
+ }
+
+ private sealed class MultiAckInfo : IAckInfo, IMultiAckInfo
{
- public TaskCompletionSource Tcs { get; private set; }
+ public readonly TaskCompletionSource _tcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
- public DateTime Expired { get; private set; }
+ private int _ackCount;
+ private int? _expectedCount;
- public AckInfo(TimeSpan ttl)
+ public DateTime TimeoutAt { get; }
+
+ public MultiAckInfo(TimeSpan timeout)
{
- Expired = DateTime.UtcNow.Add(ttl);
- Tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ TimeoutAt = DateTime.UtcNow + timeout;
}
+
+ public bool SetExpectedCount(int expectedCount)
+ {
+ if (expectedCount < 0)
+ {
+ throw new ArgumentException("Cannot less than 0.", nameof(expectedCount));
+ }
+ bool result;
+ lock (_tcs)
+ {
+ if (_expectedCount != null)
+ {
+ throw new InvalidOperationException("Cannot set expected count more than once!");
+ }
+ _expectedCount = expectedCount;
+ result = expectedCount <= _ackCount;
+ }
+ if (result)
+ {
+ _tcs.TrySetResult(AckStatus.Ok);
+ }
+ return result;
+ }
+
+ public bool Ack(AckStatus status = AckStatus.Ok)
+ {
+ bool result;
+ lock (_tcs)
+ {
+ _ackCount++;
+ result = _expectedCount <= _ackCount;
+ }
+ if (result)
+ {
+ _tcs.TrySetResult(status);
+ }
+ return result;
+ }
+
+ ///
+ /// Forcely ack the multi ack regardless of the expected count.
+ ///
+ ///
+ ///
+ public bool ForceAck(AckStatus status = AckStatus.Ok)
+ {
+ lock (_tcs)
+ {
+ _ackCount = _expectedCount ?? 0;
+ }
+ _tcs.TrySetResult(status);
+ return true;
+ }
+
+ public Task Task => _tcs.Task;
+
+ public void Cancel() => _tcs.TrySetCanceled();
}
+
}
-}
\ No newline at end of file
+}
diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs
index 550954d79..4ecf5cf45 100644
--- a/src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs
+++ b/src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs
@@ -11,84 +11,84 @@
using System.Threading.Tasks;
using Azure.Core.Serialization;
using Microsoft.Azure.SignalR.Common;
+using Microsoft.Extensions.Options;
using Microsoft.Extensions.Primitives;
+#nullable enable
+
namespace Microsoft.Azure.SignalR
{
internal class RestClient
{
private readonly IHttpClientFactory _httpClientFactory;
private readonly IPayloadContentBuilder _payloadContentBuilder;
- private readonly bool _enableMessageTracing;
- public RestClient(IHttpClientFactory httpClientFactory, IPayloadContentBuilder contentBuilder, bool enableMessageTracing)
+ public RestClient(IHttpClientFactory httpClientFactory, IPayloadContentBuilder contentBuilder)
{
_httpClientFactory = httpClientFactory;
_payloadContentBuilder = contentBuilder;
- _enableMessageTracing = enableMessageTracing;
}
-
- public RestClient(IHttpClientFactory httpClientFactory, ObjectSerializer objectSerializer, bool enableMessageTracing) : this(httpClientFactory, new JsonPayloadContentBuilder(objectSerializer), enableMessageTracing)
+ // TODO: Test only, will remove later
+ internal RestClient(IHttpClientFactory httpClientFactory) : this(httpClientFactory, new JsonPayloadContentBuilder(new JsonObjectSerializer()))
{
}
-
- public RestClient() : this(HttpClientFactory.Instance, new JsonObjectSerializer(), false)
+ // TODO: remove later
+ public RestClient() : this(HttpClientFactory.Instance)
{
}
public Task SendAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
- string productInfo,
- string methodName = null,
- object[] args = null,
- Func handleExpectedResponse = null,
+ string? methodName = null,
+ object[]? args = null,
+ Func? handleExpectedResponse = null,
CancellationToken cancellationToken = default)
{
if (handleExpectedResponse == null)
{
- return SendAsync(api, httpMethod, productInfo, methodName, args, handleExpectedResponseAsync: null, cancellationToken);
+ return SendAsync(api, httpMethod, methodName, args, handleExpectedResponseAsync: null, cancellationToken);
}
- return SendAsync(api, httpMethod, productInfo, methodName, args, response => Task.FromResult(handleExpectedResponse(response)), cancellationToken);
+ return SendAsync(api, httpMethod, methodName, args, response => Task.FromResult(handleExpectedResponse(response)), cancellationToken);
}
- public async Task SendAsync(
+ public Task SendAsync(
RestApiEndpoint api,
HttpMethod httpMethod,
- string productInfo,
- string methodName = null,
- object[] args = null,
- Func> handleExpectedResponseAsync = null,
+ string? methodName = null,
+ object[]? args = null,
+ Func>? handleExpectedResponseAsync = null,
CancellationToken cancellationToken = default)
{
- using var httpClient = _httpClientFactory.CreateClient();
- using var request = BuildRequest(api, httpMethod, productInfo, methodName, args);
+ return SendAsyncCore(Options.DefaultName, api, httpMethod, methodName, args, handleExpectedResponseAsync, cancellationToken);
+ }
- try
- {
- using var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken);
- if (handleExpectedResponseAsync == null)
- {
- await ThrowExceptionOnResponseFailureAsync(response);
- }
- else
- {
- if (!await handleExpectedResponseAsync(response))
- {
- await ThrowExceptionOnResponseFailureAsync(response);
- }
- }
- }
- catch (HttpRequestException ex)
- {
- throw new AzureSignalRException($"An error happened when making request to {request.RequestUri}", ex);
- }
+ public Task SendWithRetryAsync(
+ RestApiEndpoint api,
+ HttpMethod httpMethod,
+ string? methodName = null,
+ object[]? args = null,
+ Func? handleExpectedResponse = null,
+ CancellationToken cancellationToken = default)
+ {
+ return SendAsyncCore(Constants.HttpClientNames.Resilient, api, httpMethod, methodName, args, handleExpectedResponse == null ? null : response => Task.FromResult(handleExpectedResponse(response)), cancellationToken);
+ }
+
+ public Task SendMessageWithRetryAsync(
+ RestApiEndpoint api,
+ HttpMethod httpMethod,
+ string? methodName = null,
+ object[]? args = null,
+ Func? handleExpectedResponse = null,
+ CancellationToken cancellationToken = default)
+ {
+ return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, methodName, args, handleExpectedResponse == null ? null : response => Task.FromResult(handleExpectedResponse(response)), cancellationToken);
}
- public async Task ThrowExceptionOnResponseFailureAsync(HttpResponseMessage response)
+ private async Task ThrowExceptionOnResponseFailureAsync(HttpResponseMessage response)
{
if (response.IsSuccessStatusCode)
{
@@ -106,14 +106,47 @@ public async Task ThrowExceptionOnResponseFailureAsync(HttpResponseMessage respo
#endif
throw response.StatusCode switch
{
- HttpStatusCode.BadRequest => new AzureSignalRInvalidArgumentException(response.RequestMessage.RequestUri.ToString(), innerException, detail),
- HttpStatusCode.Unauthorized => new AzureSignalRUnauthorizedException(response.RequestMessage.RequestUri.ToString(), innerException),
- HttpStatusCode.NotFound => new AzureSignalRInaccessibleEndpointException(response.RequestMessage.RequestUri.ToString(), innerException),
- _ => new AzureSignalRRuntimeException(response.RequestMessage.RequestUri.ToString(), innerException),
+ HttpStatusCode.BadRequest => new AzureSignalRInvalidArgumentException(response.RequestMessage?.RequestUri?.ToString(), innerException, detail),
+ HttpStatusCode.Unauthorized => new AzureSignalRUnauthorizedException(response.RequestMessage?.RequestUri?.ToString(), innerException),
+ HttpStatusCode.NotFound => new AzureSignalRInaccessibleEndpointException(response.RequestMessage?.RequestUri?.ToString(), innerException),
+ _ => new AzureSignalRRuntimeException(response.RequestMessage?.RequestUri?.ToString(), innerException),
};
}
- private static Uri GetUri(string url, IDictionary query)
+ private async Task SendAsyncCore(
+ string httpClientName,
+ RestApiEndpoint api,
+ HttpMethod httpMethod,
+ string? methodName = null,
+ object[]? args = null,
+ Func>? handleExpectedResponseAsync = null,
+ CancellationToken cancellationToken = default)
+ {
+ using var httpClient = _httpClientFactory.CreateClient(httpClientName);
+ using var request = BuildRequest(api, httpMethod, methodName, args);
+
+ try
+ {
+ using var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken);
+ if (handleExpectedResponseAsync == null)
+ {
+ await ThrowExceptionOnResponseFailureAsync(response);
+ }
+ else
+ {
+ if (!await handleExpectedResponseAsync(response))
+ {
+ await ThrowExceptionOnResponseFailureAsync(response);
+ }
+ }
+ }
+ catch (HttpRequestException ex)
+ {
+ throw new AzureSignalRException($"An error happened when making request to {request.RequestUri}", ex);
+ }
+ }
+
+ private static Uri GetUri(string url, IDictionary? query)
{
if (query == null || query.Count == 0)
{
@@ -136,40 +169,25 @@ private static Uri GetUri(string url, IDictionary query)
sb.Append(sb.Length > 0 ? '&' : '?');
sb.Append(Uri.EscapeDataString(item.Key));
sb.Append('=');
- sb.Append(Uri.EscapeDataString(value));
+ sb.Append(Uri.EscapeDataString(value!));
}
}
builder.Query = sb.ToString();
return builder.Uri;
}
- private HttpRequestMessage BuildRequest(RestApiEndpoint api, HttpMethod httpMethod, string productInfo, string methodName = null, object[] args = null)
+ private HttpRequestMessage BuildRequest(RestApiEndpoint api, HttpMethod httpMethod, string? methodName = null, object[]? args = null)
{
var payload = httpMethod == HttpMethod.Post ? new PayloadMessage { Target = methodName, Arguments = args } : null;
- if (_enableMessageTracing)
- {
- AddTracingId(api);
- }
- return GenerateHttpRequest(api.Audience, api.Query, httpMethod, payload, api.Token, productInfo);
+ return GenerateHttpRequest(api.Audience, api.Query, httpMethod, payload, api.Token);
}
- private HttpRequestMessage GenerateHttpRequest(string url, IDictionary query, HttpMethod httpMethod, PayloadMessage payload, string tokenString, string productInfo)
+ private HttpRequestMessage GenerateHttpRequest(string url, IDictionary query, HttpMethod httpMethod, PayloadMessage? payload, string tokenString)
{
var request = new HttpRequestMessage(httpMethod, GetUri(url, query));
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", tokenString);
- request.Headers.Add(Constants.AsrsUserAgent, productInfo);
request.Content = _payloadContentBuilder.Build(payload);
return request;
}
-
- private void AddTracingId(RestApiEndpoint api)
- {
- var id = MessageWithTracingIdHelper.Generate();
- if (api.Query == null)
- {
- api.Query = new Dictionary();
- }
- api.Query.Add(Constants.Headers.AsrsMessageTracingId, id.ToString());
- }
}
}
\ No newline at end of file
diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/TimerAwaitable.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/TimerAwaitable.cs
index 6c84a058d..86fd38e8e 100644
--- a/src/Microsoft.Azure.SignalR.Common/Utilities/TimerAwaitable.cs
+++ b/src/Microsoft.Azure.SignalR.Common/Utilities/TimerAwaitable.cs
@@ -83,7 +83,8 @@ public void Stop()
// Stop should be used to trigger the call to end the loop which disposes
if (_disposed)
{
- throw new ObjectDisposedException(GetType().FullName);
+ // no need to throw, to allow Stop to be called multiple times safely
+ return;
}
_running = false;
diff --git a/src/Microsoft.Azure.SignalR.Management/Configuration/ServiceManagerOptions.cs b/src/Microsoft.Azure.SignalR.Management/Configuration/ServiceManagerOptions.cs
index 2e59940e6..994a42395 100644
--- a/src/Microsoft.Azure.SignalR.Management/Configuration/ServiceManagerOptions.cs
+++ b/src/Microsoft.Azure.SignalR.Management/Configuration/ServiceManagerOptions.cs
@@ -6,6 +6,8 @@
using Azure.Core.Serialization;
using Newtonsoft.Json;
+#nullable enable
+
namespace Microsoft.Azure.SignalR.Management
{
///
@@ -16,7 +18,7 @@ public class ServiceManagerOptions
///
/// Gets or sets the ApplicationName which will be prefixed to each hub name
///
- public string ApplicationName { get; set; }
+ public string? ApplicationName { get; set; }
///
/// Gets or sets the total number of connections from SDK to Azure SignalR Service. Default value is 1.
@@ -26,17 +28,17 @@ public class ServiceManagerOptions
///
/// Gets or sets a service endpoint of Azure SignalR Service instance by connection string.
///
- public string ConnectionString { get; set; } = null;
+ public string? ConnectionString { get; set; } = null;
///
/// Gets or sets multiple service endpoints of Azure SignalR Service instances.
///
- public ServiceEndpoint[] ServiceEndpoints { get; set; }
+ public ServiceEndpoint[]? ServiceEndpoints { get; set; }
///
/// Gets or sets the proxy used when ServiceManager will attempt to connect to Azure SignalR Service.
///
- public IWebProxy Proxy { get; set; }
+ public IWebProxy? Proxy { get; set; }
///
/// Gets or sets the transport type to Azure SignalR Service. Default value is Transient.
@@ -48,6 +50,8 @@ public class ServiceManagerOptions
///
public TimeSpan HttpClientTimeout { get; set; } = TimeSpan.FromSeconds(100);
+ public ServiceManagerRetryOptions? RetryOptions { get; set; }
+
///
/// Gets the json serializer settings that will be used to serialize content sent to Azure SignalR Service.
///
@@ -57,7 +61,7 @@ public class ServiceManagerOptions
///
/// If users want to use MessagePack, they should go to
///
- internal ObjectSerializer ObjectSerializer { get; set; }
+ internal ObjectSerializer? ObjectSerializer { get; set; }
///
/// Set a JSON object serializer used to serialize the data sent to clients.
@@ -73,7 +77,7 @@ public void UseJsonObjectSerializer(ObjectSerializer objectSerializer)
// not ready
internal bool EnableMessageTracing { get; set; } = false;
- internal string ProductInfo { get; set; }
+ internal string? ProductInfo { get; set; }
internal void ValidateOptions()
{
diff --git a/src/Microsoft.Azure.SignalR.Management/Configuration/ServiceManagerRetryMode.cs b/src/Microsoft.Azure.SignalR.Management/Configuration/ServiceManagerRetryMode.cs
new file mode 100644
index 000000000..c1094a5a2
--- /dev/null
+++ b/src/Microsoft.Azure.SignalR.Management/Configuration/ServiceManagerRetryMode.cs
@@ -0,0 +1,22 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+namespace Microsoft.Azure.SignalR.Management;
+
+#nullable enable
+
+///
+/// The type of approach to apply when calculating the delay between retry attempts.
+///
+public enum ServiceManagerRetryMode
+{
+ ///
+ /// Retry attempts happen at fixed intervals; each delay is a consistent duration.
+ ///
+ Fixed,
+ ///
+ /// Retry attempts will delay based on a backoff strategy, where each attempt will
+ /// increase the duration that it waits before retrying.
+ ///
+ Exponential
+}
\ No newline at end of file
diff --git a/src/Microsoft.Azure.SignalR.Management/Configuration/ServiceManagerRetryOptions.cs b/src/Microsoft.Azure.SignalR.Management/Configuration/ServiceManagerRetryOptions.cs
new file mode 100644
index 000000000..97ab35bdd
--- /dev/null
+++ b/src/Microsoft.Azure.SignalR.Management/Configuration/ServiceManagerRetryOptions.cs
@@ -0,0 +1,34 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+
+namespace Microsoft.Azure.SignalR.Management;
+
+#nullable enable
+
+public class ServiceManagerRetryOptions
+{
+ ///
+ /// The maximum number of retry attempts before giving up.
+ ///
+ public int MaxRetries { get; set; } = 3;
+
+ ///
+ /// The delay between retry attempts for a fixed approach or the delay
+ /// on which to base calculations for a backoff-based approach.
+ ///
+ public TimeSpan Delay { get; set; } = TimeSpan.FromSeconds(0.8);
+
+ ///
+ /// The maximum permissible delay between retry attempts.
+ ///
+ public TimeSpan MaxDelay { get; set; } = TimeSpan.FromMinutes(1);
+
+ ///
+ /// The approach to use for calculating retry delays.
+ ///
+ public ServiceManagerRetryMode Mode { get; set; } = ServiceManagerRetryMode.Fixed;
+}
+
+
diff --git a/src/Microsoft.Azure.SignalR.Management/DependencyInjectionExtensions.cs b/src/Microsoft.Azure.SignalR.Management/DependencyInjectionExtensions.cs
index 83a3c6c99..0a9f3bfe1 100644
--- a/src/Microsoft.Azure.SignalR.Management/DependencyInjectionExtensions.cs
+++ b/src/Microsoft.Azure.SignalR.Management/DependencyInjectionExtensions.cs
@@ -2,8 +2,10 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Linq;
+using System.Net;
using System.Net.Http;
using System.Reflection;
+using System.Threading;
using System.Threading.Tasks;
using Azure.Core.Serialization;
using Microsoft.AspNetCore.Connections;
@@ -156,15 +158,93 @@ private static IServiceCollection TrySetProductInfo(this IServiceCollection serv
return services.Configure(o => o.ProductInfo ??= productInfo);
}
- private static IServiceCollection AddRestClientFactory(this IServiceCollection services) => services
- .AddHttpClient(Options.DefaultName, (sp, client) => client.Timeout = sp.GetRequiredService>().Value.HttpClientTimeout)
- .ConfigurePrimaryHttpMessageHandler(sp => new HttpClientHandler() { Proxy = sp.GetRequiredService>().Value.Proxy }).Services
- .AddSingleton(sp =>
+ private static IServiceCollection AddRestClientFactory(this IServiceCollection services)
+ {
+ // For AAD, health check.
+ services
+ .AddHttpClient(Options.DefaultName, (sp, client) =>
+ {
+ client.Timeout = sp.GetRequiredService>().Value.HttpClientTimeout;
+ ConfigureProduceInfo(sp, client);
+ })
+ .ConfigurePrimaryHttpMessageHandler(ConfigureProxy);
+
+ // For other data plane APIs.
+ services.AddSingleton(sp =>
+ {
+ var options = sp.GetRequiredService>().Value;
+ var retryOptions = options.RetryOptions;
+ return retryOptions == null
+ ? new DummyBackOffPolicy()
+ : retryOptions.Mode switch
+ {
+ ServiceManagerRetryMode.Fixed => ActivatorUtilities.CreateInstance(sp),
+ ServiceManagerRetryMode.Exponential => ActivatorUtilities.CreateInstance(sp),
+ _ => throw new NotSupportedException($"Retry mode {retryOptions.Mode} is not supported.")
+ };
+ });
+ services
+ .AddHttpClient(Constants.HttpClientNames.Resilient, (sp, client) =>
+ {
+ var options = sp.GetRequiredService>().Value;
+ if (options.RetryOptions == null)
+ {
+ client.Timeout = options.HttpClientTimeout;
+ }
+ else
+ {
+ // The timeout is enforced by TimeoutHttpMessageHandler.
+ client.Timeout = Timeout.InfiniteTimeSpan;
+ }
+ ConfigureProduceInfo(sp, client);
+ ConfigureMessageTracingId(sp, client);
+ })
+ .ConfigurePrimaryHttpMessageHandler(ConfigureProxy)
+ .AddHttpMessageHandler(sp => ActivatorUtilities.CreateInstance(sp, (HttpStatusCode code) => IsTransientErrorForNonMessageApi(code)))
+ .AddHttpMessageHandler(sp => ActivatorUtilities.CreateInstance(sp));
+
+ services
+ .AddHttpClient(Constants.HttpClientNames.MessageResilient, (sp, client) =>
+ {
+ client.Timeout = sp.GetRequiredService>().Value.HttpClientTimeout;
+ ConfigureProduceInfo(sp, client);
+ ConfigureMessageTracingId(sp, client);
+ })
+ .ConfigurePrimaryHttpMessageHandler(ConfigureProxy)
+ .AddHttpMessageHandler(sp => ActivatorUtilities.CreateInstance(sp, (HttpStatusCode code) => IsTransientErrorAndIdempotentForMessageApi(code)));
+
+ services.AddSingleton(sp =>
{
var options = sp.GetRequiredService>().Value;
var productInfo = options.ProductInfo;
var httpClientFactory = sp.GetRequiredService();
return new RestClientFactory(productInfo, httpClientFactory);
});
+
+ return services;
+
+ static HttpMessageHandler ConfigureProxy(IServiceProvider sp) => new HttpClientHandler() { Proxy = sp.GetRequiredService>().Value.Proxy };
+
+ static bool IsTransientErrorAndIdempotentForMessageApi(HttpStatusCode code) =>
+ // Runtime returns 500 for timeout errors too, to avoid duplicate message, we exclude 500 here.
+ code > HttpStatusCode.InternalServerError;
+
+ static bool IsTransientErrorForNonMessageApi(HttpStatusCode code) =>
+ code >= HttpStatusCode.InternalServerError ||
+ code == HttpStatusCode.RequestTimeout;
+
+ static void ConfigureProduceInfo(IServiceProvider sp, HttpClient client) =>
+ client.DefaultRequestHeaders.Add(Constants.AsrsUserAgent, sp.GetRequiredService>().Value.ProductInfo ??
+ // The following value should not be used.
+ "Microsoft.Azure.SignalR.Management/");
+
+ static void ConfigureMessageTracingId(IServiceProvider sp, HttpClient client)
+ {
+ if (sp.GetRequiredService>().Value.EnableMessageTracing)
+ {
+ client.DefaultRequestHeaders.Add(Constants.Headers.AsrsMessageTracingId, MessageWithTracingIdHelper.Generate().ToString());
+ }
+ }
+ }
}
}
diff --git a/src/Microsoft.Azure.SignalR.Management/HubInstanceFactories/ServiceHubLifetimeManagerFactory.cs b/src/Microsoft.Azure.SignalR.Management/HubInstanceFactories/ServiceHubLifetimeManagerFactory.cs
index 207303517..dffa6d3da 100644
--- a/src/Microsoft.Azure.SignalR.Management/HubInstanceFactories/ServiceHubLifetimeManagerFactory.cs
+++ b/src/Microsoft.Azure.SignalR.Management/HubInstanceFactories/ServiceHubLifetimeManagerFactory.cs
@@ -38,8 +38,8 @@ public IServiceHubLifetimeManager Create(string hubName) where THub
var payloadBuilderResolver = _serviceProvider.GetRequiredService();
var httpClientFactory = _serviceProvider.GetRequiredService();
var serviceEndpoint = _serviceProvider.GetRequiredService().Endpoints.First().Key;
- var restClient = new RestClient(httpClientFactory, payloadBuilderResolver.GetPayloadContentBuilder(), _options.EnableMessageTracing);
- return new RestHubLifetimeManager(hubName, serviceEndpoint, _options.ProductInfo, _options.ApplicationName, restClient);
+ var restClient = new RestClient(httpClientFactory, payloadBuilderResolver.GetPayloadContentBuilder());
+ return new RestHubLifetimeManager(hubName, serviceEndpoint, _options.ApplicationName, restClient);
}
default: throw new InvalidEnumArgumentException(nameof(ServiceManagerOptions.ServiceTransportType), (int)_options.ServiceTransportType, typeof(ServiceTransportType));
}
diff --git a/src/Microsoft.Azure.SignalR.Management/Resilient/DummyBackoffPolicy.cs b/src/Microsoft.Azure.SignalR.Management/Resilient/DummyBackoffPolicy.cs
new file mode 100644
index 000000000..a160273d4
--- /dev/null
+++ b/src/Microsoft.Azure.SignalR.Management/Resilient/DummyBackoffPolicy.cs
@@ -0,0 +1,13 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+
+namespace Microsoft.Azure.SignalR.Management;
+
+internal class DummyBackOffPolicy : IBackOffPolicy
+{
+ public IEnumerable GetDelays() => Enumerable.Empty();
+}
diff --git a/src/Microsoft.Azure.SignalR.Management/Resilient/ExponentialBackoffPolicy.cs b/src/Microsoft.Azure.SignalR.Management/Resilient/ExponentialBackoffPolicy.cs
new file mode 100644
index 000000000..a9fbbf7fa
--- /dev/null
+++ b/src/Microsoft.Azure.SignalR.Management/Resilient/ExponentialBackoffPolicy.cs
@@ -0,0 +1,44 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using Microsoft.Extensions.Options;
+
+namespace Microsoft.Azure.SignalR.Management;
+
+internal class ExponentialBackOffPolicy : IBackOffPolicy
+{
+ private readonly int _maxRetries;
+ private readonly TimeSpan _minDelay;
+ private readonly TimeSpan _maxDelay;
+
+ public ExponentialBackOffPolicy(IOptions options)
+ {
+ var retryOptions = options.Value.RetryOptions ?? throw new ArgumentException();
+ if (retryOptions.Mode != ServiceManagerRetryMode.Exponential)
+ {
+ throw new ArgumentException();
+ }
+ _maxRetries = retryOptions.MaxRetries;
+ _minDelay = retryOptions.Delay;
+ _maxDelay = retryOptions.MaxDelay;
+ }
+ public IEnumerable GetDelays()
+ {
+ var lastDelay = TimeSpan.MinValue;
+ for (var i = 0; i < _maxRetries; i++)
+ {
+ if (lastDelay >= _maxDelay)
+ {
+ yield return _maxDelay;
+ }
+ else
+ {
+ var delay = TimeSpan.FromMilliseconds((1 << i) * (int)_minDelay.TotalMilliseconds);
+ lastDelay = delay < _maxDelay ? delay : _maxDelay;
+ yield return lastDelay;
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.Azure.SignalR.Management/Resilient/FixedBackoffPolicy.cs b/src/Microsoft.Azure.SignalR.Management/Resilient/FixedBackoffPolicy.cs
new file mode 100644
index 000000000..45c75eb35
--- /dev/null
+++ b/src/Microsoft.Azure.SignalR.Management/Resilient/FixedBackoffPolicy.cs
@@ -0,0 +1,32 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using Microsoft.Extensions.Options;
+
+namespace Microsoft.Azure.SignalR.Management;
+
+internal class FixedBackOffPolicy : IBackOffPolicy
+{
+ private readonly int _maxRetries;
+ private readonly TimeSpan _delay;
+ public FixedBackOffPolicy(IOptions options)
+ {
+ var retryOptions = options.Value.RetryOptions ?? throw new ArgumentException();
+ if (retryOptions.Mode != ServiceManagerRetryMode.Fixed)
+ {
+ throw new ArgumentException();
+ }
+ _maxRetries = retryOptions.MaxRetries;
+ _delay = retryOptions.Delay;
+ }
+
+ public IEnumerable GetDelays()
+ {
+ for (var i = 0; i < _maxRetries; i++)
+ {
+ yield return _delay;
+ }
+ }
+}
diff --git a/src/Microsoft.Azure.SignalR.Management/Resilient/IBackoffPolicy.cs b/src/Microsoft.Azure.SignalR.Management/Resilient/IBackoffPolicy.cs
new file mode 100644
index 000000000..73a799567
--- /dev/null
+++ b/src/Microsoft.Azure.SignalR.Management/Resilient/IBackoffPolicy.cs
@@ -0,0 +1,15 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using Microsoft.Extensions.Options;
+
+namespace Microsoft.Azure.SignalR.Management;
+
+#nullable enable
+
+internal interface IBackOffPolicy
+{
+ IEnumerable GetDelays();
+}
diff --git a/src/Microsoft.Azure.SignalR.Management/Resilient/RetryHttpMessageHandler.cs b/src/Microsoft.Azure.SignalR.Management/Resilient/RetryHttpMessageHandler.cs
new file mode 100644
index 000000000..c46be254c
--- /dev/null
+++ b/src/Microsoft.Azure.SignalR.Management/Resilient/RetryHttpMessageHandler.cs
@@ -0,0 +1,75 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.Net;
+using System.Net.Http;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Azure.SignalR.Common;
+
+namespace Microsoft.Azure.SignalR.Management;
+
+#nullable enable
+
+internal class RetryHttpMessageHandler : DelegatingHandler
+{
+ private readonly IBackOffPolicy _retryDelayProvider;
+ private readonly Func _canRetry;
+
+ public RetryHttpMessageHandler(IBackOffPolicy retryDelayProvider, Func transientErrorPredicate)
+ {
+ _retryDelayProvider = retryDelayProvider;
+ _canRetry = transientErrorPredicate;
+ }
+
+ protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
+ {
+ IList? exceptions = null;
+ IEnumerator? delays = null;
+ do
+ {
+ Exception? ex;
+ try
+ {
+ var response = await base.SendAsync(request, cancellationToken);
+ if (_canRetry(response.StatusCode))
+ {
+ var innerException = new HttpRequestException(
+ $"Response status code does not indicate success: {(int)response.StatusCode} ({response.ReasonPhrase})");
+ ex = new AzureSignalRRuntimeException(request.RequestUri?.ToString(), innerException);
+ response.Dispose();
+ }
+ else
+ {
+ return response;
+ }
+ }
+ catch (TaskCanceledException operationCanceledException) when (!cancellationToken.IsCancellationRequested && operationCanceledException.InnerException is TimeoutException)
+ {
+ // Thrown by our timeout handler
+ ex = operationCanceledException;
+ }
+ delays ??= _retryDelayProvider.GetDelays().GetEnumerator();
+ if (!delays.MoveNext())
+ {
+ if (exceptions == null)
+ {
+ throw ex;
+ }
+ else
+ {
+ exceptions.Add(ex);
+ throw new AzureSignalRRuntimeException(request.RequestUri?.ToString(), new AggregateException(exceptions));
+ }
+ }
+ else
+ {
+ exceptions ??= new List();
+ exceptions.Add(ex);
+ }
+ await Task.Delay(delays.Current, cancellationToken);
+ } while (true);
+ }
+}
diff --git a/src/Microsoft.Azure.SignalR.Management/Resilient/TimeoutHttpMessageHandler.cs b/src/Microsoft.Azure.SignalR.Management/Resilient/TimeoutHttpMessageHandler.cs
new file mode 100644
index 000000000..46b0bb434
--- /dev/null
+++ b/src/Microsoft.Azure.SignalR.Management/Resilient/TimeoutHttpMessageHandler.cs
@@ -0,0 +1,49 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Net.Http;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Extensions.Options;
+
+namespace Microsoft.Azure.SignalR.Management;
+
+#nullable enable
+
+internal class TimeoutHttpMessageHandler : DelegatingHandler
+{
+ private readonly bool _enableTimeout = false;
+ private readonly TimeSpan _timeout;
+ public TimeoutHttpMessageHandler(IOptions serviceManagerOptions)
+ {
+ var options = serviceManagerOptions.Value;
+ if (options.RetryOptions == null)
+ {
+ // Timeout handled by HttpClient for backward compatibility
+ _timeout = Timeout.InfiniteTimeSpan;
+ }
+ else
+ {
+ _timeout = options.HttpClientTimeout;
+ _enableTimeout = true;
+ }
+ }
+ protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
+ {
+ if (_enableTimeout)
+ {
+ var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
+ cts.CancelAfter(_timeout);
+ try
+ {
+ return await base.SendAsync(request, cts.Token);
+ }
+ catch (OperationCanceledException ex) when (!cancellationToken.IsCancellationRequested)
+ {
+ throw new TaskCanceledException($"The request was canceled due to the configured HttpClient.Timeout of {_timeout.TotalSeconds} seconds elapsing.", new TimeoutException(ex.Message, ex));
+ }
+ }
+ return await base.SendAsync(request, cancellationToken);
+ }
+}
diff --git a/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs b/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs
index 492169437..cf2b968d5 100644
--- a/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs
+++ b/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs
@@ -21,14 +21,12 @@ internal class RestHubLifetimeManager : HubLifetimeManager, IService
private readonly RestClient _restClient;
private readonly RestApiProvider _restApiProvider;
- private readonly string _productInfo;
private readonly string _hubName;
private readonly string _appName;
- public RestHubLifetimeManager(string hubName, ServiceEndpoint endpoint, string productInfo, string appName, RestClient restClient)
+ public RestHubLifetimeManager(string hubName, ServiceEndpoint endpoint, string appName, RestClient restClient)
{
_restApiProvider = new RestApiProvider(endpoint);
- _productInfo = productInfo;
_appName = appName;
_hubName = hubName;
_restClient = restClient;
@@ -47,7 +45,7 @@ public override async Task AddToGroupAsync(string connectionId, string groupName
}
var api = await _restApiProvider.GetConnectionGroupManagementEndpointAsync(_appName, _hubName, connectionId, groupName);
- await _restClient.SendAsync(api, HttpMethod.Put, _productInfo, handleExpectedResponse: static response => FilterExpectedResponse(response, ErrorCodes.ErrorConnectionNotExisted), cancellationToken: cancellationToken);
+ await _restClient.SendWithRetryAsync(api, HttpMethod.Put, handleExpectedResponse: static response => FilterExpectedResponse(response, ErrorCodes.ErrorConnectionNotExisted), cancellationToken: cancellationToken);
}
public override Task OnConnectedAsync(HubConnectionContext connection)
@@ -73,7 +71,7 @@ public override async Task RemoveFromGroupAsync(string connectionId, string grou
}
var api = await _restApiProvider.GetConnectionGroupManagementEndpointAsync(_appName, _hubName, connectionId, groupName);
- await _restClient.SendAsync(api, HttpMethod.Delete, _productInfo, handleExpectedResponseAsync: null, cancellationToken: cancellationToken);
+ await _restClient.SendWithRetryAsync(api, HttpMethod.Delete, handleExpectedResponse: null, cancellationToken: cancellationToken);
}
public async Task RemoveFromAllGroupsAsync(string connectionId, CancellationToken cancellationToken = default)
@@ -84,7 +82,7 @@ public async Task RemoveFromAllGroupsAsync(string connectionId, CancellationToke
}
var api = await _restApiProvider.GetRemoveConnectionFromAllGroupsAsync(_appName, _hubName, connectionId);
- await _restClient.SendAsync(api, HttpMethod.Delete, _productInfo, handleExpectedResponseAsync: null, cancellationToken: cancellationToken);
+ await _restClient.SendWithRetryAsync(api, HttpMethod.Delete, handleExpectedResponse: null, cancellationToken: cancellationToken);
}
public override Task SendAllAsync(string methodName, object[] args, CancellationToken cancellationToken = default)
@@ -100,7 +98,7 @@ public override async Task SendAllExceptAsync(string methodName, object[] args,
}
var api = await _restApiProvider.GetBroadcastEndpointAsync(_appName, _hubName, excluded: excludedConnectionIds);
- await _restClient.SendAsync(api, HttpMethod.Post, _productInfo, methodName, args, handleExpectedResponseAsync: null, cancellationToken: cancellationToken);
+ await _restClient.SendMessageWithRetryAsync(api, HttpMethod.Post, methodName, args, handleExpectedResponse: null, cancellationToken: cancellationToken);
}
public override async Task SendConnectionAsync(string connectionId, string methodName, object[] args, CancellationToken cancellationToken = default)
@@ -116,7 +114,7 @@ public override async Task SendConnectionAsync(string connectionId, string metho
}
var api = await _restApiProvider.GetSendToConnectionEndpointAsync(_appName, _hubName, connectionId);
- await _restClient.SendAsync(api, HttpMethod.Post, _productInfo, methodName, args, handleExpectedResponseAsync: null, cancellationToken: cancellationToken);
+ await _restClient.SendMessageWithRetryAsync(api, HttpMethod.Post, methodName, args, handleExpectedResponse: null, cancellationToken: cancellationToken);
}
public override async Task SendConnectionsAsync(IReadOnlyList connectionIds, string methodName, object[] args, CancellationToken cancellationToken = default)
@@ -142,7 +140,7 @@ public override async Task SendGroupExceptAsync(string groupName, string methodN
}
var api = await _restApiProvider.GetSendToGroupEndpointAsync(_appName, _hubName, groupName, excluded: excludedConnectionIds);
- await _restClient.SendAsync(api, HttpMethod.Post, _productInfo, methodName, args, handleExpectedResponseAsync: null, cancellationToken: cancellationToken);
+ await _restClient.SendMessageWithRetryAsync(api, HttpMethod.Post, methodName, args, handleExpectedResponse: null, cancellationToken: cancellationToken);
}
public override async Task SendGroupsAsync(IReadOnlyList groupNames, string methodName, object[] args, CancellationToken cancellationToken = default)
@@ -173,7 +171,7 @@ public override async Task SendUserAsync(string userId, string methodName, objec
}
var api = await _restApiProvider.GetSendToUserEndpointAsync(_appName, _hubName, userId);
- await _restClient.SendAsync(api, HttpMethod.Post, _productInfo, methodName, args, handleExpectedResponseAsync: null, cancellationToken: cancellationToken);
+ await _restClient.SendMessageWithRetryAsync(api, HttpMethod.Post, methodName, args, handleExpectedResponse: null, cancellationToken: cancellationToken);
}
public override async Task SendUsersAsync(IReadOnlyList userIds, string methodName, object[] args, CancellationToken cancellationToken = default)
@@ -196,7 +194,7 @@ public async Task UserAddToGroupAsync(string userId, string groupName, Cancellat
ValidateUserIdAndGroupName(userId, groupName);
var api = await _restApiProvider.GetUserGroupManagementEndpointAsync(_appName, _hubName, userId, groupName);
- await _restClient.SendAsync(api, HttpMethod.Put, _productInfo, handleExpectedResponseAsync: null, cancellationToken: cancellationToken);
+ await _restClient.SendWithRetryAsync(api, HttpMethod.Put, handleExpectedResponse: null, cancellationToken: cancellationToken);
}
public async Task UserAddToGroupAsync(string userId, string groupName, TimeSpan ttl, CancellationToken cancellationToken = default)
@@ -212,7 +210,7 @@ public async Task UserAddToGroupAsync(string userId, string groupName, TimeSpan
{
["ttl"] = ((int)ttl.TotalSeconds).ToString(),
};
- await _restClient.SendAsync(api, HttpMethod.Put, _productInfo, handleExpectedResponseAsync: null, cancellationToken: cancellationToken);
+ await _restClient.SendWithRetryAsync(api, HttpMethod.Put, handleExpectedResponse: null, cancellationToken: cancellationToken);
}
public async Task UserRemoveFromGroupAsync(string userId, string groupName, CancellationToken cancellationToken = default)
@@ -220,20 +218,20 @@ public async Task UserRemoveFromGroupAsync(string userId, string groupName, Canc
ValidateUserIdAndGroupName(userId, groupName);
var api = await _restApiProvider.GetUserGroupManagementEndpointAsync(_appName, _hubName, userId, groupName);
- await _restClient.SendAsync(api, HttpMethod.Delete, _productInfo, handleExpectedResponseAsync: null, cancellationToken: cancellationToken);
+ await _restClient.SendWithRetryAsync(api, HttpMethod.Delete, handleExpectedResponse: null, cancellationToken: cancellationToken);
}
public async Task UserRemoveFromAllGroupsAsync(string userId, CancellationToken cancellationToken = default)
{
var api = await _restApiProvider.GetRemoveUserFromAllGroupsAsync(_appName, _hubName, userId);
- await _restClient.SendAsync(api, HttpMethod.Delete, _productInfo, handleExpectedResponseAsync: null, cancellationToken: cancellationToken);
+ await _restClient.SendWithRetryAsync(api, HttpMethod.Delete, handleExpectedResponse: null, cancellationToken: cancellationToken);
}
public async Task IsUserInGroup(string userId, string groupName, CancellationToken cancellationToken = default)
{
var isUserInGroup = false;
var api = await _restApiProvider.GetUserGroupManagementEndpointAsync(_appName, _hubName, userId, groupName);
- await _restClient.SendAsync(api, HttpMethod.Get, _productInfo, handleExpectedResponse: response =>
+ await _restClient.SendWithRetryAsync(api, HttpMethod.Get, handleExpectedResponse: response =>
{
isUserInGroup = response.StatusCode == HttpStatusCode.OK;
return FilterExpectedResponse(response, ErrorCodes.InfoUserNotInGroup);
@@ -248,7 +246,7 @@ public async Task CloseConnectionAsync(string connectionId, string reason, Cance
throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(connectionId));
}
var api = await _restApiProvider.GetCloseConnectionEndpointAsync(_appName, _hubName, connectionId, reason);
- await _restClient.SendAsync(api, HttpMethod.Delete, _productInfo, handleExpectedResponse: static response => FilterExpectedResponse(response, ErrorCodes.WarningConnectionNotExisted), cancellationToken: cancellationToken);
+ await _restClient.SendWithRetryAsync(api, HttpMethod.Delete, handleExpectedResponse: static response => FilterExpectedResponse(response, ErrorCodes.WarningConnectionNotExisted), cancellationToken: cancellationToken);
}
private static void ValidateUserIdAndGroupName(string userId, string groupName)
@@ -272,7 +270,7 @@ public async Task ConnectionExistsAsync(string connectionId, CancellationT
}
var exists = false;
var api = await _restApiProvider.GetCheckConnectionExistsEndpointAsync(_appName, _hubName, connectionId);
- await _restClient.SendAsync(api, HttpMethod.Head, _productInfo, handleExpectedResponse: response =>
+ await _restClient.SendWithRetryAsync(api, HttpMethod.Head, handleExpectedResponse: response =>
{
exists = response.StatusCode == HttpStatusCode.OK;
return FilterExpectedResponse(response, ErrorCodes.WarningConnectionNotExisted);
@@ -288,7 +286,7 @@ public async Task UserExistsAsync(string userId, CancellationToken cancell
}
var exists = false;
var api = await _restApiProvider.GetCheckUserExistsEndpointAsync(_appName, _hubName, userId);
- await _restClient.SendAsync(api, HttpMethod.Head, _productInfo, handleExpectedResponse: response =>
+ await _restClient.SendWithRetryAsync(api, HttpMethod.Head, handleExpectedResponse: response =>
{
exists = response.StatusCode == HttpStatusCode.OK;
return FilterExpectedResponse(response, ErrorCodes.WarningUserNotExisted);
@@ -304,7 +302,7 @@ public async Task GroupExistsAsync(string groupName, CancellationToken can
}
var exists = false;
var api = await _restApiProvider.GetCheckGroupExistsEndpointAsync(_appName, _hubName, groupName);
- await _restClient.SendAsync(api, HttpMethod.Head, _productInfo, handleExpectedResponse: response =>
+ await _restClient.SendWithRetryAsync(api, HttpMethod.Head, handleExpectedResponse: response =>
{
exists = response.StatusCode == HttpStatusCode.OK;
return FilterExpectedResponse(response, ErrorCodes.WarningGroupNotExisted);
diff --git a/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs b/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs
index 33dfc7fd1..330dd76e1 100644
--- a/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs
+++ b/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs
@@ -51,7 +51,8 @@ public Task AddInvocation(string connectionId, string invocationId, Cancel
}
else
{
- tcs.TrySetException(new Exception(completionMessage.Error));
+ // Follow https://github.com/dotnet/aspnetcore/blob/v8.0.0-rc.2.23480.2/src/SignalR/common/Shared/ClientResultsManager.cs#L30
+ tcs.TrySetException(new HubException(completionMessage.Error));
}
})
);
diff --git a/src/Microsoft.Azure.SignalR/EndpointRouters/DefaultEndpointRouter.cs b/src/Microsoft.Azure.SignalR/EndpointRouters/DefaultEndpointRouter.cs
index 5f98d718b..a00c726cd 100644
--- a/src/Microsoft.Azure.SignalR/EndpointRouters/DefaultEndpointRouter.cs
+++ b/src/Microsoft.Azure.SignalR/EndpointRouters/DefaultEndpointRouter.cs
@@ -12,11 +12,10 @@ namespace Microsoft.Azure.SignalR
internal class DefaultEndpointRouter : DefaultMessageRouter, IEndpointRouter
{
///
- /// Randomly select from the available endpoints
+ /// Select an endpoint for negotiate request
///
/// The http context of the incoming request
/// All the available endpoints
- ///
public ServiceEndpoint GetNegotiateEndpoint(HttpContext context, IEnumerable endpoints)
{
// get primary endpoints snapshot
@@ -28,7 +27,7 @@ public ServiceEndpoint GetNegotiateEndpoint(HttpContext context, IEnumerable
- /// The availbale endpoints
+ /// The available endpoints
private ServiceEndpoint[] GetNegotiateEndpoints(IEnumerable endpoints)
{
var primary = endpoints.Where(s => s.Online && s.EndpointType == EndpointType.Primary).ToArray();
@@ -49,8 +48,7 @@ private ServiceEndpoint[] GetNegotiateEndpoints(IEnumerable end
///
/// Choose endpoint randomly by weight.
- /// The weight is defined as the remaining connection quota.
- /// The least weight is set to 1. So instance with no connection quota still has chance.
+ /// The weight is defined as (the remaining connection quota / the connection capacity).
///
private ServiceEndpoint GetEndpointAccordingToWeight(ServiceEndpoint[] availableEndpoints)
{
@@ -58,7 +56,7 @@ private ServiceEndpoint GetEndpointAccordingToWeight(ServiceEndpoint[] available
if (availableEndpoints.Any(endpoint => endpoint.EndpointMetrics.ConnectionCapacity == 0) ||
availableEndpoints.Length == 1)
{
- return GetEndpointRandomly(availableEndpoints);
+ return availableEndpoints[StaticRandom.Next(availableEndpoints.Length)];
}
var we = new int[availableEndpoints.Length];
@@ -69,7 +67,7 @@ private ServiceEndpoint GetEndpointAccordingToWeight(ServiceEndpoint[] available
var remain = endpointMetrics.ConnectionCapacity -
(endpointMetrics.ClientConnectionCount +
endpointMetrics.ServerConnectionCount);
- var weight = remain > 0 ? remain : 1;
+ var weight = Math.Max((int)((double)remain / endpointMetrics.ConnectionCapacity * 1000), 1);
totalCapacity += weight;
we[i] = totalCapacity;
}
@@ -78,10 +76,5 @@ private ServiceEndpoint GetEndpointAccordingToWeight(ServiceEndpoint[] available
return availableEndpoints[Array.FindLastIndex(we, x => x <= index) + 1];
}
-
- private static ServiceEndpoint GetEndpointRandomly(ServiceEndpoint[] availableEndpoints)
- {
- return availableEndpoints[StaticRandom.Next(availableEndpoints.Length)];
- }
}
}
\ No newline at end of file
diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs
index 877166c2c..b9d8e57f3 100644
--- a/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs
+++ b/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs
@@ -54,7 +54,9 @@ internal class ClientConnectionContext : ConnectionContext,
IConnectionStatFeature
{
private const int WritingState = 1;
+
private const int CompletedState = 2;
+
private const int IdleState = 0;
private static readonly PipeOptions DefaultPipeOptions = new PipeOptions(pauseWriterThreshold: 0,
@@ -63,12 +65,13 @@ internal class ClientConnectionContext : ConnectionContext,
useSynchronizationContext: false);
private readonly TaskCompletionSource