From 4c8092c5beba75101662397a8e8e1f69ed685c91 Mon Sep 17 00:00:00 2001 From: stdrickforce Date: Thu, 12 Oct 2023 13:59:16 +0800 Subject: [PATCH] Move App/Transport async task to ClientConnection class --- .../Constants.cs | 4 + .../ServiceConnectionBase.cs | 4 +- .../ClientConnectionContext.cs | 304 +++++++++++++----- .../ServerConnections/Log.cs | 184 +++++++++++ .../ServiceConnection.Log.cs | 188 ----------- .../ServerConnections/ServiceConnection.cs | 156 ++------- .../ServiceConnectionWritter.cs | 10 + 7 files changed, 442 insertions(+), 408 deletions(-) create mode 100644 src/Microsoft.Azure.SignalR/ServerConnections/Log.cs delete mode 100644 src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.Log.cs create mode 100644 src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnectionWritter.cs diff --git a/src/Microsoft.Azure.SignalR.Common/Constants.cs b/src/Microsoft.Azure.SignalR.Common/Constants.cs index a691ff4be..0f8ba2d21 100644 --- a/src/Microsoft.Azure.SignalR.Common/Constants.cs +++ b/src/Microsoft.Azure.SignalR.Common/Constants.cs @@ -46,6 +46,9 @@ public static class Periods // Custom handshake timeout of SignalR Service public const int DefaultHandshakeTimeout = 15; public const int MaxCustomHandshakeTimeout = 30; + + public static readonly TimeSpan DefaultServerHandshakeTimeout = TimeSpan.FromSeconds(15); + public static readonly TimeSpan DefaultClientHandshakeTimeout = TimeSpan.FromSeconds(15); } public static class ClaimType @@ -102,6 +105,7 @@ public static class Headers public const string AsrsHeaderPrefix = "X-ASRS-"; public const string AsrsServerId = AsrsHeaderPrefix + "Server-Id"; public const string AsrsMessageTracingId = AsrsHeaderPrefix + "Message-Tracing-Id"; + public const string AsrsIngressReloadMigrate = AsrsHeaderPrefix + "Ingress-Reload-Migrate"; public const string MicrosoftErrorCode = "x-ms-error-code"; } diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs index 476f22025..93f94dcd2 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs @@ -17,8 +17,6 @@ namespace Microsoft.Azure.SignalR { internal abstract class ServiceConnectionBase : IServiceConnection { - protected static readonly TimeSpan DefaultHandshakeTimeout = TimeSpan.FromSeconds(15); - // Service ping rate is 5 sec to let server know service status. Set timeout for 30 sec for some space. private static readonly TimeSpan DefaultServiceTimeout = TimeSpan.FromSeconds(30); @@ -389,7 +387,7 @@ protected virtual async Task HandshakeAsync(ConnectionContext context) using var cts = new CancellationTokenSource(); if (!Debugger.IsAttached) { - cts.CancelAfter(DefaultHandshakeTimeout); + cts.CancelAfter(Constants.Periods.DefaultServerHandshakeTimeout); } if (await ReceiveHandshakeResponseAsync(context.Transport.Input, cts.Token)) diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs index b9d8e57f3..e2542533f 100644 --- a/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs +++ b/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs @@ -21,8 +21,12 @@ using Microsoft.AspNetCore.Http.Features.Authentication; using Microsoft.AspNetCore.Localization; using Microsoft.AspNetCore.WebUtilities; +using Microsoft.Azure.SignalR.Common; using Microsoft.Azure.SignalR.Protocol; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Primitives; +using SignalRProtocol = Microsoft.AspNetCore.SignalR.Protocol; namespace Microsoft.Azure.SignalR { @@ -53,39 +57,33 @@ internal class ClientConnectionContext : ConnectionContext, IHttpContextFeature, IConnectionStatFeature { - private const int WritingState = 1; - private const int CompletedState = 2; private const int IdleState = 0; + private const int WritingState = 1; + private static readonly PipeOptions DefaultPipeOptions = new PipeOptions(pauseWriterThreshold: 0, resumeWriterThreshold: 0, readerScheduler: PipeScheduler.ThreadPool, useSynchronizationContext: false); - private readonly TaskCompletionSource _connectionEndTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - private readonly CancellationTokenSource _abortOutgoingCts = new CancellationTokenSource(); + private readonly TaskCompletionSource _connectionEndTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly object _heartbeatLock = new object(); + private volatile bool _abortOnClose = true; + private int _connectionState = IdleState; private List<(Action handler, object state)> _heartbeatHandlers; - private volatile bool _abortOnClose = true; - private long _lastMessageReceivedAt; private long _receivedBytes; - public bool IsMigrated { get; } - - public string Protocol { get; } - - public string InstanceId { get; } - // Send "Abort" to service on close except that Service asks SDK to close public bool AbortOnClose { @@ -93,33 +91,46 @@ public bool AbortOnClose set => _abortOnClose = value; } + public IDuplexPipe Application { get; set; } + public override string ConnectionId { get; set; } public override IFeatureCollection Features { get; } - public override IDictionary Items { get; set; } = new ConnectionItems(new ConcurrentDictionary()); + public HttpContext HttpContext { get; set; } - public override IDuplexPipe Transport { get; set; } + public string InstanceId { get; } - public IDuplexPipe Application { get; set; } + public bool IsMigrated { get; } - public ClaimsPrincipal User { get; set; } + public override IDictionary Items { get; set; } = new ConnectionItems(new ConcurrentDictionary()); + + public DateTime LastMessageReceivedAtUtc => new DateTime(Volatile.Read(ref _lastMessageReceivedAt), DateTimeKind.Utc); public Task LifetimeTask => _connectionEndTcs.Task; - public ServiceConnectionBase ServiceConnection { get; set; } + public CancellationToken OutgoingAborted => _abortOutgoingCts.Token; - public HttpContext HttpContext { get; set; } + public string Protocol { get; } - public CancellationToken OutgoingAborted => _abortOutgoingCts.Token; + public long ReceivedBytes => Volatile.Read(ref _receivedBytes); - public DateTime LastMessageReceivedAtUtc => new DateTime(Volatile.Read(ref _lastMessageReceivedAt), DateTimeKind.Utc); + public ServiceConnectionBase ServiceConnection { get; set; } public DateTime StartedAtUtc { get; } = DateTime.UtcNow; - public long ReceivedBytes => Volatile.Read(ref _receivedBytes); + public override IDuplexPipe Transport { get; set; } + + public ClaimsPrincipal User { get; set; } + + public ServiceConnectionWritter WritterDelegate { get; set; } + + public ILogger Logger { get; set; } = NullLogger.Instance; - public ClientConnectionContext(OpenConnectionMessage serviceMessage, Action configureContext = null, PipeOptions transportPipeOptions = null, PipeOptions appPipeOptions = null) + public ClientConnectionContext(OpenConnectionMessage serviceMessage, + Action configureContext = null, + PipeOptions transportPipeOptions = null, + PipeOptions appPipeOptions = null) { ConnectionId = serviceMessage.ConnectionId; Protocol = serviceMessage.Protocol; @@ -145,6 +156,21 @@ public ClientConnectionContext(OpenConnectionMessage serviceMessage, Action + /// Cancel the outgoing process + /// + public void CancelOutgoing(int millisecondsDelay = 0) + { + if (millisecondsDelay <= 0) + { + _abortOutgoingCts.Cancel(); + } + else + { + _abortOutgoingCts.CancelAfter(millisecondsDelay); + } + } + public void CompleteIncoming() { // always set the connection state to completing when this method is called @@ -161,52 +187,114 @@ public void CompleteIncoming() } } - public async Task WriteMessageAsync(ReadOnlySequence payload) + public void OnCompleted() { - var previousState = Interlocked.CompareExchange(ref _connectionState, WritingState, IdleState); - - // Write should not be called from multiple threads - Debug.Assert(previousState != WritingState); + _connectionEndTcs.TrySetResult(null); + } - if (previousState == CompletedState) + public void OnHeartbeat(Action action, object state) + { + lock (_heartbeatLock) { - // already completing, don't write anymore - return; + if (_heartbeatHandlers == null) + { + _heartbeatHandlers = new List<(Action handler, object state)>(); + } + _heartbeatHandlers.Add((action, state)); } + } + + public async Task ProcessApplicationTaskAsync(ConnectionDelegate connectionDelegate) + { + Exception exception = null; try { - _lastMessageReceivedAt = DateTime.UtcNow.Ticks; - _receivedBytes += payload.Length; - - // Start write - await WriteMessageAsyncCore(payload); + // Wait for the application task to complete + // application task can end when exception, or Context.Abort() from hub + await connectionDelegate(this); + } + catch (Exception ex) + { + // Capture the exception to communicate it to the transport (this isn't strictly required) + exception = ex; + throw; } finally { - // Try to set the connection to idle if it is in writing state, if it is in complete state, complete the tcs - previousState = Interlocked.CompareExchange(ref _connectionState, IdleState, WritingState); - if (previousState == CompletedState) - { - Application.Output.Complete(); - } + // Close the transport side since the application is no longer running + Transport.Output.Complete(exception); + Transport.Input.Complete(); } } - public void OnCompleted() + public async Task ProcessOutgoingMessagesAsync() { - _connectionEndTcs.TrySetResult(null); - } - - public void OnHeartbeat(Action action, object state) - { - lock (_heartbeatLock) + try { - if (_heartbeatHandlers == null) + if (IsMigrated) { - _heartbeatHandlers = new List<(Action handler, object state)>(); + using var timeoutToken = new CancellationTokenSource(Constants.Periods.DefaultClientHandshakeTimeout); + using var source = CancellationTokenSource.CreateLinkedTokenSource(OutgoingAborted, timeoutToken.Token); + + // A handshake response is not expected to be given + // if the connection was migrated from another server, + // since the connection hasn't been `dropped` from the client point of view. + if (!await SkipHandshakeResponse(source.Token)) + { + return; + } } - _heartbeatHandlers.Add((action, state)); + + while (true) + { + var result = await Application.Input.ReadAsync(OutgoingAborted); + + if (result.IsCanceled) + { + break; + } + + var buffer = result.Buffer; + + if (!buffer.IsEmpty) + { + try + { + // Forward the message to the service + await WritterDelegate(new ConnectionDataMessage(ConnectionId, buffer)); + } + catch (ServiceConnectionNotActiveException) + { + // Service connection not active means the transport layer for this connection is closed, no need to continue processing + break; + } + catch (Exception ex) + { + Log.ErrorSendingMessage(Logger, ex); + } + } + + if (result.IsCompleted) + { + // This connection ended (the application itself shut down) we should remove it from the list of connections + break; + } + + Application.Input.AdvanceTo(buffer.End); + } + } + catch (Exception ex) + { + // The exception means application fail to process input anymore + // Cancel any pending flush so that we can quit and perform disconnect + // Here is abort close and WaitOnApplicationTask will send close message to notify client to disconnect + Log.SendLoopStopped(Logger, ConnectionId, ex); + Application.Output.CancelPendingFlush(); + } + finally + { + Application.Input.Complete(); } } @@ -226,18 +314,35 @@ public void TickHeartbeat() } } - /// - /// Cancel the outgoing process - /// - public void CancelOutgoing(int millisecondsDelay = 0) + public async Task WriteMessageAsync(ReadOnlySequence payload) { - if (millisecondsDelay <= 0) + var previousState = Interlocked.CompareExchange(ref _connectionState, WritingState, IdleState); + + // Write should not be called from multiple threads + Debug.Assert(previousState != WritingState); + + if (previousState == CompletedState) { - _abortOutgoingCts.Cancel(); + // already completing, don't write anymore + return; } - else + + try { - _abortOutgoingCts.CancelAfter(millisecondsDelay); + _lastMessageReceivedAt = DateTime.UtcNow.Ticks; + _receivedBytes += payload.Length; + + // Start write + await WriteMessageAsyncCore(payload); + } + finally + { + // Try to set the connection to idle if it is in writing state, if it is in complete state, complete the tcs + previousState = Interlocked.CompareExchange(ref _connectionState, IdleState, WritingState); + if (previousState == CompletedState) + { + Application.Output.Complete(); + } } } @@ -307,28 +412,6 @@ private FeatureCollection BuildFeatures(OpenConnectionMessage serviceMessage) return features; } - private async Task WriteMessageAsyncCore(ReadOnlySequence payload) - { - if (payload.IsSingleSegment) - { - // Write the raw connection payload to the pipe let the upstream handle it - await Application.Output.WriteAsync(payload.First); - } - else - { - var position = payload.Start; - while (payload.TryGet(ref position, out var memory)) - { - var result = await Application.Output.WriteAsync(memory); - if (result.IsCanceled) - { - // IsCanceled when CancelPendingFlush is called - break; - } - } - } - } - private HttpContext BuildHttpContext(OpenConnectionMessage message) { var httpContextFeatures = new FeatureCollection(); @@ -362,5 +445,64 @@ private string GetInstanceId(IDictionary header) } return string.Empty; } + + private async Task SkipHandshakeResponse(CancellationToken token) + { + try + { + while (true) + { + var result = await Application.Input.ReadAsync(token); + if (result.IsCanceled || token.IsCancellationRequested) + { + return false; + } + + var buffer = result.Buffer; + if (buffer.IsEmpty) + { + continue; + } + + if (SignalRProtocol.HandshakeProtocol.TryParseResponseMessage(ref buffer, out var message)) + { + Application.Input.AdvanceTo(buffer.Start); + return true; + } + + if (result.IsCompleted) + { + return false; + } + } + } + catch (Exception ex) + { + Log.ErrorSkippingHandshakeResponse(Logger, ex); + } + return false; + } + + private async Task WriteMessageAsyncCore(ReadOnlySequence payload) + { + if (payload.IsSingleSegment) + { + // Write the raw connection payload to the pipe let the upstream handle it + await Application.Output.WriteAsync(payload.First); + } + else + { + var position = payload.Start; + while (payload.TryGet(ref position, out var memory)) + { + var result = await Application.Output.WriteAsync(memory); + if (result.IsCanceled) + { + // IsCanceled when CancelPendingFlush is called + break; + } + } + } + } } } diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/Log.cs b/src/Microsoft.Azure.SignalR/ServerConnections/Log.cs new file mode 100644 index 000000000..71a015a4c --- /dev/null +++ b/src/Microsoft.Azure.SignalR/ServerConnections/Log.cs @@ -0,0 +1,184 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using Microsoft.Azure.SignalR.Protocol; +using Microsoft.Extensions.Logging; + +namespace Microsoft.Azure.SignalR +{ + internal static class Log + { + // Category: ServiceConnection + private static readonly Action _waitingForTransport = + LoggerMessage.Define(LogLevel.Debug, new EventId(2, "WaitingForTransport"), "Waiting for the transport layer to end."); + + private static readonly Action _transportComplete = + LoggerMessage.Define(LogLevel.Debug, new EventId(2, "TransportComplete"), "Transport completed."); + + private static readonly Action _closeTimedOut = + LoggerMessage.Define(LogLevel.Debug, new EventId(3, "CloseTimedOut"), "Timed out waiting for close message sending to client, aborting the connection."); + + private static readonly Action _waitingForApplication = + LoggerMessage.Define(LogLevel.Debug, new EventId(4, "WaitingForApplication"), "Waiting for the application to end."); + + private static readonly Action _applicationComplete = + LoggerMessage.Define(LogLevel.Debug, new EventId(4, "ApplicationComplete"), "Application task completes."); + + private static readonly Action _failedToCleanupConnections = + LoggerMessage.Define(LogLevel.Error, new EventId(5, "FailedToCleanupConnection"), "Failed to clean up client connections."); + + private static readonly Action _errorSendingMessage = + LoggerMessage.Define(LogLevel.Error, new EventId(6, "ErrorSendingMessage"), "Error while sending message to the service, the connection carrying the traffic is dropped. Error detail: {message}"); + + private static readonly Action _sendLoopStopped = + LoggerMessage.Define(LogLevel.Error, new EventId(7, "SendLoopStopped"), "Error while processing messages from {TransportConnectionId}."); + + private static readonly Action _applicationTaskFailed = + LoggerMessage.Define(LogLevel.Error, new EventId(8, "ApplicationTaskFailed"), "Application task failed."); + + private static readonly Action _failToWriteMessageToApplication = + LoggerMessage.Define(LogLevel.Error, new EventId(9, "FailToWriteMessageToApplication"), "Failed to write message {tracingId} to {TransportConnectionId}."); + + private static readonly Action _receivedMessageForNonExistentConnection = + LoggerMessage.Define(LogLevel.Warning, new EventId(10, "ReceivedMessageForNonExistentConnection"), "Received message {tracingId} for connection {TransportConnectionId} which does not exist."); + + private static readonly Action _connectedStarting = + LoggerMessage.Define(LogLevel.Information, new EventId(11, "ConnectedStarting"), "Connection {TransportConnectionId} started."); + + private static readonly Action _connectedEnding = + LoggerMessage.Define(LogLevel.Information, new EventId(12, "ConnectedEnding"), "Connection {TransportConnectionId} ended."); + + private static readonly Action _closeConnection = + LoggerMessage.Define(LogLevel.Debug, new EventId(13, "CloseConnection"), "Sending close connection message to the service for {TransportConnectionId}."); + + private static readonly Action _writeMessageToApplication = + LoggerMessage.Define(LogLevel.Trace, new EventId(19, "WriteMessageToApplication"), "Writing {ReceivedBytes} to connection {TransportConnectionId}."); + + private static readonly Action _serviceConnectionConnected = + LoggerMessage.Define(LogLevel.Debug, new EventId(20, "ServiceConnectionConnected"), "Service connection {ServiceConnectionId} connected."); + + private static readonly Action _applicationTaskCancelled = + LoggerMessage.Define(LogLevel.Error, new EventId(21, "ApplicationTaskCancelled"), "Cancelled running application code, probably caused by time out."); + + private static readonly Action _migrationStarting = + LoggerMessage.Define(LogLevel.Debug, new EventId(22, "MigrationStarting"), "Connection {TransportConnectionId} migrated from another server."); + + private static readonly Action _errorSkippingHandshakeResponse = + LoggerMessage.Define(LogLevel.Error, new EventId(23, "ErrorSkippingHandshakeResponse"), "Error while skipping handshake response during migration, the connection will be dropped on the client-side. Error detail: {message}"); + + private static readonly Action _processConnectionFailed = + LoggerMessage.Define(LogLevel.Error, new EventId(24, "ProcessConnectionFailed"), "Error processing the connection {TransportConnectionId}."); + + private static readonly Action _closingClientConnections = + LoggerMessage.Define(LogLevel.Information, new EventId(25, "ClosingClientConnections"), "Closing {ClientCount} client connection(s) for server connection {ServerConnectionId}."); + + private static readonly Action _detectedLongRunningApplicationTask = + LoggerMessage.Define(LogLevel.Warning, new EventId(26, "DetectedLongRunningApplicationTask"), "The connection {TransportConnectionId} has a long running application logic that prevents the connection from complete."); + + public static void DetectedLongRunningApplicationTask(ILogger logger, string connectionId) + { + _detectedLongRunningApplicationTask(logger, connectionId, null); + } + + public static void WaitingForTransport(ILogger logger) + { + _waitingForTransport(logger, null); + } + + public static void TransportComplete(ILogger logger) + { + _transportComplete(logger, null); + } + + public static void CloseTimedOut(ILogger logger) + { + _closeTimedOut(logger, null); + } + + public static void WaitingForApplication(ILogger logger) + { + _waitingForApplication(logger, null); + } + + public static void ApplicationComplete(ILogger logger) + { + _applicationComplete(logger, null); + } + + public static void ClosingClientConnections(ILogger logger, int clientCount, string serverConnectionId) + { + _closingClientConnections(logger, clientCount, serverConnectionId, null); + } + + public static void FailedToCleanupConnections(ILogger logger, Exception exception) + { + _failedToCleanupConnections(logger, exception); + } + + public static void ErrorSendingMessage(ILogger logger, Exception exception) + { + _errorSendingMessage(logger, exception.Message, exception); + } + + public static void SendLoopStopped(ILogger logger, string connectionId, Exception exception) + { + _sendLoopStopped(logger, connectionId, exception); + } + + public static void ApplicationTaskFailed(ILogger logger, Exception exception) + { + _applicationTaskFailed(logger, exception); + } + + public static void FailToWriteMessageToApplication(ILogger logger, ConnectionDataMessage message, Exception exception) + { + _failToWriteMessageToApplication(logger, message.TracingId, message.ConnectionId, exception); + } + + public static void ReceivedMessageForNonExistentConnection(ILogger logger, ConnectionDataMessage message) + { + _receivedMessageForNonExistentConnection(logger, message.TracingId, message.ConnectionId, null); + } + + public static void ConnectedStarting(ILogger logger, string connectionId) + { + _connectedStarting(logger, connectionId, null); + } + + public static void MigrationStarting(ILogger logger, string connectionId) + { + _migrationStarting(logger, connectionId, null); + } + + public static void ConnectedEnding(ILogger logger, string connectionId) + { + _connectedEnding(logger, connectionId, null); + } + + public static void CloseConnection(ILogger logger, string connectionId) + { + _closeConnection(logger, connectionId, null); + } + + public static void WriteMessageToApplication(ILogger logger, long count, string connectionId) + { + _writeMessageToApplication(logger, count, connectionId, null); + } + + public static void ApplicationTaskCancelled(ILogger logger) + { + _applicationTaskCancelled(logger, null); + } + + public static void ErrorSkippingHandshakeResponse(ILogger logger, Exception ex) + { + _errorSkippingHandshakeResponse(logger, ex.Message, ex); + } + + public static void ProcessConnectionFailed(ILogger logger, string connectionId, Exception exception) + { + _processConnectionFailed(logger, connectionId, exception); + } + } +} diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.Log.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.Log.cs deleted file mode 100644 index a27c268e5..000000000 --- a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.Log.cs +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -using System; -using Microsoft.Azure.SignalR.Protocol; -using Microsoft.Extensions.Logging; - -namespace Microsoft.Azure.SignalR -{ - internal partial class ServiceConnection - { - private static class Log - { - // Category: ServiceConnection - private static readonly Action _waitingForTransport = - LoggerMessage.Define(LogLevel.Debug, new EventId(2, "WaitingForTransport"), "Waiting for the transport layer to end."); - - private static readonly Action _transportComplete = - LoggerMessage.Define(LogLevel.Debug, new EventId(2, "TransportComplete"), "Transport completed."); - - private static readonly Action _closeTimedOut = - LoggerMessage.Define(LogLevel.Debug, new EventId(3, "CloseTimedOut"), "Timed out waiting for close message sending to client, aborting the connection."); - - private static readonly Action _waitingForApplication = - LoggerMessage.Define(LogLevel.Debug, new EventId(4, "WaitingForApplication"), "Waiting for the application to end."); - - private static readonly Action _applicationComplete = - LoggerMessage.Define(LogLevel.Debug, new EventId(4, "ApplicationComplete"), "Application task completes."); - - private static readonly Action _failedToCleanupConnections = - LoggerMessage.Define(LogLevel.Error, new EventId(5, "FailedToCleanupConnection"), "Failed to clean up client connections."); - - private static readonly Action _errorSendingMessage = - LoggerMessage.Define(LogLevel.Error, new EventId(6, "ErrorSendingMessage"), "Error while sending message to the service, the connection carrying the traffic is dropped. Error detail: {message}"); - - private static readonly Action _sendLoopStopped = - LoggerMessage.Define(LogLevel.Error, new EventId(7, "SendLoopStopped"), "Error while processing messages from {TransportConnectionId}."); - - private static readonly Action _applicationTaskFailed = - LoggerMessage.Define(LogLevel.Error, new EventId(8, "ApplicationTaskFailed"), "Application task failed."); - - private static readonly Action _failToWriteMessageToApplication = - LoggerMessage.Define(LogLevel.Error, new EventId(9, "FailToWriteMessageToApplication"), "Failed to write message {tracingId} to {TransportConnectionId}."); - - private static readonly Action _receivedMessageForNonExistentConnection = - LoggerMessage.Define(LogLevel.Warning, new EventId(10, "ReceivedMessageForNonExistentConnection"), "Received message {tracingId} for connection {TransportConnectionId} which does not exist."); - - private static readonly Action _connectedStarting = - LoggerMessage.Define(LogLevel.Information, new EventId(11, "ConnectedStarting"), "Connection {TransportConnectionId} started."); - - private static readonly Action _connectedEnding = - LoggerMessage.Define(LogLevel.Information, new EventId(12, "ConnectedEnding"), "Connection {TransportConnectionId} ended."); - - private static readonly Action _closeConnection = - LoggerMessage.Define(LogLevel.Debug, new EventId(13, "CloseConnection"), "Sending close connection message to the service for {TransportConnectionId}."); - - private static readonly Action _writeMessageToApplication = - LoggerMessage.Define(LogLevel.Trace, new EventId(19, "WriteMessageToApplication"), "Writing {ReceivedBytes} to connection {TransportConnectionId}."); - - private static readonly Action _serviceConnectionConnected = - LoggerMessage.Define(LogLevel.Debug, new EventId(20, "ServiceConnectionConnected"), "Service connection {ServiceConnectionId} connected."); - - private static readonly Action _applicationTaskCancelled = - LoggerMessage.Define(LogLevel.Error, new EventId(21, "ApplicationTaskCancelled"), "Cancelled running application code, probably caused by time out."); - - private static readonly Action _migrationStarting = - LoggerMessage.Define(LogLevel.Debug, new EventId(22, "MigrationStarting"), "Connection {TransportConnectionId} migrated from another server."); - - private static readonly Action _errorSkippingHandshakeResponse = - LoggerMessage.Define(LogLevel.Error, new EventId(23, "ErrorSkippingHandshakeResponse"), "Error while skipping handshake response during migration, the connection will be dropped on the client-side. Error detail: {message}"); - - private static readonly Action _processConnectionFailed = - LoggerMessage.Define(LogLevel.Error, new EventId(24, "ProcessConnectionFailed"), "Error processing the connection {TransportConnectionId}."); - - private static readonly Action _closingClientConnections = - LoggerMessage.Define(LogLevel.Information, new EventId(25, "ClosingClientConnections"), "Closing {ClientCount} client connection(s) for server connection {ServerConnectionId}."); - - - private static readonly Action _detectedLongRunningApplicationTask = - LoggerMessage.Define(LogLevel.Warning, new EventId(26, "DetectedLongRunningApplicationTask"), "The connection {TransportConnectionId} has a long running application logic that prevents the connection from complete."); - - public static void DetectedLongRunningApplicationTask(ILogger logger, string connectionId) - { - _detectedLongRunningApplicationTask(logger, connectionId, null); - } - - public static void WaitingForTransport(ILogger logger) - { - _waitingForTransport(logger, null); - } - - public static void TransportComplete(ILogger logger) - { - _transportComplete(logger, null); - } - - public static void CloseTimedOut(ILogger logger) - { - _closeTimedOut(logger, null); - } - - public static void WaitingForApplication(ILogger logger) - { - _waitingForApplication(logger, null); - } - - public static void ApplicationComplete(ILogger logger) - { - _applicationComplete(logger, null); - } - - public static void ClosingClientConnections(ILogger logger, int clientCount, string serverConnectionId) - { - _closingClientConnections(logger, clientCount, serverConnectionId, null); - } - - public static void FailedToCleanupConnections(ILogger logger, Exception exception) - { - _failedToCleanupConnections(logger, exception); - } - - public static void ErrorSendingMessage(ILogger logger, Exception exception) - { - _errorSendingMessage(logger, exception.Message, exception); - } - - public static void SendLoopStopped(ILogger logger, string connectionId, Exception exception) - { - _sendLoopStopped(logger, connectionId, exception); - } - - public static void ApplicationTaskFailed(ILogger logger, Exception exception) - { - _applicationTaskFailed(logger, exception); - } - - public static void FailToWriteMessageToApplication(ILogger logger, ConnectionDataMessage message, Exception exception) - { - _failToWriteMessageToApplication(logger, message.TracingId, message.ConnectionId, exception); - } - - public static void ReceivedMessageForNonExistentConnection(ILogger logger, ConnectionDataMessage message) - { - _receivedMessageForNonExistentConnection(logger, message.TracingId, message.ConnectionId, null); - } - - public static void ConnectedStarting(ILogger logger, string connectionId) - { - _connectedStarting(logger, connectionId, null); - } - - public static void MigrationStarting(ILogger logger, string connectionId) - { - _migrationStarting(logger, connectionId, null); - } - - public static void ConnectedEnding(ILogger logger, string connectionId) - { - _connectedEnding(logger, connectionId, null); - } - - public static void CloseConnection(ILogger logger, string connectionId) - { - _closeConnection(logger, connectionId, null); - } - - public static void WriteMessageToApplication(ILogger logger, long count, string connectionId) - { - _writeMessageToApplication(logger, count, connectionId, null); - } - - public static void ApplicationTaskCancelled(ILogger logger) - { - _applicationTaskCancelled(logger, null); - } - - public static void ErrorSkippingHandshakeResponse(ILogger logger, Exception ex) - { - _errorSkippingHandshakeResponse(logger, ex.Message, ex); - } - - public static void ProcessConnectionFailed(ILogger logger, string connectionId, Exception exception) - { - _processConnectionFailed(logger, connectionId, exception); - } - } - } -} \ No newline at end of file diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs index 1a5aeda94..6caa2bee0 100644 --- a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs +++ b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs @@ -5,15 +5,12 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; -using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http; -using Microsoft.Azure.SignalR.Common; using Microsoft.Azure.SignalR.Protocol; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Primitives; -using SignalRProtocol = Microsoft.AspNetCore.SignalR.Protocol; namespace Microsoft.Azure.SignalR { @@ -114,8 +111,23 @@ protected override ReadOnlyMemory GetPingMessage() }); } + protected Task OnIngressReloadAsync() + { + if (!_clientConnectionManager.ClientConnections.TryGetValue(ConnectionId, out var connection)) + { + return Task.CompletedTask; + } + + return Task.CompletedTask; + } + protected override Task OnClientConnectedAsync(OpenConnectionMessage message) { + if (message.Headers.TryGetValue(Constants.Headers.AsrsIngressReloadMigrate, out var _)) + { + return OnIngressReloadAsync(); + } + var connection = _clientConnectionFactory.CreateConnection(message, ConfigureContext); connection.ServiceConnection = this; @@ -229,11 +241,14 @@ private async Task ProcessClientConnectionAsync(ClientConnectionContext connecti { try { + connection.WritterDelegate = WriteAsync; + connection.Logger = Logger; + // Writing from the application to the service - var transport = ProcessOutgoingMessagesAsync(connection, connection.OutgoingAborted); + var transport = connection.ProcessOutgoingMessagesAsync(); // Waiting for the application to shutdown so we can clean up the connection - var app = ProcessApplicationTaskAsyncCore(connection); + var app = connection.ProcessApplicationTaskAsync(_connectionDelegate); var task = await Task.WhenAny(app, transport); @@ -303,143 +318,12 @@ private async Task ProcessClientConnectionAsync(ClientConnectionContext connecti } } - private async Task SkipHandshakeResponse(ClientConnectionContext connection, CancellationToken token) - { - try - { - while (true) - { - var result = await connection.Application.Input.ReadAsync(token); - if (result.IsCanceled || token.IsCancellationRequested) - { - return false; - } - - var buffer = result.Buffer; - if (buffer.IsEmpty) - { - continue; - } - - if (SignalRProtocol.HandshakeProtocol.TryParseResponseMessage(ref buffer, out var message)) - { - connection.Application.Input.AdvanceTo(buffer.Start); - return true; - } - - if (result.IsCompleted) - { - return false; - } - } - } - catch (Exception ex) - { - Log.ErrorSkippingHandshakeResponse(Logger, ex); - } - return false; - } - - private async Task ProcessOutgoingMessagesAsync(ClientConnectionContext connection, CancellationToken token = default) - { - try - { - if (connection.IsMigrated) - { - using var timeoutToken = new CancellationTokenSource(DefaultHandshakeTimeout); - using var source = CancellationTokenSource.CreateLinkedTokenSource(token, timeoutToken.Token); - - // A handshake response is not expected to be given - // if the connection was migrated from another server, - // since the connection hasn't been `dropped` from the client point of view. - if (!await SkipHandshakeResponse(connection, source.Token)) - { - return; - } - } - - while (true) - { - var result = await connection.Application.Input.ReadAsync(token); - - if (result.IsCanceled) - { - break; - } - - var buffer = result.Buffer; - - if (!buffer.IsEmpty) - { - try - { - // Forward the message to the service - await WriteAsync(new ConnectionDataMessage(connection.ConnectionId, buffer)); - } - catch (ServiceConnectionNotActiveException) - { - // Service connection not active means the transport layer for this connection is closed, no need to continue processing - break; - } - catch (Exception ex) - { - Log.ErrorSendingMessage(Logger, ex); - } - } - - if (result.IsCompleted) - { - // This connection ended (the application itself shut down) we should remove it from the list of connections - break; - } - - connection.Application.Input.AdvanceTo(buffer.End); - } - } - catch (Exception ex) - { - // The exception means application fail to process input anymore - // Cancel any pending flush so that we can quit and perform disconnect - // Here is abort close and WaitOnApplicationTask will send close message to notify client to disconnect - Log.SendLoopStopped(Logger, connection.ConnectionId, ex); - connection.Application.Output.CancelPendingFlush(); - } - finally - { - connection.Application.Input.Complete(); - } - } - private void AddClientConnection(ClientConnectionContext connection, OpenConnectionMessage message) { _clientConnectionManager.TryAddClientConnection(connection); _connectionIds.TryAdd(connection.ConnectionId, connection.InstanceId); } - private async Task ProcessApplicationTaskAsyncCore(ClientConnectionContext connection) - { - Exception exception = null; - - try - { - // Wait for the application task to complete - // application task can end when exception, or Context.Abort() from hub - await _connectionDelegate(connection); - } - catch (Exception ex) - { - // Capture the exception to communicate it to the transport (this isn't strictly required) - exception = ex; - throw; - } - finally - { - // Close the transport side since the application is no longer running - connection.Transport.Output.Complete(exception); - connection.Transport.Input.Complete(); - } - } - private async Task PerformDisconnectAsyncCore(string connectionId) { var connection = RemoveClientConnection(connectionId); diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnectionWritter.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnectionWritter.cs new file mode 100644 index 000000000..bf5a14b0d --- /dev/null +++ b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnectionWritter.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Threading.Tasks; +using Microsoft.Azure.SignalR.Protocol; + +namespace Microsoft.Azure.SignalR +{ + public delegate Task ServiceConnectionWritter(ServiceMessage serviceMessage); +}