diff --git a/eng/PatchConfig.props b/eng/PatchConfig.props index 3ba7babc5710..83f2b2449755 100644 --- a/eng/PatchConfig.props +++ b/eng/PatchConfig.props @@ -44,4 +44,10 @@ Later on, this will be checked using this condition: Microsoft.AspNetCore.CookiePolicy; + + + Microsoft.AspNetCore.Http.Connections; + Microsoft.AspNetCore.SignalR.Core; + + diff --git a/src/SignalR/clients/ts/FunctionalTests/selenium/run-tests.ts b/src/SignalR/clients/ts/FunctionalTests/selenium/run-tests.ts index c74603b12c56..6f548f153484 100644 --- a/src/SignalR/clients/ts/FunctionalTests/selenium/run-tests.ts +++ b/src/SignalR/clients/ts/FunctionalTests/selenium/run-tests.ts @@ -1,7 +1,8 @@ import { ChildProcess, spawn } from "child_process"; -import * as fs from "fs"; +import * as _fs from "fs"; import { EOL } from "os"; import * as path from "path"; +import { promisify } from "util"; import { PassThrough, Readable } from "stream"; import { run } from "../../webdriver-tap-runner/lib"; @@ -9,6 +10,16 @@ import { run } from "../../webdriver-tap-runner/lib"; import * as _debug from "debug"; const debug = _debug("signalr-functional-tests:run"); +const ARTIFACTS_DIR = path.resolve(__dirname, "..", "..", "..", "..", "artifacts"); +const LOGS_DIR = path.resolve(ARTIFACTS_DIR, "logs"); + +// Promisify things from fs we want to use. +const fs = { + createWriteStream: _fs.createWriteStream, + exists: promisify(_fs.exists), + mkdir: promisify(_fs.mkdir), +}; + process.on("unhandledRejection", (reason) => { console.error(`Unhandled promise rejection: ${reason}`); process.exit(1); @@ -102,6 +113,13 @@ if (chromePath) { try { const serverPath = path.resolve(__dirname, "..", "bin", configuration, "netcoreapp2.1", "FunctionalTests.dll"); + if (!await fs.exists(ARTIFACTS_DIR)) { + await fs.mkdir(ARTIFACTS_DIR); + } + if (!await fs.exists(LOGS_DIR)) { + await fs.mkdir(LOGS_DIR); + } + debug(`Launching Functional Test Server: ${serverPath}`); const dotnet = spawn("dotnet", [serverPath], { env: { @@ -117,6 +135,9 @@ if (chromePath) { } } + const logStream = fs.createWriteStream(path.resolve(LOGS_DIR, "ts.functionaltests.dotnet.log")); + dotnet.stdout.pipe(logStream); + process.on("SIGINT", cleanup); process.on("exit", cleanup); diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs index 045e821ee108..425c15d76e5d 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs @@ -274,13 +274,35 @@ private async Task WaitOnTasks(Task applicationTask, Task transportTask, bool cl // Cancel any pending flushes from back pressure Application?.Output.CancelPendingFlush(); - // Shutdown both sides and wait for nothing + // Normally it isn't safe to try and acquire this lock because the Send can hold onto it for a long time if there is backpressure + // It is safe to wait for this lock now because the Send will be in one of 4 states + // 1. In the middle of a write which is in the middle of being canceled by the CancelPendingFlush above, when it throws + // an OperationCanceledException it will complete the PipeWriter which will make any other Send waiting on the lock + // throw an InvalidOperationException if they call Write + // 2. About to write and see that there is a pending cancel from the CancelPendingFlush, go to 1 to see what happens + // 3. Enters the Send and sees the Dispose state from DisposeAndRemoveAsync and releases the lock + // 4. No Send in progress + await WriteLock.WaitAsync(); + try + { + // Complete the applications read loop + Application?.Output.Complete(transportTask.Exception?.InnerException); + } + finally + { + WriteLock.Release(); + } + + Log.WaitingForTransportAndApplication(_logger, TransportType); + + // Wait for application so we can complete the writer safely + await applicationTask.NoThrow(); + + // Shutdown application side now that it's finished Transport?.Output.Complete(applicationTask.Exception?.InnerException); - Application?.Output.Complete(transportTask.Exception?.InnerException); try { - Log.WaitingForTransportAndApplication(_logger, TransportType); // A poorly written application *could* in theory get stuck forever and it'll show up as a memory leak await Task.WhenAll(applicationTask, transportTask); } diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs index 50910bccfec1..2d8b5c24ade9 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs @@ -511,6 +511,14 @@ private async Task ProcessSend(HttpContext context, HttpConnectionDispatcherOpti context.Response.StatusCode = StatusCodes.Status404NotFound; context.Response.ContentType = "text/plain"; + + // There are no writes anymore (since this is the write "loop") + // So it is safe to complete the writer + // We complete the writer here because we already have the WriteLock acquired + // and it's unsafe to complete outside of the lock + // Other code isn't guaranteed to be able to acquire the lock before another write + // even if CancelPendingFlush is called, and the other write could hang if there is backpressure + connection.Application.Output.Complete(); return; } @@ -549,11 +557,8 @@ private async Task ProcessDeleteAsync(HttpContext context) Log.TerminatingConection(_logger); - // Complete the receiving end of the pipe - connection.Application.Output.Complete(); - - // Dispose the connection gracefully, but don't wait for it. We assign it here so we can wait in tests - connection.DisposeAndRemoveTask = _manager.DisposeAndRemoveAsync(connection, closeGracefully: true); + // Dispose the connection, but don't wait for it. We assign it here so we can wait in tests + connection.DisposeAndRemoveTask = _manager.DisposeAndRemoveAsync(connection, closeGracefully: false); context.Response.StatusCode = StatusCodes.Status202Accepted; context.Response.ContentType = "text/plain"; diff --git a/src/SignalR/common/Http.Connections/src/Internal/TaskExtensions.cs b/src/SignalR/common/Http.Connections/src/Internal/TaskExtensions.cs new file mode 100644 index 000000000000..a901379b75c7 --- /dev/null +++ b/src/SignalR/common/Http.Connections/src/Internal/TaskExtensions.cs @@ -0,0 +1,27 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Runtime.CompilerServices; + +namespace System.Threading.Tasks +{ + internal static class TaskExtensions + { + public static async Task NoThrow(this Task task) + { + await new NoThrowAwaiter(task); + } + } + + internal readonly struct NoThrowAwaiter : ICriticalNotifyCompletion + { + private readonly Task _task; + public NoThrowAwaiter(Task task) { _task = task; } + public NoThrowAwaiter GetAwaiter() => this; + public bool IsCompleted => _task.IsCompleted; + // Observe exception + public void GetResult() { _ = _task.Exception; } + public void OnCompleted(Action continuation) => _task.GetAwaiter().OnCompleted(continuation); + public void UnsafeOnCompleted(Action continuation) => OnCompleted(continuation); + } +} diff --git a/src/SignalR/common/Shared/PipeWriterStream.cs b/src/SignalR/common/Shared/PipeWriterStream.cs index eb5b6d5addef..245731bfd925 100644 --- a/src/SignalR/common/Shared/PipeWriterStream.cs +++ b/src/SignalR/common/Shared/PipeWriterStream.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -76,7 +76,15 @@ private ValueTask WriteCoreAsync(ReadOnlyMemory source, CancellationToken _length += source.Length; var task = _pipeWriter.WriteAsync(source); - if (!task.IsCompletedSuccessfully) + if (task.IsCompletedSuccessfully) + { + // Cancellation can be triggered by PipeWriter.CancelPendingFlush + if (task.Result.IsCanceled) + { + throw new OperationCanceledException(); + } + } + else if (!task.IsCompletedSuccessfully) { return WriteSlowAsync(task); } diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index 5a1049e78042..085249438aa5 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -33,6 +33,7 @@ public class HubConnectionContext private long _lastSendTimestamp = Stopwatch.GetTimestamp(); private ReadOnlyMemory _cachedPingMessage; + private volatile bool _connectionAborted; /// /// Initializes a new instance of the class. @@ -99,6 +100,12 @@ public virtual ValueTask WriteAsync(HubMessage message, CancellationToken cancel return new ValueTask(WriteSlowAsync(message)); } + if (_connectionAborted) + { + _writeLock.Release(); + return default; + } + // This method should never throw synchronously var task = WriteCore(message); @@ -129,6 +136,12 @@ public virtual ValueTask WriteAsync(SerializedHubMessage message, CancellationTo return new ValueTask(WriteSlowAsync(message)); } + if (_connectionAborted) + { + _writeLock.Release(); + return default; + } + // This method should never throw synchronously var task = WriteCore(message); @@ -158,6 +171,8 @@ private ValueTask WriteCore(HubMessage message) { Log.FailedWritingMessage(_logger, ex); + Abort(); + return new ValueTask(new FlushResult(isCanceled: false, isCompleted: true)); } } @@ -175,6 +190,8 @@ private ValueTask WriteCore(SerializedHubMessage message) { Log.FailedWritingMessage(_logger, ex); + Abort(); + return new ValueTask(new FlushResult(isCanceled: false, isCompleted: true)); } } @@ -188,6 +205,8 @@ private async Task CompleteWriteAsync(ValueTask task) catch (Exception ex) { Log.FailedWritingMessage(_logger, ex); + + Abort(); } finally { @@ -201,6 +220,11 @@ private async Task WriteSlowAsync(HubMessage message) await _writeLock.WaitAsync(); try { + if (_connectionAborted) + { + return; + } + // Failed to get the lock immediately when entering WriteAsync so await until it is available await WriteCore(message); @@ -208,6 +232,8 @@ private async Task WriteSlowAsync(HubMessage message) catch (Exception ex) { Log.FailedWritingMessage(_logger, ex); + + Abort(); } finally { @@ -219,6 +245,11 @@ private async Task WriteSlowAsync(SerializedHubMessage message) { try { + if (_connectionAborted) + { + return; + } + // Failed to get the lock immediately when entering WriteAsync so await until it is available await _writeLock.WaitAsync(); @@ -227,6 +258,8 @@ private async Task WriteSlowAsync(SerializedHubMessage message) catch (Exception ex) { Log.FailedWritingMessage(_logger, ex); + + Abort(); } finally { @@ -250,6 +283,11 @@ private async Task TryWritePingSlowAsync() { try { + if (_connectionAborted) + { + return; + } + await _connectionContext.Transport.Output.WriteAsync(_cachedPingMessage); Log.SentPing(_logger); @@ -257,6 +295,8 @@ private async Task TryWritePingSlowAsync() catch (Exception ex) { Log.FailedWritingMessage(_logger, ex); + + Abort(); } finally { @@ -293,6 +333,12 @@ private async Task WriteHandshakeResponseAsync(HandshakeResponseMessage message) /// public virtual void Abort() { + _connectionAborted = true; + + // Cancel any current writes or writes that are about to happen and have already gone past the _connectionAborted bool + // We have to do this outside of the lock otherwise it could hang if the write is observing backpressure + _connectionContext.Transport.Output.CancelPendingFlush(); + // If we already triggered the token then noop, this isn't thread safe but it's good enough // to avoid spawning a new task in the most common cases if (_connectionAbortedTokenSource.IsCancellationRequested) @@ -423,9 +469,24 @@ internal void Abort(Exception exception) internal Task AbortAsync() { Abort(); + + // Acquire lock to make sure all writes are completed + if (!_writeLock.Wait(0)) + { + return AbortAsyncSlow(); + } + + _writeLock.Release(); return _abortCompletedTcs.Task; } + private async Task AbortAsyncSlow() + { + await _writeLock.WaitAsync(); + _writeLock.Release(); + await _abortCompletedTcs.Task; + } + private void KeepAliveTick() { var timestamp = Stopwatch.GetTimestamp(); diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index 06fa3841eb9c..c3100870953b 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -79,9 +79,11 @@ public async Task AbortFromHubMethodForcesClientDisconnect() { var connectionHandlerTask = await client.ConnectAsync(connectionHandler); - await client.InvokeAsync(nameof(AbortHub.Kill)); + await client.SendInvocationAsync(nameof(AbortHub.Kill)); await connectionHandlerTask.OrTimeout(); + + Assert.Null(client.TryRead()); } }