Skip to content

Commit

Permalink
[QUIC] Stream write cancellation (#53304)
Browse files Browse the repository at this point in the history
Add tests to check write cancellation behavior, fix pre-cancelled writes and fix mock stream.
Add throwing on msquic returning write canceled status.

Fixes #32077
  • Loading branch information
CarnaViire authored Jun 4, 2021
1 parent 911640b commit e0671e7
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ internal sealed class MockStream : QuicStreamProvider
private readonly bool _isInitiator;

private readonly StreamState _streamState;
private bool _writesCanceled;

internal MockStream(StreamState streamState, bool isInitiator)
{
Expand Down Expand Up @@ -84,6 +85,10 @@ internal override async ValueTask<int> ReadAsync(Memory<byte> buffer, Cancellati
internal override void Write(ReadOnlySpan<byte> buffer)
{
CheckDisposed();
if (Volatile.Read(ref _writesCanceled))
{
throw new OperationCanceledException();
}

StreamBuffer? streamBuffer = WriteStreamBuffer;
if (streamBuffer is null)
Expand All @@ -102,13 +107,24 @@ internal override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, Cancellation
internal override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, bool endStream, CancellationToken cancellationToken = default)
{
CheckDisposed();
if (Volatile.Read(ref _writesCanceled))
{
cancellationToken.ThrowIfCancellationRequested();
throw new OperationCanceledException();
}

StreamBuffer? streamBuffer = WriteStreamBuffer;
if (streamBuffer is null)
{
throw new NotSupportedException();
}

using var registration = cancellationToken.UnsafeRegister(static s =>
{
var stream = (MockStream)s!;
Volatile.Write(ref stream._writesCanceled, true);
}, this);

await streamBuffer.WriteAsync(buffer, cancellationToken).ConfigureAwait(false);

if (endStream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,18 +216,14 @@ private async ValueTask<CancellationTokenRegistration> HandleWriteStartState(Can
throw new InvalidOperationException(SR.net_quic_writing_notallowed);
}

lock (_state)
// Make sure start has completed
if (!_started)
{
if (_state.SendState == SendState.Aborted)
{
throw new OperationCanceledException(SR.net_quic_sending_aborted);
}
else if (_state.SendState == SendState.ConnectionClosed)
{
throw GetConnectionAbortedException(_state);
}
await _state.SendResettableCompletionSource.GetTypelessValueTask().ConfigureAwait(false);
_started = true;
}

// if token was already cancelled, this would execute syncronously
CancellationTokenRegistration registration = cancellationToken.UnsafeRegister(static (s, token) =>
{
var state = (State)s!;
Expand All @@ -248,11 +244,17 @@ private async ValueTask<CancellationTokenRegistration> HandleWriteStartState(Can
}
}, _state);

// Make sure start has completed
if (!_started)
lock (_state)
{
await _state.SendResettableCompletionSource.GetTypelessValueTask().ConfigureAwait(false);
_started = true;
if (_state.SendState == SendState.Aborted)
{
cancellationToken.ThrowIfCancellationRequested();
throw new OperationCanceledException(SR.net_quic_sending_aborted);
}
else if (_state.SendState == SendState.ConnectionClosed)
{
throw GetConnectionAbortedException(_state);
}
}

return registration;
Expand All @@ -262,7 +264,7 @@ private void HandleWriteCompletedState()
{
lock (_state)
{
if (_state.SendState == SendState.Finished || _state.SendState == SendState.Aborted)
if (_state.SendState == SendState.Finished)
{
_state.SendState = SendState.None;
}
Expand Down Expand Up @@ -827,6 +829,9 @@ private static uint HandleEventPeerSendShutdown(State state)

private static uint HandleEventSendComplete(State state, ref StreamEvent evt)
{
StreamEventDataSendComplete sendCompleteEvent = evt.Data.SendComplete;
bool canceled = sendCompleteEvent.Canceled != 0;

bool complete = false;

lock (state)
Expand All @@ -836,13 +841,26 @@ private static uint HandleEventSendComplete(State state, ref StreamEvent evt)
state.SendState = SendState.Finished;
complete = true;
}

if (canceled)
{
state.SendState = SendState.Aborted;
}
}

if (complete)
{
CleanupSendState(state);
// TODO throw if a write was canceled.
state.SendResettableCompletionSource.Complete(MsQuicStatusCodes.Success);

if (!canceled)
{
state.SendResettableCompletionSource.Complete(MsQuicStatusCodes.Success);
}
else
{
state.SendResettableCompletionSource.CompleteException(
ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException("Write was canceled")));
}
}

return MsQuicStatusCodes.Success;
Expand Down
133 changes: 133 additions & 0 deletions src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Xunit;

Expand Down Expand Up @@ -434,6 +435,138 @@ await Task.Run(async () =>
Assert.Equal(ExpectedErrorCode, ex.ErrorCode);
}).WaitAsync(TimeSpan.FromSeconds(15));
}

[ActiveIssue("https://github.com/dotnet/runtime/issues/53530")]
[Fact]
public async Task StreamAbortedWithoutWriting_ReadThrows()
{
long expectedErrorCode = 1234;

await RunClientServer(
clientFunction: async connection =>
{
await using QuicStream stream = connection.OpenUnidirectionalStream();
stream.AbortWrite(expectedErrorCode);
await stream.ShutdownCompleted();
},
serverFunction: async connection =>
{
await using QuicStream stream = await connection.AcceptStreamAsync();
byte[] buffer = new byte[1];
QuicStreamAbortedException ex = await Assert.ThrowsAsync<QuicStreamAbortedException>(() => ReadAll(stream, buffer));
Assert.Equal(expectedErrorCode, ex.ErrorCode);
await stream.ShutdownCompleted();
}
);
}

[Fact]
public async Task WritePreCanceled_Throws()
{
long expectedErrorCode = 1234;

await RunClientServer(
clientFunction: async connection =>
{
await using QuicStream stream = connection.OpenUnidirectionalStream();
CancellationTokenSource cts = new CancellationTokenSource();
cts.Cancel();
await Assert.ThrowsAsync<OperationCanceledException>(() => stream.WriteAsync(new byte[1], cts.Token).AsTask());
// next write would also throw
await Assert.ThrowsAsync<OperationCanceledException>(() => stream.WriteAsync(new byte[1]).AsTask());
// manual write abort is still required
stream.AbortWrite(expectedErrorCode);
await stream.ShutdownCompleted();
},
serverFunction: async connection =>
{
await using QuicStream stream = await connection.AcceptStreamAsync();
byte[] buffer = new byte[1024 * 1024];
// TODO: it should always throw QuicStreamAbortedException, but sometimes it does not https://github.com/dotnet/runtime/issues/53530
//QuicStreamAbortedException ex = await Assert.ThrowsAsync<QuicStreamAbortedException>(() => ReadAll(stream, buffer));
try
{
await ReadAll(stream, buffer);
}
catch (QuicStreamAbortedException) { }
await stream.ShutdownCompleted();
}
);
}

[Fact]
public async Task WriteCanceled_NextWriteThrows()
{
long expectedErrorCode = 1234;

await RunClientServer(
clientFunction: async connection =>
{
await using QuicStream stream = connection.OpenUnidirectionalStream();
CancellationTokenSource cts = new CancellationTokenSource(500);
async Task WriteUntilCanceled()
{
var buffer = new byte[64 * 1024];
while (true)
{
await stream.WriteAsync(buffer, cancellationToken: cts.Token);
}
}
// a write would eventually be canceled
await Assert.ThrowsAsync<OperationCanceledException>(() => WriteUntilCanceled().WaitAsync(TimeSpan.FromSeconds(3)));
// next write would also throw
await Assert.ThrowsAsync<OperationCanceledException>(() => stream.WriteAsync(new byte[1]).AsTask());
// manual write abort is still required
stream.AbortWrite(expectedErrorCode);
await stream.ShutdownCompleted();
},
serverFunction: async connection =>
{
await using QuicStream stream = await connection.AcceptStreamAsync();
async Task ReadUntilAborted()
{
var buffer = new byte[1024];
while (true)
{
int res = await stream.ReadAsync(buffer);
if (res == 0)
{
break;
}
}
}
// TODO: it should always throw QuicStreamAbortedException, but sometimes it does not https://github.com/dotnet/runtime/issues/53530
//QuicStreamAbortedException ex = await Assert.ThrowsAsync<QuicStreamAbortedException>(() => ReadUntilAborted());
try
{
await ReadUntilAborted().WaitAsync(TimeSpan.FromSeconds(3));
}
catch (QuicStreamAbortedException) { }
await stream.ShutdownCompleted();
}
);
}
}

public sealed class QuicStreamTests_MockProvider : QuicStreamTests<MockProviderFactory> { }
Expand Down

0 comments on commit e0671e7

Please sign in to comment.