diff --git a/docs/management-sdk-guide.md b/docs/management-sdk-guide.md index 27f214fda..78f82c644 100644 --- a/docs/management-sdk-guide.md +++ b/docs/management-sdk-guide.md @@ -183,7 +183,7 @@ This SDK can communicates to Azure SignalR Service with two transport types: | | Transient | Persistent | | ------------------------------ | ------------------ | --------------------------------- | | Default JSON library | `Newtonsoft.Json` | The same as Asp.Net Core SignalR:
`Newtonsoft.Json` for .NET Standard 2.0;
`System.Text.Json` for .NET Core App 3.1 and above | -| MessaegPack clients support | since v1.21.0 | since v1.20.0 | +| MessagePack clients support | since v1.21.0 | since v1.20.0 | #### Json serialization See [Customizing Json Serialization in Management SDK](./advanced-topics/json-object-serializer.md) diff --git a/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnection.cs b/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnection.cs index 09da21aa0..e297c429f 100644 --- a/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnection.cs +++ b/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnection.cs @@ -19,19 +19,22 @@ namespace Microsoft.Azure.SignalR.AspNet { internal partial class ServiceConnection : ServiceConnectionBase { + private const string ReconnectMessage = "asrs:reconnect"; + private static readonly Dictionary CustomHeader = new Dictionary {{Constants.AsrsUserAgent, ProductInfo.GetProductInfo()}}; - private const string ReconnectMessage = "asrs:reconnect"; - private static readonly TimeSpan CloseApplicationTimeout = TimeSpan.FromSeconds(5); private readonly ConcurrentDictionary _clientConnections = new ConcurrentDictionary(StringComparer.Ordinal); private readonly IConnectionFactory _connectionFactory; + private readonly IClientConnectionManager _clientConnectionManager; + private readonly AckHandler _ackHandler; + public ServiceConnection( string serverId, string connectionId, @@ -42,6 +45,7 @@ public ServiceConnection( ILoggerFactory loggerFactory, IServiceMessageHandler serviceMessageHandler, IServiceEventHandler serviceEventHandler, + AckHandler ackHandler, ServiceConnectionType connectionType = ServiceConnectionType.Default) : base( serviceProtocol, @@ -55,6 +59,7 @@ public ServiceConnection( { _connectionFactory = connectionFactory; _clientConnectionManager = clientConnectionManager; + _ackHandler = ackHandler; } protected override Task CreateConnection(string target = null) @@ -147,6 +152,17 @@ protected virtual async Task CleanupConnectionsAsyncCore(string instanceId = nul } } + private static string GetString(ReadOnlySequence buffer) + { + if (buffer.IsSingleSegment) + { + MemoryMarshal.TryGetArray(buffer.First, out var segment); + return Encoding.UTF8.GetString(segment.Array, segment.Offset, segment.Count); + } + + return Encoding.UTF8.GetString(buffer.ToArray()); + } + private async Task ForwardMessageToApplication(string connectionId, ServiceMessage message) { if (_clientConnections.TryGetValue(connectionId, out var clientContext)) @@ -230,6 +246,7 @@ private async Task OnConnectedAsyncCore(ClientConnectionContext clientContext, O catch (Exception e) { Log.ConnectedStartingFailed(Logger, connectionId, e); + // Should not wait for application task inside the application task _ = PerformDisconnectCore(connectionId, false); _ = SafeWriteAsync(new CloseConnectionMessage(connectionId, e.Message)); @@ -276,14 +293,18 @@ private async Task ProcessMessageAsync(ClientConnectionContext clientContext, Ca case OpenConnectionMessage openConnectionMessage: await OnConnectedAsyncCore(clientContext, openConnectionMessage); break; + case CloseConnectionMessage closeConnectionMessage: + // should not wait for application task when inside the application task // As the messages are in a queue, close message should be after all the other messages await PerformDisconnectCore(closeConnectionMessage.ConnectionId, false); return; + case ConnectionDataMessage connectionDataMessage: ProcessOutgoingMessages(clientContext, connectionDataMessage); break; + default: break; } @@ -304,17 +325,6 @@ private async Task ProcessMessageAsync(ClientConnectionContext clientContext, Ca } } - private static string GetString(ReadOnlySequence buffer) - { - if (buffer.IsSingleSegment) - { - MemoryMarshal.TryGetArray(buffer.First, out var segment); - return Encoding.UTF8.GetString(segment.Array, segment.Offset, segment.Count); - } - - return Encoding.UTF8.GetString(buffer.ToArray()); - } - private string GetInstanceId(IDictionary header) { if (header.TryGetValue(Constants.AsrsInstanceId, out var instanceId)) @@ -324,4 +334,4 @@ private string GetInstanceId(IDictionary header) return null; } } -} \ No newline at end of file +} diff --git a/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnectionFactory.cs b/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnectionFactory.cs index b943be39d..5d6ad4d60 100644 --- a/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnectionFactory.cs +++ b/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnectionFactory.cs @@ -29,7 +29,7 @@ public ServiceConnectionFactory( _serviceEventHandler = serviceEventHandler; } - public IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, ServiceConnectionType type) + public IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, AckHandler ackHandler, ServiceConnectionType type) { return new ServiceConnection( _nameProvider.GetName(), @@ -41,6 +41,7 @@ public IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHan _logger, serviceMessageHandler, _serviceEventHandler, + ackHandler, type); } } diff --git a/src/Microsoft.Azure.SignalR.Common/Constants.cs b/src/Microsoft.Azure.SignalR.Common/Constants.cs index a691ff4be..790bb8633 100644 --- a/src/Microsoft.Azure.SignalR.Common/Constants.cs +++ b/src/Microsoft.Azure.SignalR.Common/Constants.cs @@ -113,5 +113,11 @@ public static class ErrorCodes public const string InfoUserNotInGroup = "Info.User.NotInGroup"; public const string ErrorConnectionNotExisted = "Error.Connection.NotExisted"; } + + public static class HttpClientNames + { + public const string Resilient = "Resilient"; + public const string MessageResilient = "MessageResilient"; + } } } \ No newline at end of file diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/AadAccessKey.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/AadAccessKey.cs index 874961eed..411c39a73 100644 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/AadAccessKey.cs +++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/AadAccessKey.cs @@ -1,4 +1,7 @@ -using System; +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; using System.Collections.Generic; using System.Net; using System.Net.Http; @@ -154,7 +157,14 @@ private async Task AuthorizeWithRetryAsync(CancellationToken ctoken = default) catch (Exception e) { latest = e; - await Task.Delay(AuthorizeRetryInterval); + try + { + await Task.Delay(AuthorizeRetryInterval, ctoken); + } + catch (OperationCanceledException) + { + break; + } } } @@ -169,7 +179,6 @@ private async Task AuthorizeWithTokenAsync(string accessToken, CancellationToken await new RestClient().SendAsync( api, HttpMethod.Get, - "", handleExpectedResponseAsync: HandleHttpResponseAsync, cancellationToken: ctoken); } diff --git a/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.cs b/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.cs index 227d3850e..8bedc821e 100644 --- a/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.cs +++ b/src/Microsoft.Azure.SignalR.Common/Endpoints/AccessKeySynchronizer.cs @@ -88,11 +88,11 @@ private async Task UpdateAccessKeyAsync(AadAccessKey key) try { await key.UpdateAccessKeyAsync(); - Log.SucceedToAuthorizeAccessKey(logger, key.Endpoint.AbsoluteUri); + Log.SucceedToAuthorizeAccessKey(logger, key.AuthorizeUrl); } catch (Exception e) { - Log.FailedToAuthorizeAccessKey(logger, key.Endpoint.AbsoluteUri, e); + Log.FailedToAuthorizeAccessKey(logger, key.AuthorizeUrl, e); } } diff --git a/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionFactory.cs b/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionFactory.cs index f3ac55c14..e182663d7 100644 --- a/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionFactory.cs +++ b/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionFactory.cs @@ -2,6 +2,6 @@ { internal interface IServiceConnectionFactory { - IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, ServiceConnectionType type); + IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, AckHandler ackHandler, ServiceConnectionType type); } } diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs index 64ef90b2b..c9a257275 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; using System.Diagnostics; -using System.IO; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -17,18 +16,18 @@ namespace Microsoft.Azure.SignalR internal abstract class ServiceConnectionContainerBase : IServiceConnectionContainer, IServiceMessageHandler { private const int CheckWindow = 5; - private static readonly TimeSpan CheckTimeSpan = TimeSpan.FromMinutes(10); // Give interval(5s) * 24 = 2min window for retry considering abnormal case. private const int MaxRetryRemoveSeverConnection = 24; + private static readonly TimeSpan CheckTimeSpan = TimeSpan.FromMinutes(10); + private static readonly int MaxReconnectBackOffInternalInMilliseconds = 1000; + // Give (interval * 3 + 1) delay when check value expire. private static readonly long DefaultServersPingTimeoutTicks = Stopwatch.Frequency * ((long)Constants.Periods.DefaultServersPingInterval.TotalSeconds * 3 + 1); - private static readonly Tuple DefaultServersTagContext = new Tuple(string.Empty, 0); - private static TimeSpan ReconnectInterval => - TimeSpan.FromMilliseconds(StaticRandom.Next(MaxReconnectBackOffInternalInMilliseconds)); + private static readonly Tuple DefaultServersTagContext = new Tuple(string.Empty, 0); private readonly BackOffPolicy _backOffPolicy = new BackOffPolicy(); @@ -36,8 +35,6 @@ internal abstract class ServiceConnectionContainerBase : IServiceConnectionConta private readonly object _statusLock = new object(); - private (int count, DateTime? last) _inactiveInfo; - private readonly AckHandler _ackHandler; private readonly CustomizedPingTimer _statusPing; @@ -46,34 +43,21 @@ internal abstract class ServiceConnectionContainerBase : IServiceConnectionConta private readonly ServiceDiagnosticLogsContext _serviceDiagnosticLogsContext = new ServiceDiagnosticLogsContext { EnableMessageLog = false }; + private (int count, DateTime? last) _inactiveInfo; + private volatile List _serviceConnections; private volatile ServiceConnectionStatus _status; - // private volatile Tuple _serversTagContext = DefaultServersTagContext; - private volatile bool _hasClients; - private volatile bool _terminated = false; - protected ILogger Logger { get; } - - protected List ServiceConnections - { - get { return _serviceConnections; } - set { _serviceConnections = value; } - } - - protected IServiceConnectionFactory ServiceConnectionFactory { get; } - - protected int FixedConnectionCount { get; } + private volatile bool _hasClients; - protected virtual ServiceConnectionType InitialConnectionType { get; } = ServiceConnectionType.Default; + private volatile bool _terminated = false; public HubServiceEndpoint Endpoint { get; } - public event Action ConnectionStatusChanged; - public string ServersTag => _serversTagContext.Item1; public bool HasClients => _hasClients; @@ -99,6 +83,26 @@ private set } } + public Task ConnectionInitializedTask => Task.WhenAny(from connection in ServiceConnections + select connection.ConnectionInitializedTask); + + protected ILogger Logger { get; } + + protected List ServiceConnections + { + get => _serviceConnections; + set => _serviceConnections = value; + } + + protected IServiceConnectionFactory ServiceConnectionFactory { get; } + + protected int FixedConnectionCount { get; } + + protected virtual ServiceConnectionType InitialConnectionType { get; } = ServiceConnectionType.Default; + + private static TimeSpan ReconnectInterval => + TimeSpan.FromMilliseconds(StaticRandom.Next(MaxReconnectBackOffInternalInMilliseconds)); + protected ServiceConnectionContainerBase(IServiceConnectionFactory serviceConnectionFactory, int minConnectionCount, HubServiceEndpoint endpoint, @@ -112,7 +116,7 @@ protected ServiceConnectionContainerBase(IServiceConnectionFactory serviceConnec _ackHandler = ackHandler ?? new AckHandler(); // make sure it is after _endpoint is set - // init initial connections + // init initial connections List initial; if (initialConnections == null) { @@ -140,7 +144,7 @@ protected ServiceConnectionContainerBase(IServiceConnectionFactory serviceConnec ConnectionStatusChanged += OnStatusChanged; _statusPing = new CustomizedPingTimer(Logger, Constants.CustomizedPingTimer.ServiceStatus, WriteServiceStatusPingAsync, Constants.Periods.DefaultStatusPingInterval, Constants.Periods.DefaultStatusPingInterval); - + // when server connection count is specified to 0, the app server only handle negotiate requests if (initial.Count > 0) { @@ -150,6 +154,8 @@ protected ServiceConnectionContainerBase(IServiceConnectionFactory serviceConnec _serversPing = new CustomizedPingTimer(Logger, Constants.CustomizedPingTimer.Servers, WriteServersPingAsync, Constants.Periods.DefaultServersPingInterval, Constants.Periods.DefaultServersPingInterval); } + public event Action ConnectionStatusChanged; + public async Task StartAsync() { using (new ServiceConnectionContainerScope(_serviceDiagnosticLogsContext)) @@ -165,27 +171,6 @@ public virtual Task StopAsync() return Task.WhenAll(ServiceConnections.Select(c => c.StopAsync())); } - /// - /// Start and manage the whole connection lifetime - /// - /// - protected async Task StartCoreAsync(IServiceConnection connection, string target = null) - { - if (_terminated) - { - return; - } - - try - { - await connection.StartAsync(target); - } - finally - { - await OnConnectionComplete(connection); - } - } - public virtual Task HandlePingAsync(PingMessage pingMessage) { if (RuntimeServicePingMessage.TryGetClientCount(pingMessage, out var clientCount)) @@ -226,9 +211,6 @@ public void HandleAck(AckMessage ackMessage) _ackHandler.TriggerAck(ackMessage.AckId, (AckStatus)ackMessage.Status); } - public Task ConnectionInitializedTask => Task.WhenAny(from connection in ServiceConnections - select connection.ConnectionInitializedTask); - public virtual Task WriteAsync(ServiceMessage serviceMessage) { return WriteToScopedOrRandomAvailableConnection(serviceMessage); @@ -241,28 +223,17 @@ public async Task WriteAckableMessageAsync(ServiceMessage serviceMessage, throw new ArgumentException($"{nameof(serviceMessage)} is not {nameof(IAckableMessage)}"); } - var task = _ackHandler.CreateAck(out var id, cancellationToken); + var task = _ackHandler.CreateSingleAck(out var id, null, cancellationToken); ackableMessage.AckId = id; - // Sending regular messages completes as soon as the data leaves the outbound pipe, - // whereas ackable ones complete upon full roundtrip of the message and the ack (or timeout). + // Sending regular messages completes as soon as the data leaves the outbound pipe, + // whereas ackable ones complete upon full roundtrip of the message and the ack (or timeout). // Therefore sending them over different connections creates a possibility for processing them out of original order. // By sending both message types over the same connection we ensure that they are sent (and processed) in their original order. await WriteToScopedOrRandomAvailableConnection(serviceMessage); var status = await task; - switch (status) - { - case AckStatus.Ok: - return true; - case AckStatus.NotFound: - return false; - case AckStatus.Timeout: - case AckStatus.InternalServerError: - throw new TimeoutException($"Ack-able message {serviceMessage.GetType()}(ackId: {ackableMessage.AckId}) timed out."); - default: - throw new AzureSignalRException($"Ack-able message {serviceMessage.GetType()}(ackId: {ackableMessage.AckId}) gets error ack status {status}."); - } + return AckHandler.HandleAckStatus(ackableMessage, status); } public virtual Task OfflineAsync(GracefulShutdownMode mode) @@ -296,12 +267,65 @@ public void Dispose() GC.SuppressFinalize(this); } + internal static TimeSpan GetRetryDelay(int retryCount) + { + // retry count: 0, 1, 2, 3, 4, 5, 6, ... + // delay seconds: 1, 2, 4, 8, 16, 32, 60, ... + if (retryCount > 5) + { + return TimeSpan.FromMinutes(1) + ReconnectInterval; + } + return TimeSpan.FromSeconds(1 << retryCount) + ReconnectInterval; + } + + internal bool GetServiceStatus(bool active, int checkWindow, TimeSpan checkTimeSpan) + { + if (active) + { + _inactiveInfo = (0, null); + return true; + } + else + { + var info = _inactiveInfo; + var last = info.last ?? DateTime.UtcNow; + var count = info.count; + count++; + _inactiveInfo = (count, last); + + // Inactive it only when it checks over 5 times and elapsed for over 10 minutes + var inactive = count >= checkWindow && DateTime.UtcNow - last >= checkTimeSpan; + return !inactive; + } + } + + /// + /// Start and manage the whole connection lifetime + /// + /// + protected async Task StartCoreAsync(IServiceConnection connection, string target = null) + { + if (_terminated) + { + return; + } + + try + { + await connection.StartAsync(target); + } + finally + { + await OnConnectionComplete(connection); + } + } + /// /// Create a connection for a specific service connection type /// protected IServiceConnection CreateServiceConnectionCore(ServiceConnectionType type) { - var connection = ServiceConnectionFactory.Create(Endpoint, this, type); + var connection = ServiceConnectionFactory.Create(Endpoint, this, _ackHandler, type); connection.ConnectionStatusChanged += OnConnectionStatusChanged; return connection; @@ -324,6 +348,7 @@ protected virtual async Task OnConnectionComplete(IServiceConnection serviceConn { await RestartFixedServiceConnectionCoreAsync(index); } + // the rest are "on demand" and are only created upon request else { @@ -332,47 +357,6 @@ protected virtual async Task OnConnectionComplete(IServiceConnection serviceConn } } - private async Task RestartFixedServiceConnectionCoreAsync(int index) - { - if (_terminated) - { - return; - } - - Func> tryNewConnection = async () => - { - var connection = CreateServiceConnectionCore(InitialConnectionType); - ReplaceFixedConnection(index, connection); - - _ = StartCoreAsync(connection); - await connection.ConnectionInitializedTask; - - return connection.Status == ServiceConnectionStatus.Connected; - }; - await _backOffPolicy.CallProbeWithBackOffAsync(tryNewConnection, GetRetryDelay); - } - - private void ReplaceFixedConnection(int index, IServiceConnection serviceConnection) - { - lock (_lock) - { - var newImmutableConnections = ServiceConnections.ToList(); - newImmutableConnections[index] = serviceConnection; - ServiceConnections = newImmutableConnections; - } - } - - private void RemoveOnDemandConnection(IServiceConnection serviceConnection) - { - lock (_lock) - { - var newImmutableConnections = ServiceConnections.ToList(); - Debug.Assert(newImmutableConnections.IndexOf(serviceConnection) >= FixedConnectionCount); - newImmutableConnections.Remove(serviceConnection); - ServiceConnections = newImmutableConnections; - } - } - protected void AddOnDemandConnection(IServiceConnection serviceConnection) { lock (_lock) @@ -426,36 +410,45 @@ protected async Task RemoveConnectionAsync(IServiceConnection c, GracefulShutdow Log.TimeoutWaitingForFinAck(Logger, retry); } - internal bool GetServiceStatus(bool active, int checkWindow, TimeSpan checkTimeSpan) + private async Task RestartFixedServiceConnectionCoreAsync(int index) { - if (active) + if (_terminated) { - _inactiveInfo = (0, null); - return true; + return; } - else + + Func> tryNewConnection = async () => { - var info = _inactiveInfo; - var last = info.last ?? DateTime.UtcNow; - var count = info.count; - count++; - _inactiveInfo = (count, last); + var connection = CreateServiceConnectionCore(InitialConnectionType); + ReplaceFixedConnection(index, connection); - // Inactive it only when it checks over 5 times and elapsed for over 10 minutes - var inactive = count >= checkWindow && DateTime.UtcNow - last >= checkTimeSpan; - return !inactive; + _ = StartCoreAsync(connection); + await connection.ConnectionInitializedTask; + + return connection.Status == ServiceConnectionStatus.Connected; + }; + await _backOffPolicy.CallProbeWithBackOffAsync(tryNewConnection, GetRetryDelay); + } + + private void ReplaceFixedConnection(int index, IServiceConnection serviceConnection) + { + lock (_lock) + { + var newImmutableConnections = ServiceConnections.ToList(); + newImmutableConnections[index] = serviceConnection; + ServiceConnections = newImmutableConnections; } } - internal static TimeSpan GetRetryDelay(int retryCount) + private void RemoveOnDemandConnection(IServiceConnection serviceConnection) { - // retry count: 0, 1, 2, 3, 4, 5, 6, ... - // delay seconds: 1, 2, 4, 8, 16, 32, 60, ... - if (retryCount > 5) + lock (_lock) { - return TimeSpan.FromMinutes(1) + ReconnectInterval; + var newImmutableConnections = ServiceConnections.ToList(); + Debug.Assert(newImmutableConnections.IndexOf(serviceConnection) >= FixedConnectionCount); + newImmutableConnections.Remove(serviceConnection); + ServiceConnections = newImmutableConnections; } - return TimeSpan.FromSeconds(1 << retryCount) + ReconnectInterval; } private void OnStatusChanged(StatusChange obj) @@ -597,12 +590,17 @@ private async Task SafeWriteAsync(ServiceMessage serviceMessage) protected internal sealed class CustomizedPingTimer : IDisposable { private readonly object _lock = new object(); + private readonly long _defaultPingTicks; private readonly string _pingName; + private readonly Func _writePing; + private readonly TimeSpan _dueTime; + private readonly TimeSpan _intervalTime; + private readonly ILogger _logger; // Considering parallel add endpoints to save time, @@ -610,6 +608,7 @@ protected internal sealed class CustomizedPingTimer : IDisposable private long _counter = 0; private long _lastSendTimestamp = 0; + private TimerAwaitable _timer; public CustomizedPingTimer(ILogger logger, string pingName, Func writePing, TimeSpan dueTime, TimeSpan intervalTime) diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/AckHandler.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/AckHandler.cs index e3fbf8c2b..2f0d75840 100644 --- a/src/Microsoft.Azure.SignalR.Common/Utilities/AckHandler.cs +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/AckHandler.cs @@ -2,70 +2,140 @@ using System.Collections.Concurrent; using System.Threading; using System.Threading.Tasks; +using Microsoft.Azure.SignalR.Common; +using Microsoft.Azure.SignalR.Protocol; + +#nullable enable namespace Microsoft.Azure.SignalR { internal sealed class AckHandler : IDisposable { - private readonly ConcurrentDictionary _acks = new ConcurrentDictionary(); + private readonly ConcurrentDictionary _acks = new(); private readonly Timer _timer; - private readonly TimeSpan _ackInterval; - private readonly TimeSpan _ackTtl; - private int _currentId = 0; + private readonly TimeSpan _defaultAckTimeout; + private volatile bool _disposed; + + private int _nextId; + private int NextId() => Interlocked.Increment(ref _nextId); + + public AckHandler(int ackIntervalInMilliseconds = 3000, int ackTtlInMilliseconds = 10000) : this(TimeSpan.FromMilliseconds(ackIntervalInMilliseconds), TimeSpan.FromMilliseconds(ackTtlInMilliseconds)) { } - public AckHandler(int ackIntervalInMilliseconds = 3000, int ackTtlInMilliseconds = 10000) + internal AckHandler(TimeSpan ackInterval, TimeSpan defaultAckTimeout) { - _ackInterval = TimeSpan.FromMilliseconds(ackIntervalInMilliseconds); - _ackTtl = TimeSpan.FromMilliseconds(ackTtlInMilliseconds); + _defaultAckTimeout = defaultAckTimeout; + _timer = new Timer(_ => CheckAcks(), null, ackInterval, ackInterval); + } - bool restoreFlow = false; - try + public Task CreateSingleAck(out int id, TimeSpan? ackTimeout = default, CancellationToken cancellationToken = default) + { + id = NextId(); + if (_disposed) { - if (!ExecutionContext.IsFlowSuppressed()) - { - ExecutionContext.SuppressFlow(); - restoreFlow = true; - } + return Task.FromResult(AckStatus.Ok); + } + var info = (IAckInfo)_acks.GetOrAdd(id, _ => new SingleAckInfo(ackTimeout ?? _defaultAckTimeout)); + if (info is MultiAckInfo) + { + throw new InvalidOperationException(); + } + cancellationToken.Register(() => info.Cancel()); + return info.Task; + } - _timer = new Timer(state => ((AckHandler)state).CheckAcks(), state: this, dueTime: _ackInterval, period: _ackInterval); + public static bool HandleAckStatus(IAckableMessage message, AckStatus status) + { + return status switch + { + AckStatus.Ok => true, + AckStatus.NotFound => false, + AckStatus.Timeout or AckStatus.InternalServerError => throw new TimeoutException($"Ack-able message {message.GetType()}(ackId: {message.AckId}) timed out."), + _ => throw new AzureSignalRException($"Ack-able message {message.GetType()}(ackId: {message.AckId}) gets error ack status {status}."), + }; + } + + public Task CreateMultiAck(out int id, TimeSpan? ackTimeout = default) + { + id = NextId(); + if (_disposed) + { + return Task.FromResult(AckStatus.Ok); } - finally + var info = (IAckInfo)_acks.GetOrAdd(id, _ => new MultiAckInfo(ackTimeout ?? _defaultAckTimeout)); + if (info is SingleAckInfo) { - // Restore the current ExecutionContext - if (restoreFlow) - { - ExecutionContext.RestoreFlow(); - } + throw new InvalidOperationException(); } + return info.Task; } - public Task CreateAck(out int id, CancellationToken cancellationToken = default) + public void TriggerAck(int id, AckStatus status = AckStatus.Ok) { - id = Interlocked.Increment(ref _currentId); - var tcs = _acks.GetOrAdd(id, _ => new AckInfo(_ackTtl)).Tcs; - cancellationToken.Register(() => tcs.TrySetCanceled()); - return tcs.Task; + if (_acks.TryGetValue(id, out var info)) + { + switch (info) + { + case IAckInfo ackInfo: + if (ackInfo.Ack(status)) + { + _acks.TryRemove(id, out _); + } + break; + default: + throw new InvalidCastException($"Expected: IAckInfo<{typeof(IAckInfo).Name}>, actual type: {info.GetType().Name}"); + } + } } - public void TriggerAck(int id, AckStatus ackStatus) + public void SetExpectedCount(int id, int expectedCount) { - if (_acks.TryRemove(id, out var ack)) + if (_disposed) + { + return; + } + + if (_acks.TryGetValue(id, out var info)) { - ack.Tcs.TrySetResult(ackStatus); + if (info is not IMultiAckInfo multiAckInfo) + { + throw new InvalidOperationException(); + } + if (multiAckInfo.SetExpectedCount(expectedCount)) + { + _acks.TryRemove(id, out _); + } } } private void CheckAcks() { + if (_disposed) + { + return; + } + var utcNow = DateTime.UtcNow; - foreach (var pair in _acks) + foreach (var item in _acks) { - if (utcNow > pair.Value.Expired) + var id = item.Key; + var ack = item.Value; + if (utcNow > ack.TimeoutAt) { - if (_acks.TryRemove(pair.Key, out var ack)) + if (_acks.TryRemove(id, out _)) { - ack.Tcs.TrySetResult(AckStatus.Timeout); + if (ack is SingleAckInfo singleAckInfo) + { + singleAckInfo.Ack(AckStatus.Timeout); + } + else if (ack is MultiAckInfo multipleAckInfo) + { + multipleAckInfo.ForceAck(AckStatus.Timeout); + } + else + { + ack.Cancel(); + } } } } @@ -73,28 +143,135 @@ private void CheckAcks() public void Dispose() { - _timer?.Dispose(); + _disposed = true; + + _timer.Dispose(); - foreach (var pair in _acks) + while (!_acks.IsEmpty) { - if (_acks.TryRemove(pair.Key, out var ack)) + foreach (var item in _acks) { - ack.Tcs.TrySetCanceled(); + var id = item.Key; + var ack = item.Value; + if (_acks.TryRemove(id, out _)) + { + ack.Cancel(); + if (ack is IDisposable disposable) + { + disposable.Dispose(); + } + } } } } - private class AckInfo + private interface IAckInfo + { + DateTime TimeoutAt { get; } + void Cancel(); + } + + private interface IAckInfo : IAckInfo + { + Task Task { get; } + bool Ack(T status); + } + + public interface IMultiAckInfo + { + bool SetExpectedCount(int expectedCount); + } + + private sealed class SingleAckInfo : IAckInfo + { + public readonly TaskCompletionSource _tcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + + public DateTime TimeoutAt { get; } + + public SingleAckInfo(TimeSpan timeout) + { + TimeoutAt = DateTime.UtcNow + timeout; + } + + public bool Ack(AckStatus status = AckStatus.Ok) => + _tcs.TrySetResult(status); + + public Task Task => _tcs.Task; + + public void Cancel() => _tcs.TrySetCanceled(); + } + + private sealed class MultiAckInfo : IAckInfo, IMultiAckInfo { - public TaskCompletionSource Tcs { get; private set; } + public readonly TaskCompletionSource _tcs = new(TaskCreationOptions.RunContinuationsAsynchronously); - public DateTime Expired { get; private set; } + private int _ackCount; + private int? _expectedCount; - public AckInfo(TimeSpan ttl) + public DateTime TimeoutAt { get; } + + public MultiAckInfo(TimeSpan timeout) { - Expired = DateTime.UtcNow.Add(ttl); - Tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + TimeoutAt = DateTime.UtcNow + timeout; } + + public bool SetExpectedCount(int expectedCount) + { + if (expectedCount < 0) + { + throw new ArgumentException("Cannot less than 0.", nameof(expectedCount)); + } + bool result; + lock (_tcs) + { + if (_expectedCount != null) + { + throw new InvalidOperationException("Cannot set expected count more than once!"); + } + _expectedCount = expectedCount; + result = expectedCount <= _ackCount; + } + if (result) + { + _tcs.TrySetResult(AckStatus.Ok); + } + return result; + } + + public bool Ack(AckStatus status = AckStatus.Ok) + { + bool result; + lock (_tcs) + { + _ackCount++; + result = _expectedCount <= _ackCount; + } + if (result) + { + _tcs.TrySetResult(status); + } + return result; + } + + /// + /// Forcely ack the multi ack regardless of the expected count. + /// + /// + /// + public bool ForceAck(AckStatus status = AckStatus.Ok) + { + lock (_tcs) + { + _ackCount = _expectedCount ?? 0; + } + _tcs.TrySetResult(status); + return true; + } + + public Task Task => _tcs.Task; + + public void Cancel() => _tcs.TrySetCanceled(); } + } -} \ No newline at end of file +} diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs index 550954d79..4ecf5cf45 100644 --- a/src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/RestClient.cs @@ -11,84 +11,84 @@ using System.Threading.Tasks; using Azure.Core.Serialization; using Microsoft.Azure.SignalR.Common; +using Microsoft.Extensions.Options; using Microsoft.Extensions.Primitives; +#nullable enable + namespace Microsoft.Azure.SignalR { internal class RestClient { private readonly IHttpClientFactory _httpClientFactory; private readonly IPayloadContentBuilder _payloadContentBuilder; - private readonly bool _enableMessageTracing; - public RestClient(IHttpClientFactory httpClientFactory, IPayloadContentBuilder contentBuilder, bool enableMessageTracing) + public RestClient(IHttpClientFactory httpClientFactory, IPayloadContentBuilder contentBuilder) { _httpClientFactory = httpClientFactory; _payloadContentBuilder = contentBuilder; - _enableMessageTracing = enableMessageTracing; } - - public RestClient(IHttpClientFactory httpClientFactory, ObjectSerializer objectSerializer, bool enableMessageTracing) : this(httpClientFactory, new JsonPayloadContentBuilder(objectSerializer), enableMessageTracing) + // TODO: Test only, will remove later + internal RestClient(IHttpClientFactory httpClientFactory) : this(httpClientFactory, new JsonPayloadContentBuilder(new JsonObjectSerializer())) { } - - public RestClient() : this(HttpClientFactory.Instance, new JsonObjectSerializer(), false) + // TODO: remove later + public RestClient() : this(HttpClientFactory.Instance) { } public Task SendAsync( RestApiEndpoint api, HttpMethod httpMethod, - string productInfo, - string methodName = null, - object[] args = null, - Func handleExpectedResponse = null, + string? methodName = null, + object[]? args = null, + Func? handleExpectedResponse = null, CancellationToken cancellationToken = default) { if (handleExpectedResponse == null) { - return SendAsync(api, httpMethod, productInfo, methodName, args, handleExpectedResponseAsync: null, cancellationToken); + return SendAsync(api, httpMethod, methodName, args, handleExpectedResponseAsync: null, cancellationToken); } - return SendAsync(api, httpMethod, productInfo, methodName, args, response => Task.FromResult(handleExpectedResponse(response)), cancellationToken); + return SendAsync(api, httpMethod, methodName, args, response => Task.FromResult(handleExpectedResponse(response)), cancellationToken); } - public async Task SendAsync( + public Task SendAsync( RestApiEndpoint api, HttpMethod httpMethod, - string productInfo, - string methodName = null, - object[] args = null, - Func> handleExpectedResponseAsync = null, + string? methodName = null, + object[]? args = null, + Func>? handleExpectedResponseAsync = null, CancellationToken cancellationToken = default) { - using var httpClient = _httpClientFactory.CreateClient(); - using var request = BuildRequest(api, httpMethod, productInfo, methodName, args); + return SendAsyncCore(Options.DefaultName, api, httpMethod, methodName, args, handleExpectedResponseAsync, cancellationToken); + } - try - { - using var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken); - if (handleExpectedResponseAsync == null) - { - await ThrowExceptionOnResponseFailureAsync(response); - } - else - { - if (!await handleExpectedResponseAsync(response)) - { - await ThrowExceptionOnResponseFailureAsync(response); - } - } - } - catch (HttpRequestException ex) - { - throw new AzureSignalRException($"An error happened when making request to {request.RequestUri}", ex); - } + public Task SendWithRetryAsync( + RestApiEndpoint api, + HttpMethod httpMethod, + string? methodName = null, + object[]? args = null, + Func? handleExpectedResponse = null, + CancellationToken cancellationToken = default) + { + return SendAsyncCore(Constants.HttpClientNames.Resilient, api, httpMethod, methodName, args, handleExpectedResponse == null ? null : response => Task.FromResult(handleExpectedResponse(response)), cancellationToken); + } + + public Task SendMessageWithRetryAsync( + RestApiEndpoint api, + HttpMethod httpMethod, + string? methodName = null, + object[]? args = null, + Func? handleExpectedResponse = null, + CancellationToken cancellationToken = default) + { + return SendAsyncCore(Constants.HttpClientNames.MessageResilient, api, httpMethod, methodName, args, handleExpectedResponse == null ? null : response => Task.FromResult(handleExpectedResponse(response)), cancellationToken); } - public async Task ThrowExceptionOnResponseFailureAsync(HttpResponseMessage response) + private async Task ThrowExceptionOnResponseFailureAsync(HttpResponseMessage response) { if (response.IsSuccessStatusCode) { @@ -106,14 +106,47 @@ public async Task ThrowExceptionOnResponseFailureAsync(HttpResponseMessage respo #endif throw response.StatusCode switch { - HttpStatusCode.BadRequest => new AzureSignalRInvalidArgumentException(response.RequestMessage.RequestUri.ToString(), innerException, detail), - HttpStatusCode.Unauthorized => new AzureSignalRUnauthorizedException(response.RequestMessage.RequestUri.ToString(), innerException), - HttpStatusCode.NotFound => new AzureSignalRInaccessibleEndpointException(response.RequestMessage.RequestUri.ToString(), innerException), - _ => new AzureSignalRRuntimeException(response.RequestMessage.RequestUri.ToString(), innerException), + HttpStatusCode.BadRequest => new AzureSignalRInvalidArgumentException(response.RequestMessage?.RequestUri?.ToString(), innerException, detail), + HttpStatusCode.Unauthorized => new AzureSignalRUnauthorizedException(response.RequestMessage?.RequestUri?.ToString(), innerException), + HttpStatusCode.NotFound => new AzureSignalRInaccessibleEndpointException(response.RequestMessage?.RequestUri?.ToString(), innerException), + _ => new AzureSignalRRuntimeException(response.RequestMessage?.RequestUri?.ToString(), innerException), }; } - private static Uri GetUri(string url, IDictionary query) + private async Task SendAsyncCore( + string httpClientName, + RestApiEndpoint api, + HttpMethod httpMethod, + string? methodName = null, + object[]? args = null, + Func>? handleExpectedResponseAsync = null, + CancellationToken cancellationToken = default) + { + using var httpClient = _httpClientFactory.CreateClient(httpClientName); + using var request = BuildRequest(api, httpMethod, methodName, args); + + try + { + using var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken); + if (handleExpectedResponseAsync == null) + { + await ThrowExceptionOnResponseFailureAsync(response); + } + else + { + if (!await handleExpectedResponseAsync(response)) + { + await ThrowExceptionOnResponseFailureAsync(response); + } + } + } + catch (HttpRequestException ex) + { + throw new AzureSignalRException($"An error happened when making request to {request.RequestUri}", ex); + } + } + + private static Uri GetUri(string url, IDictionary? query) { if (query == null || query.Count == 0) { @@ -136,40 +169,25 @@ private static Uri GetUri(string url, IDictionary query) sb.Append(sb.Length > 0 ? '&' : '?'); sb.Append(Uri.EscapeDataString(item.Key)); sb.Append('='); - sb.Append(Uri.EscapeDataString(value)); + sb.Append(Uri.EscapeDataString(value!)); } } builder.Query = sb.ToString(); return builder.Uri; } - private HttpRequestMessage BuildRequest(RestApiEndpoint api, HttpMethod httpMethod, string productInfo, string methodName = null, object[] args = null) + private HttpRequestMessage BuildRequest(RestApiEndpoint api, HttpMethod httpMethod, string? methodName = null, object[]? args = null) { var payload = httpMethod == HttpMethod.Post ? new PayloadMessage { Target = methodName, Arguments = args } : null; - if (_enableMessageTracing) - { - AddTracingId(api); - } - return GenerateHttpRequest(api.Audience, api.Query, httpMethod, payload, api.Token, productInfo); + return GenerateHttpRequest(api.Audience, api.Query, httpMethod, payload, api.Token); } - private HttpRequestMessage GenerateHttpRequest(string url, IDictionary query, HttpMethod httpMethod, PayloadMessage payload, string tokenString, string productInfo) + private HttpRequestMessage GenerateHttpRequest(string url, IDictionary query, HttpMethod httpMethod, PayloadMessage? payload, string tokenString) { var request = new HttpRequestMessage(httpMethod, GetUri(url, query)); request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", tokenString); - request.Headers.Add(Constants.AsrsUserAgent, productInfo); request.Content = _payloadContentBuilder.Build(payload); return request; } - - private void AddTracingId(RestApiEndpoint api) - { - var id = MessageWithTracingIdHelper.Generate(); - if (api.Query == null) - { - api.Query = new Dictionary(); - } - api.Query.Add(Constants.Headers.AsrsMessageTracingId, id.ToString()); - } } } \ No newline at end of file diff --git a/src/Microsoft.Azure.SignalR.Common/Utilities/TimerAwaitable.cs b/src/Microsoft.Azure.SignalR.Common/Utilities/TimerAwaitable.cs index 6c84a058d..86fd38e8e 100644 --- a/src/Microsoft.Azure.SignalR.Common/Utilities/TimerAwaitable.cs +++ b/src/Microsoft.Azure.SignalR.Common/Utilities/TimerAwaitable.cs @@ -83,7 +83,8 @@ public void Stop() // Stop should be used to trigger the call to end the loop which disposes if (_disposed) { - throw new ObjectDisposedException(GetType().FullName); + // no need to throw, to allow Stop to be called multiple times safely + return; } _running = false; diff --git a/src/Microsoft.Azure.SignalR.Management/Configuration/ServiceManagerOptions.cs b/src/Microsoft.Azure.SignalR.Management/Configuration/ServiceManagerOptions.cs index 2e59940e6..994a42395 100644 --- a/src/Microsoft.Azure.SignalR.Management/Configuration/ServiceManagerOptions.cs +++ b/src/Microsoft.Azure.SignalR.Management/Configuration/ServiceManagerOptions.cs @@ -6,6 +6,8 @@ using Azure.Core.Serialization; using Newtonsoft.Json; +#nullable enable + namespace Microsoft.Azure.SignalR.Management { /// @@ -16,7 +18,7 @@ public class ServiceManagerOptions /// /// Gets or sets the ApplicationName which will be prefixed to each hub name /// - public string ApplicationName { get; set; } + public string? ApplicationName { get; set; } /// /// Gets or sets the total number of connections from SDK to Azure SignalR Service. Default value is 1. @@ -26,17 +28,17 @@ public class ServiceManagerOptions /// /// Gets or sets a service endpoint of Azure SignalR Service instance by connection string. /// - public string ConnectionString { get; set; } = null; + public string? ConnectionString { get; set; } = null; /// /// Gets or sets multiple service endpoints of Azure SignalR Service instances. /// - public ServiceEndpoint[] ServiceEndpoints { get; set; } + public ServiceEndpoint[]? ServiceEndpoints { get; set; } /// /// Gets or sets the proxy used when ServiceManager will attempt to connect to Azure SignalR Service. /// - public IWebProxy Proxy { get; set; } + public IWebProxy? Proxy { get; set; } /// /// Gets or sets the transport type to Azure SignalR Service. Default value is Transient. @@ -48,6 +50,8 @@ public class ServiceManagerOptions /// public TimeSpan HttpClientTimeout { get; set; } = TimeSpan.FromSeconds(100); + public ServiceManagerRetryOptions? RetryOptions { get; set; } + /// /// Gets the json serializer settings that will be used to serialize content sent to Azure SignalR Service. /// @@ -57,7 +61,7 @@ public class ServiceManagerOptions /// /// If users want to use MessagePack, they should go to /// - internal ObjectSerializer ObjectSerializer { get; set; } + internal ObjectSerializer? ObjectSerializer { get; set; } /// /// Set a JSON object serializer used to serialize the data sent to clients. @@ -73,7 +77,7 @@ public void UseJsonObjectSerializer(ObjectSerializer objectSerializer) // not ready internal bool EnableMessageTracing { get; set; } = false; - internal string ProductInfo { get; set; } + internal string? ProductInfo { get; set; } internal void ValidateOptions() { diff --git a/src/Microsoft.Azure.SignalR.Management/Configuration/ServiceManagerRetryMode.cs b/src/Microsoft.Azure.SignalR.Management/Configuration/ServiceManagerRetryMode.cs new file mode 100644 index 000000000..c1094a5a2 --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Management/Configuration/ServiceManagerRetryMode.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.Azure.SignalR.Management; + +#nullable enable + +/// +/// The type of approach to apply when calculating the delay between retry attempts. +/// +public enum ServiceManagerRetryMode +{ + /// + /// Retry attempts happen at fixed intervals; each delay is a consistent duration. + /// + Fixed, + /// + /// Retry attempts will delay based on a backoff strategy, where each attempt will + /// increase the duration that it waits before retrying. + /// + Exponential +} \ No newline at end of file diff --git a/src/Microsoft.Azure.SignalR.Management/Configuration/ServiceManagerRetryOptions.cs b/src/Microsoft.Azure.SignalR.Management/Configuration/ServiceManagerRetryOptions.cs new file mode 100644 index 000000000..97ab35bdd --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Management/Configuration/ServiceManagerRetryOptions.cs @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; + +namespace Microsoft.Azure.SignalR.Management; + +#nullable enable + +public class ServiceManagerRetryOptions +{ + /// + /// The maximum number of retry attempts before giving up. + /// + public int MaxRetries { get; set; } = 3; + + /// + /// The delay between retry attempts for a fixed approach or the delay + /// on which to base calculations for a backoff-based approach. + /// + public TimeSpan Delay { get; set; } = TimeSpan.FromSeconds(0.8); + + /// + /// The maximum permissible delay between retry attempts. + /// + public TimeSpan MaxDelay { get; set; } = TimeSpan.FromMinutes(1); + + /// + /// The approach to use for calculating retry delays. + /// + public ServiceManagerRetryMode Mode { get; set; } = ServiceManagerRetryMode.Fixed; +} + + diff --git a/src/Microsoft.Azure.SignalR.Management/DependencyInjectionExtensions.cs b/src/Microsoft.Azure.SignalR.Management/DependencyInjectionExtensions.cs index 83a3c6c99..0a9f3bfe1 100644 --- a/src/Microsoft.Azure.SignalR.Management/DependencyInjectionExtensions.cs +++ b/src/Microsoft.Azure.SignalR.Management/DependencyInjectionExtensions.cs @@ -2,8 +2,10 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; using System.Linq; +using System.Net; using System.Net.Http; using System.Reflection; +using System.Threading; using System.Threading.Tasks; using Azure.Core.Serialization; using Microsoft.AspNetCore.Connections; @@ -156,15 +158,93 @@ private static IServiceCollection TrySetProductInfo(this IServiceCollection serv return services.Configure(o => o.ProductInfo ??= productInfo); } - private static IServiceCollection AddRestClientFactory(this IServiceCollection services) => services - .AddHttpClient(Options.DefaultName, (sp, client) => client.Timeout = sp.GetRequiredService>().Value.HttpClientTimeout) - .ConfigurePrimaryHttpMessageHandler(sp => new HttpClientHandler() { Proxy = sp.GetRequiredService>().Value.Proxy }).Services - .AddSingleton(sp => + private static IServiceCollection AddRestClientFactory(this IServiceCollection services) + { + // For AAD, health check. + services + .AddHttpClient(Options.DefaultName, (sp, client) => + { + client.Timeout = sp.GetRequiredService>().Value.HttpClientTimeout; + ConfigureProduceInfo(sp, client); + }) + .ConfigurePrimaryHttpMessageHandler(ConfigureProxy); + + // For other data plane APIs. + services.AddSingleton(sp => + { + var options = sp.GetRequiredService>().Value; + var retryOptions = options.RetryOptions; + return retryOptions == null + ? new DummyBackOffPolicy() + : retryOptions.Mode switch + { + ServiceManagerRetryMode.Fixed => ActivatorUtilities.CreateInstance(sp), + ServiceManagerRetryMode.Exponential => ActivatorUtilities.CreateInstance(sp), + _ => throw new NotSupportedException($"Retry mode {retryOptions.Mode} is not supported.") + }; + }); + services + .AddHttpClient(Constants.HttpClientNames.Resilient, (sp, client) => + { + var options = sp.GetRequiredService>().Value; + if (options.RetryOptions == null) + { + client.Timeout = options.HttpClientTimeout; + } + else + { + // The timeout is enforced by TimeoutHttpMessageHandler. + client.Timeout = Timeout.InfiniteTimeSpan; + } + ConfigureProduceInfo(sp, client); + ConfigureMessageTracingId(sp, client); + }) + .ConfigurePrimaryHttpMessageHandler(ConfigureProxy) + .AddHttpMessageHandler(sp => ActivatorUtilities.CreateInstance(sp, (HttpStatusCode code) => IsTransientErrorForNonMessageApi(code))) + .AddHttpMessageHandler(sp => ActivatorUtilities.CreateInstance(sp)); + + services + .AddHttpClient(Constants.HttpClientNames.MessageResilient, (sp, client) => + { + client.Timeout = sp.GetRequiredService>().Value.HttpClientTimeout; + ConfigureProduceInfo(sp, client); + ConfigureMessageTracingId(sp, client); + }) + .ConfigurePrimaryHttpMessageHandler(ConfigureProxy) + .AddHttpMessageHandler(sp => ActivatorUtilities.CreateInstance(sp, (HttpStatusCode code) => IsTransientErrorAndIdempotentForMessageApi(code))); + + services.AddSingleton(sp => { var options = sp.GetRequiredService>().Value; var productInfo = options.ProductInfo; var httpClientFactory = sp.GetRequiredService(); return new RestClientFactory(productInfo, httpClientFactory); }); + + return services; + + static HttpMessageHandler ConfigureProxy(IServiceProvider sp) => new HttpClientHandler() { Proxy = sp.GetRequiredService>().Value.Proxy }; + + static bool IsTransientErrorAndIdempotentForMessageApi(HttpStatusCode code) => + // Runtime returns 500 for timeout errors too, to avoid duplicate message, we exclude 500 here. + code > HttpStatusCode.InternalServerError; + + static bool IsTransientErrorForNonMessageApi(HttpStatusCode code) => + code >= HttpStatusCode.InternalServerError || + code == HttpStatusCode.RequestTimeout; + + static void ConfigureProduceInfo(IServiceProvider sp, HttpClient client) => + client.DefaultRequestHeaders.Add(Constants.AsrsUserAgent, sp.GetRequiredService>().Value.ProductInfo ?? + // The following value should not be used. + "Microsoft.Azure.SignalR.Management/"); + + static void ConfigureMessageTracingId(IServiceProvider sp, HttpClient client) + { + if (sp.GetRequiredService>().Value.EnableMessageTracing) + { + client.DefaultRequestHeaders.Add(Constants.Headers.AsrsMessageTracingId, MessageWithTracingIdHelper.Generate().ToString()); + } + } + } } } diff --git a/src/Microsoft.Azure.SignalR.Management/HubInstanceFactories/ServiceHubLifetimeManagerFactory.cs b/src/Microsoft.Azure.SignalR.Management/HubInstanceFactories/ServiceHubLifetimeManagerFactory.cs index 207303517..dffa6d3da 100644 --- a/src/Microsoft.Azure.SignalR.Management/HubInstanceFactories/ServiceHubLifetimeManagerFactory.cs +++ b/src/Microsoft.Azure.SignalR.Management/HubInstanceFactories/ServiceHubLifetimeManagerFactory.cs @@ -38,8 +38,8 @@ public IServiceHubLifetimeManager Create(string hubName) where THub var payloadBuilderResolver = _serviceProvider.GetRequiredService(); var httpClientFactory = _serviceProvider.GetRequiredService(); var serviceEndpoint = _serviceProvider.GetRequiredService().Endpoints.First().Key; - var restClient = new RestClient(httpClientFactory, payloadBuilderResolver.GetPayloadContentBuilder(), _options.EnableMessageTracing); - return new RestHubLifetimeManager(hubName, serviceEndpoint, _options.ProductInfo, _options.ApplicationName, restClient); + var restClient = new RestClient(httpClientFactory, payloadBuilderResolver.GetPayloadContentBuilder()); + return new RestHubLifetimeManager(hubName, serviceEndpoint, _options.ApplicationName, restClient); } default: throw new InvalidEnumArgumentException(nameof(ServiceManagerOptions.ServiceTransportType), (int)_options.ServiceTransportType, typeof(ServiceTransportType)); } diff --git a/src/Microsoft.Azure.SignalR.Management/Resilient/DummyBackoffPolicy.cs b/src/Microsoft.Azure.SignalR.Management/Resilient/DummyBackoffPolicy.cs new file mode 100644 index 000000000..a160273d4 --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Management/Resilient/DummyBackoffPolicy.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.Azure.SignalR.Management; + +internal class DummyBackOffPolicy : IBackOffPolicy +{ + public IEnumerable GetDelays() => Enumerable.Empty(); +} diff --git a/src/Microsoft.Azure.SignalR.Management/Resilient/ExponentialBackoffPolicy.cs b/src/Microsoft.Azure.SignalR.Management/Resilient/ExponentialBackoffPolicy.cs new file mode 100644 index 000000000..a9fbbf7fa --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Management/Resilient/ExponentialBackoffPolicy.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.Options; + +namespace Microsoft.Azure.SignalR.Management; + +internal class ExponentialBackOffPolicy : IBackOffPolicy +{ + private readonly int _maxRetries; + private readonly TimeSpan _minDelay; + private readonly TimeSpan _maxDelay; + + public ExponentialBackOffPolicy(IOptions options) + { + var retryOptions = options.Value.RetryOptions ?? throw new ArgumentException(); + if (retryOptions.Mode != ServiceManagerRetryMode.Exponential) + { + throw new ArgumentException(); + } + _maxRetries = retryOptions.MaxRetries; + _minDelay = retryOptions.Delay; + _maxDelay = retryOptions.MaxDelay; + } + public IEnumerable GetDelays() + { + var lastDelay = TimeSpan.MinValue; + for (var i = 0; i < _maxRetries; i++) + { + if (lastDelay >= _maxDelay) + { + yield return _maxDelay; + } + else + { + var delay = TimeSpan.FromMilliseconds((1 << i) * (int)_minDelay.TotalMilliseconds); + lastDelay = delay < _maxDelay ? delay : _maxDelay; + yield return lastDelay; + } + } + } +} diff --git a/src/Microsoft.Azure.SignalR.Management/Resilient/FixedBackoffPolicy.cs b/src/Microsoft.Azure.SignalR.Management/Resilient/FixedBackoffPolicy.cs new file mode 100644 index 000000000..45c75eb35 --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Management/Resilient/FixedBackoffPolicy.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.Options; + +namespace Microsoft.Azure.SignalR.Management; + +internal class FixedBackOffPolicy : IBackOffPolicy +{ + private readonly int _maxRetries; + private readonly TimeSpan _delay; + public FixedBackOffPolicy(IOptions options) + { + var retryOptions = options.Value.RetryOptions ?? throw new ArgumentException(); + if (retryOptions.Mode != ServiceManagerRetryMode.Fixed) + { + throw new ArgumentException(); + } + _maxRetries = retryOptions.MaxRetries; + _delay = retryOptions.Delay; + } + + public IEnumerable GetDelays() + { + for (var i = 0; i < _maxRetries; i++) + { + yield return _delay; + } + } +} diff --git a/src/Microsoft.Azure.SignalR.Management/Resilient/IBackoffPolicy.cs b/src/Microsoft.Azure.SignalR.Management/Resilient/IBackoffPolicy.cs new file mode 100644 index 000000000..73a799567 --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Management/Resilient/IBackoffPolicy.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.Options; + +namespace Microsoft.Azure.SignalR.Management; + +#nullable enable + +internal interface IBackOffPolicy +{ + IEnumerable GetDelays(); +} diff --git a/src/Microsoft.Azure.SignalR.Management/Resilient/RetryHttpMessageHandler.cs b/src/Microsoft.Azure.SignalR.Management/Resilient/RetryHttpMessageHandler.cs new file mode 100644 index 000000000..c46be254c --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Management/Resilient/RetryHttpMessageHandler.cs @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Azure.SignalR.Common; + +namespace Microsoft.Azure.SignalR.Management; + +#nullable enable + +internal class RetryHttpMessageHandler : DelegatingHandler +{ + private readonly IBackOffPolicy _retryDelayProvider; + private readonly Func _canRetry; + + public RetryHttpMessageHandler(IBackOffPolicy retryDelayProvider, Func transientErrorPredicate) + { + _retryDelayProvider = retryDelayProvider; + _canRetry = transientErrorPredicate; + } + + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + IList? exceptions = null; + IEnumerator? delays = null; + do + { + Exception? ex; + try + { + var response = await base.SendAsync(request, cancellationToken); + if (_canRetry(response.StatusCode)) + { + var innerException = new HttpRequestException( + $"Response status code does not indicate success: {(int)response.StatusCode} ({response.ReasonPhrase})"); + ex = new AzureSignalRRuntimeException(request.RequestUri?.ToString(), innerException); + response.Dispose(); + } + else + { + return response; + } + } + catch (TaskCanceledException operationCanceledException) when (!cancellationToken.IsCancellationRequested && operationCanceledException.InnerException is TimeoutException) + { + // Thrown by our timeout handler + ex = operationCanceledException; + } + delays ??= _retryDelayProvider.GetDelays().GetEnumerator(); + if (!delays.MoveNext()) + { + if (exceptions == null) + { + throw ex; + } + else + { + exceptions.Add(ex); + throw new AzureSignalRRuntimeException(request.RequestUri?.ToString(), new AggregateException(exceptions)); + } + } + else + { + exceptions ??= new List(); + exceptions.Add(ex); + } + await Task.Delay(delays.Current, cancellationToken); + } while (true); + } +} diff --git a/src/Microsoft.Azure.SignalR.Management/Resilient/TimeoutHttpMessageHandler.cs b/src/Microsoft.Azure.SignalR.Management/Resilient/TimeoutHttpMessageHandler.cs new file mode 100644 index 000000000..46b0bb434 --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Management/Resilient/TimeoutHttpMessageHandler.cs @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Options; + +namespace Microsoft.Azure.SignalR.Management; + +#nullable enable + +internal class TimeoutHttpMessageHandler : DelegatingHandler +{ + private readonly bool _enableTimeout = false; + private readonly TimeSpan _timeout; + public TimeoutHttpMessageHandler(IOptions serviceManagerOptions) + { + var options = serviceManagerOptions.Value; + if (options.RetryOptions == null) + { + // Timeout handled by HttpClient for backward compatibility + _timeout = Timeout.InfiniteTimeSpan; + } + else + { + _timeout = options.HttpClientTimeout; + _enableTimeout = true; + } + } + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + if (_enableTimeout) + { + var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + cts.CancelAfter(_timeout); + try + { + return await base.SendAsync(request, cts.Token); + } + catch (OperationCanceledException ex) when (!cancellationToken.IsCancellationRequested) + { + throw new TaskCanceledException($"The request was canceled due to the configured HttpClient.Timeout of {_timeout.TotalSeconds} seconds elapsing.", new TimeoutException(ex.Message, ex)); + } + } + return await base.SendAsync(request, cancellationToken); + } +} diff --git a/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs b/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs index 492169437..cf2b968d5 100644 --- a/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs +++ b/src/Microsoft.Azure.SignalR.Management/RestHubLifetimeManager.cs @@ -21,14 +21,12 @@ internal class RestHubLifetimeManager : HubLifetimeManager, IService private readonly RestClient _restClient; private readonly RestApiProvider _restApiProvider; - private readonly string _productInfo; private readonly string _hubName; private readonly string _appName; - public RestHubLifetimeManager(string hubName, ServiceEndpoint endpoint, string productInfo, string appName, RestClient restClient) + public RestHubLifetimeManager(string hubName, ServiceEndpoint endpoint, string appName, RestClient restClient) { _restApiProvider = new RestApiProvider(endpoint); - _productInfo = productInfo; _appName = appName; _hubName = hubName; _restClient = restClient; @@ -47,7 +45,7 @@ public override async Task AddToGroupAsync(string connectionId, string groupName } var api = await _restApiProvider.GetConnectionGroupManagementEndpointAsync(_appName, _hubName, connectionId, groupName); - await _restClient.SendAsync(api, HttpMethod.Put, _productInfo, handleExpectedResponse: static response => FilterExpectedResponse(response, ErrorCodes.ErrorConnectionNotExisted), cancellationToken: cancellationToken); + await _restClient.SendWithRetryAsync(api, HttpMethod.Put, handleExpectedResponse: static response => FilterExpectedResponse(response, ErrorCodes.ErrorConnectionNotExisted), cancellationToken: cancellationToken); } public override Task OnConnectedAsync(HubConnectionContext connection) @@ -73,7 +71,7 @@ public override async Task RemoveFromGroupAsync(string connectionId, string grou } var api = await _restApiProvider.GetConnectionGroupManagementEndpointAsync(_appName, _hubName, connectionId, groupName); - await _restClient.SendAsync(api, HttpMethod.Delete, _productInfo, handleExpectedResponseAsync: null, cancellationToken: cancellationToken); + await _restClient.SendWithRetryAsync(api, HttpMethod.Delete, handleExpectedResponse: null, cancellationToken: cancellationToken); } public async Task RemoveFromAllGroupsAsync(string connectionId, CancellationToken cancellationToken = default) @@ -84,7 +82,7 @@ public async Task RemoveFromAllGroupsAsync(string connectionId, CancellationToke } var api = await _restApiProvider.GetRemoveConnectionFromAllGroupsAsync(_appName, _hubName, connectionId); - await _restClient.SendAsync(api, HttpMethod.Delete, _productInfo, handleExpectedResponseAsync: null, cancellationToken: cancellationToken); + await _restClient.SendWithRetryAsync(api, HttpMethod.Delete, handleExpectedResponse: null, cancellationToken: cancellationToken); } public override Task SendAllAsync(string methodName, object[] args, CancellationToken cancellationToken = default) @@ -100,7 +98,7 @@ public override async Task SendAllExceptAsync(string methodName, object[] args, } var api = await _restApiProvider.GetBroadcastEndpointAsync(_appName, _hubName, excluded: excludedConnectionIds); - await _restClient.SendAsync(api, HttpMethod.Post, _productInfo, methodName, args, handleExpectedResponseAsync: null, cancellationToken: cancellationToken); + await _restClient.SendMessageWithRetryAsync(api, HttpMethod.Post, methodName, args, handleExpectedResponse: null, cancellationToken: cancellationToken); } public override async Task SendConnectionAsync(string connectionId, string methodName, object[] args, CancellationToken cancellationToken = default) @@ -116,7 +114,7 @@ public override async Task SendConnectionAsync(string connectionId, string metho } var api = await _restApiProvider.GetSendToConnectionEndpointAsync(_appName, _hubName, connectionId); - await _restClient.SendAsync(api, HttpMethod.Post, _productInfo, methodName, args, handleExpectedResponseAsync: null, cancellationToken: cancellationToken); + await _restClient.SendMessageWithRetryAsync(api, HttpMethod.Post, methodName, args, handleExpectedResponse: null, cancellationToken: cancellationToken); } public override async Task SendConnectionsAsync(IReadOnlyList connectionIds, string methodName, object[] args, CancellationToken cancellationToken = default) @@ -142,7 +140,7 @@ public override async Task SendGroupExceptAsync(string groupName, string methodN } var api = await _restApiProvider.GetSendToGroupEndpointAsync(_appName, _hubName, groupName, excluded: excludedConnectionIds); - await _restClient.SendAsync(api, HttpMethod.Post, _productInfo, methodName, args, handleExpectedResponseAsync: null, cancellationToken: cancellationToken); + await _restClient.SendMessageWithRetryAsync(api, HttpMethod.Post, methodName, args, handleExpectedResponse: null, cancellationToken: cancellationToken); } public override async Task SendGroupsAsync(IReadOnlyList groupNames, string methodName, object[] args, CancellationToken cancellationToken = default) @@ -173,7 +171,7 @@ public override async Task SendUserAsync(string userId, string methodName, objec } var api = await _restApiProvider.GetSendToUserEndpointAsync(_appName, _hubName, userId); - await _restClient.SendAsync(api, HttpMethod.Post, _productInfo, methodName, args, handleExpectedResponseAsync: null, cancellationToken: cancellationToken); + await _restClient.SendMessageWithRetryAsync(api, HttpMethod.Post, methodName, args, handleExpectedResponse: null, cancellationToken: cancellationToken); } public override async Task SendUsersAsync(IReadOnlyList userIds, string methodName, object[] args, CancellationToken cancellationToken = default) @@ -196,7 +194,7 @@ public async Task UserAddToGroupAsync(string userId, string groupName, Cancellat ValidateUserIdAndGroupName(userId, groupName); var api = await _restApiProvider.GetUserGroupManagementEndpointAsync(_appName, _hubName, userId, groupName); - await _restClient.SendAsync(api, HttpMethod.Put, _productInfo, handleExpectedResponseAsync: null, cancellationToken: cancellationToken); + await _restClient.SendWithRetryAsync(api, HttpMethod.Put, handleExpectedResponse: null, cancellationToken: cancellationToken); } public async Task UserAddToGroupAsync(string userId, string groupName, TimeSpan ttl, CancellationToken cancellationToken = default) @@ -212,7 +210,7 @@ public async Task UserAddToGroupAsync(string userId, string groupName, TimeSpan { ["ttl"] = ((int)ttl.TotalSeconds).ToString(), }; - await _restClient.SendAsync(api, HttpMethod.Put, _productInfo, handleExpectedResponseAsync: null, cancellationToken: cancellationToken); + await _restClient.SendWithRetryAsync(api, HttpMethod.Put, handleExpectedResponse: null, cancellationToken: cancellationToken); } public async Task UserRemoveFromGroupAsync(string userId, string groupName, CancellationToken cancellationToken = default) @@ -220,20 +218,20 @@ public async Task UserRemoveFromGroupAsync(string userId, string groupName, Canc ValidateUserIdAndGroupName(userId, groupName); var api = await _restApiProvider.GetUserGroupManagementEndpointAsync(_appName, _hubName, userId, groupName); - await _restClient.SendAsync(api, HttpMethod.Delete, _productInfo, handleExpectedResponseAsync: null, cancellationToken: cancellationToken); + await _restClient.SendWithRetryAsync(api, HttpMethod.Delete, handleExpectedResponse: null, cancellationToken: cancellationToken); } public async Task UserRemoveFromAllGroupsAsync(string userId, CancellationToken cancellationToken = default) { var api = await _restApiProvider.GetRemoveUserFromAllGroupsAsync(_appName, _hubName, userId); - await _restClient.SendAsync(api, HttpMethod.Delete, _productInfo, handleExpectedResponseAsync: null, cancellationToken: cancellationToken); + await _restClient.SendWithRetryAsync(api, HttpMethod.Delete, handleExpectedResponse: null, cancellationToken: cancellationToken); } public async Task IsUserInGroup(string userId, string groupName, CancellationToken cancellationToken = default) { var isUserInGroup = false; var api = await _restApiProvider.GetUserGroupManagementEndpointAsync(_appName, _hubName, userId, groupName); - await _restClient.SendAsync(api, HttpMethod.Get, _productInfo, handleExpectedResponse: response => + await _restClient.SendWithRetryAsync(api, HttpMethod.Get, handleExpectedResponse: response => { isUserInGroup = response.StatusCode == HttpStatusCode.OK; return FilterExpectedResponse(response, ErrorCodes.InfoUserNotInGroup); @@ -248,7 +246,7 @@ public async Task CloseConnectionAsync(string connectionId, string reason, Cance throw new ArgumentException(NullOrEmptyStringErrorMessage, nameof(connectionId)); } var api = await _restApiProvider.GetCloseConnectionEndpointAsync(_appName, _hubName, connectionId, reason); - await _restClient.SendAsync(api, HttpMethod.Delete, _productInfo, handleExpectedResponse: static response => FilterExpectedResponse(response, ErrorCodes.WarningConnectionNotExisted), cancellationToken: cancellationToken); + await _restClient.SendWithRetryAsync(api, HttpMethod.Delete, handleExpectedResponse: static response => FilterExpectedResponse(response, ErrorCodes.WarningConnectionNotExisted), cancellationToken: cancellationToken); } private static void ValidateUserIdAndGroupName(string userId, string groupName) @@ -272,7 +270,7 @@ public async Task ConnectionExistsAsync(string connectionId, CancellationT } var exists = false; var api = await _restApiProvider.GetCheckConnectionExistsEndpointAsync(_appName, _hubName, connectionId); - await _restClient.SendAsync(api, HttpMethod.Head, _productInfo, handleExpectedResponse: response => + await _restClient.SendWithRetryAsync(api, HttpMethod.Head, handleExpectedResponse: response => { exists = response.StatusCode == HttpStatusCode.OK; return FilterExpectedResponse(response, ErrorCodes.WarningConnectionNotExisted); @@ -288,7 +286,7 @@ public async Task UserExistsAsync(string userId, CancellationToken cancell } var exists = false; var api = await _restApiProvider.GetCheckUserExistsEndpointAsync(_appName, _hubName, userId); - await _restClient.SendAsync(api, HttpMethod.Head, _productInfo, handleExpectedResponse: response => + await _restClient.SendWithRetryAsync(api, HttpMethod.Head, handleExpectedResponse: response => { exists = response.StatusCode == HttpStatusCode.OK; return FilterExpectedResponse(response, ErrorCodes.WarningUserNotExisted); @@ -304,7 +302,7 @@ public async Task GroupExistsAsync(string groupName, CancellationToken can } var exists = false; var api = await _restApiProvider.GetCheckGroupExistsEndpointAsync(_appName, _hubName, groupName); - await _restClient.SendAsync(api, HttpMethod.Head, _productInfo, handleExpectedResponse: response => + await _restClient.SendWithRetryAsync(api, HttpMethod.Head, handleExpectedResponse: response => { exists = response.StatusCode == HttpStatusCode.OK; return FilterExpectedResponse(response, ErrorCodes.WarningGroupNotExisted); diff --git a/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs b/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs index 33dfc7fd1..330dd76e1 100644 --- a/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs +++ b/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs @@ -51,7 +51,8 @@ public Task AddInvocation(string connectionId, string invocationId, Cancel } else { - tcs.TrySetException(new Exception(completionMessage.Error)); + // Follow https://github.com/dotnet/aspnetcore/blob/v8.0.0-rc.2.23480.2/src/SignalR/common/Shared/ClientResultsManager.cs#L30 + tcs.TrySetException(new HubException(completionMessage.Error)); } }) ); diff --git a/src/Microsoft.Azure.SignalR/EndpointRouters/DefaultEndpointRouter.cs b/src/Microsoft.Azure.SignalR/EndpointRouters/DefaultEndpointRouter.cs index 5f98d718b..a00c726cd 100644 --- a/src/Microsoft.Azure.SignalR/EndpointRouters/DefaultEndpointRouter.cs +++ b/src/Microsoft.Azure.SignalR/EndpointRouters/DefaultEndpointRouter.cs @@ -12,11 +12,10 @@ namespace Microsoft.Azure.SignalR internal class DefaultEndpointRouter : DefaultMessageRouter, IEndpointRouter { /// - /// Randomly select from the available endpoints + /// Select an endpoint for negotiate request /// /// The http context of the incoming request /// All the available endpoints - /// public ServiceEndpoint GetNegotiateEndpoint(HttpContext context, IEnumerable endpoints) { // get primary endpoints snapshot @@ -28,7 +27,7 @@ public ServiceEndpoint GetNegotiateEndpoint(HttpContext context, IEnumerable - /// The availbale endpoints + /// The available endpoints private ServiceEndpoint[] GetNegotiateEndpoints(IEnumerable endpoints) { var primary = endpoints.Where(s => s.Online && s.EndpointType == EndpointType.Primary).ToArray(); @@ -49,8 +48,7 @@ private ServiceEndpoint[] GetNegotiateEndpoints(IEnumerable end /// /// Choose endpoint randomly by weight. - /// The weight is defined as the remaining connection quota. - /// The least weight is set to 1. So instance with no connection quota still has chance. + /// The weight is defined as (the remaining connection quota / the connection capacity). /// private ServiceEndpoint GetEndpointAccordingToWeight(ServiceEndpoint[] availableEndpoints) { @@ -58,7 +56,7 @@ private ServiceEndpoint GetEndpointAccordingToWeight(ServiceEndpoint[] available if (availableEndpoints.Any(endpoint => endpoint.EndpointMetrics.ConnectionCapacity == 0) || availableEndpoints.Length == 1) { - return GetEndpointRandomly(availableEndpoints); + return availableEndpoints[StaticRandom.Next(availableEndpoints.Length)]; } var we = new int[availableEndpoints.Length]; @@ -69,7 +67,7 @@ private ServiceEndpoint GetEndpointAccordingToWeight(ServiceEndpoint[] available var remain = endpointMetrics.ConnectionCapacity - (endpointMetrics.ClientConnectionCount + endpointMetrics.ServerConnectionCount); - var weight = remain > 0 ? remain : 1; + var weight = Math.Max((int)((double)remain / endpointMetrics.ConnectionCapacity * 1000), 1); totalCapacity += weight; we[i] = totalCapacity; } @@ -78,10 +76,5 @@ private ServiceEndpoint GetEndpointAccordingToWeight(ServiceEndpoint[] available return availableEndpoints[Array.FindLastIndex(we, x => x <= index) + 1]; } - - private static ServiceEndpoint GetEndpointRandomly(ServiceEndpoint[] availableEndpoints) - { - return availableEndpoints[StaticRandom.Next(availableEndpoints.Length)]; - } } } \ No newline at end of file diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs index 877166c2c..b9d8e57f3 100644 --- a/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs +++ b/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs @@ -54,7 +54,9 @@ internal class ClientConnectionContext : ConnectionContext, IConnectionStatFeature { private const int WritingState = 1; + private const int CompletedState = 2; + private const int IdleState = 0; private static readonly PipeOptions DefaultPipeOptions = new PipeOptions(pauseWriterThreshold: 0, @@ -63,12 +65,13 @@ internal class ClientConnectionContext : ConnectionContext, useSynchronizationContext: false); private readonly TaskCompletionSource _connectionEndTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - private readonly CancellationTokenSource _abortOutgoingCts = new CancellationTokenSource(); - private int _connectionState = IdleState; + private readonly CancellationTokenSource _abortOutgoingCts = new CancellationTokenSource(); private readonly object _heartbeatLock = new object(); + private int _connectionState = IdleState; + private List<(Action handler, object state)> _heartbeatHandlers; private volatile bool _abortOnClose = true; @@ -175,6 +178,7 @@ public async Task WriteMessageAsync(ReadOnlySequence payload) { _lastMessageReceivedAt = DateTime.UtcNow.Ticks; _receivedBytes += payload.Length; + // Start write await WriteMessageAsyncCore(payload); } @@ -237,6 +241,53 @@ public void CancelOutgoing(int millisecondsDelay = 0) } } + internal static bool TryGetRemoteIpAddress(IHeaderDictionary headers, out IPAddress address) + { + var forwardedFor = headers.GetCommaSeparatedValues("X-Forwarded-For"); + if (forwardedFor.Length > 0 && IPAddress.TryParse(forwardedFor[0], out address)) + { + return true; + } + address = null; + return false; + } + + private static void ProcessQuery(string queryString, out string originalPath) + { + originalPath = string.Empty; + var query = QueryHelpers.ParseNullableQuery(queryString); + if (query == null) + { + return; + } + + if (query.TryGetValue(Constants.QueryParameter.RequestCulture, out var culture)) + { + SetCurrentThreadCulture(culture.FirstOrDefault()); + } + if (query.TryGetValue(Constants.QueryParameter.OriginalPath, out var path)) + { + originalPath = path.FirstOrDefault(); + } + } + + private static void SetCurrentThreadCulture(string cultureName) + { + if (!string.IsNullOrEmpty(cultureName)) + { + try + { + var requestCulture = new RequestCulture(cultureName); + CultureInfo.CurrentCulture = requestCulture.Culture; + CultureInfo.CurrentUICulture = requestCulture.UICulture; + } + catch (Exception) + { + // skip invalid culture, normal won't hit. + } + } + } + private FeatureCollection BuildFeatures(OpenConnectionMessage serviceMessage) { var features = new FeatureCollection(); @@ -311,52 +362,5 @@ private string GetInstanceId(IDictionary header) } return string.Empty; } - - internal static bool TryGetRemoteIpAddress(IHeaderDictionary headers, out IPAddress address) - { - var forwardedFor = headers.GetCommaSeparatedValues("X-Forwarded-For"); - if (forwardedFor.Length > 0 && IPAddress.TryParse(forwardedFor[0], out address)) - { - return true; - } - address = null; - return false; - } - - private static void ProcessQuery(string queryString, out string originalPath) - { - originalPath = string.Empty; - var query = QueryHelpers.ParseNullableQuery(queryString); - if (query == null) - { - return; - } - - if (query.TryGetValue(Constants.QueryParameter.RequestCulture, out var culture)) - { - SetCurrentThreadCulture(culture.FirstOrDefault()); - } - if (query.TryGetValue(Constants.QueryParameter.OriginalPath, out var path)) - { - originalPath = path.FirstOrDefault(); - } - } - - private static void SetCurrentThreadCulture(string cultureName) - { - if (!string.IsNullOrEmpty(cultureName)) - { - try - { - var requestCulture = new RequestCulture(cultureName); - CultureInfo.CurrentCulture = requestCulture.Culture; - CultureInfo.CurrentUICulture = requestCulture.UICulture; - } - catch (Exception) - { - // skip invalid culture, normal won't hit. - } - } - } } } diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs index f67636c22..4c04c0647 100644 --- a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs +++ b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs @@ -21,16 +21,20 @@ internal partial class ServiceConnection : ServiceConnectionBase { private const int DefaultCloseTimeoutMilliseconds = 30000; + private const string ClientConnectionCountInHub = "#clientInHub"; + + private const string ClientConnectionCountInServiceConnection = "#client"; + // Fix issue: https://github.com/Azure/azure-signalr/issues/198 // .NET Framework has restriction about reserved string as the header name like "User-Agent" private static readonly Dictionary CustomHeader = new Dictionary { { Constants.AsrsUserAgent, ProductInfo.GetProductInfo() } }; - private const string ClientConnectionCountInHub = "#clientInHub"; - private const string ClientConnectionCountInServiceConnection = "#client"; - private readonly IConnectionFactory _connectionFactory; + private readonly IClientConnectionFactory _clientConnectionFactory; + private readonly int _closeTimeOutMilliseconds; + private readonly IClientConnectionManager _clientConnectionManager; private readonly ConcurrentDictionary _connectionIds = @@ -43,6 +47,8 @@ internal partial class ServiceConnection : ServiceConnectionBase private readonly IClientInvocationManager _clientInvocationManager; + private readonly AckHandler _ackHandler; + public Action ConfigureContext { get; set; } public ServiceConnection(IServiceProtocol serviceProtocol, @@ -57,6 +63,7 @@ public ServiceConnection(IServiceProtocol serviceProtocol, IServiceMessageHandler serviceMessageHandler, IServiceEventHandler serviceEventHandler, IClientInvocationManager clientInvocationManager, + AckHandler ackHandler, ServiceConnectionType connectionType = ServiceConnectionType.Default, GracefulShutdownMode mode = GracefulShutdownMode.Off, int closeTimeOutMilliseconds = DefaultCloseTimeoutMilliseconds @@ -68,6 +75,7 @@ public ServiceConnection(IServiceProtocol serviceProtocol, _clientConnectionFactory = clientConnectionFactory; _closeTimeOutMilliseconds = closeTimeOutMilliseconds; _clientInvocationManager = clientInvocationManager; + _ackHandler = ackHandler; } protected override Task CreateConnection(string target = null) @@ -155,10 +163,12 @@ protected override Task OnClientDisconnectedAsync(CloseConnectionMessage closeCo { context.AbortOnClose = false; context.Features.Set(new ConnectionMigrationFeature(ServerId, to)); + // We have to prevent SignalR `{type: 7}` (close message) from reaching our client while doing migration. // Since all data messages will be sent to `ServiceConnection` directly. // We can simply ignore all messages came from the application. context.CancelOutgoing(); + // The close connection message must be the last message, so we could complete the pipe. context.CompleteIncoming(); } @@ -211,14 +221,15 @@ protected override Task OnPingMessageAsync(PingMessage pingMessage) if (RuntimeServicePingMessage.TryGetOffline(pingMessage, out var instanceId)) { _clientInvocationManager.Caller.CleanupInvocationsByInstance(instanceId); + // Router invocations will be cleanup by its `CleanupInvocationsByConnection`, which is called by `RemoveClientConnection`. - // In `base.OnPingMessageAsync`, `CleanupClientConnections(instanceId)` will finally execute `RemoveClientConnection` for each ConnectionId. + // In `base.OnPingMessageAsync`, `CleanupClientConnections(instanceId)` will finally execute `RemoveClientConnection` for each ConnectionId. } #endif return base.OnPingMessageAsync(pingMessage); } - private async Task ProcessClientConnectionAsync(ClientConnectionContext connection) + private async Task ProcessClientConnectionAsync(ClientConnectionContext connection) { try { @@ -276,6 +287,7 @@ private async Task ProcessClientConnectionAsync(ClientConnectionContext connecti // Inform the Service that we will remove the client because SignalR told us it is disconnected. var serviceMessage = new CloseConnectionMessage(connection.ConnectionId, errorMessage: exception?.Message); + // when it fails, it means the underlying connection is dropped // service is responsible for closing the client connections in this case and there is no need to throw await SafeWriteAsync(serviceMessage); @@ -494,4 +506,4 @@ private Task OnErrorCompletionAsync(ErrorCompletionMessage errorCompletionMessag return Task.CompletedTask; } } -} \ No newline at end of file +} diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnectionFactory.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnectionFactory.cs index 2df05934f..4bb9c97d2 100644 --- a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnectionFactory.cs +++ b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnectionFactory.cs @@ -44,7 +44,7 @@ public ServiceConnectionFactory( _clientInvocationManager = clientInvocationManager; } - public virtual IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, ServiceConnectionType type) + public virtual IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, AckHandler ackHandler, ServiceConnectionType type) { return new ServiceConnection( _serviceProtocol, @@ -59,6 +59,7 @@ public virtual IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMe serviceMessageHandler, _serviceEventHandler, _clientInvocationManager, + ackHandler, type, ShutdownMode ) diff --git a/test/Microsoft.Azure.SignalR.AspNet.Tests/TestClasses/TestServiceConnectionProxy.cs b/test/Microsoft.Azure.SignalR.AspNet.Tests/TestClasses/TestServiceConnectionProxy.cs index 2a99b11a0..81e0a9c5a 100644 --- a/test/Microsoft.Azure.SignalR.AspNet.Tests/TestClasses/TestServiceConnectionProxy.cs +++ b/test/Microsoft.Azure.SignalR.AspNet.Tests/TestClasses/TestServiceConnectionProxy.cs @@ -32,7 +32,8 @@ public TestServiceConnectionProxy(IClientConnectionManager clientConnectionManag clientConnectionManager, loggerFactory, serviceMessageHandler ?? new TestServiceMessageHandler(), - null) + null, + new AckHandler()) { } diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AadAccessKeyTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AadAccessKeyTests.cs index d10eed434..12a79e553 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AadAccessKeyTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/Auth/AadAccessKeyTests.cs @@ -3,8 +3,9 @@ using System.Security.Claims; using System.Threading; using System.Threading.Tasks; +using Azure.Core; using Azure.Identity; - +using Moq; using Xunit; namespace Microsoft.Azure.SignalR.Common.Tests.Auth @@ -27,9 +28,12 @@ public void TestConstructor(string endpoint, string expectedAuthorizeUrl) [Fact] public async Task TestUpdateAccessKey() { - var credential = new DefaultAzureCredential(); - var endpoint = "http://localhost"; - var key = new AadAccessKey(new Uri(endpoint), credential); + var mockCredential = new Mock(); + mockCredential.Setup(credential => credential.GetTokenAsync( + It.IsAny(), + It.IsAny())) + .ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception")); + var key = new AadAccessKey(new Uri("http://localhost"), mockCredential.Object); var audience = "http://localhost/chat"; var claims = Array.Empty(); @@ -57,8 +61,12 @@ await Assert.ThrowsAsync( [InlineData(true, 56, false)] public async Task TestUpdateAccessKeyShouldSkip(bool isAuthorized, int timeElapsed, bool shouldSkip) { - var key = new AadAccessKey(new Uri("http://localhost"), new DefaultAzureCredential()); - + var mockCredential = new Mock(); + mockCredential.Setup(credential => credential.GetTokenAsync( + It.IsAny(), + It.IsAny())) + .ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception")); + var key = new AadAccessKey(new Uri("http://localhost"), mockCredential.Object); var isAuthorizedField = typeof(AadAccessKey).GetField("_isAuthorized", BindingFlags.NonPublic | BindingFlags.Instance); isAuthorizedField.SetValue(key, isAuthorized); Assert.Equal(isAuthorized, (bool)isAuthorizedField.GetValue(key)); @@ -81,7 +89,7 @@ public async Task TestUpdateAccessKeyShouldSkip(bool isAuthorized, int timeElaps } else { - await Assert.ThrowsAsync(async () => await key.UpdateAccessKeyAsync(source.Token)); + await Assert.ThrowsAsync(async () => await key.UpdateAccessKeyAsync(source.Token)); Assert.False((bool)isAuthorizedField.GetValue(key)); Assert.True(lastUpdatedTime < (DateTime)lastUpdatedTimeField.GetValue(key)); Assert.True(initializedTcs.Task.IsCompleted); @@ -91,8 +99,12 @@ public async Task TestUpdateAccessKeyShouldSkip(bool isAuthorized, int timeElaps [Fact] public async Task TestInitializeFailed() { - var credential = new DefaultAzureCredential(); - var key = new AadAccessKey(new Uri("http://localhost"), credential); + var mockCredential = new Mock(); + mockCredential.Setup(credential => credential.GetTokenAsync( + It.IsAny(), + It.IsAny())) + .ThrowsAsync(new InvalidOperationException("Mock GetTokenAsync throws an exception")); + var key = new AadAccessKey(new Uri("http://localhost"), mockCredential.Object); var audience = "http://localhost/chat"; var claims = Array.Empty(); @@ -104,7 +116,7 @@ public async Task TestInitializeFailed() ); var source = new CancellationTokenSource(TimeSpan.FromSeconds(1)); - await Assert.ThrowsAnyAsync( + await Assert.ThrowsAsync( async () => await key.UpdateAccessKeyAsync(source.Token) ); diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/RestClients/RestClientFacts.cs b/test/Microsoft.Azure.SignalR.Common.Tests/RestClients/RestClientFacts.cs index 87c649383..14b478670 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/RestClients/RestClientFacts.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/RestClients/RestClientFacts.cs @@ -4,7 +4,6 @@ using System.Net; using System.Net.Http; using System.Threading.Tasks; -using Azure.Core.Serialization; using Microsoft.Azure.SignalR.Tests.Common; using Microsoft.Extensions.DependencyInjection; using Xunit; @@ -21,7 +20,7 @@ public async Task TestHttpRequestExceptionWithStatusCodeSetAsync() var httpClientFactory = new ServiceCollection() .AddHttpClient("").ConfigurePrimaryHttpMessageHandler(() => new TestRootHandler(HttpStatusCode.InsufficientStorage)).Services .BuildServiceProvider().GetRequiredService(); - var client = new RestClient(httpClientFactory, new JsonObjectSerializer(), true); + var client = new RestClient(httpClientFactory); var apiEndpoint = new RestApiEndpoint("https://localhost.test.com", "token"); var exception = await Assert.ThrowsAsync(() => { diff --git a/test/Microsoft.Azure.SignalR.IntegrationTests/Infrastructure/MockServiceConnectionFactory.cs b/test/Microsoft.Azure.SignalR.IntegrationTests/Infrastructure/MockServiceConnectionFactory.cs index 550402cfc..4e4f4f678 100644 --- a/test/Microsoft.Azure.SignalR.IntegrationTests/Infrastructure/MockServiceConnectionFactory.cs +++ b/test/Microsoft.Azure.SignalR.IntegrationTests/Infrastructure/MockServiceConnectionFactory.cs @@ -35,9 +35,9 @@ public MockServiceConnectionFactory( _mockService = mockService; } - public override IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, ServiceConnectionType type) + public override IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, AckHandler ackHandler, ServiceConnectionType type) { - var serviceConnection = base.Create(endpoint, serviceMessageHandler, type); + var serviceConnection = base.Create(endpoint, serviceMessageHandler, ackHandler, type); return new MockServiceConnection(_mockService, serviceConnection); } } diff --git a/test/Microsoft.Azure.SignalR.Management.Tests/DependencyInjectionExtensionFacts.cs b/test/Microsoft.Azure.SignalR.Management.Tests/DependencyInjectionExtensionFacts.cs index 6e806921f..c97131e39 100644 --- a/test/Microsoft.Azure.SignalR.Management.Tests/DependencyInjectionExtensionFacts.cs +++ b/test/Microsoft.Azure.SignalR.Management.Tests/DependencyInjectionExtensionFacts.cs @@ -278,19 +278,121 @@ public async Task CustomizeHttpClientTimeoutTestAsync() o.ConnectionString = FakeEndpointUtils.GetFakeConnectionString(1).Single(); o.HttpClientTimeout = TimeSpan.FromSeconds(1); }) - .ConfigureServices(services => services.AddHttpClient(Options.DefaultName).AddHttpMessageHandler(sp => new WaitInfinitelyHandler())) + .ConfigureServices(services => + { + services.AddHttpClient(Constants.HttpClientNames.MessageResilient).AddHttpMessageHandler(sp => new WaitInfinitelyHandler()); + services.AddHttpClient(Constants.HttpClientNames.Resilient).AddHttpMessageHandler(sp => new WaitInfinitelyHandler()); + }) .BuildServiceManager(); var requestStartTime = DateTime.UtcNow; var serviceHubContext = await serviceManager.CreateHubContextAsync("hub", default); - await Assert.ThrowsAsync(() => serviceHubContext.Clients.All.SendCoreAsync("method", null)); + await TestCoreAsync(() => serviceHubContext.Clients.All.SendCoreAsync("method", null)); + await TestCoreAsync(() => serviceHubContext.ClientManager.CloseConnectionAsync("connectionId")); + } + + static async Task TestCoreAsync(Func testAction) + { + var requestStartTime = DateTime.UtcNow; + await Assert.ThrowsAsync(testAction); var elapsed = DateTime.UtcNow - requestStartTime; - _outputHelper.WriteLine($"Request elapsed time: {elapsed.Ticks}"); // Don't know why, the elapsed time sometimes is shorter than 1 second, but it should be close to 1 second. Assert.True(elapsed >= TimeSpan.FromSeconds(0.8)); Assert.True(elapsed < TimeSpan.FromSeconds(1.2)); } } + [Theory] + [InlineData("")] + [InlineData(Constants.HttpClientNames.MessageResilient)] + [InlineData(Constants.HttpClientNames.Resilient)] + public async Task HttpClientProductInfoTestAsync(string httpClientName) + { + using var hubContext = await new ServiceManagerBuilder() + .WithOptions(o => o.ConnectionString = FakeEndpointUtils.GetFakeConnectionString(1).Single()) + .ConfigureServices(services => services.AddHttpClient(httpClientName) + .ConfigurePrimaryHttpMessageHandler(() => + new TestRootHandler((message, token) => + { + if (message.Headers.TryGetValues(Constants.AsrsUserAgent, out var values)) + { + Assert.Single(values); + Assert.Matches("^Microsoft.Azure.SignalR.Management/", values.Single()); + } + else + { + throw new Exception("Product info header is missing"); + } + }))) + .BuildServiceManager() + .CreateHubContextAsync("hubName", default); + var serviceProvider = (hubContext as ServiceHubContextImpl).ServiceProvider; + var httpClientFactory = serviceProvider.GetRequiredService(); + using var httpClient = httpClientFactory.CreateClient(httpClientName); + await httpClient.SendAsync(new HttpRequestMessage(HttpMethod.Get, "http://abc")); + } + + [Theory] + [InlineData(Constants.HttpClientNames.Resilient)] + [InlineData(Constants.HttpClientNames.MessageResilient)] + public async Task HttpClientMessageTracingIdEnabledTestAsync(string httpClientName) + { + using var hubContext = await new ServiceManagerBuilder() + .WithOptions(o => + { + o.ConnectionString = FakeEndpointUtils.GetFakeConnectionString(1).Single(); + o.EnableMessageTracing = true; + }) + .ConfigureServices(services => services.AddHttpClient(httpClientName) + .ConfigurePrimaryHttpMessageHandler(() => + new TestRootHandler((message, token) => + { + if (message.Headers.TryGetValues(Constants.Headers.AsrsMessageTracingId, out var values)) + { + Assert.Single(values); + Convert.ToUInt64(values.Single()); + } + else + { + throw new Exception("Message tracing Id header is missing"); + } + }))) + .BuildServiceManager() + .CreateHubContextAsync("hubName", default); + var serviceProvider = (hubContext as ServiceHubContextImpl).ServiceProvider; + var httpClientFactory = serviceProvider.GetRequiredService(); + using var httpClient = httpClientFactory.CreateClient(httpClientName); + await httpClient.SendAsync(new HttpRequestMessage(HttpMethod.Get, "http://abc")); + } + + + [Theory] + [InlineData(Constants.HttpClientNames.Resilient)] + [InlineData(Constants.HttpClientNames.MessageResilient)] + public async Task HttpClientMessageTracingIdDisabledTestAsync(string httpClientName) + { + using var hubContext = await new ServiceManagerBuilder() + .WithOptions(o => + { + o.ConnectionString = FakeEndpointUtils.GetFakeConnectionString(1).Single(); + o.EnableMessageTracing = false; + }) + .ConfigureServices(services => services.AddHttpClient(httpClientName) + .ConfigurePrimaryHttpMessageHandler(() => + new TestRootHandler((message, token) => + { + if (message.Headers.TryGetValues(Constants.Headers.AsrsMessageTracingId, out var values)) + { + throw new Exception("Message tracing Id header is not expected"); + } + }))) + .BuildServiceManager() + .CreateHubContextAsync("hubName", default); + var serviceProvider = (hubContext as ServiceHubContextImpl).ServiceProvider; + var httpClientFactory = serviceProvider.GetRequiredService(); + using var httpClient = httpClientFactory.CreateClient(httpClientName); + await httpClient.SendAsync(new HttpRequestMessage(HttpMethod.Get, "http://abc")); + } + private class WaitInfinitelyHandler : DelegatingHandler { protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) diff --git a/test/Microsoft.Azure.SignalR.Management.Tests/IInternalServiceHubContextFacts.cs b/test/Microsoft.Azure.SignalR.Management.Tests/IInternalServiceHubContextFacts.cs index 90b8cdea7..d8e110d81 100644 --- a/test/Microsoft.Azure.SignalR.Management.Tests/IInternalServiceHubContextFacts.cs +++ b/test/Microsoft.Azure.SignalR.Management.Tests/IInternalServiceHubContextFacts.cs @@ -202,7 +202,7 @@ public async Task AddNotExistedConnectionToGroup_NoError_Test() .WithLoggerFactory(loggerFactory) .ConfigureServices(services => { - services.AddHttpClient(string.Empty).AddHttpMessageHandler(sp => + services.AddHttpClient(Constants.HttpClientNames.Resilient).AddHttpMessageHandler(sp => { var response = new HttpResponseMessage(HttpStatusCode.NotFound); response.Headers.Add(Constants.Headers.MicrosoftErrorCode, "Error.Connection.NotExisted"); diff --git a/test/Microsoft.Azure.SignalR.Management.Tests/Resilient/ExponentialBackoffPolicyTests.cs b/test/Microsoft.Azure.SignalR.Management.Tests/Resilient/ExponentialBackoffPolicyTests.cs new file mode 100644 index 000000000..13e18d682 --- /dev/null +++ b/test/Microsoft.Azure.SignalR.Management.Tests/Resilient/ExponentialBackoffPolicyTests.cs @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.Options; +using Xunit; + +namespace Microsoft.Azure.SignalR.Management.Tests.Resilient; + +public class ExponentialBackoffPolicyTests +{ + [Fact] + public void GetDelays_ReturnsExpectedDelays() + { + // Arrange + var options = Options.Create(new ServiceManagerOptions + { + RetryOptions = new ServiceManagerRetryOptions + { + Mode = ServiceManagerRetryMode.Exponential, + MaxRetries = 5, + Delay = TimeSpan.FromSeconds(1), + MaxDelay = TimeSpan.FromSeconds(10) + } + }); + var provider = new ExponentialBackOffPolicy(options); + + // Act + var delays = provider.GetDelays(); + + // Assert + var expectedDelays = new List + { + TimeSpan.FromSeconds(1), + TimeSpan.FromSeconds(2), + TimeSpan.FromSeconds(4), + TimeSpan.FromSeconds(8), + TimeSpan.FromSeconds(10) + }; + Assert.Equal(expectedDelays, delays); + } + + [Fact] + public void Constructor_ThrowsInvalidOperationException_WhenRetryModeIsNotExponential() + { + // Arrange + var options = Options.Create(new ServiceManagerOptions + { + RetryOptions = new ServiceManagerRetryOptions + { + Mode = ServiceManagerRetryMode.Fixed, + MaxRetries = 5, + Delay = TimeSpan.FromSeconds(1), + MaxDelay = TimeSpan.FromSeconds(10) + } + }); + + // Act & Assert + Assert.Throws(() => new ExponentialBackOffPolicy(options)); + } + + [Fact] + public void Constructor_ThrowsInvalidOperationException_WhenRetryOptionsIsNull() + { + // Arrange + var options = Options.Create(new ServiceManagerOptions()); + + // Act & Assert + Assert.Throws(() => new ExponentialBackOffPolicy(options)); + } +} diff --git a/test/Microsoft.Azure.SignalR.Management.Tests/Resilient/HttpClientRetryFacts.cs b/test/Microsoft.Azure.SignalR.Management.Tests/Resilient/HttpClientRetryFacts.cs new file mode 100644 index 000000000..f23e781aa --- /dev/null +++ b/test/Microsoft.Azure.SignalR.Management.Tests/Resilient/HttpClientRetryFacts.cs @@ -0,0 +1,278 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR; +using Microsoft.Azure.SignalR; +using Microsoft.Azure.SignalR.Common; +using Microsoft.Azure.SignalR.Tests.Common; +using Microsoft.Extensions.DependencyInjection; +using Moq; +using Moq.Protected; +using Xunit; + +namespace Microsoft.Azure.SignalR.Management.Tests; + +using HttpHandlerSetup = Moq.Language.Flow.ISetup>; + +public class HttpClientRetryFacts +{ + private const string HubName = "hub"; + private static readonly Func>[] NonMessageApiTransientHttpErrorSetup = new Func>[] + { + (setup) => setup.ReturnsAsync(new HttpResponseMessage(HttpStatusCode.InternalServerError)), + (setup) => setup.ReturnsAsync(new HttpResponseMessage(HttpStatusCode.BadGateway)), + (setup) => setup.ReturnsAsync(new HttpResponseMessage(HttpStatusCode.RequestTimeout)), + // Simulate timeout + (setup) => setup.Returns(async (request, token) => + { + await Task.Delay(-1, token); + return new HttpResponseMessage(HttpStatusCode.OK); + }) + }; + + public static readonly IEnumerable NullRetryOptionsTestData = NonMessageApiTransientHttpErrorSetup.Zip(new Action[] + { + void (ex) => Assert.IsType(ex), + void (ex) => Assert.IsType(ex), + void (ex) => Assert.IsType(ex), + void (ex) => + { + var canceled = Assert.IsType(ex); + Assert.IsType(canceled.InnerException); + }, + }, (setup, assert) => new object[] { setup, assert }); + + [Theory] + [MemberData(nameof(NullRetryOptionsTestData))] + public async Task NullRetryOptionsTest(Func> setup, Action assert) + { + var handlerMock = new Mock(); + setup(handlerMock.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny())); + var hubContext = await new ServiceManagerBuilder() + .WithOptions(o => + { + o.ConnectionString = FakeEndpointUtils.GetFakeConnectionString(1).Single(); + o.HttpClientTimeout = TimeSpan.FromMilliseconds(1); + }) + .ConfigureServices(services => services + .AddHttpClient(Constants.HttpClientNames.Resilient) + .ConfigurePrimaryHttpMessageHandler(sp => handlerMock.Object)) + .BuildServiceManager() + .CreateHubContextAsync(HubName, default); + var exception = await Assert.ThrowsAnyAsync(() => hubContext.ClientManager.GroupExistsAsync("groupName")); + assert(exception); + } + + public static readonly Func[] NonMessageApis = new Func[] + { + hubContext=>hubContext.Groups.AddToGroupAsync("connectionId", "groupName"), + hubContext=>hubContext.Groups.RemoveFromGroupAsync("connectionId", "groupName"), + hubContext=>hubContext.Groups.RemoveFromAllGroupsAsync("connectionId"), + hubContext=>hubContext.UserGroups.AddToGroupAsync("userId", "groupName"), + hubContext=>hubContext.UserGroups.RemoveFromGroupAsync("userId", "groupName"), + hubContext=>hubContext.UserGroups.RemoveFromAllGroupsAsync("userId"), + hubContext=>hubContext.ClientManager.GroupExistsAsync("groupName"), + hubContext=>hubContext.ClientManager.UserExistsAsync("userId"), + hubContext=>hubContext.ClientManager.ConnectionExistsAsync("connectionId"), + hubContext=>hubContext.ClientManager.CloseConnectionAsync("connectionId", "reason"), + }; + + public static readonly IEnumerable FixedDelayRetryTestData = + from pair in NonMessageApiTransientHttpErrorSetup.Zip(new Action[] + { + void (ex) => Assert.All(ex.InnerExceptions,inner=> Assert.IsType(inner)), + void (ex) => Assert.All(ex.InnerExceptions,inner=> Assert.IsType(inner)), + void (ex) => Assert.All(ex.InnerExceptions,inner=> Assert.IsType(inner)), + void (ex) => Assert.All(ex.InnerExceptions,inner=> + { + var operationCanceled = Assert.IsType(inner); + Assert.IsType(operationCanceled.InnerException); + }), + }) + from api in NonMessageApis + select new object[] { pair.First, pair.Second, api }; + + [Theory] + [MemberData(nameof(FixedDelayRetryTestData))] + public async Task FixedDelayRetryTestNonMessageApi(Func> setup, Action assert, Func api) + { + await FixedDelayRetryTestCore(setup, assert, api, Constants.HttpClientNames.Resilient); + } + + private static readonly Func>[] MessageApiTransientHttpErrorSetup = new Func>[] + { + (setup) => setup.ReturnsAsync(new HttpResponseMessage(HttpStatusCode.ServiceUnavailable)), + (setup) => setup.ReturnsAsync(new HttpResponseMessage(HttpStatusCode.BadGateway)) + }; + + public static readonly Func[] MessageApis = new Func[] + { + hubContext=>hubContext.Clients.All.SendAsync("method"), + hubContext=>hubContext.Clients.Client("abc").SendAsync("method"), + hubContext=>hubContext.Clients.Group("groupName").SendAsync("method"), + hubContext=>hubContext.Clients.User("userName").SendAsync("method"), + }; + + public static readonly IEnumerable FixedDelayRetryTestMessageApiTestData = + from setup in MessageApiTransientHttpErrorSetup + from api in MessageApis + select new object[] { setup, api }; + + [Theory] + [MemberData(nameof(FixedDelayRetryTestMessageApiTestData))] + public async Task FixedDelayRetryTestMessageApi(Func> setup, Func api) + { + await FixedDelayRetryTestCore(setup, void (ex) => Assert.All(ex.InnerExceptions, e => Assert.IsType(e)), api, Constants.HttpClientNames.MessageResilient); + } + + private static async Task FixedDelayRetryTestCore(Func> setup, Action assert, Func testAction, string httpClientName) + { + var handlerMock = new Mock(); + setup(handlerMock.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny())); + + var hubContext = await new ServiceManagerBuilder() + .WithOptions(o => + { + o.ConnectionString = FakeEndpointUtils.GetFakeConnectionString(1).Single(); + o.HttpClientTimeout = TimeSpan.FromMilliseconds(1000); + o.RetryOptions = new ServiceManagerRetryOptions + { + Mode = ServiceManagerRetryMode.Fixed, + Delay = TimeSpan.FromMilliseconds(50), + MaxRetries = 3 + }; + }) + .ConfigureServices(services => services + .AddHttpClient(httpClientName) + .ConfigurePrimaryHttpMessageHandler(sp => handlerMock.Object)) + .BuildServiceManager() + .CreateHubContextAsync(HubName, default); + var exception = await Assert.ThrowsAnyAsync(() => testAction(hubContext)); + var aggregationException = Assert.IsType(exception.InnerException); + assert(aggregationException); + + handlerMock.Protected().Verify("SendAsync", Times.Exactly(4), ItExpr.IsAny(), ItExpr.IsAny()); + } + + private static readonly Func>[] NonMessageApi_NotTransientHttpErrorSetup = new Func>[] + { + (setup) => setup.ReturnsAsync(new HttpResponseMessage(HttpStatusCode.BadRequest)), + (setup) => setup.ReturnsAsync(new HttpResponseMessage(HttpStatusCode.NotFound)), + (setup) => setup.ReturnsAsync(new HttpResponseMessage(HttpStatusCode.Unauthorized)) + }; + + public static IEnumerable NotRetryable_RetryTestNonMessageApiTestData = + from pair in NonMessageApi_NotTransientHttpErrorSetup.Zip(new Action[] + { + void (ex) => Assert.IsType(ex), + void (ex) => Assert.IsType(ex), + void (ex) => Assert.IsType(ex) + }) + from api in NonMessageApis + select new object[] { pair.First, pair.Second, api }; + + [Theory] + [MemberData(nameof(NotRetryable_RetryTestNonMessageApiTestData))] + public async Task NonRetryableError_RetryTestNonMessageApi(Func> setup, Action assert, Func api) + { + await NonRetryableError_RetryTestCore(setup, assert, api, Constants.HttpClientNames.Resilient); + } + + private static readonly Func>[] MessageApi_NotTransientHttpErrorSetup = new Func>[] + { + (setup) => setup.ReturnsAsync(new HttpResponseMessage(HttpStatusCode.InternalServerError)), + (setup) => setup.ReturnsAsync(new HttpResponseMessage(HttpStatusCode.BadRequest)), + (setup) => setup.ReturnsAsync(new HttpResponseMessage(HttpStatusCode.NotFound)), + (setup) => setup.ReturnsAsync(new HttpResponseMessage(HttpStatusCode.Unauthorized)), + // Simulate timeout + (setup) => setup.Returns(async (request, token) => + { + await Task.Delay(-1, token); + return new HttpResponseMessage(HttpStatusCode.OK); + }) + }; + + public static IEnumerable NotRetryable_RetryTest_MessageApiTestData = + from pair in MessageApi_NotTransientHttpErrorSetup.Zip(new Action[] + { + void (ex) => Assert.IsType(ex), + void (ex) => Assert.IsType(ex), + void (ex) => Assert.IsType(ex), + void (ex) => Assert.IsType(ex), + void (ex) => Assert.IsType(ex), + }) + from api in MessageApis + select new object[] { pair.First, pair.Second, api }; + + [Theory] + [MemberData(nameof(NotRetryable_RetryTest_MessageApiTestData))] + public async Task NonRetryableError_RetryTestMessageApi(Func> setup, Action assert, Func api) + { + await NonRetryableError_RetryTestCore(setup, assert, api, Constants.HttpClientNames.MessageResilient); + } + + private static async Task NonRetryableError_RetryTestCore(Func> setup, Action assert, Func testAction, string httpClientName) + { + var handlerMock = new Mock(); + setup(handlerMock.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny())); + + var hubContext = await new ServiceManagerBuilder() + .WithOptions(o => + { + o.ConnectionString = FakeEndpointUtils.GetFakeConnectionString(1).Single(); + o.HttpClientTimeout = TimeSpan.FromMilliseconds(1000); + o.RetryOptions = new ServiceManagerRetryOptions + { + Mode = ServiceManagerRetryMode.Fixed, + Delay = TimeSpan.FromMilliseconds(50), + MaxRetries = 3 + }; + }) + .ConfigureServices(services => services + .AddHttpClient(httpClientName) + .ConfigurePrimaryHttpMessageHandler(sp => handlerMock.Object)) + .BuildServiceManager() + .CreateHubContextAsync(HubName, default); + var exception = await Assert.ThrowsAnyAsync(() => testAction(hubContext)); + assert(exception); + } + + [Fact] + public async Task TheSecondRetrySuccessTest() + { + var handlerMock = new Mock(); + handlerMock.Protected() + .SetupSequence>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage(HttpStatusCode.InternalServerError)) + .ReturnsAsync(new HttpResponseMessage(HttpStatusCode.OK)); + var hubContext = await new ServiceManagerBuilder() + .WithOptions(o => + { + o.ConnectionString = FakeEndpointUtils.GetFakeConnectionString(1).Single(); + o.HttpClientTimeout = TimeSpan.FromMilliseconds(1); + o.RetryOptions = new ServiceManagerRetryOptions + { + Mode = ServiceManagerRetryMode.Fixed, + Delay = TimeSpan.FromMilliseconds(50), + MaxRetries = 3 + }; + }) + .ConfigureServices(services => services + .AddHttpClient(Constants.HttpClientNames.Resilient) + .ConfigurePrimaryHttpMessageHandler(sp => handlerMock.Object)) + .BuildServiceManager() + .CreateHubContextAsync(HubName, default); + await hubContext.ClientManager.GroupExistsAsync("groupName"); + handlerMock.Protected().Verify("SendAsync", Times.Exactly(2), ItExpr.IsAny(), ItExpr.IsAny()); + } +} diff --git a/test/Microsoft.Azure.SignalR.Management.Tests/RestApiProviderFacts.cs b/test/Microsoft.Azure.SignalR.Management.Tests/RestApiProviderFacts.cs index d501c4790..a99dd570e 100644 --- a/test/Microsoft.Azure.SignalR.Management.Tests/RestApiProviderFacts.cs +++ b/test/Microsoft.Azure.SignalR.Management.Tests/RestApiProviderFacts.cs @@ -4,9 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Net.Http; using System.Threading.Tasks; -using Azure.Core.Serialization; using Microsoft.Azure.SignalR.Tests; using Xunit; @@ -33,28 +31,6 @@ internal async Task RestApiTest(Task task, string expectedAudie Assert.Equal(expectedTokenString, api.Token); } - [Theory] - [InlineData(true)] - [InlineData(false)] - internal async Task EnableMessageTracingIdInRestApiTest(bool enable) - { - var api = await _restApiProvider.GetBroadcastEndpointAsync("app", "hub"); - var client = new RestClient(HttpClientFactory.Instance, new NewtonsoftJsonObjectSerializer(), enable); - try - { - await client.SendAsync(api, HttpMethod.Post, "", handleExpectedResponse: default).OrTimeout(200); - } - catch - { - } - Assert.Equal(enable, api.Query?.ContainsKey(Constants.Headers.AsrsMessageTracingId) ?? false); - if (enable) - { - var id = Convert.ToUInt64(api.Query[Constants.Headers.AsrsMessageTracingId]); - Assert.Equal(MessageWithTracingIdHelper.Prefix, id); - } - } - public static IEnumerable GetTestData() => from context in GetContext() from pair in GetTestDataByContext(context) diff --git a/test/Microsoft.Azure.SignalR.Management.Tests/Serialization/SerailizerFacts.cs b/test/Microsoft.Azure.SignalR.Management.Tests/Serialization/SerailizerFacts.cs index 434e01e05..68258ca4e 100644 --- a/test/Microsoft.Azure.SignalR.Management.Tests/Serialization/SerailizerFacts.cs +++ b/test/Microsoft.Azure.SignalR.Management.Tests/Serialization/SerailizerFacts.cs @@ -115,7 +115,7 @@ private ServiceManagerBuilder CreateTransientBuilder(string expectedHttpBody) .WithLoggerFactory(_loggerFactory) .ConfigureServices(services => { - services.AddHttpClient(string.Empty).AddHttpMessageHandler(() => new TestRootHandler((message, cancellationToken) => + services.AddHttpClient(Constants.HttpClientNames.MessageResilient).AddHttpMessageHandler(() => new TestRootHandler((message, cancellationToken) => { var actualBody = message.Content.ReadAsStringAsync().Result; diff --git a/test/Microsoft.Azure.SignalR.Management.Tests/StronglyTypedServiceHubContextFacts.cs b/test/Microsoft.Azure.SignalR.Management.Tests/StronglyTypedServiceHubContextFacts.cs index 47123b0bf..ccf49b096 100644 --- a/test/Microsoft.Azure.SignalR.Management.Tests/StronglyTypedServiceHubContextFacts.cs +++ b/test/Microsoft.Azure.SignalR.Management.Tests/StronglyTypedServiceHubContextFacts.cs @@ -11,7 +11,6 @@ using Microsoft.Azure.SignalR.Tests.Common; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Options; using Newtonsoft.Json; using Xunit; using Xunit.Abstractions; @@ -69,7 +68,7 @@ void assertion(HttpRequestMessage request, CancellationToken t) var expected = JsonConvert.SerializeObject(payload); Assert.Equal(expected, actual); } - var services = new ServiceCollection().AddHttpClient(Options.DefaultName) + var services = new ServiceCollection().AddHttpClient(Constants.HttpClientNames.Resilient) .ConfigurePrimaryHttpMessageHandler(() => new TestRootHandler(assertion)).Services .AddSignalRServiceManager(); await using var hubContext = await Create(ServiceTransportType.Transient, services); @@ -116,7 +115,7 @@ void assertion(HttpRequestMessage request, CancellationToken t) } var services = new ServiceCollection() .AddSignalRServiceManager() - .AddHttpClient(Options.DefaultName).ConfigurePrimaryHttpMessageHandler(() => new TestRootHandler(assertion)).Services; + .AddHttpClient(Constants.HttpClientNames.Resilient).ConfigurePrimaryHttpMessageHandler(() => new TestRootHandler(assertion)).Services; await using var hubContext = await Create(ServiceTransportType.Transient, services); await hubContext.Groups.AddToGroupAsync(connectionId, groupName); @@ -158,7 +157,7 @@ void assertion(HttpRequestMessage request, CancellationToken t) } var services = new ServiceCollection() .AddSignalRServiceManager() - .AddHttpClient(Options.DefaultName).ConfigurePrimaryHttpMessageHandler(() => new TestRootHandler(assertion)).Services; + .AddHttpClient(Constants.HttpClientNames.Resilient).ConfigurePrimaryHttpMessageHandler(() => new TestRootHandler(assertion)).Services; await using var hubContext = await Create(ServiceTransportType.Transient, services); await hubContext.UserGroups.AddToGroupAsync(userId, groupName); @@ -200,7 +199,7 @@ void assertion(HttpRequestMessage request, CancellationToken t) } var services = new ServiceCollection() .AddSignalRServiceManager() - .AddHttpClient(Options.DefaultName).ConfigurePrimaryHttpMessageHandler(() => new TestRootHandler(assertion)).Services; + .AddHttpClient(Constants.HttpClientNames.Resilient).ConfigurePrimaryHttpMessageHandler(() => new TestRootHandler(assertion)).Services; await using var hubContext = await Create(ServiceTransportType.Transient, services); await hubContext.ClientManager.CloseConnectionAsync(connectionId); diff --git a/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnectionFactory.cs b/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnectionFactory.cs index 15cc0be6e..4281f63de 100644 --- a/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnectionFactory.cs +++ b/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnectionFactory.cs @@ -18,7 +18,7 @@ public TestServiceConnectionFactory(Func ge _generator = generator; } - public IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, ServiceConnectionType type) + public IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, AckHandler ackHandler, ServiceConnectionType type) { var conn = _generator?.Invoke(endpoint) ?? new TestServiceConnection(serviceMessageHandler: serviceMessageHandler); var receiver = CreatedConnections.GetOrAdd(endpoint, e => new()); diff --git a/test/Microsoft.Azure.SignalR.Tests/AckHandlerTest.cs b/test/Microsoft.Azure.SignalR.Tests/AckHandlerTest.cs new file mode 100644 index 000000000..93c460786 --- /dev/null +++ b/test/Microsoft.Azure.SignalR.Tests/AckHandlerTest.cs @@ -0,0 +1,83 @@ +using System; +using System.Text; +using System.Threading.Tasks; + +using Microsoft.Azure.SignalR.Tests; +using Microsoft.Extensions.Options; + +using Xunit; + +namespace Microsoft.Azure.SignalR.Tests +{ + public class AckHandlerTest + { + [Fact] + public void TestOnce() + { + var handler = new AckHandler(); + var task = handler.CreateSingleAck(out var ackId); + handler.TriggerAck(ackId); + Assert.True(task.IsCompletedSuccessfully); + Assert.Equal(AckStatus.Ok, task.Result); + } + + [Fact] + public async Task TestOnce_Timeout() + { + var handler = new AckHandler(TimeSpan.FromSeconds(1), TimeSpan.FromMilliseconds(20)); + var task = handler.CreateSingleAck(out var ackId); + Assert.False(task.IsCompleted); + await Task.Delay(TimeSpan.FromSeconds(1.5)); + Assert.True(task.IsCompleted); + // This assertion is different from RT for different behaviour when timeout of AckHandler. See annotation in AckHandler.cs method CheckAcs + Assert.Equal(AckStatus.Timeout, task.Result); + } + + [Fact] + public void TestTwice_SetExpectedFirst() + { + var handler = new AckHandler(); + var task = handler.CreateMultiAck(out var ackId); + handler.SetExpectedCount(ackId, 2); + handler.TriggerAck(ackId); + Assert.False(task.IsCompleted); + handler.TriggerAck(ackId); + Assert.True(task.IsCompletedSuccessfully); + } + + [Fact] + public void TestTwice_AckFirst() + { + var handler = new AckHandler(); + var task = handler.CreateMultiAck(out var ackId); + handler.TriggerAck(ackId); + Assert.False(task.IsCompleted); + handler.TriggerAck(ackId); + Assert.False(task.IsCompleted); + handler.SetExpectedCount(ackId, 2); + Assert.True(task.IsCompletedSuccessfully); + } + + [Fact] + public async Task TestTwice_Timeout() + { + var handler = new AckHandler(TimeSpan.FromSeconds(1), TimeSpan.FromMilliseconds(20)); + var task = handler.CreateMultiAck(out var ackId); + Assert.False(task.IsCompleted); + handler.SetExpectedCount(ackId, 2); + Assert.False(task.IsCompleted); + await Task.Delay(TimeSpan.FromSeconds(1.5)); + Assert.True(task.IsCompleted); + // This assertion is different from RT for different behaviour when timeout of AckHandler. See annotation in AckHandler.cs method CheckAcs + Assert.Equal(AckStatus.Timeout, task.Result); + } + + [Fact] + public void TestInvalid_SetExpectedForSingle() + { + var handler = new AckHandler(TimeSpan.FromSeconds(1), TimeSpan.FromMilliseconds(20)); + var task = handler.CreateSingleAck(out var ackId); + Assert.Throws(() => handler.SetExpectedCount(ackId, 2)); + } + } +} diff --git a/test/Microsoft.Azure.SignalR.Tests/EndpointRouterTests.cs b/test/Microsoft.Azure.SignalR.Tests/EndpointRouterTests.cs index 48cb54859..e329a1e8c 100644 --- a/test/Microsoft.Azure.SignalR.Tests/EndpointRouterTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/EndpointRouterTests.cs @@ -10,26 +10,69 @@ namespace Microsoft.Azure.SignalR.Tests public class EndpointRouterTests { [Fact] - public void TestDefaultEndpointWeightedRouter() + public void TestDefaultEndpointRouterWeightedMode() { - const int loops = 1000; - var context = new RandomContext(); var drt = new DefaultEndpointRouter(); - const string u1Full = "u1_full", u1Empty = "u1_empty"; - var u1F = GenerateServiceEndpoint(1000, 10, 990, u1Full); - var u1E = GenerateServiceEndpoint(1000, 10, 0, u1Empty); - var el = new List() { u1E, u1F }; + const int loops = 20; + var context = new RandomContext(); + + const string small = "small_instance", large = "large_instance"; + var uSmall = GenerateServiceEndpoint(10, 0, 9, small); + var uLarge = GenerateServiceEndpoint(1000, 0, 900, large); + var el = new List() { uLarge, uSmall }; context.BenchTest(loops, () => - drt.GetNegotiateEndpoint(null, el).Name); - var u1ECount = context.GetCount(u1Empty); - const int smallVar = 10; - Assert.True(u1ECount is > loops - smallVar and <= loops); - var u1FCount = context.GetCount(u1Full); - Assert.True(u1FCount <= smallVar); + { + var ep = drt.GetNegotiateEndpoint(null, el); + ep.EndpointMetrics.ClientConnectionCount++; + return ep.Name; + }); + var uLargeCount = context.GetCount(large); + const int smallVar = 3; + var uSmallCount = context.GetCount(small); + Assert.True(uLargeCount is >= loops - smallVar and <= loops); + Assert.True(uSmallCount is >= 1 and <= smallVar); context.Reset(); } + [Theory] + [InlineData(200)] + [InlineData(300)] + [InlineData(400)] + [InlineData(500)] + public void TestDefaultEndpointRouterWeightedModeWhenAutoScaleIsEnabled(int quotaOfScaleUpInstance) + { + var drt = new DefaultEndpointRouter(); + + var loops = 100 + (quotaOfScaleUpInstance / 5); + var context = new RandomContext(); + const double quotaBarForScaleUp = 0.8; + + var endpointA = GenerateServiceEndpoint(quotaOfScaleUpInstance, 0, 80, "a"); + var endpointB = GenerateServiceEndpoint(100, 0, 70, "b"); + var endpointC = GenerateServiceEndpoint(100, 0, 70, "c"); + var el = new List() {endpointA, endpointB, endpointC}; + context.BenchTest(loops, () => + { + var ep = drt.GetNegotiateEndpoint(null, el); + ep.EndpointMetrics.ClientConnectionCount++; + var percent = (ep.EndpointMetrics.ClientConnectionCount + ep.EndpointMetrics.ServerConnectionCount) / + (double)ep.EndpointMetrics.ConnectionCapacity; + if (percent > quotaBarForScaleUp) + { + ep.EndpointMetrics.ConnectionCapacity += 100; + } + + return ep.Name; + }); + + Assert.Equal(context.GetCount("a") + context.GetCount("b") + context.GetCount("c"), loops); + Assert.Equal(quotaOfScaleUpInstance, endpointA.EndpointMetrics.ConnectionCapacity); + Assert.Equal(200, endpointB.EndpointMetrics.ConnectionCapacity); + Assert.Equal(200, endpointC.EndpointMetrics.ConnectionCapacity); + + context.Reset(); + } private static ServiceEndpoint GenerateServiceEndpoint(int capacity, int serverConnectionCount, int clientConnectionCount, string name) diff --git a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/ServiceConnectionProxy.cs b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/ServiceConnectionProxy.cs index a39c5546e..e9bfbc2f6 100644 --- a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/ServiceConnectionProxy.cs +++ b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/ServiceConnectionProxy.cs @@ -74,8 +74,10 @@ public ServiceConnectionProxy( ServiceMessageHandler = (StrongServiceConnectionContainer) ServiceConnectionContainer; } - public IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, - ServiceConnectionType type) + public IServiceConnection Create(HubServiceEndpoint endpoint, + IServiceMessageHandler serviceMessageHandler, + AckHandler ackHandler, + ServiceConnectionType type) { var connectionId = Guid.NewGuid().ToString("N"); var connection = new ServiceConnection( @@ -91,6 +93,7 @@ public IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHan serviceMessageHandler, null, ClientInvocationManager, + ackHandler, type); ServiceConnections.TryAdd(connectionId, connection); return connection; diff --git a/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs b/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs index cd2a1c60e..4493b6f64 100644 --- a/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs @@ -1555,13 +1555,13 @@ public async Task ServiceConnectionContainerScopeWithPingUpdateTest() var endpoints = sem.GetEndpoints("hub"); var clientInvocationManager = new DefaultClientInvocationManager(); var connection1 = new ServiceConnection(protocol, ccm, connectionFactory1, loggerFactory, connectionDelegate, ccf, - "serverId", "server-conn-1", endpoints[0], endpoints[0].ConnectionContainer as IServiceMessageHandler, null, clientInvocationManager, closeTimeOutMilliseconds: 500); + "serverId", "server-conn-1", endpoints[0], endpoints[0].ConnectionContainer as IServiceMessageHandler, null, clientInvocationManager, new AckHandler(), closeTimeOutMilliseconds: 500); var connection2 = new ServiceConnection(protocol, ccm, connectionFactory2, loggerFactory, connectionDelegate, ccf, - "serverId", "server-conn-2", endpoints[1], endpoints[1].ConnectionContainer as IServiceMessageHandler, null, clientInvocationManager, closeTimeOutMilliseconds: 500); + "serverId", "server-conn-2", endpoints[1], endpoints[1].ConnectionContainer as IServiceMessageHandler, null, clientInvocationManager, new AckHandler(), closeTimeOutMilliseconds: 500); var connection22 = new ServiceConnection(protocol, ccm, connectionFactory22, loggerFactory, connectionDelegate, ccf, - "serverId", "server-conn-22", endpoints[1], endpoints[1].ConnectionContainer as IServiceMessageHandler, null, clientInvocationManager, closeTimeOutMilliseconds: 500); + "serverId", "server-conn-22", endpoints[1], endpoints[1].ConnectionContainer as IServiceMessageHandler, null, clientInvocationManager, new AckHandler(), closeTimeOutMilliseconds: 500); var router = new TestEndpointRouter(); diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionContainerBaseTests.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionContainerBaseTests.cs index 26ee1c74c..9419f56bd 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionContainerBaseTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionContainerBaseTests.cs @@ -16,61 +16,6 @@ public class ServiceConnectionContainerBaseTests : VerifiableLoggedTest public ServiceConnectionContainerBaseTests(ITestOutputHelper helper) : base(helper) { } - - [Theory] - [InlineData(ServiceConnectionStatus.Disconnected)] - [InlineData(ServiceConnectionStatus.Connected)] - [InlineData(ServiceConnectionStatus.Connecting)] - [InlineData(ServiceConnectionStatus.Inited)] - internal async Task TestIfConnectionWillNotRestartAfterShutdown(ServiceConnectionStatus status) - { - List connections = new List - { - new SimpleTestServiceConnection(), - new SimpleTestServiceConnection(status: status) - }; - - IServiceConnection connection = connections[1]; - - using TestServiceConnectionContainer container = new TestServiceConnectionContainer(connections, factory: new SimpleTestServiceConnectionFactory()); - container.ShutdownForTest(); - - await container.OnConnectionCompleteForTestShutdown(connection); - - // the connection should not be replaced when shutting down - Assert.Equal(container.Connections[1], connection); - // its status is not changed - Assert.Equal(status, container.Connections[1].Status); - // the container is not listening to the connection's status changes after shutdown - Assert.Equal(1, (connection as SimpleTestServiceConnection).ConnectionStatusChangedRemoveCount); - } - - [Theory] - [InlineData(GracefulShutdownMode.Off)] - [InlineData(GracefulShutdownMode.WaitForClientsClose)] - [InlineData(GracefulShutdownMode.MigrateClients)] - internal async Task TestOffline(GracefulShutdownMode mode) - { - List connections = new List - { - new SimpleTestServiceConnection(), - new SimpleTestServiceConnection() - }; - using TestServiceConnectionContainer container = new TestServiceConnectionContainer(connections, factory: new SimpleTestServiceConnectionFactory()); - - foreach (SimpleTestServiceConnection c in connections) - { - Assert.False(c.ConnectionOfflineTask.IsCompleted); - } - - await container.OfflineAsync(mode); - - foreach (SimpleTestServiceConnection c in connections) - { - Assert.True(c.ConnectionOfflineTask.IsCompleted); - } - } - [Theory] [InlineData(3, 3, 0)] [InlineData(0, 1, 1)] // stop more than start will log warn @@ -95,10 +40,10 @@ public async Task TestServersPing(int startCount, int stopCount, int expectedWar new SimpleTestServiceConnection(), new SimpleTestServiceConnection() }; - using TestServiceConnectionContainer container = + using TestServiceConnectionContainer container = new TestServiceConnectionContainer( - connections, - factory: new SimpleTestServiceConnectionFactory(), + connections, + factory: new SimpleTestServiceConnectionFactory(), logger: loggerFactory.CreateLogger()); await container.StartAsync(); @@ -114,12 +59,12 @@ public async Task TestServersPing(int startCount, int stopCount, int expectedWar // default interval is 5s, add 2s for delay, validate any one connection write servers ping. if (tasks.Count > 0) - { + { await Task.WhenAny(connections.Select(c => { var connection = c as SimpleTestServiceConnection; return connection.ServersPingTask.OrTimeout(7000); - })); + })); } tasks.Clear(); @@ -211,21 +156,77 @@ await Task.WhenAny(connections.Select(c => } } + [Theory] + [InlineData(ServiceConnectionStatus.Disconnected)] + [InlineData(ServiceConnectionStatus.Connected)] + [InlineData(ServiceConnectionStatus.Connecting)] + [InlineData(ServiceConnectionStatus.Inited)] + internal async Task TestIfConnectionWillNotRestartAfterShutdown(ServiceConnectionStatus status) + { + List connections = new List + { + new SimpleTestServiceConnection(), + new SimpleTestServiceConnection(status: status) + }; + + IServiceConnection connection = connections[1]; + + using TestServiceConnectionContainer container = new TestServiceConnectionContainer(connections, factory: new SimpleTestServiceConnectionFactory()); + container.ShutdownForTest(); + + await container.OnConnectionCompleteForTestShutdown(connection); + + // the connection should not be replaced when shutting down + Assert.Equal(container.Connections[1], connection); + + // its status is not changed + Assert.Equal(status, container.Connections[1].Status); + + // the container is not listening to the connection's status changes after shutdown + Assert.Equal(1, (connection as SimpleTestServiceConnection).ConnectionStatusChangedRemoveCount); + } + + [Theory] + [InlineData(GracefulShutdownMode.Off)] + [InlineData(GracefulShutdownMode.WaitForClientsClose)] + [InlineData(GracefulShutdownMode.MigrateClients)] + internal async Task TestOffline(GracefulShutdownMode mode) + { + List connections = new List + { + new SimpleTestServiceConnection(), + new SimpleTestServiceConnection() + }; + using TestServiceConnectionContainer container = new TestServiceConnectionContainer(connections, factory: new SimpleTestServiceConnectionFactory()); + + foreach (SimpleTestServiceConnection c in connections) + { + Assert.False(c.ConnectionOfflineTask.IsCompleted); + } + + await container.OfflineAsync(mode); + + foreach (SimpleTestServiceConnection c in connections) + { + Assert.True(c.ConnectionOfflineTask.IsCompleted); + } + } private sealed class SimpleTestServiceConnectionFactory : IServiceConnectionFactory { - public IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, ServiceConnectionType type) => new SimpleTestServiceConnection(); + public IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, AckHandler ackHandler, ServiceConnectionType type) => new SimpleTestServiceConnection(); } private sealed class SimpleTestServiceConnection : IServiceConnection { + private readonly TaskCompletionSource _offline = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + private readonly TaskCompletionSource _serversPing = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + public Task ConnectionInitializedTask => Task.Delay(TimeSpan.FromSeconds(1)); public ServiceConnectionStatus Status { get; set; } = ServiceConnectionStatus.Disconnected; - private readonly TaskCompletionSource _offline = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - private readonly TaskCompletionSource _serversPing = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - public Task ConnectionOfflineTask => _offline.Task; public Task ServersPingTask => _serversPing.Task; @@ -241,8 +242,8 @@ public SimpleTestServiceConnection(ServiceConnectionStatus status = ServiceConne public event Action ConnectionStatusChanged { - add { ConnectionStatusChangedAddCount++; } - remove { ConnectionStatusChangedRemoveCount++; } + add => ConnectionStatusChangedAddCount++; + remove => ConnectionStatusChangedRemoveCount++; } public Task StartAsync(string target = null) diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionTests.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionTests.cs index 8f5aa9099..102ae5c2c 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionTests.cs @@ -47,7 +47,7 @@ public async Task TestServiceConnectionWithNormalApplicationTask() builder.UseConnectionHandler(); ConnectionDelegate handler = builder.Build(); var connection = new ServiceConnection(protocol, ccm, connectionFactory, loggerFactory, handler, ccf, - "serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager()); + "serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager(), new AckHandler()); var connectionTask = connection.StartAsync(); @@ -103,7 +103,7 @@ public async Task TestServiceConnectionErrorCleansAllClients() builder.UseConnectionHandler(); ConnectionDelegate handler = builder.Build(); var connection = new ServiceConnection(protocol, ccm, connectionFactory, loggerFactory, handler, ccf, - "serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager()); + "serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager(), new AckHandler()); var connectionTask = connection.StartAsync(); @@ -159,7 +159,7 @@ public async Task TestServiceConnectionWithErrorApplicationTask() ConnectionDelegate handler = builder.Build(); var connection = new ServiceConnection(protocol, ccm, connectionFactory, loggerFactory, handler, ccf, - "serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager()); + "serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager(), new AckHandler()); var connectionTask = connection.StartAsync(); @@ -222,7 +222,7 @@ public async Task TestServiceConnectionWithEndlessApplicationTaskNeverEnds() ConnectionDelegate handler = builder.Build(); var connection = new ServiceConnection(protocol, ccm, connectionFactory, loggerFactory, handler, ccf, "serverId", Guid.NewGuid().ToString("N"), - null, null, null, new DefaultClientInvocationManager(), closeTimeOutMilliseconds: 1); + null, null, null, new DefaultClientInvocationManager(), new AckHandler(), closeTimeOutMilliseconds: 1); var connectionTask = connection.StartAsync(); @@ -277,7 +277,7 @@ public async Task ClientConnectionOutgoingAbortCanEndLifeTime() builder.UseConnectionHandler(); ConnectionDelegate handler = builder.Build(); var connection = new ServiceConnection(protocol, ccm, connectionFactory, loggerFactory, handler, ccf, - "serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager(), + "serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager(), new AckHandler(), closeTimeOutMilliseconds: 500); var connectionTask = connection.StartAsync(); @@ -335,7 +335,7 @@ public async Task ClientConnectionContextAbortCanSendOutCloseMessage() ConnectionDelegate handler = builder.Build(); var connection = new ServiceConnection(protocol, ccm, connectionFactory, loggerFactory, handler, ccf, - "serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager(), closeTimeOutMilliseconds: 500); + "serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager(), new AckHandler(), closeTimeOutMilliseconds: 500); var connectionTask = connection.StartAsync(); @@ -398,7 +398,7 @@ public async Task ClientConnectionWithDiagnosticClientTagTest() ConnectionDelegate handler = builder.Build(); var connection = new ServiceConnection(protocol, ccm, connectionFactory, loggerFactory, handler, ccf, - "serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager(), closeTimeOutMilliseconds: 500); + "serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager(), new AckHandler(), closeTimeOutMilliseconds: 500); var connectionTask = connection.StartAsync(); @@ -456,7 +456,7 @@ public async Task ClientConnectionLastWillCanSendOut() builder.UseConnectionHandler(); ConnectionDelegate handler = builder.Build(); var connection = new ServiceConnection(protocol, ccm, connectionFactory, loggerFactory, handler, ccf, - "serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager(), closeTimeOutMilliseconds: 500); + "serverId", Guid.NewGuid().ToString("N"), null, null, null, new DefaultClientInvocationManager(), new AckHandler(), closeTimeOutMilliseconds: 500); var connectionTask = connection.StartAsync(); diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs index 90ccaa735..ca29456f5 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs @@ -146,22 +146,6 @@ public async Task TestCloseConnectionMessage() await connection.StopAsync(); } - private ServiceEndpoint MockServiceEndpoint(string keyTypeName) - { - switch (keyTypeName) - { - case nameof(AccessKey): - return new ServiceEndpoint(_keyConnectionString); - case nameof(AadAccessKey): - var endpoint = new ServiceEndpoint(_aadConnectionString); - var p = typeof(ServiceEndpoint).GetProperty("AccessKey", BindingFlags.NonPublic | BindingFlags.Instance); - p.SetValue(endpoint, new TestAadAccessKey()); - return endpoint; - default: - throw new NotImplementedException(); - } - } - [Theory] [InlineData(typeof(AccessKey))] [InlineData(typeof(AadAccessKey))] @@ -322,6 +306,24 @@ private static TestServiceConnection CreateServiceConnection(ConnectionHandler h ); } + private ServiceEndpoint MockServiceEndpoint(string keyTypeName) + { + switch (keyTypeName) + { + case nameof(AccessKey): + return new ServiceEndpoint(_keyConnectionString); + + case nameof(AadAccessKey): + var endpoint = new ServiceEndpoint(_aadConnectionString); + var p = typeof(ServiceEndpoint).GetProperty("AccessKey", BindingFlags.NonPublic | BindingFlags.Instance); + p.SetValue(endpoint, new TestAadAccessKey()); + return endpoint; + + default: + throw new NotImplementedException(); + } + } + private class TestAadAccessKey : AadAccessKey { public string Token { get; } = Guid.NewGuid().ToString(); @@ -335,6 +337,7 @@ public override Task GenerateAadTokenAsync(CancellationToken ctoken = de return Task.FromResult(Token); } } + private sealed class TestConnectionContainer { public TestConnection Instance { get; set; } @@ -343,6 +346,7 @@ private sealed class TestConnectionContainer private sealed class TestConnectionHandler : ConnectionHandler { private readonly int _shutdownAfter = 0; + private readonly string _lastWords; public TestConnectionHandler(int shutdownAfter = 0, string lastWords = null) @@ -389,6 +393,7 @@ public override async Task OnConnectedAsync(ConnectionContext connection) await connection.Transport.Output.FlushAsync(); } } + /// /// ------------------------- Client Connection------------------------------ -------------Service Connection--------- /// | Transport Application | | Transport Application | @@ -409,15 +414,19 @@ private sealed class TestServiceConnection : ServiceConnection private readonly TestConnectionContainer _container; private readonly TaskCompletionSource _clientConnectedTcs = new TaskCompletionSource(); + private readonly TaskCompletionSource _clientDisconnectedTcs = new TaskCompletionSource(); private ReadOnlySequence _payload = new ReadOnlySequence(); + public TestClientConnectionManager ClientConnectionManager { get; } public PipeReader Reader => _connection.Application.Input; + public PipeWriter Writer => _connection.Application.Output; public Task ClientConnectedTask => _clientConnectedTcs.Task; + public Task ClientDisconnectedTask => _clientDisconnectedTcs.Task; public ServiceProtocol DefaultServiceProtocol { get; } = new ServiceProtocol(); @@ -460,6 +469,7 @@ public TestServiceConnection(TestConnectionContainer container, serviceMessageHandler, serviceEventHandler, clientInvocationManager, + new AckHandler(), connectionType: connectionType, mode: mode, closeTimeOutMilliseconds: closeTimeOutMilliseconds) @@ -467,6 +477,7 @@ public TestServiceConnection(TestConnectionContainer container, _container = container; ClientConnectionManager = clientConnectionManager; } + public async Task ExpectStringMessage(string expected, string connectionId = null) { var payload = await GetPayloadAsync(connectionId: connectionId); @@ -479,7 +490,7 @@ public async Task ExpectStringMessage(string expected, string connectionId = nul _payload = payload.Slice(expectedBytes.Length); } - public async Task ExpectServiceMessage() where T: ServiceMessage + public async Task ExpectServiceMessage() where T : ServiceMessage { var result = await Reader.ReadAsync(); var buffer = result.Buffer; diff --git a/version.props b/version.props index 5df721e7a..fbe0e62a4 100644 --- a/version.props +++ b/version.props @@ -1,7 +1,7 @@ - 1.21.7 + 1.22.0 1.1.0 1.9.0