diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs index 3b8aaae6605ed..caf39a7ac7bd0 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs @@ -33,6 +33,7 @@ internal sealed class Http3Connection : HttpConnectionBase // Our control stream. private QuicStream? _clientControl; + private Task _sendSettingsTask; // Server-advertised SETTINGS_MAX_FIELD_SECTION_SIZE // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4.1-2.2.1 @@ -88,7 +89,7 @@ public Http3Connection(HttpConnectionPool pool, HttpAuthority authority, QuicCon } // Errors are observed via Abort(). - _ = SendSettingsAsync(); + _sendSettingsTask = SendSettingsAsync(); // This process is cleaned up when _connection is disposed, and errors are observed via Abort(). _ = AcceptStreamsAsync(); @@ -150,6 +151,7 @@ private void CheckForShutdown() if (_clientControl != null) { + await _sendSettingsTask.ConfigureAwait(false); await _clientControl.DisposeAsync().ConfigureAwait(false); _clientControl = null; } @@ -486,7 +488,7 @@ private async Task ProcessServerStreamAsync(QuicStream stream) if (bytesRead == 0) { - // https://quicwg.org/base-drafts/draft-ietf-quic-http.html#name-unidirectional-streams + // https://www.rfc-editor.org/rfc/rfc9114.html#name-unidirectional-streams // A sender can close or reset a unidirectional stream unless otherwise specified. A receiver MUST // tolerate unidirectional streams being closed or reset prior to the reception of the unidirectional // stream header. diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs index f5e037cf1d406..c18a43402fd93 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs @@ -31,6 +31,8 @@ internal sealed class Http3RequestStream : IHttpStreamHeadersHandler, IAsyncDisp private TaskCompletionSource? _expect100ContinueCompletionSource; // True indicates we should send content (e.g. received 100 Continue). private bool _disposed; private readonly CancellationTokenSource _requestBodyCancellationSource; + private Task? _sendRequestTask; // Set with SendContentAsync, must be awaited before QuicStream.DisposeAsync(); + private Task? _readResponseTask; // Set with ReadResponseAsync, must be awaited before QuicStream.DisposeAsync(); // Allocated when we receive a :status header. private HttpResponseMessage? _response; @@ -88,9 +90,25 @@ public void Dispose() { _disposed = true; AbortStream(); + // We aborted both sides, thus both task should unblock and should be finished before disposing the QuicStream. + WaitUnfinished(_sendRequestTask); + WaitUnfinished(_readResponseTask); _stream.Dispose(); DisposeSyncHelper(); } + + static void WaitUnfinished(Task? task) + { + if (task is not null && !task.IsCompleted) + { + try + { + task.GetAwaiter().GetResult(); + } + catch // Exceptions from both tasks are logged via _connection.LogException() in case they're not awaited in SendAsync, so the exception can be ignored here. + { } + } + } } private void RemoveFromConnectionIfDone() @@ -107,9 +125,25 @@ public async ValueTask DisposeAsync() { _disposed = true; AbortStream(); + // We aborted both sides, thus both task should unblock and should be finished before disposing the QuicStream. + await AwaitUnfinished(_sendRequestTask).ConfigureAwait(false); + await AwaitUnfinished(_readResponseTask).ConfigureAwait(false); await _stream.DisposeAsync().ConfigureAwait(false); DisposeSyncHelper(); } + + static async ValueTask AwaitUnfinished(Task? task) + { + if (task is not null && !task.IsCompleted) + { + try + { + await task.ConfigureAwait(false); + } + catch // Exceptions from both tasks are logged via _connection.LogException() in case they're not awaited in SendAsync, so the exception can be ignored here. + { } + } + } } private void DisposeSyncHelper() @@ -158,40 +192,39 @@ public async Task SendAsync(CancellationToken cancellationT await FlushSendBufferAsync(endStream: _request.Content == null, _requestBodyCancellationSource.Token).ConfigureAwait(false); } - Task sendContentTask; if (_request.Content != null) { - sendContentTask = SendContentAsync(_request.Content!, _requestBodyCancellationSource.Token); + _sendRequestTask = SendContentAsync(_request.Content!, _requestBodyCancellationSource.Token); } else { - sendContentTask = Task.CompletedTask; + _sendRequestTask = Task.CompletedTask; } // In parallel, send content and read response. // Depending on Expect 100 Continue usage, one will depend on the other making progress. - Task readResponseTask = ReadResponseAsync(_requestBodyCancellationSource.Token); + _readResponseTask = ReadResponseAsync(_requestBodyCancellationSource.Token); bool sendContentObserved = false; // If we're not doing duplex, wait for content to finish sending here. // If we are doing duplex and have the unlikely event that it completes here, observe the result. // See Http2Connection.SendAsync for a full comment on this logic -- it is identical behavior. - if (sendContentTask.IsCompleted || + if (_sendRequestTask.IsCompleted || _request.Content?.AllowDuplex != true || - await Task.WhenAny(sendContentTask, readResponseTask).ConfigureAwait(false) == sendContentTask || - sendContentTask.IsCompleted) + await Task.WhenAny(_sendRequestTask, _readResponseTask).ConfigureAwait(false) == _sendRequestTask || + _sendRequestTask.IsCompleted) { try { - await sendContentTask.ConfigureAwait(false); + await _sendRequestTask.ConfigureAwait(false); sendContentObserved = true; } catch { - // Exceptions will be bubbled up from sendContentTask here, - // which means the result of readResponseTask won't be observed directly: + // Exceptions will be bubbled up from _sendRequestTask here, + // which means the result of _readResponseTask won't be observed directly: // Do a background await to log any exceptions. - _connection.LogExceptions(readResponseTask); + _connection.LogExceptions(_readResponseTask); throw; } } @@ -199,11 +232,11 @@ await Task.WhenAny(sendContentTask, readResponseTask).ConfigureAwait(false) == s { // Duplex is being used, so we can't wait for content to finish sending. // Do a background await to log any exceptions. - _connection.LogExceptions(sendContentTask); + _connection.LogExceptions(_sendRequestTask); } // Wait for the response headers to be read. - await readResponseTask.ConfigureAwait(false); + await _readResponseTask.ConfigureAwait(false); Debug.Assert(_response != null && _response.Content != null); // Set our content stream. diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs index 844d2866bde61..55bc73b5d146d 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs @@ -1664,10 +1664,17 @@ public async Task ServerSendsTrailingHeaders_Success() } + public enum CloseOutboundControlStream + { + BogusData, + Dispose, + Abort, + } [Theory] - [InlineData(true)] - [InlineData(false)] - public async Task ServerClosesOutboundControlStream_ClientClosesConnection(bool graceful) + [InlineData(CloseOutboundControlStream.BogusData)] + [InlineData(CloseOutboundControlStream.Dispose)] + [InlineData(CloseOutboundControlStream.Abort)] + public async Task ServerClosesOutboundControlStream_ClientClosesConnection(CloseOutboundControlStream closeType) { using Http3LoopbackServer server = CreateHttp3LoopbackServer(); @@ -1680,13 +1687,31 @@ public async Task ServerClosesOutboundControlStream_ClientClosesConnection(bool await using Http3LoopbackStream requestStream = await connection.AcceptRequestStreamAsync(); // abort the control stream - if (graceful) + if (closeType == CloseOutboundControlStream.BogusData) { await connection.OutboundControlStream.SendResponseBodyAsync(Array.Empty(), isFinal: true); } - else + else if (closeType == CloseOutboundControlStream.Dispose) { - connection.OutboundControlStream.Abort(Http3LoopbackConnection.H3_INTERNAL_ERROR); + await connection.OutboundControlStream.DisposeAsync(); + } + else if (closeType == CloseOutboundControlStream.Abort) + { + int iterations = 5; + while (iterations-- > 0) + { + connection.OutboundControlStream.Abort(Http3LoopbackConnection.H3_INTERNAL_ERROR); + // This sends RESET_FRAME which might cause complete discard of any data including stream type, leading to client ignoring the stream. + // Attempt to establish the control stream again then. + if (await semaphore.WaitAsync(100)) + { + // Client finished with the expected error. + return; + } + await connection.OutboundControlStream.DisposeAsync(); + await connection.EstablishControlStreamAsync(Array.Empty()); + await Task.Delay(100); + } } // wait for client task before tearing down the requestStream and connection diff --git a/src/libraries/System.Net.Http/tests/StressTests/HttpStress/Directory.Build.targets b/src/libraries/System.Net.Http/tests/StressTests/HttpStress/Directory.Build.targets index e3ebd0de32875..db6e799e071df 100644 --- a/src/libraries/System.Net.Http/tests/StressTests/HttpStress/Directory.Build.targets +++ b/src/libraries/System.Net.Http/tests/StressTests/HttpStress/Directory.Build.targets @@ -6,6 +6,6 @@ Define this here because the SDK resets it unconditionally in Microsoft.NETCoreSdk.BundledVersions.props. --> - 8.0 + 9.0 \ No newline at end of file diff --git a/src/libraries/System.Net.Http/tests/StressTests/HttpStress/build-local.ps1 b/src/libraries/System.Net.Http/tests/StressTests/HttpStress/build-local.ps1 index dbdd2e696c634..19085c5af766c 100644 --- a/src/libraries/System.Net.Http/tests/StressTests/HttpStress/build-local.ps1 +++ b/src/libraries/System.Net.Http/tests/StressTests/HttpStress/build-local.ps1 @@ -3,7 +3,7 @@ ## Usage: ## ./build-local.ps1 [StressConfiguration] [LibrariesConfiguration] -$Version="8.0" +$Version="9.0" $RepoRoot="$(git rev-parse --show-toplevel)" $DailyDotnetRoot= "./.dotnet-daily" diff --git a/src/libraries/System.Net.Http/tests/StressTests/HttpStress/build-local.sh b/src/libraries/System.Net.Http/tests/StressTests/HttpStress/build-local.sh index f5a0e2b784575..44b5dbf21139f 100755 --- a/src/libraries/System.Net.Http/tests/StressTests/HttpStress/build-local.sh +++ b/src/libraries/System.Net.Http/tests/StressTests/HttpStress/build-local.sh @@ -5,7 +5,7 @@ ## Usage: ## ./build-local.sh [StressConfiguration] [LibrariesConfiguration] -version=8.0 +version=9.0 repo_root=$(git rev-parse --show-toplevel) daily_dotnet_root=./.dotnet-daily diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicBuffers.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicBuffers.cs index 594245a1cb723..ba4a5a448b019 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicBuffers.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicBuffers.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Diagnostics; using System.Runtime.InteropServices; using Microsoft.Quic; @@ -32,8 +33,8 @@ private void FreeNativeMemory() { QUIC_BUFFER* buffers = _buffers; _buffers = null; - NativeMemory.Free(buffers); _count = 0; + NativeMemory.Free(buffers); } private void Reserve(int count) @@ -48,6 +49,10 @@ private void Reserve(int count) private void SetBuffer(int index, ReadOnlyMemory buffer) { + Debug.Assert(index < _count); + Debug.Assert(_buffers[index].Buffer is null); + Debug.Assert(_buffers[index].Length == 0); + _buffers[index].Buffer = (byte*)NativeMemory.Alloc((nuint)buffer.Length, (nuint)sizeof(byte)); _buffers[index].Length = (uint)buffer.Length; buffer.Span.CopyTo(_buffers[index].Span); @@ -93,8 +98,8 @@ public void Reset() } byte* buffer = _buffers[i].Buffer; _buffers[i].Buffer = null; - NativeMemory.Free(buffer); _buffers[i].Length = 0; + NativeMemory.Free(buffer); } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.cs index 337884c61a5d3..1c3b4872df163 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicConfiguration.cs @@ -7,7 +7,6 @@ using System.Security.Cryptography.X509Certificates; using System.Threading; using Microsoft.Quic; -using static Microsoft.Quic.MsQuic; namespace System.Net.Quic; diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicExtensions.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicExtensions.cs index a3d7bc6f3f7d3..5c079f6528741 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicExtensions.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicExtensions.cs @@ -66,6 +66,8 @@ public override string ToString() => $"{{ {nameof(SEND_SHUTDOWN_COMPLETE.Graceful)} = {SEND_SHUTDOWN_COMPLETE.Graceful} }}", QUIC_STREAM_EVENT_TYPE.SHUTDOWN_COMPLETE => $"{{ {nameof(SHUTDOWN_COMPLETE.ConnectionShutdown)} = {SHUTDOWN_COMPLETE.ConnectionShutdown}, {nameof(SHUTDOWN_COMPLETE.ConnectionShutdownByApp)} = {SHUTDOWN_COMPLETE.ConnectionShutdownByApp}, {nameof(SHUTDOWN_COMPLETE.ConnectionClosedRemotely)} = {SHUTDOWN_COMPLETE.ConnectionClosedRemotely}, {nameof(SHUTDOWN_COMPLETE.ConnectionErrorCode)} = {SHUTDOWN_COMPLETE.ConnectionErrorCode}, {nameof(SHUTDOWN_COMPLETE.ConnectionCloseStatus)} = {SHUTDOWN_COMPLETE.ConnectionCloseStatus} }}", + QUIC_STREAM_EVENT_TYPE.IDEAL_SEND_BUFFER_SIZE + => $"{{ {nameof(IDEAL_SEND_BUFFER_SIZE.ByteCount)} = {IDEAL_SEND_BUFFER_SIZE.ByteCount} }}", _ => string.Empty }; } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicTlsSecret.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicTlsSecret.cs index 4c45abe7acd92..ad2b3a87ccf00 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicTlsSecret.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicTlsSecret.cs @@ -2,11 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. #if DEBUG -using System.Collections.Generic; using System.IO; using System.Runtime.InteropServices; using System.Text; -using System.Threading; using Microsoft.Quic; using static Microsoft.Quic.MsQuic; diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ReceiveBuffers.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ReceiveBuffers.cs index 531ac0171ca07..93f78acc87f02 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ReceiveBuffers.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ReceiveBuffers.cs @@ -66,7 +66,7 @@ public int CopyFrom(ReadOnlySpan quicBuffers, int totalLength, bool } } - public int CopyTo(Memory buffer, out bool isCompleted, out bool isEmpty) + public int CopyTo(Memory buffer, out bool completed, out bool empty) { lock (_syncRoot) { @@ -79,8 +79,8 @@ public int CopyTo(Memory buffer, out bool isCompleted, out bool isEmpty) _buffer.Discard(copied); } - isCompleted = _buffer.IsEmpty && _final; - isEmpty = _buffer.IsEmpty; + completed = _buffer.IsEmpty && _final; + empty = _buffer.IsEmpty; return copied; } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ResettableValueTaskSource.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ResettableValueTaskSource.cs index e7c0cf87bfd5d..c3135042b032b 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ResettableValueTaskSource.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ResettableValueTaskSource.cs @@ -26,22 +26,23 @@ private enum State } private State _state; + private bool _hasWaiter; private ManualResetValueTaskSourceCore _valueTaskSource; private CancellationTokenRegistration _cancellationRegistration; + private CancellationToken _cancelledToken; private Action? _cancellationAction; private GCHandle _keepAlive; + private FinalTaskSource _finalTaskSource; - private readonly TaskCompletionSource _finalTaskSource; - - public ResettableValueTaskSource(bool runContinuationsAsynchronously = true) + public ResettableValueTaskSource() { _state = State.None; - _valueTaskSource = new ManualResetValueTaskSourceCore() { RunContinuationsAsynchronously = runContinuationsAsynchronously }; + _hasWaiter = false; + _valueTaskSource = new ManualResetValueTaskSourceCore() { RunContinuationsAsynchronously = true }; _cancellationRegistration = default; + _cancelledToken = default; _keepAlive = default; - - // TODO: defer instantiation only after Task is retrieved - _finalTaskSource = new TaskCompletionSource(runContinuationsAsynchronously ? TaskCreationOptions.RunContinuationsAsynchronously : TaskCreationOptions.None); + _finalTaskSource = new FinalTaskSource(); } /// @@ -56,18 +57,6 @@ public ResettableValueTaskSource(bool runContinuationsAsynchronously = true) /// public bool IsCompleted => (State)Volatile.Read(ref Unsafe.As(ref _state)) == State.Completed; - // TODO: Revisit this with https://github.com/dotnet/runtime/issues/79818 and https://github.com/dotnet/runtime/issues/79911 - public bool KeepAliveReleased - { - get - { - lock (this) - { - return !_keepAlive.IsAllocated; - } - } - } - /// /// Tries to get a value task representing this task source. If this task source is , it'll also transition it into state. /// It prevents concurrent operations from being invoked since it'll return false if the task source was already in state. @@ -91,11 +80,11 @@ public bool TryGetValueTask(out ValueTask valueTask, object? keepAlive = null, C _cancellationRegistration = cancellationToken.UnsafeRegister(static (obj, cancellationToken) => { (ResettableValueTaskSource thisRef, object? target) = ((ResettableValueTaskSource, object?))obj!; - // This will transition the state to Ready. - if (thisRef.TrySetException(new OperationCanceledException(cancellationToken))) + lock (thisRef) { - thisRef._cancellationAction?.Invoke(target); + thisRef._cancelledToken = cancellationToken; } + thisRef._cancellationAction?.Invoke(target); }, (this, keepAlive)); } } @@ -115,11 +104,13 @@ public bool TryGetValueTask(out ValueTask valueTask, object? keepAlive = null, C _state = State.Awaiting; } - // None, Completed, Final: return the current task. + // None, Ready, Completed: return the current task. if (state == State.None || state == State.Ready || state == State.Completed) { + // Remember that the value task with the current version is being given out. + _hasWaiter = true; valueTask = new ValueTask(this, _valueTaskSource.Version); return true; } @@ -130,84 +121,102 @@ public bool TryGetValueTask(out ValueTask valueTask, object? keepAlive = null, C } } - public Task GetFinalTask() => _finalTaskSource.Task; + /// + /// Gets a that will transition to a completed state with the last transition of this source, i.e. into . + /// + /// The that will transition to a completed state with the last transition of this source. + public Task GetFinalTask(object? keepAlive) + { + lock (this) + { + return _finalTaskSource.GetTask(keepAlive); + } + } private bool TryComplete(Exception? exception, bool final) { + // Dispose the cancellation registration before completing the task, so that it cannot run after the awaiting method returned. + // Dispose must be done outside of lock since it will wait on pending cancellation callbacks that can hold the lock from another thread. CancellationTokenRegistration cancellationRegistration = default; - try + lock (this) { - lock (this) + cancellationRegistration = _cancellationRegistration; + _cancellationRegistration = default; + } + cancellationRegistration.Dispose(); + + lock (this) + { + try { - try + State state = _state; + + // Completed: nothing to do. + if (state == State.Completed) { - State state = _state; + return false; + } - // Completed: nothing to do. - if (state == State.Completed) - { - return false; - } + // The task was non-finally completed without having anyone awaiting on it. + // In such case, discard the temporary result and replace it with this final completion. + if (state == State.Ready && !_hasWaiter && final) + { + _valueTaskSource.Reset(); + state = State.None; + } - // If the _valueTaskSource has already been set, we don't want to lose the result by overwriting it. - // So keep it as is and store the result in _finalTaskSource. + // If the _valueTaskSource has already been set, we don't want to lose the result by overwriting it. + // So keep it as is and store the result in _finalTaskSource. + if (state == State.None || + state == State.Awaiting) + { + _state = final ? State.Completed : State.Ready; + } + + // Unblock the current task source and in case of a final also the final task source. + if (exception is not null) + { + // Set up the exception stack trace for the caller. + exception = exception.StackTrace is null ? ExceptionDispatchInfo.SetCurrentStackTrace(exception) : exception; if (state == State.None || state == State.Awaiting) { - _state = final ? State.Completed : State.Ready; + _valueTaskSource.SetException(exception); } - - // Swap the cancellation registration so the one that's been registered gets eventually Disposed. - // Ideally, we would dispose it here, but if the callbacks kicks in, it tries to take the lock held by this thread leading to deadlock. - cancellationRegistration = _cancellationRegistration; - _cancellationRegistration = default; - - // Unblock the current task source and in case of a final also the final task source. - if (exception is not null) + } + else + { + if (state == State.None || + state == State.Awaiting) { - // Set up the exception stack trace for the caller. - exception = exception.StackTrace is null ? ExceptionDispatchInfo.SetCurrentStackTrace(exception) : exception; - if (state == State.None || - state == State.Awaiting) - { - _valueTaskSource.SetException(exception); - } - if (final) - { - return _finalTaskSource.TrySetException(exception); - } - return state != State.Ready; + _valueTaskSource.SetResult(final); } - else + } + if (final) + { + if (_finalTaskSource.TryComplete(exception)) { - if (state == State.None || - state == State.Awaiting) + // Signal the final task only if we don't have another result in the value task source. + // In that case, the final task will be signalled after the value task result is retrieved. + if (state != State.Ready) { - _valueTaskSource.SetResult(final); + _finalTaskSource.TrySignal(out _); } - if (final) - { - return _finalTaskSource.TrySetResult(); - } - return state != State.Ready; + return true; } + return false; } - finally + return state != State.Ready; + } + finally + { + // Un-root the the kept alive object in all cases. + if (_keepAlive.IsAllocated) { - // Un-root the the kept alive object in all cases. - if (_keepAlive.IsAllocated) - { - _keepAlive.Free(); - } + _keepAlive.Free(); } } } - finally - { - // Dispose the cancellation if registered. - // Must be done outside of lock since Dispose will wait on pending cancellation callbacks which require taking the lock. - cancellationRegistration.Dispose(); - } } /// @@ -241,11 +250,10 @@ void IValueTaskSource.OnCompleted(Action continuation, object? state, s void IValueTaskSource.GetResult(short token) { - bool successful = false; try { + _cancelledToken.ThrowIfCancellationRequested(); _valueTaskSource.GetResult(token); - successful = true; } finally { @@ -253,34 +261,109 @@ void IValueTaskSource.GetResult(short token) { State state = _state; + _hasWaiter = false; + _cancelledToken = default; + if (state == State.Ready) { _valueTaskSource.Reset(); _state = State.None; // Propagate the _finalTaskSource result into _valueTaskSource if completed. - if (_finalTaskSource.Task.IsCompleted) + if (_finalTaskSource.TrySignal(out Exception? exception)) { _state = State.Completed; - if (_finalTaskSource.Task.IsCompletedSuccessfully) + + if (exception is not null) { - _valueTaskSource.SetResult(true); + _valueTaskSource.SetException(exception); } else { - // We know it's always going to be a single exception since we're the ones setting it. - _valueTaskSource.SetException(_finalTaskSource.Task.Exception?.InnerException!); - } - - // In case the _valueTaskSource was successful, we want the potential error from _finalTaskSource to surface immediately. - // In other words, if _valueTaskSource was set with success while final exception arrived, this will throw that exception right away. - if (successful) - { - _valueTaskSource.GetResult(_valueTaskSource.Version); + _valueTaskSource.SetResult(true); } } + else + { + _state = State.None; + } + } + } + } + } + + /// + /// It remembers the result from and propagates it to only after is called. + /// Effectively allowing to separate setting of the result from task completion, which is necessary when the resettable portion of the value task source needs to consumed first. + /// + private struct FinalTaskSource + { + private TaskCompletionSource? _finalTaskSource; + private bool _isCompleted; + private bool _isSignaled; + private Exception? _exception; + + public FinalTaskSource() + { + _finalTaskSource = null; + _isCompleted = false; + _isSignaled = false; + _exception = null; + } + + public Task GetTask(object? keepAlive) + { + if (_finalTaskSource is null) + { + _finalTaskSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + if (!_isCompleted) + { + GCHandle handle = GCHandle.Alloc(keepAlive); + _finalTaskSource.Task.ContinueWith(static (_, state) => + { + ((GCHandle)state!).Free(); + }, handle, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default); } + if (_isSignaled) + { + TrySignal(out _); + } + } + return _finalTaskSource.Task; + } + + public bool TryComplete(Exception? exception = null) + { + if (_isCompleted) + { + return false; + } + + _exception = exception; + _isCompleted = true; + return true; + } + + public bool TrySignal(out Exception? exception) + { + if (!_isCompleted) + { + exception = default; + return false; } + + if (_exception is not null) + { + _finalTaskSource?.SetException(_exception); + } + else + { + _finalTaskSource?.SetResult(); + } + + exception = _exception; + _isSignaled = true; + return true; } } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ThrowHelper.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ThrowHelper.cs index ec677f9f4e58d..114c39c49c1e5 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ThrowHelper.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ThrowHelper.cs @@ -1,12 +1,12 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Security.Authentication; +using System.Diagnostics.CodeAnalysis; using System.Net.Security; using System.Net.Sockets; -using static Microsoft.Quic.MsQuic; -using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; +using System.Security.Authentication; +using static Microsoft.Quic.MsQuic; namespace System.Net.Quic; diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ValueTaskSource.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ValueTaskSource.cs index a6e40dbf7ea8a..2acd2138a1237 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ValueTaskSource.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ValueTaskSource.cs @@ -27,10 +27,10 @@ private enum State : byte private CancellationTokenRegistration _cancellationRegistration; private GCHandle _keepAlive; - public ValueTaskSource(bool runContinuationsAsynchronously = true) + public ValueTaskSource() { _state = State.None; - _valueTaskSource = new ManualResetValueTaskSourceCore() { RunContinuationsAsynchronously = runContinuationsAsynchronously }; + _valueTaskSource = new ManualResetValueTaskSourceCore() { RunContinuationsAsynchronously = true }; _cancellationRegistration = default; _keepAlive = default; } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs index a2ade033afe59..f1117a5fa7cd7 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs @@ -557,7 +557,7 @@ private static unsafe int NativeCallback(QUIC_HANDLE* connection, void* context, { if (NetEventSource.Log.IsEnabled()) { - NetEventSource.Error(null, $"Received event {connectionEvent->Type} while connection is already disposed"); + NetEventSource.Error(null, $"Received event {connectionEvent->Type} for [conn][{(nint)connection:X11}] while connection is already disposed"); } return QUIC_STATUS_INVALID_STATE; } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs index fcbfba56336ac..8a9eb59d3f178 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Diagnostics; using System.Net.Security; using System.Runtime.CompilerServices; using System.Runtime.ExceptionServices; @@ -329,7 +328,7 @@ private static unsafe int NativeCallback(QUIC_HANDLE* listener, void* context, Q { if (NetEventSource.Log.IsEnabled()) { - NetEventSource.Error(null, $"Received event {listenerEvent->Type} while listener is already disposed"); + NetEventSource.Error(null, $"Received event {listenerEvent->Type} for [list][{(nint)listener:X11}] while listener is already disposed"); } return QUIC_STATUS_INVALID_STATE; } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs index 6165f2085cb5f..bc3783d5253d9 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs @@ -76,6 +76,7 @@ public sealed partial class QuicStream if (target is QuicStream stream) { stream.Abort(QuicAbortDirection.Read, stream._defaultErrorCode); + stream._receiveTcs.TrySetResult(); } } catch (ObjectDisposedException) @@ -109,7 +110,8 @@ public sealed partial class QuicStream } }; private MsQuicBuffers _sendBuffers = new MsQuicBuffers(); - private readonly object _sendBuffersLock = new object(); + private int _sendLocked; + private Exception? _sendException; private readonly long _defaultErrorCode; @@ -135,7 +137,7 @@ public sealed partial class QuicStream /// or when for is called, /// or when the peer called for . /// - public Task ReadsClosed => _receiveTcs.GetFinalTask(); + public Task ReadsClosed => _receiveTcs.GetFinalTask(this); /// /// A that will get completed once writing side has been closed. @@ -144,7 +146,7 @@ public sealed partial class QuicStream /// or when for is called, /// or when the peer called for . /// - public Task WritesClosed => _sendTcs.GetFinalTask(); + public Task WritesClosed => _sendTcs.GetFinalTask(this); /// public override string ToString() => _handle.ToString(); @@ -334,7 +336,7 @@ public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationTo /// The region of memory to write data from. /// The token to monitor for cancellation requests. The default value is . /// Notifies the peer about gracefully closing the write side, i.e.: sends FIN flag with the data. - public ValueTask WriteAsync(ReadOnlyMemory buffer, bool completeWrites, CancellationToken cancellationToken = default) + public async ValueTask WriteAsync(ReadOnlyMemory buffer, bool completeWrites, CancellationToken cancellationToken = default) { ObjectDisposedException.ThrowIf(_disposed == 1, this); @@ -348,11 +350,11 @@ public ValueTask WriteAsync(ReadOnlyMemory buffer, bool completeWrites, Ca NetEventSource.Info(this, $"{this} Stream writing memory of '{buffer.Length}' bytes while {(completeWrites ? "completing" : "not completing")} writes."); } - if (_sendTcs.IsCompleted && cancellationToken.IsCancellationRequested) + if (_sendTcs.IsCompleted) { // Special case exception type for pre-canceled token while we've already transitioned to a final state and don't need to abort write. // It must happen before we try to get the value task, since the task source is versioned and each instance must be awaited. - return ValueTask.FromCanceled(cancellationToken); + cancellationToken.ThrowIfCancellationRequested(); } // Concurrent call, this one lost the race. @@ -364,7 +366,8 @@ public ValueTask WriteAsync(ReadOnlyMemory buffer, bool completeWrites, Ca // No need to call anything since we already have a result, most likely an exception. if (valueTask.IsCompleted) { - return valueTask; + await valueTask.ConfigureAwait(false); + return; } // For an empty buffer complete immediately, close the writing side of the stream if necessary. @@ -375,25 +378,15 @@ public ValueTask WriteAsync(ReadOnlyMemory buffer, bool completeWrites, Ca { CompleteWrites(); } - return valueTask; + await valueTask.ConfigureAwait(false); + return; } - lock (_sendBuffersLock) + // We own the lock, abort might happen, but exception will get stored instead. + if (Interlocked.CompareExchange(ref _sendLocked, 1, 0) == 0) { - ObjectDisposedException.ThrowIf(_disposed == 1, this); // TODO: valueTask is left unobserved unsafe { - if (_sendBuffers.Count > 0 && _sendBuffers.Buffers[0].Buffer != null) - { - // _sendBuffers are not reset, meaning SendComplete for the previous WriteAsync call didn't arrive yet. - // In case of cancellation, the task from _sendTcs is finished before the aborting. It is technically possible for subsequent - // WriteAsync to grab the next task from _sendTcs and start executing before SendComplete event occurs for the previous (canceled) write. - // This is not an "invalid nested call", because the previous task has finished. Best guess is to mimic OperationAborted as it will be from Abort - // that would execute soon enough, if not already. Not final, because Abort should be the one to set final exception. - _sendTcs.TrySetException(ThrowHelper.GetOperationAbortedException(SR.net_quic_writing_aborted), final: false); - return valueTask; - } - _sendBuffers.Initialize(buffer); int status = MsQuicApi.Api.StreamSend( _handle, @@ -401,15 +394,28 @@ public ValueTask WriteAsync(ReadOnlyMemory buffer, bool completeWrites, Ca (uint)_sendBuffers.Count, completeWrites ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE, null); - if (ThrowHelper.TryGetStreamExceptionForMsQuicStatus(status, out Exception? exception)) + // No SEND_COMPLETE expected, release buffer and unlock. + if (StatusFailed(status)) { _sendBuffers.Reset(); - _sendTcs.TrySetException(exception, final: true); + Volatile.Write(ref _sendLocked, 0); + + // There might be stored exception from when we held the lock. + if (ThrowHelper.TryGetStreamExceptionForMsQuicStatus(status, out Exception? exception)) + { + Interlocked.CompareExchange(ref _sendException, exception, null); + } + exception = Volatile.Read(ref _sendException); + if (exception is not null) + { + _sendTcs.TrySetException(exception, final: true); + } } + // SEND_COMPLETE expected, buffer and lock will be released then. } } - return valueTask; + await valueTask.ConfigureAwait(false); } /// @@ -429,19 +435,13 @@ public void Abort(QuicAbortDirection abortDirection, long errorCode) } QUIC_STREAM_SHUTDOWN_FLAGS flags = QUIC_STREAM_SHUTDOWN_FLAGS.NONE; - if (abortDirection.HasFlag(QuicAbortDirection.Read)) + if (abortDirection.HasFlag(QuicAbortDirection.Read) && !_receiveTcs.IsCompleted) { - if (_receiveTcs.TrySetException(ThrowHelper.GetOperationAbortedException(SR.net_quic_reading_aborted), final: true)) - { - flags |= QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_RECEIVE; - } + flags |= QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_RECEIVE; } - if (abortDirection.HasFlag(QuicAbortDirection.Write)) + if (abortDirection.HasFlag(QuicAbortDirection.Write) && !_sendTcs.IsCompleted) { - if (_sendTcs.TrySetException(ThrowHelper.GetOperationAbortedException(SR.net_quic_writing_aborted), final: true)) - { - flags |= QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_SEND; - } + flags |= QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_SEND; } // Nothing to abort, the requested sides to abort are already closed. if (flags == QUIC_STREAM_SHUTDOWN_FLAGS.NONE) @@ -453,7 +453,6 @@ public void Abort(QuicAbortDirection abortDirection, long errorCode) { NetEventSource.Info(this, $"{this} Aborting {abortDirection} with {errorCode}"); } - unsafe { ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.StreamShutdown( @@ -462,6 +461,21 @@ public void Abort(QuicAbortDirection abortDirection, long errorCode) (ulong)errorCode), "StreamShutdown failed"); } + + if (abortDirection.HasFlag(QuicAbortDirection.Read)) + { + _receiveTcs.TrySetException(ThrowHelper.GetOperationAbortedException(SR.net_quic_reading_aborted), final: true); + } + if (abortDirection.HasFlag(QuicAbortDirection.Write)) + { + var exception = ThrowHelper.GetOperationAbortedException(SR.net_quic_writing_aborted); + Interlocked.CompareExchange(ref _sendException, exception, null); + if (Interlocked.CompareExchange(ref _sendLocked, 1, 0) == 0) + { + _sendTcs.TrySetException(_sendException, final: true); + Volatile.Write(ref _sendLocked, 0); + } + } } /// @@ -475,16 +489,23 @@ public void CompleteWrites() { ObjectDisposedException.ThrowIf(_disposed == 1, this); - if (_shutdownTcs.TryInitialize(out _, this)) + // Nothing to complete, the writing side is already closed. + if (_sendTcs.IsCompleted) { - unsafe - { - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.StreamShutdown( - _handle, - QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, - default), - "StreamShutdown failed"); - } + return; + } + + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Info(this, $"{this} Completing writes."); + } + unsafe + { + ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.StreamShutdown( + _handle, + QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, + default), + "StreamShutdown failed"); } } @@ -528,11 +549,15 @@ private unsafe int HandleEventReceive(ref RECEIVE_DATA data) } private unsafe int HandleEventSendComplete(ref SEND_COMPLETE_DATA data) { - // In case of cancellation, the task from _sendTcs is finished before the aborting. It is technically possible for subsequent WriteAsync to grab the next task - // from _sendTcs and start executing before SendComplete event occurs for the previous (canceled) write - lock (_sendBuffersLock) + // Release buffer and unlock. + _sendBuffers.Reset(); + Volatile.Write(ref _sendLocked, 0); + + // There might be stored exception from when we held the lock. + Exception? exception = Volatile.Read(ref _sendException); + if (exception is not null) { - _sendBuffers.Reset(); + _sendTcs.TrySetException(exception, final: true); } if (data.Canceled == 0) { @@ -616,7 +641,7 @@ private unsafe int HandleStreamEvent(ref QUIC_STREAM_EVENT streamEvent) #pragma warning disable CS3016 [UnmanagedCallersOnly(CallConvs = new Type[] { typeof(CallConvCdecl) })] #pragma warning restore CS3016 - private static unsafe int NativeCallback(QUIC_HANDLE* connection, void* context, QUIC_STREAM_EVENT* streamEvent) + private static unsafe int NativeCallback(QUIC_HANDLE* stream, void* context, QUIC_STREAM_EVENT* streamEvent) { GCHandle stateHandle = GCHandle.FromIntPtr((IntPtr)context); @@ -625,7 +650,7 @@ private static unsafe int NativeCallback(QUIC_HANDLE* connection, void* context, { if (NetEventSource.Log.IsEnabled()) { - NetEventSource.Error(null, $"Received event {streamEvent->Type} while connection is already disposed"); + NetEventSource.Error(null, $"Received event {streamEvent->Type} for [strm][{(nint)stream:X11}] while stream is already disposed"); } return QUIC_STATUS_INVALID_STATE; } @@ -663,44 +688,37 @@ public override async ValueTask DisposeAsync() return; } - ValueTask valueTask; - // If the stream wasn't started successfully, gracelessly abort it. if (!_startedTcs.IsCompletedSuccessfully) { // Check if the stream has been shut down and if not, shut it down. - if (_shutdownTcs.TryInitialize(out valueTask, this)) - { - StreamShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT | QUIC_STREAM_SHUTDOWN_FLAGS.IMMEDIATE, _defaultErrorCode); - } + StreamShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT | QUIC_STREAM_SHUTDOWN_FLAGS.IMMEDIATE, _defaultErrorCode); } else { - // Abort the read side of the stream if it hasn't been fully consumed. - if (_receiveTcs.TrySetException(ThrowHelper.GetOperationAbortedException(), final: true)) + // Abort the read side and complete the write side if that side hasn't been completed yet. + if (!_receiveTcs.IsCompleted) { StreamShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_RECEIVE, _defaultErrorCode); } - // Check if the stream has been shut down and if not, shut it down. - if (_shutdownTcs.TryInitialize(out valueTask, this)) + if (!_sendTcs.IsCompleted) { StreamShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, default); } } // Wait for SHUTDOWN_COMPLETE, the last event, so that all resources can be safely released. - await valueTask.ConfigureAwait(false); + if (_shutdownTcs.TryInitialize(out ValueTask valueTask, this)) + { + await valueTask.ConfigureAwait(false); + } Debug.Assert(_startedTcs.IsCompleted); - // TODO: Revisit this with https://github.com/dotnet/runtime/issues/79818 and https://github.com/dotnet/runtime/issues/79911 - Debug.Assert(_receiveTcs.KeepAliveReleased); - Debug.Assert(_sendTcs.KeepAliveReleased); + Debug.Assert(_receiveTcs.IsCompleted); + Debug.Assert(_sendTcs.IsCompleted); _handle.Dispose(); - lock (_sendBuffersLock) - { - // TODO: memory leak if not disposed - _sendBuffers.Dispose(); - } + // TODO: memory leak if not disposed + _sendBuffers.Dispose(); unsafe void StreamShutdown(QUIC_STREAM_SHUTDOWN_FLAGS flags, long errorCode) { @@ -715,6 +733,17 @@ unsafe void StreamShutdown(QUIC_STREAM_SHUTDOWN_FLAGS flags, long errorCode) NetEventSource.Error(this, $"{this} StreamShutdown({flags}) failed: {ThrowHelper.GetErrorMessageForStatus(status)}."); } } + else + { + if (flags.HasFlag(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_RECEIVE) && !_receiveTcs.IsCompleted) + { + _receiveTcs.TrySetException(ThrowHelper.GetOperationAbortedException(SR.net_quic_reading_aborted), final: true); + } + if (flags.HasFlag(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_SEND) && !_sendTcs.IsCompleted) + { + _sendTcs.TrySetException(ThrowHelper.GetOperationAbortedException(SR.net_quic_writing_aborted), final: true); + } + } } } } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs index e220decb1bd40..cd3c1a2394f38 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs @@ -1208,5 +1208,310 @@ async ValueTask ReleaseOnReadsClosedAsync() } ); } + + private const int SmallestPayload = 1; + private const int SmallPayload = 1024; + private const int BufferPayload = 64*1024; + private const int BufferPlusPayload = 64*1024+1; + private const int BigPayload = 1024*1024*1024; + + public static IEnumerable PayloadSizeAndTwoBools() + { + var boolValues = new [] { true, false }; + var payloadValues = !PlatformDetection.IsInHelix ? + new [] { SmallestPayload, SmallPayload, BufferPayload, BufferPlusPayload, BigPayload } : + new [] { SmallestPayload, SmallPayload, BufferPayload, BufferPlusPayload }; + return + from payload in payloadValues + from bool1 in boolValues + from bool2 in boolValues + select new object[] { payload, bool1, bool2 }; + } + + [Theory] + [MemberData(nameof(PayloadSizeAndTwoBools))] + public async Task ReadsClosedFinishes_ConnectionClose(int payloadSize, bool closeServer, bool useDispose) + { + using SemaphoreSlim serverSem = new SemaphoreSlim(0); + using SemaphoreSlim clientSem = new SemaphoreSlim(0); + + await RunClientServer( + serverFunction: async connection => + { + QuicError expectedError = QuicError.ConnectionAborted; + long expectedErrorCode = DefaultCloseErrorCodeClient; + + await using QuicStream stream = await connection.AcceptInboundStreamAsync(); + await stream.WriteAsync(new byte[payloadSize], completeWrites: true); + // Make sure the data gets received by the peer if we expect the reading side to get buffered including FIN. + if (payloadSize <= BufferPayload) + { + await stream.WritesClosed; + } + serverSem.Release(); + await clientSem.WaitAsync(); + + if (closeServer) + { + expectedError = QuicError.OperationAborted; + expectedErrorCode = DefaultCloseErrorCodeServer; + if (useDispose) + { + await connection.DisposeAsync(); + } + else + { + await connection.CloseAsync(DefaultCloseErrorCodeServer); + } + } + + await CheckReadsClosed(stream, expectedError, expectedErrorCode); + }, + clientFunction: async connection => + { + QuicError expectedError = QuicError.ConnectionAborted; + long expectedErrorCode = DefaultCloseErrorCodeServer; + + await using QuicStream stream = await connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); + await stream.WriteAsync(new byte[payloadSize], completeWrites: true); + if (payloadSize <= BufferPayload) + { + await stream.WritesClosed; + } + clientSem.Release(); + await serverSem.WaitAsync(); + + if (!closeServer) + { + expectedError = QuicError.OperationAborted; + expectedErrorCode = DefaultCloseErrorCodeClient; + if (useDispose) + { + await connection.DisposeAsync(); + } + else + { + await connection.CloseAsync(DefaultCloseErrorCodeClient); + } + } + + await CheckReadsClosed(stream, expectedError, expectedErrorCode); + } + ); + + async ValueTask CheckReadsClosed(QuicStream stream, QuicError expectedError, long expectedErrorCode) + { + // All data should be buffered if they fit in the internal buffer, reading should still pass. + if (payloadSize <= BufferPayload) + { + Assert.False(stream.ReadsClosed.IsCompleted); + var buffer = new byte[BufferPayload]; + var length = await ReadAll(stream, buffer); + Assert.True(stream.ReadsClosed.IsCompletedSuccessfully); + Assert.Equal(payloadSize, length); + } + else + { + var ex = await AssertThrowsQuicExceptionAsync(expectedError, () => stream.ReadsClosed); + if (expectedError == QuicError.OperationAborted) + { + Assert.Null(ex.ApplicationErrorCode); + } + else + { + Assert.Equal(expectedErrorCode, ex.ApplicationErrorCode); + } + } + } + } + + [Theory] + [MemberData(nameof(PayloadSizeAndTwoBools))] + public async Task WritesClosedFinishes_ConnectionClose(int payloadSize, bool closeServer, bool useDispose) + { + using SemaphoreSlim serverSem = new SemaphoreSlim(0); + using SemaphoreSlim clientSem = new SemaphoreSlim(0); + + await RunClientServer( + serverFunction: async connection => + { + QuicError expectedError = QuicError.ConnectionAborted; + long expectedErrorCode = DefaultCloseErrorCodeClient; + + await using QuicStream stream = await connection.AcceptInboundStreamAsync(); + await stream.WriteAsync(new byte[payloadSize]); + serverSem.Release(); + await clientSem.WaitAsync(); + + if (closeServer) + { + expectedError = QuicError.OperationAborted; + expectedErrorCode = DefaultCloseErrorCodeServer; + if (useDispose) + { + await connection.DisposeAsync(); + } + else + { + await connection.CloseAsync(DefaultCloseErrorCodeServer); + } + } + + await CheckWritesClosed(stream, expectedError, expectedErrorCode); + }, + clientFunction: async connection => + { + QuicError expectedError = QuicError.ConnectionAborted; + long expectedErrorCode = DefaultCloseErrorCodeServer; + + await using QuicStream stream = await connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); + await stream.WriteAsync(new byte[payloadSize]); + clientSem.Release(); + await serverSem.WaitAsync(); + + if (!closeServer) + { + expectedError = QuicError.OperationAborted; + expectedErrorCode = DefaultCloseErrorCodeClient; + if (useDispose) + { + await connection.DisposeAsync(); + } + else + { + await connection.CloseAsync(DefaultCloseErrorCodeClient); + } + } + + await CheckWritesClosed(stream, expectedError, expectedErrorCode); + } + ); + + async ValueTask CheckWritesClosed(QuicStream stream, QuicError expectedError, long expectedErrorCode) + { + var ex = await AssertThrowsQuicExceptionAsync(expectedError, () => stream.WritesClosed); + if (expectedError == QuicError.OperationAborted) + { + Assert.Null(ex.ApplicationErrorCode); + } + else + { + Assert.Equal(expectedErrorCode, ex.ApplicationErrorCode); + } + } + } + + [Theory] + [MemberData(nameof(PayloadSizeAndTwoBools))] + public async Task ReadsWritesClosedFinish_StreamDisposed(int payloadSize, bool disposeServer, bool completeWrites) + { + using SemaphoreSlim serverSem = new SemaphoreSlim(0); + using SemaphoreSlim clientSem = new SemaphoreSlim(0); + TaskCompletionSource tcs = new TaskCompletionSource(); + + await RunClientServer( + serverFunction: async connection => + { + // Establish stream, send the payload based on the input and synchronize with the peer. + await using QuicStream stream = await connection.AcceptInboundStreamAsync(); + await stream.WriteAsync(new byte[payloadSize], completeWrites); + serverSem.Release(); + await clientSem.WaitAsync(); + + if (disposeServer) + { + await DisposeSide(stream, tcs); + } + else + { + await WaitingSide(stream, tcs.Task, DefaultStreamErrorCodeClient); + } + }, + clientFunction: async connection => + { + // Establish stream, send the payload based on the input and synchronize with the peer. + await using QuicStream stream = await connection.OpenOutboundStreamAsync(QuicStreamType.Bidirectional); + await stream.WriteAsync(new byte[payloadSize], completeWrites); + clientSem.Release(); + await serverSem.WaitAsync(); + + if (disposeServer) + { + await WaitingSide(stream, tcs.Task, DefaultStreamErrorCodeServer); + } + else + { + await DisposeSide(stream, tcs); + } + }); + + async ValueTask DisposeSide(QuicStream stream, TaskCompletionSource tcs) + { + // Abort writing side if it's getting blocked by peer not consuming the data. + long? abortCode = null; + if (completeWrites || payloadSize >= BigPayload) + { + try + { + await stream.WritesClosed.WaitAsync(TimeSpan.FromSeconds(2.5)); + } + catch (TimeoutException) + { + Assert.True(payloadSize >= BigPayload); + abortCode = 0xABC; + stream.Abort(QuicAbortDirection.Write, abortCode.Value); + } + } + + await stream.DisposeAsync(); + + // Reads should be aborted as we didn't consume the data. + var readEx = await AssertThrowsQuicExceptionAsync(QuicError.OperationAborted, () => stream.ReadsClosed); + Assert.Null(readEx.ApplicationErrorCode); + + // Writes should be aborted as we aborted them. + if (abortCode.HasValue) + { + var writeEx = await AssertThrowsQuicExceptionAsync(QuicError.OperationAborted, () => stream.WritesClosed); + Assert.Null(writeEx.ApplicationErrorCode); + } + else + { + // Writes should be completed successfully as they should all fit in the peers buffers. + Assert.True(stream.WritesClosed.IsCompletedSuccessfully); + } + + tcs.SetResult(abortCode); + } + async ValueTask WaitingSide(QuicStream stream, Task task, long errorCode) + { + long? abortCode = await task; + + // Reads will be aborted by the peer as we didn't consume them all. + if (abortCode.HasValue) + { + var readEx = await AssertThrowsQuicExceptionAsync(QuicError.StreamAborted, () => stream.ReadsClosed); + Assert.Equal(abortCode.Value, readEx.ApplicationErrorCode); + } + // Reads should be still open as the peer closed gracefully and we are keeping the data in buffer. + else + { + Assert.False(stream.ReadsClosed.IsCompleted); + } + + if (!completeWrites) + { + // Writes must be aborted by the peer as we didn't complete them. + var writeEx = await AssertThrowsQuicExceptionAsync(QuicError.StreamAborted, () => stream.WritesClosed); + Assert.Equal(errorCode, writeEx.ApplicationErrorCode); + } + else + { + // Writes must be closed, but whether successfully or not depends on the timing. + // Peer might have aborted reading side before receiving all the data. + Assert.True(stream.WritesClosed.IsCompleted); + } + } + } } }