From 2d72286c5c50259a62286f7c32990309ccdcfa52 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Sun, 10 Jan 2021 16:29:05 +0200 Subject: [PATCH 01/52] Initial API changes. --- .../ref/System.Net.WebSockets.cs | 18 +++++ .../src/Resources/Strings.resx | 67 ++++++++++++++++++- .../src/System.Net.WebSockets.csproj | 2 + .../System/Net/WebSockets/ManagedWebSocket.cs | 25 +++---- .../src/System/Net/WebSockets/WebSocket.cs | 54 +++++++-------- .../WebSockets/WebSocketCreationOptions.cs | 55 +++++++++++++++ .../Net/WebSockets/WebSocketDeflateOptions.cs | 60 +++++++++++++++++ 7 files changed, 232 insertions(+), 49 deletions(-) create mode 100644 src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs create mode 100644 src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs diff --git a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs index ca1e25ea21289..e1fff547f2124 100644 --- a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs +++ b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs @@ -31,6 +31,8 @@ protected WebSocket() { } public static System.Net.WebSockets.WebSocket CreateClientWebSocket(System.IO.Stream innerStream, string? subProtocol, int receiveBufferSize, int sendBufferSize, System.TimeSpan keepAliveInterval, bool useZeroMaskingKey, System.ArraySegment internalBuffer) { throw null; } [System.Runtime.Versioning.UnsupportedOSPlatform("browser")] public static System.Net.WebSockets.WebSocket CreateFromStream(System.IO.Stream stream, bool isServer, string? subProtocol, System.TimeSpan keepAliveInterval) { throw null; } + [System.Runtime.Versioning.UnsupportedOSPlatform("browser")] + public static System.Net.WebSockets.WebSocket CreateFromStream(System.IO.Stream stream, WebSocketCreationOptions options) { throw null; } public static System.ArraySegment CreateServerBuffer(int receiveBufferSize) { throw null; } public abstract void Dispose(); [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] @@ -133,4 +135,20 @@ public enum WebSocketState Closed = 5, Aborted = 6, } + + public sealed class WebSocketCreationOptions + { + public bool IsServer { get; set; } + public string? SubProtocol { get; set; } + public TimeSpan KeepAliveInterval { get; set; } + public WebSocketDeflateOptions? DeflateOptions { get; set; } + } + + public sealed class WebSocketDeflateOptions + { + public int ClientMaxWindowBits { get; set; } + public bool ClientContextTakeover { get; set; } + public int ServerMaxWindowBits { get; set; } + public bool ServerContextTakeover { get; set; } + } } diff --git a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx index a4f630ea24c03..a3eecbceb6b61 100644 --- a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx +++ b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx @@ -1,4 +1,64 @@ - + + + @@ -138,4 +198,7 @@ The base stream is not writeable. - + + The argument must be a value between {0} and {1}. + + \ No newline at end of file diff --git a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj index d65e6c55737af..67ca96113f754 100644 --- a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj +++ b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj @@ -9,6 +9,8 @@ + + diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index bf4888a57b057..8579de1651515 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -29,14 +29,11 @@ internal sealed partial class ManagedWebSocket : WebSocket { /// Creates a from a connected to a websocket endpoint. /// The connected Stream. - /// true if this is the server-side of the connection; false if this is the client-side of the connection. - /// The agreed upon subprotocol for the connection. - /// The interval to use for keep-alive pings. + /// The options with which the websocket must be created. /// The created instance. - public static ManagedWebSocket CreateFromConnectedStream( - Stream stream, bool isServer, string? subprotocol, TimeSpan keepAliveInterval) + public static ManagedWebSocket CreateFromConnectedStream(Stream stream, WebSocketCreationOptions options) { - return new ManagedWebSocket(stream, isServer, subprotocol, keepAliveInterval); + return new ManagedWebSocket(stream, options); } /// Thread-safe random number generator used to generate masks for each send. @@ -153,12 +150,7 @@ public static ManagedWebSocket CreateFromConnectedStream( /// private object ReceiveAsyncLock => _utf8TextState; // some object, as we're simply lock'ing on it - /// Initializes the websocket. - /// The connected Stream. - /// true if this is the server-side of the connection; false if this is the client-side of the connection. - /// The agreed upon subprotocol for the connection. - /// The interval to use for keep-alive pings. - private ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, TimeSpan keepAliveInterval) + private ManagedWebSocket(Stream stream, WebSocketCreationOptions options) { Debug.Assert(StateUpdateLock != null, $"Expected {nameof(StateUpdateLock)} to be non-null"); Debug.Assert(ReceiveAsyncLock != null, $"Expected {nameof(ReceiveAsyncLock)} to be non-null"); @@ -167,11 +159,10 @@ private ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Time Debug.Assert(stream != null, $"Expected non-null stream"); Debug.Assert(stream.CanRead, $"Expected readable stream"); Debug.Assert(stream.CanWrite, $"Expected writeable stream"); - Debug.Assert(keepAliveInterval == Timeout.InfiniteTimeSpan || keepAliveInterval >= TimeSpan.Zero, $"Invalid keepalive interval: {keepAliveInterval}"); _stream = stream; - _isServer = isServer; - _subprotocol = subprotocol; + _isServer = options.IsServer; + _subprotocol = options.SubProtocol; // Create a buffer just large enough to handle received packet headers (at most 14 bytes) and // control payloads (at most 125 bytes). Message payloads are read directly into the buffer @@ -201,7 +192,7 @@ private ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Time // Now that we're opened, initiate the keep alive timer to send periodic pings. // We use a weak reference from the timer to the web socket to avoid a cycle // that could keep the web socket rooted in erroneous cases. - if (keepAliveInterval > TimeSpan.Zero) + if (options.KeepAliveInterval > TimeSpan.Zero) { _keepAliveTimer = new Timer(static s => { @@ -210,7 +201,7 @@ private ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Time { thisRef.SendKeepAliveFrameAsync(); } - }, new WeakReference(this), keepAliveInterval, keepAliveInterval); + }, new WeakReference(this), options.KeepAliveInterval, options.KeepAliveInterval); } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs index 14732fb6b77ca..9d93222577568 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs @@ -137,29 +137,30 @@ public static ArraySegment CreateServerBuffer(int receiveBufferSize) [UnsupportedOSPlatform("browser")] public static WebSocket CreateFromStream(Stream stream, bool isServer, string? subProtocol, TimeSpan keepAliveInterval) { - if (stream == null) + return CreateFromStream(stream, new WebSocketCreationOptions { + IsServer = isServer, + SubProtocol = subProtocol, + KeepAliveInterval = keepAliveInterval + }); + } + + /// Creates a that operates on a representing a web socket connection. + /// The for the connection. + /// The options with which the websocket must be created. + [UnsupportedOSPlatform("browser")] + public static WebSocket CreateFromStream(Stream stream, WebSocketCreationOptions options) + { + if (stream is null) throw new ArgumentNullException(nameof(stream)); - } + + if (options is null) + throw new ArgumentNullException(nameof(options)); if (!stream.CanRead || !stream.CanWrite) - { throw new ArgumentException(!stream.CanRead ? SR.NotReadableStream : SR.NotWriteableStream, nameof(stream)); - } - if (subProtocol != null) - { - WebSocketValidate.ValidateSubprotocol(subProtocol); - } - - if (keepAliveInterval != Timeout.InfiniteTimeSpan && keepAliveInterval < TimeSpan.Zero) - { - throw new ArgumentOutOfRangeException(nameof(keepAliveInterval), keepAliveInterval, - SR.Format(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, - 0)); - } - - return ManagedWebSocket.CreateFromConnectedStream(stream, isServer, subProtocol, keepAliveInterval); + return ManagedWebSocket.CreateFromConnectedStream(stream, options); } [EditorBrowsable(EditorBrowsableState.Never)] @@ -190,18 +191,6 @@ public static WebSocket CreateClientWebSocket(Stream innerStream, throw new ArgumentException(!innerStream.CanRead ? SR.NotReadableStream : SR.NotWriteableStream, nameof(innerStream)); } - if (subProtocol != null) - { - WebSocketValidate.ValidateSubprotocol(subProtocol); - } - - if (keepAliveInterval != Timeout.InfiniteTimeSpan && keepAliveInterval < TimeSpan.Zero) - { - throw new ArgumentOutOfRangeException(nameof(keepAliveInterval), keepAliveInterval, - SR.Format(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, - 0)); - } - if (receiveBufferSize <= 0 || sendBufferSize <= 0) { throw new ArgumentOutOfRangeException( @@ -213,7 +202,12 @@ public static WebSocket CreateClientWebSocket(Stream innerStream, // Ignore useZeroMaskingKey. ManagedWebSocket doesn't currently support that debugging option. // Ignore internalBuffer. ManagedWebSocket uses its own small buffer for headers/control messages. - return ManagedWebSocket.CreateFromConnectedStream(innerStream, false, subProtocol, keepAliveInterval); + return ManagedWebSocket.CreateFromConnectedStream(innerStream, new WebSocketCreationOptions + { + IsServer = false, + KeepAliveInterval = keepAliveInterval, + SubProtocol = subProtocol + }); } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs new file mode 100644 index 0000000000000..3cfc9810610a1 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs @@ -0,0 +1,55 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading; + +namespace System.Net.WebSockets +{ + public sealed class WebSocketCreationOptions + { + private string? _subProtocol; + private TimeSpan _keepAliveInterval; + + /// + /// Defines if this websocket is the server-side of the connection. The default value is false. + /// + public bool IsServer { get; set; } + + /// + /// The agreed upon sub-protocol that was used when creating the connection. + /// + public string? SubProtocol + { + get => _subProtocol; + set + { + if (value is not null) + WebSocketValidate.ValidateSubprotocol(value); + + _subProtocol = value; + } + } + + /// + /// The keep-alive interval to use, or or to disable keep-alives. + /// The default is . + /// + public TimeSpan KeepAliveInterval + { + get => _keepAliveInterval; + set + { + if (value != Timeout.InfiniteTimeSpan && value < TimeSpan.Zero) + throw new ArgumentOutOfRangeException(nameof(KeepAliveInterval), value, + SR.Format(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, 0)); + + _keepAliveInterval = value; + } + } + + /// + /// The agreed upon options for per message deflate. + /// + public WebSocketDeflateOptions? DeflateOptions { get; set; } + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs new file mode 100644 index 0000000000000..375617ec3cb45 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs @@ -0,0 +1,60 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Net.WebSockets +{ + /// + /// Options to enable per-message deflate compression for . + /// + public sealed class WebSocketDeflateOptions + { + private int _clientMaxWindowBits = 15; + private int _serverMaxWindowBits = 15; + + /// + /// This parameter indicates the base-2 logarithm of the LZ77 sliding window size of the client context. + /// Must be a value between 8 and 15. The default is 15. + /// + public int ClientMaxWindowBits + { + get => _clientMaxWindowBits; + set + { + if (value < 8 || value > 15) + throw new ArgumentOutOfRangeException(nameof(ClientMaxWindowBits), value, + SR.Format(SR.net_WebSockets_ArgumentOutOfRange, 8, 15)); + + _clientMaxWindowBits = value; + } + } + + /// + /// When true the client-side of the connection indicates that it will persist the deflate context accross messages. + /// The default is true. + /// + public bool ClientContextTakeover { get; set; } = true; + + /// + /// This parameter indicates the base-2 logarithm of the LZ77 sliding window size of the server context. + /// Must be a value between 8 and 15. The default is 15. + /// + public int ServerMaxWindowBits + { + get => _serverMaxWindowBits; + set + { + if (value < 8 || value > 15) + throw new ArgumentOutOfRangeException(nameof(ServerMaxWindowBits), value, + SR.Format(SR.net_WebSockets_ArgumentOutOfRange, 8, 15)); + + _serverMaxWindowBits = value; + } + } + + /// + /// When true the server-side of the connection indicates that it will persist the deflate context accross messages. + /// The default is true. + /// + public bool ServerContextTakeover { get; set; } = true; + } +} From 3913164620677a8e29674915310827488650faa3 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 12 Feb 2021 14:09:44 +0200 Subject: [PATCH 02/52] Added inflater / deflater and retargeted the project for each of the supported platforms because we now depend on native API. --- .../src/Resources/Strings.resx | 24 ++ .../src/System.Net.WebSockets.csproj | 20 +- .../src/System/IO/Compression/Deflater.cs | 234 ++++++++++++ .../src/System/IO/Compression/Inflater.cs | 327 +++++++++++++++++ .../IO/Compression/ZLibNative.ZStream.cs | 34 ++ .../src/System/IO/Compression/ZLibNative.cs | 345 ++++++++++++++++++ .../System/Net/WebSockets/ManagedWebSocket.cs | 27 +- 7 files changed, 1004 insertions(+), 7 deletions(-) create mode 100644 src/libraries/System.Net.WebSockets/src/System/IO/Compression/Deflater.cs create mode 100644 src/libraries/System.Net.WebSockets/src/System/IO/Compression/Inflater.cs create mode 100644 src/libraries/System.Net.WebSockets/src/System/IO/Compression/ZLibNative.ZStream.cs create mode 100644 src/libraries/System.Net.WebSockets/src/System/IO/Compression/ZLibNative.cs diff --git a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx index a3eecbceb6b61..09cc98eaad0e2 100644 --- a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx +++ b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx @@ -201,4 +201,28 @@ The argument must be a value between {0} and {1}. + + The underlying compression routine could not be loaded correctly. + + + The underlying compression routine received incorrect initialization parameters. + + + The underlying compression routine could not reserve sufficient memory. + + + The underlying compression routine returned an unexpected error code {0}. + + + The version of the underlying compression routine does not match expected version. + + + The message was compressed using an unsupported compression method. + + + The WebSocket received a continuation frame with Per-Message Compressed flag set. + + + The stream state of the underlying compression routine is inconsistent. + \ No newline at end of file diff --git a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj index 67ca96113f754..eba162e9f17fc 100644 --- a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj +++ b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj @@ -1,7 +1,7 @@ - + True - $(NetCoreAppCurrent) + $(NetCoreAppCurrent)-windows;$(NetCoreAppCurrent)-Unix;$(NetCoreAppCurrent)-Browser enable @@ -17,8 +17,24 @@ + + + + + + + + + + + + + diff --git a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Deflater.cs b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Deflater.cs new file mode 100644 index 0000000000000..2a65f160017f0 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Deflater.cs @@ -0,0 +1,234 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Diagnostics; +using System.Net.WebSockets; +using System.Security; + +using ZErrorCode = System.IO.Compression.ZLibNative.ErrorCode; +using ZFlushCode = System.IO.Compression.ZLibNative.FlushCode; + +namespace System.IO.Compression +{ + /// + /// Provides a wrapper around the ZLib compression API. + /// + internal sealed class Deflater : IDisposable + { + private readonly ZLibNative.ZLibStreamHandle _zlibStream; + private MemoryHandle _inputBufferHandle; + private bool _isDisposed; + private const int minWindowBits = -15; // WindowBits must be between -8..-15 to write no header, 8..15 for a + private const int maxWindowBits = 31; // zlib header, or 24..31 for a GZip header + + // Note, DeflateStream or the deflater do not try to be thread safe. + // The lock is just used to make writing to unmanaged structures atomic to make sure + // that they do not get inconsistent fields that may lead to an unmanaged memory violation. + // To prevent *managed* buffer corruption or other weird behaviour users need to synchronise + // on the stream explicitly. + private object SyncLock => this; + + internal Deflater(int windowBits) + { + Debug.Assert(windowBits >= minWindowBits && windowBits <= maxWindowBits); + + var compressionLevel = ZLibNative.CompressionLevel.DefaultCompression; + var memLevel = ZLibNative.Deflate_DefaultMemLevel; + var strategy = ZLibNative.CompressionStrategy.DefaultStrategy; + + ZErrorCode errC; + try + { + errC = ZLibNative.CreateZLibStreamForDeflate(out _zlibStream, compressionLevel, + windowBits, memLevel, strategy); + } + catch (Exception cause) + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); + } + + switch (errC) + { + case ZErrorCode.Ok: + return; + + case ZErrorCode.MemError: + throw new WebSocketException(SR.ZLibErrorNotEnoughMemory); + + case ZErrorCode.VersionError: + throw new WebSocketException(SR.ZLibErrorVersionMismatch); + + case ZErrorCode.StreamError: + throw new WebSocketException(SR.ZLibErrorIncorrectInitParameters); + + default: + throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errC)); + } + } + + ~Deflater() + { + Dispose(false); + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (!_isDisposed) + { + if (disposing) + _zlibStream.Dispose(); + + DeallocateInputBufferHandle(); + _isDisposed = true; + } + } + + public bool NeedsInput() => 0 == _zlibStream.AvailIn; + + internal unsafe void SetInput(ReadOnlyMemory inputBuffer) + { + Debug.Assert(NeedsInput(), "We have something left in previous input!"); + if (0 == inputBuffer.Length) + { + return; + } + + lock (SyncLock) + { + _inputBufferHandle = inputBuffer.Pin(); + + _zlibStream.NextIn = (IntPtr)_inputBufferHandle.Pointer; + _zlibStream.AvailIn = (uint)inputBuffer.Length; + } + } + + internal unsafe void SetInput(byte* inputBufferPtr, int count) + { + Debug.Assert(NeedsInput(), "We have something left in previous input!"); + Debug.Assert(inputBufferPtr != null); + + if (count == 0) + { + return; + } + + lock (SyncLock) + { + _zlibStream.NextIn = (IntPtr)inputBufferPtr; + _zlibStream.AvailIn = (uint)count; + } + } + + internal int GetDeflateOutput(byte[] outputBuffer) + { + Debug.Assert(null != outputBuffer, "Can't pass in a null output buffer!"); + Debug.Assert(!NeedsInput(), "GetDeflateOutput should only be called after providing input"); + + try + { + int bytesRead; + ReadDeflateOutput(outputBuffer, ZFlushCode.NoFlush, out bytesRead); + return bytesRead; + } + finally + { + // Before returning, make sure to release input buffer if necessary: + if (0 == _zlibStream.AvailIn) + { + DeallocateInputBufferHandle(); + } + } + } + + private unsafe ZErrorCode ReadDeflateOutput(byte[] outputBuffer, ZFlushCode flushCode, out int bytesRead) + { + Debug.Assert(outputBuffer?.Length > 0); + + lock (SyncLock) + { + fixed (byte* bufPtr = &outputBuffer[0]) + { + _zlibStream.NextOut = (IntPtr)bufPtr; + _zlibStream.AvailOut = (uint)outputBuffer.Length; + + ZErrorCode errC = Deflate(flushCode); + bytesRead = outputBuffer.Length - (int)_zlibStream.AvailOut; + + return errC; + } + } + } + + internal bool Finish(byte[] outputBuffer, out int bytesRead) + { + Debug.Assert(null != outputBuffer, "Can't pass in a null output buffer!"); + Debug.Assert(outputBuffer.Length > 0, "Can't pass in an empty output buffer!"); + + ZErrorCode errC = ReadDeflateOutput(outputBuffer, ZFlushCode.Finish, out bytesRead); + return errC == ZErrorCode.StreamEnd; + } + + /// + /// Returns true if there was something to flush. Otherwise False. + /// + internal bool Flush(byte[] outputBuffer, out int bytesRead) + { + Debug.Assert(null != outputBuffer, "Can't pass in a null output buffer!"); + Debug.Assert(outputBuffer.Length > 0, "Can't pass in an empty output buffer!"); + Debug.Assert(NeedsInput(), "We have something left in previous input!"); + + + // Note: we require that NeedsInput() == true, i.e. that 0 == _zlibStream.AvailIn. + // If there is still input left we should never be getting here; instead we + // should be calling GetDeflateOutput. + + return ReadDeflateOutput(outputBuffer, ZFlushCode.SyncFlush, out bytesRead) == ZErrorCode.Ok; + } + + private void DeallocateInputBufferHandle() + { + lock (SyncLock) + { + _zlibStream.AvailIn = 0; + _zlibStream.NextIn = ZLibNative.ZNullPtr; + _inputBufferHandle.Dispose(); + } + } + + private ZErrorCode Deflate(ZFlushCode flushCode) + { + ZErrorCode errC; + try + { + errC = _zlibStream.Deflate(flushCode); + } + catch (Exception cause) + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); + } + + switch (errC) + { + case ZErrorCode.Ok: + case ZErrorCode.StreamEnd: + return errC; + + case ZErrorCode.BufError: + return errC; // This is a recoverable error + + case ZErrorCode.StreamError: + throw new WebSocketException(SR.ZLibErrorInconsistentStream); + + default: + throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errC)); + } + } + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Inflater.cs b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Inflater.cs new file mode 100644 index 0000000000000..d289b3e199d1d --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Inflater.cs @@ -0,0 +1,327 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Net.WebSockets; +using System.Runtime.InteropServices; +using System.Security; + +namespace System.IO.Compression +{ + /// + /// Provides a wrapper around the ZLib decompression API. + /// + internal sealed class Inflater : IDisposable + { + private const int MinWindowBits = -15; // WindowBits must be between -8..-15 to ignore the header, 8..15 for + private const int MaxWindowBits = 47; // zlib headers, 24..31 for GZip headers, or 40..47 for either Zlib or GZip + + private bool _finished; // Whether the end of the stream has been reached + private bool _isDisposed; // Prevents multiple disposals + private readonly int _windowBits; // The WindowBits parameter passed to Inflater construction + private ZLibNative.ZLibStreamHandle _zlibStream; // The handle to the primary underlying zlib stream + private GCHandle _inputBufferHandle; // The handle to the buffer that provides input to _zlibStream + private readonly long _uncompressedSize; + private long _currentInflatedCount; + + private object SyncLock => this; // Used to make writing to unmanaged structures atomic + + /// + /// Initialized the Inflater with the given windowBits size + /// + internal Inflater(int windowBits, long uncompressedSize = -1) + { + Debug.Assert(windowBits >= MinWindowBits && windowBits <= MaxWindowBits); + _finished = false; + _isDisposed = false; + _windowBits = windowBits; + InflateInit(windowBits); + _uncompressedSize = uncompressedSize; + } + + public int AvailableOutput => (int)_zlibStream.AvailOut; + + /// + /// Returns true if the end of the stream has been reached. + /// + public bool Finished() => _finished; + + public unsafe bool Inflate(out byte b) + { + fixed (byte* bufPtr = &b) + { + int bytesRead = InflateVerified(bufPtr, 1); + Debug.Assert(bytesRead == 0 || bytesRead == 1); + return bytesRead != 0; + } + } + + public unsafe int Inflate(byte[] bytes, int offset, int length) + { + // If Inflate is called on an invalid or unready inflater, return 0 to indicate no bytes have been read. + if (length == 0) + return 0; + + Debug.Assert(null != bytes, "Can't pass in a null output buffer!"); + fixed (byte* bufPtr = bytes) + { + return InflateVerified(bufPtr + offset, length); + } + } + + public unsafe int Inflate(Span destination) + { + // If Inflate is called on an invalid or unready inflater, return 0 to indicate no bytes have been read. + if (destination.Length == 0) + return 0; + + fixed (byte* bufPtr = &MemoryMarshal.GetReference(destination)) + { + return InflateVerified(bufPtr, destination.Length); + } + } + + public unsafe int InflateVerified(byte* bufPtr, int length) + { + // State is valid; attempt inflation + try + { + int bytesRead = 0; + if (_uncompressedSize == -1) + { + ReadOutput(bufPtr, length, out bytesRead); + } + else + { + if (_uncompressedSize > _currentInflatedCount) + { + length = (int)Math.Min(length, _uncompressedSize - _currentInflatedCount); + ReadOutput(bufPtr, length, out bytesRead); + _currentInflatedCount += bytesRead; + } + else + { + _finished = true; + _zlibStream.AvailIn = 0; + } + } + return bytesRead; + } + finally + { + // Before returning, make sure to release input buffer if necessary: + if (0 == _zlibStream.AvailIn && _inputBufferHandle.IsAllocated) + { + DeallocateInputBufferHandle(); + } + } + } + + private unsafe void ReadOutput(byte* bufPtr, int length, out int bytesRead) + { + if (ReadInflateOutput(bufPtr, length, ZLibNative.FlushCode.NoFlush, out bytesRead) == ZLibNative.ErrorCode.StreamEnd) + { + if (!NeedsInput() && IsGzipStream() && _inputBufferHandle.IsAllocated) + { + _finished = ResetStreamForLeftoverInput(); + } + else + { + _finished = true; + } + } + } + + /// + /// If this stream has some input leftover that hasn't been processed then we should + /// check if it is another GZip file concatenated with this one. + /// + /// Returns false if the leftover input is another GZip data stream. + /// + private unsafe bool ResetStreamForLeftoverInput() + { + Debug.Assert(!NeedsInput()); + Debug.Assert(IsGzipStream()); + Debug.Assert(_inputBufferHandle.IsAllocated); + + lock (SyncLock) + { + IntPtr nextInPtr = _zlibStream.NextIn; + byte* nextInPointer = (byte*)nextInPtr.ToPointer(); + uint nextAvailIn = _zlibStream.AvailIn; + + // Check the leftover bytes to see if they start with he gzip header ID bytes + if (*nextInPointer != ZLibNative.GZip_Header_ID1 || (nextAvailIn > 1 && *(nextInPointer + 1) != ZLibNative.GZip_Header_ID2)) + { + return true; + } + + // Trash our existing zstream. + _zlibStream.Dispose(); + + // Create a new zstream + InflateInit(_windowBits); + + // SetInput on the new stream to the bits remaining from the last stream + _zlibStream.NextIn = nextInPtr; + _zlibStream.AvailIn = nextAvailIn; + _finished = false; + } + + return false; + } + + internal bool IsGzipStream() => _windowBits >= 24 && _windowBits <= 31; + + public bool NeedsInput() => _zlibStream.AvailIn == 0; + + public void SetInput(byte[] inputBuffer, int startIndex, int count) + { + Debug.Assert(NeedsInput(), "We have something left in previous input!"); + Debug.Assert(inputBuffer != null); + Debug.Assert(startIndex >= 0 && count >= 0 && count + startIndex <= inputBuffer.Length); + Debug.Assert(!_inputBufferHandle.IsAllocated); + + if (0 == count) + return; + + lock (SyncLock) + { + _inputBufferHandle = GCHandle.Alloc(inputBuffer, GCHandleType.Pinned); + _zlibStream.NextIn = _inputBufferHandle.AddrOfPinnedObject() + startIndex; + _zlibStream.AvailIn = (uint)count; + _finished = false; + } + } + + private void Dispose(bool disposing) + { + if (!_isDisposed) + { + if (disposing) + _zlibStream.Dispose(); + + if (_inputBufferHandle.IsAllocated) + DeallocateInputBufferHandle(); + + _isDisposed = true; + } + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + ~Inflater() + { + Dispose(false); + } + + /// + /// Creates the ZStream that will handle inflation. + /// + [MemberNotNull(nameof(_zlibStream))] + private void InflateInit(int windowBits) + { + ZLibNative.ErrorCode error; + try + { + error = ZLibNative.CreateZLibStreamForInflate(out _zlibStream, windowBits); + } + catch (Exception exception) // could not load the ZLib dll + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, exception); + } + + switch (error) + { + case ZLibNative.ErrorCode.Ok: // Successful initialization + return; + + case ZLibNative.ErrorCode.MemError: // Not enough memory + throw new WebSocketException(SR.ZLibErrorNotEnoughMemory); + + case ZLibNative.ErrorCode.VersionError: //zlib library is incompatible with the version assumed + throw new WebSocketException(SR.ZLibErrorVersionMismatch); + + case ZLibNative.ErrorCode.StreamError: // Parameters are invalid + throw new WebSocketException(SR.ZLibErrorIncorrectInitParameters); + + default: + throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)error)); + } + } + + /// + /// Wrapper around the ZLib inflate function, configuring the stream appropriately. + /// + private unsafe ZLibNative.ErrorCode ReadInflateOutput(byte* bufPtr, int length, ZLibNative.FlushCode flushCode, out int bytesRead) + { + lock (SyncLock) + { + _zlibStream.NextOut = (IntPtr)bufPtr; + _zlibStream.AvailOut = (uint)length; + + ZLibNative.ErrorCode errC = Inflate(flushCode); + bytesRead = length - (int)_zlibStream.AvailOut; + + return errC; + } + } + + /// + /// Wrapper around the ZLib inflate function + /// + private ZLibNative.ErrorCode Inflate(ZLibNative.FlushCode flushCode) + { + ZLibNative.ErrorCode errC; + try + { + errC = _zlibStream.Inflate(flushCode); + } + catch (Exception cause) // could not load the Zlib DLL correctly + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); + } + switch (errC) + { + case ZLibNative.ErrorCode.Ok: // progress has been made inflating + case ZLibNative.ErrorCode.StreamEnd: // The end of the input stream has been reached + return errC; + + case ZLibNative.ErrorCode.BufError: // No room in the output buffer - inflate() can be called again with more space to continue + return errC; + + case ZLibNative.ErrorCode.MemError: // Not enough memory to complete the operation + throw new WebSocketException(SR.ZLibErrorNotEnoughMemory); + + case ZLibNative.ErrorCode.DataError: // The input data was corrupted (input stream not conforming to the zlib format or incorrect check value) + throw new WebSocketException(SR.UnsupportedCompression); + + case ZLibNative.ErrorCode.StreamError: //the stream structure was inconsistent (for example if next_in or next_out was NULL), + throw new WebSocketException(SR.ZLibErrorInconsistentStream); + + default: + throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errC)); + } + } + + /// + /// Frees the GCHandle being used to store the input buffer + /// + private void DeallocateInputBufferHandle() + { + Debug.Assert(_inputBufferHandle.IsAllocated); + + lock (SyncLock) + { + _zlibStream.AvailIn = 0; + _zlibStream.NextIn = ZLibNative.ZNullPtr; + _inputBufferHandle.Free(); + } + } + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/ZLibNative.ZStream.cs b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/ZLibNative.ZStream.cs new file mode 100644 index 0000000000000..bfb8c5145c04a --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/ZLibNative.ZStream.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.InteropServices; + +namespace System.IO.Compression +{ + internal static partial class ZLibNative + { + /// + /// ZLib stream descriptor data structure + /// Do not construct instances of ZStream explicitly. + /// Always use ZLibNative.DeflateInit2_ or ZLibNative.InflateInit2_ instead. + /// Those methods will wrap this structure into a SafeHandle and thus make sure that it is always disposed correctly. + /// + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Ansi)] + internal struct ZStream + { + internal void Init() + { + } + + internal IntPtr nextIn; //Bytef *next_in; /* next input byte */ + internal IntPtr nextOut; //Bytef *next_out; /* next output byte should be put there */ + + internal IntPtr msg; //char *msg; /* last error message, NULL if no error */ + + private readonly IntPtr internalState; //internal state that is not visible to managed code + + internal uint availIn; //uInt avail_in; /* number of bytes available at next_in */ + internal uint availOut; //uInt avail_out; /* remaining free space at next_out */ + } + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/ZLibNative.cs b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/ZLibNative.cs new file mode 100644 index 0000000000000..8118aeba0ecb8 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/ZLibNative.cs @@ -0,0 +1,345 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.InteropServices; +using System.Security; + +namespace System.IO.Compression +{ + /// + /// This class provides declaration for constants and PInvokes as well as some basic tools for exposing the + /// native System.IO.Compression.Native.dll (effectively, ZLib) library to managed code. + /// + /// See also: How to choose a compression level (in comments to CompressionLevel. + /// + internal static partial class ZLibNative + { + // This is the NULL pointer for using with ZLib pointers; + // we prefer it to IntPtr.Zero to mimic the definition of Z_NULL in zlib.h: + internal static readonly IntPtr ZNullPtr = IntPtr.Zero; + + public enum FlushCode : int + { + NoFlush = 0, + SyncFlush = 2, + Finish = 4, + } + + public enum ErrorCode : int + { + Ok = 0, + StreamEnd = 1, + StreamError = -2, + DataError = -3, + MemError = -4, + BufError = -5, + VersionError = -6 + } + + /// + ///

ZLib can accept any integer value between 0 and 9 (inclusive) as a valid compression level parameter: + /// 1 gives best speed, 9 gives best compression, 0 gives no compression at all (the input data is simply copied a block at a time). + /// CompressionLevel.DefaultCompression = -1 requests a default compromise between speed and compression + /// (currently equivalent to level 6).

+ /// + ///

How to choose a compression level:

+ /// + ///

The names NoCompression, BestSpeed, DefaultCompression, BestCompression are taken over from + /// the corresponding ZLib definitions, which map to our public NoCompression, Fastest, Optimal, and SmallestSize respectively.

+ ///

Optimal Compression:

+ ///

ZLibNative.CompressionLevel compressionLevel = ZLibNative.CompressionLevel.DefaultCompression;
+ /// int windowBits = 15; // or -15 if no headers required
+ /// int memLevel = 8;
+ /// ZLibNative.CompressionStrategy strategy = ZLibNative.CompressionStrategy.DefaultStrategy;

+ /// + ///

Fastest compression:

+ ///

ZLibNative.CompressionLevel compressionLevel = ZLibNative.CompressionLevel.BestSpeed;
+ /// int windowBits = 15; // or -15 if no headers required
+ /// int memLevel = 8;
+ /// ZLibNative.CompressionStrategy strategy = ZLibNative.CompressionStrategy.DefaultStrategy;

+ /// + ///

No compression (even faster, useful for data that cannot be compressed such some image formats):

+ ///

ZLibNative.CompressionLevel compressionLevel = ZLibNative.CompressionLevel.NoCompression;
+ /// int windowBits = 15; // or -15 if no headers required
+ /// int memLevel = 7;
+ /// ZLibNative.CompressionStrategy strategy = ZLibNative.CompressionStrategy.DefaultStrategy;

+ /// + ///

Smallest Size Compression:

+ ///

ZLibNative.CompressionLevel compressionLevel = ZLibNative.CompressionLevel.BestCompression;
+ /// int windowBits = 15; // or -15 if no headers required
+ /// int memLevel = 8;
+ /// ZLibNative.CompressionStrategy strategy = ZLibNative.CompressionStrategy.DefaultStrategy;

+ ///
+ public enum CompressionLevel : int + { + NoCompression = 0, + BestSpeed = 1, + DefaultCompression = -1, + BestCompression = 9 + } + + /// + ///

From the ZLib manual:

+ ///

CompressionStrategy is used to tune the compression algorithm.
+ /// Use the value DefaultStrategy for normal data, Filtered for data produced by a filter (or predictor), + /// HuffmanOnly to force Huffman encoding only (no string match), or Rle to limit match distances to one + /// (run-length encoding). Filtered data consists mostly of small values with a somewhat random distribution. In this case, the + /// compression algorithm is tuned to compress them better. The effect of Filtered is to force more Huffman coding and] + /// less string matching; it is somewhat intermediate between DefaultStrategy and HuffmanOnly. + /// Rle is designed to be almost as fast as HuffmanOnly, but give better compression for PNG image data. + /// The strategy parameter only affects the compression ratio but not the correctness of the compressed output even if it is not set + /// appropriately. Fixed prevents the use of dynamic Huffman codes, allowing for a simpler decoder for special applications.

+ /// + ///

For .NET Framework use:

+ ///

We have investigated compression scenarios for a bunch of different frequently occurring compression data and found that in all + /// cases we investigated so far, DefaultStrategy provided best results

+ ///

See also: How to choose a compression level (in comments to CompressionLevel.

+ ///
+ public enum CompressionStrategy : int + { + DefaultStrategy = 0 + } + + /// + /// In version 1.2.3, ZLib provides on the Deflated-CompressionMethod. + /// + public enum CompressionMethod : int + { + Deflated = 8 + } + + /// + ///

From the ZLib manual:

+ ///

ZLib's windowBits parameter is the base two logarithm of the window size (the size of the history buffer). + /// It should be in the range 8..15 for this version of the library. Larger values of this parameter result in better compression + /// at the expense of memory usage. The default value is 15 if deflateInit is used instead.

+ /// Note: + /// windowBits can also be -8..-15 for raw deflate. In this case, -windowBits determines the window size. + /// Deflate will then generate raw deflate data with no ZLib header or trailer, and will not compute an adler32 check value.
+ ///

See also: How to choose a compression level (in comments to CompressionLevel.

+ ///
+ public const int Deflate_DefaultWindowBits = -15; // Legal values are 8..15 and -8..-15. 15 is the window size, + // negative val causes deflate to produce raw deflate data (no zlib header). + + /// + ///

From the ZLib manual:

+ ///

ZLib's windowBits parameter is the base two logarithm of the window size (the size of the history buffer). + /// It should be in the range 8..15 for this version of the library. Larger values of this parameter result in better compression + /// at the expense of memory usage. The default value is 15 if deflateInit is used instead.

+ ///
+ public const int ZLib_DefaultWindowBits = 15; + + /// + ///

Zlib's windowBits parameter is the base two logarithm of the window size (the size of the history buffer). + /// For GZip header encoding, windowBits should be equal to a value between 8..15 (to specify Window Size) added to + /// 16. The range of values for GZip encoding is therefore 24..31. + /// Note: + /// The GZip header will have no file name, no extra data, no comment, no modification time (set to zero), no header crc, and + /// the operating system will be set based on the OS that the ZLib library was compiled to. ZStream.adler + /// is a crc32 instead of an adler32.

+ ///
+ public const int GZip_DefaultWindowBits = 31; + + /// + ///

From the ZLib manual:

+ ///

The memLevel parameter specifies how much memory should be allocated for the internal compression state. + /// memLevel = 1 uses minimum memory but is slow and reduces compression ratio; memLevel = 9 uses maximum + /// memory for optimal speed. The default value is 8.

+ ///

See also: How to choose a compression level (in comments to CompressionLevel.

+ ///
+ public const int Deflate_DefaultMemLevel = 8; // Memory usage by deflate. Legal range: [1..9]. 8 is ZLib default. + // More is faster and better compression with more memory usage. + public const int Deflate_NoCompressionMemLevel = 7; + + public const byte GZip_Header_ID1 = 31; + public const byte GZip_Header_ID2 = 139; + + /** + * Do not remove the nested typing of types inside of System.IO.Compression.ZLibNative. + * This was done on purpose to: + * + * - Achieve the right encapsulation in a situation where ZLibNative may be compiled division-wide + * into different assemblies that wish to consume System.IO.Compression.Native. Since internal + * scope is effectively like public scope when compiling ZLibNative into a higher + * level assembly, we need a combination of inner types and private-scope members to achieve + * the right encapsulation. + * + * - Achieve late dynamic loading of System.IO.Compression.Native.dll at the right time. + * The native assembly will not be loaded unless it is actually used since the loading is performed by a static + * constructor of an inner type that is not directly referenced by user code. + * + * In Dev12 we would like to create a proper feature for loading native assemblies from user-specified + * directories in order to PInvoke into them. This would preferably happen in the native interop/PInvoke + * layer; if not we can add a Framework level feature. + */ + + /// + /// The ZLibStreamHandle could be a CriticalFinalizerObject rather than a + /// SafeHandleMinusOneIsInvalid. This would save an IntPtr field since + /// ZLibStreamHandle does not actually use its handle field. + /// Instead it uses a private ZStream zStream field which is the actual handle data + /// structure requiring critical finalization. + /// However, we would like to take advantage if the better debugability offered by the fact that a + /// releaseHandleFailed MDA is raised if the ReleaseHandle method returns + /// false, which can for instance happen if the underlying ZLib XxxxEnd + /// routines return an failure error code. + /// + public sealed class ZLibStreamHandle : SafeHandle + { + public enum State { NotInitialized, InitializedForDeflate, InitializedForInflate, Disposed } + + private ZStream _zStream; + + private volatile State _initializationState; + + + public ZLibStreamHandle() + : base(new IntPtr(-1), true) + { + _zStream.Init(); + + _initializationState = State.NotInitialized; + SetHandle(IntPtr.Zero); + } + + public override bool IsInvalid + { + get { return handle == new IntPtr(-1); } + } + + public State InitializationState + { + get { return _initializationState; } + } + + + protected override bool ReleaseHandle() => + InitializationState switch + { + State.NotInitialized => true, + State.InitializedForDeflate => (DeflateEnd() == ErrorCode.Ok), + State.InitializedForInflate => (InflateEnd() == ErrorCode.Ok), + State.Disposed => true, + _ => false, // This should never happen. Did we forget one of the State enum values in the switch? + }; + + public IntPtr NextIn + { + get { return _zStream.nextIn; } + set { _zStream.nextIn = value; } + } + + public uint AvailIn + { + get { return _zStream.availIn; } + set { _zStream.availIn = value; } + } + + public IntPtr NextOut + { + get { return _zStream.nextOut; } + set { _zStream.nextOut = value; } + } + + public uint AvailOut + { + get { return _zStream.availOut; } + set { _zStream.availOut = value; } + } + + private void EnsureNotDisposed() + { + if (InitializationState == State.Disposed) + throw new ObjectDisposedException(GetType().ToString()); + } + + + private void EnsureState(State requiredState) + { + if (InitializationState != requiredState) + throw new InvalidOperationException("InitializationState != " + requiredState.ToString()); + } + + + public ErrorCode DeflateInit2_(CompressionLevel level, int windowBits, int memLevel, CompressionStrategy strategy) + { + EnsureNotDisposed(); + EnsureState(State.NotInitialized); + + ErrorCode errC = Interop.zlib.DeflateInit2_(ref _zStream, level, CompressionMethod.Deflated, windowBits, memLevel, strategy); + _initializationState = State.InitializedForDeflate; + + return errC; + } + + + public ErrorCode Deflate(FlushCode flush) + { + EnsureNotDisposed(); + EnsureState(State.InitializedForDeflate); + return Interop.zlib.Deflate(ref _zStream, flush); + } + + + public ErrorCode DeflateEnd() + { + EnsureNotDisposed(); + EnsureState(State.InitializedForDeflate); + + ErrorCode errC = Interop.zlib.DeflateEnd(ref _zStream); + _initializationState = State.Disposed; + + return errC; + } + + + public ErrorCode InflateInit2_(int windowBits) + { + EnsureNotDisposed(); + EnsureState(State.NotInitialized); + + ErrorCode errC = Interop.zlib.InflateInit2_(ref _zStream, windowBits); + _initializationState = State.InitializedForInflate; + + return errC; + } + + + public ErrorCode Inflate(FlushCode flush) + { + EnsureNotDisposed(); + EnsureState(State.InitializedForInflate); + return Interop.zlib.Inflate(ref _zStream, flush); + } + + + public ErrorCode InflateEnd() + { + EnsureNotDisposed(); + EnsureState(State.InitializedForInflate); + + ErrorCode errC = Interop.zlib.InflateEnd(ref _zStream); + _initializationState = State.Disposed; + + return errC; + } + + // This can work even after XxflateEnd(). + public string GetErrorMessage() => _zStream.msg != ZNullPtr ? Marshal.PtrToStringAnsi(_zStream.msg)! : string.Empty; + } + + public static ErrorCode CreateZLibStreamForDeflate(out ZLibStreamHandle zLibStreamHandle, CompressionLevel level, + int windowBits, int memLevel, CompressionStrategy strategy) + { + zLibStreamHandle = new ZLibStreamHandle(); + return zLibStreamHandle.DeflateInit2_(level, windowBits, memLevel, strategy); + } + + + public static ErrorCode CreateZLibStreamForInflate(out ZLibStreamHandle zLibStreamHandle, int windowBits) + { + zLibStreamHandle = new ZLibStreamHandle(); + return zLibStreamHandle.InflateInit2_(windowBits); + } + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 8579de1651515..6810c5371ed0e 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -3,6 +3,7 @@ using System.Buffers; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.IO; using System.Numerics; using System.Runtime.CompilerServices; @@ -86,6 +87,8 @@ public static ManagedWebSocket CreateFromConnectedStream(Stream stream, WebSocke /// private readonly SemaphoreSlim _sendFrameAsyncLock = new SemaphoreSlim(1, 1); + private readonly bool _compressionEnabled; + // We maintain the current WebSocketState in _state. However, we separately maintain _sentCloseFrame and _receivedCloseFrame // as there isn't a strict ordering between CloseSent and CloseReceived. If we receive a close frame from the server, we need to // transition to CloseReceived even if we're currently in CloseSent, and if we send a close frame, we need to transition to @@ -163,6 +166,7 @@ private ManagedWebSocket(Stream stream, WebSocketCreationOptions options) _stream = stream; _isServer = options.IsServer; _subprotocol = options.SubProtocol; + _compressionEnabled = options.DeflateOptions is not null; // Create a buffer just large enough to handle received packet headers (at most 14 bytes) and // control payloads (at most 125 bytes). Message payloads are read directly into the buffer @@ -523,7 +527,6 @@ private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, Read { // Ensure we have a _sendBuffer. AllocateSendBuffer(payloadBuffer.Length + MaxMessageHeaderLength); - Debug.Assert(_sendBuffer != null); // Write the message header data to the buffer. int headerLength; @@ -588,11 +591,11 @@ private void SendKeepAliveFrameAsync() } } - private static int WriteHeader(MessageOpcode opcode, byte[] sendBuffer, ReadOnlySpan payload, bool endOfMessage, bool useMask) + private int WriteHeader(MessageOpcode opcode, byte[] sendBuffer, ReadOnlySpan payload, bool endOfMessage, bool useMask) { // Client header format: // 1 bit - FIN - 1 if this is the final fragment in the message (it could be the only fragment), otherwise 0 - // 1 bit - RSV1 - Reserved - 0 + // 1 bit - RSV1 - Per-Message Deflate Compress // 1 bit - RSV2 - Reserved - 0 // 1 bit - RSV3 - Reserved - 0 // 4 bits - Opcode - How to interpret the payload @@ -616,7 +619,12 @@ private static int WriteHeader(MessageOpcode opcode, byte[] sendBuffer, ReadOnly sendBuffer[0] = (byte)opcode; // 4 bits for the opcode if (endOfMessage) { - sendBuffer[0] |= 0x80; // 1 bit for FIN + sendBuffer[0] |= 0b1000_0000; // 1 bit for FIN + } + + if (_compressionEnabled && opcode is MessageOpcode.Text or MessageOpcode.Binary) + { + sendBuffer[0] |= 0b0100_0000; } // Store the payload length. @@ -1040,8 +1048,9 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync( Span receiveBufferSpan = _receiveBuffer.Span; header.Fin = (receiveBufferSpan[_receiveBufferOffset] & 0x80) != 0; - bool reservedSet = (receiveBufferSpan[_receiveBufferOffset] & 0x70) != 0; + bool reservedSet = (receiveBufferSpan[_receiveBufferOffset] & 0b0011_0000) != 0; header.Opcode = (MessageOpcode)(receiveBufferSpan[_receiveBufferOffset] & 0xF); + header.Compressed = (receiveBufferSpan[_receiveBufferOffset] & 0b0100_0000) != 0; bool masked = (receiveBufferSpan[_receiveBufferOffset + 1] & 0x80) != 0; header.PayloadLength = receiveBufferSpan[_receiveBufferOffset + 1] & 0x7F; @@ -1095,6 +1104,12 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync( resultHeader = default; return SR.net_Websockets_ContinuationFromFinalFrame; } + if (header.Compressed) + { + // Per-Message Compressed flag must be set only in the first frame + resultHeader = default; + return SR.net_Websockets_PerMessageCompressedFlagInContinuation; + } break; case MessageOpcode.Binary: @@ -1320,6 +1335,7 @@ private void ThrowIfEOFUnexpected(bool throwOnPrematureClosure) } /// Gets a send buffer from the pool. + [MemberNotNull(nameof(_sendBuffer))] private void AllocateSendBuffer(int minLength) { Debug.Assert(_sendBuffer == null); // would only fail if had some catastrophic error previously that prevented cleaning up @@ -1570,6 +1586,7 @@ private struct MessageHeader internal bool Fin; internal long PayloadLength; internal int Mask; + internal bool Compressed; } /// From 1a3f8cfaec64043023e3ffb25452591a00829efd Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 12 Feb 2021 20:49:35 +0200 Subject: [PATCH 03/52] Websocket sending extracted into a dedicated class so we can implement deflate encoders more naturally. --- .../src/System.Net.WebSockets.csproj | 1 + .../src/System/IO/Compression/Deflater.cs | 162 ++-------- .../Net/WebSockets/ManagedWebSocket.Sender.cs | 292 ++++++++++++++++++ .../System/Net/WebSockets/ManagedWebSocket.cs | 175 +---------- 4 files changed, 334 insertions(+), 296 deletions(-) create mode 100644 src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs diff --git a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj index eba162e9f17fc..2ed0d8a17b9d9 100644 --- a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj +++ b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj @@ -15,6 +15,7 @@ + diff --git a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Deflater.cs b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Deflater.cs index 2a65f160017f0..3f718eb8d85e8 100644 --- a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Deflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Deflater.cs @@ -1,10 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Buffers; -using System.Diagnostics; using System.Net.WebSockets; -using System.Security; using ZErrorCode = System.IO.Compression.ZLibNative.ErrorCode; using ZFlushCode = System.IO.Compression.ZLibNative.FlushCode; @@ -16,11 +13,8 @@ namespace System.IO.Compression /// internal sealed class Deflater : IDisposable { - private readonly ZLibNative.ZLibStreamHandle _zlibStream; - private MemoryHandle _inputBufferHandle; + private readonly ZLibNative.ZLibStreamHandle _handle; private bool _isDisposed; - private const int minWindowBits = -15; // WindowBits must be between -8..-15 to write no header, 8..15 for a - private const int maxWindowBits = 31; // zlib header, or 24..31 for a GZip header // Note, DeflateStream or the deflater do not try to be thread safe. // The lock is just used to make writing to unmanaged structures atomic to make sure @@ -31,24 +25,21 @@ internal sealed class Deflater : IDisposable internal Deflater(int windowBits) { - Debug.Assert(windowBits >= minWindowBits && windowBits <= maxWindowBits); - var compressionLevel = ZLibNative.CompressionLevel.DefaultCompression; var memLevel = ZLibNative.Deflate_DefaultMemLevel; var strategy = ZLibNative.CompressionStrategy.DefaultStrategy; - ZErrorCode errC; + ZErrorCode errorCode; try { - errC = ZLibNative.CreateZLibStreamForDeflate(out _zlibStream, compressionLevel, - windowBits, memLevel, strategy); + errorCode = ZLibNative.CreateZLibStreamForDeflate(out _handle, compressionLevel, windowBits, memLevel, strategy); } catch (Exception cause) { throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); } - switch (errC) + switch (errorCode) { case ZErrorCode.Ok: return; @@ -63,171 +54,82 @@ internal Deflater(int windowBits) throw new WebSocketException(SR.ZLibErrorIncorrectInitParameters); default: - throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errC)); + throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errorCode)); } } - ~Deflater() - { - Dispose(false); - } - public void Dispose() - { - Dispose(true); - GC.SuppressFinalize(this); - } - - private void Dispose(bool disposing) { if (!_isDisposed) { - if (disposing) - _zlibStream.Dispose(); - - DeallocateInputBufferHandle(); + _handle.Dispose(); _isDisposed = true; } } - public bool NeedsInput() => 0 == _zlibStream.AvailIn; - - internal unsafe void SetInput(ReadOnlyMemory inputBuffer) + public unsafe void Deflate(ReadOnlySpan input, Span output, out int consumed, out int written) { - Debug.Assert(NeedsInput(), "We have something left in previous input!"); - if (0 == inputBuffer.Length) + fixed (byte* fixedInput = input) + fixed (byte* fixedOutput = output) { - return; - } + _handle.NextIn = (IntPtr)fixedInput; + _handle.AvailIn = (uint)input.Length; - lock (SyncLock) - { - _inputBufferHandle = inputBuffer.Pin(); - - _zlibStream.NextIn = (IntPtr)_inputBufferHandle.Pointer; - _zlibStream.AvailIn = (uint)inputBuffer.Length; - } - } - - internal unsafe void SetInput(byte* inputBufferPtr, int count) - { - Debug.Assert(NeedsInput(), "We have something left in previous input!"); - Debug.Assert(inputBufferPtr != null); - - if (count == 0) - { - return; - } - - lock (SyncLock) - { - _zlibStream.NextIn = (IntPtr)inputBufferPtr; - _zlibStream.AvailIn = (uint)count; - } - } + _handle.NextOut = (IntPtr)fixedOutput; + _handle.AvailOut = (uint)output.Length; - internal int GetDeflateOutput(byte[] outputBuffer) - { - Debug.Assert(null != outputBuffer, "Can't pass in a null output buffer!"); - Debug.Assert(!NeedsInput(), "GetDeflateOutput should only be called after providing input"); + Deflate(ZFlushCode.NoFlush); - try - { - int bytesRead; - ReadDeflateOutput(outputBuffer, ZFlushCode.NoFlush, out bytesRead); - return bytesRead; - } - finally - { - // Before returning, make sure to release input buffer if necessary: - if (0 == _zlibStream.AvailIn) - { - DeallocateInputBufferHandle(); - } + consumed = input.Length - (int)_handle.AvailIn; + written = output.Length - (int)_handle.AvailOut; } } - private unsafe ZErrorCode ReadDeflateOutput(byte[] outputBuffer, ZFlushCode flushCode, out int bytesRead) + public unsafe int Finish(Span output, out bool completed) { - Debug.Assert(outputBuffer?.Length > 0); - - lock (SyncLock) + fixed (byte* fixedOutput = output) { - fixed (byte* bufPtr = &outputBuffer[0]) - { - _zlibStream.NextOut = (IntPtr)bufPtr; - _zlibStream.AvailOut = (uint)outputBuffer.Length; + _handle.NextIn = IntPtr.Zero; + _handle.AvailIn = 0; - ZErrorCode errC = Deflate(flushCode); - bytesRead = outputBuffer.Length - (int)_zlibStream.AvailOut; + _handle.NextOut = (IntPtr)fixedOutput; + _handle.AvailOut = (uint)output.Length; - return errC; - } - } - } + var errorCode = Deflate(ZFlushCode.SyncFlush); + var writtenBytes = output.Length - (int)_handle.AvailOut; - internal bool Finish(byte[] outputBuffer, out int bytesRead) - { - Debug.Assert(null != outputBuffer, "Can't pass in a null output buffer!"); - Debug.Assert(outputBuffer.Length > 0, "Can't pass in an empty output buffer!"); + completed = errorCode == ZErrorCode.Ok && writtenBytes < output.Length; - ZErrorCode errC = ReadDeflateOutput(outputBuffer, ZFlushCode.Finish, out bytesRead); - return errC == ZErrorCode.StreamEnd; - } - - /// - /// Returns true if there was something to flush. Otherwise False. - /// - internal bool Flush(byte[] outputBuffer, out int bytesRead) - { - Debug.Assert(null != outputBuffer, "Can't pass in a null output buffer!"); - Debug.Assert(outputBuffer.Length > 0, "Can't pass in an empty output buffer!"); - Debug.Assert(NeedsInput(), "We have something left in previous input!"); - - - // Note: we require that NeedsInput() == true, i.e. that 0 == _zlibStream.AvailIn. - // If there is still input left we should never be getting here; instead we - // should be calling GetDeflateOutput. - - return ReadDeflateOutput(outputBuffer, ZFlushCode.SyncFlush, out bytesRead) == ZErrorCode.Ok; - } - - private void DeallocateInputBufferHandle() - { - lock (SyncLock) - { - _zlibStream.AvailIn = 0; - _zlibStream.NextIn = ZLibNative.ZNullPtr; - _inputBufferHandle.Dispose(); + return writtenBytes; } } private ZErrorCode Deflate(ZFlushCode flushCode) { - ZErrorCode errC; + ZErrorCode errorCode; try { - errC = _zlibStream.Deflate(flushCode); + errorCode = _handle.Deflate(flushCode); } catch (Exception cause) { throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); } - switch (errC) + switch (errorCode) { case ZErrorCode.Ok: case ZErrorCode.StreamEnd: - return errC; + return errorCode; case ZErrorCode.BufError: - return errC; // This is a recoverable error + return errorCode; // This is a recoverable error case ZErrorCode.StreamError: throw new WebSocketException(SR.ZLibErrorInconsistentStream); default: - throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errC)); + throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errorCode)); } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs new file mode 100644 index 0000000000000..03c6f8d4b9bef --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs @@ -0,0 +1,292 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Diagnostics; +using System.IO; +using System.IO.Compression; +using System.Security.Cryptography; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.WebSockets +{ + internal partial class ManagedWebSocket + { + private sealed class Sender : IDisposable + { + private readonly int _maskLength; + private readonly Encoder? _encoder; + + public Sender(WebSocketCreationOptions options) + { + _maskLength = options.IsServer ? 0 : MaskLength; + + var deflate = options.DeflateOptions; + + if (deflate is not null) + { + // Important note here is that we must use negative window bits + // which will instruct the underlying implementation to not emit deflate headers + if (options.IsServer) + { + _encoder = deflate.ServerContextTakeover ? + new Deflate(-deflate.ServerMaxWindowBits) : + new PersistedDeflate(-deflate.ServerMaxWindowBits); + } + else + { + _encoder = deflate.ClientContextTakeover ? + new Deflate(-deflate.ClientMaxWindowBits) : + new PersistedDeflate(-deflate.ClientMaxWindowBits); + } + } + } + + public void Dispose() + { + _encoder?.Dispose(); + } + + public ValueTask SendAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemory content, Stream stream, CancellationToken cancellationToken = default) + { + var buffer = new Buffer(content.Length + MaxMessageHeaderLength); + var compressed = false; + + // Reserve space for the frame header + buffer.Advance(MaxMessageHeaderLength); + + if (_encoder is not null && opcode is MessageOpcode.Text or MessageOpcode.Binary) + { + _encoder.Encode(content.Span, ref buffer); + compressed = true; + } + else if (content.Length > 0) + { + content.Span.CopyTo(buffer.GetSpan(content.Length)); + buffer.Advance(content.Length); + } + + var payload = buffer.WrittenSpan.Slice(MaxMessageHeaderLength); + var headerLength = CalculateHeaderLength(payload.Length); + + // Because we want the header to come just before to the payload + // we will use a slice that offsets the unused part. + var headerOffset = MaxMessageHeaderLength - headerLength; + var header = buffer.WrittenSpan.Slice(headerOffset, headerLength); + + // Write the message header data to the buffer. + EncodeHeader(header, opcode, endOfMessage, payload.Length, compressed); + + // If we added a mask to the header, XOR the payload with the mask. + if (payload.Length > 0 && _maskLength > 0) + { + ApplyMask(payload, BitConverter.ToInt32(header.Slice(header.Length - MaskLength)), 0); + } + + var array = buffer.GetArray(); + var releaseArray = true; + + try + { + var sendTask = stream.WriteAsync(new ReadOnlyMemory(array, headerOffset, headerLength + payload.Length), cancellationToken); + + if (sendTask.IsCompleted) + return sendTask; + + releaseArray = false; + return WaitAsync(sendTask.AsTask(), array); + } + finally + { + if (releaseArray) + ArrayPool.Shared.Return(array); + } + } + + private static async ValueTask WaitAsync(Task sendTask, byte[] buffer) + { + try + { + await sendTask.ConfigureAwait(false); + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } + + private int CalculateHeaderLength(int payloadLength) => payloadLength switch + { + <= 125 => 2, + <= ushort.MaxValue => 4, + _ => 10 + } + _maskLength; + + private void EncodeHeader(Span header, MessageOpcode opcode, bool endOfMessage, int payloadLength, bool compressed) + { + // Client header format: + // 1 bit - FIN - 1 if this is the final fragment in the message (it could be the only fragment), otherwise 0 + // 1 bit - RSV1 - Per-Message Deflate Compress + // 1 bit - RSV2 - Reserved - 0 + // 1 bit - RSV3 - Reserved - 0 + // 4 bits - Opcode - How to interpret the payload + // - 0x0 - continuation + // - 0x1 - text + // - 0x2 - binary + // - 0x8 - connection close + // - 0x9 - ping + // - 0xA - pong + // - (0x3 to 0x7, 0xB-0xF - reserved) + // 1 bit - Masked - 1 if the payload is masked, 0 if it's not. Must be 1 for the client + // 7 bits, 7+16 bits, or 7+64 bits - Payload length + // - For length 0 through 125, 7 bits storing the length + // - For lengths 126 through 2^16, 7 bits storing the value 126, followed by 16 bits storing the length + // - For lengths 2^16+1 through 2^64, 7 bits storing the value 127, followed by 64 bytes storing the length + // 0 or 4 bytes - Mask, if Masked is 1 - random value XOR'd with each 4 bytes of the payload, round-robin + // Length bytes - Payload data + header[0] = (byte)opcode; // 4 bits for the opcode + if (endOfMessage) + { + header[0] |= 0b1000_0000; // 1 bit for FIN + } + + if (compressed) + { + header[0] |= 0b0100_0000; + } + + // Store the payload length. + if (payloadLength <= 125) + { + header[1] = (byte)payloadLength; + } + else if (payloadLength <= ushort.MaxValue) + { + header[1] = 126; + header[2] = (byte)(payloadLength / 256); + header[3] = unchecked((byte)payloadLength); + } + else + { + header[1] = 127; + for (int i = 9; i >= 2; i--) + { + header[i] = unchecked((byte)payloadLength); + payloadLength = payloadLength / 256; + } + } + + if (_maskLength > 0) + { + // Generate the mask. + header[1] |= 0x80; + RandomNumberGenerator.Fill(header.Slice(header.Length - _maskLength)); + } + } + + internal ref struct Buffer + { + private byte[] _array; + private int _index; + + public Buffer(int capacity) + { + _array = ArrayPool.Shared.Rent(capacity); + _index = 0; + } + + public Span WrittenSpan => new Span(_array, 0, _index); + + public int FreeCapacity => _array.Length - _index; + + public void Advance(int count) + { + _index += count; + + Debug.Assert(_index >= 0 || _index < _array.Length); + } + + public Span GetSpan(int sizeHint = 0) + { + if (sizeHint == 0) + sizeHint = 1; + + if (sizeHint > FreeCapacity) + { + var newArray = ArrayPool.Shared.Rent(_array.Length + sizeHint); + _array.AsSpan().CopyTo(newArray); + + ArrayPool.Shared.Return(_array); + _array = newArray; + } + + return _array.AsSpan(_index); + } + + public byte[] GetArray() => _array; + } + + private abstract class Encoder : IDisposable + { + public abstract void Dispose(); + + internal abstract void Encode(ReadOnlySpan payload, ref Buffer buffer); + } + + private class Deflate : Encoder + { + private readonly int _windowBits; + + public Deflate(int windowBits) => _windowBits = windowBits; + + public override void Dispose() { } + + internal override void Encode(ReadOnlySpan payload, ref Buffer buffer) + { + using var deflater = new Deflater(_windowBits); + + Encode(payload, ref buffer, deflater); + } + + protected static void Encode(ReadOnlySpan payload, ref Buffer buffer, Deflater deflater) + { + while (payload.Length > 0) + { + deflater.Deflate(payload, buffer.GetSpan(payload.Length), out var consumed, out var written); + buffer.Advance(written); + + payload = payload.Slice(consumed); + } + + while (true) + { + var bytesWritten = deflater.Finish(buffer.GetSpan(), out var completed); + buffer.Advance(bytesWritten); + + if (completed) + break; + } + + // The deflated block always ends with 0x00 0x00 0xFF 0xFF but the websocket protocol doesn't want it. + buffer.Advance(-4); + } + } + + private sealed class PersistedDeflate : Deflate + { + private readonly Deflater _deflater; + + public PersistedDeflate(int windowBits) : base(windowBits) + { + _deflater = new Deflater(windowBits); + } + + public override void Dispose() => _deflater.Dispose(); + + internal override void Encode(ReadOnlySpan payload, ref Buffer buffer) + => Encode(payload, ref buffer, _deflater); + } + } + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 6810c5371ed0e..4dcaa5a3175a2 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -3,13 +3,11 @@ using System.Buffers; using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; using System.IO; using System.Numerics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.Versioning; -using System.Security.Cryptography; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -37,8 +35,6 @@ public static ManagedWebSocket CreateFromConnectedStream(Stream stream, WebSocke return new ManagedWebSocket(stream, options); } - /// Thread-safe random number generator used to generate masks for each send. - private static readonly RandomNumberGenerator s_random = RandomNumberGenerator.Create(); /// Encoding for the payload of text messages: UTF8 encoding that throws if invalid bytes are discovered, per the RFC. private static readonly UTF8Encoding s_textEncoding = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true); @@ -86,9 +82,7 @@ public static ManagedWebSocket CreateFromConnectedStream(Stream stream, WebSocke /// Semaphore used to ensure that calls to SendFrameAsync don't run concurrently. /// private readonly SemaphoreSlim _sendFrameAsyncLock = new SemaphoreSlim(1, 1); - - private readonly bool _compressionEnabled; - + private readonly Sender _sender; // We maintain the current WebSocketState in _state. However, we separately maintain _sentCloseFrame and _receivedCloseFrame // as there isn't a strict ordering between CloseSent and CloseReceived. If we receive a close frame from the server, we need to // transition to CloseReceived even if we're currently in CloseSent, and if we send a close frame, we need to transition to @@ -127,13 +121,6 @@ public static ManagedWebSocket CreateFromConnectedStream(Stream stream, WebSocke /// private int _receivedMaskOffsetOffset; /// - /// Temporary send buffer. This should be released back to the ArrayPool once it's - /// no longer needed for the current send operation. It is stored as an instance - /// field to minimize needing to pass it around and to avoid it becoming a field on - /// various async state machine objects. - /// - private byte[]? _sendBuffer; - /// /// Whether the last SendAsync had endOfMessage==false. We need to track this so that we /// can send the subsequent message with a continuation opcode if the last message was a fragment. /// @@ -166,7 +153,7 @@ private ManagedWebSocket(Stream stream, WebSocketCreationOptions options) _stream = stream; _isServer = options.IsServer; _subprotocol = options.SubProtocol; - _compressionEnabled = options.DeflateOptions is not null; + _sender = new Sender(options); // Create a buffer just large enough to handle received packet headers (at most 14 bytes) and // control payloads (at most 125 bytes). Message payloads are read directly into the buffer @@ -225,6 +212,8 @@ private void DisposeCore() _disposed = true; _keepAliveTimer?.Dispose(); _stream?.Dispose(); + _sender.Dispose(); + if (_state < WebSocketState.Aborted) { _state = WebSocketState.Closed; @@ -440,12 +429,10 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, // If we get here, the cancellation token is not cancelable so we don't have to worry about it, // and we own the semaphore, so we don't need to asynchronously wait for it. ValueTask writeTask = default; - bool releaseSendBufferAndSemaphore = true; + bool releaseSemaphore = true; try { - // Write the payload synchronously to the buffer, then write that buffer out to the network. - int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, payloadBuffer.Span); - writeTask = _stream.WriteAsync(new ReadOnlyMemory(_sendBuffer, 0, sendBytes)); + writeTask = _sender.SendAsync(opcode, endOfMessage, payloadBuffer, _stream); // If the operation happens to complete synchronously (or, more specifically, by // the time we get from the previous line to here), release the semaphore, return @@ -458,7 +445,7 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, // Up until this point, if an exception occurred (such as when accessing _stream or when // calling GetResult), we want to release the semaphore and the send buffer. After this point, // both need to be held until writeTask completes. - releaseSendBufferAndSemaphore = false; + releaseSemaphore = false; } catch (Exception exc) { @@ -469,9 +456,8 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, } finally { - if (releaseSendBufferAndSemaphore) + if (releaseSemaphore) { - ReleaseSendBuffer(); _sendFrameAsyncLock.Release(); } } @@ -493,7 +479,6 @@ private async ValueTask WaitForWriteTaskAsync(ValueTask writeTask) } finally { - ReleaseSendBuffer(); _sendFrameAsyncLock.Release(); } } @@ -503,10 +488,9 @@ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfM await _sendFrameAsyncLock.WaitAsync(cancellationToken).ConfigureAwait(false); try { - int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, payloadBuffer.Span); using (cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this)) { - await _stream.WriteAsync(new ReadOnlyMemory(_sendBuffer, 0, sendBytes), cancellationToken).ConfigureAwait(false); + await _sender.SendAsync(opcode, endOfMessage, payloadBuffer, _stream, cancellationToken).ConfigureAwait(false); } } catch (Exception exc) when (!(exc is OperationCanceledException)) @@ -517,51 +501,10 @@ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfM } finally { - ReleaseSendBuffer(); _sendFrameAsyncLock.Release(); } } - /// Writes a frame into the send buffer, which can then be sent over the network. - private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, ReadOnlySpan payloadBuffer) - { - // Ensure we have a _sendBuffer. - AllocateSendBuffer(payloadBuffer.Length + MaxMessageHeaderLength); - - // Write the message header data to the buffer. - int headerLength; - int? maskOffset = null; - if (_isServer) - { - // The server doesn't send a mask, so the mask offset returned by WriteHeader - // is actually the end of the header. - headerLength = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: false); - } - else - { - // We need to know where the mask starts so that we can use the mask to manipulate the payload data, - // and we need to know the total length for sending it on the wire. - maskOffset = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: true); - headerLength = maskOffset.GetValueOrDefault() + MaskLength; - } - - // Write the payload - if (payloadBuffer.Length > 0) - { - payloadBuffer.CopyTo(new Span(_sendBuffer, headerLength, payloadBuffer.Length)); - - // If we added a mask to the header, XOR the payload with the mask. We do the manipulation in the send buffer so as to avoid - // changing the data in the caller-supplied payload buffer. - if (maskOffset.HasValue) - { - ApplyMask(new Span(_sendBuffer, headerLength, payloadBuffer.Length), _sendBuffer, maskOffset.Value, 0); - } - } - - // Return the number of bytes in the send buffer - return headerLength + payloadBuffer.Length; - } - private void SendKeepAliveFrameAsync() { bool acquiredLock = _sendFrameAsyncLock.Wait(0); @@ -591,85 +534,6 @@ private void SendKeepAliveFrameAsync() } } - private int WriteHeader(MessageOpcode opcode, byte[] sendBuffer, ReadOnlySpan payload, bool endOfMessage, bool useMask) - { - // Client header format: - // 1 bit - FIN - 1 if this is the final fragment in the message (it could be the only fragment), otherwise 0 - // 1 bit - RSV1 - Per-Message Deflate Compress - // 1 bit - RSV2 - Reserved - 0 - // 1 bit - RSV3 - Reserved - 0 - // 4 bits - Opcode - How to interpret the payload - // - 0x0 - continuation - // - 0x1 - text - // - 0x2 - binary - // - 0x8 - connection close - // - 0x9 - ping - // - 0xA - pong - // - (0x3 to 0x7, 0xB-0xF - reserved) - // 1 bit - Masked - 1 if the payload is masked, 0 if it's not. Must be 1 for the client - // 7 bits, 7+16 bits, or 7+64 bits - Payload length - // - For length 0 through 125, 7 bits storing the length - // - For lengths 126 through 2^16, 7 bits storing the value 126, followed by 16 bits storing the length - // - For lengths 2^16+1 through 2^64, 7 bits storing the value 127, followed by 64 bytes storing the length - // 0 or 4 bytes - Mask, if Masked is 1 - random value XOR'd with each 4 bytes of the payload, round-robin - // Length bytes - Payload data - - Debug.Assert(sendBuffer.Length >= MaxMessageHeaderLength, $"Expected sendBuffer to be at least {MaxMessageHeaderLength}, got {sendBuffer.Length}"); - - sendBuffer[0] = (byte)opcode; // 4 bits for the opcode - if (endOfMessage) - { - sendBuffer[0] |= 0b1000_0000; // 1 bit for FIN - } - - if (_compressionEnabled && opcode is MessageOpcode.Text or MessageOpcode.Binary) - { - sendBuffer[0] |= 0b0100_0000; - } - - // Store the payload length. - int maskOffset; - if (payload.Length <= 125) - { - sendBuffer[1] = (byte)payload.Length; - maskOffset = 2; // no additional payload length - } - else if (payload.Length <= ushort.MaxValue) - { - sendBuffer[1] = 126; - sendBuffer[2] = (byte)(payload.Length / 256); - sendBuffer[3] = unchecked((byte)payload.Length); - maskOffset = 2 + sizeof(ushort); // additional 2 bytes for 16-bit length - } - else - { - sendBuffer[1] = 127; - int length = payload.Length; - for (int i = 9; i >= 2; i--) - { - sendBuffer[i] = unchecked((byte)length); - length = length / 256; - } - maskOffset = 2 + sizeof(ulong); // additional 8 bytes for 64-bit length - } - - if (useMask) - { - // Generate the mask. - sendBuffer[1] |= 0x80; - WriteRandomMask(sendBuffer, maskOffset); - } - - // Return the position of the mask. - return maskOffset; - } - - /// Writes a 4-byte random mask to the specified buffer at the specified offset. - /// The buffer to which to write the mask. - /// The offset into the buffer at which to write the mask. - private static void WriteRandomMask(byte[] buffer, int offset) => - s_random.GetBytes(buffer, offset, MaskLength); - /// /// Receive the next text, binary, continuation, or close message, returning information about it and /// writing its payload into the supplied buffer. Other control messages may be consumed and processed @@ -1334,27 +1198,6 @@ private void ThrowIfEOFUnexpected(bool throwOnPrematureClosure) } } - /// Gets a send buffer from the pool. - [MemberNotNull(nameof(_sendBuffer))] - private void AllocateSendBuffer(int minLength) - { - Debug.Assert(_sendBuffer == null); // would only fail if had some catastrophic error previously that prevented cleaning up - _sendBuffer = ArrayPool.Shared.Rent(minLength); - } - - /// Releases the send buffer to the pool. - private void ReleaseSendBuffer() - { - Debug.Assert(_sendFrameAsyncLock.CurrentCount == 0, "Caller should hold the _sendFrameAsyncLock"); - - byte[]? old = _sendBuffer; - if (old != null) - { - _sendBuffer = null; - ArrayPool.Shared.Return(old); - } - } - private static int CombineMaskBytes(Span buffer, int maskOffset) => BitConverter.ToInt32(buffer.Slice(maskOffset)); From e8499ff75a0ee12b28d6bfa82946957097477d8a Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Sun, 14 Feb 2021 12:43:42 +0200 Subject: [PATCH 04/52] Fixed a bug in the sender implementation where if a non persisted deflate was used and a message were to be sent with multiple frames the implementation would have used seprate deflates for each of the frames. --- .../Net/WebSockets/ManagedWebSocket.Sender.cs | 105 +++++++++++------- 1 file changed, 64 insertions(+), 41 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs index 03c6f8d4b9bef..f40d0a6165bcb 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs @@ -15,6 +15,8 @@ internal partial class ManagedWebSocket { private sealed class Sender : IDisposable { + private const byte PerMessageDeflateBit = 0b0100_0000; + private readonly int _maskLength; private readonly Encoder? _encoder; @@ -31,35 +33,32 @@ public Sender(WebSocketCreationOptions options) if (options.IsServer) { _encoder = deflate.ServerContextTakeover ? - new Deflate(-deflate.ServerMaxWindowBits) : - new PersistedDeflate(-deflate.ServerMaxWindowBits); + new Deflater(-deflate.ServerMaxWindowBits) : + new PersistedDeflater(-deflate.ServerMaxWindowBits); } else { _encoder = deflate.ClientContextTakeover ? - new Deflate(-deflate.ClientMaxWindowBits) : - new PersistedDeflate(-deflate.ClientMaxWindowBits); + new Deflater(-deflate.ClientMaxWindowBits) : + new PersistedDeflater(-deflate.ClientMaxWindowBits); } } } - public void Dispose() - { - _encoder?.Dispose(); - } + public void Dispose() => _encoder?.Dispose(); public ValueTask SendAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemory content, Stream stream, CancellationToken cancellationToken = default) { var buffer = new Buffer(content.Length + MaxMessageHeaderLength); - var compressed = false; + byte reservedBits = 0; // Reserve space for the frame header buffer.Advance(MaxMessageHeaderLength); - if (_encoder is not null && opcode is MessageOpcode.Text or MessageOpcode.Binary) + // Encoding is onlt supported for user messages + if (_encoder is not null && opcode <= MessageOpcode.Continuation) { - _encoder.Encode(content.Span, ref buffer); - compressed = true; + _encoder.Encode(content.Span, ref buffer, continuation: opcode == MessageOpcode.Continuation, endOfMessage, out reservedBits); } else if (content.Length > 0) { @@ -76,7 +75,7 @@ public ValueTask SendAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemo var header = buffer.WrittenSpan.Slice(headerOffset, headerLength); // Write the message header data to the buffer. - EncodeHeader(header, opcode, endOfMessage, payload.Length, compressed); + EncodeHeader(header, opcode, endOfMessage, payload.Length, reservedBits); // If we added a mask to the header, XOR the payload with the mask. if (payload.Length > 0 && _maskLength > 0) @@ -84,23 +83,22 @@ public ValueTask SendAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemo ApplyMask(payload, BitConverter.ToInt32(header.Slice(header.Length - MaskLength)), 0); } - var array = buffer.GetArray(); var releaseArray = true; try { - var sendTask = stream.WriteAsync(new ReadOnlyMemory(array, headerOffset, headerLength + payload.Length), cancellationToken); + var sendTask = stream.WriteAsync(new ReadOnlyMemory(buffer.Array, headerOffset, headerLength + payload.Length), cancellationToken); if (sendTask.IsCompleted) return sendTask; releaseArray = false; - return WaitAsync(sendTask.AsTask(), array); + return WaitAsync(sendTask.AsTask(), buffer.Array); } finally { if (releaseArray) - ArrayPool.Shared.Return(array); + ArrayPool.Shared.Return(buffer.Array); } } @@ -123,8 +121,12 @@ private static async ValueTask WaitAsync(Task sendTask, byte[] buffer) _ => 10 } + _maskLength; - private void EncodeHeader(Span header, MessageOpcode opcode, bool endOfMessage, int payloadLength, bool compressed) + private void EncodeHeader(Span header, MessageOpcode opcode, bool endOfMessage, int payloadLength, byte reservedBits) { + // The current implementation only supports per message deflate extension. In the future + // if more extensions are implemented or we allow third party extensions this assert must be changed. + Debug.Assert((reservedBits | 0b0100_0000) == 0b0100_0000); + // Client header format: // 1 bit - FIN - 1 if this is the final fragment in the message (it could be the only fragment), otherwise 0 // 1 bit - RSV1 - Per-Message Deflate Compress @@ -146,16 +148,13 @@ private void EncodeHeader(Span header, MessageOpcode opcode, bool endOfMes // 0 or 4 bytes - Mask, if Masked is 1 - random value XOR'd with each 4 bytes of the payload, round-robin // Length bytes - Payload data header[0] = (byte)opcode; // 4 bits for the opcode + header[0] |= reservedBits; + if (endOfMessage) { header[0] |= 0b1000_0000; // 1 bit for FIN } - if (compressed) - { - header[0] |= 0b0100_0000; - } - // Store the payload length. if (payloadLength <= 125) { @@ -185,6 +184,10 @@ private void EncodeHeader(Span header, MessageOpcode opcode, bool endOfMes } } + /// + /// Helper class which allows writing to a rent'ed byte array + /// and auto-grow functionality. + /// internal ref struct Buffer { private byte[] _array; @@ -198,6 +201,8 @@ public Buffer(int capacity) public Span WrittenSpan => new Span(_array, 0, _index); + public byte[] Array => _array; + public int FreeCapacity => _array.Length - _index; public void Advance(int count) @@ -223,33 +228,48 @@ public Span GetSpan(int sizeHint = 0) return _array.AsSpan(_index); } - - public byte[] GetArray() => _array; } private abstract class Encoder : IDisposable { public abstract void Dispose(); - internal abstract void Encode(ReadOnlySpan payload, ref Buffer buffer); + internal abstract void Encode(ReadOnlySpan payload, ref Buffer buffer, bool continuation, bool endOfMessage, out byte reservedBits); } - private class Deflate : Encoder + /// + /// Deflate encoder which doesn't persist the deflator accross messages. + /// + private class Deflater : Encoder { private readonly int _windowBits; - public Deflate(int windowBits) => _windowBits = windowBits; + // Although the inflater isn't persisted accross messages, a single message + // might be split into multiple frames. + private IO.Compression.Deflater? _deflater; - public override void Dispose() { } + public Deflater(int windowBits) => _windowBits = windowBits; - internal override void Encode(ReadOnlySpan payload, ref Buffer buffer) + public override void Dispose() => _deflater?.Dispose(); + + internal override void Encode(ReadOnlySpan payload, ref Buffer buffer, bool continuation, bool endOfMessage, out byte reservedBits) { - using var deflater = new Deflater(_windowBits); + Debug.Assert((continuation && _deflater is not null) || (!continuation && _deflater is null), + "Invalid state. The deflater was expected to be null if not continuation and not null otherwise."); + + _deflater ??= new IO.Compression.Deflater(_windowBits); + + Encode(payload, ref buffer, _deflater); + reservedBits = continuation ? 0 : PerMessageDeflateBit; - Encode(payload, ref buffer, deflater); + if (endOfMessage) + { + _deflater.Dispose(); + _deflater = null; + } } - protected static void Encode(ReadOnlySpan payload, ref Buffer buffer, Deflater deflater) + public static void Encode(ReadOnlySpan payload, ref Buffer buffer, IO.Compression.Deflater deflater) { while (payload.Length > 0) { @@ -273,19 +293,22 @@ protected static void Encode(ReadOnlySpan payload, ref Buffer buffer, Defl } } - private sealed class PersistedDeflate : Deflate + /// + /// Deflate encoder which persists the deflator state accross messages. + /// + private sealed class PersistedDeflater : Encoder { - private readonly Deflater _deflater; + private readonly IO.Compression.Deflater _deflater; - public PersistedDeflate(int windowBits) : base(windowBits) - { - _deflater = new Deflater(windowBits); - } + public PersistedDeflater(int windowBits) => _deflater = new(windowBits); public override void Dispose() => _deflater.Dispose(); - internal override void Encode(ReadOnlySpan payload, ref Buffer buffer) - => Encode(payload, ref buffer, _deflater); + internal override void Encode(ReadOnlySpan payload, ref Buffer buffer, bool continuation, bool endOfMessage, out byte reservedBits) + { + Deflater.Encode(payload, ref buffer, _deflater); + reservedBits = continuation ? 0 : PerMessageDeflateBit; + } } } } From c4fdf98e02a56cfd9657f77cc35a402fbbc1e23b Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Mon, 15 Feb 2021 16:32:47 +0200 Subject: [PATCH 05/52] Removed tests that were testing RCV1 flag (per message compression) which is now supported. --- .../System.Net.WebSockets/tests/WebSocketCreateTest.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs index 7f391f7754ec6..dfe3bdc92ae2a 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs @@ -105,10 +105,9 @@ public async Task ReceiveAsync_UTF8SplitAcrossMultipleBuffers_ValidDataReceived( [PlatformSpecific(~TestPlatforms.Browser)] // System.Net.Sockets is not supported on this platform. [ActiveIssue("https://github.com/dotnet/runtime/issues/34690", TestPlatforms.Windows, TargetFrameworkMonikers.Netcoreapp, TestRuntimes.Mono)] [InlineData(0b_1000_0001, 0b_0_000_0001, false)] // fin + text, no mask + length == 1 - [InlineData(0b_1100_0001, 0b_0_000_0001, true)] // fin + rsv1 + text, no mask + length == 1 [InlineData(0b_1010_0001, 0b_0_000_0001, true)] // fin + rsv2 + text, no mask + length == 1 [InlineData(0b_1001_0001, 0b_0_000_0001, true)] // fin + rsv3 + text, no mask + length == 1 - [InlineData(0b_1111_0001, 0b_0_000_0001, true)] // fin + rsv1 + rsv2 + rsv3 + text, no mask + length == 1 + [InlineData(0b_1011_0001, 0b_0_000_0001, true)] // fin + rsv2 + rsv3 + text, no mask + length == 1 [InlineData(0b_1000_0001, 0b_1_000_0001, true)] // fin + text, mask + length == 1 [InlineData(0b_1000_0011, 0b_0_000_0001, true)] // fin + opcode==3, no mask + length == 1 [InlineData(0b_1000_0100, 0b_0_000_0001, true)] // fin + opcode==4, no mask + length == 1 From 224f88e23f66be63444f563996cd9f9bfe349ce9 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Mon, 15 Feb 2021 16:36:22 +0200 Subject: [PATCH 06/52] Added receiver encoders and implementation for per message compression. --- .../src/System.Net.WebSockets.csproj | 1 + .../src/System/IO/Compression/Deflater.cs | 17 +- .../src/System/IO/Compression/Inflater.cs | 268 +---------- .../WebSockets/ManagedWebSocket.Receiver.cs | 445 +++++++++++++++++ .../Net/WebSockets/ManagedWebSocket.Sender.cs | 8 +- .../System/Net/WebSockets/ManagedWebSocket.cs | 450 +++++------------- 6 files changed, 601 insertions(+), 588 deletions(-) create mode 100644 src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs diff --git a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj index 2ed0d8a17b9d9..9ebda23e96f1b 100644 --- a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj +++ b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj @@ -5,6 +5,7 @@ enable + diff --git a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Deflater.cs b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Deflater.cs index 3f718eb8d85e8..2324cf69f0d52 100644 --- a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Deflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Deflater.cs @@ -14,14 +14,6 @@ namespace System.IO.Compression internal sealed class Deflater : IDisposable { private readonly ZLibNative.ZLibStreamHandle _handle; - private bool _isDisposed; - - // Note, DeflateStream or the deflater do not try to be thread safe. - // The lock is just used to make writing to unmanaged structures atomic to make sure - // that they do not get inconsistent fields that may lead to an unmanaged memory violation. - // To prevent *managed* buffer corruption or other weird behaviour users need to synchronise - // on the stream explicitly. - private object SyncLock => this; internal Deflater(int windowBits) { @@ -58,14 +50,7 @@ internal Deflater(int windowBits) } } - public void Dispose() - { - if (!_isDisposed) - { - _handle.Dispose(); - _isDisposed = true; - } - } + public void Dispose() => _handle.Dispose(); public unsafe void Deflate(ReadOnlySpan input, Span output, out int consumed, out int written) { diff --git a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Inflater.cs b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Inflater.cs index d289b3e199d1d..4a5b8a490b498 100644 --- a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Inflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Inflater.cs @@ -1,11 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; using System.Net.WebSockets; -using System.Runtime.InteropServices; -using System.Security; namespace System.IO.Compression { @@ -14,222 +10,14 @@ namespace System.IO.Compression /// internal sealed class Inflater : IDisposable { - private const int MinWindowBits = -15; // WindowBits must be between -8..-15 to ignore the header, 8..15 for - private const int MaxWindowBits = 47; // zlib headers, 24..31 for GZip headers, or 40..47 for either Zlib or GZip + private readonly ZLibNative.ZLibStreamHandle _handle; - private bool _finished; // Whether the end of the stream has been reached - private bool _isDisposed; // Prevents multiple disposals - private readonly int _windowBits; // The WindowBits parameter passed to Inflater construction - private ZLibNative.ZLibStreamHandle _zlibStream; // The handle to the primary underlying zlib stream - private GCHandle _inputBufferHandle; // The handle to the buffer that provides input to _zlibStream - private readonly long _uncompressedSize; - private long _currentInflatedCount; - - private object SyncLock => this; // Used to make writing to unmanaged structures atomic - - /// - /// Initialized the Inflater with the given windowBits size - /// - internal Inflater(int windowBits, long uncompressedSize = -1) - { - Debug.Assert(windowBits >= MinWindowBits && windowBits <= MaxWindowBits); - _finished = false; - _isDisposed = false; - _windowBits = windowBits; - InflateInit(windowBits); - _uncompressedSize = uncompressedSize; - } - - public int AvailableOutput => (int)_zlibStream.AvailOut; - - /// - /// Returns true if the end of the stream has been reached. - /// - public bool Finished() => _finished; - - public unsafe bool Inflate(out byte b) - { - fixed (byte* bufPtr = &b) - { - int bytesRead = InflateVerified(bufPtr, 1); - Debug.Assert(bytesRead == 0 || bytesRead == 1); - return bytesRead != 0; - } - } - - public unsafe int Inflate(byte[] bytes, int offset, int length) - { - // If Inflate is called on an invalid or unready inflater, return 0 to indicate no bytes have been read. - if (length == 0) - return 0; - - Debug.Assert(null != bytes, "Can't pass in a null output buffer!"); - fixed (byte* bufPtr = bytes) - { - return InflateVerified(bufPtr + offset, length); - } - } - - public unsafe int Inflate(Span destination) - { - // If Inflate is called on an invalid or unready inflater, return 0 to indicate no bytes have been read. - if (destination.Length == 0) - return 0; - - fixed (byte* bufPtr = &MemoryMarshal.GetReference(destination)) - { - return InflateVerified(bufPtr, destination.Length); - } - } - - public unsafe int InflateVerified(byte* bufPtr, int length) - { - // State is valid; attempt inflation - try - { - int bytesRead = 0; - if (_uncompressedSize == -1) - { - ReadOutput(bufPtr, length, out bytesRead); - } - else - { - if (_uncompressedSize > _currentInflatedCount) - { - length = (int)Math.Min(length, _uncompressedSize - _currentInflatedCount); - ReadOutput(bufPtr, length, out bytesRead); - _currentInflatedCount += bytesRead; - } - else - { - _finished = true; - _zlibStream.AvailIn = 0; - } - } - return bytesRead; - } - finally - { - // Before returning, make sure to release input buffer if necessary: - if (0 == _zlibStream.AvailIn && _inputBufferHandle.IsAllocated) - { - DeallocateInputBufferHandle(); - } - } - } - - private unsafe void ReadOutput(byte* bufPtr, int length, out int bytesRead) - { - if (ReadInflateOutput(bufPtr, length, ZLibNative.FlushCode.NoFlush, out bytesRead) == ZLibNative.ErrorCode.StreamEnd) - { - if (!NeedsInput() && IsGzipStream() && _inputBufferHandle.IsAllocated) - { - _finished = ResetStreamForLeftoverInput(); - } - else - { - _finished = true; - } - } - } - - /// - /// If this stream has some input leftover that hasn't been processed then we should - /// check if it is another GZip file concatenated with this one. - /// - /// Returns false if the leftover input is another GZip data stream. - /// - private unsafe bool ResetStreamForLeftoverInput() - { - Debug.Assert(!NeedsInput()); - Debug.Assert(IsGzipStream()); - Debug.Assert(_inputBufferHandle.IsAllocated); - - lock (SyncLock) - { - IntPtr nextInPtr = _zlibStream.NextIn; - byte* nextInPointer = (byte*)nextInPtr.ToPointer(); - uint nextAvailIn = _zlibStream.AvailIn; - - // Check the leftover bytes to see if they start with he gzip header ID bytes - if (*nextInPointer != ZLibNative.GZip_Header_ID1 || (nextAvailIn > 1 && *(nextInPointer + 1) != ZLibNative.GZip_Header_ID2)) - { - return true; - } - - // Trash our existing zstream. - _zlibStream.Dispose(); - - // Create a new zstream - InflateInit(_windowBits); - - // SetInput on the new stream to the bits remaining from the last stream - _zlibStream.NextIn = nextInPtr; - _zlibStream.AvailIn = nextAvailIn; - _finished = false; - } - - return false; - } - - internal bool IsGzipStream() => _windowBits >= 24 && _windowBits <= 31; - - public bool NeedsInput() => _zlibStream.AvailIn == 0; - - public void SetInput(byte[] inputBuffer, int startIndex, int count) - { - Debug.Assert(NeedsInput(), "We have something left in previous input!"); - Debug.Assert(inputBuffer != null); - Debug.Assert(startIndex >= 0 && count >= 0 && count + startIndex <= inputBuffer.Length); - Debug.Assert(!_inputBufferHandle.IsAllocated); - - if (0 == count) - return; - - lock (SyncLock) - { - _inputBufferHandle = GCHandle.Alloc(inputBuffer, GCHandleType.Pinned); - _zlibStream.NextIn = _inputBufferHandle.AddrOfPinnedObject() + startIndex; - _zlibStream.AvailIn = (uint)count; - _finished = false; - } - } - - private void Dispose(bool disposing) - { - if (!_isDisposed) - { - if (disposing) - _zlibStream.Dispose(); - - if (_inputBufferHandle.IsAllocated) - DeallocateInputBufferHandle(); - - _isDisposed = true; - } - } - - public void Dispose() - { - Dispose(true); - GC.SuppressFinalize(this); - } - - ~Inflater() - { - Dispose(false); - } - - /// - /// Creates the ZStream that will handle inflation. - /// - [MemberNotNull(nameof(_zlibStream))] - private void InflateInit(int windowBits) + internal Inflater(int windowBits) { ZLibNative.ErrorCode error; try { - error = ZLibNative.CreateZLibStreamForInflate(out _zlibStream, windowBits); + error = ZLibNative.CreateZLibStreamForInflate(out _handle, windowBits); } catch (Exception exception) // could not load the ZLib dll { @@ -255,45 +43,48 @@ private void InflateInit(int windowBits) } } - /// - /// Wrapper around the ZLib inflate function, configuring the stream appropriately. - /// - private unsafe ZLibNative.ErrorCode ReadInflateOutput(byte* bufPtr, int length, ZLibNative.FlushCode flushCode, out int bytesRead) + internal unsafe void Inflate(ReadOnlySpan input, Span output, out int consumed, out int written) { - lock (SyncLock) + fixed (byte* fixedInput = input) + fixed (byte* fixedOutput = output) { - _zlibStream.NextOut = (IntPtr)bufPtr; - _zlibStream.AvailOut = (uint)length; + _handle.NextIn = (IntPtr)fixedInput; + _handle.AvailIn = (uint)input.Length; + + _handle.NextOut = (IntPtr)fixedOutput; + _handle.AvailOut = (uint)output.Length; - ZLibNative.ErrorCode errC = Inflate(flushCode); - bytesRead = length - (int)_zlibStream.AvailOut; + Inflate(ZLibNative.FlushCode.NoFlush); - return errC; + consumed = input.Length - (int)_handle.AvailIn; + written = output.Length - (int)_handle.AvailOut; } } + public void Dispose() => _handle.Dispose(); + /// /// Wrapper around the ZLib inflate function /// private ZLibNative.ErrorCode Inflate(ZLibNative.FlushCode flushCode) { - ZLibNative.ErrorCode errC; + ZLibNative.ErrorCode errorCode; try { - errC = _zlibStream.Inflate(flushCode); + errorCode = _handle.Inflate(flushCode); } catch (Exception cause) // could not load the Zlib DLL correctly { throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); } - switch (errC) + switch (errorCode) { case ZLibNative.ErrorCode.Ok: // progress has been made inflating case ZLibNative.ErrorCode.StreamEnd: // The end of the input stream has been reached - return errC; + return errorCode; case ZLibNative.ErrorCode.BufError: // No room in the output buffer - inflate() can be called again with more space to continue - return errC; + return errorCode; case ZLibNative.ErrorCode.MemError: // Not enough memory to complete the operation throw new WebSocketException(SR.ZLibErrorNotEnoughMemory); @@ -305,22 +96,7 @@ private ZLibNative.ErrorCode Inflate(ZLibNative.FlushCode flushCode) throw new WebSocketException(SR.ZLibErrorInconsistentStream); default: - throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errC)); - } - } - - /// - /// Frees the GCHandle being used to store the input buffer - /// - private void DeallocateInputBufferHandle() - { - Debug.Assert(_inputBufferHandle.IsAllocated); - - lock (SyncLock) - { - _zlibStream.AvailIn = 0; - _zlibStream.NextIn = ZLibNative.ZNullPtr; - _inputBufferHandle.Free(); + throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errorCode)); } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs new file mode 100644 index 0000000000000..c600ec3e02457 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs @@ -0,0 +1,445 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.WebSockets +{ + internal partial class ManagedWebSocket + { + private const int ReceivedConnectionClose = -1; + private const int ReceivedControlMessage = -2; + private const int ReceivedHeaderError = -3; + + private sealed class Receiver : IDisposable + { + private readonly bool _isServer; + private readonly Stream _stream; + private readonly Decoder? _decoder; + + /// + /// If we have a decoder we cannot use the buffer provided from clients because + /// we cannot guarantee that the decoding can happen in place. This buffer is rent'ed + /// and returned when consumed. + /// + private byte[]? _decoderBuffer; + + /// + /// The next index that needs to be consumed from the decoder's buffer. + /// + private int _decoderBufferPosition; + + /// + /// The number of usable bytes in the decoder's buffer. + /// + private int _decoderBufferCount; + + /// + /// The last header received in a ReceiveAsync. If ReceiveAsync got a header but then + /// returned fewer bytes than was indicated in the header, subsequent ReceiveAsync calls + /// will use the data from the header to construct the subsequent receive results, and + /// the payload length in this header will be decremented to indicate the number of bytes + /// remaining to be received for that header. As a result, between fragments, the payload + /// length in this header should be 0. + /// + private MessageHeader _lastHeader = new() { Opcode = MessageOpcode.Text, Fin = true }; + + /// + /// Because user messages can have continuations (split onto multiple frames) + /// but also we need to know the actual message type (text or binary) we need + /// a seperate field to track if the last received header is actually a continuation. + /// + private bool _continuation; + + /// + /// Buffer used for reading data from the network. + /// Not readonly here because the buffer is mutable and is a struct. + /// + private Buffer _readBuffer; + + /// + /// When dealing with partially read fragments of binary/text messages, a mask previously received may still + /// apply, and the first new byte received may not correspond to the 0th position in the mask. This value is + /// the next offset into the mask that should be applied. + /// + private int _receivedMaskOffset; + + public Receiver(Stream stream, WebSocketCreationOptions options) + { + _stream = stream; + _isServer = options.IsServer; + + // Create a buffer just large enough to handle received packet headers (at most 14 bytes) and + // control payloads (at most 125 bytes). Message payloads are read directly into the buffer + // supplied to ReceiveAsync. + _readBuffer = new(MaxControlPayloadLength + MaxMessageHeaderLength); + + var deflate = options.DeflateOptions; + + if (deflate is not null) + { + // Important note here is that we must use negative window bits + // which will instruct the underlying implementation to not expect deflate headers + if (options.IsServer) + { + _decoder = deflate.ServerContextTakeover ? + new Inflater(-deflate.ServerMaxWindowBits) : + new PersistedInflater(-deflate.ServerMaxWindowBits); + } + else + { + _decoder = deflate.ClientContextTakeover ? + new Inflater(-deflate.ClientMaxWindowBits) : + new PersistedInflater(-deflate.ClientMaxWindowBits); + } + } + } + + public void Dispose() => _decoder?.Dispose(); + + public MessageHeader GetLastHeader() => _lastHeader; + + /// Issues a read on the stream to wait for EOF. + public async ValueTask WaitForServerToCloseConnectionAsync(CancellationToken cancellationToken) + { + if (_readBuffer.FreeLength == 0) + { + // Because we are going to need only 1 byte buffer, do a discard + // only when necessary (avoiding needless copying). + _readBuffer.DiscardConsumed(); + } + // Per RFC 6455 7.1.1, try to let the server close the connection. We give it up to a second. + // We simply issue a read and don't care what we get back; we could validate that we don't get + // additional data, but at this point we're about to close the connection and we're just stalling + // to try to get the server to close first. + ValueTask finalReadTask = _stream.ReadAsync(_readBuffer.FreeMemory.Slice(start: 0, length: 1), cancellationToken); + + if (!finalReadTask.IsCompletedSuccessfully) + { + // Wait an arbitrary amount of time to give the server (same as netfx, 1 second) + using var cts = finalReadTask.IsCompleted ? null : new CancellationTokenSource(TimeSpan.FromSeconds(1)); + using var cancellation = cts is not null ? cts.Token.UnsafeRegister(static s => ((Stream)s!).Dispose(), _stream) : default; + + // TODO: Once this is merged https://github.com/dotnet/runtime/issues/47525 + // use configure await with the option to suppress exceptions and remove the try catch + try + { + await finalReadTask.ConfigureAwait(false); + } + catch + { + // Eat any resulting exceptions. We were going to close the connection, anyway. + } + } + } + + public async ValueTask ReceiveControlMessageAsync(CancellationToken cancellationToken) + { + Debug.Assert(_lastHeader.Opcode > MessageOpcode.Binary); + + if (_lastHeader.PayloadLength == 0) + return new ControlMessage(_lastHeader.Opcode, ReadOnlyMemory.Empty); + + _readBuffer.DiscardConsumed(); + + while (_lastHeader.PayloadLength > _readBuffer.AvailableLength) + { + var byteCount = await _stream.ReadAsync(_readBuffer.FreeMemory, cancellationToken).ConfigureAwait(false); + if (byteCount <= 0) + return null; + + ApplyMask(_readBuffer.FreeMemory.Span.Slice(0, (int)Math.Min(_lastHeader.PayloadLength, byteCount))); + _readBuffer.Commit(byteCount); + } + + // Update the payload length in the header to indicate + // that we've received everything we need. + var payload = _readBuffer.Consume((int)_lastHeader.PayloadLength); + _lastHeader.PayloadLength = 0; + + return new ControlMessage(_lastHeader.Opcode, payload); + } + + public async ValueTask ReceiveAsync(Memory buffer, CancellationToken cancellationToken) + { + _readBuffer.DiscardConsumed(); + + // When there's nothing left over to receive, start a new + if (_lastHeader.PayloadLength == 0) + { + var success = await ReceiveHeaderAsync(cancellationToken).ConfigureAwait(false); + + if (!success) + return ReceivedConnectionClose; + + if (_lastHeader.Error is not null) + return ReceivedHeaderError; + + if (_lastHeader.Opcode > MessageOpcode.Binary) + { + // The received message is a control message and it's up + // to the websocket how to handle it. + return ReceivedControlMessage; + } + } + + if (buffer.IsEmpty) + return 0; + + // The number of bytes that are copied onto the provided buffer + var resultByteCount = 0; + + if (_readBuffer.AvailableLength > 0) + { + int consumed, written; + int available = (int)Math.Min(_readBuffer.AvailableLength, _lastHeader.PayloadLength); + + if (_decoder is not null && _decoder.IsNeeded(_lastHeader)) + { + _decoder.Decode(input: _readBuffer.AvailableSpan.Slice(0, available), + output: buffer.Span, out consumed, out written); + } + else + { + written = Math.Min(available, buffer.Length); + consumed = written; + _readBuffer.AvailableSpan.Slice(0, written).CopyTo(buffer.Span); + } + + _readBuffer.Consume(consumed); + _lastHeader.PayloadLength -= consumed; + + if (_lastHeader.PayloadLength == 0 || _readBuffer.AvailableLength > 0) + { + // If the payload length is 0 it means that we have consumed everything. + // Otherwise if available length is still non zero, than it means that the + // decoder needs more memory and the operation cannot continue. + return written; + } + + resultByteCount += written; + buffer = buffer.Slice(written); + } + + // At this point we should have consumed everything from the buffer + // and should start issuing reads on the stream. + Debug.Assert(_readBuffer.AvailableLength == 0 && _lastHeader.PayloadLength > 0); + + if (_decoder is null || !_decoder.IsNeeded(_lastHeader)) + { + if (buffer.Length > _lastHeader.PayloadLength) + { + // We don't want to receive more that we need + buffer = buffer.Slice(0, (int)_lastHeader.PayloadLength); + } + + var bytesRead = await _stream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); + if (bytesRead <= 0) + return ReceivedConnectionClose; + + resultByteCount += bytesRead; + ApplyMask(buffer.Span.Slice(0, bytesRead)); + } + else + { + if (_decoderBuffer is null) + { + // Rent a buffer but restrict it's max size to 1MB + _decoderBuffer = ArrayPool.Shared.Rent((int)Math.Min(_lastHeader.PayloadLength, 1_000_000)); + _decoderBufferCount = await _stream.ReadAsync(_decoderBuffer, cancellationToken).ConfigureAwait(false); + if (_decoderBufferCount <= 0) + { + ArrayPool.Shared.Return(_decoderBuffer); + return ReceivedConnectionClose; + } + + ApplyMask(_decoderBuffer.AsSpan(_decoderBufferPosition, _decoderBufferCount)); + } + + // There is lefover data that we need to decode + _decoder.Decode(input: _decoderBuffer.AsSpan(_decoderBufferPosition, _decoderBufferCount), + output: buffer.Span, out var consumed, out var written); + + resultByteCount += written; + _decoderBufferPosition += consumed; + _decoderBufferCount -= consumed; + _lastHeader.PayloadLength -= consumed; + + if (_decoderBufferCount == 0) + { + ArrayPool.Shared.Return(_decoderBuffer); + _decoderBuffer = null; + _decoderBufferPosition = 0; + } + } + + return resultByteCount; + } + + private async ValueTask ReceiveHeaderAsync(CancellationToken cancellationToken) + { + Debug.Assert(_lastHeader.PayloadLength == 0); + + _receivedMaskOffset = 0; + _decoder?.Reset(); + + while (true) + { + if (TryParseMessageHeader(_readBuffer.AvailableSpan, _lastHeader, _isServer, out var header, out var consumedBytes)) + { + // If this is a continuation, replace the opcode with the one of the message it's continuing + if (_continuation = (header.Opcode == MessageOpcode.Continuation)) + { + header.Opcode = _lastHeader.Opcode; + header.Compressed = _lastHeader.Compressed; + } + + _lastHeader = header; + _readBuffer.Consume(consumedBytes); + + if (_isServer) + { + // Unmask any payload that we've received + if (header.PayloadLength > 0 && _readBuffer.AvailableLength > 0) + { + ApplyMask(_readBuffer.AvailableSpan.Slice(0, (int)Math.Min(_readBuffer.AvailableLength, header.PayloadLength))); + } + } + + break; + } + + // More data is neeed to parse the header + var byteCount = await _stream.ReadAsync(_readBuffer.FreeMemory, cancellationToken).ConfigureAwait(false); + if (byteCount <= 0) + return false; + + _readBuffer.Commit(byteCount); + } + + return true; + } + + private void ApplyMask(Span input) + { + if (_isServer) + { + _receivedMaskOffset = ManagedWebSocket.ApplyMask(input, _lastHeader.Mask, _receivedMaskOffset); + } + } + + [StructLayout(LayoutKind.Auto)] + private struct Buffer + { + private readonly byte[] _bytes; + private int _position; + private int _consumed; + + public Buffer(int capacity) + { + _bytes = new byte[capacity]; + _position = 0; + _consumed = 0; + } + + public int AvailableLength => _position - _consumed; + + public Span AvailableSpan => + new Span(_bytes, start: _consumed, length: _position - _consumed); + + public Memory AvailableMemory => + new Memory(_bytes, start: _consumed, length: _position - _consumed); + + public Memory FreeMemory => _bytes.AsMemory(_position); + + public int FreeLength => _bytes.Length - _position; + + public void Commit(int count) + { + _position += count; + } + + public Memory Consume(int count) + { + var memory = new Memory(_bytes, _consumed, count); + _consumed += count; + + return memory; + } + + public void DiscardConsumed() + { + if (AvailableLength > 0) + { + AvailableMemory.CopyTo(_bytes); + } + + _position -= _consumed; + _consumed = 0; + } + } + + private abstract class Decoder : IDisposable + { + public abstract bool IsNeeded(MessageHeader header); + + public abstract void Dispose(); + + public abstract void Reset(); + + public abstract void Decode(ReadOnlySpan input, Span output, out int consumed, out int written); + } + + private class Inflater : Decoder + { + private readonly int _windowBits; + + // Although the inflater isn't persisted accross messages, a single message + // might have been split into multiple frames. + private IO.Compression.Inflater? _inflater; + + public Inflater(int windowBits) => _windowBits = windowBits; + + public override bool IsNeeded(MessageHeader header) => header.Compressed; + + public override void Dispose() => _inflater?.Dispose(); + + public override void Reset() + { + _inflater?.Dispose(); + _inflater = null; + } + + public override void Decode(ReadOnlySpan input, Span output, out int consumed, out int written) + { + _inflater ??= new IO.Compression.Inflater(_windowBits); + _inflater.Inflate(input, output, out consumed, out written); + } + } + + private sealed class PersistedInflater : Decoder + { + private readonly IO.Compression.Inflater _inflater; + + public PersistedInflater(int windowBits) => _inflater = new(windowBits); + + public override bool IsNeeded(MessageHeader header) => header.Compressed; + + public override void Dispose() => _inflater.Dispose(); + + public override void Reset() { } + + public override void Decode(ReadOnlySpan input, Span output, out int consumed, out int written) + { + _inflater.Inflate(input, output, out consumed, out written); + } + } + } + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs index f40d0a6165bcb..5f45b3044ddf9 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs @@ -19,10 +19,12 @@ private sealed class Sender : IDisposable private readonly int _maskLength; private readonly Encoder? _encoder; + private readonly Stream _stream; - public Sender(WebSocketCreationOptions options) + public Sender(Stream stream, WebSocketCreationOptions options) { _maskLength = options.IsServer ? 0 : MaskLength; + _stream = stream; var deflate = options.DeflateOptions; @@ -47,7 +49,7 @@ public Sender(WebSocketCreationOptions options) public void Dispose() => _encoder?.Dispose(); - public ValueTask SendAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemory content, Stream stream, CancellationToken cancellationToken = default) + public ValueTask SendAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemory content, CancellationToken cancellationToken = default) { var buffer = new Buffer(content.Length + MaxMessageHeaderLength); byte reservedBits = 0; @@ -87,7 +89,7 @@ public ValueTask SendAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemo try { - var sendTask = stream.WriteAsync(new ReadOnlyMemory(buffer.Array, headerOffset, headerLength + payload.Length), cancellationToken); + var sendTask = _stream.WriteAsync(new ReadOnlyMemory(buffer.Array, headerOffset, headerLength + payload.Length), cancellationToken); if (sendTask.IsCompleted) return sendTask; diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 4dcaa5a3175a2..2b701cdfd8eb3 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -72,17 +72,14 @@ public static ManagedWebSocket CreateFromConnectedStream(Stream stream, WebSocke private readonly Timer? _keepAliveTimer; /// CancellationTokenSource used to abort all current and future operations when anything is canceled or any error occurs. private readonly CancellationTokenSource _abortSource = new CancellationTokenSource(); - /// Buffer used for reading data from the network. - private readonly Memory _receiveBuffer; - /// - /// Tracks the state of the validity of the UTF8 encoding of text payloads. Text may be split across fragments. - /// - private readonly Utf8MessageState _utf8TextState = new Utf8MessageState(); + /// /// Semaphore used to ensure that calls to SendFrameAsync don't run concurrently. /// private readonly SemaphoreSlim _sendFrameAsyncLock = new SemaphoreSlim(1, 1); private readonly Sender _sender; + private readonly Receiver _receiver; + // We maintain the current WebSocketState in _state. However, we separately maintain _sentCloseFrame and _receivedCloseFrame // as there isn't a strict ordering between CloseSent and CloseReceived. If we receive a close frame from the server, we need to // transition to CloseReceived even if we're currently in CloseSent, and if we send a close frame, we need to transition to @@ -102,24 +99,10 @@ public static ManagedWebSocket CreateFromConnectedStream(Stream stream, WebSocke private string? _closeStatusDescription; /// - /// The last header received in a ReceiveAsync. If ReceiveAsync got a header but then - /// returned fewer bytes than was indicated in the header, subsequent ReceiveAsync calls - /// will use the data from the header to construct the subsequent receive results, and - /// the payload length in this header will be decremented to indicate the number of bytes - /// remaining to be received for that header. As a result, between fragments, the payload - /// length in this header should be 0. - /// - private MessageHeader _lastReceiveHeader = new MessageHeader { Opcode = MessageOpcode.Text, Fin = true }; - /// The offset of the next available byte in the _receiveBuffer. - private int _receiveBufferOffset; - /// The number of bytes available in the _receiveBuffer. - private int _receiveBufferCount; - /// - /// When dealing with partially read fragments of binary/text messages, a mask previously received may still - /// apply, and the first new byte received may not correspond to the 0th position in the mask. This value is - /// the next offset into the mask that should be applied. + /// Tracks the state of the validity of the UTF8 encoding of text payloads. Text may be split across fragments. /// - private int _receivedMaskOffsetOffset; + private readonly Utf8MessageState _utf8TextState = new Utf8MessageState(); + /// /// Whether the last SendAsync had endOfMessage==false. We need to track this so that we /// can send the subsequent message with a continuation opcode if the last message was a fragment. @@ -138,10 +121,13 @@ public static ManagedWebSocket CreateFromConnectedStream(Stream stream, WebSocke /// be pending while a Close{Output}Async is issued, which itself needs to loop until a close frame is received. /// As such, we need thread-safety in the management of . /// - private object ReceiveAsyncLock => _utf8TextState; // some object, as we're simply lock'ing on it + private object ReceiveAsyncLock => _receiver; // some object, as we're simply lock'ing on it private ManagedWebSocket(Stream stream, WebSocketCreationOptions options) { + _sender = new Sender(stream, options); + _receiver = new Receiver(stream, options); + Debug.Assert(StateUpdateLock != null, $"Expected {nameof(StateUpdateLock)} to be non-null"); Debug.Assert(ReceiveAsyncLock != null, $"Expected {nameof(ReceiveAsyncLock)} to be non-null"); Debug.Assert(StateUpdateLock != ReceiveAsyncLock, "Locks should be different objects"); @@ -153,13 +139,6 @@ private ManagedWebSocket(Stream stream, WebSocketCreationOptions options) _stream = stream; _isServer = options.IsServer; _subprotocol = options.SubProtocol; - _sender = new Sender(options); - - // Create a buffer just large enough to handle received packet headers (at most 14 bytes) and - // control payloads (at most 125 bytes). Message payloads are read directly into the buffer - // supplied to ReceiveAsync. - const int ReceiveBufferMinLength = MaxControlPayloadLength; - _receiveBuffer = new byte[ReceiveBufferMinLength]; // Set up the abort source so that if it's triggered, we transition the instance appropriately. // There's no need to store the resulting CancellationTokenRegistration, as this instance owns @@ -432,7 +411,7 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, bool releaseSemaphore = true; try { - writeTask = _sender.SendAsync(opcode, endOfMessage, payloadBuffer, _stream); + writeTask = _sender.SendAsync(opcode, endOfMessage, payloadBuffer); // If the operation happens to complete synchronously (or, more specifically, by // the time we get from the previous line to here), release the semaphore, return @@ -490,7 +469,7 @@ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfM { using (cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this)) { - await _sender.SendAsync(opcode, endOfMessage, payloadBuffer, _stream, cancellationToken).ConfigureAwait(false); + await _sender.SendAsync(opcode, endOfMessage, payloadBuffer, cancellationToken).ConfigureAwait(false); } } catch (Exception exc) when (!(exc is OperationCanceledException)) @@ -562,134 +541,67 @@ private async ValueTask ReceiveAsyncPrivate 125) + continue; + } + else { - int minNeeded = - 2 + - (_isServer ? MaskLength : 0) + - (payloadLength <= 125 ? 0 : payloadLength == 126 ? sizeof(ushort) : sizeof(ulong)); // additional 2 or 8 bytes for 16-bit or 64-bit length - await EnsureBufferContainsAsync(minNeeded, cancellationToken).ConfigureAwait(false); + Debug.Assert(message.Opcode == MessageOpcode.Close); + + await HandleReceivedCloseAsync(message.Payload, cancellationToken).ConfigureAwait(false); + return resultGetter.GetResult(0, WebSocketMessageType.Close, true, _closeStatus, _closeStatusDescription); } } - - string? headerErrorMessage = TryParseMessageHeaderFromReceiveBuffer(out header); - if (headerErrorMessage != null) + else if (byteCount == ReceivedConnectionClose) { - await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted, headerErrorMessage).ConfigureAwait(false); + ThrowIfEOFUnexpected(true); } - _receivedMaskOffsetOffset = 0; - } - - // If the header represents a ping or a pong, it's a control message meant - // to be transparent to the user, so handle it and then loop around to read again. - // Alternatively, if it's a close message, handle it and exit. - if (header.Opcode == MessageOpcode.Ping || header.Opcode == MessageOpcode.Pong) - { - await HandleReceivedPingPongAsync(header, cancellationToken).ConfigureAwait(false); - continue; - } - else if (header.Opcode == MessageOpcode.Close) - { - await HandleReceivedCloseAsync(header, cancellationToken).ConfigureAwait(false); - return resultGetter.GetResult(0, WebSocketMessageType.Close, true, _closeStatus, _closeStatusDescription); - } - - // If this is a continuation, replace the opcode with the one of the message it's continuing - if (header.Opcode == MessageOpcode.Continuation) - { - header.Opcode = _lastReceiveHeader.Opcode; - } - - // The message should now be a binary or text message. Handle it by reading the payload and returning the contents. - Debug.Assert(header.Opcode == MessageOpcode.Binary || header.Opcode == MessageOpcode.Text, $"Unexpected opcode {header.Opcode}"); - - // If there's no data to read, return an appropriate result. - if (header.PayloadLength == 0 || payloadBuffer.Length == 0) - { - _lastReceiveHeader = header; - return resultGetter.GetResult( - 0, - header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary, - header.Fin && header.PayloadLength == 0, - null, null); - } - - // Otherwise, read as much of the payload as we can efficiently, and update the header to reflect how much data - // remains for future reads. We first need to copy any data that may be lingering in the receive buffer - // into the destination; then to minimize ReceiveAsync calls, we want to read as much as we can, stopping - // only when we've either read the whole message or when we've filled the payload buffer. - - // First copy any data lingering in the receive buffer. - int totalBytesReceived = 0; - if (_receiveBufferCount > 0) - { - int receiveBufferBytesToCopy = Math.Min(payloadBuffer.Length, (int)Math.Min(header.PayloadLength, _receiveBufferCount)); - Debug.Assert(receiveBufferBytesToCopy > 0); - _receiveBuffer.Span.Slice(_receiveBufferOffset, receiveBufferBytesToCopy).CopyTo(payloadBuffer.Span); - ConsumeFromBuffer(receiveBufferBytesToCopy); - totalBytesReceived += receiveBufferBytesToCopy; - Debug.Assert( - _receiveBufferCount == 0 || - totalBytesReceived == payloadBuffer.Length || - totalBytesReceived == header.PayloadLength); - } - - // Then read directly into the payload buffer until we've hit a limit. - while (totalBytesReceived < payloadBuffer.Length && - totalBytesReceived < header.PayloadLength) - { - int numBytesRead = await _stream.ReadAsync(payloadBuffer.Slice( - totalBytesReceived, - (int)Math.Min(payloadBuffer.Length, header.PayloadLength) - totalBytesReceived), cancellationToken).ConfigureAwait(false); - if (numBytesRead <= 0) + else { - ThrowIfEOFUnexpected(throwOnPrematureClosure: true); - break; + Debug.Assert(byteCount == ReceivedHeaderError); + + var error = _receiver.GetLastHeader().Error; + await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted, error).ConfigureAwait(false); } - totalBytesReceived += numBytesRead; } - if (_isServer) - { - _receivedMaskOffsetOffset = ApplyMask(payloadBuffer.Span.Slice(0, totalBytesReceived), header.Mask, _receivedMaskOffsetOffset); - } - header.PayloadLength -= totalBytesReceived; + var header = _receiver.GetLastHeader(); // If this a text message, validate that it contains valid UTF8. - if (header.Opcode == MessageOpcode.Text && - !TryValidateUtf8(payloadBuffer.Span.Slice(0, totalBytesReceived), header.Fin && header.PayloadLength == 0, _utf8TextState)) + if (header.Opcode == MessageOpcode.Text && byteCount > 0 && + !TryValidateUtf8(payloadBuffer.Span.Slice(0, byteCount), header.Fin && header.PayloadLength == 0, _utf8TextState)) { await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.InvalidPayloadData, WebSocketError.Faulted).ConfigureAwait(false); } - _lastReceiveHeader = header; return resultGetter.GetResult( - totalBytesReceived, - header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary, - header.Fin && header.PayloadLength == 0, - null, null); + count: byteCount, + messageType: header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary, + endOfMessage: header.Fin && header.PayloadLength == 0, + closeStatus: null, closeDescription: null); } } - catch (Exception exc) when (!(exc is OperationCanceledException)) + catch (Exception exc) when (exc is not OperationCanceledException) { if (_state == WebSocketState.Aborted) { @@ -711,10 +623,7 @@ private async ValueTask ReceiveAsyncPrivateProcesses a received close message. - /// The message header. - /// The CancellationToken used to cancel the websocket operation. - /// The received result message. - private async ValueTask HandleReceivedCloseAsync(MessageHeader header, CancellationToken cancellationToken) + private async ValueTask HandleReceivedCloseAsync(ReadOnlyMemory payload, CancellationToken cancellationToken) { lock (StateUpdateLock) { @@ -733,41 +642,30 @@ private async ValueTask HandleReceivedCloseAsync(MessageHeader header, Cancellat string closeStatusDescription = string.Empty; // Handle any payload by parsing it into the close status and description. - if (header.PayloadLength == 1) + if (payload.Length == 1) { // The close payload length can be 0 or >= 2, but not 1. await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted).ConfigureAwait(false); } - else if (header.PayloadLength >= 2) + else if (payload.Length >= 2) { - if (_receiveBufferCount < header.PayloadLength) - { - await EnsureBufferContainsAsync((int)header.PayloadLength, cancellationToken).ConfigureAwait(false); - } - - if (_isServer) - { - ApplyMask(_receiveBuffer.Span.Slice(_receiveBufferOffset, (int)header.PayloadLength), header.Mask, 0); - } - - closeStatus = (WebSocketCloseStatus)(_receiveBuffer.Span[_receiveBufferOffset] << 8 | _receiveBuffer.Span[_receiveBufferOffset + 1]); + closeStatus = (WebSocketCloseStatus)(payload.Span[0] << 8 | payload.Span[1]); if (!IsValidCloseStatus(closeStatus)) { await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted).ConfigureAwait(false); } - if (header.PayloadLength > 2) + if (payload.Length > 2) { try { - closeStatusDescription = s_textEncoding.GetString(_receiveBuffer.Span.Slice(_receiveBufferOffset + 2, (int)header.PayloadLength - 2)); + closeStatusDescription = s_textEncoding.GetString(payload.Span.Slice(2)); } catch (DecoderFallbackException exc) { await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted, innerException: exc).ConfigureAwait(false); } } - ConsumeFromBuffer((int)header.PayloadLength); } // Store the close status and description onto the instance. @@ -776,66 +674,7 @@ private async ValueTask HandleReceivedCloseAsync(MessageHeader header, Cancellat if (!_isServer && _sentCloseFrame) { - await WaitForServerToCloseConnectionAsync(cancellationToken).ConfigureAwait(false); - } - } - - /// Issues a read on the stream to wait for EOF. - private async ValueTask WaitForServerToCloseConnectionAsync(CancellationToken cancellationToken) - { - // Per RFC 6455 7.1.1, try to let the server close the connection. We give it up to a second. - // We simply issue a read and don't care what we get back; we could validate that we don't get - // additional data, but at this point we're about to close the connection and we're just stalling - // to try to get the server to close first. - ValueTask finalReadTask = _stream.ReadAsync(_receiveBuffer, cancellationToken); - if (!finalReadTask.IsCompletedSuccessfully) - { - const int WaitForCloseTimeoutMs = 1_000; // arbitrary amount of time to give the server (same as netfx) - using (var finalCts = new CancellationTokenSource(WaitForCloseTimeoutMs)) - using (finalCts.Token.Register(static s => ((ManagedWebSocket)s!).Abort(), this)) - { - try - { - await finalReadTask.ConfigureAwait(false); - } - catch - { - // Eat any resulting exceptions. We were going to close the connection, anyway. - } - } - } - } - - /// Processes a received ping or pong message. - /// The message header. - /// The CancellationToken used to cancel the websocket operation. - private async ValueTask HandleReceivedPingPongAsync(MessageHeader header, CancellationToken cancellationToken) - { - // Consume any (optional) payload associated with the ping/pong. - if (header.PayloadLength > 0 && _receiveBufferCount < header.PayloadLength) - { - await EnsureBufferContainsAsync((int)header.PayloadLength, cancellationToken).ConfigureAwait(false); - } - - // If this was a ping, send back a pong response. - if (header.Opcode == MessageOpcode.Ping) - { - if (_isServer) - { - ApplyMask(_receiveBuffer.Span.Slice(_receiveBufferOffset, (int)header.PayloadLength), header.Mask, 0); - } - - await SendFrameAsync( - MessageOpcode.Pong, - endOfMessage: true, - _receiveBuffer.Slice(_receiveBufferOffset, (int)header.PayloadLength), - cancellationToken).ConfigureAwait(false); - } - - // Regardless of whether it was a ping or pong, we no longer need the payload. - if (header.PayloadLength > 0) - { - ConsumeFromBuffer((int)header.PayloadLength); + await _receiver.WaitForServerToCloseConnectionAsync(cancellationToken).ConfigureAwait(false); } } @@ -892,97 +731,98 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync( await CloseOutputAsync(closeStatus, string.Empty, default).ConfigureAwait(false); } - // Dump our receive buffer; we're in a bad state to do any further processing - _receiveBufferCount = 0; - // Let the caller know we've failed throw errorMessage != null ? new WebSocketException(error, errorMessage, innerException) : new WebSocketException(error, innerException); } - /// Parses a message header from the buffer. This assumes the header is in the buffer. - /// The read header. - /// null if a valid header was read; non-null containing the string error message to use if the header was invalid. - private string? TryParseMessageHeaderFromReceiveBuffer(out MessageHeader resultHeader) + private static bool TryParseMessageHeader( + ReadOnlySpan buffer, + MessageHeader previousHeader, + bool isServer, + out MessageHeader header, + out int consumedBytes) { - Debug.Assert(_receiveBufferCount >= 2, $"Expected to at least have the first two bytes of the header."); + header = default; + consumedBytes = 0; + + if (buffer.Length < 2) + return false; + + // Check first for reserved bits that should always be unset + if ((buffer[0] & 0b0011_0000) != 0) + return Error(ref header, SR.net_Websockets_ReservedBitsSet); - MessageHeader header = default; - Span receiveBufferSpan = _receiveBuffer.Span; + header.Fin = (buffer[0] & 0x80) != 0; + header.Opcode = (MessageOpcode)(buffer[0] & 0xF); + header.Compressed = (buffer[0] & 0b0100_0000) != 0; - header.Fin = (receiveBufferSpan[_receiveBufferOffset] & 0x80) != 0; - bool reservedSet = (receiveBufferSpan[_receiveBufferOffset] & 0b0011_0000) != 0; - header.Opcode = (MessageOpcode)(receiveBufferSpan[_receiveBufferOffset] & 0xF); - header.Compressed = (receiveBufferSpan[_receiveBufferOffset] & 0b0100_0000) != 0; + bool masked = (buffer[1] & 0x80) != 0; + if (masked && !isServer) + return Error(ref header, SR.net_Websockets_ClientReceivedMaskedFrame); - bool masked = (receiveBufferSpan[_receiveBufferOffset + 1] & 0x80) != 0; - header.PayloadLength = receiveBufferSpan[_receiveBufferOffset + 1] & 0x7F; + header.PayloadLength = buffer[1] & 0x7F; - ConsumeFromBuffer(2); + // We've consumed the first 2 bytes + buffer = buffer.Slice(2); + consumedBytes += 2; // Read the remainder of the payload length, if necessary if (header.PayloadLength == 126) { - Debug.Assert(_receiveBufferCount >= 2, $"Expected to have two bytes for the payload length."); - header.PayloadLength = (receiveBufferSpan[_receiveBufferOffset] << 8) | receiveBufferSpan[_receiveBufferOffset + 1]; - ConsumeFromBuffer(2); + if (buffer.Length < 2) + return false; + + header.PayloadLength = (buffer[0] << 8) | buffer[1]; + buffer = buffer.Slice(2); + consumedBytes += 2; } else if (header.PayloadLength == 127) { - Debug.Assert(_receiveBufferCount >= 8, $"Expected to have eight bytes for the payload length."); + if (buffer.Length < 8) + return false; + header.PayloadLength = 0; - for (int i = 0; i < 8; i++) + for (int i = 0; i < 8; ++i) { - header.PayloadLength = (header.PayloadLength << 8) | receiveBufferSpan[_receiveBufferOffset + i]; + header.PayloadLength = (header.PayloadLength << 8) | buffer[i]; } - ConsumeFromBuffer(8); - } - - if (reservedSet) - { - resultHeader = default; - return SR.net_Websockets_ReservedBitsSet; + buffer = buffer.Slice(8); + consumedBytes += 8; } if (masked) { - if (!_isServer) - { - resultHeader = default; - return SR.net_Websockets_ClientReceivedMaskedFrame; - } - header.Mask = CombineMaskBytes(receiveBufferSpan, _receiveBufferOffset); + if (buffer.Length < MaskLength) + return false; - // Consume the mask bytes - ConsumeFromBuffer(4); + header.Mask = BitConverter.ToInt32(buffer); + consumedBytes += MaskLength; } // Do basic validation of the header switch (header.Opcode) { case MessageOpcode.Continuation: - if (_lastReceiveHeader.Fin) + if (previousHeader.Fin) { // Can't continue from a final message - resultHeader = default; - return SR.net_Websockets_ContinuationFromFinalFrame; + return Error(ref header, SR.net_Websockets_ContinuationFromFinalFrame); } if (header.Compressed) { // Per-Message Compressed flag must be set only in the first frame - resultHeader = default; - return SR.net_Websockets_PerMessageCompressedFlagInContinuation; + return Error(ref header, SR.net_Websockets_PerMessageCompressedFlagInContinuation); } break; case MessageOpcode.Binary: case MessageOpcode.Text: - if (!_lastReceiveHeader.Fin) + if (!previousHeader.Fin) { // Must continue from a non-final message - resultHeader = default; - return SR.net_Websockets_NonContinuationAfterNonFinalFrame; + return Error(ref header, SR.net_Websockets_NonContinuationAfterNonFinalFrame); } break; @@ -992,20 +832,22 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync( if (header.PayloadLength > MaxControlPayloadLength || !header.Fin) { // Invalid control messgae - resultHeader = default; - return SR.net_Websockets_InvalidControlMessage; + return Error(ref header, SR.net_Websockets_InvalidControlMessage); } break; default: // Unknown opcode - resultHeader = default; - return SR.Format(SR.net_Websockets_UnknownOpcode, header.Opcode); + return Error(ref header, SR.Format(SR.net_Websockets_UnknownOpcode, header.Opcode)); } - // Return the read header - resultHeader = header; - return null; + return true; + + static bool Error(ref MessageHeader header, string error) + { + header.Error = error; + return false; + } } /// Send a close message, then receive until we get a close response message. @@ -1141,44 +983,7 @@ private async ValueTask SendCloseFrameAsync(WebSocketCloseStatus closeStatus, st if (!_isServer && _receivedCloseFrame) { - await WaitForServerToCloseConnectionAsync(cancellationToken).ConfigureAwait(false); - } - } - - private void ConsumeFromBuffer(int count) - { - Debug.Assert(count >= 0, $"Expected non-negative count, got {count}"); - Debug.Assert(count <= _receiveBufferCount, $"Trying to consume {count}, which is more than exists {_receiveBufferCount}"); - _receiveBufferCount -= count; - _receiveBufferOffset += count; - } - - private async ValueTask EnsureBufferContainsAsync(int minimumRequiredBytes, CancellationToken cancellationToken, bool throwOnPrematureClosure = true) - { - Debug.Assert(minimumRequiredBytes <= _receiveBuffer.Length, $"Requested number of bytes {minimumRequiredBytes} must not exceed {_receiveBuffer.Length}"); - - // If we don't have enough data in the buffer to satisfy the minimum required, read some more. - if (_receiveBufferCount < minimumRequiredBytes) - { - // If there's any data in the buffer, shift it down. - if (_receiveBufferCount > 0) - { - _receiveBuffer.Span.Slice(_receiveBufferOffset, _receiveBufferCount).CopyTo(_receiveBuffer.Span); - } - _receiveBufferOffset = 0; - - // While we don't have enough data, read more. - while (_receiveBufferCount < minimumRequiredBytes) - { - int numRead = await _stream.ReadAsync(_receiveBuffer.Slice(_receiveBufferCount, _receiveBuffer.Length - _receiveBufferCount), cancellationToken).ConfigureAwait(false); - Debug.Assert(numRead >= 0, $"Expected non-negative bytes read, got {numRead}"); - if (numRead <= 0) - { - ThrowIfEOFUnexpected(throwOnPrematureClosure); - break; - } - _receiveBufferCount += numRead; - } + await _receiver.WaitForServerToCloseConnectionAsync(cancellationToken).ConfigureAwait(false); } } @@ -1201,19 +1006,6 @@ private void ThrowIfEOFUnexpected(bool throwOnPrematureClosure) private static int CombineMaskBytes(Span buffer, int maskOffset) => BitConverter.ToInt32(buffer.Slice(maskOffset)); - /// Applies a mask to a portion of a byte array. - /// The buffer to which the mask should be applied. - /// The array containing the mask to apply. - /// The offset into of the mask to apply of length . - /// The next position offset from of which by to apply next from the mask. - /// The updated maskOffsetOffset value. - private static int ApplyMask(Span toMask, byte[] mask, int maskOffset, int maskOffsetIndex) - { - Debug.Assert(maskOffsetIndex < MaskLength, $"Unexpected {nameof(maskOffsetIndex)}: {maskOffsetIndex}"); - Debug.Assert(mask.Length >= MaskLength + maskOffset, $"Unexpected inputs: {mask.Length}, {maskOffset}"); - return ApplyMask(toMask, CombineMaskBytes(mask, maskOffset), maskOffsetIndex); - } - /// Applies a mask to a portion of a byte array. /// The buffer to which the mask should be applied. /// The four-byte mask, stored as an Int32. @@ -1430,6 +1222,18 @@ private struct MessageHeader internal long PayloadLength; internal int Mask; internal bool Compressed; + internal string? Error; + } + + private readonly struct ControlMessage + { + internal ControlMessage(MessageOpcode opcode, ReadOnlyMemory payload) + { + Opcode = opcode; + Payload = payload; + } + internal MessageOpcode Opcode { get; } + internal ReadOnlyMemory Payload { get; } } /// From 600d0fcd8f1ead76353d8aaefebe8a515569c9f0 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Mon, 15 Feb 2021 16:36:42 +0200 Subject: [PATCH 07/52] Removed unused namespace. --- .../src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs index c600ec3e02457..c9a802218587f 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs @@ -3,7 +3,6 @@ using System.Buffers; using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; using System.IO; using System.Runtime.InteropServices; using System.Threading; From 0890bf8556388691fadedfc007772c98665200e5 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Mon, 15 Feb 2021 16:39:05 +0200 Subject: [PATCH 08/52] Fixed decoder state reset to not be called in continuations. --- .../Net/WebSockets/ManagedWebSocket.Receiver.cs | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs index c9a802218587f..60c22c97db794 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs @@ -49,13 +49,6 @@ private sealed class Receiver : IDisposable /// private MessageHeader _lastHeader = new() { Opcode = MessageOpcode.Text, Fin = true }; - /// - /// Because user messages can have continuations (split onto multiple frames) - /// but also we need to know the actual message type (text or binary) we need - /// a seperate field to track if the last received header is actually a continuation. - /// - private bool _continuation; - /// /// Buffer used for reading data from the network. /// Not readonly here because the buffer is mutable and is a struct. @@ -286,18 +279,21 @@ private async ValueTask ReceiveHeaderAsync(CancellationToken cancellationT Debug.Assert(_lastHeader.PayloadLength == 0); _receivedMaskOffset = 0; - _decoder?.Reset(); while (true) { if (TryParseMessageHeader(_readBuffer.AvailableSpan, _lastHeader, _isServer, out var header, out var consumedBytes)) { // If this is a continuation, replace the opcode with the one of the message it's continuing - if (_continuation = (header.Opcode == MessageOpcode.Continuation)) + if (header.Opcode == MessageOpcode.Continuation) { header.Opcode = _lastHeader.Opcode; header.Compressed = _lastHeader.Compressed; } + else + { + _decoder?.Reset(); + } _lastHeader = header; _readBuffer.Consume(consumedBytes); @@ -390,6 +386,9 @@ private abstract class Decoder : IDisposable public abstract void Dispose(); + /// + /// Resets the decoder state after fully processing a message. + /// public abstract void Reset(); public abstract void Decode(ReadOnlySpan input, Span output, out int consumed, out int written); From fb5665460bd3412faa99b2141299bf02274d45d5 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Mon, 15 Feb 2021 16:40:59 +0200 Subject: [PATCH 09/52] Calling dispose for receiver. --- .../Net/WebSockets/ManagedWebSocket.Receiver.cs | 11 ++++++++++- .../src/System/Net/WebSockets/ManagedWebSocket.cs | 1 + 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs index 60c22c97db794..dc08be485462f 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs @@ -93,7 +93,16 @@ public Receiver(Stream stream, WebSocketCreationOptions options) } } - public void Dispose() => _decoder?.Dispose(); + public void Dispose() + { + _decoder?.Dispose(); + + if (_decoderBuffer is not null) + { + ArrayPool.Shared.Return(_decoderBuffer); + _decoderBuffer = null; + } + } public MessageHeader GetLastHeader() => _lastHeader; diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 2b701cdfd8eb3..9b24f2d405cff 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -192,6 +192,7 @@ private void DisposeCore() _keepAliveTimer?.Dispose(); _stream?.Dispose(); _sender.Dispose(); + _receiver.Dispose(); if (_state < WebSocketState.Aborted) { From 7b3b453d32aeb5c659503ba8eecb9df6b8e15df4 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Mon, 15 Feb 2021 18:38:11 +0200 Subject: [PATCH 10/52] Removed unncecessary renting of memory when waiting for close message. --- .../System/Net/WebSockets/ManagedWebSocket.cs | 68 ++++++++----------- 1 file changed, 30 insertions(+), 38 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 9b24f2d405cff..063dba9639a57 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -359,14 +359,14 @@ public override ValueTask ReceiveAsync(Memory } } - private Task ValidateAndReceiveAsync(Task receiveTask, byte[] buffer, CancellationToken cancellationToken) + private Task ValidateAndReceiveAsync(Task receiveTask, CancellationToken cancellationToken) { if (receiveTask == null || (receiveTask.IsCompletedSuccessfully && !(receiveTask is Task wsrr && wsrr.Result.MessageType == WebSocketMessageType.Close) && !(receiveTask is Task vwsrr && vwsrr.Result.MessageType == WebSocketMessageType.Close))) { - ValueTask vt = ReceiveAsyncPrivate(buffer, cancellationToken); + ValueTask vt = ReceiveAsyncPrivate(Memory.Empty, cancellationToken); receiveTask = vt.IsCompletedSuccessfully ? (vt.Result.MessageType == WebSocketMessageType.Close ? s_cachedCloseTask : Task.CompletedTask) : vt.AsTask(); @@ -879,48 +879,40 @@ private async Task CloseAsyncPrivate(WebSocketCloseStatus closeStatus, string? s if (State == WebSocketState.CloseSent) { // Wait until we've received a close response - byte[] closeBuffer = ArrayPool.Shared.Rent(MaxMessageHeaderLength + MaxControlPayloadLength); - try + while (!_receivedCloseFrame) { - while (!_receivedCloseFrame) + Debug.Assert(!Monitor.IsEntered(StateUpdateLock), $"{nameof(StateUpdateLock)} must never be held when acquiring {nameof(ReceiveAsyncLock)}"); + Task receiveTask; + bool usingExistingReceive; + lock (ReceiveAsyncLock) { - Debug.Assert(!Monitor.IsEntered(StateUpdateLock), $"{nameof(StateUpdateLock)} must never be held when acquiring {nameof(ReceiveAsyncLock)}"); - Task receiveTask; - bool usingExistingReceive; - lock (ReceiveAsyncLock) + // Now that we're holding the ReceiveAsyncLock, double-check that we've not yet received the close frame. + // It could have been received between our check above and now due to a concurrent receive completing. + if (_receivedCloseFrame) { - // Now that we're holding the ReceiveAsyncLock, double-check that we've not yet received the close frame. - // It could have been received between our check above and now due to a concurrent receive completing. - if (_receivedCloseFrame) - { - break; - } - - // We've not yet processed a received close frame, which means we need to wait for a received close to complete. - // There may already be one in flight, in which case we want to just wait for that one rather than kicking off - // another (we don't support concurrent receive operations). We need to kick off a new receive if either we've - // never issued a receive or if the last issued receive completed for reasons other than a close frame. There is - // a race condition here, e.g. if there's a in-flight receive that completes after we check, but that's fine: worst - // case is we then await it, find that it's not what we need, and try again. - receiveTask = _lastReceiveAsync; - Task newReceiveTask = ValidateAndReceiveAsync(receiveTask, closeBuffer, cancellationToken); - usingExistingReceive = ReferenceEquals(receiveTask, newReceiveTask); - _lastReceiveAsync = receiveTask = newReceiveTask; + break; } - // Wait for whatever receive task we have. We'll then loop around again to re-check our state. - // If this is an existing receive, and if we have a cancelable token, we need to register with that - // token while we wait, since it may not be the same one that was given to the receive initially. - Debug.Assert(receiveTask != null); - using (usingExistingReceive ? cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this) : default) - { - await receiveTask.ConfigureAwait(false); - } + // We've not yet processed a received close frame, which means we need to wait for a received close to complete. + // There may already be one in flight, in which case we want to just wait for that one rather than kicking off + // another (we don't support concurrent receive operations). We need to kick off a new receive if either we've + // never issued a receive or if the last issued receive completed for reasons other than a close frame. There is + // a race condition here, e.g. if there's a in-flight receive that completes after we check, but that's fine: worst + // case is we then await it, find that it's not what we need, and try again. + receiveTask = _lastReceiveAsync; + Task newReceiveTask = ValidateAndReceiveAsync(receiveTask, cancellationToken); + usingExistingReceive = ReferenceEquals(receiveTask, newReceiveTask); + _lastReceiveAsync = receiveTask = newReceiveTask; + } + + // Wait for whatever receive task we have. We'll then loop around again to re-check our state. + // If this is an existing receive, and if we have a cancelable token, we need to register with that + // token while we wait, since it may not be the same one that was given to the receive initially. + Debug.Assert(receiveTask != null); + using (usingExistingReceive ? cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this) : default) + { + await receiveTask.ConfigureAwait(false); } - } - finally - { - ArrayPool.Shared.Return(closeBuffer); } } From 25849739d3d1b1b640a37c726b12f4ae92ff2961 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Mon, 15 Feb 2021 18:43:47 +0200 Subject: [PATCH 11/52] Removed unused method. --- .../src/System/Net/WebSockets/ManagedWebSocket.cs | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 063dba9639a57..2381b728f461e 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -996,9 +996,6 @@ private void ThrowIfEOFUnexpected(bool throwOnPrematureClosure) } } - private static int CombineMaskBytes(Span buffer, int maskOffset) => - BitConverter.ToInt32(buffer.Slice(maskOffset)); - /// Applies a mask to a portion of a byte array. /// The buffer to which the mask should be applied. /// The four-byte mask, stored as an Int32. From 7597eb2cb187405a977de447ca1229660392cd1c Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Mon, 15 Feb 2021 18:45:16 +0200 Subject: [PATCH 12/52] Removed unnecessary null check, because the object is never null. --- .../src/System/Net/WebSockets/ManagedWebSocket.cs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 2381b728f461e..487a02047394a 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -361,10 +361,9 @@ public override ValueTask ReceiveAsync(Memory private Task ValidateAndReceiveAsync(Task receiveTask, CancellationToken cancellationToken) { - if (receiveTask == null || - (receiveTask.IsCompletedSuccessfully && - !(receiveTask is Task wsrr && wsrr.Result.MessageType == WebSocketMessageType.Close) && - !(receiveTask is Task vwsrr && vwsrr.Result.MessageType == WebSocketMessageType.Close))) + if ( receiveTask.IsCompletedSuccessfully && + !(receiveTask is Task wsrr && wsrr.Result.MessageType == WebSocketMessageType.Close) && + !(receiveTask is Task vwsrr && vwsrr.Result.MessageType == WebSocketMessageType.Close)) { ValueTask vt = ReceiveAsyncPrivate(Memory.Empty, cancellationToken); receiveTask = From a8cba93d580df8161cb5ecee5bbd3728a3fb9da6 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Tue, 16 Feb 2021 12:12:44 +0200 Subject: [PATCH 13/52] Updated websockets tests csproj to reflect that now the websockets has seperate builds for each platform. --- .../tests/System.Net.WebSockets.Tests.csproj | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj index 7cf0328df31ca..6ff2159ff8507 100644 --- a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj +++ b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj @@ -1,6 +1,6 @@ - + - $(NetCoreAppCurrent) + $(NetCoreAppCurrent)-windows;$(NetCoreAppCurrent)-Unix;$(NetCoreAppCurrent)-Browser From e273bf3bb0693142313e6337081d99280abb65e2 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Tue, 16 Feb 2021 12:14:52 +0200 Subject: [PATCH 14/52] Removed socket listener from a test where we are only testing how the websocket handles incoming data where we could easily just use a memory stream. --- .../tests/WebSocketCreateTest.cs | 45 +++++++------------ 1 file changed, 16 insertions(+), 29 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs index dfe3bdc92ae2a..f9ccd309913db 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs @@ -103,7 +103,6 @@ public async Task ReceiveAsync_UTF8SplitAcrossMultipleBuffers_ValidDataReceived( [Theory] [PlatformSpecific(~TestPlatforms.Browser)] // System.Net.Sockets is not supported on this platform. - [ActiveIssue("https://github.com/dotnet/runtime/issues/34690", TestPlatforms.Windows, TargetFrameworkMonikers.Netcoreapp, TestRuntimes.Mono)] [InlineData(0b_1000_0001, 0b_0_000_0001, false)] // fin + text, no mask + length == 1 [InlineData(0b_1010_0001, 0b_0_000_0001, true)] // fin + rsv2 + text, no mask + length == 1 [InlineData(0b_1001_0001, 0b_0_000_0001, true)] // fin + rsv3 + text, no mask + length == 1 @@ -116,34 +115,22 @@ public async Task ReceiveAsync_UTF8SplitAcrossMultipleBuffers_ValidDataReceived( [InlineData(0b_1000_0111, 0b_0_000_0001, true)] // fin + opcode==7, no mask + length == 1 public async Task ReceiveAsync_InvalidFrameHeader_AbortsAndThrowsException(byte firstByte, byte secondByte, bool shouldFail) { - using (Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) - using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) - { - listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); - listener.Listen(1); - - await client.ConnectAsync(listener.LocalEndPoint); - using (Socket server = await listener.AcceptAsync()) - { - WebSocket websocket = CreateFromStream(new NetworkStream(client, ownsSocket: false), isServer: false, null, Timeout.InfiniteTimeSpan); - - await server.SendAsync(new ArraySegment(new byte[3] { firstByte, secondByte, (byte)'a' }), SocketFlags.None); + var stream = new MemoryStream(new byte[3] { firstByte, secondByte, (byte)'a' }); + using var websocket = CreateFromStream(stream, isServer: false, null, Timeout.InfiniteTimeSpan); - var buffer = new byte[1]; - Task t = websocket.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); - if (shouldFail) - { - await Assert.ThrowsAsync(() => t); - Assert.Equal(WebSocketState.Aborted, websocket.State); - } - else - { - WebSocketReceiveResult result = await t; - Assert.True(result.EndOfMessage); - Assert.Equal(1, result.Count); - Assert.Equal('a', (char)buffer[0]); - } - } + var buffer = new byte[1]; + Task t = websocket.ReceiveAsync(buffer, CancellationToken.None); + if (shouldFail) + { + await Assert.ThrowsAsync(() => t); + Assert.Equal(WebSocketState.Aborted, websocket.State); + } + else + { + WebSocketReceiveResult result = await t; + Assert.True(result.EndOfMessage); + Assert.Equal(1, result.Count); + Assert.Equal('a', (char)buffer[0]); } } @@ -308,7 +295,7 @@ private static async Task CreateWebSocketStream(Uri echoUri, Socket clie string statusLine = await reader.ReadLineAsync(); Assert.NotEmpty(statusLine); Assert.Equal("HTTP/1.1 101 Switching Protocols", statusLine); - while (!string.IsNullOrEmpty(await reader.ReadLineAsync())); + while (!string.IsNullOrEmpty(await reader.ReadLineAsync())) ; } return stream; From e26f044c7355a230ce3d6bdb374da5edba83303e Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Tue, 16 Feb 2021 12:28:25 +0200 Subject: [PATCH 15/52] Keep alive interval validated twice when using the existing CreateFromStream method in order to keep the name for the ArgumentException that would happen the same as it was. --- .../src/System/Net/WebSockets/WebSocket.cs | 4 ++++ .../src/System/Net/WebSockets/WebSocketCreationOptions.cs | 8 +++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs index 9d93222577568..b80c9b1482e03 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs @@ -137,6 +137,10 @@ public static ArraySegment CreateServerBuffer(int receiveBufferSize) [UnsupportedOSPlatform("browser")] public static WebSocket CreateFromStream(Stream stream, bool isServer, string? subProtocol, TimeSpan keepAliveInterval) { + if (WebSocketCreationOptions.IsKeepAliveValid(keepAliveInterval)) + throw new ArgumentOutOfRangeException(nameof(keepAliveInterval), keepAliveInterval, + SR.Format(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, 0)); + return CreateFromStream(stream, new WebSocketCreationOptions { IsServer = isServer, diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs index 3cfc9810610a1..310797c299b52 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs @@ -39,7 +39,7 @@ public TimeSpan KeepAliveInterval get => _keepAliveInterval; set { - if (value != Timeout.InfiniteTimeSpan && value < TimeSpan.Zero) + if (!IsKeepAliveValid(value)) throw new ArgumentOutOfRangeException(nameof(KeepAliveInterval), value, SR.Format(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, 0)); @@ -51,5 +51,11 @@ public TimeSpan KeepAliveInterval /// The agreed upon options for per message deflate. /// public WebSocketDeflateOptions? DeflateOptions { get; set; } + + /// + /// Returns whether the provided value is valid websocket keep-alive interval. + /// + internal static bool IsKeepAliveValid(TimeSpan value) => + value == Timeout.InfiniteTimeSpan || value >= TimeSpan.Zero; } } From 80bf74d8c95d8239c4783e210c2847bd0aca2992 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Tue, 16 Feb 2021 12:29:39 +0200 Subject: [PATCH 16/52] The *ContextTakeover properties were interpreted incorrectly. --- .../System/Net/WebSockets/ManagedWebSocket.Receiver.cs | 8 ++++---- .../src/System/Net/WebSockets/ManagedWebSocket.Sender.cs | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs index dc08be485462f..4eb3b3c1aaaaa 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs @@ -81,14 +81,14 @@ public Receiver(Stream stream, WebSocketCreationOptions options) if (options.IsServer) { _decoder = deflate.ServerContextTakeover ? - new Inflater(-deflate.ServerMaxWindowBits) : - new PersistedInflater(-deflate.ServerMaxWindowBits); + new PersistedInflater(-deflate.ServerMaxWindowBits) : + new Inflater(-deflate.ServerMaxWindowBits); } else { _decoder = deflate.ClientContextTakeover ? - new Inflater(-deflate.ClientMaxWindowBits) : - new PersistedInflater(-deflate.ClientMaxWindowBits); + new PersistedInflater(-deflate.ClientMaxWindowBits) : + new Inflater(-deflate.ClientMaxWindowBits); } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs index 5f45b3044ddf9..62fa8198cc7a3 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs @@ -35,14 +35,14 @@ public Sender(Stream stream, WebSocketCreationOptions options) if (options.IsServer) { _encoder = deflate.ServerContextTakeover ? - new Deflater(-deflate.ServerMaxWindowBits) : - new PersistedDeflater(-deflate.ServerMaxWindowBits); + new PersistedDeflater(-deflate.ServerMaxWindowBits) : + new Deflater(-deflate.ServerMaxWindowBits); } else { _encoder = deflate.ClientContextTakeover ? - new Deflater(-deflate.ClientMaxWindowBits) : - new PersistedDeflater(-deflate.ClientMaxWindowBits); + new PersistedDeflater(-deflate.ClientMaxWindowBits) : + new Deflater(-deflate.ClientMaxWindowBits); } } } From 518740e12f118afb05fdb233fb9cbf5801859b08 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Tue, 16 Feb 2021 13:43:15 +0200 Subject: [PATCH 17/52] Fixed a check where ! (not) was missing. --- .../src/System/Net/WebSockets/WebSocket.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs index b80c9b1482e03..8ba34ce1af27e 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs @@ -137,7 +137,7 @@ public static ArraySegment CreateServerBuffer(int receiveBufferSize) [UnsupportedOSPlatform("browser")] public static WebSocket CreateFromStream(Stream stream, bool isServer, string? subProtocol, TimeSpan keepAliveInterval) { - if (WebSocketCreationOptions.IsKeepAliveValid(keepAliveInterval)) + if (!WebSocketCreationOptions.IsKeepAliveValid(keepAliveInterval)) throw new ArgumentOutOfRangeException(nameof(keepAliveInterval), keepAliveInterval, SR.Format(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, 0)); From d517063e292a4346b31c267ad189b6f71ef6d54f Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Tue, 16 Feb 2021 15:30:01 +0200 Subject: [PATCH 18/52] Created basic tests for compression using the examples from RFC. Fixed a few bugs along the way. --- .../src/System/IO/Compression/Inflater.cs | 3 +- .../WebSockets/ManagedWebSocket.Receiver.cs | 39 ++++- .../Net/WebSockets/ManagedWebSocket.Sender.cs | 20 ++- .../System/Net/WebSockets/ManagedWebSocket.cs | 23 +-- .../tests/System.Net.WebSockets.Tests.csproj | 13 +- .../tests/WebSocketCreateTest.cs | 6 +- .../tests/WebSocketDeflateTests.cs | 71 +++++++++ .../tests/WebSocketStream.cs | 138 ++++++++++++++++++ 8 files changed, 280 insertions(+), 33 deletions(-) create mode 100644 src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs create mode 100644 src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs diff --git a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Inflater.cs b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Inflater.cs index 4a5b8a490b498..180b73725195c 100644 --- a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Inflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Inflater.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Net.WebSockets; +using System.Runtime.InteropServices; namespace System.IO.Compression { @@ -46,7 +47,7 @@ internal Inflater(int windowBits) internal unsafe void Inflate(ReadOnlySpan input, Span output, out int consumed, out int written) { fixed (byte* fixedInput = input) - fixed (byte* fixedOutput = output) + fixed (byte* fixedOutput = &MemoryMarshal.GetReference(output)) { _handle.NextIn = (IntPtr)fixedInput; _handle.AvailIn = (uint)input.Length; diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs index 4eb3b3c1aaaaa..00ee7f5adbd24 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs @@ -62,6 +62,12 @@ private sealed class Receiver : IDisposable /// private int _receivedMaskOffset; + /// + /// When parsing message header if an error occurs the websocket is notified and this + /// will contain the error message. + /// + private string? _headerError; + public Receiver(Stream stream, WebSocketCreationOptions options) { _stream = stream; @@ -106,6 +112,8 @@ public void Dispose() public MessageHeader GetLastHeader() => _lastHeader; + public string? GetHeaderError() => _headerError; + /// Issues a read on the stream to wait for EOF. public async ValueTask WaitForServerToCloseConnectionAsync(CancellationToken cancellationToken) { @@ -177,10 +185,7 @@ public async ValueTask ReceiveAsync(Memory buffer, CancellationToken var success = await ReceiveHeaderAsync(cancellationToken).ConfigureAwait(false); if (!success) - return ReceivedConnectionClose; - - if (_lastHeader.Error is not null) - return ReceivedHeaderError; + return _headerError is not null ? ReceivedHeaderError : ReceivedConnectionClose; if (_lastHeader.Opcode > MessageOpcode.Binary) { @@ -265,7 +270,7 @@ public async ValueTask ReceiveAsync(Memory buffer, CancellationToken // There is lefover data that we need to decode _decoder.Decode(input: _decoderBuffer.AsSpan(_decoderBufferPosition, _decoderBufferCount), - output: buffer.Span, out var consumed, out var written); + output: buffer.Span, out var consumed, out var written); resultByteCount += written; _decoderBufferPosition += consumed; @@ -291,7 +296,8 @@ private async ValueTask ReceiveHeaderAsync(CancellationToken cancellationT while (true) { - if (TryParseMessageHeader(_readBuffer.AvailableSpan, _lastHeader, _isServer, out var header, out var consumedBytes)) + if (TryParseMessageHeader(_readBuffer.AvailableSpan, _lastHeader, _isServer, out var header, + out var error, out var consumedBytes)) { // If this is a continuation, replace the opcode with the one of the message it's continuing if (header.Opcode == MessageOpcode.Continuation) @@ -318,6 +324,11 @@ private async ValueTask ReceiveHeaderAsync(CancellationToken cancellationT break; } + else if (error is not null) + { + _headerError = error; + return false; + } // More data is neeed to parse the header var byteCount = await _stream.ReadAsync(_readBuffer.FreeMemory, cancellationToken).ConfigureAwait(false); @@ -432,7 +443,10 @@ public override void Decode(ReadOnlySpan input, Span output, out int private sealed class PersistedInflater : Decoder { + private static ReadOnlySpan Trailer => new byte[] { 0x00, 0x00, 0xFF, 0xFF }; + private readonly IO.Compression.Inflater _inflater; + private bool _needsTrailer; public PersistedInflater(int windowBits) => _inflater = new(windowBits); @@ -440,11 +454,22 @@ private sealed class PersistedInflater : Decoder public override void Dispose() => _inflater.Dispose(); - public override void Reset() { } + public override void Reset() + { + if (_needsTrailer) + { + _needsTrailer = false; + _inflater.Inflate(Trailer, Array.Empty(), out var consumed, out var written); + + Debug.Assert(consumed == 4); + Debug.Assert(written == 0); + } + } public override void Decode(ReadOnlySpan input, Span output, out int consumed, out int written) { _inflater.Inflate(input, output, out consumed, out written); + _needsTrailer = true; } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs index 62fa8198cc7a3..5329deaa6669b 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs @@ -261,7 +261,7 @@ internal override void Encode(ReadOnlySpan payload, ref Buffer buffer, boo _deflater ??= new IO.Compression.Deflater(_windowBits); - Encode(payload, ref buffer, _deflater); + Encode(payload, ref buffer, _deflater, endOfMessage); reservedBits = continuation ? 0 : PerMessageDeflateBit; if (endOfMessage) @@ -271,7 +271,7 @@ internal override void Encode(ReadOnlySpan payload, ref Buffer buffer, boo } } - public static void Encode(ReadOnlySpan payload, ref Buffer buffer, IO.Compression.Deflater deflater) + public static void Encode(ReadOnlySpan payload, ref Buffer buffer, IO.Compression.Deflater deflater, bool final) { while (payload.Length > 0) { @@ -290,8 +290,18 @@ public static void Encode(ReadOnlySpan payload, ref Buffer buffer, IO.Comp break; } - // The deflated block always ends with 0x00 0x00 0xFF 0xFF but the websocket protocol doesn't want it. - buffer.Advance(-4); + if (final) + { + // The deflated block always ends with 0x00 0x00 0xFF 0xFF + // but the websocket protocol doesn't want it. + Debug.Assert( + buffer.WrittenSpan[^4] == 0x00 && + buffer.WrittenSpan[^3] == 0x00 && + buffer.WrittenSpan[^2] == 0xFF && + buffer.WrittenSpan[^1] == 0xFF); + + buffer.Advance(-4); + } } } @@ -308,7 +318,7 @@ private sealed class PersistedDeflater : Encoder internal override void Encode(ReadOnlySpan payload, ref Buffer buffer, bool continuation, bool endOfMessage, out byte reservedBits) { - Deflater.Encode(payload, ref buffer, _deflater); + Deflater.Encode(payload, ref buffer, _deflater, endOfMessage); reservedBits = continuation ? 0 : PerMessageDeflateBit; } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 487a02047394a..80023b8ce5858 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -580,7 +580,7 @@ private async ValueTask ReceiveAsyncPrivate MaxControlPayloadLength || !header.Fin) { // Invalid control messgae - return Error(ref header, SR.net_Websockets_InvalidControlMessage); + return Error(ref error, SR.net_Websockets_InvalidControlMessage); } break; default: // Unknown opcode - return Error(ref header, SR.Format(SR.net_Websockets_UnknownOpcode, header.Opcode)); + return Error(ref error, SR.Format(SR.net_Websockets_UnknownOpcode, header.Opcode)); } return true; - static bool Error(ref MessageHeader header, string error) + static bool Error(ref string? target, string error) { - header.Error = error; + target = error; return false; } } @@ -1211,7 +1213,6 @@ private struct MessageHeader internal long PayloadLength; internal int Mask; internal bool Compressed; - internal string? Error; } private readonly struct ControlMessage diff --git a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj index 6ff2159ff8507..cdfd537caafcc 100644 --- a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj +++ b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj @@ -1,17 +1,16 @@ - + $(NetCoreAppCurrent)-windows;$(NetCoreAppCurrent)-Unix;$(NetCoreAppCurrent)-Browser + + - - - + + + diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs index f9ccd309913db..b60ae01471c60 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs @@ -115,8 +115,10 @@ public async Task ReceiveAsync_UTF8SplitAcrossMultipleBuffers_ValidDataReceived( [InlineData(0b_1000_0111, 0b_0_000_0001, true)] // fin + opcode==7, no mask + length == 1 public async Task ReceiveAsync_InvalidFrameHeader_AbortsAndThrowsException(byte firstByte, byte secondByte, bool shouldFail) { - var stream = new MemoryStream(new byte[3] { firstByte, secondByte, (byte)'a' }); - using var websocket = CreateFromStream(stream, isServer: false, null, Timeout.InfiniteTimeSpan); + var (serverStream, clientStream) = WebSocketStream.Create(); + + serverStream.Write(firstByte, secondByte, (byte)'a'); + using var websocket = CreateFromStream(clientStream, isServer: false, null, Timeout.InfiniteTimeSpan); var buffer = new byte[1]; Task t = websocket.ReceiveAsync(buffer, CancellationToken.None); diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs new file mode 100644 index 0000000000000..1ad7811f1f344 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs @@ -0,0 +1,71 @@ +using System.Text; +using System.Threading.Tasks; +using Xunit; + +namespace System.Net.WebSockets.Tests +{ + [PlatformSpecific(~TestPlatforms.Browser)] + public class WebSocketDeflateTests + { + [Fact] + public async Task HelloWithContextTakeover() + { + (var server, var client) = WebSocketStream.Create(); + + server.Write(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); + using var websocket = WebSocket.CreateFromStream(client, new WebSocketCreationOptions + { + DeflateOptions = new() + }); + + var buffer = new byte[5]; + var result = await websocket.ReceiveAsync(buffer, default); + + Assert.True(result.EndOfMessage); + Assert.Equal(buffer.Length, result.Count); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + Assert.Equal("Hello", Encoding.UTF8.GetString(buffer)); + + // Because context takeover is set by default if we try to send + // the same message it would take fewer bytes. + server.Write(0xc1, 0x05, 0xf2, 0x00, 0x11, 0x00, 0x00); + + buffer.AsSpan().Clear(); + result = await websocket.ReceiveAsync(buffer, default); + + Assert.True(result.EndOfMessage); + Assert.Equal(buffer.Length, result.Count); + Assert.Equal("Hello", Encoding.UTF8.GetString(buffer)); + } + + [Fact] + public async Task HelloWithoutContextTakeover() + { + (var server, var client) = WebSocketStream.Create(); + + using var websocket = WebSocket.CreateFromStream(client, new WebSocketCreationOptions + { + DeflateOptions = new() + { + ClientContextTakeover = false + } + }); + + var buffer = new byte[5]; + + for (var i = 0; i < 100; ++i) + { + // Without context takeover the message should look the same every time + server.Write(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); + buffer.AsSpan().Clear(); + + var result = await websocket.ReceiveAsync(buffer, default); + + Assert.True(result.EndOfMessage); + Assert.Equal(buffer.Length, result.Count); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + Assert.Equal("Hello", Encoding.UTF8.GetString(buffer)); + } + } + } +} diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs new file mode 100644 index 0000000000000..03ab6aebbbd15 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs @@ -0,0 +1,138 @@ +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.WebSockets.Tests +{ + /// + /// A helper stream class that can be used simulate + /// sending / receiving (duplex) data in a websocket. + /// + public class WebSocketStream : Stream + { + private readonly SemaphoreSlim _inputLock = new(initialCount: 0); + private readonly Queue _inputQueue = new(); + + private WebSocketStream _remoteStream; + + public static (WebSocketStream server, WebSocketStream client) Create() + { + var server = new WebSocketStream(); + var client = new WebSocketStream(); + + server._remoteStream = client; + client._remoteStream = server; + + return (server, client); + } + + private WebSocketStream() + { + GC.SuppressFinalize(this); + } + + public override bool CanRead => true; + + public override bool CanSeek => false; + + public override bool CanWrite => true; + + public override long Length => -1; + + public override long Position { get => -1; set => throw new NotSupportedException(); } + + protected override void Dispose(bool disposing) + { + _inputLock.Dispose(); + + lock (_remoteStream._inputQueue) + { + _remoteStream._inputQueue.Enqueue(Block.ConnectionClosed); + _remoteStream._inputLock.Release(); + } + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + try + { + await _inputLock.WaitAsync(cancellationToken); + } + catch (ObjectDisposedException) + { + return 0; + } + + lock (_inputQueue) + { + var block = _inputQueue.Peek(); + if (block == Block.ConnectionClosed) + return 0; + + var count = Math.Min(block.AvailableLength, buffer.Length); + + block.Available.Slice(0, count).CopyTo(buffer.Span); + block.Advance(count); + + if (block.AvailableLength == 0) + _inputQueue.Dequeue(); + + return count; + } + } + + public void Write(params byte[] data) + { + lock (_remoteStream._inputQueue) + { + _remoteStream._inputQueue.Enqueue(new Block(data)); + _remoteStream._inputLock.Release(); + } + } + + public override void Write(ReadOnlySpan buffer) + { + lock (_remoteStream._inputQueue) + { + _remoteStream._inputQueue.Enqueue(new Block(buffer.ToArray())); + _remoteStream._inputLock.Release(); + } + } + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + Write(buffer.Span); + return ValueTask.CompletedTask; + } + + public override void Flush() => throw new NotSupportedException(); + + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + + public override void SetLength(long value) => throw new NotSupportedException(); + + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + + private sealed class Block + { + public static readonly Block ConnectionClosed = new(Array.Empty()); + + private readonly byte[] _data; + private int _position; + + public Block(byte[] data) + { + _data = data; + } + + public Span Available => _data.AsSpan(_position); + + public int AvailableLength => _data.Length - _position; + + public void Advance(int count) => _position += count; + } + } +} From 63df73ae472319b4ddab02b9676e3b7fd11dbc62 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Tue, 16 Feb 2021 15:49:24 +0200 Subject: [PATCH 19/52] Simplified the creation of the websocket stream used for testing. --- .../tests/WebSocketCreateTest.cs | 6 +-- .../tests/WebSocketStream.cs | 48 +++++++++++-------- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs index b60ae01471c60..66f0707c5c27a 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs @@ -115,10 +115,10 @@ public async Task ReceiveAsync_UTF8SplitAcrossMultipleBuffers_ValidDataReceived( [InlineData(0b_1000_0111, 0b_0_000_0001, true)] // fin + opcode==7, no mask + length == 1 public async Task ReceiveAsync_InvalidFrameHeader_AbortsAndThrowsException(byte firstByte, byte secondByte, bool shouldFail) { - var (serverStream, clientStream) = WebSocketStream.Create(); + var stream = new WebSocketStream(); - serverStream.Write(firstByte, secondByte, (byte)'a'); - using var websocket = CreateFromStream(clientStream, isServer: false, null, Timeout.InfiniteTimeSpan); + stream.Write(firstByte, secondByte, (byte)'a'); + using var websocket = CreateFromStream(stream.Remote, isServer: false, null, Timeout.InfiniteTimeSpan); var buffer = new byte[1]; Task t = websocket.ReceiveAsync(buffer, CancellationToken.None); diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs index 03ab6aebbbd15..e052347237982 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs @@ -14,24 +14,20 @@ public class WebSocketStream : Stream private readonly SemaphoreSlim _inputLock = new(initialCount: 0); private readonly Queue _inputQueue = new(); - private WebSocketStream _remoteStream; - - public static (WebSocketStream server, WebSocketStream client) Create() + public WebSocketStream() { - var server = new WebSocketStream(); - var client = new WebSocketStream(); - - server._remoteStream = client; - client._remoteStream = server; - - return (server, client); + GC.SuppressFinalize(this); + Remote = new WebSocketStream(this); } - private WebSocketStream() + private WebSocketStream(WebSocketStream remote) { GC.SuppressFinalize(this); + Remote = remote; } + public WebSocketStream Remote { get; } + public override bool CanRead => true; public override bool CanSeek => false; @@ -46,10 +42,16 @@ protected override void Dispose(bool disposing) { _inputLock.Dispose(); - lock (_remoteStream._inputQueue) + lock (Remote._inputQueue) { - _remoteStream._inputQueue.Enqueue(Block.ConnectionClosed); - _remoteStream._inputLock.Release(); + try + { + Remote._inputLock.Release(); + Remote._inputQueue.Enqueue(Block.ConnectionClosed); + } + catch ( ObjectDisposedException) + { + } } } @@ -84,19 +86,25 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation public void Write(params byte[] data) { - lock (_remoteStream._inputQueue) + lock (Remote._inputQueue) { - _remoteStream._inputQueue.Enqueue(new Block(data)); - _remoteStream._inputLock.Release(); + Remote._inputLock.Release(); + Remote._inputQueue.Enqueue(new Block(data)); } } public override void Write(ReadOnlySpan buffer) { - lock (_remoteStream._inputQueue) + lock (Remote._inputQueue) { - _remoteStream._inputQueue.Enqueue(new Block(buffer.ToArray())); - _remoteStream._inputLock.Release(); + try + { + Remote._inputLock.Release(); + Remote._inputQueue.Enqueue(new Block(buffer.ToArray())); + } + catch (ObjectDisposedException) + { + } } } From 4316bd1f18544de49e8cfabc41804283a6b1d51b Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Tue, 16 Feb 2021 15:50:06 +0200 Subject: [PATCH 20/52] Added duplex end to end test that verifies that client and server compress / decompress messages as expected. --- .../tests/WebSocketDeflateTests.cs | 65 ++++++++++++++++--- 1 file changed, 57 insertions(+), 8 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs index 1ad7811f1f344..98ff39406b2d1 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs @@ -10,10 +10,10 @@ public class WebSocketDeflateTests [Fact] public async Task HelloWithContextTakeover() { - (var server, var client) = WebSocketStream.Create(); - - server.Write(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); - using var websocket = WebSocket.CreateFromStream(client, new WebSocketCreationOptions + var stream = new WebSocketStream(); + + stream.Write(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); + using var websocket = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions { DeflateOptions = new() }); @@ -28,7 +28,7 @@ public async Task HelloWithContextTakeover() // Because context takeover is set by default if we try to send // the same message it would take fewer bytes. - server.Write(0xc1, 0x05, 0xf2, 0x00, 0x11, 0x00, 0x00); + stream.Write(0xc1, 0x05, 0xf2, 0x00, 0x11, 0x00, 0x00); buffer.AsSpan().Clear(); result = await websocket.ReceiveAsync(buffer, default); @@ -41,9 +41,9 @@ public async Task HelloWithContextTakeover() [Fact] public async Task HelloWithoutContextTakeover() { - (var server, var client) = WebSocketStream.Create(); + var stream = new WebSocketStream(); - using var websocket = WebSocket.CreateFromStream(client, new WebSocketCreationOptions + using var websocket = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions { DeflateOptions = new() { @@ -56,7 +56,7 @@ public async Task HelloWithoutContextTakeover() for (var i = 0; i < 100; ++i) { // Without context takeover the message should look the same every time - server.Write(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); + stream.Write(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); buffer.AsSpan().Clear(); var result = await websocket.ReceiveAsync(buffer, default); @@ -67,5 +67,54 @@ public async Task HelloWithoutContextTakeover() Assert.Equal("Hello", Encoding.UTF8.GetString(buffer)); } } + + [Fact] + public async Task Duplex() + { + var stream = new WebSocketStream(); + using var server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + IsServer = true, + DeflateOptions = new() + }); + using var client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + DeflateOptions = new() + }); + + var buffer = new byte[1024]; + + for (var i = 0; i < 10; ++i) + { + var message = $"Sending number {i} from server."; + await SendTextAsync(message, server); + + var result = await client.ReceiveAsync(buffer.AsMemory(), default); + + Assert.True(result.EndOfMessage); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + + Assert.Equal(message, Encoding.UTF8.GetString(buffer.AsSpan(0, result.Count))); + } + + for (var i = 0; i < 10; ++i) + { + var message = $"Sending number {i} from client."; + await SendTextAsync(message, client); + + var result = await server.ReceiveAsync(buffer.AsMemory(), default); + + Assert.True(result.EndOfMessage); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + + Assert.Equal(message, Encoding.UTF8.GetString(buffer.AsSpan(0, result.Count))); + } + } + + private static ValueTask SendTextAsync(string text, WebSocket websocket) + { + var bytes = Encoding.UTF8.GetBytes(text); + return websocket.SendAsync(bytes.AsMemory(), WebSocketMessageType.Text, true, default); + } } } From 398e3d8164808055f1014c8f1f9887e0fe5aec1c Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Wed, 17 Feb 2021 11:52:00 +0200 Subject: [PATCH 21/52] Lighting up websocket compression in client. --- .../ref/System.Net.WebSockets.Client.cs | 2 + .../src/Resources/Strings.resx | 66 +++++------ .../ClientWebSocketOptions.cs | 7 ++ .../Net/WebSockets/ClientWebSocketOptions.cs | 3 + .../Net/WebSockets/WebSocketHandle.Managed.cs | 104 +++++++++++++++++- 5 files changed, 146 insertions(+), 36 deletions(-) diff --git a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs index cee3a5170b862..beaa33f9226dc 100644 --- a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs +++ b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs @@ -36,6 +36,8 @@ internal ClientWebSocketOptions() { } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.TimeSpan KeepAliveInterval { get { throw null; } set { } } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] + public WebSocketDeflateOptions? DeflateOptions { get { throw null; } set { } } + [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.Net.IWebProxy? Proxy { get { throw null; } set { } } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.Net.Security.RemoteCertificateValidationCallback? RemoteCertificateValidationCallback { get { throw null; } set { } } diff --git a/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx index 3259b86c99fcb..5649bc52e7653 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx +++ b/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx @@ -1,16 +1,17 @@ - - @@ -193,8 +194,11 @@ Connection was aborted. - + WebSocket binary type '{0}' not supported. - - + + + The WebSocket failed to negotiate max {0} window bits. The client requested {1} but the server responded with {2}. + + \ No newline at end of file diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs index 85b0f025b4650..2ed5c527421c9 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs @@ -100,6 +100,13 @@ public TimeSpan KeepAliveInterval set => throw new PlatformNotSupportedException(); } + [UnsupportedOSPlatform("browser")] + public WebSocketDeflateOptions? DeflateOptions + { + get => throw new PlatformNotSupportedException(); + set => throw new PlatformNotSupportedException(); + } + [UnsupportedOSPlatform("browser")] public void SetBuffer(int receiveBufferSize, int sendBufferSize) { diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs index a7609a0ff0905..573c3eb8325b7 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs @@ -148,6 +148,9 @@ public TimeSpan KeepAliveInterval } } + [UnsupportedOSPlatform("browser")] + public WebSocketDeflateOptions? DeflateOptions { get; set; } + internal int ReceiveBufferSize => _receiveBufferSize; internal ArraySegment? Buffer => _buffer; diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index a378566af65ca..8316a94a63c03 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Globalization; using System.IO; using System.Net.Http; using System.Net.Http.Headers; @@ -175,6 +176,21 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli } } + // Because deflate options are negotiated we need a new object + WebSocketDeflateOptions? deflateOptions = null; + + if (options.DeflateOptions is not null && response.Headers.TryGetValues(HttpKnownHeaderNames.SecWebSocketExtensions, out var extensions)) + { + foreach (var extension in extensions) + { + if (extension.StartsWith("permessage-deflate")) + { + deflateOptions = ParseDeflateOptions(extension, options.DeflateOptions); + break; + } + } + } + if (response.Content is null) { throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely); @@ -184,11 +200,13 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli Stream connectedStream = response.Content.ReadAsStream(); Debug.Assert(connectedStream.CanWrite); Debug.Assert(connectedStream.CanRead); - WebSocket = WebSocket.CreateFromStream( - connectedStream, - isServer: false, - subprotocol, - options.KeepAliveInterval); + WebSocket = WebSocket.CreateFromStream(connectedStream, new WebSocketCreationOptions + { + IsServer = false, + SubProtocol = subprotocol, + KeepAliveInterval = options.KeepAliveInterval, + DeflateOptions = deflateOptions, + }); } catch (Exception exc) { @@ -218,6 +236,47 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli } } + private static WebSocketDeflateOptions ParseDeflateOptions(string extensions, WebSocketDeflateOptions original) + { + var options = new WebSocketDeflateOptions(); + + foreach (var value in extensions.Split(';', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries)) + { + if (value == "client_no_context_takeover") + { + options.ClientContextTakeover = false; + } + else if (value == "server_no_context_takeover") + { + options.ServerContextTakeover = false; + } + else if (value.StartsWith("client_max_window_bits=")) + { + options.ClientMaxWindowBits = int.Parse(value.Substring("client_max_window_bits=".Length), + NumberFormatInfo.InvariantInfo); + } + else if (value.StartsWith("server_max_window_bits=")) + { + options.ServerMaxWindowBits = int.Parse(value.Substring("server_max_window_bits=".Length), + NumberFormatInfo.InvariantInfo); + } + } + + if (options.ClientMaxWindowBits > original.ClientMaxWindowBits) + { + throw new WebSocketException(string.Format(SR.net_WebSockets_WindowBitsNegotiationFailure, + "client", original.ClientMaxWindowBits, options.ClientMaxWindowBits)); + } + + if (options.ServerMaxWindowBits > original.ServerMaxWindowBits) + { + throw new WebSocketException(string.Format(SR.net_WebSockets_WindowBitsNegotiationFailure, + "server", original.ServerMaxWindowBits, options.ServerMaxWindowBits)); + } + + return options; + } + /// Adds the necessary headers for the web socket request. /// The request to which the headers should be added. /// The generated security key to send in the Sec-WebSocket-Key header. @@ -232,6 +291,41 @@ private static void AddWebSocketHeaders(HttpRequestMessage request, string secKe { request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketProtocol, string.Join(", ", options.RequestedSubProtocols)); } + if (options.DeflateOptions is not null) + { + request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketExtensions, string.Join("; ", GetDeflateOptions(options.DeflateOptions))); + + static IEnumerable GetDeflateOptions(WebSocketDeflateOptions options) + { + yield return "permessage-deflate"; + + if (options.ClientMaxWindowBits != 15) + { + yield return "client_max_window_bits=" + options.ClientMaxWindowBits; + } + else + { + // Advertise that we support this option + yield return "client_max_window_bits"; + } + + if (options.ServerMaxWindowBits != 15) + { + yield return "server_max_window_bits=" + options.ServerMaxWindowBits; + } + else + { + // Advertise that we support this option + yield return "server_max_window_bits"; + } + + if (!options.ServerContextTakeover) + yield return "server_no_context_takeover"; + + if (!options.ClientContextTakeover) + yield return "client_no_context_takeover"; + } + } } /// From 004c58909ca3ad394c0bbd9262a3f87472eaed92 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Thu, 18 Feb 2021 15:55:37 +0200 Subject: [PATCH 22/52] Updating the constraints of the MaxWindowBits properties - 8 is no longer a valid window bits value, because the underlying gzip library doesn't really support it. --- .../System/Net/WebSockets/WebSocketDeflateOptions.cs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs index 375617ec3cb45..a4044494f7cdc 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs @@ -13,16 +13,16 @@ public sealed class WebSocketDeflateOptions /// /// This parameter indicates the base-2 logarithm of the LZ77 sliding window size of the client context. - /// Must be a value between 8 and 15. The default is 15. + /// Must be a value between 9 and 15. The default is 15. /// public int ClientMaxWindowBits { get => _clientMaxWindowBits; set { - if (value < 8 || value > 15) + if (value < 9 || value > 15) throw new ArgumentOutOfRangeException(nameof(ClientMaxWindowBits), value, - SR.Format(SR.net_WebSockets_ArgumentOutOfRange, 8, 15)); + SR.Format(SR.net_WebSockets_ArgumentOutOfRange, 9, 15)); _clientMaxWindowBits = value; } @@ -36,16 +36,16 @@ public int ClientMaxWindowBits /// /// This parameter indicates the base-2 logarithm of the LZ77 sliding window size of the server context. - /// Must be a value between 8 and 15. The default is 15. + /// Must be a value between 9 and 15. The default is 15. /// public int ServerMaxWindowBits { get => _serverMaxWindowBits; set { - if (value < 8 || value > 15) + if (value < 9 || value > 15) throw new ArgumentOutOfRangeException(nameof(ServerMaxWindowBits), value, - SR.Format(SR.net_WebSockets_ArgumentOutOfRange, 8, 15)); + SR.Format(SR.net_WebSockets_ArgumentOutOfRange, 9, 15)); _serverMaxWindowBits = value; } From 16353ea4ab3b863050b8d1df58e568146242f1fc Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Thu, 18 Feb 2021 15:58:58 +0200 Subject: [PATCH 23/52] Added more tests and fixed some bugs related to splitting messages in multiple frames and reading with smaller output buffers. --- .../src/System/IO/Compression/Deflater.cs | 4 +- .../src/System/IO/Compression/Inflater.cs | 20 +- .../WebSockets/ManagedWebSocket.Receiver.cs | 231 +++++++++++++----- .../Net/WebSockets/ManagedWebSocket.Sender.cs | 44 ++-- .../System/Net/WebSockets/ManagedWebSocket.cs | 22 +- .../tests/WebSocketDeflateTests.cs | 121 ++++++++- .../tests/WebSocketStream.cs | 36 ++- 7 files changed, 375 insertions(+), 103 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Deflater.cs b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Deflater.cs index 2324cf69f0d52..1abcb6347a9f1 100644 --- a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Deflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Deflater.cs @@ -63,7 +63,7 @@ public unsafe void Deflate(ReadOnlySpan input, Span output, out int _handle.NextOut = (IntPtr)fixedOutput; _handle.AvailOut = (uint)output.Length; - Deflate(ZFlushCode.NoFlush); + Deflate((ZFlushCode)5/*Z_BLOCK*/); consumed = input.Length - (int)_handle.AvailIn; written = output.Length - (int)_handle.AvailOut; @@ -80,7 +80,7 @@ public unsafe int Finish(Span output, out bool completed) _handle.NextOut = (IntPtr)fixedOutput; _handle.AvailOut = (uint)output.Length; - var errorCode = Deflate(ZFlushCode.SyncFlush); + var errorCode = Deflate((ZFlushCode)3/*Z_FULL_FLUSH*/); var writtenBytes = output.Length - (int)_handle.AvailOut; completed = errorCode == ZErrorCode.Ok && writtenBytes < output.Length; diff --git a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Inflater.cs b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Inflater.cs index 180b73725195c..16d3738e5dc9a 100644 --- a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Inflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Inflater.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; +using System.Diagnostics; using System.Net.WebSockets; using System.Runtime.InteropServices; @@ -62,12 +64,24 @@ internal unsafe void Inflate(ReadOnlySpan input, Span output, out in } } + public unsafe int Inflate(Span destination) + { + fixed (byte* bufPtr = &MemoryMarshal.GetReference(destination)) + { + _handle.NextOut = (IntPtr)bufPtr; + _handle.AvailOut = (uint)destination.Length; + + Inflate(ZLibNative.FlushCode.NoFlush); + return destination.Length - (int)_handle.AvailOut; + } + } + public void Dispose() => _handle.Dispose(); /// /// Wrapper around the ZLib inflate function /// - private ZLibNative.ErrorCode Inflate(ZLibNative.FlushCode flushCode) + private void Inflate(ZLibNative.FlushCode flushCode) { ZLibNative.ErrorCode errorCode; try @@ -82,10 +96,8 @@ private ZLibNative.ErrorCode Inflate(ZLibNative.FlushCode flushCode) { case ZLibNative.ErrorCode.Ok: // progress has been made inflating case ZLibNative.ErrorCode.StreamEnd: // The end of the input stream has been reached - return errorCode; - case ZLibNative.ErrorCode.BufError: // No room in the output buffer - inflate() can be called again with more space to continue - return errorCode; + break; case ZLibNative.ErrorCode.MemError: // Not enough memory to complete the operation throw new WebSocketException(SR.ZLibErrorNotEnoughMemory); diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs index 00ee7f5adbd24..539014a3d316e 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs @@ -12,9 +12,21 @@ namespace System.Net.WebSockets { internal partial class ManagedWebSocket { - private const int ReceivedConnectionClose = -1; - private const int ReceivedControlMessage = -2; - private const int ReceivedHeaderError = -3; + private enum ReceiveResultType + { + Message, + ConnectionClose, + ControlMessage, + HeaderError + } + + private readonly struct ReceiveResult + { + public int Count { get; init; } + public bool EndOfMessage { get; init; } + public ReceiveResultType ResultType { get; init; } + public WebSocketMessageType MessageType { get; init; } + } private sealed class Receiver : IDisposable { @@ -22,22 +34,29 @@ private sealed class Receiver : IDisposable private readonly Stream _stream; private readonly Decoder? _decoder; + /// + /// Because a message might be split into multiple fragments we need to + /// keep in mind that even if we've processed all of them, we need to call + /// finish on the decoder to flush any left over data. + /// + private bool _decodingFinished = true; + /// /// If we have a decoder we cannot use the buffer provided from clients because /// we cannot guarantee that the decoding can happen in place. This buffer is rent'ed /// and returned when consumed. /// - private byte[]? _decoderBuffer; + private byte[]? _decoderInputBuffer; /// - /// The next index that needs to be consumed from the decoder's buffer. + /// The next index that needs to be consumed from the decoder's input buffer. /// - private int _decoderBufferPosition; + private int _decoderInputPosition; /// /// The number of usable bytes in the decoder's buffer. /// - private int _decoderBufferCount; + private int _decoderInputCount; /// /// The last header received in a ReceiveAsync. If ReceiveAsync got a header but then @@ -103,15 +122,13 @@ public void Dispose() { _decoder?.Dispose(); - if (_decoderBuffer is not null) + if (_decoderInputBuffer is not null) { - ArrayPool.Shared.Return(_decoderBuffer); - _decoderBuffer = null; + ArrayPool.Shared.Return(_decoderInputBuffer); + _decoderInputBuffer = null; } } - public MessageHeader GetLastHeader() => _lastHeader; - public string? GetHeaderError() => _headerError; /// Issues a read on the stream to wait for EOF. @@ -175,28 +192,35 @@ public async ValueTask WaitForServerToCloseConnectionAsync(CancellationToken can return new ControlMessage(_lastHeader.Opcode, payload); } - public async ValueTask ReceiveAsync(Memory buffer, CancellationToken cancellationToken) + public async ValueTask ReceiveAsync(Memory buffer, CancellationToken cancellationToken) { - _readBuffer.DiscardConsumed(); - // When there's nothing left over to receive, start a new if (_lastHeader.PayloadLength == 0) { + if (!_decodingFinished) + { + Debug.Assert(_decoder is not null); + _decodingFinished = _decoder.Finish(buffer.Span, out var written); + + return Result(written); + } + + _readBuffer.DiscardConsumed(); var success = await ReceiveHeaderAsync(cancellationToken).ConfigureAwait(false); if (!success) - return _headerError is not null ? ReceivedHeaderError : ReceivedConnectionClose; + return Result(_headerError is not null ? ReceiveResultType.HeaderError : ReceiveResultType.ConnectionClose); if (_lastHeader.Opcode > MessageOpcode.Binary) { // The received message is a control message and it's up // to the websocket how to handle it. - return ReceivedControlMessage; + return Result(ReceiveResultType.ControlMessage); } } if (buffer.IsEmpty) - return 0; + return default; // The number of bytes that are copied onto the provided buffer var resultByteCount = 0; @@ -221,15 +245,20 @@ public async ValueTask ReceiveAsync(Memory buffer, CancellationToken _readBuffer.Consume(consumed); _lastHeader.PayloadLength -= consumed; - if (_lastHeader.PayloadLength == 0 || _readBuffer.AvailableLength > 0) + resultByteCount += written; + + if (_lastHeader.PayloadLength == 0 || buffer.Length == written) { - // If the payload length is 0 it means that we have consumed everything. - // Otherwise if available length is still non zero, than it means that the - // decoder needs more memory and the operation cannot continue. - return written; + // We have either received everything or the buffer is full. + if (_decoder is not null && _lastHeader.PayloadLength == 0 && _lastHeader.Fin) + { + _decodingFinished = _decoder.Finish(buffer.Span.Slice(written), out written); + resultByteCount += written; + } + + return Result(resultByteCount); } - resultByteCount += written; buffer = buffer.Slice(written); } @@ -247,45 +276,55 @@ public async ValueTask ReceiveAsync(Memory buffer, CancellationToken var bytesRead = await _stream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); if (bytesRead <= 0) - return ReceivedConnectionClose; + return Result(ReceiveResultType.ConnectionClose); resultByteCount += bytesRead; ApplyMask(buffer.Span.Slice(0, bytesRead)); } else { - if (_decoderBuffer is null) + if (_decoderInputBuffer is null) { // Rent a buffer but restrict it's max size to 1MB - _decoderBuffer = ArrayPool.Shared.Rent((int)Math.Min(_lastHeader.PayloadLength, 1_000_000)); - _decoderBufferCount = await _stream.ReadAsync(_decoderBuffer, cancellationToken).ConfigureAwait(false); - if (_decoderBufferCount <= 0) + var decoderBufferLength = (int)Math.Min(_lastHeader.PayloadLength, 1_000_000); + + _decoderInputBuffer = ArrayPool.Shared.Rent(decoderBufferLength); + _decoderInputCount = await _stream.ReadAsync(_decoderInputBuffer.AsMemory(0, decoderBufferLength), cancellationToken).ConfigureAwait(false); + _decoderInputPosition = 0; + + if (_decoderInputCount <= 0) { - ArrayPool.Shared.Return(_decoderBuffer); - return ReceivedConnectionClose; + ArrayPool.Shared.Return(_decoderInputBuffer); + _decoderInputBuffer = null; + + return Result(ReceiveResultType.ConnectionClose); } - ApplyMask(_decoderBuffer.AsSpan(_decoderBufferPosition, _decoderBufferCount)); + ApplyMask(_decoderInputBuffer.AsSpan(0, _decoderInputCount)); } - // There is lefover data that we need to decode - _decoder.Decode(input: _decoderBuffer.AsSpan(_decoderBufferPosition, _decoderBufferCount), + _decoder.Decode(input: _decoderInputBuffer.AsSpan(_decoderInputPosition, _decoderInputCount), output: buffer.Span, out var consumed, out var written); resultByteCount += written; - _decoderBufferPosition += consumed; - _decoderBufferCount -= consumed; + _decoderInputPosition += consumed; + _decoderInputCount -= consumed; _lastHeader.PayloadLength -= consumed; - if (_decoderBufferCount == 0) + if (_decoderInputCount == 0) { - ArrayPool.Shared.Return(_decoderBuffer); - _decoderBuffer = null; - _decoderBufferPosition = 0; + ArrayPool.Shared.Return(_decoderInputBuffer); + _decoderInputBuffer = null; + + if (_lastHeader.PayloadLength == 0 && _lastHeader.Fin) + { + _decodingFinished = _decoder.Finish(buffer.Span.Slice(written), out written); + resultByteCount += written; + } } } - return resultByteCount; + return Result(resultByteCount); } private async ValueTask ReceiveHeaderAsync(CancellationToken cancellationToken) @@ -305,10 +344,6 @@ private async ValueTask ReceiveHeaderAsync(CancellationToken cancellationT header.Opcode = _lastHeader.Opcode; header.Compressed = _lastHeader.Compressed; } - else - { - _decoder?.Reset(); - } _lastHeader = header; _readBuffer.Consume(consumedBytes); @@ -341,6 +376,19 @@ private async ValueTask ReceiveHeaderAsync(CancellationToken cancellationT return true; } + private ReceiveResult Result(int count) => new ReceiveResult + { + Count = count, + ResultType = ReceiveResultType.Message, + MessageType = _lastHeader.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary, + EndOfMessage = _lastHeader.Fin && _lastHeader.PayloadLength == 0 && _decodingFinished + }; + + private ReceiveResult Result(ReceiveResultType resultType) => new ReceiveResult + { + ResultType = resultType + }; + private void ApplyMask(Span input) { if (_isServer) @@ -406,17 +454,19 @@ private abstract class Decoder : IDisposable public abstract void Dispose(); + public abstract void Decode(ReadOnlySpan input, Span output, out int consumed, out int written); + /// - /// Resets the decoder state after fully processing a message. + /// Finishes the decoding by writing any outstanding data to the output. /// - public abstract void Reset(); - - public abstract void Decode(ReadOnlySpan input, Span output, out int consumed, out int written); + /// true if the finish completed, false to indicate that there is more outstanding data. + public abstract bool Finish(Span output, out int written); } private class Inflater : Decoder { private readonly int _windowBits; + private byte? _remainingByte; // Although the inflater isn't persisted accross messages, a single message // might have been split into multiple frames. @@ -428,10 +478,17 @@ private class Inflater : Decoder public override void Dispose() => _inflater?.Dispose(); - public override void Reset() + public override bool Finish(Span output, out int written) { - _inflater?.Dispose(); - _inflater = null; + Debug.Assert(_inflater is not null); + + if (Finish(_inflater, output, out written, ref _remainingByte)) + { + _inflater.Dispose(); + _inflater = null; + return true; + } + return false; } public override void Decode(ReadOnlySpan input, Span output, out int consumed, out int written) @@ -439,14 +496,63 @@ public override void Decode(ReadOnlySpan input, Span output, out int _inflater ??= new IO.Compression.Inflater(_windowBits); _inflater.Inflate(input, output, out consumed, out written); } + + public static bool Finish(IO.Compression.Inflater inflater, Span output, out int written, ref byte? remainingByte) + { + written = 0; + + if (output.Length == 0) + { + if (remainingByte is not null) + return false; + + if (IsFinished(inflater, out remainingByte)) + { + return true; + } + } + else + { + if (remainingByte is not null) + { + output[0] = remainingByte.GetValueOrDefault(); + written = 1; + remainingByte = null; + } + + written += inflater.Inflate(output.Slice(written)); + if (written < output.Length || IsFinished(inflater, out remainingByte)) + { + return true; + } + } + + return false; + } + + public static bool IsFinished(IO.Compression.Inflater inflater, out byte? remainingByte) + { + // There is no other way to make sure that we'e consumed all data + // but to try to inflate again with at least one byte of output buffer. + Span oneByte = stackalloc byte[1]; + if (inflater.Inflate(oneByte) == 0) + { + remainingByte = null; + return true; + } + + remainingByte = oneByte[0]; + return false; + } } private sealed class PersistedInflater : Decoder { - private static ReadOnlySpan Trailer => new byte[] { 0x00, 0x00, 0xFF, 0xFF }; + private static ReadOnlySpan FlushMarker => new byte[] { 0x00, 0x00, 0xFF, 0xFF }; private readonly IO.Compression.Inflater _inflater; - private bool _needsTrailer; + private bool _needsFlushMarker; + private byte? _remainingByte; public PersistedInflater(int windowBits) => _inflater = new(windowBits); @@ -454,22 +560,25 @@ private sealed class PersistedInflater : Decoder public override void Dispose() => _inflater.Dispose(); - public override void Reset() + public override bool Finish(Span output, out int written) { - if (_needsTrailer) + if (_needsFlushMarker) { - _needsTrailer = false; - _inflater.Inflate(Trailer, Array.Empty(), out var consumed, out var written); + _needsFlushMarker = false; + _inflater.Inflate(FlushMarker, output, out var consumed, out written); - Debug.Assert(consumed == 4); - Debug.Assert(written == 0); + Debug.Assert(consumed == FlushMarker.Length); + + return written < output.Length || Inflater.IsFinished(_inflater, out _remainingByte); } + + return Inflater.Finish(_inflater, output, out written, ref _remainingByte); } public override void Decode(ReadOnlySpan input, Span output, out int consumed, out int written) { _inflater.Inflate(input, output, out consumed, out written); - _needsTrailer = true; + _needsFlushMarker = true; } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs index 5329deaa6669b..d9beb1ad69d72 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs @@ -4,7 +4,6 @@ using System.Buffers; using System.Diagnostics; using System.IO; -using System.IO.Compression; using System.Security.Cryptography; using System.Threading; using System.Threading.Tasks; @@ -31,19 +30,20 @@ public Sender(Stream stream, WebSocketCreationOptions options) if (deflate is not null) { // Important note here is that we must use negative window bits - // which will instruct the underlying implementation to not emit deflate headers + // which will instruct the underlying implementation to not emit gzip headers if (options.IsServer) { - _encoder = deflate.ServerContextTakeover ? - new PersistedDeflater(-deflate.ServerMaxWindowBits) : - new Deflater(-deflate.ServerMaxWindowBits); - } - else - { + // If we are the server we must use the client options _encoder = deflate.ClientContextTakeover ? new PersistedDeflater(-deflate.ClientMaxWindowBits) : new Deflater(-deflate.ClientMaxWindowBits); } + else + { + _encoder = deflate.ServerContextTakeover ? + new PersistedDeflater(-deflate.ServerMaxWindowBits) : + new Deflater(-deflate.ServerMaxWindowBits); + } } } @@ -58,7 +58,7 @@ public ValueTask SendAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemo buffer.Advance(MaxMessageHeaderLength); // Encoding is onlt supported for user messages - if (_encoder is not null && opcode <= MessageOpcode.Continuation) + if (_encoder is not null && opcode <= MessageOpcode.Binary) { _encoder.Encode(content.Span, ref buffer, continuation: opcode == MessageOpcode.Continuation, endOfMessage, out reservedBits); } @@ -182,7 +182,7 @@ private void EncodeHeader(Span header, MessageOpcode opcode, bool endOfMes { // Generate the mask. header[1] |= 0x80; - RandomNumberGenerator.Fill(header.Slice(header.Length - _maskLength)); + RandomNumberGenerator.Fill(header.Slice(header.Length - MaskLength)); } } @@ -281,25 +281,25 @@ public static void Encode(ReadOnlySpan payload, ref Buffer buffer, IO.Comp payload = payload.Slice(consumed); } - while (true) - { - var bytesWritten = deflater.Finish(buffer.GetSpan(), out var completed); - buffer.Advance(bytesWritten); + // See comment by Mark Adler https://github.com/madler/zlib/issues/149#issuecomment-225237457 + // At that point there will be at most a few bits left to write. + // Then call deflate() with Z_FULL_FLUSH and no more input and at least six bytes of available output. + var bytesWritten = deflater.Finish(buffer.GetSpan(6), out var completed); + buffer.Advance(bytesWritten); - if (completed) - break; - } + Debug.Assert(completed); - if (final) - { - // The deflated block always ends with 0x00 0x00 0xFF 0xFF - // but the websocket protocol doesn't want it. - Debug.Assert( + // The deflated block always ends with 0x00 0x00 0xFF 0xFF + Debug.Assert( buffer.WrittenSpan[^4] == 0x00 && buffer.WrittenSpan[^3] == 0x00 && buffer.WrittenSpan[^2] == 0xFF && buffer.WrittenSpan[^1] == 0xFF); + if (final) + { + // As per RFC we need to remove the flush markers + // 0x00 0x00 0xFF 0xFF buffer.Advance(-4); } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 80023b8ce5858..7ac2cc8f62feb 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -541,11 +541,11 @@ private async ValueTask ReceiveAsyncPrivate ReceiveAsyncPrivate 0 && - !TryValidateUtf8(payloadBuffer.Span.Slice(0, byteCount), header.Fin && header.PayloadLength == 0, _utf8TextState)) + if (result.MessageType == WebSocketMessageType.Text && result.Count > 0 && + !TryValidateUtf8(payloadBuffer.Span.Slice(0, result.Count), result.EndOfMessage, _utf8TextState)) { await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.InvalidPayloadData, WebSocketError.Faulted).ConfigureAwait(false); } return resultGetter.GetResult( - count: byteCount, - messageType: header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary, - endOfMessage: header.Fin && header.PayloadLength == 0, + count: result.Count, + messageType: result.MessageType, + endOfMessage: result.EndOfMessage, closeStatus: null, closeDescription: null); } } diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs index 98ff39406b2d1..220d1defe9234 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs @@ -1,4 +1,8 @@ -using System.Text; +using System.Collections.Generic; +using System.IO; +using System.Security.Cryptography; +using System.Text; +using System.Threading; using System.Threading.Tasks; using Xunit; @@ -7,6 +11,17 @@ namespace System.Net.WebSockets.Tests [PlatformSpecific(~TestPlatforms.Browser)] public class WebSocketDeflateTests { + public static IEnumerable SupportedWindowBits + { + get + { + for (var i = 9; i <= 15; ++i) + { + yield return new object[] { i }; + } + } + } + [Fact] public async Task HelloWithContextTakeover() { @@ -68,6 +83,40 @@ public async Task HelloWithoutContextTakeover() } } + [Fact] + public async Task TwoDeflateBlocksInOneMessage() + { + // Two or more DEFLATE blocks may be used in one message. + using var cancellation = new CancellationTokenSource(TimeSpan.FromSeconds(1000)); + var stream = new WebSocketStream(); + using var websocket = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + DeflateOptions = new() + }); + // The first 3 octets(0xf2 0x48 0x05) and the least significant two + // bits of the 4th octet(0x00) constitute one DEFLATE block with + // "BFINAL" set to 0 and "BTYPE" set to 01 containing "He". The rest of + // the 4th octet contains the header bits with "BFINAL" set to 0 and + // "BTYPE" set to 00, and the 3 padding bits of 0. Together with the + // following 4 octets(0x00 0x00 0xff 0xff), the header bits constitute + // an empty DEFLATE block with no compression. A DEFLATE block + // containing "llo" follows the empty DEFLATE block. + stream.Write(0x41, 0x08, 0xf2, 0x48, 0x05, 0x00, 0x00, 0x00, 0xff, 0xff); + stream.Write(0x80, 0x05, 0xca, 0xc9, 0xc9, 0x07, 0x00); + + Memory buffer = new byte[5]; + var result = await websocket.ReceiveAsync(buffer, cancellation.Token); + + Assert.Equal(2, result.Count); + Assert.False(result.EndOfMessage); + + result = await websocket.ReceiveAsync(buffer.Slice(result.Count), cancellation.Token); + + Assert.Equal(3, result.Count); + Assert.True(result.EndOfMessage); + Assert.Equal("Hello", Encoding.UTF8.GetString(buffer.Span)); + } + [Fact] public async Task Duplex() { @@ -111,6 +160,76 @@ public async Task Duplex() } } + [Theory] + [MemberData(nameof(SupportedWindowBits))] + public async Task LargeMessageSplitInMultipleFrames(int windowBits) + { + var stream = new WebSocketStream(); + using var server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + IsServer = true, + DeflateOptions = new() + { + ClientMaxWindowBits = windowBits + } + }); + using var client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + DeflateOptions = new() + { + ClientMaxWindowBits = windowBits + } + }); + + Memory testData = File.ReadAllBytes(typeof(WebSocketDeflateTests).Assembly.Location).AsMemory().TrimEnd((byte)0); + Memory receivedData = new byte[testData.Length]; + + // Test it a few times with different frame sizes + for (var i = 0; i < 10; ++i) + { + // Use a timeout cancellation token in case something doesn't work right + using var cancellation = new CancellationTokenSource(TimeSpan.FromSeconds(1000)); + + var frameSize = RandomNumberGenerator.GetInt32(1024, 2048); + var position = 0; + + while (position < testData.Length) + { + var currentFrameSize = Math.Min(frameSize, testData.Length - position); + var eof = position + currentFrameSize == testData.Length; + + await server.SendAsync(testData.Slice(position, currentFrameSize), WebSocketMessageType.Binary, eof, cancellation.Token); + position += currentFrameSize; + } + + Assert.Equal(testData.Length, position); + Assert.True(testData.Length > stream.Remote.Available, "The data must be compressed."); + + // Receive the data from the client side + receivedData.Span.Clear(); + position = 0; + + // Intentionally receive with a frame size that is less than what the sender used + frameSize /= 3; + + while (true) + { + var currentFrameSize = Math.Min(frameSize, testData.Length - position); + var result = await client.ReceiveAsync(receivedData.Slice(position, currentFrameSize), cancellation.Token); + + Assert.Equal(WebSocketMessageType.Binary, result.MessageType); + position += result.Count; + + if (result.EndOfMessage) + break; + } + + Assert.Equal(0, stream.Remote.Available); + Assert.Equal(testData.Length, position); + Assert.True(testData.Span.SequenceEqual(receivedData.Span)); + } + } + private static ValueTask SendTextAsync(string text, WebSocket websocket) { var bytes = Encoding.UTF8.GetBytes(text); diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs index e052347237982..b6ca90ed0860c 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs @@ -28,6 +28,27 @@ private WebSocketStream(WebSocketStream remote) public WebSocketStream Remote { get; } + /// + /// Returns the number of unread bytes. + /// + public int Available + { + get + { + var available = 0; + + lock (_inputQueue) + { + foreach ( var x in _inputQueue) + { + available += x.AvailableLength; + } + } + + return available; + } + } + public override bool CanRead => true; public override bool CanSeek => false; @@ -49,7 +70,7 @@ protected override void Dispose(bool disposing) Remote._inputLock.Release(); Remote._inputQueue.Enqueue(Block.ConnectionClosed); } - catch ( ObjectDisposedException) + catch (ObjectDisposedException) { } } @@ -78,7 +99,20 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation block.Advance(count); if (block.AvailableLength == 0) + { _inputQueue.Dequeue(); + } + else + { + try + { + // Because we haven't fully consumed the buffer + // we should release once the input lock so we can acquire + // it again on consequent receive. + _inputLock.Release(); + } + catch (ObjectDisposedException) { } + } return count; } From ed4b069cf550bc211c86fb436dc815d2322b4ad2 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Thu, 18 Feb 2021 16:02:39 +0200 Subject: [PATCH 24/52] Added cancellation token to tests to avoid cases where something is expected to happen but doesn't and we don't have timeouts. --- .../tests/WebSocketDeflateTests.cs | 38 ++++++++++++------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs index 220d1defe9234..a9255c93f90cf 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using System.Diagnostics; using System.IO; using System.Security.Cryptography; using System.Text; @@ -11,6 +12,18 @@ namespace System.Net.WebSockets.Tests [PlatformSpecific(~TestPlatforms.Browser)] public class WebSocketDeflateTests { + private CancellationTokenSource? _cancellation; + + public WebSocketDeflateTests() + { + if (!Debugger.IsAttached) + { + _cancellation = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + } + } + + public CancellationToken CancellationToken => _cancellation?.Token ?? default; + public static IEnumerable SupportedWindowBits { get @@ -34,7 +47,7 @@ public async Task HelloWithContextTakeover() }); var buffer = new byte[5]; - var result = await websocket.ReceiveAsync(buffer, default); + var result = await websocket.ReceiveAsync(buffer, CancellationToken); Assert.True(result.EndOfMessage); Assert.Equal(buffer.Length, result.Count); @@ -46,7 +59,7 @@ public async Task HelloWithContextTakeover() stream.Write(0xc1, 0x05, 0xf2, 0x00, 0x11, 0x00, 0x00); buffer.AsSpan().Clear(); - result = await websocket.ReceiveAsync(buffer, default); + result = await websocket.ReceiveAsync(buffer, CancellationToken); Assert.True(result.EndOfMessage); Assert.Equal(buffer.Length, result.Count); @@ -74,7 +87,7 @@ public async Task HelloWithoutContextTakeover() stream.Write(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); buffer.AsSpan().Clear(); - var result = await websocket.ReceiveAsync(buffer, default); + var result = await websocket.ReceiveAsync(buffer, CancellationToken); Assert.True(result.EndOfMessage); Assert.Equal(buffer.Length, result.Count); @@ -87,7 +100,6 @@ public async Task HelloWithoutContextTakeover() public async Task TwoDeflateBlocksInOneMessage() { // Two or more DEFLATE blocks may be used in one message. - using var cancellation = new CancellationTokenSource(TimeSpan.FromSeconds(1000)); var stream = new WebSocketStream(); using var websocket = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions { @@ -105,12 +117,12 @@ public async Task TwoDeflateBlocksInOneMessage() stream.Write(0x80, 0x05, 0xca, 0xc9, 0xc9, 0x07, 0x00); Memory buffer = new byte[5]; - var result = await websocket.ReceiveAsync(buffer, cancellation.Token); + var result = await websocket.ReceiveAsync(buffer, CancellationToken); Assert.Equal(2, result.Count); Assert.False(result.EndOfMessage); - result = await websocket.ReceiveAsync(buffer.Slice(result.Count), cancellation.Token); + result = await websocket.ReceiveAsync(buffer.Slice(result.Count), CancellationToken); Assert.Equal(3, result.Count); Assert.True(result.EndOfMessage); @@ -138,7 +150,7 @@ public async Task Duplex() var message = $"Sending number {i} from server."; await SendTextAsync(message, server); - var result = await client.ReceiveAsync(buffer.AsMemory(), default); + var result = await client.ReceiveAsync(buffer.AsMemory(), CancellationToken); Assert.True(result.EndOfMessage); Assert.Equal(WebSocketMessageType.Text, result.MessageType); @@ -151,7 +163,7 @@ public async Task Duplex() var message = $"Sending number {i} from client."; await SendTextAsync(message, client); - var result = await server.ReceiveAsync(buffer.AsMemory(), default); + var result = await server.ReceiveAsync(buffer.AsMemory(), CancellationToken); Assert.True(result.EndOfMessage); Assert.Equal(WebSocketMessageType.Text, result.MessageType); @@ -188,8 +200,6 @@ public async Task LargeMessageSplitInMultipleFrames(int windowBits) for (var i = 0; i < 10; ++i) { // Use a timeout cancellation token in case something doesn't work right - using var cancellation = new CancellationTokenSource(TimeSpan.FromSeconds(1000)); - var frameSize = RandomNumberGenerator.GetInt32(1024, 2048); var position = 0; @@ -198,7 +208,7 @@ public async Task LargeMessageSplitInMultipleFrames(int windowBits) var currentFrameSize = Math.Min(frameSize, testData.Length - position); var eof = position + currentFrameSize == testData.Length; - await server.SendAsync(testData.Slice(position, currentFrameSize), WebSocketMessageType.Binary, eof, cancellation.Token); + await server.SendAsync(testData.Slice(position, currentFrameSize), WebSocketMessageType.Binary, eof, CancellationToken); position += currentFrameSize; } @@ -215,7 +225,7 @@ public async Task LargeMessageSplitInMultipleFrames(int windowBits) while (true) { var currentFrameSize = Math.Min(frameSize, testData.Length - position); - var result = await client.ReceiveAsync(receivedData.Slice(position, currentFrameSize), cancellation.Token); + var result = await client.ReceiveAsync(receivedData.Slice(position, currentFrameSize), CancellationToken); Assert.Equal(WebSocketMessageType.Binary, result.MessageType); position += result.Count; @@ -230,10 +240,10 @@ public async Task LargeMessageSplitInMultipleFrames(int windowBits) } } - private static ValueTask SendTextAsync(string text, WebSocket websocket) + private ValueTask SendTextAsync(string text, WebSocket websocket) { var bytes = Encoding.UTF8.GetBytes(text); - return websocket.SendAsync(bytes.AsMemory(), WebSocketMessageType.Text, true, default); + return websocket.SendAsync(bytes.AsMemory(), WebSocketMessageType.Text, true, CancellationToken); } } } From afd23ef603f88eb8378884ed042167a93144611a Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Thu, 18 Feb 2021 16:16:33 +0200 Subject: [PATCH 25/52] Fixed a bug where the receiver would happily ignore per message deflate flag in the header if decoder is not configured. --- .../src/Resources/Strings.resx | 3 +++ .../WebSockets/ManagedWebSocket.Receiver.cs | 20 +++++++++---------- .../tests/WebSocketDeflateTests.cs | 14 +++++++++++++ 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx index 09cc98eaad0e2..b9f8ef004d5b2 100644 --- a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx +++ b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx @@ -225,4 +225,7 @@ The stream state of the underlying compression routine is inconsistent. + + The WebSocket received compressed frame when compression is not enabled. + \ No newline at end of file diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs index 539014a3d316e..d9bc6b2515ac4 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs @@ -230,8 +230,9 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella int consumed, written; int available = (int)Math.Min(_readBuffer.AvailableLength, _lastHeader.PayloadLength); - if (_decoder is not null && _decoder.IsNeeded(_lastHeader)) + if (_lastHeader.Compressed) { + Debug.Assert(_decoder is not null); _decoder.Decode(input: _readBuffer.AvailableSpan.Slice(0, available), output: buffer.Span, out consumed, out written); } @@ -266,7 +267,7 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella // and should start issuing reads on the stream. Debug.Assert(_readBuffer.AvailableLength == 0 && _lastHeader.PayloadLength > 0); - if (_decoder is null || !_decoder.IsNeeded(_lastHeader)) + if (_decoder is null) { if (buffer.Length > _lastHeader.PayloadLength) { @@ -335,9 +336,14 @@ private async ValueTask ReceiveHeaderAsync(CancellationToken cancellationT while (true) { - if (TryParseMessageHeader(_readBuffer.AvailableSpan, _lastHeader, _isServer, out var header, - out var error, out var consumedBytes)) + if (TryParseMessageHeader(_readBuffer.AvailableSpan, _lastHeader, _isServer, out var header, out var error, out var consumedBytes)) { + if (header.Compressed && _decoder is null) + { + _headerError = SR.net_Websockets_PerMessageCompressedFlagWhenNotEnabled; + return false; + } + // If this is a continuation, replace the opcode with the one of the message it's continuing if (header.Opcode == MessageOpcode.Continuation) { @@ -450,8 +456,6 @@ public void DiscardConsumed() private abstract class Decoder : IDisposable { - public abstract bool IsNeeded(MessageHeader header); - public abstract void Dispose(); public abstract void Decode(ReadOnlySpan input, Span output, out int consumed, out int written); @@ -474,8 +478,6 @@ private class Inflater : Decoder public Inflater(int windowBits) => _windowBits = windowBits; - public override bool IsNeeded(MessageHeader header) => header.Compressed; - public override void Dispose() => _inflater?.Dispose(); public override bool Finish(Span output, out int written) @@ -556,8 +558,6 @@ private sealed class PersistedInflater : Decoder public PersistedInflater(int windowBits) => _inflater = new(windowBits); - public override bool IsNeeded(MessageHeader header) => header.Compressed; - public override void Dispose() => _inflater.Dispose(); public override bool Finish(Span output, out int written) diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs index a9255c93f90cf..f4da4cb37cd15 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs @@ -240,6 +240,20 @@ public async Task LargeMessageSplitInMultipleFrames(int windowBits) } } + [Fact] + public async Task WebSocketWithoutDeflateShouldThrowOnCompressedMessage() + { + var stream = new WebSocketStream(); + + stream.Write(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); + using var websocket = WebSocket.CreateFromStream(stream.Remote, new()); + + var exception = await Assert.ThrowsAsync(() => + websocket.ReceiveAsync(Memory.Empty, CancellationToken).AsTask()); + + Assert.Equal("The WebSocket received compressed frame when compression is not enabled.", exception.Message); + } + private ValueTask SendTextAsync(string text, WebSocket websocket) { var bytes = Encoding.UTF8.GetBytes(text); From f8d1f0e4e680ccbe80c0e0be84dd86dbd7ddc062 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Thu, 18 Feb 2021 16:45:15 +0200 Subject: [PATCH 26/52] Moving ZLibNative to Common so it can be cross compiled alongside Interop.zlib.cs. --- .../src/System/IO/Compression}/ZLibNative.ZStream.cs | 0 .../src/System/IO/Compression}/ZLibNative.cs | 0 .../System.IO.Compression/src/System.IO.Compression.csproj | 6 +++--- 3 files changed, 3 insertions(+), 3 deletions(-) rename src/libraries/{System.IO.Compression/src/System/IO/Compression/DeflateZLib => Common/src/System/IO/Compression}/ZLibNative.ZStream.cs (100%) rename src/libraries/{System.IO.Compression/src/System/IO/Compression/DeflateZLib => Common/src/System/IO/Compression}/ZLibNative.cs (100%) diff --git a/src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.ZStream.cs b/src/libraries/Common/src/System/IO/Compression/ZLibNative.ZStream.cs similarity index 100% rename from src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.ZStream.cs rename to src/libraries/Common/src/System/IO/Compression/ZLibNative.ZStream.cs diff --git a/src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.cs b/src/libraries/Common/src/System/IO/Compression/ZLibNative.cs similarity index 100% rename from src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.cs rename to src/libraries/Common/src/System/IO/Compression/ZLibNative.cs diff --git a/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj b/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj index e2a7adee12f57..758e86a80d6b4 100644 --- a/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj +++ b/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj @@ -1,4 +1,4 @@ - + true $(NetCoreAppCurrent)-windows;$(NetCoreAppCurrent)-Unix;$(NetCoreAppCurrent)-Browser @@ -25,8 +25,8 @@ - - + + From 8872025b055bfe09792d2b21ce7d795136ef8f77 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Thu, 18 Feb 2021 16:46:26 +0200 Subject: [PATCH 27/52] Removed custom ZLibNative and using the one which now resides in Common. --- .../src/System.Net.WebSockets.csproj | 20 +- .../IO/Compression/ZLibNative.ZStream.cs | 34 -- .../src/System/IO/Compression/ZLibNative.cs | 345 ------------------ .../Compression/WebSocketDeflater.cs} | 8 +- .../Compression/WebSocketInflater.cs} | 10 +- .../WebSockets/ManagedWebSocket.Receiver.cs | 11 +- .../Net/WebSockets/ManagedWebSocket.Sender.cs | 9 +- 7 files changed, 27 insertions(+), 410 deletions(-) delete mode 100644 src/libraries/System.Net.WebSockets/src/System/IO/Compression/ZLibNative.ZStream.cs delete mode 100644 src/libraries/System.Net.WebSockets/src/System/IO/Compression/ZLibNative.cs rename src/libraries/System.Net.WebSockets/src/System/{IO/Compression/Deflater.cs => Net/WebSockets/Compression/WebSocketDeflater.cs} (95%) rename src/libraries/System.Net.WebSockets/src/System/{IO/Compression/Inflater.cs => Net/WebSockets/Compression/WebSocketInflater.cs} (95%) diff --git a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj index 9ebda23e96f1b..56575010d4430 100644 --- a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj +++ b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj @@ -19,24 +19,20 @@ - - - - - - + + + + + + - + - + diff --git a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/ZLibNative.ZStream.cs b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/ZLibNative.ZStream.cs deleted file mode 100644 index bfb8c5145c04a..0000000000000 --- a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/ZLibNative.ZStream.cs +++ /dev/null @@ -1,34 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Runtime.InteropServices; - -namespace System.IO.Compression -{ - internal static partial class ZLibNative - { - /// - /// ZLib stream descriptor data structure - /// Do not construct instances of ZStream explicitly. - /// Always use ZLibNative.DeflateInit2_ or ZLibNative.InflateInit2_ instead. - /// Those methods will wrap this structure into a SafeHandle and thus make sure that it is always disposed correctly. - /// - [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Ansi)] - internal struct ZStream - { - internal void Init() - { - } - - internal IntPtr nextIn; //Bytef *next_in; /* next input byte */ - internal IntPtr nextOut; //Bytef *next_out; /* next output byte should be put there */ - - internal IntPtr msg; //char *msg; /* last error message, NULL if no error */ - - private readonly IntPtr internalState; //internal state that is not visible to managed code - - internal uint availIn; //uInt avail_in; /* number of bytes available at next_in */ - internal uint availOut; //uInt avail_out; /* remaining free space at next_out */ - } - } -} diff --git a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/ZLibNative.cs b/src/libraries/System.Net.WebSockets/src/System/IO/Compression/ZLibNative.cs deleted file mode 100644 index 8118aeba0ecb8..0000000000000 --- a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/ZLibNative.cs +++ /dev/null @@ -1,345 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Runtime.InteropServices; -using System.Security; - -namespace System.IO.Compression -{ - /// - /// This class provides declaration for constants and PInvokes as well as some basic tools for exposing the - /// native System.IO.Compression.Native.dll (effectively, ZLib) library to managed code. - /// - /// See also: How to choose a compression level (in comments to CompressionLevel. - /// - internal static partial class ZLibNative - { - // This is the NULL pointer for using with ZLib pointers; - // we prefer it to IntPtr.Zero to mimic the definition of Z_NULL in zlib.h: - internal static readonly IntPtr ZNullPtr = IntPtr.Zero; - - public enum FlushCode : int - { - NoFlush = 0, - SyncFlush = 2, - Finish = 4, - } - - public enum ErrorCode : int - { - Ok = 0, - StreamEnd = 1, - StreamError = -2, - DataError = -3, - MemError = -4, - BufError = -5, - VersionError = -6 - } - - /// - ///

ZLib can accept any integer value between 0 and 9 (inclusive) as a valid compression level parameter: - /// 1 gives best speed, 9 gives best compression, 0 gives no compression at all (the input data is simply copied a block at a time). - /// CompressionLevel.DefaultCompression = -1 requests a default compromise between speed and compression - /// (currently equivalent to level 6).

- /// - ///

How to choose a compression level:

- /// - ///

The names NoCompression, BestSpeed, DefaultCompression, BestCompression are taken over from - /// the corresponding ZLib definitions, which map to our public NoCompression, Fastest, Optimal, and SmallestSize respectively.

- ///

Optimal Compression:

- ///

ZLibNative.CompressionLevel compressionLevel = ZLibNative.CompressionLevel.DefaultCompression;
- /// int windowBits = 15; // or -15 if no headers required
- /// int memLevel = 8;
- /// ZLibNative.CompressionStrategy strategy = ZLibNative.CompressionStrategy.DefaultStrategy;

- /// - ///

Fastest compression:

- ///

ZLibNative.CompressionLevel compressionLevel = ZLibNative.CompressionLevel.BestSpeed;
- /// int windowBits = 15; // or -15 if no headers required
- /// int memLevel = 8;
- /// ZLibNative.CompressionStrategy strategy = ZLibNative.CompressionStrategy.DefaultStrategy;

- /// - ///

No compression (even faster, useful for data that cannot be compressed such some image formats):

- ///

ZLibNative.CompressionLevel compressionLevel = ZLibNative.CompressionLevel.NoCompression;
- /// int windowBits = 15; // or -15 if no headers required
- /// int memLevel = 7;
- /// ZLibNative.CompressionStrategy strategy = ZLibNative.CompressionStrategy.DefaultStrategy;

- /// - ///

Smallest Size Compression:

- ///

ZLibNative.CompressionLevel compressionLevel = ZLibNative.CompressionLevel.BestCompression;
- /// int windowBits = 15; // or -15 if no headers required
- /// int memLevel = 8;
- /// ZLibNative.CompressionStrategy strategy = ZLibNative.CompressionStrategy.DefaultStrategy;

- ///
- public enum CompressionLevel : int - { - NoCompression = 0, - BestSpeed = 1, - DefaultCompression = -1, - BestCompression = 9 - } - - /// - ///

From the ZLib manual:

- ///

CompressionStrategy is used to tune the compression algorithm.
- /// Use the value DefaultStrategy for normal data, Filtered for data produced by a filter (or predictor), - /// HuffmanOnly to force Huffman encoding only (no string match), or Rle to limit match distances to one - /// (run-length encoding). Filtered data consists mostly of small values with a somewhat random distribution. In this case, the - /// compression algorithm is tuned to compress them better. The effect of Filtered is to force more Huffman coding and] - /// less string matching; it is somewhat intermediate between DefaultStrategy and HuffmanOnly. - /// Rle is designed to be almost as fast as HuffmanOnly, but give better compression for PNG image data. - /// The strategy parameter only affects the compression ratio but not the correctness of the compressed output even if it is not set - /// appropriately. Fixed prevents the use of dynamic Huffman codes, allowing for a simpler decoder for special applications.

- /// - ///

For .NET Framework use:

- ///

We have investigated compression scenarios for a bunch of different frequently occurring compression data and found that in all - /// cases we investigated so far, DefaultStrategy provided best results

- ///

See also: How to choose a compression level (in comments to CompressionLevel.

- ///
- public enum CompressionStrategy : int - { - DefaultStrategy = 0 - } - - /// - /// In version 1.2.3, ZLib provides on the Deflated-CompressionMethod. - /// - public enum CompressionMethod : int - { - Deflated = 8 - } - - /// - ///

From the ZLib manual:

- ///

ZLib's windowBits parameter is the base two logarithm of the window size (the size of the history buffer). - /// It should be in the range 8..15 for this version of the library. Larger values of this parameter result in better compression - /// at the expense of memory usage. The default value is 15 if deflateInit is used instead.

- /// Note: - /// windowBits can also be -8..-15 for raw deflate. In this case, -windowBits determines the window size. - /// Deflate will then generate raw deflate data with no ZLib header or trailer, and will not compute an adler32 check value.
- ///

See also: How to choose a compression level (in comments to CompressionLevel.

- ///
- public const int Deflate_DefaultWindowBits = -15; // Legal values are 8..15 and -8..-15. 15 is the window size, - // negative val causes deflate to produce raw deflate data (no zlib header). - - /// - ///

From the ZLib manual:

- ///

ZLib's windowBits parameter is the base two logarithm of the window size (the size of the history buffer). - /// It should be in the range 8..15 for this version of the library. Larger values of this parameter result in better compression - /// at the expense of memory usage. The default value is 15 if deflateInit is used instead.

- ///
- public const int ZLib_DefaultWindowBits = 15; - - /// - ///

Zlib's windowBits parameter is the base two logarithm of the window size (the size of the history buffer). - /// For GZip header encoding, windowBits should be equal to a value between 8..15 (to specify Window Size) added to - /// 16. The range of values for GZip encoding is therefore 24..31. - /// Note: - /// The GZip header will have no file name, no extra data, no comment, no modification time (set to zero), no header crc, and - /// the operating system will be set based on the OS that the ZLib library was compiled to. ZStream.adler - /// is a crc32 instead of an adler32.

- ///
- public const int GZip_DefaultWindowBits = 31; - - /// - ///

From the ZLib manual:

- ///

The memLevel parameter specifies how much memory should be allocated for the internal compression state. - /// memLevel = 1 uses minimum memory but is slow and reduces compression ratio; memLevel = 9 uses maximum - /// memory for optimal speed. The default value is 8.

- ///

See also: How to choose a compression level (in comments to CompressionLevel.

- ///
- public const int Deflate_DefaultMemLevel = 8; // Memory usage by deflate. Legal range: [1..9]. 8 is ZLib default. - // More is faster and better compression with more memory usage. - public const int Deflate_NoCompressionMemLevel = 7; - - public const byte GZip_Header_ID1 = 31; - public const byte GZip_Header_ID2 = 139; - - /** - * Do not remove the nested typing of types inside of System.IO.Compression.ZLibNative. - * This was done on purpose to: - * - * - Achieve the right encapsulation in a situation where ZLibNative may be compiled division-wide - * into different assemblies that wish to consume System.IO.Compression.Native. Since internal - * scope is effectively like public scope when compiling ZLibNative into a higher - * level assembly, we need a combination of inner types and private-scope members to achieve - * the right encapsulation. - * - * - Achieve late dynamic loading of System.IO.Compression.Native.dll at the right time. - * The native assembly will not be loaded unless it is actually used since the loading is performed by a static - * constructor of an inner type that is not directly referenced by user code. - * - * In Dev12 we would like to create a proper feature for loading native assemblies from user-specified - * directories in order to PInvoke into them. This would preferably happen in the native interop/PInvoke - * layer; if not we can add a Framework level feature. - */ - - /// - /// The ZLibStreamHandle could be a CriticalFinalizerObject rather than a - /// SafeHandleMinusOneIsInvalid. This would save an IntPtr field since - /// ZLibStreamHandle does not actually use its handle field. - /// Instead it uses a private ZStream zStream field which is the actual handle data - /// structure requiring critical finalization. - /// However, we would like to take advantage if the better debugability offered by the fact that a - /// releaseHandleFailed MDA is raised if the ReleaseHandle method returns - /// false, which can for instance happen if the underlying ZLib XxxxEnd - /// routines return an failure error code. - /// - public sealed class ZLibStreamHandle : SafeHandle - { - public enum State { NotInitialized, InitializedForDeflate, InitializedForInflate, Disposed } - - private ZStream _zStream; - - private volatile State _initializationState; - - - public ZLibStreamHandle() - : base(new IntPtr(-1), true) - { - _zStream.Init(); - - _initializationState = State.NotInitialized; - SetHandle(IntPtr.Zero); - } - - public override bool IsInvalid - { - get { return handle == new IntPtr(-1); } - } - - public State InitializationState - { - get { return _initializationState; } - } - - - protected override bool ReleaseHandle() => - InitializationState switch - { - State.NotInitialized => true, - State.InitializedForDeflate => (DeflateEnd() == ErrorCode.Ok), - State.InitializedForInflate => (InflateEnd() == ErrorCode.Ok), - State.Disposed => true, - _ => false, // This should never happen. Did we forget one of the State enum values in the switch? - }; - - public IntPtr NextIn - { - get { return _zStream.nextIn; } - set { _zStream.nextIn = value; } - } - - public uint AvailIn - { - get { return _zStream.availIn; } - set { _zStream.availIn = value; } - } - - public IntPtr NextOut - { - get { return _zStream.nextOut; } - set { _zStream.nextOut = value; } - } - - public uint AvailOut - { - get { return _zStream.availOut; } - set { _zStream.availOut = value; } - } - - private void EnsureNotDisposed() - { - if (InitializationState == State.Disposed) - throw new ObjectDisposedException(GetType().ToString()); - } - - - private void EnsureState(State requiredState) - { - if (InitializationState != requiredState) - throw new InvalidOperationException("InitializationState != " + requiredState.ToString()); - } - - - public ErrorCode DeflateInit2_(CompressionLevel level, int windowBits, int memLevel, CompressionStrategy strategy) - { - EnsureNotDisposed(); - EnsureState(State.NotInitialized); - - ErrorCode errC = Interop.zlib.DeflateInit2_(ref _zStream, level, CompressionMethod.Deflated, windowBits, memLevel, strategy); - _initializationState = State.InitializedForDeflate; - - return errC; - } - - - public ErrorCode Deflate(FlushCode flush) - { - EnsureNotDisposed(); - EnsureState(State.InitializedForDeflate); - return Interop.zlib.Deflate(ref _zStream, flush); - } - - - public ErrorCode DeflateEnd() - { - EnsureNotDisposed(); - EnsureState(State.InitializedForDeflate); - - ErrorCode errC = Interop.zlib.DeflateEnd(ref _zStream); - _initializationState = State.Disposed; - - return errC; - } - - - public ErrorCode InflateInit2_(int windowBits) - { - EnsureNotDisposed(); - EnsureState(State.NotInitialized); - - ErrorCode errC = Interop.zlib.InflateInit2_(ref _zStream, windowBits); - _initializationState = State.InitializedForInflate; - - return errC; - } - - - public ErrorCode Inflate(FlushCode flush) - { - EnsureNotDisposed(); - EnsureState(State.InitializedForInflate); - return Interop.zlib.Inflate(ref _zStream, flush); - } - - - public ErrorCode InflateEnd() - { - EnsureNotDisposed(); - EnsureState(State.InitializedForInflate); - - ErrorCode errC = Interop.zlib.InflateEnd(ref _zStream); - _initializationState = State.Disposed; - - return errC; - } - - // This can work even after XxflateEnd(). - public string GetErrorMessage() => _zStream.msg != ZNullPtr ? Marshal.PtrToStringAnsi(_zStream.msg)! : string.Empty; - } - - public static ErrorCode CreateZLibStreamForDeflate(out ZLibStreamHandle zLibStreamHandle, CompressionLevel level, - int windowBits, int memLevel, CompressionStrategy strategy) - { - zLibStreamHandle = new ZLibStreamHandle(); - return zLibStreamHandle.DeflateInit2_(level, windowBits, memLevel, strategy); - } - - - public static ErrorCode CreateZLibStreamForInflate(out ZLibStreamHandle zLibStreamHandle, int windowBits) - { - zLibStreamHandle = new ZLibStreamHandle(); - return zLibStreamHandle.InflateInit2_(windowBits); - } - } -} diff --git a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Deflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs similarity index 95% rename from src/libraries/System.Net.WebSockets/src/System/IO/Compression/Deflater.cs rename to src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs index 1abcb6347a9f1..27e5f89f8f17c 100644 --- a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Deflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -1,21 +1,21 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Net.WebSockets; +using System.IO.Compression; using ZErrorCode = System.IO.Compression.ZLibNative.ErrorCode; using ZFlushCode = System.IO.Compression.ZLibNative.FlushCode; -namespace System.IO.Compression +namespace System.Net.WebSockets.Compression { /// /// Provides a wrapper around the ZLib compression API. /// - internal sealed class Deflater : IDisposable + internal sealed class WebSocketDeflater : IDisposable { private readonly ZLibNative.ZLibStreamHandle _handle; - internal Deflater(int windowBits) + internal WebSocketDeflater(int windowBits) { var compressionLevel = ZLibNative.CompressionLevel.DefaultCompression; var memLevel = ZLibNative.Deflate_DefaultMemLevel; diff --git a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Inflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs similarity index 95% rename from src/libraries/System.Net.WebSockets/src/System/IO/Compression/Inflater.cs rename to src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs index 16d3738e5dc9a..b08e6337f5718 100644 --- a/src/libraries/System.Net.WebSockets/src/System/IO/Compression/Inflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs @@ -1,21 +1,19 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Buffers; -using System.Diagnostics; -using System.Net.WebSockets; +using System.IO.Compression; using System.Runtime.InteropServices; -namespace System.IO.Compression +namespace System.Net.WebSockets.Compression { /// /// Provides a wrapper around the ZLib decompression API. /// - internal sealed class Inflater : IDisposable + internal sealed class WebSocketInflater : IDisposable { private readonly ZLibNative.ZLibStreamHandle _handle; - internal Inflater(int windowBits) + internal WebSocketInflater(int windowBits) { ZLibNative.ErrorCode error; try diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs index d9bc6b2515ac4..be41fb8fdc8be 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs @@ -4,6 +4,7 @@ using System.Buffers; using System.Diagnostics; using System.IO; +using System.Net.WebSockets.Compression; using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; @@ -474,7 +475,7 @@ private class Inflater : Decoder // Although the inflater isn't persisted accross messages, a single message // might have been split into multiple frames. - private IO.Compression.Inflater? _inflater; + private WebSocketInflater? _inflater; public Inflater(int windowBits) => _windowBits = windowBits; @@ -495,11 +496,11 @@ public override bool Finish(Span output, out int written) public override void Decode(ReadOnlySpan input, Span output, out int consumed, out int written) { - _inflater ??= new IO.Compression.Inflater(_windowBits); + _inflater ??= new WebSocketInflater(_windowBits); _inflater.Inflate(input, output, out consumed, out written); } - public static bool Finish(IO.Compression.Inflater inflater, Span output, out int written, ref byte? remainingByte) + public static bool Finish(WebSocketInflater inflater, Span output, out int written, ref byte? remainingByte) { written = 0; @@ -532,7 +533,7 @@ public static bool Finish(IO.Compression.Inflater inflater, Span output, o return false; } - public static bool IsFinished(IO.Compression.Inflater inflater, out byte? remainingByte) + public static bool IsFinished(WebSocketInflater inflater, out byte? remainingByte) { // There is no other way to make sure that we'e consumed all data // but to try to inflate again with at least one byte of output buffer. @@ -552,7 +553,7 @@ private sealed class PersistedInflater : Decoder { private static ReadOnlySpan FlushMarker => new byte[] { 0x00, 0x00, 0xFF, 0xFF }; - private readonly IO.Compression.Inflater _inflater; + private readonly WebSocketInflater _inflater; private bool _needsFlushMarker; private byte? _remainingByte; diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs index d9beb1ad69d72..65f12a1f82137 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs @@ -4,6 +4,7 @@ using System.Buffers; using System.Diagnostics; using System.IO; +using System.Net.WebSockets.Compression; using System.Security.Cryptography; using System.Threading; using System.Threading.Tasks; @@ -248,7 +249,7 @@ private class Deflater : Encoder // Although the inflater isn't persisted accross messages, a single message // might be split into multiple frames. - private IO.Compression.Deflater? _deflater; + private WebSocketDeflater? _deflater; public Deflater(int windowBits) => _windowBits = windowBits; @@ -259,7 +260,7 @@ internal override void Encode(ReadOnlySpan payload, ref Buffer buffer, boo Debug.Assert((continuation && _deflater is not null) || (!continuation && _deflater is null), "Invalid state. The deflater was expected to be null if not continuation and not null otherwise."); - _deflater ??= new IO.Compression.Deflater(_windowBits); + _deflater ??= new WebSocketDeflater(_windowBits); Encode(payload, ref buffer, _deflater, endOfMessage); reservedBits = continuation ? 0 : PerMessageDeflateBit; @@ -271,7 +272,7 @@ internal override void Encode(ReadOnlySpan payload, ref Buffer buffer, boo } } - public static void Encode(ReadOnlySpan payload, ref Buffer buffer, IO.Compression.Deflater deflater, bool final) + public static void Encode(ReadOnlySpan payload, ref Buffer buffer, WebSocketDeflater deflater, bool final) { while (payload.Length > 0) { @@ -310,7 +311,7 @@ public static void Encode(ReadOnlySpan payload, ref Buffer buffer, IO.Comp ///
private sealed class PersistedDeflater : Encoder { - private readonly IO.Compression.Deflater _deflater; + private readonly WebSocketDeflater _deflater; public PersistedDeflater(int windowBits) => _deflater = new(windowBits); From a0bba92db77b6010d9d09eebd1f5feaa704b78fb Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Thu, 18 Feb 2021 20:02:24 +0200 Subject: [PATCH 28/52] Small refactoring to make inflate / deflate logic more clear. --- .../src/Resources/Strings.resx | 6 - .../Compression/WebSocketDeflater.cs | 170 +++++++++----- .../Compression/WebSocketInflater.cs | 215 +++++++++++++---- .../WebSockets/ManagedWebSocket.Receiver.cs | 177 ++------------ .../Net/WebSockets/ManagedWebSocket.Sender.cs | 222 ++++++------------ 5 files changed, 376 insertions(+), 414 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx index b9f8ef004d5b2..22cf53cd585a6 100644 --- a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx +++ b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx @@ -204,18 +204,12 @@ The underlying compression routine could not be loaded correctly. - - The underlying compression routine received incorrect initialization parameters. - The underlying compression routine could not reserve sufficient memory. The underlying compression routine returned an unexpected error code {0}. - - The version of the underlying compression routine does not match expected version. - The message was compressed using an unsupported compression method. diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs index 27e5f89f8f17c..ca6cff6c5a89e 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -1,10 +1,10 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.IO.Compression; - -using ZErrorCode = System.IO.Compression.ZLibNative.ErrorCode; -using ZFlushCode = System.IO.Compression.ZLibNative.FlushCode; +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using static System.IO.Compression.ZLibNative; namespace System.Net.WebSockets.Compression { @@ -13,88 +13,120 @@ namespace System.Net.WebSockets.Compression ///
internal sealed class WebSocketDeflater : IDisposable { - private readonly ZLibNative.ZLibStreamHandle _handle; + private ZLibStreamHandle? _stream; + private readonly int _windowBits; + private readonly bool _persisted; - internal WebSocketDeflater(int windowBits) + internal WebSocketDeflater(int windowBits, bool persisted) { - var compressionLevel = ZLibNative.CompressionLevel.DefaultCompression; - var memLevel = ZLibNative.Deflate_DefaultMemLevel; - var strategy = ZLibNative.CompressionStrategy.DefaultStrategy; + Debug.Assert(windowBits >= 9 && windowBits <= 15); - ZErrorCode errorCode; - try - { - errorCode = ZLibNative.CreateZLibStreamForDeflate(out _handle, compressionLevel, windowBits, memLevel, strategy); - } - catch (Exception cause) - { - throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); - } + // We use negative window bits in order to produce raw deflate data + _windowBits = -windowBits; + _persisted = persisted; + } - switch (errorCode) + public void Dispose() => _stream?.Dispose(); + + public void Deflate(ReadOnlySpan payload, IBufferWriter output, bool continuation, bool endOfMessage) + { + Debug.Assert(!continuation || _stream is not null, "Invalid state. The stream should not be null in continuations."); + + if (_stream is null) + Initialize(); + + while (payload.Length > 0) { - case ZErrorCode.Ok: - return; + Deflate(payload, output.GetSpan(payload.Length), out var consumed, out var written); + output.Advance(written); - case ZErrorCode.MemError: - throw new WebSocketException(SR.ZLibErrorNotEnoughMemory); + payload = payload.Slice(consumed); + } - case ZErrorCode.VersionError: - throw new WebSocketException(SR.ZLibErrorVersionMismatch); + // See comment by Mark Adler https://github.com/madler/zlib/issues/149#issuecomment-225237457 + // At that point there will be at most a few bits left to write. + // Then call deflate() with Z_FULL_FLUSH and no more input and at least six bytes of available output. + Span end = stackalloc byte[6]; + var count = Flush(end); + + end = end.Slice(0, count); + // The deflated block always ends with 0x00 0x00 0xFF 0xFF + Debug.Assert(count >= 4); + Debug.Assert(end[^4] == 0x00 && + end[^3] == 0x00 && + end[^2] == 0xFF && + end[^1] == 0xFF); + + if (endOfMessage) + { + // As per RFC we need to remove the flush markers + end = end.Slice(0, end.Length - 4); + } - case ZErrorCode.StreamError: - throw new WebSocketException(SR.ZLibErrorIncorrectInitParameters); + end.CopyTo(output.GetSpan(end.Length)); + output.Advance(end.Length); - default: - throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errorCode)); + if (endOfMessage && !_persisted) + { + _stream.Dispose(); + _stream = null; } } - public void Dispose() => _handle.Dispose(); - - public unsafe void Deflate(ReadOnlySpan input, Span output, out int consumed, out int written) + private unsafe void Deflate(ReadOnlySpan input, Span output, out int consumed, out int written) { + Debug.Assert(_stream is not null); + fixed (byte* fixedInput = input) fixed (byte* fixedOutput = output) { - _handle.NextIn = (IntPtr)fixedInput; - _handle.AvailIn = (uint)input.Length; + _stream.NextIn = (IntPtr)fixedInput; + _stream.AvailIn = (uint)input.Length; - _handle.NextOut = (IntPtr)fixedOutput; - _handle.AvailOut = (uint)output.Length; + _stream.NextOut = (IntPtr)fixedOutput; + _stream.AvailOut = (uint)output.Length; - Deflate((ZFlushCode)5/*Z_BLOCK*/); + // If flush is set to Z_BLOCK, a deflate block is completed + // and emitted, as for Z_SYNC_FLUSH, but the output + // is not aligned on a byte boundary, and up to seven bits + // of the current block are held to be written as the next byte after + // the next deflate block is completed. + Deflate(_stream, (FlushCode)5/*Z_BLOCK*/); - consumed = input.Length - (int)_handle.AvailIn; - written = output.Length - (int)_handle.AvailOut; + consumed = input.Length - (int)_stream.AvailIn; + written = output.Length - (int)_stream.AvailOut; } } - public unsafe int Finish(Span output, out bool completed) + private unsafe int Flush(Span output) { + Debug.Assert(_stream is not null); + Debug.Assert(_stream.AvailIn == 0); + Debug.Assert(output.Length >= 6); + fixed (byte* fixedOutput = output) { - _handle.NextIn = IntPtr.Zero; - _handle.AvailIn = 0; + _stream.NextIn = IntPtr.Zero; + _stream.AvailIn = 0; - _handle.NextOut = (IntPtr)fixedOutput; - _handle.AvailOut = (uint)output.Length; + _stream.NextOut = (IntPtr)fixedOutput; + _stream.AvailOut = (uint)output.Length; - var errorCode = Deflate((ZFlushCode)3/*Z_FULL_FLUSH*/); - var writtenBytes = output.Length - (int)_handle.AvailOut; + var errorCode = Deflate(_stream, (FlushCode)3/*Z_FULL_FLUSH*/); + var writtenBytes = output.Length - (int)_stream.AvailOut; - completed = errorCode == ZErrorCode.Ok && writtenBytes < output.Length; + Debug.Assert(errorCode == ErrorCode.Ok); return writtenBytes; } } - private ZErrorCode Deflate(ZFlushCode flushCode) + private static ErrorCode Deflate(ZLibStreamHandle stream, FlushCode flushCode) { - ZErrorCode errorCode; + ErrorCode errorCode; try { - errorCode = _handle.Deflate(flushCode); + errorCode = stream.Deflate(flushCode); } catch (Exception cause) { @@ -103,19 +135,49 @@ private ZErrorCode Deflate(ZFlushCode flushCode) switch (errorCode) { - case ZErrorCode.Ok: - case ZErrorCode.StreamEnd: + case ErrorCode.Ok: + case ErrorCode.StreamEnd: return errorCode; - case ZErrorCode.BufError: + case ErrorCode.BufError: return errorCode; // This is a recoverable error - case ZErrorCode.StreamError: + case ErrorCode.StreamError: throw new WebSocketException(SR.ZLibErrorInconsistentStream); default: throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errorCode)); } } + + [MemberNotNull(nameof(_stream))] + private void Initialize() + { + Debug.Assert(_stream is null); + + var compressionLevel = CompressionLevel.DefaultCompression; + var memLevel = Deflate_DefaultMemLevel; + var strategy = CompressionStrategy.DefaultStrategy; + + ErrorCode errorCode; + try + { + errorCode = CreateZLibStreamForDeflate(out _stream, compressionLevel, _windowBits, memLevel, strategy); + } + catch (Exception cause) + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); + } + + switch (errorCode) + { + case ErrorCode.Ok: + return; + case ErrorCode.MemError: + throw new WebSocketException(SR.ZLibErrorNotEnoughMemory); + default: + throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errorCode)); + } + } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs index b08e6337f5718..51f0d43475c7f 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs @@ -1,8 +1,10 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.IO.Compression; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Runtime.InteropServices; +using static System.IO.Compression.ZLibNative; namespace System.Net.WebSockets.Compression { @@ -11,80 +13,165 @@ namespace System.Net.WebSockets.Compression /// internal sealed class WebSocketInflater : IDisposable { - private readonly ZLibNative.ZLibStreamHandle _handle; + private static ReadOnlySpan FlushMarker => new byte[] { 0x00, 0x00, 0xFF, 0xFF }; - internal WebSocketInflater(int windowBits) + private ZLibStreamHandle? _stream; + private readonly int _windowBits; + private readonly bool _persisted; + + /// + /// There is now way of knowing when decoding data if the underlying deflater + /// has flushed all outstanding data to consumer other than to provide a buffer + /// and see whether any bytes are written. There are cases when the consumers + /// provide a buffer exactly the size of the uncompressed data and in this case + /// to avoid requiring another read we will use this field. + /// + private byte? _remainingByte; + + /// + /// When the inflater is persisted we need to manually append the flush marker + /// before finishing the decoding. + /// + private bool _needsFlushMarker; + + internal WebSocketInflater(int windowBits, bool persisted) { - ZLibNative.ErrorCode error; - try + Debug.Assert(windowBits >= 9 && windowBits <= 15); + + // We use negative window bits to instruct deflater to expect raw deflate data + _windowBits = -windowBits; + _persisted = persisted; + } + + public void Dispose() => _stream?.Dispose(); + + public unsafe void Inflate(ReadOnlySpan input, Span output, out int consumed, out int written) + { + if (_stream is null) + Initialize(); + + fixed (byte* fixedInput = &MemoryMarshal.GetReference(input)) + fixed (byte* fixedOutput = &MemoryMarshal.GetReference(output)) { - error = ZLibNative.CreateZLibStreamForInflate(out _handle, windowBits); + _stream.NextIn = (IntPtr)fixedInput; + _stream.AvailIn = (uint)input.Length; + + _stream.NextOut = (IntPtr)fixedOutput; + _stream.AvailOut = (uint)output.Length; + + Inflate(_stream); + + consumed = input.Length - (int)_stream.AvailIn; + written = output.Length - (int)_stream.AvailOut; } - catch (Exception exception) // could not load the ZLib dll + + _needsFlushMarker = _persisted; + } + + /// + /// Finishes the decoding by writing any outstanding data to the output. + /// + /// true if the finish completed, false to indicate that there is more outstanding data. + public bool Finish(Span output, out int written) + { + Debug.Assert(_stream is not null); + + if (_needsFlushMarker) { - throw new WebSocketException(SR.ZLibErrorDLLLoadError, exception); + Inflate(FlushMarker, output, out var _, out written); + _needsFlushMarker = false; + + if ( written < output.Length || IsFinished(_stream, out _remainingByte) ) + { + OnFinished(); + return true; + } } - switch (error) - { - case ZLibNative.ErrorCode.Ok: // Successful initialization - return; + written = 0; - case ZLibNative.ErrorCode.MemError: // Not enough memory - throw new WebSocketException(SR.ZLibErrorNotEnoughMemory); + if (output.Length == 0) + { + if (_remainingByte is not null) + return false; + + if (IsFinished(_stream, out _remainingByte)) + { + OnFinished(); + return true; + } + } + else + { + if (_remainingByte is not null) + { + output[0] = _remainingByte.GetValueOrDefault(); + written = 1; + _remainingByte = null; + } + + written += Inflate(_stream, output.Slice(written)); + + if (written < output.Length || IsFinished(_stream, out _remainingByte)) + { + OnFinished(); + return true; + } + } - case ZLibNative.ErrorCode.VersionError: //zlib library is incompatible with the version assumed - throw new WebSocketException(SR.ZLibErrorVersionMismatch); + return false; + } - case ZLibNative.ErrorCode.StreamError: // Parameters are invalid - throw new WebSocketException(SR.ZLibErrorIncorrectInitParameters); + private void OnFinished() + { + Debug.Assert(_stream is not null); - default: - throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)error)); + if (!_persisted) + { + _stream.Dispose(); + _stream = null; } } - internal unsafe void Inflate(ReadOnlySpan input, Span output, out int consumed, out int written) + private static bool IsFinished(ZLibStreamHandle stream, out byte? remainingByte) { - fixed (byte* fixedInput = input) - fixed (byte* fixedOutput = &MemoryMarshal.GetReference(output)) + if (stream.AvailIn > 0) { - _handle.NextIn = (IntPtr)fixedInput; - _handle.AvailIn = (uint)input.Length; - - _handle.NextOut = (IntPtr)fixedOutput; - _handle.AvailOut = (uint)output.Length; - - Inflate(ZLibNative.FlushCode.NoFlush); + remainingByte = null; + return false; + } - consumed = input.Length - (int)_handle.AvailIn; - written = output.Length - (int)_handle.AvailOut; + // There is no other way to make sure that we'e consumed all data + // but to try to inflate again with at least one byte of output buffer. + Span oneByte = stackalloc byte[1]; + if (Inflate(stream, oneByte) == 0) + { + remainingByte = null; + return true; } + + remainingByte = oneByte[0]; + return false; } - public unsafe int Inflate(Span destination) + private static unsafe int Inflate(ZLibStreamHandle stream, Span destination) { fixed (byte* bufPtr = &MemoryMarshal.GetReference(destination)) { - _handle.NextOut = (IntPtr)bufPtr; - _handle.AvailOut = (uint)destination.Length; + stream.NextOut = (IntPtr)bufPtr; + stream.AvailOut = (uint)destination.Length; - Inflate(ZLibNative.FlushCode.NoFlush); - return destination.Length - (int)_handle.AvailOut; + Inflate(stream); + return destination.Length - (int)stream.AvailOut; } } - public void Dispose() => _handle.Dispose(); - - /// - /// Wrapper around the ZLib inflate function - /// - private void Inflate(ZLibNative.FlushCode flushCode) + private static void Inflate(ZLibStreamHandle stream) { - ZLibNative.ErrorCode errorCode; + ErrorCode errorCode; try { - errorCode = _handle.Inflate(flushCode); + errorCode = stream.Inflate(FlushCode.NoFlush); } catch (Exception cause) // could not load the Zlib DLL correctly { @@ -92,23 +179,49 @@ private void Inflate(ZLibNative.FlushCode flushCode) } switch (errorCode) { - case ZLibNative.ErrorCode.Ok: // progress has been made inflating - case ZLibNative.ErrorCode.StreamEnd: // The end of the input stream has been reached - case ZLibNative.ErrorCode.BufError: // No room in the output buffer - inflate() can be called again with more space to continue + case ErrorCode.Ok: // progress has been made inflating + case ErrorCode.StreamEnd: // The end of the input stream has been reached + case ErrorCode.BufError: // No room in the output buffer - inflate() can be called again with more space to continue break; - case ZLibNative.ErrorCode.MemError: // Not enough memory to complete the operation + case ErrorCode.MemError: // Not enough memory to complete the operation throw new WebSocketException(SR.ZLibErrorNotEnoughMemory); - case ZLibNative.ErrorCode.DataError: // The input data was corrupted (input stream not conforming to the zlib format or incorrect check value) + case ErrorCode.DataError: // The input data was corrupted (input stream not conforming to the zlib format or incorrect check value) throw new WebSocketException(SR.UnsupportedCompression); - case ZLibNative.ErrorCode.StreamError: //the stream structure was inconsistent (for example if next_in or next_out was NULL), + case ErrorCode.StreamError: //the stream structure was inconsistent (for example if next_in or next_out was NULL), throw new WebSocketException(SR.ZLibErrorInconsistentStream); default: throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errorCode)); } } + + [MemberNotNull(nameof(_stream))] + private void Initialize() + { + Debug.Assert(_stream is null); + + ErrorCode error; + try + { + error = CreateZLibStreamForInflate(out _stream, _windowBits); + } + catch (Exception exception) + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, exception); + } + + switch (error) + { + case ErrorCode.Ok: + return; + case ErrorCode.MemError: + throw new WebSocketException(SR.ZLibErrorNotEnoughMemory); + default: + throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)error)); + } + } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs index be41fb8fdc8be..b5e6addedf4d1 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs @@ -33,14 +33,14 @@ private sealed class Receiver : IDisposable { private readonly bool _isServer; private readonly Stream _stream; - private readonly Decoder? _decoder; + private readonly WebSocketInflater? _inflater; /// /// Because a message might be split into multiple fragments we need to /// keep in mind that even if we've processed all of them, we need to call /// finish on the decoder to flush any left over data. /// - private bool _decodingFinished = true; + private bool _inflateFinished = true; /// /// If we have a decoder we cannot use the buffer provided from clients because @@ -104,24 +104,15 @@ public Receiver(Stream stream, WebSocketCreationOptions options) { // Important note here is that we must use negative window bits // which will instruct the underlying implementation to not expect deflate headers - if (options.IsServer) - { - _decoder = deflate.ServerContextTakeover ? - new PersistedInflater(-deflate.ServerMaxWindowBits) : - new Inflater(-deflate.ServerMaxWindowBits); - } - else - { - _decoder = deflate.ClientContextTakeover ? - new PersistedInflater(-deflate.ClientMaxWindowBits) : - new Inflater(-deflate.ClientMaxWindowBits); - } + _inflater = options.IsServer ? + new WebSocketInflater(deflate.ServerMaxWindowBits, deflate.ServerContextTakeover) : + new WebSocketInflater(deflate.ClientMaxWindowBits, deflate.ClientContextTakeover); } } public void Dispose() { - _decoder?.Dispose(); + _inflater?.Dispose(); if (_decoderInputBuffer is not null) { @@ -198,10 +189,10 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella // When there's nothing left over to receive, start a new if (_lastHeader.PayloadLength == 0) { - if (!_decodingFinished) + if (!_inflateFinished) { - Debug.Assert(_decoder is not null); - _decodingFinished = _decoder.Finish(buffer.Span, out var written); + Debug.Assert(_inflater is not null); + _inflateFinished = _inflater.Finish(buffer.Span, out var written); return Result(written); } @@ -233,9 +224,9 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella if (_lastHeader.Compressed) { - Debug.Assert(_decoder is not null); - _decoder.Decode(input: _readBuffer.AvailableSpan.Slice(0, available), - output: buffer.Span, out consumed, out written); + Debug.Assert(_inflater is not null); + _inflater.Inflate(input: _readBuffer.AvailableSpan.Slice(0, available), + output: buffer.Span, out consumed, out written); } else { @@ -252,9 +243,9 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella if (_lastHeader.PayloadLength == 0 || buffer.Length == written) { // We have either received everything or the buffer is full. - if (_decoder is not null && _lastHeader.PayloadLength == 0 && _lastHeader.Fin) + if (_inflater is not null && _lastHeader.PayloadLength == 0 && _lastHeader.Fin) { - _decodingFinished = _decoder.Finish(buffer.Span.Slice(written), out written); + _inflateFinished = _inflater.Finish(buffer.Span.Slice(written), out written); resultByteCount += written; } @@ -268,7 +259,7 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella // and should start issuing reads on the stream. Debug.Assert(_readBuffer.AvailableLength == 0 && _lastHeader.PayloadLength > 0); - if (_decoder is null) + if (_inflater is null) { if (buffer.Length > _lastHeader.PayloadLength) { @@ -305,8 +296,8 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella ApplyMask(_decoderInputBuffer.AsSpan(0, _decoderInputCount)); } - _decoder.Decode(input: _decoderInputBuffer.AsSpan(_decoderInputPosition, _decoderInputCount), - output: buffer.Span, out var consumed, out var written); + _inflater.Inflate(input: _decoderInputBuffer.AsSpan(_decoderInputPosition, _decoderInputCount), + output: buffer.Span, out var consumed, out var written); resultByteCount += written; _decoderInputPosition += consumed; @@ -320,7 +311,7 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella if (_lastHeader.PayloadLength == 0 && _lastHeader.Fin) { - _decodingFinished = _decoder.Finish(buffer.Span.Slice(written), out written); + _inflateFinished = _inflater.Finish(buffer.Span.Slice(written), out written); resultByteCount += written; } } @@ -339,7 +330,7 @@ private async ValueTask ReceiveHeaderAsync(CancellationToken cancellationT { if (TryParseMessageHeader(_readBuffer.AvailableSpan, _lastHeader, _isServer, out var header, out var error, out var consumedBytes)) { - if (header.Compressed && _decoder is null) + if (header.Compressed && _inflater is null) { _headerError = SR.net_Websockets_PerMessageCompressedFlagWhenNotEnabled; return false; @@ -388,7 +379,7 @@ private async ValueTask ReceiveHeaderAsync(CancellationToken cancellationT Count = count, ResultType = ReceiveResultType.Message, MessageType = _lastHeader.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary, - EndOfMessage = _lastHeader.Fin && _lastHeader.PayloadLength == 0 && _decodingFinished + EndOfMessage = _lastHeader.Fin && _lastHeader.PayloadLength == 0 && _inflateFinished }; private ReceiveResult Result(ReceiveResultType resultType) => new ReceiveResult @@ -454,134 +445,6 @@ public void DiscardConsumed() _consumed = 0; } } - - private abstract class Decoder : IDisposable - { - public abstract void Dispose(); - - public abstract void Decode(ReadOnlySpan input, Span output, out int consumed, out int written); - - /// - /// Finishes the decoding by writing any outstanding data to the output. - /// - /// true if the finish completed, false to indicate that there is more outstanding data. - public abstract bool Finish(Span output, out int written); - } - - private class Inflater : Decoder - { - private readonly int _windowBits; - private byte? _remainingByte; - - // Although the inflater isn't persisted accross messages, a single message - // might have been split into multiple frames. - private WebSocketInflater? _inflater; - - public Inflater(int windowBits) => _windowBits = windowBits; - - public override void Dispose() => _inflater?.Dispose(); - - public override bool Finish(Span output, out int written) - { - Debug.Assert(_inflater is not null); - - if (Finish(_inflater, output, out written, ref _remainingByte)) - { - _inflater.Dispose(); - _inflater = null; - return true; - } - return false; - } - - public override void Decode(ReadOnlySpan input, Span output, out int consumed, out int written) - { - _inflater ??= new WebSocketInflater(_windowBits); - _inflater.Inflate(input, output, out consumed, out written); - } - - public static bool Finish(WebSocketInflater inflater, Span output, out int written, ref byte? remainingByte) - { - written = 0; - - if (output.Length == 0) - { - if (remainingByte is not null) - return false; - - if (IsFinished(inflater, out remainingByte)) - { - return true; - } - } - else - { - if (remainingByte is not null) - { - output[0] = remainingByte.GetValueOrDefault(); - written = 1; - remainingByte = null; - } - - written += inflater.Inflate(output.Slice(written)); - if (written < output.Length || IsFinished(inflater, out remainingByte)) - { - return true; - } - } - - return false; - } - - public static bool IsFinished(WebSocketInflater inflater, out byte? remainingByte) - { - // There is no other way to make sure that we'e consumed all data - // but to try to inflate again with at least one byte of output buffer. - Span oneByte = stackalloc byte[1]; - if (inflater.Inflate(oneByte) == 0) - { - remainingByte = null; - return true; - } - - remainingByte = oneByte[0]; - return false; - } - } - - private sealed class PersistedInflater : Decoder - { - private static ReadOnlySpan FlushMarker => new byte[] { 0x00, 0x00, 0xFF, 0xFF }; - - private readonly WebSocketInflater _inflater; - private bool _needsFlushMarker; - private byte? _remainingByte; - - public PersistedInflater(int windowBits) => _inflater = new(windowBits); - - public override void Dispose() => _inflater.Dispose(); - - public override bool Finish(Span output, out int written) - { - if (_needsFlushMarker) - { - _needsFlushMarker = false; - _inflater.Inflate(FlushMarker, output, out var consumed, out written); - - Debug.Assert(consumed == FlushMarker.Length); - - return written < output.Length || Inflater.IsFinished(_inflater, out _remainingByte); - } - - return Inflater.Finish(_inflater, output, out written, ref _remainingByte); - } - - public override void Decode(ReadOnlySpan input, Span output, out int consumed, out int written) - { - _inflater.Inflate(input, output, out consumed, out written); - _needsFlushMarker = true; - } - } } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs index 65f12a1f82137..c28e289155431 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs @@ -3,6 +3,7 @@ using System.Buffers; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.IO; using System.Net.WebSockets.Compression; using System.Security.Cryptography; @@ -15,12 +16,12 @@ internal partial class ManagedWebSocket { private sealed class Sender : IDisposable { - private const byte PerMessageDeflateBit = 0b0100_0000; - private readonly int _maskLength; - private readonly Encoder? _encoder; + private readonly WebSocketDeflater? _deflater; private readonly Stream _stream; + private readonly Buffer _buffer = new(); + public Sender(Stream stream, WebSocketCreationOptions options) { _maskLength = options.IsServer ? 0 : MaskLength; @@ -30,55 +31,47 @@ public Sender(Stream stream, WebSocketCreationOptions options) if (deflate is not null) { - // Important note here is that we must use negative window bits - // which will instruct the underlying implementation to not emit gzip headers - if (options.IsServer) - { - // If we are the server we must use the client options - _encoder = deflate.ClientContextTakeover ? - new PersistedDeflater(-deflate.ClientMaxWindowBits) : - new Deflater(-deflate.ClientMaxWindowBits); - } - else - { - _encoder = deflate.ServerContextTakeover ? - new PersistedDeflater(-deflate.ServerMaxWindowBits) : - new Deflater(-deflate.ServerMaxWindowBits); - } + // If we are the server we must use the client options + _deflater = options.IsServer ? + new WebSocketDeflater(deflate.ClientMaxWindowBits, deflate.ClientContextTakeover) : + new WebSocketDeflater(deflate.ServerMaxWindowBits, deflate.ServerContextTakeover); } } - public void Dispose() => _encoder?.Dispose(); + public void Dispose() => _deflater?.Dispose(); public ValueTask SendAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemory content, CancellationToken cancellationToken = default) { - var buffer = new Buffer(content.Length + MaxMessageHeaderLength); - byte reservedBits = 0; - - // Reserve space for the frame header - buffer.Advance(MaxMessageHeaderLength); + bool compressed = false; // Encoding is onlt supported for user messages - if (_encoder is not null && opcode <= MessageOpcode.Binary) + if (_deflater is not null && opcode <= MessageOpcode.Binary) { - _encoder.Encode(content.Span, ref buffer, continuation: opcode == MessageOpcode.Continuation, endOfMessage, out reservedBits); + _buffer.EnsureFreeCapacity(MaxMessageHeaderLength + (int)(content.Length * 0.6)); + _buffer.Advance(MaxMessageHeaderLength); + + _deflater.Deflate(content.Span, _buffer, continuation: opcode == MessageOpcode.Continuation, endOfMessage); + compressed = true; } else if (content.Length > 0) { - content.Span.CopyTo(buffer.GetSpan(content.Length)); - buffer.Advance(content.Length); + _buffer.EnsureFreeCapacity(MaxMessageHeaderLength + content.Length); + _buffer.Advance(MaxMessageHeaderLength); + + content.Span.CopyTo(_buffer.GetSpan(content.Length)); + _buffer.Advance(content.Length); } - var payload = buffer.WrittenSpan.Slice(MaxMessageHeaderLength); + var payload = _buffer.WrittenSpan.Slice(MaxMessageHeaderLength); var headerLength = CalculateHeaderLength(payload.Length); // Because we want the header to come just before to the payload // we will use a slice that offsets the unused part. var headerOffset = MaxMessageHeaderLength - headerLength; - var header = buffer.WrittenSpan.Slice(headerOffset, headerLength); + var header = _buffer.WrittenSpan.Slice(headerOffset, headerLength); // Write the message header data to the buffer. - EncodeHeader(header, opcode, endOfMessage, payload.Length, reservedBits); + EncodeHeader(header, opcode, endOfMessage, payload.Length, compressed); // If we added a mask to the header, XOR the payload with the mask. if (payload.Length > 0 && _maskLength > 0) @@ -86,26 +79,26 @@ public ValueTask SendAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemo ApplyMask(payload, BitConverter.ToInt32(header.Slice(header.Length - MaskLength)), 0); } - var releaseArray = true; + var resetBuffer = true; try { - var sendTask = _stream.WriteAsync(new ReadOnlyMemory(buffer.Array, headerOffset, headerLength + payload.Length), cancellationToken); + var sendTask = _stream.WriteAsync(_buffer.WrittenMemory.Slice(headerOffset), cancellationToken); if (sendTask.IsCompleted) return sendTask; - releaseArray = false; - return WaitAsync(sendTask.AsTask(), buffer.Array); + resetBuffer = false; + return WaitAsync(sendTask); } finally { - if (releaseArray) - ArrayPool.Shared.Return(buffer.Array); + if (resetBuffer) + _buffer.Reset(); } } - private static async ValueTask WaitAsync(Task sendTask, byte[] buffer) + private async ValueTask WaitAsync(ValueTask sendTask) { try { @@ -113,23 +106,19 @@ private static async ValueTask WaitAsync(Task sendTask, byte[] buffer) } finally { - ArrayPool.Shared.Return(buffer); + _buffer.Reset(); } } - private int CalculateHeaderLength(int payloadLength) => payloadLength switch + private int CalculateHeaderLength(int payloadLength) => _maskLength + (payloadLength switch { <= 125 => 2, <= ushort.MaxValue => 4, _ => 10 - } + _maskLength; + }); - private void EncodeHeader(Span header, MessageOpcode opcode, bool endOfMessage, int payloadLength, byte reservedBits) + private void EncodeHeader(Span header, MessageOpcode opcode, bool endOfMessage, int payloadLength, bool compressed) { - // The current implementation only supports per message deflate extension. In the future - // if more extensions are implemented or we allow third party extensions this assert must be changed. - Debug.Assert((reservedBits | 0b0100_0000) == 0b0100_0000); - // Client header format: // 1 bit - FIN - 1 if this is the final fragment in the message (it could be the only fragment), otherwise 0 // 1 bit - RSV1 - Per-Message Deflate Compress @@ -151,7 +140,11 @@ private void EncodeHeader(Span header, MessageOpcode opcode, bool endOfMes // 0 or 4 bytes - Mask, if Masked is 1 - random value XOR'd with each 4 bytes of the payload, round-robin // Length bytes - Payload data header[0] = (byte)opcode; // 4 bits for the opcode - header[0] |= reservedBits; + + if (compressed && opcode != MessageOpcode.Continuation) + { + header[0] |= 0b0100_0000; + } if (endOfMessage) { @@ -191,136 +184,73 @@ private void EncodeHeader(Span header, MessageOpcode opcode, bool endOfMes /// Helper class which allows writing to a rent'ed byte array /// and auto-grow functionality. /// - internal ref struct Buffer + private sealed class Buffer : IBufferWriter { - private byte[] _array; + private readonly ArrayPool _arrayPool; + + private byte[]? _array; private int _index; - public Buffer(int capacity) + public Buffer() { - _array = ArrayPool.Shared.Rent(capacity); - _index = 0; + _arrayPool = ArrayPool.Shared; } public Span WrittenSpan => new Span(_array, 0, _index); - public byte[] Array => _array; - - public int FreeCapacity => _array.Length - _index; + public ReadOnlyMemory WrittenMemory => new ReadOnlyMemory(_array, 0, _index); public void Advance(int count) { + Debug.Assert(_array is not null); + Debug.Assert(count >= 0); + Debug.Assert(_index + count <= _array.Length); + _index += count; + } - Debug.Assert(_index >= 0 || _index < _array.Length); + public Memory GetMemory(int sizeHint = 0) + { + EnsureFreeCapacity(sizeHint); + return _array.AsMemory(_index); } public Span GetSpan(int sizeHint = 0) { - if (sizeHint == 0) - sizeHint = 1; - - if (sizeHint > FreeCapacity) - { - var newArray = ArrayPool.Shared.Rent(_array.Length + sizeHint); - _array.AsSpan().CopyTo(newArray); - - ArrayPool.Shared.Return(_array); - _array = newArray; - } - + EnsureFreeCapacity(sizeHint); return _array.AsSpan(_index); } - } - - private abstract class Encoder : IDisposable - { - public abstract void Dispose(); - - internal abstract void Encode(ReadOnlySpan payload, ref Buffer buffer, bool continuation, bool endOfMessage, out byte reservedBits); - } - - /// - /// Deflate encoder which doesn't persist the deflator accross messages. - /// - private class Deflater : Encoder - { - private readonly int _windowBits; - - // Although the inflater isn't persisted accross messages, a single message - // might be split into multiple frames. - private WebSocketDeflater? _deflater; - - public Deflater(int windowBits) => _windowBits = windowBits; - - public override void Dispose() => _deflater?.Dispose(); - internal override void Encode(ReadOnlySpan payload, ref Buffer buffer, bool continuation, bool endOfMessage, out byte reservedBits) + public void Reset() { - Debug.Assert((continuation && _deflater is not null) || (!continuation && _deflater is null), - "Invalid state. The deflater was expected to be null if not continuation and not null otherwise."); - - _deflater ??= new WebSocketDeflater(_windowBits); - - Encode(payload, ref buffer, _deflater, endOfMessage); - reservedBits = continuation ? 0 : PerMessageDeflateBit; - - if (endOfMessage) + if (_array is not null) { - _deflater.Dispose(); - _deflater = null; + _arrayPool.Return(_array); + _array = null; + _index = 0; } } - public static void Encode(ReadOnlySpan payload, ref Buffer buffer, WebSocketDeflater deflater, bool final) + [MemberNotNull(nameof(_array))] + public void EnsureFreeCapacity(int sizeHint) { - while (payload.Length > 0) - { - deflater.Deflate(payload, buffer.GetSpan(payload.Length), out var consumed, out var written); - buffer.Advance(written); - - payload = payload.Slice(consumed); - } - - // See comment by Mark Adler https://github.com/madler/zlib/issues/149#issuecomment-225237457 - // At that point there will be at most a few bits left to write. - // Then call deflate() with Z_FULL_FLUSH and no more input and at least six bytes of available output. - var bytesWritten = deflater.Finish(buffer.GetSpan(6), out var completed); - buffer.Advance(bytesWritten); - - Debug.Assert(completed); - - // The deflated block always ends with 0x00 0x00 0xFF 0xFF - Debug.Assert( - buffer.WrittenSpan[^4] == 0x00 && - buffer.WrittenSpan[^3] == 0x00 && - buffer.WrittenSpan[^2] == 0xFF && - buffer.WrittenSpan[^1] == 0xFF); + if (sizeHint == 0) + sizeHint = 1; - if (final) + if (_array is null) { - // As per RFC we need to remove the flush markers - // 0x00 0x00 0xFF 0xFF - buffer.Advance(-4); + _array = _arrayPool.Rent(sizeHint); + return; } - } - } - /// - /// Deflate encoder which persists the deflator state accross messages. - /// - private sealed class PersistedDeflater : Encoder - { - private readonly WebSocketDeflater _deflater; - - public PersistedDeflater(int windowBits) => _deflater = new(windowBits); - - public override void Dispose() => _deflater.Dispose(); + if (sizeHint > (_array.Length - _index)) + { + var newArray = _arrayPool.Rent(_array.Length + sizeHint); + _array.AsSpan().CopyTo(newArray); - internal override void Encode(ReadOnlySpan payload, ref Buffer buffer, bool continuation, bool endOfMessage, out byte reservedBits) - { - Deflater.Encode(payload, ref buffer, _deflater, endOfMessage); - reservedBits = continuation ? 0 : PerMessageDeflateBit; + _arrayPool.Return(_array); + _array = newArray; + } } } } From 7c0c51491d889938e9a9d52ca4c0e9293f20e7ad Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 19 Feb 2021 14:25:10 +0200 Subject: [PATCH 29/52] Parsing websocket deflate headers allocation free. --- .../src/System.Net.WebSockets.Client.csproj | 1 + .../Net/WebSockets/WebSocketHandle.Managed.cs | 53 +++++++++++-------- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj index b74f3d8962be6..4297d189ec2f9 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj +++ b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj @@ -37,6 +37,7 @@ +
diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index 8316a94a63c03..3aab249544b6c 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -179,11 +179,11 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli // Because deflate options are negotiated we need a new object WebSocketDeflateOptions? deflateOptions = null; - if (options.DeflateOptions is not null && response.Headers.TryGetValues(HttpKnownHeaderNames.SecWebSocketExtensions, out var extensions)) + if (options.DeflateOptions is not null && response.Headers.TryGetValues(HttpKnownHeaderNames.SecWebSocketExtensions, out IEnumerable? extensions)) { - foreach (var extension in extensions) + foreach (ReadOnlySpan extension in extensions) { - if (extension.StartsWith("permessage-deflate")) + if (extension.TrimStart().StartsWith("permessage-deflate")) { deflateOptions = ParseDeflateOptions(extension, options.DeflateOptions); break; @@ -236,30 +236,41 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli } } - private static WebSocketDeflateOptions ParseDeflateOptions(string extensions, WebSocketDeflateOptions original) + private static WebSocketDeflateOptions ParseDeflateOptions(ReadOnlySpan extension, WebSocketDeflateOptions original) { var options = new WebSocketDeflateOptions(); - foreach (var value in extensions.Split(';', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries)) + while (true) { - if (value == "client_no_context_takeover") - { - options.ClientContextTakeover = false; - } - else if (value == "server_no_context_takeover") - { - options.ServerContextTakeover = false; - } - else if (value.StartsWith("client_max_window_bits=")) - { - options.ClientMaxWindowBits = int.Parse(value.Substring("client_max_window_bits=".Length), - NumberFormatInfo.InvariantInfo); - } - else if (value.StartsWith("server_max_window_bits=")) + int end = extension.IndexOf(';'); + ReadOnlySpan value = (end >= 0 ? extension[..end] : extension).Trim(); + + if (!value.IsEmpty) { - options.ServerMaxWindowBits = int.Parse(value.Substring("server_max_window_bits=".Length), - NumberFormatInfo.InvariantInfo); + if (value == "client_no_context_takeover") + { + options.ClientContextTakeover = false; + } + else if (value == "server_no_context_takeover") + { + options.ServerContextTakeover = false; + } + else if (value.StartsWith("client_max_window_bits=")) + { + options.ClientMaxWindowBits = int.Parse(value["client_max_window_bits=".Length..], + provider: CultureInfo.InvariantCulture); + } + else if (value.StartsWith("server_max_window_bits=")) + { + options.ServerMaxWindowBits = int.Parse(value["server_max_window_bits=".Length..], + provider: CultureInfo.InvariantCulture); + } } + + if (end < 0) + break; + + extension = extension[(end + 1)..]; } if (options.ClientMaxWindowBits > original.ClientMaxWindowBits) From 148c84263dbd33217ff8d04c9ca42afdebf403bc Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 19 Feb 2021 15:12:53 +0200 Subject: [PATCH 30/52] Addressing code style and pr feedback. --- .../Compression/WebSocketDeflater.cs | 21 +++++------ .../Compression/WebSocketInflater.cs | 8 ++--- .../WebSockets/ManagedWebSocket.Receiver.cs | 36 ++++++++----------- .../Net/WebSockets/ManagedWebSocket.Sender.cs | 20 +++++------ .../System/Net/WebSockets/ManagedWebSocket.cs | 6 ++-- 5 files changed, 40 insertions(+), 51 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs index ca6cff6c5a89e..8c4923e0420e0 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -35,32 +35,27 @@ public void Deflate(ReadOnlySpan payload, IBufferWriter output, bool if (_stream is null) Initialize(); - while (payload.Length > 0) + while (!payload.IsEmpty) { - Deflate(payload, output.GetSpan(payload.Length), out var consumed, out var written); + Deflate(payload, output.GetSpan(payload.Length), out int consumed, out int written); output.Advance(written); - payload = payload.Slice(consumed); + payload = payload[consumed..]; } // See comment by Mark Adler https://github.com/madler/zlib/issues/149#issuecomment-225237457 // At that point there will be at most a few bits left to write. // Then call deflate() with Z_FULL_FLUSH and no more input and at least six bytes of available output. Span end = stackalloc byte[6]; - var count = Flush(end); + int count = Flush(end); end = end.Slice(0, count); - // The deflated block always ends with 0x00 0x00 0xFF 0xFF - Debug.Assert(count >= 4); - Debug.Assert(end[^4] == 0x00 && - end[^3] == 0x00 && - end[^2] == 0xFF && - end[^1] == 0xFF); + Debug.Assert(end.EndsWith(WebSocketInflater.FlushMarker), "The deflated block must always end with a flush marker."); if (endOfMessage) { // As per RFC we need to remove the flush markers - end = end.Slice(0, end.Length - 4); + end = end[..^4]; } end.CopyTo(output.GetSpan(end.Length)); @@ -112,8 +107,8 @@ private unsafe int Flush(Span output) _stream.NextOut = (IntPtr)fixedOutput; _stream.AvailOut = (uint)output.Length; - var errorCode = Deflate(_stream, (FlushCode)3/*Z_FULL_FLUSH*/); - var writtenBytes = output.Length - (int)_stream.AvailOut; + ErrorCode errorCode = Deflate(_stream, (FlushCode)3/*Z_FULL_FLUSH*/); + int writtenBytes = output.Length - (int)_stream.AvailOut; Debug.Assert(errorCode == ErrorCode.Ok); diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs index 51f0d43475c7f..fd6286c531b16 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs @@ -13,14 +13,14 @@ namespace System.Net.WebSockets.Compression /// internal sealed class WebSocketInflater : IDisposable { - private static ReadOnlySpan FlushMarker => new byte[] { 0x00, 0x00, 0xFF, 0xFF }; + internal static ReadOnlySpan FlushMarker => new byte[] { 0x00, 0x00, 0xFF, 0xFF }; private ZLibStreamHandle? _stream; private readonly int _windowBits; private readonly bool _persisted; /// - /// There is now way of knowing when decoding data if the underlying deflater + /// There is no way of knowing, when decoding data, if the underlying deflater /// has flushed all outstanding data to consumer other than to provide a buffer /// and see whether any bytes are written. There are cases when the consumers /// provide a buffer exactly the size of the uncompressed data and in this case @@ -90,7 +90,7 @@ public bool Finish(Span output, out int written) written = 0; - if (output.Length == 0) + if (output.IsEmpty) { if (_remainingByte is not null) return false; @@ -110,7 +110,7 @@ public bool Finish(Span output, out int written) _remainingByte = null; } - written += Inflate(_stream, output.Slice(written)); + written += Inflate(_stream, output[written..]); if (written < output.Length || IsFinished(_stream, out _remainingByte)) { diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs index b5e6addedf4d1..5cab018db1dc1 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs @@ -168,7 +168,7 @@ public async ValueTask WaitForServerToCloseConnectionAsync(CancellationToken can while (_lastHeader.PayloadLength > _readBuffer.AvailableLength) { - var byteCount = await _stream.ReadAsync(_readBuffer.FreeMemory, cancellationToken).ConfigureAwait(false); + int byteCount = await _stream.ReadAsync(_readBuffer.FreeMemory, cancellationToken).ConfigureAwait(false); if (byteCount <= 0) return null; @@ -178,7 +178,9 @@ public async ValueTask WaitForServerToCloseConnectionAsync(CancellationToken can // Update the payload length in the header to indicate // that we've received everything we need. - var payload = _readBuffer.Consume((int)_lastHeader.PayloadLength); + ReadOnlyMemory payload = _readBuffer.AvailableMemory.Slice(0, (int)_lastHeader.PayloadLength); + + _readBuffer.Consume(payload.Length); _lastHeader.PayloadLength = 0; return new ControlMessage(_lastHeader.Opcode, payload); @@ -192,13 +194,13 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella if (!_inflateFinished) { Debug.Assert(_inflater is not null); - _inflateFinished = _inflater.Finish(buffer.Span, out var written); + _inflateFinished = _inflater.Finish(buffer.Span, out int written); return Result(written); } _readBuffer.DiscardConsumed(); - var success = await ReceiveHeaderAsync(cancellationToken).ConfigureAwait(false); + bool success = await ReceiveHeaderAsync(cancellationToken).ConfigureAwait(false); if (!success) return Result(_headerError is not null ? ReceiveResultType.HeaderError : ReceiveResultType.ConnectionClose); @@ -215,7 +217,7 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella return default; // The number of bytes that are copied onto the provided buffer - var resultByteCount = 0; + int resultByteCount = 0; if (_readBuffer.AvailableLength > 0) { @@ -267,7 +269,7 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella buffer = buffer.Slice(0, (int)_lastHeader.PayloadLength); } - var bytesRead = await _stream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); + int bytesRead = await _stream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); if (bytesRead <= 0) return Result(ReceiveResultType.ConnectionClose); @@ -279,7 +281,7 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella if (_decoderInputBuffer is null) { // Rent a buffer but restrict it's max size to 1MB - var decoderBufferLength = (int)Math.Min(_lastHeader.PayloadLength, 1_000_000); + int decoderBufferLength = (int)Math.Min(_lastHeader.PayloadLength, 1_000_000); _decoderInputBuffer = ArrayPool.Shared.Rent(decoderBufferLength); _decoderInputCount = await _stream.ReadAsync(_decoderInputBuffer.AsMemory(0, decoderBufferLength), cancellationToken).ConfigureAwait(false); @@ -297,7 +299,7 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella } _inflater.Inflate(input: _decoderInputBuffer.AsSpan(_decoderInputPosition, _decoderInputCount), - output: buffer.Span, out var consumed, out var written); + output: buffer.Span, out int consumed, out int written); resultByteCount += written; _decoderInputPosition += consumed; @@ -328,7 +330,8 @@ private async ValueTask ReceiveHeaderAsync(CancellationToken cancellationT while (true) { - if (TryParseMessageHeader(_readBuffer.AvailableSpan, _lastHeader, _isServer, out var header, out var error, out var consumedBytes)) + if (TryParseMessageHeader(_readBuffer.AvailableSpan, _lastHeader, _isServer, + out MessageHeader header, out string? error, out int consumedBytes)) { if (header.Compressed && _inflater is null) { @@ -364,7 +367,7 @@ private async ValueTask ReceiveHeaderAsync(CancellationToken cancellationT } // More data is neeed to parse the header - var byteCount = await _stream.ReadAsync(_readBuffer.FreeMemory, cancellationToken).ConfigureAwait(false); + int byteCount = await _stream.ReadAsync(_readBuffer.FreeMemory, cancellationToken).ConfigureAwait(false); if (byteCount <= 0) return false; @@ -421,18 +424,9 @@ public Buffer(int capacity) public int FreeLength => _bytes.Length - _position; - public void Commit(int count) - { - _position += count; - } - - public Memory Consume(int count) - { - var memory = new Memory(_bytes, _consumed, count); - _consumed += count; + public void Commit(int count) => _position += count; - return memory; - } + public void Consume(int count) => _consumed += count; public void DiscardConsumed() { diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs index c28e289155431..04097a9b52090 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs @@ -44,7 +44,7 @@ public ValueTask SendAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemo { bool compressed = false; - // Encoding is onlt supported for user messages + // Compression is only supported for user messages if (_deflater is not null && opcode <= MessageOpcode.Binary) { _buffer.EnsureFreeCapacity(MaxMessageHeaderLength + (int)(content.Length * 0.6)); @@ -53,7 +53,7 @@ public ValueTask SendAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemo _deflater.Deflate(content.Span, _buffer, continuation: opcode == MessageOpcode.Continuation, endOfMessage); compressed = true; } - else if (content.Length > 0) + else if (!content.IsEmpty) { _buffer.EnsureFreeCapacity(MaxMessageHeaderLength + content.Length); _buffer.Advance(MaxMessageHeaderLength); @@ -62,28 +62,28 @@ public ValueTask SendAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemo _buffer.Advance(content.Length); } - var payload = _buffer.WrittenSpan.Slice(MaxMessageHeaderLength); - var headerLength = CalculateHeaderLength(payload.Length); + Span payload = _buffer.WrittenSpan.Slice(MaxMessageHeaderLength); + int headerLength = CalculateHeaderLength(payload.Length); // Because we want the header to come just before to the payload // we will use a slice that offsets the unused part. - var headerOffset = MaxMessageHeaderLength - headerLength; - var header = _buffer.WrittenSpan.Slice(headerOffset, headerLength); + int headerOffset = MaxMessageHeaderLength - headerLength; + Span header = _buffer.WrittenSpan.Slice(headerOffset, headerLength); // Write the message header data to the buffer. EncodeHeader(header, opcode, endOfMessage, payload.Length, compressed); // If we added a mask to the header, XOR the payload with the mask. - if (payload.Length > 0 && _maskLength > 0) + if (!payload.IsEmpty && _maskLength > 0) { ApplyMask(payload, BitConverter.ToInt32(header.Slice(header.Length - MaskLength)), 0); } - var resetBuffer = true; + bool resetBuffer = true; try { - var sendTask = _stream.WriteAsync(_buffer.WrittenMemory.Slice(headerOffset), cancellationToken); + ValueTask sendTask = _stream.WriteAsync(_buffer.WrittenMemory.Slice(headerOffset), cancellationToken); if (sendTask.IsCompleted) return sendTask; @@ -245,7 +245,7 @@ public void EnsureFreeCapacity(int sizeHint) if (sizeHint > (_array.Length - _index)) { - var newArray = _arrayPool.Rent(_array.Length + sizeHint); + byte[] newArray = _arrayPool.Rent(_array.Length + sizeHint); _array.AsSpan().CopyTo(newArray); _arrayPool.Return(_array); diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 7ac2cc8f62feb..0587eaf1be67a 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -541,7 +541,7 @@ private async ValueTask ReceiveAsyncPrivate ReceiveAsyncPrivate ReceiveAsyncPrivate Date: Fri, 19 Feb 2021 15:28:54 +0200 Subject: [PATCH 31/52] A few style improvements based on feedback. --- .../Net/WebSockets/Compression/WebSocketDeflater.cs | 2 +- .../Net/WebSockets/Compression/WebSocketInflater.cs | 8 ++++---- .../System/Net/WebSockets/ManagedWebSocket.Receiver.cs | 8 ++++---- .../src/System/Net/WebSockets/ManagedWebSocket.Sender.cs | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs index 8c4923e0420e0..3e8603eefc259 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -49,7 +49,7 @@ public void Deflate(ReadOnlySpan payload, IBufferWriter output, bool Span end = stackalloc byte[6]; int count = Flush(end); - end = end.Slice(0, count); + end = end[..count]; Debug.Assert(end.EndsWith(WebSocketInflater.FlushMarker), "The deflated block must always end with a flush marker."); if (endOfMessage) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs index fd6286c531b16..d05aaf48f457b 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs @@ -133,7 +133,7 @@ private void OnFinished() } } - private static bool IsFinished(ZLibStreamHandle stream, out byte? remainingByte) + private static unsafe bool IsFinished(ZLibStreamHandle stream, out byte? remainingByte) { if (stream.AvailIn > 0) { @@ -143,14 +143,14 @@ private static bool IsFinished(ZLibStreamHandle stream, out byte? remainingByte) // There is no other way to make sure that we'e consumed all data // but to try to inflate again with at least one byte of output buffer. - Span oneByte = stackalloc byte[1]; - if (Inflate(stream, oneByte) == 0) + byte b; + if (Inflate(stream, new Span(&b, 1)) == 0) { remainingByte = null; return true; } - remainingByte = oneByte[0]; + remainingByte = b; return false; } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs index 5cab018db1dc1..59b523c6d617f 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs @@ -96,7 +96,7 @@ public Receiver(Stream stream, WebSocketCreationOptions options) // Create a buffer just large enough to handle received packet headers (at most 14 bytes) and // control payloads (at most 125 bytes). Message payloads are read directly into the buffer // supplied to ReceiveAsync. - _readBuffer = new(MaxControlPayloadLength + MaxMessageHeaderLength); + _readBuffer = new Buffer(MaxControlPayloadLength + MaxMessageHeaderLength); var deflate = options.DeflateOptions; @@ -254,7 +254,7 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella return Result(resultByteCount); } - buffer = buffer.Slice(written); + buffer = buffer[written..]; } // At this point we should have consumed everything from the buffer @@ -265,7 +265,7 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella { if (buffer.Length > _lastHeader.PayloadLength) { - // We don't want to receive more that we need + // We don't want to receive more than we need buffer = buffer.Slice(0, (int)_lastHeader.PayloadLength); } @@ -313,7 +313,7 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella if (_lastHeader.PayloadLength == 0 && _lastHeader.Fin) { - _inflateFinished = _inflater.Finish(buffer.Span.Slice(written), out written); + _inflateFinished = _inflater.Finish(buffer.Span[written..], out written); resultByteCount += written; } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs index 04097a9b52090..4ab9a9e65a206 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs @@ -196,7 +196,7 @@ public Buffer() _arrayPool = ArrayPool.Shared; } - public Span WrittenSpan => new Span(_array, 0, _index); + public Span WrittenSpan => _array.AsSpan(0, _index); public ReadOnlyMemory WrittenMemory => new ReadOnlyMemory(_array, 0, _index); From 65119823e53e3b2f944085bef93089422c8eb34f Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 19 Feb 2021 16:27:45 +0200 Subject: [PATCH 32/52] Stream is never null. --- .../src/System/Net/WebSockets/ManagedWebSocket.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 0587eaf1be67a..afd82475f582d 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -190,7 +190,7 @@ private void DisposeCore() { _disposed = true; _keepAliveTimer?.Dispose(); - _stream?.Dispose(); + _stream.Dispose(); _sender.Dispose(); _receiver.Dispose(); From cb3881727d49052a2a08f5e5236e38a11f710624 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 19 Feb 2021 16:28:39 +0200 Subject: [PATCH 33/52] Fixed a bug in the deflater when incompressible message is being sent. --- .../Compression/WebSocketDeflater.cs | 25 ++++++++++++++----- .../tests/WebSocketDeflateTests.cs | 8 +++--- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs index 3e8603eefc259..e234803dd4779 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -43,23 +43,36 @@ public void Deflate(ReadOnlySpan payload, IBufferWriter output, bool payload = payload[consumed..]; } + // There is a catch here. If the payload we're trying to compress isn't really compressable + // then the resulting output will be larger. And in this case although we might have processed the input + // more output might be available. The only way to check for this scenario is to do another deflate + // attempt, but without any input this time. + while (true) + { + Deflate(null, output.GetSpan(), out int consumed, out int written); + Debug.Assert(consumed == 0); + + if (written == 0) + break; + + output.Advance(written); + } + // See comment by Mark Adler https://github.com/madler/zlib/issues/149#issuecomment-225237457 // At that point there will be at most a few bits left to write. // Then call deflate() with Z_FULL_FLUSH and no more input and at least six bytes of available output. - Span end = stackalloc byte[6]; + Span end = output.GetSpan(6); int count = Flush(end); - end = end[..count]; - Debug.Assert(end.EndsWith(WebSocketInflater.FlushMarker), "The deflated block must always end with a flush marker."); + Debug.Assert(end[..count].EndsWith(WebSocketInflater.FlushMarker), "The deflated block must always end with a flush marker."); if (endOfMessage) { // As per RFC we need to remove the flush markers - end = end[..^4]; + count -= 4; } - end.CopyTo(output.GetSpan(end.Length)); - output.Advance(end.Length); + output.Advance(count); if (endOfMessage && !_persisted) { diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs index f4da4cb37cd15..9ab7875d4935b 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs @@ -1,6 +1,5 @@ using System.Collections.Generic; using System.Diagnostics; -using System.IO; using System.Security.Cryptography; using System.Text; using System.Threading; @@ -193,9 +192,12 @@ public async Task LargeMessageSplitInMultipleFrames(int windowBits) } }); - Memory testData = File.ReadAllBytes(typeof(WebSocketDeflateTests).Assembly.Location).AsMemory().TrimEnd((byte)0); + Memory testData = new byte[ushort.MaxValue]; Memory receivedData = new byte[testData.Length]; + // Make the data incompressible to make sure that the output is larger than the input + RandomNumberGenerator.Fill(testData.Span); + // Test it a few times with different frame sizes for (var i = 0; i < 10; ++i) { @@ -212,8 +214,8 @@ public async Task LargeMessageSplitInMultipleFrames(int windowBits) position += currentFrameSize; } + Assert.True(testData.Length < stream.Remote.Available, "The compressed data should be bigger."); Assert.Equal(testData.Length, position); - Assert.True(testData.Length > stream.Remote.Available, "The data must be compressed."); // Receive the data from the client side receivedData.Span.Clear(); From a7c049a249784569f62ce653f36be50eba4bad1a Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 19 Feb 2021 17:48:29 +0200 Subject: [PATCH 34/52] Added a test that sends / receives uncompressed messages with different sizes. Fixed a bug when receiving 0 byte message. Fixed another bug where the payload length wasn't reduced after receiving non compressed message. --- .../WebSockets/ManagedWebSocket.Receiver.cs | 3 +- .../Net/WebSockets/ManagedWebSocket.Sender.cs | 5 ++++ .../tests/WebSocketTests.cs | 29 +++++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs index 59b523c6d617f..ef3c63aac7045 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs @@ -214,7 +214,7 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella } if (buffer.IsEmpty) - return default; + return Result(count: 0); // The number of bytes that are copied onto the provided buffer int resultByteCount = 0; @@ -274,6 +274,7 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella return Result(ReceiveResultType.ConnectionClose); resultByteCount += bytesRead; + _lastHeader.PayloadLength -= bytesRead; ApplyMask(buffer.Span.Slice(0, bytesRead)); } else diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs index 4ab9a9e65a206..d3246189da8dd 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs @@ -61,6 +61,11 @@ public ValueTask SendAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemo content.Span.CopyTo(_buffer.GetSpan(content.Length)); _buffer.Advance(content.Length); } + else + { + _buffer.EnsureFreeCapacity(MaxMessageHeaderLength); + _buffer.Advance(MaxMessageHeaderLength); + } Span payload = _buffer.WrittenSpan.Slice(MaxMessageHeaderLength); int headerLength = CalculateHeaderLength(payload.Length); diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs index ad738a00ec864..18e5b05433d8d 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.IO; +using System.Security.Cryptography; +using System.Threading.Tasks; using Xunit; namespace System.Net.WebSockets.Tests @@ -171,6 +173,33 @@ public void ValueWebSocketReceiveResult_Ctor_ValidArguments_Roundtrip(int count, Assert.Equal(endOfMessage, r.EndOfMessage); } + [Theory] + [InlineData(0)] + [InlineData(125)] + [InlineData(ushort.MaxValue)] + [InlineData(ushort.MaxValue * 2)] + public async Task SendUncompressedClientMessage(int messageSize) + { + var stream = new WebSocketStream(); + using var server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + IsServer = true + }); + using var client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions()); + + var message = new byte[messageSize]; + RandomNumberGenerator.Fill(message); + + await client.SendAsync(message, WebSocketMessageType.Binary, true, default); + + var buffer = new byte[messageSize]; + var result = await server.ReceiveAsync(buffer, default); + + Assert.Equal(messageSize, result.Count); + Assert.True(result.EndOfMessage); + Assert.True(message.AsSpan().SequenceEqual(buffer)); + } + public abstract class ExposeProtectedWebSocket : WebSocket { public static new bool IsStateTerminal(WebSocketState state) => From 2f11043116dad366a949eea0741f58a826c0647c Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Fri, 19 Feb 2021 19:59:10 +0200 Subject: [PATCH 35/52] Added more tests that test control messages. Fixed a bug in websocket during client initiated close, the server would not respond with close message and the client's close task would hang until timeout or server closes the connection. --- .../System/Net/WebSockets/ManagedWebSocket.cs | 19 ++++- .../tests/WebSocketCreateTest.cs | 4 +- .../tests/WebSocketDeflateTests.cs | 20 ++--- .../tests/WebSocketStream.cs | 78 +++++++++++-------- .../tests/WebSocketTests.cs | 67 ++++++++++++++++ 5 files changed, 140 insertions(+), 48 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index afd82475f582d..f75827c5426e6 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -361,7 +361,7 @@ public override ValueTask ReceiveAsync(Memory private Task ValidateAndReceiveAsync(Task receiveTask, CancellationToken cancellationToken) { - if ( receiveTask.IsCompletedSuccessfully && + if (receiveTask.IsCompletedSuccessfully && !(receiveTask is Task wsrr && wsrr.Result.MessageType == WebSocketMessageType.Close) && !(receiveTask is Task vwsrr && vwsrr.Result.MessageType == WebSocketMessageType.Close)) { @@ -670,7 +670,22 @@ private async ValueTask HandleReceivedCloseAsync(ReadOnlyMemory payload, C _closeStatus = closeStatus; _closeStatusDescription = closeStatusDescription; - if (!_isServer && _sentCloseFrame) + bool closeOutput = false; + + lock (StateUpdateLock) + { + if (!_sentCloseFrame) + { + _sentCloseFrame = true; + closeOutput = true; + } + } + + if (closeOutput) + { + await CloseOutputAsyncCore(closeStatus, closeStatusDescription, cancellationToken).ConfigureAwait(false); + } + else if (!_isServer) { await _receiver.WaitForServerToCloseConnectionAsync(cancellationToken).ConfigureAwait(false); } diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs index 66f0707c5c27a..0f28f1b4171b0 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs @@ -117,8 +117,8 @@ public async Task ReceiveAsync_InvalidFrameHeader_AbortsAndThrowsException(byte { var stream = new WebSocketStream(); - stream.Write(firstByte, secondByte, (byte)'a'); - using var websocket = CreateFromStream(stream.Remote, isServer: false, null, Timeout.InfiniteTimeSpan); + stream.Enqueue(firstByte, secondByte, (byte)'a'); + using var websocket = CreateFromStream(stream, isServer: false, null, Timeout.InfiniteTimeSpan); var buffer = new byte[1]; Task t = websocket.ReceiveAsync(buffer, CancellationToken.None); diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs index 9ab7875d4935b..12e9d20139903 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs @@ -39,8 +39,8 @@ public async Task HelloWithContextTakeover() { var stream = new WebSocketStream(); - stream.Write(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); - using var websocket = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + stream.Enqueue(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); + using var websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions { DeflateOptions = new() }); @@ -55,7 +55,7 @@ public async Task HelloWithContextTakeover() // Because context takeover is set by default if we try to send // the same message it would take fewer bytes. - stream.Write(0xc1, 0x05, 0xf2, 0x00, 0x11, 0x00, 0x00); + stream.Enqueue(0xc1, 0x05, 0xf2, 0x00, 0x11, 0x00, 0x00); buffer.AsSpan().Clear(); result = await websocket.ReceiveAsync(buffer, CancellationToken); @@ -70,7 +70,7 @@ public async Task HelloWithoutContextTakeover() { var stream = new WebSocketStream(); - using var websocket = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + using var websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions { DeflateOptions = new() { @@ -83,7 +83,7 @@ public async Task HelloWithoutContextTakeover() for (var i = 0; i < 100; ++i) { // Without context takeover the message should look the same every time - stream.Write(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); + stream.Enqueue(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); buffer.AsSpan().Clear(); var result = await websocket.ReceiveAsync(buffer, CancellationToken); @@ -100,7 +100,7 @@ public async Task TwoDeflateBlocksInOneMessage() { // Two or more DEFLATE blocks may be used in one message. var stream = new WebSocketStream(); - using var websocket = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + using var websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions { DeflateOptions = new() }); @@ -112,8 +112,8 @@ public async Task TwoDeflateBlocksInOneMessage() // following 4 octets(0x00 0x00 0xff 0xff), the header bits constitute // an empty DEFLATE block with no compression. A DEFLATE block // containing "llo" follows the empty DEFLATE block. - stream.Write(0x41, 0x08, 0xf2, 0x48, 0x05, 0x00, 0x00, 0x00, 0xff, 0xff); - stream.Write(0x80, 0x05, 0xca, 0xc9, 0xc9, 0x07, 0x00); + stream.Enqueue(0x41, 0x08, 0xf2, 0x48, 0x05, 0x00, 0x00, 0x00, 0xff, 0xff); + stream.Enqueue(0x80, 0x05, 0xca, 0xc9, 0xc9, 0x07, 0x00); Memory buffer = new byte[5]; var result = await websocket.ReceiveAsync(buffer, CancellationToken); @@ -247,8 +247,8 @@ public async Task WebSocketWithoutDeflateShouldThrowOnCompressedMessage() { var stream = new WebSocketStream(); - stream.Write(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); - using var websocket = WebSocket.CreateFromStream(stream.Remote, new()); + stream.Enqueue(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); + using var websocket = WebSocket.CreateFromStream(stream, new()); var exception = await Assert.ThrowsAsync(() => websocket.ReceiveAsync(Memory.Empty, CancellationToken).AsTask()); diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs index b6ca90ed0860c..c20107a88f4ae 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs @@ -13,6 +13,7 @@ public class WebSocketStream : Stream { private readonly SemaphoreSlim _inputLock = new(initialCount: 0); private readonly Queue _inputQueue = new(); + private readonly CancellationTokenSource _disposed = new(); public WebSocketStream() { @@ -39,7 +40,7 @@ public int Available lock (_inputQueue) { - foreach ( var x in _inputQueue) + foreach (var x in _inputQueue) { available += x.AvailableLength; } @@ -49,6 +50,22 @@ public int Available } } + public Span NextAvailableBytes + { + get + { + lock (_inputQueue) + { + var block = _inputQueue.Peek(); + + if (block is null) + return default; + + return block.Available; + } + } + } + public override bool CanRead => true; public override bool CanSeek => false; @@ -61,30 +78,30 @@ public int Available protected override void Dispose(bool disposing) { - _inputLock.Dispose(); - - lock (Remote._inputQueue) + if (!_disposed.IsCancellationRequested) { - try + _disposed.Cancel(); + + lock (Remote._inputQueue) { Remote._inputLock.Release(); Remote._inputQueue.Enqueue(Block.ConnectionClosed); } - catch (ObjectDisposedException) - { - } } } public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) { - try + using (var cancellation = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _disposed.Token)) { - await _inputLock.WaitAsync(cancellationToken); - } - catch (ObjectDisposedException) - { - return 0; + try + { + await _inputLock.WaitAsync(cancellation.Token).ConfigureAwait(false); + } + catch (OperationCanceledException) when (_disposed.IsCancellationRequested) + { + return 0; + } } lock (_inputQueue) @@ -104,26 +121,25 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation } else { - try - { - // Because we haven't fully consumed the buffer - // we should release once the input lock so we can acquire - // it again on consequent receive. - _inputLock.Release(); - } - catch (ObjectDisposedException) { } + // Because we haven't fully consumed the buffer + // we should release once the input lock so we can acquire + // it again on consequent receive. + _inputLock.Release(); } return count; } } - public void Write(params byte[] data) + /// + /// Receives the data and enqueues it for processing. + /// + public void Enqueue(params byte[] data) { - lock (Remote._inputQueue) + lock (_inputQueue) { - Remote._inputLock.Release(); - Remote._inputQueue.Enqueue(new Block(data)); + _inputLock.Release(); + _inputQueue.Enqueue(new Block(data)); } } @@ -131,14 +147,8 @@ public override void Write(ReadOnlySpan buffer) { lock (Remote._inputQueue) { - try - { - Remote._inputLock.Release(); - Remote._inputQueue.Enqueue(new Block(buffer.ToArray())); - } - catch (ObjectDisposedException) - { - } + Remote._inputLock.Release(); + Remote._inputQueue.Enqueue(new Block(buffer.ToArray())); } } diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs index 18e5b05433d8d..c864822cfd7b6 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs @@ -3,6 +3,7 @@ using System.IO; using System.Security.Cryptography; +using System.Threading; using System.Threading.Tasks; using Xunit; @@ -200,6 +201,72 @@ public async Task SendUncompressedClientMessage(int messageSize) Assert.True(message.AsSpan().SequenceEqual(buffer)); } + [Fact] + public async Task WhenPingReceivedPongMessageMustBeSent() + { + var stream = new WebSocketStream(); + using var server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + // Use server so we don't receive any masked payload + IsServer = true + }); + using var cancellation = new CancellationTokenSource(); + + stream.Enqueue(0b1000_1001, 0x00); + var receiveTask = server.ReceiveAsync(Memory.Empty, cancellation.Token).AsTask(); + + Assert.Equal(0, stream.Available); + Assert.Equal(2, stream.Remote.Available); + Assert.Equal(new byte[] { 0b1000_1010, 0x00 }, stream.Remote.NextAvailableBytes.ToArray()); + + cancellation.Cancel(); + await Assert.ThrowsAsync(async () => await receiveTask.ConfigureAwait(false)); + } + + [Fact] + public async Task WhenPongReceivedNothingShouldBeSentBack() + { + var stream = new WebSocketStream(); + using var client = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions()); + + using var cancellation = new CancellationTokenSource(); + + stream.Enqueue(0b1000_1010, 0x00); + var receiveTask = client.ReceiveAsync(Memory.Empty, cancellation.Token).AsTask(); + + Assert.Equal(0, stream.Available); + Assert.Equal(0, stream.Remote.Available); + + cancellation.Cancel(); + await Assert.ThrowsAsync(async () => await receiveTask.ConfigureAwait(false)); + } + + [Fact] + public async Task ClosingWebSocketsGracefully() + { + var stream = new WebSocketStream(); + using var cancellation = new CancellationTokenSource(TimeSpan.FromSeconds(155)); + using var client = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions()); + using var server = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + IsServer = true + }); + + var clientClose = client.CloseAsync(WebSocketCloseStatus.PolicyViolation, "Yeet", cancellation.Token); + var result = await server.ReceiveAsync(Memory.Empty, cancellation.Token); + + Assert.True(result.EndOfMessage); + Assert.Equal(WebSocketMessageType.Close, result.MessageType); + Assert.Equal(0, result.Count); + Assert.Equal("Yeet", server.CloseStatusDescription); + Assert.Equal(WebSocketCloseStatus.PolicyViolation, server.CloseStatus); + + await clientClose; + + Assert.Equal(WebSocketState.Closed, server.State); + Assert.Equal(WebSocketState.Closed, client.State); + } + public abstract class ExposeProtectedWebSocket : WebSocket { public static new bool IsStateTerminal(WebSocketState state) => From 7303fe240377fb44ca0f4a18134eb98a4e597ebf Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Sat, 20 Feb 2021 17:02:25 +0200 Subject: [PATCH 36/52] Fixed definitions in ref assemblies. --- .../ref/System.Net.WebSockets.Client.cs | 2 +- .../ref/System.Net.WebSockets.cs | 22 +++++++++---------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs index beaa33f9226dc..660a2c5fbe748 100644 --- a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs +++ b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs @@ -36,7 +36,7 @@ internal ClientWebSocketOptions() { } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.TimeSpan KeepAliveInterval { get { throw null; } set { } } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] - public WebSocketDeflateOptions? DeflateOptions { get { throw null; } set { } } + public System.Net.WebSockets.WebSocketDeflateOptions? DeflateOptions { get { throw null; } set { } } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.Net.IWebProxy? Proxy { get { throw null; } set { } } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] diff --git a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs index e1fff547f2124..12431032cc346 100644 --- a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs +++ b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs @@ -135,20 +135,18 @@ public enum WebSocketState Closed = 5, Aborted = 6, } - - public sealed class WebSocketCreationOptions + public sealed partial class WebSocketCreationOptions { - public bool IsServer { get; set; } - public string? SubProtocol { get; set; } - public TimeSpan KeepAliveInterval { get; set; } - public WebSocketDeflateOptions? DeflateOptions { get; set; } + public bool IsServer { get { throw null; } set { } } + public string? SubProtocol { get { throw null; } set { } } + public System.TimeSpan KeepAliveInterval { get { throw null; } set { } } + public System.Net.WebSockets.WebSocketDeflateOptions? DeflateOptions { get { throw null; } set { } } } - - public sealed class WebSocketDeflateOptions + public sealed partial class WebSocketDeflateOptions { - public int ClientMaxWindowBits { get; set; } - public bool ClientContextTakeover { get; set; } - public int ServerMaxWindowBits { get; set; } - public bool ServerContextTakeover { get; set; } + public int ClientMaxWindowBits { get { throw null; } set { } } + public bool ClientContextTakeover { get { throw null; } set { } } + public int ServerMaxWindowBits { get { throw null; } set { } } + public bool ServerContextTakeover { get { throw null; } set { } } } } From 89b7e09414c02310a06c81de031759b0219ac49b Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Sat, 20 Feb 2021 17:15:24 +0200 Subject: [PATCH 37/52] Added links to RFC of each of the websocket deflate properties. Also added explanation why we don't support 8 bits althouh the RFC allows it. --- .../src/System/Net/WebSockets/WebSocketDeflateOptions.cs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs index a4044494f7cdc..49a5d03a69b2b 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs @@ -6,6 +6,10 @@ namespace System.Net.WebSockets /// /// Options to enable per-message deflate compression for . /// + /// + /// Although the WebSocket spec allows window bits from 8 to 15, the current implementation doesn't support 8 bits. + /// For more information refer to the zlib manual https://zlib.net/manual.html. + /// public sealed class WebSocketDeflateOptions { private int _clientMaxWindowBits = 15; @@ -15,6 +19,7 @@ public sealed class WebSocketDeflateOptions /// This parameter indicates the base-2 logarithm of the LZ77 sliding window size of the client context. /// Must be a value between 9 and 15. The default is 15. /// + /// https://tools.ietf.org/html/rfc7692#section-7.1.2.2 public int ClientMaxWindowBits { get => _clientMaxWindowBits; @@ -32,12 +37,14 @@ public int ClientMaxWindowBits /// When true the client-side of the connection indicates that it will persist the deflate context accross messages. /// The default is true. /// + /// https://tools.ietf.org/html/rfc7692#section-7.1.1.2 public bool ClientContextTakeover { get; set; } = true; /// /// This parameter indicates the base-2 logarithm of the LZ77 sliding window size of the server context. /// Must be a value between 9 and 15. The default is 15. /// + /// https://tools.ietf.org/html/rfc7692#section-7.1.2.1 public int ServerMaxWindowBits { get => _serverMaxWindowBits; @@ -55,6 +62,7 @@ public int ServerMaxWindowBits /// When true the server-side of the connection indicates that it will persist the deflate context accross messages. /// The default is true. /// + /// https://tools.ietf.org/html/rfc7692#section-7.1.1.1 public bool ServerContextTakeover { get; set; } = true; } } From 108da5883d05eba95024dd71e777e03ac464fe1d Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Sat, 20 Feb 2021 17:39:32 +0200 Subject: [PATCH 38/52] Addressing PR feedback - always using { } for single line blocks. --- .../Compression/WebSocketDeflater.cs | 5 +++- .../Compression/WebSocketInflater.cs | 6 +++-- .../WebSockets/ManagedWebSocket.Receiver.cs | 18 +++++++++----- .../Net/WebSockets/ManagedWebSocket.Sender.cs | 8 +++++-- .../System/Net/WebSockets/ManagedWebSocket.cs | 24 ++++++++++++------- .../WebSockets/WebSocketCreationOptions.cs | 6 +++-- .../Net/WebSockets/WebSocketDeflateOptions.cs | 6 +++-- .../tests/WebSocketStream.cs | 6 +++-- 8 files changed, 54 insertions(+), 25 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs index e234803dd4779..d10782b8b4252 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -33,8 +33,9 @@ public void Deflate(ReadOnlySpan payload, IBufferWriter output, bool Debug.Assert(!continuation || _stream is not null, "Invalid state. The stream should not be null in continuations."); if (_stream is null) + { Initialize(); - + } while (!payload.IsEmpty) { Deflate(payload, output.GetSpan(payload.Length), out int consumed, out int written); @@ -53,7 +54,9 @@ public void Deflate(ReadOnlySpan payload, IBufferWriter output, bool Debug.Assert(consumed == 0); if (written == 0) + { break; + } output.Advance(written); } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs index d05aaf48f457b..3c05af4ac2734 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs @@ -48,8 +48,9 @@ internal WebSocketInflater(int windowBits, bool persisted) public unsafe void Inflate(ReadOnlySpan input, Span output, out int consumed, out int written) { if (_stream is null) + { Initialize(); - + } fixed (byte* fixedInput = &MemoryMarshal.GetReference(input)) fixed (byte* fixedOutput = &MemoryMarshal.GetReference(output)) { @@ -93,8 +94,9 @@ public bool Finish(Span output, out int written) if (output.IsEmpty) { if (_remainingByte is not null) + { return false; - + } if (IsFinished(_stream, out _remainingByte)) { OnFinished(); diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs index ef3c63aac7045..73cc49f73a947 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs @@ -162,16 +162,18 @@ public async ValueTask WaitForServerToCloseConnectionAsync(CancellationToken can Debug.Assert(_lastHeader.Opcode > MessageOpcode.Binary); if (_lastHeader.PayloadLength == 0) + { return new ControlMessage(_lastHeader.Opcode, ReadOnlyMemory.Empty); - + } _readBuffer.DiscardConsumed(); while (_lastHeader.PayloadLength > _readBuffer.AvailableLength) { int byteCount = await _stream.ReadAsync(_readBuffer.FreeMemory, cancellationToken).ConfigureAwait(false); if (byteCount <= 0) + { return null; - + } ApplyMask(_readBuffer.FreeMemory.Span.Slice(0, (int)Math.Min(_lastHeader.PayloadLength, byteCount))); _readBuffer.Commit(byteCount); } @@ -203,8 +205,9 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella bool success = await ReceiveHeaderAsync(cancellationToken).ConfigureAwait(false); if (!success) + { return Result(_headerError is not null ? ReceiveResultType.HeaderError : ReceiveResultType.ConnectionClose); - + } if (_lastHeader.Opcode > MessageOpcode.Binary) { // The received message is a control message and it's up @@ -214,8 +217,9 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella } if (buffer.IsEmpty) + { return Result(count: 0); - + } // The number of bytes that are copied onto the provided buffer int resultByteCount = 0; @@ -271,8 +275,9 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella int bytesRead = await _stream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); if (bytesRead <= 0) + { return Result(ReceiveResultType.ConnectionClose); - + } resultByteCount += bytesRead; _lastHeader.PayloadLength -= bytesRead; ApplyMask(buffer.Span.Slice(0, bytesRead)); @@ -370,8 +375,9 @@ private async ValueTask ReceiveHeaderAsync(CancellationToken cancellationT // More data is neeed to parse the header int byteCount = await _stream.ReadAsync(_readBuffer.FreeMemory, cancellationToken).ConfigureAwait(false); if (byteCount <= 0) + { return false; - + } _readBuffer.Commit(byteCount); } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs index d3246189da8dd..fab2a0c544d76 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs @@ -91,15 +91,18 @@ public ValueTask SendAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemo ValueTask sendTask = _stream.WriteAsync(_buffer.WrittenMemory.Slice(headerOffset), cancellationToken); if (sendTask.IsCompleted) + { return sendTask; - + } resetBuffer = false; return WaitAsync(sendTask); } finally { if (resetBuffer) + { _buffer.Reset(); + } } } @@ -240,8 +243,9 @@ public void Reset() public void EnsureFreeCapacity(int sizeHint) { if (sizeHint == 0) + { sizeHint = 1; - + } if (_array is null) { _array = _arrayPool.Rent(sizeHint); diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index f75827c5426e6..0c25991bab5fa 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -549,8 +549,9 @@ private async ValueTask ReceiveAsyncPrivate ReceiveAsyncPrivate 15) + { throw new ArgumentOutOfRangeException(nameof(ClientMaxWindowBits), value, SR.Format(SR.net_WebSockets_ArgumentOutOfRange, 9, 15)); - + } _clientMaxWindowBits = value; } } @@ -51,9 +52,10 @@ public int ServerMaxWindowBits set { if (value < 9 || value > 15) + { throw new ArgumentOutOfRangeException(nameof(ServerMaxWindowBits), value, SR.Format(SR.net_WebSockets_ArgumentOutOfRange, 9, 15)); - + } _serverMaxWindowBits = value; } } diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs index c20107a88f4ae..bd46e40c73db0 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs @@ -59,8 +59,9 @@ public Span NextAvailableBytes var block = _inputQueue.Peek(); if (block is null) + { return default; - + } return block.Available; } } @@ -108,8 +109,9 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation { var block = _inputQueue.Peek(); if (block == Block.ConnectionClosed) + { return 0; - + } var count = Math.Min(block.AvailableLength, buffer.Length); block.Available.Slice(0, count).CopyTo(buffer.Span); From e0f98c383c365578845267d53e61725a468c2fb9 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Sat, 20 Feb 2021 18:10:49 +0200 Subject: [PATCH 39/52] Moving string literals into constants. --- .../src/System.Net.WebSockets.Client.csproj | 1 + .../ClientWebSocketDeflateConstants.cs | 16 ++++++ .../Net/WebSockets/WebSocketHandle.Managed.cs | 50 +++++++++++++------ 3 files changed, 51 insertions(+), 16 deletions(-) create mode 100644 src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketDeflateConstants.cs diff --git a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj index 4297d189ec2f9..e84ea02f895ba 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj +++ b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj @@ -6,6 +6,7 @@ + diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketDeflateConstants.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketDeflateConstants.cs new file mode 100644 index 0000000000000..39b5619b27085 --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketDeflateConstants.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Net.WebSockets +{ + internal static class ClientWebSocketDeflateConstants + { + public const string Extension = "permessage-deflate"; + + public const string ClientMaxWindowBits = "client_max_window_bits"; + public const string ClientNoContextTakeover = "client_no_context_takeover"; + + public const string ServerMaxWindowBits = "server_max_window_bits"; + public const string ServerNoContextTakeover = "server_no_context_takeover"; + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index 3aab249544b6c..784bdeb2d11e4 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -183,7 +183,7 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli { foreach (ReadOnlySpan extension in extensions) { - if (extension.TrimStart().StartsWith("permessage-deflate")) + if (extension.TrimStart().StartsWith(ClientWebSocketDeflateConstants.Extension)) { deflateOptions = ParseDeflateOptions(extension, options.DeflateOptions); break; @@ -247,23 +247,37 @@ private static WebSocketDeflateOptions ParseDeflateOptions(ReadOnlySpan ex if (!value.IsEmpty) { - if (value == "client_no_context_takeover") + if (value == ClientWebSocketDeflateConstants.ClientNoContextTakeover) { options.ClientContextTakeover = false; } - else if (value == "server_no_context_takeover") + else if (value == ClientWebSocketDeflateConstants.ServerNoContextTakeover) { options.ServerContextTakeover = false; } - else if (value.StartsWith("client_max_window_bits=")) + else if (value.StartsWith(ClientWebSocketDeflateConstants.ClientMaxWindowBits)) { - options.ClientMaxWindowBits = int.Parse(value["client_max_window_bits=".Length..], - provider: CultureInfo.InvariantCulture); + options.ClientMaxWindowBits = ParseWindowBits(value); } - else if (value.StartsWith("server_max_window_bits=")) + else if (value.StartsWith(ClientWebSocketDeflateConstants.ServerMaxWindowBits)) { - options.ServerMaxWindowBits = int.Parse(value["server_max_window_bits=".Length..], - provider: CultureInfo.InvariantCulture); + options.ServerMaxWindowBits = ParseWindowBits(value); + } + + static int ParseWindowBits(ReadOnlySpan value) + { + var startIndex = value.IndexOf('='); + + if (startIndex < 0 || + !int.TryParse(value.Slice(startIndex + 1), NumberStyles.Integer, CultureInfo.InvariantCulture, out int windowBits) || + windowBits < 9 || + windowBits > 15) + { + throw new WebSocketException(WebSocketError.HeaderError, + SR.Format(SR.net_WebSockets_InvalidResponseHeader, ClientWebSocketDeflateConstants.Extension, value.ToString())); + } + + return windowBits; } } @@ -308,33 +322,37 @@ private static void AddWebSocketHeaders(HttpRequestMessage request, string secKe static IEnumerable GetDeflateOptions(WebSocketDeflateOptions options) { - yield return "permessage-deflate"; + yield return ClientWebSocketDeflateConstants.Extension; if (options.ClientMaxWindowBits != 15) { - yield return "client_max_window_bits=" + options.ClientMaxWindowBits; + yield return $"{ClientWebSocketDeflateConstants.ClientMaxWindowBits}={options.ClientMaxWindowBits}"; } else { // Advertise that we support this option - yield return "client_max_window_bits"; + yield return ClientWebSocketDeflateConstants.ClientMaxWindowBits; } if (options.ServerMaxWindowBits != 15) { - yield return "server_max_window_bits=" + options.ServerMaxWindowBits; + yield return $"{ClientWebSocketDeflateConstants.ServerMaxWindowBits}={options.ServerMaxWindowBits}"; } else { // Advertise that we support this option - yield return "server_max_window_bits"; + yield return ClientWebSocketDeflateConstants.ServerMaxWindowBits; } if (!options.ServerContextTakeover) - yield return "server_no_context_takeover"; + { + yield return ClientWebSocketDeflateConstants.ServerNoContextTakeover; + } if (!options.ClientContextTakeover) - yield return "client_no_context_takeover"; + { + yield return ClientWebSocketDeflateConstants.ClientNoContextTakeover; + } } } } From c99534937d1d1decd136cff1e4a82c465e308a58 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Sat, 20 Feb 2021 18:15:33 +0200 Subject: [PATCH 40/52] Consistent Common links. --- .../src/System.Net.WebSockets.csproj | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj index 56575010d4430..64aaea87d46f0 100644 --- a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj +++ b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj @@ -21,18 +21,24 @@ - - - - + + + + - + - + From 1c2c47089854c3f9ca5cdf1c5f43ce93d3c35021 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Sat, 20 Feb 2021 18:18:01 +0200 Subject: [PATCH 41/52] Fully qualified parameter type in ref assembly. --- .../System.Net.WebSockets/ref/System.Net.WebSockets.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs index 12431032cc346..7b63abe187071 100644 --- a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs +++ b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs @@ -32,7 +32,7 @@ protected WebSocket() { } [System.Runtime.Versioning.UnsupportedOSPlatform("browser")] public static System.Net.WebSockets.WebSocket CreateFromStream(System.IO.Stream stream, bool isServer, string? subProtocol, System.TimeSpan keepAliveInterval) { throw null; } [System.Runtime.Versioning.UnsupportedOSPlatform("browser")] - public static System.Net.WebSockets.WebSocket CreateFromStream(System.IO.Stream stream, WebSocketCreationOptions options) { throw null; } + public static System.Net.WebSockets.WebSocket CreateFromStream(System.IO.Stream stream, System.Net.WebSockets.WebSocketCreationOptions options) { throw null; } public static System.ArraySegment CreateServerBuffer(int receiveBufferSize) { throw null; } public abstract void Dispose(); [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] From 172089a441b76c5d8e46f0a6aec89118556a12de Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Sun, 21 Feb 2021 16:46:09 +0200 Subject: [PATCH 42/52] Fixing failing tests. --- .../System/Net/WebSockets/ManagedWebSocket.cs | 17 +---------------- .../tests/WebSocketTests.cs | 1 + 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 0c25991bab5fa..58d20e8c96f72 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -672,22 +672,7 @@ private async ValueTask HandleReceivedCloseAsync(ReadOnlyMemory payload, C _closeStatus = closeStatus; _closeStatusDescription = closeStatusDescription; - bool closeOutput = false; - - lock (StateUpdateLock) - { - if (!_sentCloseFrame) - { - _sentCloseFrame = true; - closeOutput = true; - } - } - - if (closeOutput) - { - await CloseOutputAsyncCore(closeStatus, closeStatusDescription, cancellationToken).ConfigureAwait(false); - } - else if (!_isServer) + if (!_isServer) { await _receiver.WaitForServerToCloseConnectionAsync(cancellationToken).ConfigureAwait(false); } diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs index c864822cfd7b6..d2d36ea3a92e2 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs @@ -261,6 +261,7 @@ public async Task ClosingWebSocketsGracefully() Assert.Equal("Yeet", server.CloseStatusDescription); Assert.Equal(WebSocketCloseStatus.PolicyViolation, server.CloseStatus); + await server.CloseAsync(WebSocketCloseStatus.NormalClosure, null, cancellation.Token); await clientClose; Assert.Equal(WebSocketState.Closed, server.State); From 624405fc82ba52d87f837c1f42b11196b0690501 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Sun, 21 Feb 2021 16:55:22 +0200 Subject: [PATCH 43/52] Added comments in code explaining why we use 9 instead of 8 window bits as minimum. --- .../src/System/Net/WebSockets/WebSocketDeflateOptions.cs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs index 09a5e45d25819..6ddb82c1b0c7e 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs @@ -25,6 +25,10 @@ public int ClientMaxWindowBits get => _clientMaxWindowBits; set { + // The underlying zlib component doesn't support 8 bits in deflater (see https://github.com/madler/zlib/issues/94#issuecomment-125832411 + // and https://zlib.net/manual.html). Quote from the manual "For the current implementation of deflate(), a windowBits value of 8 (a window size of 256 bytes) is not supported.". + // We cannot use silently 9 instead of 8, because the websocket produces raw deflate stream + // and thus it needs to know the window bits in advance. Also take a look at https://github.com/madler/zlib/issues/171. if (value < 9 || value > 15) { throw new ArgumentOutOfRangeException(nameof(ClientMaxWindowBits), value, @@ -51,6 +55,10 @@ public int ServerMaxWindowBits get => _serverMaxWindowBits; set { + // The underlying zlib component doesn't support 8 bits in deflater (see https://github.com/madler/zlib/issues/94#issuecomment-125832411 + // and https://zlib.net/manual.html). Quote from the manual "For the current implementation of deflate(), a windowBits value of 8 (a window size of 256 bytes) is not supported.". + // We cannot use silently 9 instead of 8, because the websocket produces raw deflate stream + // and thus it needs to know the window bits in advance. Also take a look at https://github.com/madler/zlib/issues/171. if (value < 9 || value > 15) { throw new ArgumentOutOfRangeException(nameof(ServerMaxWindowBits), value, From 383b4c98eda5308815a8e066f36500914b6156aa Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Sun, 21 Feb 2021 17:24:28 +0200 Subject: [PATCH 44/52] Added struct layout auto for readonly struct. --- .../src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs index 73cc49f73a947..ee92901112d74 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs @@ -21,6 +21,7 @@ private enum ReceiveResultType HeaderError } + [StructLayout(LayoutKind.Auto)] private readonly struct ReceiveResult { public int Count { get; init; } @@ -414,7 +415,7 @@ private struct Buffer public Buffer(int capacity) { - _bytes = new byte[capacity]; + _bytes = GC.AllocateUninitializedArray(capacity, pinned: true); _position = 0; _consumed = 0; } From da7d050d5ea52d3fd126dc932a9da9063e168f2f Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Sun, 21 Feb 2021 20:35:11 +0200 Subject: [PATCH 45/52] Forgot to check _sentCloseFrame flag before trying to wait for server connection close. --- .../src/System/Net/WebSockets/ManagedWebSocket.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index 58d20e8c96f72..ad314edc75bfc 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -672,7 +672,7 @@ private async ValueTask HandleReceivedCloseAsync(ReadOnlyMemory payload, C _closeStatus = closeStatus; _closeStatusDescription = closeStatusDescription; - if (!_isServer) + if (!_isServer && _sentCloseFrame) { await _receiver.WaitForServerToCloseConnectionAsync(cancellationToken).ConfigureAwait(false); } From 8d951fccc3641fec546b51ff352fa880b9c4cc1b Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Mon, 22 Feb 2021 14:34:15 +0200 Subject: [PATCH 46/52] Replaced RandomNumberGenerator with Random with seed to remove non-determinism from tests. --- .../System.Net.WebSockets/tests/WebSocketDeflateTests.cs | 7 +++---- .../System.Net.WebSockets/tests/WebSocketTests.cs | 3 +-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs index 12e9d20139903..de4c00af42a91 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs @@ -1,6 +1,5 @@ using System.Collections.Generic; using System.Diagnostics; -using System.Security.Cryptography; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -196,13 +195,13 @@ public async Task LargeMessageSplitInMultipleFrames(int windowBits) Memory receivedData = new byte[testData.Length]; // Make the data incompressible to make sure that the output is larger than the input - RandomNumberGenerator.Fill(testData.Span); + var rng = new Random(0); + rng.NextBytes(testData.Span); // Test it a few times with different frame sizes for (var i = 0; i < 10; ++i) { - // Use a timeout cancellation token in case something doesn't work right - var frameSize = RandomNumberGenerator.GetInt32(1024, 2048); + var frameSize = rng.Next(1024, 2048); var position = 0; while (position < testData.Length) diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs index d2d36ea3a92e2..132ad560c81c9 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.IO; -using System.Security.Cryptography; using System.Threading; using System.Threading.Tasks; using Xunit; @@ -189,7 +188,7 @@ public async Task SendUncompressedClientMessage(int messageSize) using var client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions()); var message = new byte[messageSize]; - RandomNumberGenerator.Fill(message); + new Random(0).NextBytes(message); await client.SendAsync(message, WebSocketMessageType.Binary, true, default); From eafbd00759e92e2f151bc32ad1d43f9644fea6cb Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Mon, 22 Feb 2021 19:05:23 +0200 Subject: [PATCH 47/52] Little refactoring of receiver to reduce complexity and make it more readable. --- .../WebSockets/ManagedWebSocket.Receiver.cs | 242 ++++++++++-------- 1 file changed, 139 insertions(+), 103 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs index ee92901112d74..8fab03fd081f6 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs @@ -44,21 +44,11 @@ private sealed class Receiver : IDisposable private bool _inflateFinished = true; /// - /// If we have a decoder we cannot use the buffer provided from clients because + /// If we have compression we cannot use the buffer provided from clients because /// we cannot guarantee that the decoding can happen in place. This buffer is rent'ed /// and returned when consumed. /// - private byte[]? _decoderInputBuffer; - - /// - /// The next index that needs to be consumed from the decoder's input buffer. - /// - private int _decoderInputPosition; - - /// - /// The number of usable bytes in the decoder's buffer. - /// - private int _decoderInputCount; + private Memory _inflateBuffer; /// /// The last header received in a ReceiveAsync. If ReceiveAsync got a header but then @@ -114,12 +104,7 @@ public Receiver(Stream stream, WebSocketCreationOptions options) public void Dispose() { _inflater?.Dispose(); - - if (_decoderInputBuffer is not null) - { - ArrayPool.Shared.Return(_decoderInputBuffer); - _decoderInputBuffer = null; - } + ReturnInflateBuffer(); } public string? GetHeaderError() => _headerError; @@ -189,7 +174,7 @@ public async ValueTask WaitForServerToCloseConnectionAsync(CancellationToken can return new ControlMessage(_lastHeader.Opcode, payload); } - public async ValueTask ReceiveAsync(Memory buffer, CancellationToken cancellationToken) + public async ValueTask ReceiveAsync(Memory output, CancellationToken cancellationToken) { // When there's nothing left over to receive, start a new if (_lastHeader.PayloadLength == 0) @@ -197,15 +182,14 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella if (!_inflateFinished) { Debug.Assert(_inflater is not null); - _inflateFinished = _inflater.Finish(buffer.Span, out int written); + _inflateFinished = _inflater.Finish(output.Span, out int written); return Result(written); } _readBuffer.DiscardConsumed(); - bool success = await ReceiveHeaderAsync(cancellationToken).ConfigureAwait(false); - if (!success) + if (!await ReceiveHeaderAsync(cancellationToken).ConfigureAwait(false)) { return Result(_headerError is not null ? ReceiveResultType.HeaderError : ReceiveResultType.ConnectionClose); } @@ -217,116 +201,88 @@ public async ValueTask ReceiveAsync(Memory buffer, Cancella } } - if (buffer.IsEmpty) + if (output.IsEmpty) { return Result(count: 0); } - // The number of bytes that are copied onto the provided buffer - int resultByteCount = 0; + // The number of bytes that are written to the output buffer + int outputByteCount = 0; if (_readBuffer.AvailableLength > 0) { - int consumed, written; - int available = (int)Math.Min(_readBuffer.AvailableLength, _lastHeader.PayloadLength); - - if (_lastHeader.Compressed) - { - Debug.Assert(_inflater is not null); - _inflater.Inflate(input: _readBuffer.AvailableSpan.Slice(0, available), - output: buffer.Span, out consumed, out written); - } - else + if (!ConsumeReadBuffer(output.Span, out int written)) { - written = Math.Min(available, buffer.Length); - consumed = written; - _readBuffer.AvailableSpan.Slice(0, written).CopyTo(buffer.Span); + return Result(written); } + outputByteCount += written; + output = output[written..]; + } - _readBuffer.Consume(consumed); - _lastHeader.PayloadLength -= consumed; - - resultByteCount += written; - - if (_lastHeader.PayloadLength == 0 || buffer.Length == written) - { - // We have either received everything or the buffer is full. - if (_inflater is not null && _lastHeader.PayloadLength == 0 && _lastHeader.Fin) - { - _inflateFinished = _inflater.Finish(buffer.Span.Slice(written), out written); - resultByteCount += written; - } + // At this point we should have consumed everything from the read buffer + // and should start issuing reads on the stream. + Debug.Assert(_readBuffer.AvailableLength == 0 && _lastHeader.PayloadLength > 0); - return Result(resultByteCount); - } + int receivedByteCount = _lastHeader.Compressed ? + await ReceiveCompressedAsync(output, cancellationToken).ConfigureAwait(false) : + await ReceiveUncompressedAsync(output, cancellationToken).ConfigureAwait(false); - buffer = buffer[written..]; + if (receivedByteCount == 0) + { + return Result(ReceiveResultType.ConnectionClose); } - // At this point we should have consumed everything from the buffer - // and should start issuing reads on the stream. - Debug.Assert(_readBuffer.AvailableLength == 0 && _lastHeader.PayloadLength > 0); + return Result(outputByteCount + receivedByteCount); + } - if (_inflater is null) + private async ValueTask ReceiveUncompressedAsync(Memory output, CancellationToken cancellationToken) + { + Debug.Assert(!_lastHeader.Compressed); + + if (output.Length > _lastHeader.PayloadLength) { - if (buffer.Length > _lastHeader.PayloadLength) - { - // We don't want to receive more than we need - buffer = buffer.Slice(0, (int)_lastHeader.PayloadLength); - } + // We don't want to receive more than we need + output = output.Slice(0, (int)_lastHeader.PayloadLength); + } - int bytesRead = await _stream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); - if (bytesRead <= 0) - { - return Result(ReceiveResultType.ConnectionClose); - } - resultByteCount += bytesRead; + int bytesRead = await _stream.ReadAsync(output, cancellationToken).ConfigureAwait(false); + if (bytesRead > 0) + { _lastHeader.PayloadLength -= bytesRead; - ApplyMask(buffer.Span.Slice(0, bytesRead)); + ApplyMask(output.Span.Slice(0, bytesRead)); } - else - { - if (_decoderInputBuffer is null) - { - // Rent a buffer but restrict it's max size to 1MB - int decoderBufferLength = (int)Math.Min(_lastHeader.PayloadLength, 1_000_000); - - _decoderInputBuffer = ArrayPool.Shared.Rent(decoderBufferLength); - _decoderInputCount = await _stream.ReadAsync(_decoderInputBuffer.AsMemory(0, decoderBufferLength), cancellationToken).ConfigureAwait(false); - _decoderInputPosition = 0; - if (_decoderInputCount <= 0) - { - ArrayPool.Shared.Return(_decoderInputBuffer); - _decoderInputBuffer = null; + return bytesRead; + } - return Result(ReceiveResultType.ConnectionClose); - } + private async ValueTask ReceiveCompressedAsync(Memory output, CancellationToken cancellationToken) + { + Debug.Assert(_lastHeader.Compressed); + Debug.Assert(_inflater is not null); - ApplyMask(_decoderInputBuffer.AsSpan(0, _decoderInputCount)); + if (_inflateBuffer.IsEmpty) + { + if (!await LoadInflateBufferAsync(cancellationToken).ConfigureAwait(false)) + { + return 0; } + } - _inflater.Inflate(input: _decoderInputBuffer.AsSpan(_decoderInputPosition, _decoderInputCount), - output: buffer.Span, out int consumed, out int written); + _inflater.Inflate(_inflateBuffer.Span, output.Span, out int consumed, out int outputByteCount); + _lastHeader.PayloadLength -= consumed; + _inflateBuffer = _inflateBuffer.Slice(consumed); - resultByteCount += written; - _decoderInputPosition += consumed; - _decoderInputCount -= consumed; - _lastHeader.PayloadLength -= consumed; + if (_inflateBuffer.IsEmpty) + { + ReturnInflateBuffer(); - if (_decoderInputCount == 0) + if (_lastHeader.PayloadLength == 0 && _lastHeader.Fin) { - ArrayPool.Shared.Return(_decoderInputBuffer); - _decoderInputBuffer = null; - - if (_lastHeader.PayloadLength == 0 && _lastHeader.Fin) - { - _inflateFinished = _inflater.Finish(buffer.Span[written..], out written); - resultByteCount += written; - } + _inflateFinished = _inflater.Finish(output.Span.Slice(outputByteCount), out var written); + outputByteCount += written; } } - return Result(resultByteCount); + return outputByteCount; } private async ValueTask ReceiveHeaderAsync(CancellationToken cancellationToken) @@ -406,6 +362,86 @@ private void ApplyMask(Span input) } } + /// + /// Tries to consume anything remaining in _readBuffer for the current message.s + /// + /// + /// True when the read buffer is consumed and there's more to be processed, + /// and the output buffer is not full. + /// + private bool ConsumeReadBuffer(Span output, out int outputByteCount) + { + Debug.Assert(_readBuffer.AvailableLength > 0); + + int consumed, written; + int available = (int)Math.Min(_readBuffer.AvailableLength, _lastHeader.PayloadLength); + + if (_lastHeader.Compressed) + { + Debug.Assert(_inflater is not null); + _inflater.Inflate(input: _readBuffer.AvailableSpan.Slice(0, available), + output, out consumed, out written); + } + else + { + // We can copy directly to output + written = Math.Min(available, output.Length); + consumed = written; + _readBuffer.AvailableSpan.Slice(0, written).CopyTo(output); + } + + _readBuffer.Consume(consumed); + _lastHeader.PayloadLength -= consumed; + + outputByteCount = written; + + if (_lastHeader.PayloadLength == 0 || output.Length == written) + { + // We have either consumed everything or the output is full. + // In this case we try to finish inflating if needed and return. + if (_inflater is not null && _lastHeader.Compressed && _lastHeader.PayloadLength == 0 && _lastHeader.Fin) + { + _inflateFinished = _inflater.Finish(output.Slice(written), out written); + outputByteCount += written; + } + + return false; + } + + return true; + } + + private async ValueTask LoadInflateBufferAsync(CancellationToken cancellationToken) + { + // Rent a buffer but restrict it's max size to 1MB + int decoderBufferLength = (int)Math.Min(_lastHeader.PayloadLength, 1_000_000); + + _inflateBuffer = ArrayPool.Shared.Rent(decoderBufferLength); + int byteCount = await _stream.ReadAsync(_inflateBuffer, cancellationToken).ConfigureAwait(false); + + if (byteCount <= 0) + { + ReturnInflateBuffer(); + return false; + } + + _inflateBuffer = _inflateBuffer.Slice(0, byteCount); + ApplyMask(_inflateBuffer.Span); + + return true; + } + + private void ReturnInflateBuffer() + { + if (MemoryMarshal.TryGetArray(_inflateBuffer, out ArraySegment arraySegment) + && arraySegment.Array!.Length > 0) + { + Debug.Assert(arraySegment.Array is not null); + ArrayPool.Shared.Return(arraySegment.Array); + _inflateBuffer = null; + } + } + [StructLayout(LayoutKind.Auto)] private struct Buffer { From 75a5c4f9c5c3b7896dfbe737f32609525d1919f3 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Mon, 22 Feb 2021 19:27:29 +0200 Subject: [PATCH 48/52] Replaced unused CancellationTokenSource with a method. --- .../System/Net/WebSockets/ManagedWebSocket.cs | 40 ++++++++----------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index ad314edc75bfc..c106c71c41766 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -70,8 +70,6 @@ public static ManagedWebSocket CreateFromConnectedStream(Stream stream, WebSocke private readonly string? _subprotocol; /// Timer used to send periodic pings to the server, at the interval specified private readonly Timer? _keepAliveTimer; - /// CancellationTokenSource used to abort all current and future operations when anything is canceled or any error occurs. - private readonly CancellationTokenSource _abortSource = new CancellationTokenSource(); /// /// Semaphore used to ensure that calls to SendFrameAsync don't run concurrently. @@ -115,7 +113,7 @@ public static ManagedWebSocket CreateFromConnectedStream(Stream stream, WebSocke private Task _lastReceiveAsync = Task.CompletedTask; /// Lock used to protect update and check-and-update operations on _state. - private object StateUpdateLock => _abortSource; + private object StateUpdateLock => _sender; /// /// We need to coordinate between receives and close operations happening concurrently, as a ReceiveAsync may /// be pending while a Close{Output}Async is issued, which itself needs to loop until a close frame is received. @@ -140,25 +138,6 @@ private ManagedWebSocket(Stream stream, WebSocketCreationOptions options) _isServer = options.IsServer; _subprotocol = options.SubProtocol; - // Set up the abort source so that if it's triggered, we transition the instance appropriately. - // There's no need to store the resulting CancellationTokenRegistration, as this instance owns - // the CancellationTokenSource, and the lifetime of that CTS matches the lifetime of the registration. - _abortSource.Token.UnsafeRegister(static s => - { - var thisRef = (ManagedWebSocket)s!; - - lock (thisRef.StateUpdateLock) - { - WebSocketState state = thisRef._state; - if (state != WebSocketState.Closed && state != WebSocketState.Aborted) - { - thisRef._state = state != WebSocketState.None && state != WebSocketState.Connecting ? - WebSocketState.Aborted : - WebSocketState.Closed; - } - } - }, this); - // Now that we're opened, initiate the keep alive timer to send periodic pings. // We use a weak reference from the timer to the web socket to avoid a cycle // that could keep the web socket rooted in erroneous cases. @@ -317,10 +296,23 @@ private async Task CloseOutputAsyncCore(WebSocketCloseStatus closeStatus, string public override void Abort() { - _abortSource.Cancel(); + OnAborted(); Dispose(); // forcibly tear down connection } + private void OnAborted() + { + lock (StateUpdateLock) + { + if (_state is not (WebSocketState.Closed or WebSocketState.Aborted)) + { + _state = _state is not (WebSocketState.None or WebSocketState.Connecting) ? + WebSocketState.Aborted : WebSocketState.Closed; + } + } + } + + public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { return SendPrivateAsync(buffer, messageType, endOfMessage, cancellationToken); @@ -607,7 +599,7 @@ private async ValueTask ReceiveAsyncPrivate Date: Mon, 22 Feb 2021 19:34:47 +0200 Subject: [PATCH 49/52] Removed unneeded (duplicated) check. --- .../src/System/Net/WebSockets/ManagedWebSocket.cs | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index c106c71c41766..28d1c5df2d56c 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -190,14 +190,6 @@ private void DisposeCore() public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { - if (messageType != WebSocketMessageType.Text && messageType != WebSocketMessageType.Binary) - { - throw new ArgumentException(SR.Format( - SR.net_WebSockets_Argument_InvalidMessageType, - nameof(WebSocketMessageType.Close), nameof(SendAsync), nameof(WebSocketMessageType.Binary), nameof(WebSocketMessageType.Text), nameof(CloseOutputAsync)), - nameof(messageType)); - } - WebSocketValidate.ValidateArraySegment(buffer, nameof(buffer)); return SendPrivateAsync(buffer, messageType, endOfMessage, cancellationToken).AsTask(); From 2f31febc8207bb8337e40406086664e612c73dc5 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Mon, 22 Feb 2021 19:34:59 +0200 Subject: [PATCH 50/52] More tests. --- .../tests/System.Net.WebSockets.Tests.csproj | 1 + .../tests/WebSocketDeflateOptionsTests.cs | 48 ++++++++++ .../tests/WebSocketDeflateTests.cs | 89 +++++++++++++++++-- 3 files changed, 133 insertions(+), 5 deletions(-) create mode 100644 src/libraries/System.Net.WebSockets/tests/WebSocketDeflateOptionsTests.cs diff --git a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj index cdfd537caafcc..a8e63a93296d0 100644 --- a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj +++ b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj @@ -3,6 +3,7 @@ $(NetCoreAppCurrent)-windows;$(NetCoreAppCurrent)-Unix;$(NetCoreAppCurrent)-Browser + diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateOptionsTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateOptionsTests.cs new file mode 100644 index 0000000000000..43d9eea7ba297 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateOptionsTests.cs @@ -0,0 +1,48 @@ +using Xunit; + +namespace System.Net.WebSockets.Tests +{ + public class WebSocketDeflateOptionsTests + { + [Fact] + public void ClientMaxWindowBits() + { + var options = new WebSocketDeflateOptions(); + Assert.Equal(15, options.ClientMaxWindowBits); + + Assert.Throws(() => options.ClientMaxWindowBits = 8); + Assert.Throws(() => options.ClientMaxWindowBits = 16); + + options.ClientMaxWindowBits = 14; + Assert.Equal(14, options.ClientMaxWindowBits); + } + + [Fact] + public void ServerMaxWindowBits() + { + var options = new WebSocketDeflateOptions(); + Assert.Equal(15, options.ServerMaxWindowBits); + + Assert.Throws(() => options.ServerMaxWindowBits = 8); + Assert.Throws(() => options.ServerMaxWindowBits = 16); + + options.ServerMaxWindowBits = 14; + Assert.Equal(14, options.ServerMaxWindowBits); + } + + [Fact] + public void ContextTakeover() + { + var options = new WebSocketDeflateOptions(); + + Assert.True(options.ClientContextTakeover); + Assert.True(options.ServerContextTakeover); + + options.ClientContextTakeover = false; + Assert.False(options.ClientContextTakeover); + + options.ServerContextTakeover = false; + Assert.False(options.ServerContextTakeover); + } + } +} diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs index de4c00af42a91..68768acb054aa 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System.Buffers; +using System.Collections.Generic; using System.Diagnostics; using System.Text; using System.Threading; @@ -127,18 +128,30 @@ public async Task TwoDeflateBlocksInOneMessage() Assert.Equal("Hello", Encoding.UTF8.GetString(buffer.Span)); } - [Fact] - public async Task Duplex() + [Theory] + [InlineData(false, false)] + [InlineData(true, true)] + [InlineData(false, true)] + [InlineData(true, false)] + public async Task Duplex(bool clientContextTakover, bool serverContextTakover) { var stream = new WebSocketStream(); using var server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions { IsServer = true, - DeflateOptions = new() + DeflateOptions = new WebSocketDeflateOptions + { + ClientContextTakeover = clientContextTakover, + ServerContextTakeover = serverContextTakover + } }); using var client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions { - DeflateOptions = new() + DeflateOptions = new WebSocketDeflateOptions + { + ClientContextTakeover = clientContextTakover, + ServerContextTakeover = serverContextTakover + } }); var buffer = new byte[1024]; @@ -255,10 +268,76 @@ public async Task WebSocketWithoutDeflateShouldThrowOnCompressedMessage() Assert.Equal("The WebSocket received compressed frame when compression is not enabled.", exception.Message); } + [Fact] + public async Task ReceiveUncompressedMessageWhenCompressionEnabled() + { + // We should be able to handle the situation where even if we have + // deflate compression enabled, uncompressed messages are OK + var stream = new WebSocketStream(); + var server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + IsServer = true, + DeflateOptions = null + }); + var client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + DeflateOptions = new WebSocketDeflateOptions() + }); + + // Server sends uncompressed + await SendTextAsync("Hello", server); + + // Although client has deflate options, it should still be able + // to handle uncompressed messages. + Assert.Equal("Hello", await ReceiveTextAsync(client)); + + // Client sends compressed, but server compression is disabled and should throw on receive + await SendTextAsync("Hello back", client); + var exception = await Assert.ThrowsAsync(() => ReceiveTextAsync(server)); + Assert.Equal("The WebSocket received compressed frame when compression is not enabled.", exception.Message); + Assert.Equal(WebSocketState.Aborted, server.State); + + // The client should close if we try to receive + var result = await client.ReceiveAsync(Memory.Empty, CancellationToken); + Assert.Equal(WebSocketMessageType.Close, result.MessageType); + Assert.Equal(WebSocketCloseStatus.ProtocolError, client.CloseStatus); + Assert.Equal(WebSocketState.CloseReceived, client.State); + } + + [Fact] + public async Task ReceiveInvalidCompressedData() + { + var stream = new WebSocketStream(); + var client = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + DeflateOptions = new WebSocketDeflateOptions() + }); + + stream.Enqueue(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); + Assert.Equal("Hello", await ReceiveTextAsync(client)); + + stream.Enqueue(0xc1, 0x07, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00); + var exception = await Assert.ThrowsAsync(() => ReceiveTextAsync(client)); + + Assert.Equal("The message was compressed using an unsupported compression method.", exception.Message); + Assert.Equal(WebSocketState.Aborted, client.State); + } + private ValueTask SendTextAsync(string text, WebSocket websocket) { var bytes = Encoding.UTF8.GetBytes(text); return websocket.SendAsync(bytes.AsMemory(), WebSocketMessageType.Text, true, CancellationToken); } + + private async Task ReceiveTextAsync(WebSocket websocket) + { + using var buffer = MemoryPool.Shared.Rent(1024 * 32); + var result = await websocket.ReceiveAsync(buffer.Memory, CancellationToken); + + Assert.True(result.EndOfMessage); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + + return Encoding.UTF8.GetString(buffer.Memory.Span.Slice(0, result.Count)); + } } } From 9ba2536c937ddf7754221d5de868bf5acfc86a4c Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Mon, 22 Feb 2021 21:09:43 +0200 Subject: [PATCH 51/52] Created test for deflate options in client websocket. --- .../Net/WebSockets/WebSocketHandle.Managed.cs | 23 +++-- .../tests/DeflateTests.cs | 99 +++++++++++++++++++ .../tests/LoopbackHelper.cs | 3 +- .../System.Net.WebSockets.Client.Tests.csproj | 1 + .../tests/System.Net.WebSockets.Tests.csproj | 9 +- 5 files changed, 122 insertions(+), 13 deletions(-) create mode 100644 src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index 784bdeb2d11e4..ccc598368b605 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -191,6 +191,11 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli } } + // Store the negotiated deflate options in the original options, because + // otherwise there is now way of clients to actually check whether we are using + // per message deflate or not. + options.DeflateOptions = deflateOptions; + if (response.Content is null) { throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely); @@ -247,19 +252,19 @@ private static WebSocketDeflateOptions ParseDeflateOptions(ReadOnlySpan ex if (!value.IsEmpty) { - if (value == ClientWebSocketDeflateConstants.ClientNoContextTakeover) + if (value.Equals(ClientWebSocketDeflateConstants.ClientNoContextTakeover, StringComparison.Ordinal)) { options.ClientContextTakeover = false; } - else if (value == ClientWebSocketDeflateConstants.ServerNoContextTakeover) + else if (value.Equals(ClientWebSocketDeflateConstants.ServerNoContextTakeover, StringComparison.Ordinal)) { options.ServerContextTakeover = false; } - else if (value.StartsWith(ClientWebSocketDeflateConstants.ClientMaxWindowBits)) + else if (value.StartsWith(ClientWebSocketDeflateConstants.ClientMaxWindowBits, StringComparison.Ordinal)) { options.ClientMaxWindowBits = ParseWindowBits(value); } - else if (value.StartsWith(ClientWebSocketDeflateConstants.ServerMaxWindowBits)) + else if (value.StartsWith(ClientWebSocketDeflateConstants.ServerMaxWindowBits, StringComparison.Ordinal)) { options.ServerMaxWindowBits = ParseWindowBits(value); } @@ -334,6 +339,11 @@ static IEnumerable GetDeflateOptions(WebSocketDeflateOptions options) yield return ClientWebSocketDeflateConstants.ClientMaxWindowBits; } + if (!options.ClientContextTakeover) + { + yield return ClientWebSocketDeflateConstants.ClientNoContextTakeover; + } + if (options.ServerMaxWindowBits != 15) { yield return $"{ClientWebSocketDeflateConstants.ServerMaxWindowBits}={options.ServerMaxWindowBits}"; @@ -348,11 +358,6 @@ static IEnumerable GetDeflateOptions(WebSocketDeflateOptions options) { yield return ClientWebSocketDeflateConstants.ServerNoContextTakeover; } - - if (!options.ClientContextTakeover) - { - yield return ClientWebSocketDeflateConstants.ClientNoContextTakeover; - } } } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs new file mode 100644 index 0000000000000..a182830426c30 --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs @@ -0,0 +1,99 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.IO; +using System.Net.Test.Common; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +using Xunit; +using Xunit.Abstractions; + +namespace System.Net.WebSockets.Client.Tests +{ + public class DeflateTests : ClientWebSocketTestBase + { + public DeflateTests(ITestOutputHelper output) : base(output) + { + } + + [ConditionalTheory(nameof(WebSocketsSupported))] + [ActiveIssue("https://github.com/dotnet/runtime/issues/34690", TestPlatforms.Windows, TargetFrameworkMonikers.Netcoreapp, TestRuntimes.Mono)] + [ActiveIssue("https://github.com/dotnet/runtime/issues/42852", TestPlatforms.Browser)] + [InlineData(15, true, 15, true, "permessage-deflate; client_max_window_bits; server_max_window_bits")] + [InlineData(14, true, 15, true, "permessage-deflate; client_max_window_bits=14; server_max_window_bits")] + [InlineData(15, true, 14, true, "permessage-deflate; client_max_window_bits; server_max_window_bits=14")] + [InlineData(10, true, 11, true, "permessage-deflate; client_max_window_bits=10; server_max_window_bits=11")] + [InlineData(15, false, 15, true, "permessage-deflate; client_max_window_bits; client_no_context_takeover; server_max_window_bits")] + [InlineData(15, true, 15, false, "permessage-deflate; client_max_window_bits; server_max_window_bits; server_no_context_takeover")] + public async Task PerMessageDeflateHeaders(int clientWindowBits, bool clientContextTakeover, + int serverWindowBits, bool serverContextTakover, + string expected) + { + await LoopbackServer.CreateClientAndServerAsync(async uri => + { + using var client = new ClientWebSocket(); + using var cancellation = new CancellationTokenSource(TimeOutMilliseconds); + + client.Options.DeflateOptions = new WebSocketDeflateOptions + { + ClientMaxWindowBits = clientWindowBits, + ClientContextTakeover = clientContextTakeover, + ServerMaxWindowBits = serverWindowBits, + ServerContextTakeover = serverContextTakover + }; + + await client.ConnectAsync(uri, cancellation.Token); + + Assert.NotNull(client.Options.DeflateOptions); + Assert.Equal(clientWindowBits - 1, client.Options.DeflateOptions.ClientMaxWindowBits); + Assert.Equal(clientContextTakeover, client.Options.DeflateOptions.ClientContextTakeover); + Assert.Equal(serverWindowBits - 1, client.Options.DeflateOptions.ServerMaxWindowBits); + Assert.Equal(serverContextTakover, client.Options.DeflateOptions.ServerContextTakeover); + }, server => server.AcceptConnectionAsync(async connection => + { + var extensionsReply = CreateDeflateOptionsHeader(new WebSocketDeflateOptions + { + ClientMaxWindowBits = clientWindowBits - 1, + ClientContextTakeover = clientContextTakeover, + ServerMaxWindowBits = serverWindowBits - 1, + ServerContextTakeover = serverContextTakover + }); + Dictionary headers = await LoopbackHelper.WebSocketHandshakeAsync(connection, extensionsReply); + Assert.NotNull(headers); + Assert.True(headers.TryGetValue("Sec-WebSocket-Extensions", out string extensions)); + Assert.Equal(expected, extensions); + }), new LoopbackServer.Options { WebSocketEndpoint = true }); + } + + private static string CreateDeflateOptionsHeader(WebSocketDeflateOptions options) + { + var builder = new StringBuilder(); + builder.Append("permessage-deflate"); + + if (options.ClientMaxWindowBits != 15) + { + builder.Append("; client_max_window_bits=").Append(options.ClientMaxWindowBits); + } + + if (!options.ClientContextTakeover) + { + builder.Append("; client_no_context_takeover"); + } + + if (options.ServerMaxWindowBits != 15) + { + builder.Append("; server_max_window_bits=").Append(options.ServerMaxWindowBits); + } + + if (!options.ServerContextTakeover) + { + builder.Append("; server_no_context_takeover"); + } + + return builder.ToString(); + } + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs index 5726326c6ab8f..48d167b072f78 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs @@ -11,7 +11,7 @@ namespace System.Net.WebSockets.Client.Tests { public static class LoopbackHelper { - public static async Task> WebSocketHandshakeAsync(LoopbackServer.Connection connection) + public static async Task> WebSocketHandshakeAsync(LoopbackServer.Connection connection, string? extensions = null) { string serverResponse = null; List headers = await connection.ReadRequestHeaderAsync().ConfigureAwait(false); @@ -34,6 +34,7 @@ public static async Task> WebSocketHandshakeAsync(Loo "Content-Length: 0\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + + (extensions is null ? null : $"Sec-WebSocket-Extensions: {extensions}\r\n") + "Sec-WebSocket-Accept: " + responseSecurityAcceptValue + "\r\n\r\n"; } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj index 248546468d0fa..adb1c3e447fb0 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj +++ b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj @@ -38,6 +38,7 @@ + diff --git a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj index a8e63a93296d0..c7691fa535199 100644 --- a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj +++ b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj @@ -10,8 +10,11 @@ - - - + + + From 2099294ef01405d28a8ece8ed4b983d2d6dc6fe9 Mon Sep 17 00:00:00 2001 From: Ivan Zlatanov Date: Mon, 22 Feb 2021 21:28:02 +0200 Subject: [PATCH 52/52] Fixed wrong test in http listener with websocket. --- .../src/System/Net/Windows/WebSockets/WebSocketBase.cs | 4 ++-- .../tests/HttpListenerWebSocketTests.cs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/libraries/System.Net.HttpListener/src/System/Net/Windows/WebSockets/WebSocketBase.cs b/src/libraries/System.Net.HttpListener/src/System/Net/Windows/WebSockets/WebSocketBase.cs index 4f7fa86e43c85..0dc1af676dfee 100644 --- a/src/libraries/System.Net.HttpListener/src/System/Net/Windows/WebSockets/WebSocketBase.cs +++ b/src/libraries/System.Net.HttpListener/src/System/Net/Windows/WebSockets/WebSocketBase.cs @@ -225,6 +225,8 @@ public override Task SendAsync(ArraySegment buffer, bool endOfMessage, CancellationToken cancellationToken) { + WebSocketValidate.ValidateArraySegment(buffer, nameof(buffer)); + if (messageType != WebSocketMessageType.Binary && messageType != WebSocketMessageType.Text) { @@ -237,8 +239,6 @@ public override Task SendAsync(ArraySegment buffer, nameof(messageType)); } - WebSocketValidate.ValidateArraySegment(buffer, nameof(buffer)); - return SendAsyncCore(buffer, messageType, endOfMessage, cancellationToken); } diff --git a/src/libraries/System.Net.HttpListener/tests/HttpListenerWebSocketTests.cs b/src/libraries/System.Net.HttpListener/tests/HttpListenerWebSocketTests.cs index f9838ff2efe2f..72f26cf5f8faf 100644 --- a/src/libraries/System.Net.HttpListener/tests/HttpListenerWebSocketTests.cs +++ b/src/libraries/System.Net.HttpListener/tests/HttpListenerWebSocketTests.cs @@ -73,7 +73,7 @@ public async Task SendAsync_NoInnerBuffer_ThrowsArgumentNullException() public async Task SendAsync_InvalidMessageType_ThrowsArgumentNullException(WebSocketMessageType messageType) { HttpListenerWebSocketContext context = await GetWebSocketContext(); - await AssertExtensions.ThrowsAsync("messageType", () => context.WebSocket.SendAsync(new ArraySegment(), messageType, false, new CancellationToken())); + await AssertExtensions.ThrowsAsync("buffer.Array", () => context.WebSocket.SendAsync(new ArraySegment(), messageType, false, new CancellationToken())); } [ConditionalFact(nameof(IsNotWindows7AndIsWindowsImplementation))] // [ActiveIssue("https://github.com/dotnet/runtime/issues/22014", TestPlatforms.AnyUnix)]