diff --git a/src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.ZStream.cs b/src/libraries/Common/src/System/IO/Compression/ZLibNative.ZStream.cs similarity index 100% rename from src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.ZStream.cs rename to src/libraries/Common/src/System/IO/Compression/ZLibNative.ZStream.cs diff --git a/src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.cs b/src/libraries/Common/src/System/IO/Compression/ZLibNative.cs similarity index 100% rename from src/libraries/System.IO.Compression/src/System/IO/Compression/DeflateZLib/ZLibNative.cs rename to src/libraries/Common/src/System/IO/Compression/ZLibNative.cs diff --git a/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj b/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj index e2a7adee12f57..758e86a80d6b4 100644 --- a/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj +++ b/src/libraries/System.IO.Compression/src/System.IO.Compression.csproj @@ -1,4 +1,4 @@ - + true $(NetCoreAppCurrent)-windows;$(NetCoreAppCurrent)-Unix;$(NetCoreAppCurrent)-Browser @@ -25,8 +25,8 @@ - - + + diff --git a/src/libraries/System.Net.HttpListener/src/System/Net/Windows/WebSockets/WebSocketBase.cs b/src/libraries/System.Net.HttpListener/src/System/Net/Windows/WebSockets/WebSocketBase.cs index 4f7fa86e43c85..0dc1af676dfee 100644 --- a/src/libraries/System.Net.HttpListener/src/System/Net/Windows/WebSockets/WebSocketBase.cs +++ b/src/libraries/System.Net.HttpListener/src/System/Net/Windows/WebSockets/WebSocketBase.cs @@ -225,6 +225,8 @@ public override Task SendAsync(ArraySegment buffer, bool endOfMessage, CancellationToken cancellationToken) { + WebSocketValidate.ValidateArraySegment(buffer, nameof(buffer)); + if (messageType != WebSocketMessageType.Binary && messageType != WebSocketMessageType.Text) { @@ -237,8 +239,6 @@ public override Task SendAsync(ArraySegment buffer, nameof(messageType)); } - WebSocketValidate.ValidateArraySegment(buffer, nameof(buffer)); - return SendAsyncCore(buffer, messageType, endOfMessage, cancellationToken); } diff --git a/src/libraries/System.Net.HttpListener/tests/HttpListenerWebSocketTests.cs b/src/libraries/System.Net.HttpListener/tests/HttpListenerWebSocketTests.cs index f9838ff2efe2f..72f26cf5f8faf 100644 --- a/src/libraries/System.Net.HttpListener/tests/HttpListenerWebSocketTests.cs +++ b/src/libraries/System.Net.HttpListener/tests/HttpListenerWebSocketTests.cs @@ -73,7 +73,7 @@ public async Task SendAsync_NoInnerBuffer_ThrowsArgumentNullException() public async Task SendAsync_InvalidMessageType_ThrowsArgumentNullException(WebSocketMessageType messageType) { HttpListenerWebSocketContext context = await GetWebSocketContext(); - await AssertExtensions.ThrowsAsync("messageType", () => context.WebSocket.SendAsync(new ArraySegment(), messageType, false, new CancellationToken())); + await AssertExtensions.ThrowsAsync("buffer.Array", () => context.WebSocket.SendAsync(new ArraySegment(), messageType, false, new CancellationToken())); } [ConditionalFact(nameof(IsNotWindows7AndIsWindowsImplementation))] // [ActiveIssue("https://github.com/dotnet/runtime/issues/22014", TestPlatforms.AnyUnix)] diff --git a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs index cee3a5170b862..660a2c5fbe748 100644 --- a/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs +++ b/src/libraries/System.Net.WebSockets.Client/ref/System.Net.WebSockets.Client.cs @@ -36,6 +36,8 @@ internal ClientWebSocketOptions() { } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.TimeSpan KeepAliveInterval { get { throw null; } set { } } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] + public System.Net.WebSockets.WebSocketDeflateOptions? DeflateOptions { get { throw null; } set { } } + [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.Net.IWebProxy? Proxy { get { throw null; } set { } } [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")] public System.Net.Security.RemoteCertificateValidationCallback? RemoteCertificateValidationCallback { get { throw null; } set { } } diff --git a/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx index 3259b86c99fcb..5649bc52e7653 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx +++ b/src/libraries/System.Net.WebSockets.Client/src/Resources/Strings.resx @@ -1,16 +1,17 @@ - - @@ -193,8 +194,11 @@ Connection was aborted. - + WebSocket binary type '{0}' not supported. - - + + + The WebSocket failed to negotiate max {0} window bits. The client requested {1} but the server responded with {2}. + + \ No newline at end of file diff --git a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj index b74f3d8962be6..e84ea02f895ba 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj +++ b/src/libraries/System.Net.WebSockets.Client/src/System.Net.WebSockets.Client.csproj @@ -6,6 +6,7 @@ + @@ -37,6 +38,7 @@ + diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs index 85b0f025b4650..2ed5c527421c9 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/BrowserWebSockets/ClientWebSocketOptions.cs @@ -100,6 +100,13 @@ public TimeSpan KeepAliveInterval set => throw new PlatformNotSupportedException(); } + [UnsupportedOSPlatform("browser")] + public WebSocketDeflateOptions? DeflateOptions + { + get => throw new PlatformNotSupportedException(); + set => throw new PlatformNotSupportedException(); + } + [UnsupportedOSPlatform("browser")] public void SetBuffer(int receiveBufferSize, int sendBufferSize) { diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketDeflateConstants.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketDeflateConstants.cs new file mode 100644 index 0000000000000..39b5619b27085 --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketDeflateConstants.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Net.WebSockets +{ + internal static class ClientWebSocketDeflateConstants + { + public const string Extension = "permessage-deflate"; + + public const string ClientMaxWindowBits = "client_max_window_bits"; + public const string ClientNoContextTakeover = "client_no_context_takeover"; + + public const string ServerMaxWindowBits = "server_max_window_bits"; + public const string ServerNoContextTakeover = "server_no_context_takeover"; + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs index a7609a0ff0905..573c3eb8325b7 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/ClientWebSocketOptions.cs @@ -148,6 +148,9 @@ public TimeSpan KeepAliveInterval } } + [UnsupportedOSPlatform("browser")] + public WebSocketDeflateOptions? DeflateOptions { get; set; } + internal int ReceiveBufferSize => _receiveBufferSize; internal ArraySegment? Buffer => _buffer; diff --git a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs index a378566af65ca..ccc598368b605 100644 --- a/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs +++ b/src/libraries/System.Net.WebSockets.Client/src/System/Net/WebSockets/WebSocketHandle.Managed.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Globalization; using System.IO; using System.Net.Http; using System.Net.Http.Headers; @@ -175,6 +176,26 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli } } + // Because deflate options are negotiated we need a new object + WebSocketDeflateOptions? deflateOptions = null; + + if (options.DeflateOptions is not null && response.Headers.TryGetValues(HttpKnownHeaderNames.SecWebSocketExtensions, out IEnumerable? extensions)) + { + foreach (ReadOnlySpan extension in extensions) + { + if (extension.TrimStart().StartsWith(ClientWebSocketDeflateConstants.Extension)) + { + deflateOptions = ParseDeflateOptions(extension, options.DeflateOptions); + break; + } + } + } + + // Store the negotiated deflate options in the original options, because + // otherwise there is now way of clients to actually check whether we are using + // per message deflate or not. + options.DeflateOptions = deflateOptions; + if (response.Content is null) { throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely); @@ -184,11 +205,13 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli Stream connectedStream = response.Content.ReadAsStream(); Debug.Assert(connectedStream.CanWrite); Debug.Assert(connectedStream.CanRead); - WebSocket = WebSocket.CreateFromStream( - connectedStream, - isServer: false, - subprotocol, - options.KeepAliveInterval); + WebSocket = WebSocket.CreateFromStream(connectedStream, new WebSocketCreationOptions + { + IsServer = false, + SubProtocol = subprotocol, + KeepAliveInterval = options.KeepAliveInterval, + DeflateOptions = deflateOptions, + }); } catch (Exception exc) { @@ -218,6 +241,72 @@ public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken, Cli } } + private static WebSocketDeflateOptions ParseDeflateOptions(ReadOnlySpan extension, WebSocketDeflateOptions original) + { + var options = new WebSocketDeflateOptions(); + + while (true) + { + int end = extension.IndexOf(';'); + ReadOnlySpan value = (end >= 0 ? extension[..end] : extension).Trim(); + + if (!value.IsEmpty) + { + if (value.Equals(ClientWebSocketDeflateConstants.ClientNoContextTakeover, StringComparison.Ordinal)) + { + options.ClientContextTakeover = false; + } + else if (value.Equals(ClientWebSocketDeflateConstants.ServerNoContextTakeover, StringComparison.Ordinal)) + { + options.ServerContextTakeover = false; + } + else if (value.StartsWith(ClientWebSocketDeflateConstants.ClientMaxWindowBits, StringComparison.Ordinal)) + { + options.ClientMaxWindowBits = ParseWindowBits(value); + } + else if (value.StartsWith(ClientWebSocketDeflateConstants.ServerMaxWindowBits, StringComparison.Ordinal)) + { + options.ServerMaxWindowBits = ParseWindowBits(value); + } + + static int ParseWindowBits(ReadOnlySpan value) + { + var startIndex = value.IndexOf('='); + + if (startIndex < 0 || + !int.TryParse(value.Slice(startIndex + 1), NumberStyles.Integer, CultureInfo.InvariantCulture, out int windowBits) || + windowBits < 9 || + windowBits > 15) + { + throw new WebSocketException(WebSocketError.HeaderError, + SR.Format(SR.net_WebSockets_InvalidResponseHeader, ClientWebSocketDeflateConstants.Extension, value.ToString())); + } + + return windowBits; + } + } + + if (end < 0) + break; + + extension = extension[(end + 1)..]; + } + + if (options.ClientMaxWindowBits > original.ClientMaxWindowBits) + { + throw new WebSocketException(string.Format(SR.net_WebSockets_WindowBitsNegotiationFailure, + "client", original.ClientMaxWindowBits, options.ClientMaxWindowBits)); + } + + if (options.ServerMaxWindowBits > original.ServerMaxWindowBits) + { + throw new WebSocketException(string.Format(SR.net_WebSockets_WindowBitsNegotiationFailure, + "server", original.ServerMaxWindowBits, options.ServerMaxWindowBits)); + } + + return options; + } + /// Adds the necessary headers for the web socket request. /// The request to which the headers should be added. /// The generated security key to send in the Sec-WebSocket-Key header. @@ -232,6 +321,45 @@ private static void AddWebSocketHeaders(HttpRequestMessage request, string secKe { request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketProtocol, string.Join(", ", options.RequestedSubProtocols)); } + if (options.DeflateOptions is not null) + { + request.Headers.TryAddWithoutValidation(HttpKnownHeaderNames.SecWebSocketExtensions, string.Join("; ", GetDeflateOptions(options.DeflateOptions))); + + static IEnumerable GetDeflateOptions(WebSocketDeflateOptions options) + { + yield return ClientWebSocketDeflateConstants.Extension; + + if (options.ClientMaxWindowBits != 15) + { + yield return $"{ClientWebSocketDeflateConstants.ClientMaxWindowBits}={options.ClientMaxWindowBits}"; + } + else + { + // Advertise that we support this option + yield return ClientWebSocketDeflateConstants.ClientMaxWindowBits; + } + + if (!options.ClientContextTakeover) + { + yield return ClientWebSocketDeflateConstants.ClientNoContextTakeover; + } + + if (options.ServerMaxWindowBits != 15) + { + yield return $"{ClientWebSocketDeflateConstants.ServerMaxWindowBits}={options.ServerMaxWindowBits}"; + } + else + { + // Advertise that we support this option + yield return ClientWebSocketDeflateConstants.ServerMaxWindowBits; + } + + if (!options.ServerContextTakeover) + { + yield return ClientWebSocketDeflateConstants.ServerNoContextTakeover; + } + } + } } /// diff --git a/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs new file mode 100644 index 0000000000000..a182830426c30 --- /dev/null +++ b/src/libraries/System.Net.WebSockets.Client/tests/DeflateTests.cs @@ -0,0 +1,99 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.IO; +using System.Net.Test.Common; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +using Xunit; +using Xunit.Abstractions; + +namespace System.Net.WebSockets.Client.Tests +{ + public class DeflateTests : ClientWebSocketTestBase + { + public DeflateTests(ITestOutputHelper output) : base(output) + { + } + + [ConditionalTheory(nameof(WebSocketsSupported))] + [ActiveIssue("https://github.com/dotnet/runtime/issues/34690", TestPlatforms.Windows, TargetFrameworkMonikers.Netcoreapp, TestRuntimes.Mono)] + [ActiveIssue("https://github.com/dotnet/runtime/issues/42852", TestPlatforms.Browser)] + [InlineData(15, true, 15, true, "permessage-deflate; client_max_window_bits; server_max_window_bits")] + [InlineData(14, true, 15, true, "permessage-deflate; client_max_window_bits=14; server_max_window_bits")] + [InlineData(15, true, 14, true, "permessage-deflate; client_max_window_bits; server_max_window_bits=14")] + [InlineData(10, true, 11, true, "permessage-deflate; client_max_window_bits=10; server_max_window_bits=11")] + [InlineData(15, false, 15, true, "permessage-deflate; client_max_window_bits; client_no_context_takeover; server_max_window_bits")] + [InlineData(15, true, 15, false, "permessage-deflate; client_max_window_bits; server_max_window_bits; server_no_context_takeover")] + public async Task PerMessageDeflateHeaders(int clientWindowBits, bool clientContextTakeover, + int serverWindowBits, bool serverContextTakover, + string expected) + { + await LoopbackServer.CreateClientAndServerAsync(async uri => + { + using var client = new ClientWebSocket(); + using var cancellation = new CancellationTokenSource(TimeOutMilliseconds); + + client.Options.DeflateOptions = new WebSocketDeflateOptions + { + ClientMaxWindowBits = clientWindowBits, + ClientContextTakeover = clientContextTakeover, + ServerMaxWindowBits = serverWindowBits, + ServerContextTakeover = serverContextTakover + }; + + await client.ConnectAsync(uri, cancellation.Token); + + Assert.NotNull(client.Options.DeflateOptions); + Assert.Equal(clientWindowBits - 1, client.Options.DeflateOptions.ClientMaxWindowBits); + Assert.Equal(clientContextTakeover, client.Options.DeflateOptions.ClientContextTakeover); + Assert.Equal(serverWindowBits - 1, client.Options.DeflateOptions.ServerMaxWindowBits); + Assert.Equal(serverContextTakover, client.Options.DeflateOptions.ServerContextTakeover); + }, server => server.AcceptConnectionAsync(async connection => + { + var extensionsReply = CreateDeflateOptionsHeader(new WebSocketDeflateOptions + { + ClientMaxWindowBits = clientWindowBits - 1, + ClientContextTakeover = clientContextTakeover, + ServerMaxWindowBits = serverWindowBits - 1, + ServerContextTakeover = serverContextTakover + }); + Dictionary headers = await LoopbackHelper.WebSocketHandshakeAsync(connection, extensionsReply); + Assert.NotNull(headers); + Assert.True(headers.TryGetValue("Sec-WebSocket-Extensions", out string extensions)); + Assert.Equal(expected, extensions); + }), new LoopbackServer.Options { WebSocketEndpoint = true }); + } + + private static string CreateDeflateOptionsHeader(WebSocketDeflateOptions options) + { + var builder = new StringBuilder(); + builder.Append("permessage-deflate"); + + if (options.ClientMaxWindowBits != 15) + { + builder.Append("; client_max_window_bits=").Append(options.ClientMaxWindowBits); + } + + if (!options.ClientContextTakeover) + { + builder.Append("; client_no_context_takeover"); + } + + if (options.ServerMaxWindowBits != 15) + { + builder.Append("; server_max_window_bits=").Append(options.ServerMaxWindowBits); + } + + if (!options.ServerContextTakeover) + { + builder.Append("; server_no_context_takeover"); + } + + return builder.ToString(); + } + } +} diff --git a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs index 5726326c6ab8f..48d167b072f78 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs +++ b/src/libraries/System.Net.WebSockets.Client/tests/LoopbackHelper.cs @@ -11,7 +11,7 @@ namespace System.Net.WebSockets.Client.Tests { public static class LoopbackHelper { - public static async Task> WebSocketHandshakeAsync(LoopbackServer.Connection connection) + public static async Task> WebSocketHandshakeAsync(LoopbackServer.Connection connection, string? extensions = null) { string serverResponse = null; List headers = await connection.ReadRequestHeaderAsync().ConfigureAwait(false); @@ -34,6 +34,7 @@ public static async Task> WebSocketHandshakeAsync(Loo "Content-Length: 0\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + + (extensions is null ? null : $"Sec-WebSocket-Extensions: {extensions}\r\n") + "Sec-WebSocket-Accept: " + responseSecurityAcceptValue + "\r\n\r\n"; } } diff --git a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj index 248546468d0fa..adb1c3e447fb0 100644 --- a/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj +++ b/src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj @@ -38,6 +38,7 @@ + diff --git a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs index ca1e25ea21289..7b63abe187071 100644 --- a/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs +++ b/src/libraries/System.Net.WebSockets/ref/System.Net.WebSockets.cs @@ -31,6 +31,8 @@ protected WebSocket() { } public static System.Net.WebSockets.WebSocket CreateClientWebSocket(System.IO.Stream innerStream, string? subProtocol, int receiveBufferSize, int sendBufferSize, System.TimeSpan keepAliveInterval, bool useZeroMaskingKey, System.ArraySegment internalBuffer) { throw null; } [System.Runtime.Versioning.UnsupportedOSPlatform("browser")] public static System.Net.WebSockets.WebSocket CreateFromStream(System.IO.Stream stream, bool isServer, string? subProtocol, System.TimeSpan keepAliveInterval) { throw null; } + [System.Runtime.Versioning.UnsupportedOSPlatform("browser")] + public static System.Net.WebSockets.WebSocket CreateFromStream(System.IO.Stream stream, System.Net.WebSockets.WebSocketCreationOptions options) { throw null; } public static System.ArraySegment CreateServerBuffer(int receiveBufferSize) { throw null; } public abstract void Dispose(); [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)] @@ -133,4 +135,18 @@ public enum WebSocketState Closed = 5, Aborted = 6, } + public sealed partial class WebSocketCreationOptions + { + public bool IsServer { get { throw null; } set { } } + public string? SubProtocol { get { throw null; } set { } } + public System.TimeSpan KeepAliveInterval { get { throw null; } set { } } + public System.Net.WebSockets.WebSocketDeflateOptions? DeflateOptions { get { throw null; } set { } } + } + public sealed partial class WebSocketDeflateOptions + { + public int ClientMaxWindowBits { get { throw null; } set { } } + public bool ClientContextTakeover { get { throw null; } set { } } + public int ServerMaxWindowBits { get { throw null; } set { } } + public bool ServerContextTakeover { get { throw null; } set { } } + } } diff --git a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx index a4f630ea24c03..22cf53cd585a6 100644 --- a/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx +++ b/src/libraries/System.Net.WebSockets/src/Resources/Strings.resx @@ -1,4 +1,64 @@ - + + + @@ -138,4 +198,28 @@ The base stream is not writeable. - + + The argument must be a value between {0} and {1}. + + + The underlying compression routine could not be loaded correctly. + + + The underlying compression routine could not reserve sufficient memory. + + + The underlying compression routine returned an unexpected error code {0}. + + + The message was compressed using an unsupported compression method. + + + The WebSocket received a continuation frame with Per-Message Compressed flag set. + + + The stream state of the underlying compression routine is inconsistent. + + + The WebSocket received compressed frame when compression is not enabled. + + \ No newline at end of file diff --git a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj index d65e6c55737af..64aaea87d46f0 100644 --- a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj +++ b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj @@ -1,22 +1,44 @@ - + True - $(NetCoreAppCurrent) + $(NetCoreAppCurrent)-windows;$(NetCoreAppCurrent)-Unix;$(NetCoreAppCurrent)-Browser enable + + + + + + + + + + + + + + + + + diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs new file mode 100644 index 0000000000000..d10782b8b4252 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -0,0 +1,194 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using static System.IO.Compression.ZLibNative; + +namespace System.Net.WebSockets.Compression +{ + /// + /// Provides a wrapper around the ZLib compression API. + /// + internal sealed class WebSocketDeflater : IDisposable + { + private ZLibStreamHandle? _stream; + private readonly int _windowBits; + private readonly bool _persisted; + + internal WebSocketDeflater(int windowBits, bool persisted) + { + Debug.Assert(windowBits >= 9 && windowBits <= 15); + + // We use negative window bits in order to produce raw deflate data + _windowBits = -windowBits; + _persisted = persisted; + } + + public void Dispose() => _stream?.Dispose(); + + public void Deflate(ReadOnlySpan payload, IBufferWriter output, bool continuation, bool endOfMessage) + { + Debug.Assert(!continuation || _stream is not null, "Invalid state. The stream should not be null in continuations."); + + if (_stream is null) + { + Initialize(); + } + while (!payload.IsEmpty) + { + Deflate(payload, output.GetSpan(payload.Length), out int consumed, out int written); + output.Advance(written); + + payload = payload[consumed..]; + } + + // There is a catch here. If the payload we're trying to compress isn't really compressable + // then the resulting output will be larger. And in this case although we might have processed the input + // more output might be available. The only way to check for this scenario is to do another deflate + // attempt, but without any input this time. + while (true) + { + Deflate(null, output.GetSpan(), out int consumed, out int written); + Debug.Assert(consumed == 0); + + if (written == 0) + { + break; + } + + output.Advance(written); + } + + // See comment by Mark Adler https://github.com/madler/zlib/issues/149#issuecomment-225237457 + // At that point there will be at most a few bits left to write. + // Then call deflate() with Z_FULL_FLUSH and no more input and at least six bytes of available output. + Span end = output.GetSpan(6); + int count = Flush(end); + + Debug.Assert(end[..count].EndsWith(WebSocketInflater.FlushMarker), "The deflated block must always end with a flush marker."); + + if (endOfMessage) + { + // As per RFC we need to remove the flush markers + count -= 4; + } + + output.Advance(count); + + if (endOfMessage && !_persisted) + { + _stream.Dispose(); + _stream = null; + } + } + + private unsafe void Deflate(ReadOnlySpan input, Span output, out int consumed, out int written) + { + Debug.Assert(_stream is not null); + + fixed (byte* fixedInput = input) + fixed (byte* fixedOutput = output) + { + _stream.NextIn = (IntPtr)fixedInput; + _stream.AvailIn = (uint)input.Length; + + _stream.NextOut = (IntPtr)fixedOutput; + _stream.AvailOut = (uint)output.Length; + + // If flush is set to Z_BLOCK, a deflate block is completed + // and emitted, as for Z_SYNC_FLUSH, but the output + // is not aligned on a byte boundary, and up to seven bits + // of the current block are held to be written as the next byte after + // the next deflate block is completed. + Deflate(_stream, (FlushCode)5/*Z_BLOCK*/); + + consumed = input.Length - (int)_stream.AvailIn; + written = output.Length - (int)_stream.AvailOut; + } + } + + private unsafe int Flush(Span output) + { + Debug.Assert(_stream is not null); + Debug.Assert(_stream.AvailIn == 0); + Debug.Assert(output.Length >= 6); + + fixed (byte* fixedOutput = output) + { + _stream.NextIn = IntPtr.Zero; + _stream.AvailIn = 0; + + _stream.NextOut = (IntPtr)fixedOutput; + _stream.AvailOut = (uint)output.Length; + + ErrorCode errorCode = Deflate(_stream, (FlushCode)3/*Z_FULL_FLUSH*/); + int writtenBytes = output.Length - (int)_stream.AvailOut; + + Debug.Assert(errorCode == ErrorCode.Ok); + + return writtenBytes; + } + } + + private static ErrorCode Deflate(ZLibStreamHandle stream, FlushCode flushCode) + { + ErrorCode errorCode; + try + { + errorCode = stream.Deflate(flushCode); + } + catch (Exception cause) + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); + } + + switch (errorCode) + { + case ErrorCode.Ok: + case ErrorCode.StreamEnd: + return errorCode; + + case ErrorCode.BufError: + return errorCode; // This is a recoverable error + + case ErrorCode.StreamError: + throw new WebSocketException(SR.ZLibErrorInconsistentStream); + + default: + throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errorCode)); + } + } + + [MemberNotNull(nameof(_stream))] + private void Initialize() + { + Debug.Assert(_stream is null); + + var compressionLevel = CompressionLevel.DefaultCompression; + var memLevel = Deflate_DefaultMemLevel; + var strategy = CompressionStrategy.DefaultStrategy; + + ErrorCode errorCode; + try + { + errorCode = CreateZLibStreamForDeflate(out _stream, compressionLevel, _windowBits, memLevel, strategy); + } + catch (Exception cause) + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); + } + + switch (errorCode) + { + case ErrorCode.Ok: + return; + case ErrorCode.MemError: + throw new WebSocketException(SR.ZLibErrorNotEnoughMemory); + default: + throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errorCode)); + } + } + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs new file mode 100644 index 0000000000000..3c05af4ac2734 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs @@ -0,0 +1,229 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.InteropServices; +using static System.IO.Compression.ZLibNative; + +namespace System.Net.WebSockets.Compression +{ + /// + /// Provides a wrapper around the ZLib decompression API. + /// + internal sealed class WebSocketInflater : IDisposable + { + internal static ReadOnlySpan FlushMarker => new byte[] { 0x00, 0x00, 0xFF, 0xFF }; + + private ZLibStreamHandle? _stream; + private readonly int _windowBits; + private readonly bool _persisted; + + /// + /// There is no way of knowing, when decoding data, if the underlying deflater + /// has flushed all outstanding data to consumer other than to provide a buffer + /// and see whether any bytes are written. There are cases when the consumers + /// provide a buffer exactly the size of the uncompressed data and in this case + /// to avoid requiring another read we will use this field. + /// + private byte? _remainingByte; + + /// + /// When the inflater is persisted we need to manually append the flush marker + /// before finishing the decoding. + /// + private bool _needsFlushMarker; + + internal WebSocketInflater(int windowBits, bool persisted) + { + Debug.Assert(windowBits >= 9 && windowBits <= 15); + + // We use negative window bits to instruct deflater to expect raw deflate data + _windowBits = -windowBits; + _persisted = persisted; + } + + public void Dispose() => _stream?.Dispose(); + + public unsafe void Inflate(ReadOnlySpan input, Span output, out int consumed, out int written) + { + if (_stream is null) + { + Initialize(); + } + fixed (byte* fixedInput = &MemoryMarshal.GetReference(input)) + fixed (byte* fixedOutput = &MemoryMarshal.GetReference(output)) + { + _stream.NextIn = (IntPtr)fixedInput; + _stream.AvailIn = (uint)input.Length; + + _stream.NextOut = (IntPtr)fixedOutput; + _stream.AvailOut = (uint)output.Length; + + Inflate(_stream); + + consumed = input.Length - (int)_stream.AvailIn; + written = output.Length - (int)_stream.AvailOut; + } + + _needsFlushMarker = _persisted; + } + + /// + /// Finishes the decoding by writing any outstanding data to the output. + /// + /// true if the finish completed, false to indicate that there is more outstanding data. + public bool Finish(Span output, out int written) + { + Debug.Assert(_stream is not null); + + if (_needsFlushMarker) + { + Inflate(FlushMarker, output, out var _, out written); + _needsFlushMarker = false; + + if ( written < output.Length || IsFinished(_stream, out _remainingByte) ) + { + OnFinished(); + return true; + } + } + + written = 0; + + if (output.IsEmpty) + { + if (_remainingByte is not null) + { + return false; + } + if (IsFinished(_stream, out _remainingByte)) + { + OnFinished(); + return true; + } + } + else + { + if (_remainingByte is not null) + { + output[0] = _remainingByte.GetValueOrDefault(); + written = 1; + _remainingByte = null; + } + + written += Inflate(_stream, output[written..]); + + if (written < output.Length || IsFinished(_stream, out _remainingByte)) + { + OnFinished(); + return true; + } + } + + return false; + } + + private void OnFinished() + { + Debug.Assert(_stream is not null); + + if (!_persisted) + { + _stream.Dispose(); + _stream = null; + } + } + + private static unsafe bool IsFinished(ZLibStreamHandle stream, out byte? remainingByte) + { + if (stream.AvailIn > 0) + { + remainingByte = null; + return false; + } + + // There is no other way to make sure that we'e consumed all data + // but to try to inflate again with at least one byte of output buffer. + byte b; + if (Inflate(stream, new Span(&b, 1)) == 0) + { + remainingByte = null; + return true; + } + + remainingByte = b; + return false; + } + + private static unsafe int Inflate(ZLibStreamHandle stream, Span destination) + { + fixed (byte* bufPtr = &MemoryMarshal.GetReference(destination)) + { + stream.NextOut = (IntPtr)bufPtr; + stream.AvailOut = (uint)destination.Length; + + Inflate(stream); + return destination.Length - (int)stream.AvailOut; + } + } + + private static void Inflate(ZLibStreamHandle stream) + { + ErrorCode errorCode; + try + { + errorCode = stream.Inflate(FlushCode.NoFlush); + } + catch (Exception cause) // could not load the Zlib DLL correctly + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); + } + switch (errorCode) + { + case ErrorCode.Ok: // progress has been made inflating + case ErrorCode.StreamEnd: // The end of the input stream has been reached + case ErrorCode.BufError: // No room in the output buffer - inflate() can be called again with more space to continue + break; + + case ErrorCode.MemError: // Not enough memory to complete the operation + throw new WebSocketException(SR.ZLibErrorNotEnoughMemory); + + case ErrorCode.DataError: // The input data was corrupted (input stream not conforming to the zlib format or incorrect check value) + throw new WebSocketException(SR.UnsupportedCompression); + + case ErrorCode.StreamError: //the stream structure was inconsistent (for example if next_in or next_out was NULL), + throw new WebSocketException(SR.ZLibErrorInconsistentStream); + + default: + throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)errorCode)); + } + } + + [MemberNotNull(nameof(_stream))] + private void Initialize() + { + Debug.Assert(_stream is null); + + ErrorCode error; + try + { + error = CreateZLibStreamForInflate(out _stream, _windowBits); + } + catch (Exception exception) + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, exception); + } + + switch (error) + { + case ErrorCode.Ok: + return; + case ErrorCode.MemError: + throw new WebSocketException(SR.ZLibErrorNotEnoughMemory); + default: + throw new WebSocketException(string.Format(SR.ZLibErrorUnexpected, (int)error)); + } + } + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs new file mode 100644 index 0000000000000..8fab03fd081f6 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Receiver.cs @@ -0,0 +1,488 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Diagnostics; +using System.IO; +using System.Net.WebSockets.Compression; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.WebSockets +{ + internal partial class ManagedWebSocket + { + private enum ReceiveResultType + { + Message, + ConnectionClose, + ControlMessage, + HeaderError + } + + [StructLayout(LayoutKind.Auto)] + private readonly struct ReceiveResult + { + public int Count { get; init; } + public bool EndOfMessage { get; init; } + public ReceiveResultType ResultType { get; init; } + public WebSocketMessageType MessageType { get; init; } + } + + private sealed class Receiver : IDisposable + { + private readonly bool _isServer; + private readonly Stream _stream; + private readonly WebSocketInflater? _inflater; + + /// + /// Because a message might be split into multiple fragments we need to + /// keep in mind that even if we've processed all of them, we need to call + /// finish on the decoder to flush any left over data. + /// + private bool _inflateFinished = true; + + /// + /// If we have compression we cannot use the buffer provided from clients because + /// we cannot guarantee that the decoding can happen in place. This buffer is rent'ed + /// and returned when consumed. + /// + private Memory _inflateBuffer; + + /// + /// The last header received in a ReceiveAsync. If ReceiveAsync got a header but then + /// returned fewer bytes than was indicated in the header, subsequent ReceiveAsync calls + /// will use the data from the header to construct the subsequent receive results, and + /// the payload length in this header will be decremented to indicate the number of bytes + /// remaining to be received for that header. As a result, between fragments, the payload + /// length in this header should be 0. + /// + private MessageHeader _lastHeader = new() { Opcode = MessageOpcode.Text, Fin = true }; + + /// + /// Buffer used for reading data from the network. + /// Not readonly here because the buffer is mutable and is a struct. + /// + private Buffer _readBuffer; + + /// + /// When dealing with partially read fragments of binary/text messages, a mask previously received may still + /// apply, and the first new byte received may not correspond to the 0th position in the mask. This value is + /// the next offset into the mask that should be applied. + /// + private int _receivedMaskOffset; + + /// + /// When parsing message header if an error occurs the websocket is notified and this + /// will contain the error message. + /// + private string? _headerError; + + public Receiver(Stream stream, WebSocketCreationOptions options) + { + _stream = stream; + _isServer = options.IsServer; + + // Create a buffer just large enough to handle received packet headers (at most 14 bytes) and + // control payloads (at most 125 bytes). Message payloads are read directly into the buffer + // supplied to ReceiveAsync. + _readBuffer = new Buffer(MaxControlPayloadLength + MaxMessageHeaderLength); + + var deflate = options.DeflateOptions; + + if (deflate is not null) + { + // Important note here is that we must use negative window bits + // which will instruct the underlying implementation to not expect deflate headers + _inflater = options.IsServer ? + new WebSocketInflater(deflate.ServerMaxWindowBits, deflate.ServerContextTakeover) : + new WebSocketInflater(deflate.ClientMaxWindowBits, deflate.ClientContextTakeover); + } + } + + public void Dispose() + { + _inflater?.Dispose(); + ReturnInflateBuffer(); + } + + public string? GetHeaderError() => _headerError; + + /// Issues a read on the stream to wait for EOF. + public async ValueTask WaitForServerToCloseConnectionAsync(CancellationToken cancellationToken) + { + if (_readBuffer.FreeLength == 0) + { + // Because we are going to need only 1 byte buffer, do a discard + // only when necessary (avoiding needless copying). + _readBuffer.DiscardConsumed(); + } + // Per RFC 6455 7.1.1, try to let the server close the connection. We give it up to a second. + // We simply issue a read and don't care what we get back; we could validate that we don't get + // additional data, but at this point we're about to close the connection and we're just stalling + // to try to get the server to close first. + ValueTask finalReadTask = _stream.ReadAsync(_readBuffer.FreeMemory.Slice(start: 0, length: 1), cancellationToken); + + if (!finalReadTask.IsCompletedSuccessfully) + { + // Wait an arbitrary amount of time to give the server (same as netfx, 1 second) + using var cts = finalReadTask.IsCompleted ? null : new CancellationTokenSource(TimeSpan.FromSeconds(1)); + using var cancellation = cts is not null ? cts.Token.UnsafeRegister(static s => ((Stream)s!).Dispose(), _stream) : default; + + // TODO: Once this is merged https://github.com/dotnet/runtime/issues/47525 + // use configure await with the option to suppress exceptions and remove the try catch + try + { + await finalReadTask.ConfigureAwait(false); + } + catch + { + // Eat any resulting exceptions. We were going to close the connection, anyway. + } + } + } + + public async ValueTask ReceiveControlMessageAsync(CancellationToken cancellationToken) + { + Debug.Assert(_lastHeader.Opcode > MessageOpcode.Binary); + + if (_lastHeader.PayloadLength == 0) + { + return new ControlMessage(_lastHeader.Opcode, ReadOnlyMemory.Empty); + } + _readBuffer.DiscardConsumed(); + + while (_lastHeader.PayloadLength > _readBuffer.AvailableLength) + { + int byteCount = await _stream.ReadAsync(_readBuffer.FreeMemory, cancellationToken).ConfigureAwait(false); + if (byteCount <= 0) + { + return null; + } + ApplyMask(_readBuffer.FreeMemory.Span.Slice(0, (int)Math.Min(_lastHeader.PayloadLength, byteCount))); + _readBuffer.Commit(byteCount); + } + + // Update the payload length in the header to indicate + // that we've received everything we need. + ReadOnlyMemory payload = _readBuffer.AvailableMemory.Slice(0, (int)_lastHeader.PayloadLength); + + _readBuffer.Consume(payload.Length); + _lastHeader.PayloadLength = 0; + + return new ControlMessage(_lastHeader.Opcode, payload); + } + + public async ValueTask ReceiveAsync(Memory output, CancellationToken cancellationToken) + { + // When there's nothing left over to receive, start a new + if (_lastHeader.PayloadLength == 0) + { + if (!_inflateFinished) + { + Debug.Assert(_inflater is not null); + _inflateFinished = _inflater.Finish(output.Span, out int written); + + return Result(written); + } + + _readBuffer.DiscardConsumed(); + + if (!await ReceiveHeaderAsync(cancellationToken).ConfigureAwait(false)) + { + return Result(_headerError is not null ? ReceiveResultType.HeaderError : ReceiveResultType.ConnectionClose); + } + if (_lastHeader.Opcode > MessageOpcode.Binary) + { + // The received message is a control message and it's up + // to the websocket how to handle it. + return Result(ReceiveResultType.ControlMessage); + } + } + + if (output.IsEmpty) + { + return Result(count: 0); + } + // The number of bytes that are written to the output buffer + int outputByteCount = 0; + + if (_readBuffer.AvailableLength > 0) + { + if (!ConsumeReadBuffer(output.Span, out int written)) + { + return Result(written); + } + outputByteCount += written; + output = output[written..]; + } + + // At this point we should have consumed everything from the read buffer + // and should start issuing reads on the stream. + Debug.Assert(_readBuffer.AvailableLength == 0 && _lastHeader.PayloadLength > 0); + + int receivedByteCount = _lastHeader.Compressed ? + await ReceiveCompressedAsync(output, cancellationToken).ConfigureAwait(false) : + await ReceiveUncompressedAsync(output, cancellationToken).ConfigureAwait(false); + + if (receivedByteCount == 0) + { + return Result(ReceiveResultType.ConnectionClose); + } + + return Result(outputByteCount + receivedByteCount); + } + + private async ValueTask ReceiveUncompressedAsync(Memory output, CancellationToken cancellationToken) + { + Debug.Assert(!_lastHeader.Compressed); + + if (output.Length > _lastHeader.PayloadLength) + { + // We don't want to receive more than we need + output = output.Slice(0, (int)_lastHeader.PayloadLength); + } + + int bytesRead = await _stream.ReadAsync(output, cancellationToken).ConfigureAwait(false); + if (bytesRead > 0) + { + _lastHeader.PayloadLength -= bytesRead; + ApplyMask(output.Span.Slice(0, bytesRead)); + } + + return bytesRead; + } + + private async ValueTask ReceiveCompressedAsync(Memory output, CancellationToken cancellationToken) + { + Debug.Assert(_lastHeader.Compressed); + Debug.Assert(_inflater is not null); + + if (_inflateBuffer.IsEmpty) + { + if (!await LoadInflateBufferAsync(cancellationToken).ConfigureAwait(false)) + { + return 0; + } + } + + _inflater.Inflate(_inflateBuffer.Span, output.Span, out int consumed, out int outputByteCount); + _lastHeader.PayloadLength -= consumed; + _inflateBuffer = _inflateBuffer.Slice(consumed); + + if (_inflateBuffer.IsEmpty) + { + ReturnInflateBuffer(); + + if (_lastHeader.PayloadLength == 0 && _lastHeader.Fin) + { + _inflateFinished = _inflater.Finish(output.Span.Slice(outputByteCount), out var written); + outputByteCount += written; + } + } + + return outputByteCount; + } + + private async ValueTask ReceiveHeaderAsync(CancellationToken cancellationToken) + { + Debug.Assert(_lastHeader.PayloadLength == 0); + + _receivedMaskOffset = 0; + + while (true) + { + if (TryParseMessageHeader(_readBuffer.AvailableSpan, _lastHeader, _isServer, + out MessageHeader header, out string? error, out int consumedBytes)) + { + if (header.Compressed && _inflater is null) + { + _headerError = SR.net_Websockets_PerMessageCompressedFlagWhenNotEnabled; + return false; + } + + // If this is a continuation, replace the opcode with the one of the message it's continuing + if (header.Opcode == MessageOpcode.Continuation) + { + header.Opcode = _lastHeader.Opcode; + header.Compressed = _lastHeader.Compressed; + } + + _lastHeader = header; + _readBuffer.Consume(consumedBytes); + + if (_isServer) + { + // Unmask any payload that we've received + if (header.PayloadLength > 0 && _readBuffer.AvailableLength > 0) + { + ApplyMask(_readBuffer.AvailableSpan.Slice(0, (int)Math.Min(_readBuffer.AvailableLength, header.PayloadLength))); + } + } + + break; + } + else if (error is not null) + { + _headerError = error; + return false; + } + + // More data is neeed to parse the header + int byteCount = await _stream.ReadAsync(_readBuffer.FreeMemory, cancellationToken).ConfigureAwait(false); + if (byteCount <= 0) + { + return false; + } + _readBuffer.Commit(byteCount); + } + + return true; + } + + private ReceiveResult Result(int count) => new ReceiveResult + { + Count = count, + ResultType = ReceiveResultType.Message, + MessageType = _lastHeader.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary, + EndOfMessage = _lastHeader.Fin && _lastHeader.PayloadLength == 0 && _inflateFinished + }; + + private ReceiveResult Result(ReceiveResultType resultType) => new ReceiveResult + { + ResultType = resultType + }; + + private void ApplyMask(Span input) + { + if (_isServer) + { + _receivedMaskOffset = ManagedWebSocket.ApplyMask(input, _lastHeader.Mask, _receivedMaskOffset); + } + } + + /// + /// Tries to consume anything remaining in _readBuffer for the current message.s + /// + /// + /// True when the read buffer is consumed and there's more to be processed, + /// and the output buffer is not full. + /// + private bool ConsumeReadBuffer(Span output, out int outputByteCount) + { + Debug.Assert(_readBuffer.AvailableLength > 0); + + int consumed, written; + int available = (int)Math.Min(_readBuffer.AvailableLength, _lastHeader.PayloadLength); + + if (_lastHeader.Compressed) + { + Debug.Assert(_inflater is not null); + _inflater.Inflate(input: _readBuffer.AvailableSpan.Slice(0, available), + output, out consumed, out written); + } + else + { + // We can copy directly to output + written = Math.Min(available, output.Length); + consumed = written; + _readBuffer.AvailableSpan.Slice(0, written).CopyTo(output); + } + + _readBuffer.Consume(consumed); + _lastHeader.PayloadLength -= consumed; + + outputByteCount = written; + + if (_lastHeader.PayloadLength == 0 || output.Length == written) + { + // We have either consumed everything or the output is full. + // In this case we try to finish inflating if needed and return. + if (_inflater is not null && _lastHeader.Compressed && _lastHeader.PayloadLength == 0 && _lastHeader.Fin) + { + _inflateFinished = _inflater.Finish(output.Slice(written), out written); + outputByteCount += written; + } + + return false; + } + + return true; + } + + private async ValueTask LoadInflateBufferAsync(CancellationToken cancellationToken) + { + // Rent a buffer but restrict it's max size to 1MB + int decoderBufferLength = (int)Math.Min(_lastHeader.PayloadLength, 1_000_000); + + _inflateBuffer = ArrayPool.Shared.Rent(decoderBufferLength); + int byteCount = await _stream.ReadAsync(_inflateBuffer, cancellationToken).ConfigureAwait(false); + + if (byteCount <= 0) + { + ReturnInflateBuffer(); + return false; + } + + _inflateBuffer = _inflateBuffer.Slice(0, byteCount); + ApplyMask(_inflateBuffer.Span); + + return true; + } + + private void ReturnInflateBuffer() + { + if (MemoryMarshal.TryGetArray(_inflateBuffer, out ArraySegment arraySegment) + && arraySegment.Array!.Length > 0) + { + Debug.Assert(arraySegment.Array is not null); + ArrayPool.Shared.Return(arraySegment.Array); + _inflateBuffer = null; + } + } + + [StructLayout(LayoutKind.Auto)] + private struct Buffer + { + private readonly byte[] _bytes; + private int _position; + private int _consumed; + + public Buffer(int capacity) + { + _bytes = GC.AllocateUninitializedArray(capacity, pinned: true); + _position = 0; + _consumed = 0; + } + + public int AvailableLength => _position - _consumed; + + public Span AvailableSpan => + new Span(_bytes, start: _consumed, length: _position - _consumed); + + public Memory AvailableMemory => + new Memory(_bytes, start: _consumed, length: _position - _consumed); + + public Memory FreeMemory => _bytes.AsMemory(_position); + + public int FreeLength => _bytes.Length - _position; + + public void Commit(int count) => _position += count; + + public void Consume(int count) => _consumed += count; + + public void DiscardConsumed() + { + if (AvailableLength > 0) + { + AvailableMemory.CopyTo(_bytes); + } + + _position -= _consumed; + _consumed = 0; + } + } + } + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs new file mode 100644 index 0000000000000..fab2a0c544d76 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.Sender.cs @@ -0,0 +1,267 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Net.WebSockets.Compression; +using System.Security.Cryptography; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.WebSockets +{ + internal partial class ManagedWebSocket + { + private sealed class Sender : IDisposable + { + private readonly int _maskLength; + private readonly WebSocketDeflater? _deflater; + private readonly Stream _stream; + + private readonly Buffer _buffer = new(); + + public Sender(Stream stream, WebSocketCreationOptions options) + { + _maskLength = options.IsServer ? 0 : MaskLength; + _stream = stream; + + var deflate = options.DeflateOptions; + + if (deflate is not null) + { + // If we are the server we must use the client options + _deflater = options.IsServer ? + new WebSocketDeflater(deflate.ClientMaxWindowBits, deflate.ClientContextTakeover) : + new WebSocketDeflater(deflate.ServerMaxWindowBits, deflate.ServerContextTakeover); + } + } + + public void Dispose() => _deflater?.Dispose(); + + public ValueTask SendAsync(MessageOpcode opcode, bool endOfMessage, ReadOnlyMemory content, CancellationToken cancellationToken = default) + { + bool compressed = false; + + // Compression is only supported for user messages + if (_deflater is not null && opcode <= MessageOpcode.Binary) + { + _buffer.EnsureFreeCapacity(MaxMessageHeaderLength + (int)(content.Length * 0.6)); + _buffer.Advance(MaxMessageHeaderLength); + + _deflater.Deflate(content.Span, _buffer, continuation: opcode == MessageOpcode.Continuation, endOfMessage); + compressed = true; + } + else if (!content.IsEmpty) + { + _buffer.EnsureFreeCapacity(MaxMessageHeaderLength + content.Length); + _buffer.Advance(MaxMessageHeaderLength); + + content.Span.CopyTo(_buffer.GetSpan(content.Length)); + _buffer.Advance(content.Length); + } + else + { + _buffer.EnsureFreeCapacity(MaxMessageHeaderLength); + _buffer.Advance(MaxMessageHeaderLength); + } + + Span payload = _buffer.WrittenSpan.Slice(MaxMessageHeaderLength); + int headerLength = CalculateHeaderLength(payload.Length); + + // Because we want the header to come just before to the payload + // we will use a slice that offsets the unused part. + int headerOffset = MaxMessageHeaderLength - headerLength; + Span header = _buffer.WrittenSpan.Slice(headerOffset, headerLength); + + // Write the message header data to the buffer. + EncodeHeader(header, opcode, endOfMessage, payload.Length, compressed); + + // If we added a mask to the header, XOR the payload with the mask. + if (!payload.IsEmpty && _maskLength > 0) + { + ApplyMask(payload, BitConverter.ToInt32(header.Slice(header.Length - MaskLength)), 0); + } + + bool resetBuffer = true; + + try + { + ValueTask sendTask = _stream.WriteAsync(_buffer.WrittenMemory.Slice(headerOffset), cancellationToken); + + if (sendTask.IsCompleted) + { + return sendTask; + } + resetBuffer = false; + return WaitAsync(sendTask); + } + finally + { + if (resetBuffer) + { + _buffer.Reset(); + } + } + } + + private async ValueTask WaitAsync(ValueTask sendTask) + { + try + { + await sendTask.ConfigureAwait(false); + } + finally + { + _buffer.Reset(); + } + } + + private int CalculateHeaderLength(int payloadLength) => _maskLength + (payloadLength switch + { + <= 125 => 2, + <= ushort.MaxValue => 4, + _ => 10 + }); + + private void EncodeHeader(Span header, MessageOpcode opcode, bool endOfMessage, int payloadLength, bool compressed) + { + // Client header format: + // 1 bit - FIN - 1 if this is the final fragment in the message (it could be the only fragment), otherwise 0 + // 1 bit - RSV1 - Per-Message Deflate Compress + // 1 bit - RSV2 - Reserved - 0 + // 1 bit - RSV3 - Reserved - 0 + // 4 bits - Opcode - How to interpret the payload + // - 0x0 - continuation + // - 0x1 - text + // - 0x2 - binary + // - 0x8 - connection close + // - 0x9 - ping + // - 0xA - pong + // - (0x3 to 0x7, 0xB-0xF - reserved) + // 1 bit - Masked - 1 if the payload is masked, 0 if it's not. Must be 1 for the client + // 7 bits, 7+16 bits, or 7+64 bits - Payload length + // - For length 0 through 125, 7 bits storing the length + // - For lengths 126 through 2^16, 7 bits storing the value 126, followed by 16 bits storing the length + // - For lengths 2^16+1 through 2^64, 7 bits storing the value 127, followed by 64 bytes storing the length + // 0 or 4 bytes - Mask, if Masked is 1 - random value XOR'd with each 4 bytes of the payload, round-robin + // Length bytes - Payload data + header[0] = (byte)opcode; // 4 bits for the opcode + + if (compressed && opcode != MessageOpcode.Continuation) + { + header[0] |= 0b0100_0000; + } + + if (endOfMessage) + { + header[0] |= 0b1000_0000; // 1 bit for FIN + } + + // Store the payload length. + if (payloadLength <= 125) + { + header[1] = (byte)payloadLength; + } + else if (payloadLength <= ushort.MaxValue) + { + header[1] = 126; + header[2] = (byte)(payloadLength / 256); + header[3] = unchecked((byte)payloadLength); + } + else + { + header[1] = 127; + for (int i = 9; i >= 2; i--) + { + header[i] = unchecked((byte)payloadLength); + payloadLength = payloadLength / 256; + } + } + + if (_maskLength > 0) + { + // Generate the mask. + header[1] |= 0x80; + RandomNumberGenerator.Fill(header.Slice(header.Length - MaskLength)); + } + } + + /// + /// Helper class which allows writing to a rent'ed byte array + /// and auto-grow functionality. + /// + private sealed class Buffer : IBufferWriter + { + private readonly ArrayPool _arrayPool; + + private byte[]? _array; + private int _index; + + public Buffer() + { + _arrayPool = ArrayPool.Shared; + } + + public Span WrittenSpan => _array.AsSpan(0, _index); + + public ReadOnlyMemory WrittenMemory => new ReadOnlyMemory(_array, 0, _index); + + public void Advance(int count) + { + Debug.Assert(_array is not null); + Debug.Assert(count >= 0); + Debug.Assert(_index + count <= _array.Length); + + _index += count; + } + + public Memory GetMemory(int sizeHint = 0) + { + EnsureFreeCapacity(sizeHint); + return _array.AsMemory(_index); + } + + public Span GetSpan(int sizeHint = 0) + { + EnsureFreeCapacity(sizeHint); + return _array.AsSpan(_index); + } + + public void Reset() + { + if (_array is not null) + { + _arrayPool.Return(_array); + _array = null; + _index = 0; + } + } + + [MemberNotNull(nameof(_array))] + public void EnsureFreeCapacity(int sizeHint) + { + if (sizeHint == 0) + { + sizeHint = 1; + } + if (_array is null) + { + _array = _arrayPool.Rent(sizeHint); + return; + } + + if (sizeHint > (_array.Length - _index)) + { + byte[] newArray = _arrayPool.Rent(_array.Length + sizeHint); + _array.AsSpan().CopyTo(newArray); + + _arrayPool.Return(_array); + _array = newArray; + } + } + } + } + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs index bf4888a57b057..28d1c5df2d56c 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -8,7 +8,6 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.Versioning; -using System.Security.Cryptography; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -29,18 +28,13 @@ internal sealed partial class ManagedWebSocket : WebSocket { /// Creates a from a connected to a websocket endpoint. /// The connected Stream. - /// true if this is the server-side of the connection; false if this is the client-side of the connection. - /// The agreed upon subprotocol for the connection. - /// The interval to use for keep-alive pings. + /// The options with which the websocket must be created. /// The created instance. - public static ManagedWebSocket CreateFromConnectedStream( - Stream stream, bool isServer, string? subprotocol, TimeSpan keepAliveInterval) + public static ManagedWebSocket CreateFromConnectedStream(Stream stream, WebSocketCreationOptions options) { - return new ManagedWebSocket(stream, isServer, subprotocol, keepAliveInterval); + return new ManagedWebSocket(stream, options); } - /// Thread-safe random number generator used to generate masks for each send. - private static readonly RandomNumberGenerator s_random = RandomNumberGenerator.Create(); /// Encoding for the payload of text messages: UTF8 encoding that throws if invalid bytes are discovered, per the RFC. private static readonly UTF8Encoding s_textEncoding = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true); @@ -76,18 +70,13 @@ public static ManagedWebSocket CreateFromConnectedStream( private readonly string? _subprotocol; /// Timer used to send periodic pings to the server, at the interval specified private readonly Timer? _keepAliveTimer; - /// CancellationTokenSource used to abort all current and future operations when anything is canceled or any error occurs. - private readonly CancellationTokenSource _abortSource = new CancellationTokenSource(); - /// Buffer used for reading data from the network. - private readonly Memory _receiveBuffer; - /// - /// Tracks the state of the validity of the UTF8 encoding of text payloads. Text may be split across fragments. - /// - private readonly Utf8MessageState _utf8TextState = new Utf8MessageState(); + /// /// Semaphore used to ensure that calls to SendFrameAsync don't run concurrently. /// private readonly SemaphoreSlim _sendFrameAsyncLock = new SemaphoreSlim(1, 1); + private readonly Sender _sender; + private readonly Receiver _receiver; // We maintain the current WebSocketState in _state. However, we separately maintain _sentCloseFrame and _receivedCloseFrame // as there isn't a strict ordering between CloseSent and CloseReceived. If we receive a close frame from the server, we need to @@ -108,31 +97,10 @@ public static ManagedWebSocket CreateFromConnectedStream( private string? _closeStatusDescription; /// - /// The last header received in a ReceiveAsync. If ReceiveAsync got a header but then - /// returned fewer bytes than was indicated in the header, subsequent ReceiveAsync calls - /// will use the data from the header to construct the subsequent receive results, and - /// the payload length in this header will be decremented to indicate the number of bytes - /// remaining to be received for that header. As a result, between fragments, the payload - /// length in this header should be 0. - /// - private MessageHeader _lastReceiveHeader = new MessageHeader { Opcode = MessageOpcode.Text, Fin = true }; - /// The offset of the next available byte in the _receiveBuffer. - private int _receiveBufferOffset; - /// The number of bytes available in the _receiveBuffer. - private int _receiveBufferCount; - /// - /// When dealing with partially read fragments of binary/text messages, a mask previously received may still - /// apply, and the first new byte received may not correspond to the 0th position in the mask. This value is - /// the next offset into the mask that should be applied. - /// - private int _receivedMaskOffsetOffset; - /// - /// Temporary send buffer. This should be released back to the ArrayPool once it's - /// no longer needed for the current send operation. It is stored as an instance - /// field to minimize needing to pass it around and to avoid it becoming a field on - /// various async state machine objects. + /// Tracks the state of the validity of the UTF8 encoding of text payloads. Text may be split across fragments. /// - private byte[]? _sendBuffer; + private readonly Utf8MessageState _utf8TextState = new Utf8MessageState(); + /// /// Whether the last SendAsync had endOfMessage==false. We need to track this so that we /// can send the subsequent message with a continuation opcode if the last message was a fragment. @@ -145,21 +113,19 @@ public static ManagedWebSocket CreateFromConnectedStream( private Task _lastReceiveAsync = Task.CompletedTask; /// Lock used to protect update and check-and-update operations on _state. - private object StateUpdateLock => _abortSource; + private object StateUpdateLock => _sender; /// /// We need to coordinate between receives and close operations happening concurrently, as a ReceiveAsync may /// be pending while a Close{Output}Async is issued, which itself needs to loop until a close frame is received. /// As such, we need thread-safety in the management of . /// - private object ReceiveAsyncLock => _utf8TextState; // some object, as we're simply lock'ing on it + private object ReceiveAsyncLock => _receiver; // some object, as we're simply lock'ing on it - /// Initializes the websocket. - /// The connected Stream. - /// true if this is the server-side of the connection; false if this is the client-side of the connection. - /// The agreed upon subprotocol for the connection. - /// The interval to use for keep-alive pings. - private ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, TimeSpan keepAliveInterval) + private ManagedWebSocket(Stream stream, WebSocketCreationOptions options) { + _sender = new Sender(stream, options); + _receiver = new Receiver(stream, options); + Debug.Assert(StateUpdateLock != null, $"Expected {nameof(StateUpdateLock)} to be non-null"); Debug.Assert(ReceiveAsyncLock != null, $"Expected {nameof(ReceiveAsyncLock)} to be non-null"); Debug.Assert(StateUpdateLock != ReceiveAsyncLock, "Locks should be different objects"); @@ -167,41 +133,15 @@ private ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Time Debug.Assert(stream != null, $"Expected non-null stream"); Debug.Assert(stream.CanRead, $"Expected readable stream"); Debug.Assert(stream.CanWrite, $"Expected writeable stream"); - Debug.Assert(keepAliveInterval == Timeout.InfiniteTimeSpan || keepAliveInterval >= TimeSpan.Zero, $"Invalid keepalive interval: {keepAliveInterval}"); _stream = stream; - _isServer = isServer; - _subprotocol = subprotocol; - - // Create a buffer just large enough to handle received packet headers (at most 14 bytes) and - // control payloads (at most 125 bytes). Message payloads are read directly into the buffer - // supplied to ReceiveAsync. - const int ReceiveBufferMinLength = MaxControlPayloadLength; - _receiveBuffer = new byte[ReceiveBufferMinLength]; - - // Set up the abort source so that if it's triggered, we transition the instance appropriately. - // There's no need to store the resulting CancellationTokenRegistration, as this instance owns - // the CancellationTokenSource, and the lifetime of that CTS matches the lifetime of the registration. - _abortSource.Token.UnsafeRegister(static s => - { - var thisRef = (ManagedWebSocket)s!; - - lock (thisRef.StateUpdateLock) - { - WebSocketState state = thisRef._state; - if (state != WebSocketState.Closed && state != WebSocketState.Aborted) - { - thisRef._state = state != WebSocketState.None && state != WebSocketState.Connecting ? - WebSocketState.Aborted : - WebSocketState.Closed; - } - } - }, this); + _isServer = options.IsServer; + _subprotocol = options.SubProtocol; // Now that we're opened, initiate the keep alive timer to send periodic pings. // We use a weak reference from the timer to the web socket to avoid a cycle // that could keep the web socket rooted in erroneous cases. - if (keepAliveInterval > TimeSpan.Zero) + if (options.KeepAliveInterval > TimeSpan.Zero) { _keepAliveTimer = new Timer(static s => { @@ -210,7 +150,7 @@ private ManagedWebSocket(Stream stream, bool isServer, string? subprotocol, Time { thisRef.SendKeepAliveFrameAsync(); } - }, new WeakReference(this), keepAliveInterval, keepAliveInterval); + }, new WeakReference(this), options.KeepAliveInterval, options.KeepAliveInterval); } } @@ -229,7 +169,10 @@ private void DisposeCore() { _disposed = true; _keepAliveTimer?.Dispose(); - _stream?.Dispose(); + _stream.Dispose(); + _sender.Dispose(); + _receiver.Dispose(); + if (_state < WebSocketState.Aborted) { _state = WebSocketState.Closed; @@ -247,14 +190,6 @@ private void DisposeCore() public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { - if (messageType != WebSocketMessageType.Text && messageType != WebSocketMessageType.Binary) - { - throw new ArgumentException(SR.Format( - SR.net_WebSockets_Argument_InvalidMessageType, - nameof(WebSocketMessageType.Close), nameof(SendAsync), nameof(WebSocketMessageType.Binary), nameof(WebSocketMessageType.Text), nameof(CloseOutputAsync)), - nameof(messageType)); - } - WebSocketValidate.ValidateArraySegment(buffer, nameof(buffer)); return SendPrivateAsync(buffer, messageType, endOfMessage, cancellationToken).AsTask(); @@ -353,10 +288,23 @@ private async Task CloseOutputAsyncCore(WebSocketCloseStatus closeStatus, string public override void Abort() { - _abortSource.Cancel(); + OnAborted(); Dispose(); // forcibly tear down connection } + private void OnAborted() + { + lock (StateUpdateLock) + { + if (_state is not (WebSocketState.Closed or WebSocketState.Aborted)) + { + _state = _state is not (WebSocketState.None or WebSocketState.Connecting) ? + WebSocketState.Aborted : WebSocketState.Closed; + } + } + } + + public override ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { return SendPrivateAsync(buffer, messageType, endOfMessage, cancellationToken); @@ -395,14 +343,13 @@ public override ValueTask ReceiveAsync(Memory } } - private Task ValidateAndReceiveAsync(Task receiveTask, byte[] buffer, CancellationToken cancellationToken) + private Task ValidateAndReceiveAsync(Task receiveTask, CancellationToken cancellationToken) { - if (receiveTask == null || - (receiveTask.IsCompletedSuccessfully && - !(receiveTask is Task wsrr && wsrr.Result.MessageType == WebSocketMessageType.Close) && - !(receiveTask is Task vwsrr && vwsrr.Result.MessageType == WebSocketMessageType.Close))) + if (receiveTask.IsCompletedSuccessfully && + !(receiveTask is Task wsrr && wsrr.Result.MessageType == WebSocketMessageType.Close) && + !(receiveTask is Task vwsrr && vwsrr.Result.MessageType == WebSocketMessageType.Close)) { - ValueTask vt = ReceiveAsyncPrivate(buffer, cancellationToken); + ValueTask vt = ReceiveAsyncPrivate(Memory.Empty, cancellationToken); receiveTask = vt.IsCompletedSuccessfully ? (vt.Result.MessageType == WebSocketMessageType.Close ? s_cachedCloseTask : Task.CompletedTask) : vt.AsTask(); @@ -445,12 +392,10 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, // If we get here, the cancellation token is not cancelable so we don't have to worry about it, // and we own the semaphore, so we don't need to asynchronously wait for it. ValueTask writeTask = default; - bool releaseSendBufferAndSemaphore = true; + bool releaseSemaphore = true; try { - // Write the payload synchronously to the buffer, then write that buffer out to the network. - int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, payloadBuffer.Span); - writeTask = _stream.WriteAsync(new ReadOnlyMemory(_sendBuffer, 0, sendBytes)); + writeTask = _sender.SendAsync(opcode, endOfMessage, payloadBuffer); // If the operation happens to complete synchronously (or, more specifically, by // the time we get from the previous line to here), release the semaphore, return @@ -463,7 +408,7 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, // Up until this point, if an exception occurred (such as when accessing _stream or when // calling GetResult), we want to release the semaphore and the send buffer. After this point, // both need to be held until writeTask completes. - releaseSendBufferAndSemaphore = false; + releaseSemaphore = false; } catch (Exception exc) { @@ -474,9 +419,8 @@ private ValueTask SendFrameLockAcquiredNonCancelableAsync(MessageOpcode opcode, } finally { - if (releaseSendBufferAndSemaphore) + if (releaseSemaphore) { - ReleaseSendBuffer(); _sendFrameAsyncLock.Release(); } } @@ -498,7 +442,6 @@ private async ValueTask WaitForWriteTaskAsync(ValueTask writeTask) } finally { - ReleaseSendBuffer(); _sendFrameAsyncLock.Release(); } } @@ -508,10 +451,9 @@ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfM await _sendFrameAsyncLock.WaitAsync(cancellationToken).ConfigureAwait(false); try { - int sendBytes = WriteFrameToSendBuffer(opcode, endOfMessage, payloadBuffer.Span); using (cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this)) { - await _stream.WriteAsync(new ReadOnlyMemory(_sendBuffer, 0, sendBytes), cancellationToken).ConfigureAwait(false); + await _sender.SendAsync(opcode, endOfMessage, payloadBuffer, cancellationToken).ConfigureAwait(false); } } catch (Exception exc) when (!(exc is OperationCanceledException)) @@ -522,52 +464,10 @@ private async ValueTask SendFrameFallbackAsync(MessageOpcode opcode, bool endOfM } finally { - ReleaseSendBuffer(); _sendFrameAsyncLock.Release(); } } - /// Writes a frame into the send buffer, which can then be sent over the network. - private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, ReadOnlySpan payloadBuffer) - { - // Ensure we have a _sendBuffer. - AllocateSendBuffer(payloadBuffer.Length + MaxMessageHeaderLength); - Debug.Assert(_sendBuffer != null); - - // Write the message header data to the buffer. - int headerLength; - int? maskOffset = null; - if (_isServer) - { - // The server doesn't send a mask, so the mask offset returned by WriteHeader - // is actually the end of the header. - headerLength = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: false); - } - else - { - // We need to know where the mask starts so that we can use the mask to manipulate the payload data, - // and we need to know the total length for sending it on the wire. - maskOffset = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: true); - headerLength = maskOffset.GetValueOrDefault() + MaskLength; - } - - // Write the payload - if (payloadBuffer.Length > 0) - { - payloadBuffer.CopyTo(new Span(_sendBuffer, headerLength, payloadBuffer.Length)); - - // If we added a mask to the header, XOR the payload with the mask. We do the manipulation in the send buffer so as to avoid - // changing the data in the caller-supplied payload buffer. - if (maskOffset.HasValue) - { - ApplyMask(new Span(_sendBuffer, headerLength, payloadBuffer.Length), _sendBuffer, maskOffset.Value, 0); - } - } - - // Return the number of bytes in the send buffer - return headerLength + payloadBuffer.Length; - } - private void SendKeepAliveFrameAsync() { bool acquiredLock = _sendFrameAsyncLock.Wait(0); @@ -597,80 +497,6 @@ private void SendKeepAliveFrameAsync() } } - private static int WriteHeader(MessageOpcode opcode, byte[] sendBuffer, ReadOnlySpan payload, bool endOfMessage, bool useMask) - { - // Client header format: - // 1 bit - FIN - 1 if this is the final fragment in the message (it could be the only fragment), otherwise 0 - // 1 bit - RSV1 - Reserved - 0 - // 1 bit - RSV2 - Reserved - 0 - // 1 bit - RSV3 - Reserved - 0 - // 4 bits - Opcode - How to interpret the payload - // - 0x0 - continuation - // - 0x1 - text - // - 0x2 - binary - // - 0x8 - connection close - // - 0x9 - ping - // - 0xA - pong - // - (0x3 to 0x7, 0xB-0xF - reserved) - // 1 bit - Masked - 1 if the payload is masked, 0 if it's not. Must be 1 for the client - // 7 bits, 7+16 bits, or 7+64 bits - Payload length - // - For length 0 through 125, 7 bits storing the length - // - For lengths 126 through 2^16, 7 bits storing the value 126, followed by 16 bits storing the length - // - For lengths 2^16+1 through 2^64, 7 bits storing the value 127, followed by 64 bytes storing the length - // 0 or 4 bytes - Mask, if Masked is 1 - random value XOR'd with each 4 bytes of the payload, round-robin - // Length bytes - Payload data - - Debug.Assert(sendBuffer.Length >= MaxMessageHeaderLength, $"Expected sendBuffer to be at least {MaxMessageHeaderLength}, got {sendBuffer.Length}"); - - sendBuffer[0] = (byte)opcode; // 4 bits for the opcode - if (endOfMessage) - { - sendBuffer[0] |= 0x80; // 1 bit for FIN - } - - // Store the payload length. - int maskOffset; - if (payload.Length <= 125) - { - sendBuffer[1] = (byte)payload.Length; - maskOffset = 2; // no additional payload length - } - else if (payload.Length <= ushort.MaxValue) - { - sendBuffer[1] = 126; - sendBuffer[2] = (byte)(payload.Length / 256); - sendBuffer[3] = unchecked((byte)payload.Length); - maskOffset = 2 + sizeof(ushort); // additional 2 bytes for 16-bit length - } - else - { - sendBuffer[1] = 127; - int length = payload.Length; - for (int i = 9; i >= 2; i--) - { - sendBuffer[i] = unchecked((byte)length); - length = length / 256; - } - maskOffset = 2 + sizeof(ulong); // additional 8 bytes for 64-bit length - } - - if (useMask) - { - // Generate the mask. - sendBuffer[1] |= 0x80; - WriteRandomMask(sendBuffer, maskOffset); - } - - // Return the position of the mask. - return maskOffset; - } - - /// Writes a 4-byte random mask to the specified buffer at the specified offset. - /// The buffer to which to write the mask. - /// The offset into the buffer at which to write the mask. - private static void WriteRandomMask(byte[] buffer, int offset) => - s_random.GetBytes(buffer, offset, MaskLength); - /// /// Receive the next text, binary, continuation, or close message, returning information about it and /// writing its payload into the supplied buffer. Other control messages may be consumed and processed @@ -699,140 +525,73 @@ private async ValueTask ReceiveAsyncPrivate 125) + // If the header represents a ping or a pong, it's a control message meant + // to be transparent to the user, so handle it and then loop around to read again. + // Alternatively, if it's a close message, handle it and exit. + if (message.Opcode is MessageOpcode.Ping or MessageOpcode.Pong) { - int minNeeded = - 2 + - (_isServer ? MaskLength : 0) + - (payloadLength <= 125 ? 0 : payloadLength == 126 ? sizeof(ushort) : sizeof(ulong)); // additional 2 or 8 bytes for 16-bit or 64-bit length - await EnsureBufferContainsAsync(minNeeded, cancellationToken).ConfigureAwait(false); + // If this was a ping, send back a pong response. + if (message.Opcode == MessageOpcode.Ping) + { + await SendFrameAsync(MessageOpcode.Pong, endOfMessage: true, message.Payload, cancellationToken).ConfigureAwait(false); + } + continue; } - } + else + { + Debug.Assert(message.Opcode == MessageOpcode.Close); - string? headerErrorMessage = TryParseMessageHeaderFromReceiveBuffer(out header); - if (headerErrorMessage != null) - { - await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted, headerErrorMessage).ConfigureAwait(false); + await HandleReceivedCloseAsync(message.Payload, cancellationToken).ConfigureAwait(false); + return resultGetter.GetResult(0, WebSocketMessageType.Close, true, _closeStatus, _closeStatusDescription); + } } - _receivedMaskOffsetOffset = 0; - } - - // If the header represents a ping or a pong, it's a control message meant - // to be transparent to the user, so handle it and then loop around to read again. - // Alternatively, if it's a close message, handle it and exit. - if (header.Opcode == MessageOpcode.Ping || header.Opcode == MessageOpcode.Pong) - { - await HandleReceivedPingPongAsync(header, cancellationToken).ConfigureAwait(false); - continue; - } - else if (header.Opcode == MessageOpcode.Close) - { - await HandleReceivedCloseAsync(header, cancellationToken).ConfigureAwait(false); - return resultGetter.GetResult(0, WebSocketMessageType.Close, true, _closeStatus, _closeStatusDescription); - } - - // If this is a continuation, replace the opcode with the one of the message it's continuing - if (header.Opcode == MessageOpcode.Continuation) - { - header.Opcode = _lastReceiveHeader.Opcode; - } - - // The message should now be a binary or text message. Handle it by reading the payload and returning the contents. - Debug.Assert(header.Opcode == MessageOpcode.Binary || header.Opcode == MessageOpcode.Text, $"Unexpected opcode {header.Opcode}"); - - // If there's no data to read, return an appropriate result. - if (header.PayloadLength == 0 || payloadBuffer.Length == 0) - { - _lastReceiveHeader = header; - return resultGetter.GetResult( - 0, - header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary, - header.Fin && header.PayloadLength == 0, - null, null); - } - - // Otherwise, read as much of the payload as we can efficiently, and update the header to reflect how much data - // remains for future reads. We first need to copy any data that may be lingering in the receive buffer - // into the destination; then to minimize ReceiveAsync calls, we want to read as much as we can, stopping - // only when we've either read the whole message or when we've filled the payload buffer. - - // First copy any data lingering in the receive buffer. - int totalBytesReceived = 0; - if (_receiveBufferCount > 0) - { - int receiveBufferBytesToCopy = Math.Min(payloadBuffer.Length, (int)Math.Min(header.PayloadLength, _receiveBufferCount)); - Debug.Assert(receiveBufferBytesToCopy > 0); - _receiveBuffer.Span.Slice(_receiveBufferOffset, receiveBufferBytesToCopy).CopyTo(payloadBuffer.Span); - ConsumeFromBuffer(receiveBufferBytesToCopy); - totalBytesReceived += receiveBufferBytesToCopy; - Debug.Assert( - _receiveBufferCount == 0 || - totalBytesReceived == payloadBuffer.Length || - totalBytesReceived == header.PayloadLength); - } - - // Then read directly into the payload buffer until we've hit a limit. - while (totalBytesReceived < payloadBuffer.Length && - totalBytesReceived < header.PayloadLength) - { - int numBytesRead = await _stream.ReadAsync(payloadBuffer.Slice( - totalBytesReceived, - (int)Math.Min(payloadBuffer.Length, header.PayloadLength) - totalBytesReceived), cancellationToken).ConfigureAwait(false); - if (numBytesRead <= 0) + else if (result.ResultType == ReceiveResultType.ConnectionClose) { - ThrowIfEOFUnexpected(throwOnPrematureClosure: true); - break; + ThrowIfEOFUnexpected(true); } - totalBytesReceived += numBytesRead; - } + else + { + Debug.Assert(result.ResultType == ReceiveResultType.HeaderError); - if (_isServer) - { - _receivedMaskOffsetOffset = ApplyMask(payloadBuffer.Span.Slice(0, totalBytesReceived), header.Mask, _receivedMaskOffsetOffset); + string? error = _receiver.GetHeaderError(); + await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted, error).ConfigureAwait(false); + } } - header.PayloadLength -= totalBytesReceived; // If this a text message, validate that it contains valid UTF8. - if (header.Opcode == MessageOpcode.Text && - !TryValidateUtf8(payloadBuffer.Span.Slice(0, totalBytesReceived), header.Fin && header.PayloadLength == 0, _utf8TextState)) + if (result.MessageType == WebSocketMessageType.Text && result.Count > 0 && + !TryValidateUtf8(payloadBuffer.Span.Slice(0, result.Count), result.EndOfMessage, _utf8TextState)) { await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.InvalidPayloadData, WebSocketError.Faulted).ConfigureAwait(false); } - _lastReceiveHeader = header; return resultGetter.GetResult( - totalBytesReceived, - header.Opcode == MessageOpcode.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary, - header.Fin && header.PayloadLength == 0, - null, null); + count: result.Count, + messageType: result.MessageType, + endOfMessage: result.EndOfMessage, + closeStatus: null, closeDescription: null); } } - catch (Exception exc) when (!(exc is OperationCanceledException)) + catch (Exception exc) when (exc is not OperationCanceledException) { if (_state == WebSocketState.Aborted) { throw new OperationCanceledException(nameof(WebSocketState.Aborted), exc); } - _abortSource.Cancel(); + OnAborted(); if (exc is WebSocketException) { @@ -848,10 +607,7 @@ private async ValueTask ReceiveAsyncPrivateProcesses a received close message. - /// The message header. - /// The CancellationToken used to cancel the websocket operation. - /// The received result message. - private async ValueTask HandleReceivedCloseAsync(MessageHeader header, CancellationToken cancellationToken) + private async ValueTask HandleReceivedCloseAsync(ReadOnlyMemory payload, CancellationToken cancellationToken) { lock (StateUpdateLock) { @@ -870,41 +626,30 @@ private async ValueTask HandleReceivedCloseAsync(MessageHeader header, Cancellat string closeStatusDescription = string.Empty; // Handle any payload by parsing it into the close status and description. - if (header.PayloadLength == 1) + if (payload.Length == 1) { // The close payload length can be 0 or >= 2, but not 1. await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted).ConfigureAwait(false); } - else if (header.PayloadLength >= 2) + else if (payload.Length >= 2) { - if (_receiveBufferCount < header.PayloadLength) - { - await EnsureBufferContainsAsync((int)header.PayloadLength, cancellationToken).ConfigureAwait(false); - } - - if (_isServer) - { - ApplyMask(_receiveBuffer.Span.Slice(_receiveBufferOffset, (int)header.PayloadLength), header.Mask, 0); - } - - closeStatus = (WebSocketCloseStatus)(_receiveBuffer.Span[_receiveBufferOffset] << 8 | _receiveBuffer.Span[_receiveBufferOffset + 1]); + closeStatus = (WebSocketCloseStatus)(payload.Span[0] << 8 | payload.Span[1]); if (!IsValidCloseStatus(closeStatus)) { await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted).ConfigureAwait(false); } - if (header.PayloadLength > 2) + if (payload.Length > 2) { try { - closeStatusDescription = s_textEncoding.GetString(_receiveBuffer.Span.Slice(_receiveBufferOffset + 2, (int)header.PayloadLength - 2)); + closeStatusDescription = s_textEncoding.GetString(payload.Span.Slice(2)); } catch (DecoderFallbackException exc) { await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted, innerException: exc).ConfigureAwait(false); } } - ConsumeFromBuffer((int)header.PayloadLength); } // Store the close status and description onto the instance. @@ -913,66 +658,7 @@ private async ValueTask HandleReceivedCloseAsync(MessageHeader header, Cancellat if (!_isServer && _sentCloseFrame) { - await WaitForServerToCloseConnectionAsync(cancellationToken).ConfigureAwait(false); - } - } - - /// Issues a read on the stream to wait for EOF. - private async ValueTask WaitForServerToCloseConnectionAsync(CancellationToken cancellationToken) - { - // Per RFC 6455 7.1.1, try to let the server close the connection. We give it up to a second. - // We simply issue a read and don't care what we get back; we could validate that we don't get - // additional data, but at this point we're about to close the connection and we're just stalling - // to try to get the server to close first. - ValueTask finalReadTask = _stream.ReadAsync(_receiveBuffer, cancellationToken); - if (!finalReadTask.IsCompletedSuccessfully) - { - const int WaitForCloseTimeoutMs = 1_000; // arbitrary amount of time to give the server (same as netfx) - using (var finalCts = new CancellationTokenSource(WaitForCloseTimeoutMs)) - using (finalCts.Token.Register(static s => ((ManagedWebSocket)s!).Abort(), this)) - { - try - { - await finalReadTask.ConfigureAwait(false); - } - catch - { - // Eat any resulting exceptions. We were going to close the connection, anyway. - } - } - } - } - - /// Processes a received ping or pong message. - /// The message header. - /// The CancellationToken used to cancel the websocket operation. - private async ValueTask HandleReceivedPingPongAsync(MessageHeader header, CancellationToken cancellationToken) - { - // Consume any (optional) payload associated with the ping/pong. - if (header.PayloadLength > 0 && _receiveBufferCount < header.PayloadLength) - { - await EnsureBufferContainsAsync((int)header.PayloadLength, cancellationToken).ConfigureAwait(false); - } - - // If this was a ping, send back a pong response. - if (header.Opcode == MessageOpcode.Ping) - { - if (_isServer) - { - ApplyMask(_receiveBuffer.Span.Slice(_receiveBufferOffset, (int)header.PayloadLength), header.Mask, 0); - } - - await SendFrameAsync( - MessageOpcode.Pong, - endOfMessage: true, - _receiveBuffer.Slice(_receiveBufferOffset, (int)header.PayloadLength), - cancellationToken).ConfigureAwait(false); - } - - // Regardless of whether it was a ping or pong, we no longer need the payload. - if (header.PayloadLength > 0) - { - ConsumeFromBuffer((int)header.PayloadLength); + await _receiver.WaitForServerToCloseConnectionAsync(cancellationToken).ConfigureAwait(false); } } @@ -1029,90 +715,106 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync( await CloseOutputAsync(closeStatus, string.Empty, default).ConfigureAwait(false); } - // Dump our receive buffer; we're in a bad state to do any further processing - _receiveBufferCount = 0; - // Let the caller know we've failed throw errorMessage != null ? new WebSocketException(error, errorMessage, innerException) : new WebSocketException(error, innerException); } - /// Parses a message header from the buffer. This assumes the header is in the buffer. - /// The read header. - /// null if a valid header was read; non-null containing the string error message to use if the header was invalid. - private string? TryParseMessageHeaderFromReceiveBuffer(out MessageHeader resultHeader) + private static bool TryParseMessageHeader( + ReadOnlySpan buffer, + MessageHeader previousHeader, + bool isServer, + out MessageHeader header, + out string? error, + out int consumedBytes) { - Debug.Assert(_receiveBufferCount >= 2, $"Expected to at least have the first two bytes of the header."); + header = default; + consumedBytes = 0; + error = null; - MessageHeader header = default; - Span receiveBufferSpan = _receiveBuffer.Span; - - header.Fin = (receiveBufferSpan[_receiveBufferOffset] & 0x80) != 0; - bool reservedSet = (receiveBufferSpan[_receiveBufferOffset] & 0x70) != 0; - header.Opcode = (MessageOpcode)(receiveBufferSpan[_receiveBufferOffset] & 0xF); + if (buffer.Length < 2) + { + return false; + } + // Check first for reserved bits that should always be unset + if ((buffer[0] & 0b0011_0000) != 0) + { + return Error(ref error, SR.net_Websockets_ReservedBitsSet); + } + header.Fin = (buffer[0] & 0x80) != 0; + header.Opcode = (MessageOpcode)(buffer[0] & 0xF); + header.Compressed = (buffer[0] & 0b0100_0000) != 0; - bool masked = (receiveBufferSpan[_receiveBufferOffset + 1] & 0x80) != 0; - header.PayloadLength = receiveBufferSpan[_receiveBufferOffset + 1] & 0x7F; + bool masked = (buffer[1] & 0x80) != 0; + if (masked && !isServer) + { + return Error(ref error, SR.net_Websockets_ClientReceivedMaskedFrame); + } + header.PayloadLength = buffer[1] & 0x7F; - ConsumeFromBuffer(2); + // We've consumed the first 2 bytes + buffer = buffer.Slice(2); + consumedBytes += 2; // Read the remainder of the payload length, if necessary if (header.PayloadLength == 126) { - Debug.Assert(_receiveBufferCount >= 2, $"Expected to have two bytes for the payload length."); - header.PayloadLength = (receiveBufferSpan[_receiveBufferOffset] << 8) | receiveBufferSpan[_receiveBufferOffset + 1]; - ConsumeFromBuffer(2); + if (buffer.Length < 2) + { + return false; + } + header.PayloadLength = (buffer[0] << 8) | buffer[1]; + buffer = buffer.Slice(2); + consumedBytes += 2; } else if (header.PayloadLength == 127) { - Debug.Assert(_receiveBufferCount >= 8, $"Expected to have eight bytes for the payload length."); + if (buffer.Length < 8) + { + return false; + } header.PayloadLength = 0; - for (int i = 0; i < 8; i++) + for (int i = 0; i < 8; ++i) { - header.PayloadLength = (header.PayloadLength << 8) | receiveBufferSpan[_receiveBufferOffset + i]; + header.PayloadLength = (header.PayloadLength << 8) | buffer[i]; } - ConsumeFromBuffer(8); - } - - if (reservedSet) - { - resultHeader = default; - return SR.net_Websockets_ReservedBitsSet; + buffer = buffer.Slice(8); + consumedBytes += 8; } if (masked) { - if (!_isServer) + if (buffer.Length < MaskLength) { - resultHeader = default; - return SR.net_Websockets_ClientReceivedMaskedFrame; + return false; } - header.Mask = CombineMaskBytes(receiveBufferSpan, _receiveBufferOffset); - - // Consume the mask bytes - ConsumeFromBuffer(4); + header.Mask = BitConverter.ToInt32(buffer); + consumedBytes += MaskLength; } // Do basic validation of the header switch (header.Opcode) { case MessageOpcode.Continuation: - if (_lastReceiveHeader.Fin) + if (previousHeader.Fin) { // Can't continue from a final message - resultHeader = default; - return SR.net_Websockets_ContinuationFromFinalFrame; + return Error(ref error, SR.net_Websockets_ContinuationFromFinalFrame); + } + if (header.Compressed) + { + // Per-Message Compressed flag must be set only in the first frame + return Error(ref error, SR.net_Websockets_PerMessageCompressedFlagInContinuation); } break; case MessageOpcode.Binary: case MessageOpcode.Text: - if (!_lastReceiveHeader.Fin) + if (!previousHeader.Fin) { // Must continue from a non-final message - resultHeader = default; - return SR.net_Websockets_NonContinuationAfterNonFinalFrame; + return Error(ref error, SR.net_Websockets_NonContinuationAfterNonFinalFrame); } break; @@ -1122,20 +824,22 @@ private async ValueTask CloseWithReceiveErrorAndThrowAsync( if (header.PayloadLength > MaxControlPayloadLength || !header.Fin) { // Invalid control messgae - resultHeader = default; - return SR.net_Websockets_InvalidControlMessage; + return Error(ref error, SR.net_Websockets_InvalidControlMessage); } break; default: // Unknown opcode - resultHeader = default; - return SR.Format(SR.net_Websockets_UnknownOpcode, header.Opcode); + return Error(ref error, SR.Format(SR.net_Websockets_UnknownOpcode, header.Opcode)); } - // Return the read header - resultHeader = header; - return null; + return true; + + static bool Error(ref string? target, string error) + { + target = error; + return false; + } } /// Send a close message, then receive until we get a close response message. @@ -1166,48 +870,40 @@ private async Task CloseAsyncPrivate(WebSocketCloseStatus closeStatus, string? s if (State == WebSocketState.CloseSent) { // Wait until we've received a close response - byte[] closeBuffer = ArrayPool.Shared.Rent(MaxMessageHeaderLength + MaxControlPayloadLength); - try + while (!_receivedCloseFrame) { - while (!_receivedCloseFrame) + Debug.Assert(!Monitor.IsEntered(StateUpdateLock), $"{nameof(StateUpdateLock)} must never be held when acquiring {nameof(ReceiveAsyncLock)}"); + Task receiveTask; + bool usingExistingReceive; + lock (ReceiveAsyncLock) { - Debug.Assert(!Monitor.IsEntered(StateUpdateLock), $"{nameof(StateUpdateLock)} must never be held when acquiring {nameof(ReceiveAsyncLock)}"); - Task receiveTask; - bool usingExistingReceive; - lock (ReceiveAsyncLock) + // Now that we're holding the ReceiveAsyncLock, double-check that we've not yet received the close frame. + // It could have been received between our check above and now due to a concurrent receive completing. + if (_receivedCloseFrame) { - // Now that we're holding the ReceiveAsyncLock, double-check that we've not yet received the close frame. - // It could have been received between our check above and now due to a concurrent receive completing. - if (_receivedCloseFrame) - { - break; - } - - // We've not yet processed a received close frame, which means we need to wait for a received close to complete. - // There may already be one in flight, in which case we want to just wait for that one rather than kicking off - // another (we don't support concurrent receive operations). We need to kick off a new receive if either we've - // never issued a receive or if the last issued receive completed for reasons other than a close frame. There is - // a race condition here, e.g. if there's a in-flight receive that completes after we check, but that's fine: worst - // case is we then await it, find that it's not what we need, and try again. - receiveTask = _lastReceiveAsync; - Task newReceiveTask = ValidateAndReceiveAsync(receiveTask, closeBuffer, cancellationToken); - usingExistingReceive = ReferenceEquals(receiveTask, newReceiveTask); - _lastReceiveAsync = receiveTask = newReceiveTask; + break; } - // Wait for whatever receive task we have. We'll then loop around again to re-check our state. - // If this is an existing receive, and if we have a cancelable token, we need to register with that - // token while we wait, since it may not be the same one that was given to the receive initially. - Debug.Assert(receiveTask != null); - using (usingExistingReceive ? cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this) : default) - { - await receiveTask.ConfigureAwait(false); - } + // We've not yet processed a received close frame, which means we need to wait for a received close to complete. + // There may already be one in flight, in which case we want to just wait for that one rather than kicking off + // another (we don't support concurrent receive operations). We need to kick off a new receive if either we've + // never issued a receive or if the last issued receive completed for reasons other than a close frame. There is + // a race condition here, e.g. if there's a in-flight receive that completes after we check, but that's fine: worst + // case is we then await it, find that it's not what we need, and try again. + receiveTask = _lastReceiveAsync; + Task newReceiveTask = ValidateAndReceiveAsync(receiveTask, cancellationToken); + usingExistingReceive = ReferenceEquals(receiveTask, newReceiveTask); + _lastReceiveAsync = receiveTask = newReceiveTask; + } + + // Wait for whatever receive task we have. We'll then loop around again to re-check our state. + // If this is an existing receive, and if we have a cancelable token, we need to register with that + // token while we wait, since it may not be the same one that was given to the receive initially. + Debug.Assert(receiveTask != null); + using (usingExistingReceive ? cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this) : default) + { + await receiveTask.ConfigureAwait(false); } - } - finally - { - ArrayPool.Shared.Return(closeBuffer); } } @@ -1271,44 +967,7 @@ private async ValueTask SendCloseFrameAsync(WebSocketCloseStatus closeStatus, st if (!_isServer && _receivedCloseFrame) { - await WaitForServerToCloseConnectionAsync(cancellationToken).ConfigureAwait(false); - } - } - - private void ConsumeFromBuffer(int count) - { - Debug.Assert(count >= 0, $"Expected non-negative count, got {count}"); - Debug.Assert(count <= _receiveBufferCount, $"Trying to consume {count}, which is more than exists {_receiveBufferCount}"); - _receiveBufferCount -= count; - _receiveBufferOffset += count; - } - - private async ValueTask EnsureBufferContainsAsync(int minimumRequiredBytes, CancellationToken cancellationToken, bool throwOnPrematureClosure = true) - { - Debug.Assert(minimumRequiredBytes <= _receiveBuffer.Length, $"Requested number of bytes {minimumRequiredBytes} must not exceed {_receiveBuffer.Length}"); - - // If we don't have enough data in the buffer to satisfy the minimum required, read some more. - if (_receiveBufferCount < minimumRequiredBytes) - { - // If there's any data in the buffer, shift it down. - if (_receiveBufferCount > 0) - { - _receiveBuffer.Span.Slice(_receiveBufferOffset, _receiveBufferCount).CopyTo(_receiveBuffer.Span); - } - _receiveBufferOffset = 0; - - // While we don't have enough data, read more. - while (_receiveBufferCount < minimumRequiredBytes) - { - int numRead = await _stream.ReadAsync(_receiveBuffer.Slice(_receiveBufferCount, _receiveBuffer.Length - _receiveBufferCount), cancellationToken).ConfigureAwait(false); - Debug.Assert(numRead >= 0, $"Expected non-negative bytes read, got {numRead}"); - if (numRead <= 0) - { - ThrowIfEOFUnexpected(throwOnPrematureClosure); - break; - } - _receiveBufferCount += numRead; - } + await _receiver.WaitForServerToCloseConnectionAsync(cancellationToken).ConfigureAwait(false); } } @@ -1328,42 +987,6 @@ private void ThrowIfEOFUnexpected(bool throwOnPrematureClosure) } } - /// Gets a send buffer from the pool. - private void AllocateSendBuffer(int minLength) - { - Debug.Assert(_sendBuffer == null); // would only fail if had some catastrophic error previously that prevented cleaning up - _sendBuffer = ArrayPool.Shared.Rent(minLength); - } - - /// Releases the send buffer to the pool. - private void ReleaseSendBuffer() - { - Debug.Assert(_sendFrameAsyncLock.CurrentCount == 0, "Caller should hold the _sendFrameAsyncLock"); - - byte[]? old = _sendBuffer; - if (old != null) - { - _sendBuffer = null; - ArrayPool.Shared.Return(old); - } - } - - private static int CombineMaskBytes(Span buffer, int maskOffset) => - BitConverter.ToInt32(buffer.Slice(maskOffset)); - - /// Applies a mask to a portion of a byte array. - /// The buffer to which the mask should be applied. - /// The array containing the mask to apply. - /// The offset into of the mask to apply of length . - /// The next position offset from of which by to apply next from the mask. - /// The updated maskOffsetOffset value. - private static int ApplyMask(Span toMask, byte[] mask, int maskOffset, int maskOffsetIndex) - { - Debug.Assert(maskOffsetIndex < MaskLength, $"Unexpected {nameof(maskOffsetIndex)}: {maskOffsetIndex}"); - Debug.Assert(mask.Length >= MaskLength + maskOffset, $"Unexpected inputs: {mask.Length}, {maskOffset}"); - return ApplyMask(toMask, CombineMaskBytes(mask, maskOffset), maskOffsetIndex); - } - /// Applies a mask to a portion of a byte array. /// The buffer to which the mask should be applied. /// The four-byte mask, stored as an Int32. @@ -1579,6 +1202,18 @@ private struct MessageHeader internal bool Fin; internal long PayloadLength; internal int Mask; + internal bool Compressed; + } + + private readonly struct ControlMessage + { + internal ControlMessage(MessageOpcode opcode, ReadOnlyMemory payload) + { + Opcode = opcode; + Payload = payload; + } + internal MessageOpcode Opcode { get; } + internal ReadOnlyMemory Payload { get; } } /// diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs index 14732fb6b77ca..8ba34ce1af27e 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocket.cs @@ -137,29 +137,34 @@ public static ArraySegment CreateServerBuffer(int receiveBufferSize) [UnsupportedOSPlatform("browser")] public static WebSocket CreateFromStream(Stream stream, bool isServer, string? subProtocol, TimeSpan keepAliveInterval) { - if (stream == null) + if (!WebSocketCreationOptions.IsKeepAliveValid(keepAliveInterval)) + throw new ArgumentOutOfRangeException(nameof(keepAliveInterval), keepAliveInterval, + SR.Format(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, 0)); + + return CreateFromStream(stream, new WebSocketCreationOptions { + IsServer = isServer, + SubProtocol = subProtocol, + KeepAliveInterval = keepAliveInterval + }); + } + + /// Creates a that operates on a representing a web socket connection. + /// The for the connection. + /// The options with which the websocket must be created. + [UnsupportedOSPlatform("browser")] + public static WebSocket CreateFromStream(Stream stream, WebSocketCreationOptions options) + { + if (stream is null) throw new ArgumentNullException(nameof(stream)); - } + + if (options is null) + throw new ArgumentNullException(nameof(options)); if (!stream.CanRead || !stream.CanWrite) - { throw new ArgumentException(!stream.CanRead ? SR.NotReadableStream : SR.NotWriteableStream, nameof(stream)); - } - - if (subProtocol != null) - { - WebSocketValidate.ValidateSubprotocol(subProtocol); - } - if (keepAliveInterval != Timeout.InfiniteTimeSpan && keepAliveInterval < TimeSpan.Zero) - { - throw new ArgumentOutOfRangeException(nameof(keepAliveInterval), keepAliveInterval, - SR.Format(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, - 0)); - } - - return ManagedWebSocket.CreateFromConnectedStream(stream, isServer, subProtocol, keepAliveInterval); + return ManagedWebSocket.CreateFromConnectedStream(stream, options); } [EditorBrowsable(EditorBrowsableState.Never)] @@ -190,18 +195,6 @@ public static WebSocket CreateClientWebSocket(Stream innerStream, throw new ArgumentException(!innerStream.CanRead ? SR.NotReadableStream : SR.NotWriteableStream, nameof(innerStream)); } - if (subProtocol != null) - { - WebSocketValidate.ValidateSubprotocol(subProtocol); - } - - if (keepAliveInterval != Timeout.InfiniteTimeSpan && keepAliveInterval < TimeSpan.Zero) - { - throw new ArgumentOutOfRangeException(nameof(keepAliveInterval), keepAliveInterval, - SR.Format(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, - 0)); - } - if (receiveBufferSize <= 0 || sendBufferSize <= 0) { throw new ArgumentOutOfRangeException( @@ -213,7 +206,12 @@ public static WebSocket CreateClientWebSocket(Stream innerStream, // Ignore useZeroMaskingKey. ManagedWebSocket doesn't currently support that debugging option. // Ignore internalBuffer. ManagedWebSocket uses its own small buffer for headers/control messages. - return ManagedWebSocket.CreateFromConnectedStream(innerStream, false, subProtocol, keepAliveInterval); + return ManagedWebSocket.CreateFromConnectedStream(innerStream, new WebSocketCreationOptions + { + IsServer = false, + KeepAliveInterval = keepAliveInterval, + SubProtocol = subProtocol + }); } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs new file mode 100644 index 0000000000000..2f7e5f410ddad --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketCreationOptions.cs @@ -0,0 +1,63 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading; + +namespace System.Net.WebSockets +{ + public sealed class WebSocketCreationOptions + { + private string? _subProtocol; + private TimeSpan _keepAliveInterval; + + /// + /// Defines if this websocket is the server-side of the connection. The default value is false. + /// + public bool IsServer { get; set; } + + /// + /// The agreed upon sub-protocol that was used when creating the connection. + /// + public string? SubProtocol + { + get => _subProtocol; + set + { + if (value is not null) + { + WebSocketValidate.ValidateSubprotocol(value); + } + _subProtocol = value; + } + } + + /// + /// The keep-alive interval to use, or or to disable keep-alives. + /// The default is . + /// + public TimeSpan KeepAliveInterval + { + get => _keepAliveInterval; + set + { + if (!IsKeepAliveValid(value)) + { + throw new ArgumentOutOfRangeException(nameof(KeepAliveInterval), value, + SR.Format(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, 0)); + } + _keepAliveInterval = value; + } + } + + /// + /// The agreed upon options for per message deflate. + /// + public WebSocketDeflateOptions? DeflateOptions { get; set; } + + /// + /// Returns whether the provided value is valid websocket keep-alive interval. + /// + internal static bool IsKeepAliveValid(TimeSpan value) => + value == Timeout.InfiniteTimeSpan || value >= TimeSpan.Zero; + } +} diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs new file mode 100644 index 0000000000000..6ddb82c1b0c7e --- /dev/null +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/WebSocketDeflateOptions.cs @@ -0,0 +1,78 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Net.WebSockets +{ + /// + /// Options to enable per-message deflate compression for . + /// + /// + /// Although the WebSocket spec allows window bits from 8 to 15, the current implementation doesn't support 8 bits. + /// For more information refer to the zlib manual https://zlib.net/manual.html. + /// + public sealed class WebSocketDeflateOptions + { + private int _clientMaxWindowBits = 15; + private int _serverMaxWindowBits = 15; + + /// + /// This parameter indicates the base-2 logarithm of the LZ77 sliding window size of the client context. + /// Must be a value between 9 and 15. The default is 15. + /// + /// https://tools.ietf.org/html/rfc7692#section-7.1.2.2 + public int ClientMaxWindowBits + { + get => _clientMaxWindowBits; + set + { + // The underlying zlib component doesn't support 8 bits in deflater (see https://github.com/madler/zlib/issues/94#issuecomment-125832411 + // and https://zlib.net/manual.html). Quote from the manual "For the current implementation of deflate(), a windowBits value of 8 (a window size of 256 bytes) is not supported.". + // We cannot use silently 9 instead of 8, because the websocket produces raw deflate stream + // and thus it needs to know the window bits in advance. Also take a look at https://github.com/madler/zlib/issues/171. + if (value < 9 || value > 15) + { + throw new ArgumentOutOfRangeException(nameof(ClientMaxWindowBits), value, + SR.Format(SR.net_WebSockets_ArgumentOutOfRange, 9, 15)); + } + _clientMaxWindowBits = value; + } + } + + /// + /// When true the client-side of the connection indicates that it will persist the deflate context accross messages. + /// The default is true. + /// + /// https://tools.ietf.org/html/rfc7692#section-7.1.1.2 + public bool ClientContextTakeover { get; set; } = true; + + /// + /// This parameter indicates the base-2 logarithm of the LZ77 sliding window size of the server context. + /// Must be a value between 9 and 15. The default is 15. + /// + /// https://tools.ietf.org/html/rfc7692#section-7.1.2.1 + public int ServerMaxWindowBits + { + get => _serverMaxWindowBits; + set + { + // The underlying zlib component doesn't support 8 bits in deflater (see https://github.com/madler/zlib/issues/94#issuecomment-125832411 + // and https://zlib.net/manual.html). Quote from the manual "For the current implementation of deflate(), a windowBits value of 8 (a window size of 256 bytes) is not supported.". + // We cannot use silently 9 instead of 8, because the websocket produces raw deflate stream + // and thus it needs to know the window bits in advance. Also take a look at https://github.com/madler/zlib/issues/171. + if (value < 9 || value > 15) + { + throw new ArgumentOutOfRangeException(nameof(ServerMaxWindowBits), value, + SR.Format(SR.net_WebSockets_ArgumentOutOfRange, 9, 15)); + } + _serverMaxWindowBits = value; + } + } + + /// + /// When true the server-side of the connection indicates that it will persist the deflate context accross messages. + /// The default is true. + /// + /// https://tools.ietf.org/html/rfc7692#section-7.1.1.1 + public bool ServerContextTakeover { get; set; } = true; + } +} diff --git a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj index 7cf0328df31ca..c7691fa535199 100644 --- a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj +++ b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj @@ -1,8 +1,11 @@ - $(NetCoreAppCurrent) + $(NetCoreAppCurrent)-windows;$(NetCoreAppCurrent)-Unix;$(NetCoreAppCurrent)-Browser + + + diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs index 7f391f7754ec6..0f28f1b4171b0 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketCreateTest.cs @@ -103,12 +103,10 @@ public async Task ReceiveAsync_UTF8SplitAcrossMultipleBuffers_ValidDataReceived( [Theory] [PlatformSpecific(~TestPlatforms.Browser)] // System.Net.Sockets is not supported on this platform. - [ActiveIssue("https://github.com/dotnet/runtime/issues/34690", TestPlatforms.Windows, TargetFrameworkMonikers.Netcoreapp, TestRuntimes.Mono)] [InlineData(0b_1000_0001, 0b_0_000_0001, false)] // fin + text, no mask + length == 1 - [InlineData(0b_1100_0001, 0b_0_000_0001, true)] // fin + rsv1 + text, no mask + length == 1 [InlineData(0b_1010_0001, 0b_0_000_0001, true)] // fin + rsv2 + text, no mask + length == 1 [InlineData(0b_1001_0001, 0b_0_000_0001, true)] // fin + rsv3 + text, no mask + length == 1 - [InlineData(0b_1111_0001, 0b_0_000_0001, true)] // fin + rsv1 + rsv2 + rsv3 + text, no mask + length == 1 + [InlineData(0b_1011_0001, 0b_0_000_0001, true)] // fin + rsv2 + rsv3 + text, no mask + length == 1 [InlineData(0b_1000_0001, 0b_1_000_0001, true)] // fin + text, mask + length == 1 [InlineData(0b_1000_0011, 0b_0_000_0001, true)] // fin + opcode==3, no mask + length == 1 [InlineData(0b_1000_0100, 0b_0_000_0001, true)] // fin + opcode==4, no mask + length == 1 @@ -117,34 +115,24 @@ public async Task ReceiveAsync_UTF8SplitAcrossMultipleBuffers_ValidDataReceived( [InlineData(0b_1000_0111, 0b_0_000_0001, true)] // fin + opcode==7, no mask + length == 1 public async Task ReceiveAsync_InvalidFrameHeader_AbortsAndThrowsException(byte firstByte, byte secondByte, bool shouldFail) { - using (Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) - using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) - { - listener.Bind(new IPEndPoint(IPAddress.Loopback, 0)); - listener.Listen(1); - - await client.ConnectAsync(listener.LocalEndPoint); - using (Socket server = await listener.AcceptAsync()) - { - WebSocket websocket = CreateFromStream(new NetworkStream(client, ownsSocket: false), isServer: false, null, Timeout.InfiniteTimeSpan); + var stream = new WebSocketStream(); - await server.SendAsync(new ArraySegment(new byte[3] { firstByte, secondByte, (byte)'a' }), SocketFlags.None); + stream.Enqueue(firstByte, secondByte, (byte)'a'); + using var websocket = CreateFromStream(stream, isServer: false, null, Timeout.InfiniteTimeSpan); - var buffer = new byte[1]; - Task t = websocket.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); - if (shouldFail) - { - await Assert.ThrowsAsync(() => t); - Assert.Equal(WebSocketState.Aborted, websocket.State); - } - else - { - WebSocketReceiveResult result = await t; - Assert.True(result.EndOfMessage); - Assert.Equal(1, result.Count); - Assert.Equal('a', (char)buffer[0]); - } - } + var buffer = new byte[1]; + Task t = websocket.ReceiveAsync(buffer, CancellationToken.None); + if (shouldFail) + { + await Assert.ThrowsAsync(() => t); + Assert.Equal(WebSocketState.Aborted, websocket.State); + } + else + { + WebSocketReceiveResult result = await t; + Assert.True(result.EndOfMessage); + Assert.Equal(1, result.Count); + Assert.Equal('a', (char)buffer[0]); } } @@ -309,7 +297,7 @@ private static async Task CreateWebSocketStream(Uri echoUri, Socket clie string statusLine = await reader.ReadLineAsync(); Assert.NotEmpty(statusLine); Assert.Equal("HTTP/1.1 101 Switching Protocols", statusLine); - while (!string.IsNullOrEmpty(await reader.ReadLineAsync())); + while (!string.IsNullOrEmpty(await reader.ReadLineAsync())) ; } return stream; diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateOptionsTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateOptionsTests.cs new file mode 100644 index 0000000000000..43d9eea7ba297 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateOptionsTests.cs @@ -0,0 +1,48 @@ +using Xunit; + +namespace System.Net.WebSockets.Tests +{ + public class WebSocketDeflateOptionsTests + { + [Fact] + public void ClientMaxWindowBits() + { + var options = new WebSocketDeflateOptions(); + Assert.Equal(15, options.ClientMaxWindowBits); + + Assert.Throws(() => options.ClientMaxWindowBits = 8); + Assert.Throws(() => options.ClientMaxWindowBits = 16); + + options.ClientMaxWindowBits = 14; + Assert.Equal(14, options.ClientMaxWindowBits); + } + + [Fact] + public void ServerMaxWindowBits() + { + var options = new WebSocketDeflateOptions(); + Assert.Equal(15, options.ServerMaxWindowBits); + + Assert.Throws(() => options.ServerMaxWindowBits = 8); + Assert.Throws(() => options.ServerMaxWindowBits = 16); + + options.ServerMaxWindowBits = 14; + Assert.Equal(14, options.ServerMaxWindowBits); + } + + [Fact] + public void ContextTakeover() + { + var options = new WebSocketDeflateOptions(); + + Assert.True(options.ClientContextTakeover); + Assert.True(options.ServerContextTakeover); + + options.ClientContextTakeover = false; + Assert.False(options.ClientContextTakeover); + + options.ServerContextTakeover = false; + Assert.False(options.ServerContextTakeover); + } + } +} diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs new file mode 100644 index 0000000000000..68768acb054aa --- /dev/null +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketDeflateTests.cs @@ -0,0 +1,343 @@ +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Net.WebSockets.Tests +{ + [PlatformSpecific(~TestPlatforms.Browser)] + public class WebSocketDeflateTests + { + private CancellationTokenSource? _cancellation; + + public WebSocketDeflateTests() + { + if (!Debugger.IsAttached) + { + _cancellation = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + } + } + + public CancellationToken CancellationToken => _cancellation?.Token ?? default; + + public static IEnumerable SupportedWindowBits + { + get + { + for (var i = 9; i <= 15; ++i) + { + yield return new object[] { i }; + } + } + } + + [Fact] + public async Task HelloWithContextTakeover() + { + var stream = new WebSocketStream(); + + stream.Enqueue(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); + using var websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + DeflateOptions = new() + }); + + var buffer = new byte[5]; + var result = await websocket.ReceiveAsync(buffer, CancellationToken); + + Assert.True(result.EndOfMessage); + Assert.Equal(buffer.Length, result.Count); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + Assert.Equal("Hello", Encoding.UTF8.GetString(buffer)); + + // Because context takeover is set by default if we try to send + // the same message it would take fewer bytes. + stream.Enqueue(0xc1, 0x05, 0xf2, 0x00, 0x11, 0x00, 0x00); + + buffer.AsSpan().Clear(); + result = await websocket.ReceiveAsync(buffer, CancellationToken); + + Assert.True(result.EndOfMessage); + Assert.Equal(buffer.Length, result.Count); + Assert.Equal("Hello", Encoding.UTF8.GetString(buffer)); + } + + [Fact] + public async Task HelloWithoutContextTakeover() + { + var stream = new WebSocketStream(); + + using var websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + DeflateOptions = new() + { + ClientContextTakeover = false + } + }); + + var buffer = new byte[5]; + + for (var i = 0; i < 100; ++i) + { + // Without context takeover the message should look the same every time + stream.Enqueue(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); + buffer.AsSpan().Clear(); + + var result = await websocket.ReceiveAsync(buffer, CancellationToken); + + Assert.True(result.EndOfMessage); + Assert.Equal(buffer.Length, result.Count); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + Assert.Equal("Hello", Encoding.UTF8.GetString(buffer)); + } + } + + [Fact] + public async Task TwoDeflateBlocksInOneMessage() + { + // Two or more DEFLATE blocks may be used in one message. + var stream = new WebSocketStream(); + using var websocket = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + DeflateOptions = new() + }); + // The first 3 octets(0xf2 0x48 0x05) and the least significant two + // bits of the 4th octet(0x00) constitute one DEFLATE block with + // "BFINAL" set to 0 and "BTYPE" set to 01 containing "He". The rest of + // the 4th octet contains the header bits with "BFINAL" set to 0 and + // "BTYPE" set to 00, and the 3 padding bits of 0. Together with the + // following 4 octets(0x00 0x00 0xff 0xff), the header bits constitute + // an empty DEFLATE block with no compression. A DEFLATE block + // containing "llo" follows the empty DEFLATE block. + stream.Enqueue(0x41, 0x08, 0xf2, 0x48, 0x05, 0x00, 0x00, 0x00, 0xff, 0xff); + stream.Enqueue(0x80, 0x05, 0xca, 0xc9, 0xc9, 0x07, 0x00); + + Memory buffer = new byte[5]; + var result = await websocket.ReceiveAsync(buffer, CancellationToken); + + Assert.Equal(2, result.Count); + Assert.False(result.EndOfMessage); + + result = await websocket.ReceiveAsync(buffer.Slice(result.Count), CancellationToken); + + Assert.Equal(3, result.Count); + Assert.True(result.EndOfMessage); + Assert.Equal("Hello", Encoding.UTF8.GetString(buffer.Span)); + } + + [Theory] + [InlineData(false, false)] + [InlineData(true, true)] + [InlineData(false, true)] + [InlineData(true, false)] + public async Task Duplex(bool clientContextTakover, bool serverContextTakover) + { + var stream = new WebSocketStream(); + using var server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + IsServer = true, + DeflateOptions = new WebSocketDeflateOptions + { + ClientContextTakeover = clientContextTakover, + ServerContextTakeover = serverContextTakover + } + }); + using var client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + DeflateOptions = new WebSocketDeflateOptions + { + ClientContextTakeover = clientContextTakover, + ServerContextTakeover = serverContextTakover + } + }); + + var buffer = new byte[1024]; + + for (var i = 0; i < 10; ++i) + { + var message = $"Sending number {i} from server."; + await SendTextAsync(message, server); + + var result = await client.ReceiveAsync(buffer.AsMemory(), CancellationToken); + + Assert.True(result.EndOfMessage); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + + Assert.Equal(message, Encoding.UTF8.GetString(buffer.AsSpan(0, result.Count))); + } + + for (var i = 0; i < 10; ++i) + { + var message = $"Sending number {i} from client."; + await SendTextAsync(message, client); + + var result = await server.ReceiveAsync(buffer.AsMemory(), CancellationToken); + + Assert.True(result.EndOfMessage); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + + Assert.Equal(message, Encoding.UTF8.GetString(buffer.AsSpan(0, result.Count))); + } + } + + [Theory] + [MemberData(nameof(SupportedWindowBits))] + public async Task LargeMessageSplitInMultipleFrames(int windowBits) + { + var stream = new WebSocketStream(); + using var server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + IsServer = true, + DeflateOptions = new() + { + ClientMaxWindowBits = windowBits + } + }); + using var client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + DeflateOptions = new() + { + ClientMaxWindowBits = windowBits + } + }); + + Memory testData = new byte[ushort.MaxValue]; + Memory receivedData = new byte[testData.Length]; + + // Make the data incompressible to make sure that the output is larger than the input + var rng = new Random(0); + rng.NextBytes(testData.Span); + + // Test it a few times with different frame sizes + for (var i = 0; i < 10; ++i) + { + var frameSize = rng.Next(1024, 2048); + var position = 0; + + while (position < testData.Length) + { + var currentFrameSize = Math.Min(frameSize, testData.Length - position); + var eof = position + currentFrameSize == testData.Length; + + await server.SendAsync(testData.Slice(position, currentFrameSize), WebSocketMessageType.Binary, eof, CancellationToken); + position += currentFrameSize; + } + + Assert.True(testData.Length < stream.Remote.Available, "The compressed data should be bigger."); + Assert.Equal(testData.Length, position); + + // Receive the data from the client side + receivedData.Span.Clear(); + position = 0; + + // Intentionally receive with a frame size that is less than what the sender used + frameSize /= 3; + + while (true) + { + var currentFrameSize = Math.Min(frameSize, testData.Length - position); + var result = await client.ReceiveAsync(receivedData.Slice(position, currentFrameSize), CancellationToken); + + Assert.Equal(WebSocketMessageType.Binary, result.MessageType); + position += result.Count; + + if (result.EndOfMessage) + break; + } + + Assert.Equal(0, stream.Remote.Available); + Assert.Equal(testData.Length, position); + Assert.True(testData.Span.SequenceEqual(receivedData.Span)); + } + } + + [Fact] + public async Task WebSocketWithoutDeflateShouldThrowOnCompressedMessage() + { + var stream = new WebSocketStream(); + + stream.Enqueue(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); + using var websocket = WebSocket.CreateFromStream(stream, new()); + + var exception = await Assert.ThrowsAsync(() => + websocket.ReceiveAsync(Memory.Empty, CancellationToken).AsTask()); + + Assert.Equal("The WebSocket received compressed frame when compression is not enabled.", exception.Message); + } + + [Fact] + public async Task ReceiveUncompressedMessageWhenCompressionEnabled() + { + // We should be able to handle the situation where even if we have + // deflate compression enabled, uncompressed messages are OK + var stream = new WebSocketStream(); + var server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + IsServer = true, + DeflateOptions = null + }); + var client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + DeflateOptions = new WebSocketDeflateOptions() + }); + + // Server sends uncompressed + await SendTextAsync("Hello", server); + + // Although client has deflate options, it should still be able + // to handle uncompressed messages. + Assert.Equal("Hello", await ReceiveTextAsync(client)); + + // Client sends compressed, but server compression is disabled and should throw on receive + await SendTextAsync("Hello back", client); + var exception = await Assert.ThrowsAsync(() => ReceiveTextAsync(server)); + Assert.Equal("The WebSocket received compressed frame when compression is not enabled.", exception.Message); + Assert.Equal(WebSocketState.Aborted, server.State); + + // The client should close if we try to receive + var result = await client.ReceiveAsync(Memory.Empty, CancellationToken); + Assert.Equal(WebSocketMessageType.Close, result.MessageType); + Assert.Equal(WebSocketCloseStatus.ProtocolError, client.CloseStatus); + Assert.Equal(WebSocketState.CloseReceived, client.State); + } + + [Fact] + public async Task ReceiveInvalidCompressedData() + { + var stream = new WebSocketStream(); + var client = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + DeflateOptions = new WebSocketDeflateOptions() + }); + + stream.Enqueue(0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00); + Assert.Equal("Hello", await ReceiveTextAsync(client)); + + stream.Enqueue(0xc1, 0x07, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00); + var exception = await Assert.ThrowsAsync(() => ReceiveTextAsync(client)); + + Assert.Equal("The message was compressed using an unsupported compression method.", exception.Message); + Assert.Equal(WebSocketState.Aborted, client.State); + } + + private ValueTask SendTextAsync(string text, WebSocket websocket) + { + var bytes = Encoding.UTF8.GetBytes(text); + return websocket.SendAsync(bytes.AsMemory(), WebSocketMessageType.Text, true, CancellationToken); + } + + private async Task ReceiveTextAsync(WebSocket websocket) + { + using var buffer = MemoryPool.Shared.Rent(1024 * 32); + var result = await websocket.ReceiveAsync(buffer.Memory, CancellationToken); + + Assert.True(result.EndOfMessage); + Assert.Equal(WebSocketMessageType.Text, result.MessageType); + + return Encoding.UTF8.GetString(buffer.Memory.Span.Slice(0, result.Count)); + } + } +} diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs new file mode 100644 index 0000000000000..bd46e40c73db0 --- /dev/null +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketStream.cs @@ -0,0 +1,192 @@ +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.WebSockets.Tests +{ + /// + /// A helper stream class that can be used simulate + /// sending / receiving (duplex) data in a websocket. + /// + public class WebSocketStream : Stream + { + private readonly SemaphoreSlim _inputLock = new(initialCount: 0); + private readonly Queue _inputQueue = new(); + private readonly CancellationTokenSource _disposed = new(); + + public WebSocketStream() + { + GC.SuppressFinalize(this); + Remote = new WebSocketStream(this); + } + + private WebSocketStream(WebSocketStream remote) + { + GC.SuppressFinalize(this); + Remote = remote; + } + + public WebSocketStream Remote { get; } + + /// + /// Returns the number of unread bytes. + /// + public int Available + { + get + { + var available = 0; + + lock (_inputQueue) + { + foreach (var x in _inputQueue) + { + available += x.AvailableLength; + } + } + + return available; + } + } + + public Span NextAvailableBytes + { + get + { + lock (_inputQueue) + { + var block = _inputQueue.Peek(); + + if (block is null) + { + return default; + } + return block.Available; + } + } + } + + public override bool CanRead => true; + + public override bool CanSeek => false; + + public override bool CanWrite => true; + + public override long Length => -1; + + public override long Position { get => -1; set => throw new NotSupportedException(); } + + protected override void Dispose(bool disposing) + { + if (!_disposed.IsCancellationRequested) + { + _disposed.Cancel(); + + lock (Remote._inputQueue) + { + Remote._inputLock.Release(); + Remote._inputQueue.Enqueue(Block.ConnectionClosed); + } + } + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + using (var cancellation = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _disposed.Token)) + { + try + { + await _inputLock.WaitAsync(cancellation.Token).ConfigureAwait(false); + } + catch (OperationCanceledException) when (_disposed.IsCancellationRequested) + { + return 0; + } + } + + lock (_inputQueue) + { + var block = _inputQueue.Peek(); + if (block == Block.ConnectionClosed) + { + return 0; + } + var count = Math.Min(block.AvailableLength, buffer.Length); + + block.Available.Slice(0, count).CopyTo(buffer.Span); + block.Advance(count); + + if (block.AvailableLength == 0) + { + _inputQueue.Dequeue(); + } + else + { + // Because we haven't fully consumed the buffer + // we should release once the input lock so we can acquire + // it again on consequent receive. + _inputLock.Release(); + } + + return count; + } + } + + /// + /// Receives the data and enqueues it for processing. + /// + public void Enqueue(params byte[] data) + { + lock (_inputQueue) + { + _inputLock.Release(); + _inputQueue.Enqueue(new Block(data)); + } + } + + public override void Write(ReadOnlySpan buffer) + { + lock (Remote._inputQueue) + { + Remote._inputLock.Release(); + Remote._inputQueue.Enqueue(new Block(buffer.ToArray())); + } + } + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + Write(buffer.Span); + return ValueTask.CompletedTask; + } + + public override void Flush() => throw new NotSupportedException(); + + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + + public override void SetLength(long value) => throw new NotSupportedException(); + + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + + private sealed class Block + { + public static readonly Block ConnectionClosed = new(Array.Empty()); + + private readonly byte[] _data; + private int _position; + + public Block(byte[] data) + { + _data = data; + } + + public Span Available => _data.AsSpan(_position); + + public int AvailableLength => _data.Length - _position; + + public void Advance(int count) => _position += count; + } + } +} diff --git a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs index ad738a00ec864..132ad560c81c9 100644 --- a/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs +++ b/src/libraries/System.Net.WebSockets/tests/WebSocketTests.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.IO; +using System.Threading; +using System.Threading.Tasks; using Xunit; namespace System.Net.WebSockets.Tests @@ -171,6 +173,100 @@ public void ValueWebSocketReceiveResult_Ctor_ValidArguments_Roundtrip(int count, Assert.Equal(endOfMessage, r.EndOfMessage); } + [Theory] + [InlineData(0)] + [InlineData(125)] + [InlineData(ushort.MaxValue)] + [InlineData(ushort.MaxValue * 2)] + public async Task SendUncompressedClientMessage(int messageSize) + { + var stream = new WebSocketStream(); + using var server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + IsServer = true + }); + using var client = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions()); + + var message = new byte[messageSize]; + new Random(0).NextBytes(message); + + await client.SendAsync(message, WebSocketMessageType.Binary, true, default); + + var buffer = new byte[messageSize]; + var result = await server.ReceiveAsync(buffer, default); + + Assert.Equal(messageSize, result.Count); + Assert.True(result.EndOfMessage); + Assert.True(message.AsSpan().SequenceEqual(buffer)); + } + + [Fact] + public async Task WhenPingReceivedPongMessageMustBeSent() + { + var stream = new WebSocketStream(); + using var server = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions + { + // Use server so we don't receive any masked payload + IsServer = true + }); + using var cancellation = new CancellationTokenSource(); + + stream.Enqueue(0b1000_1001, 0x00); + var receiveTask = server.ReceiveAsync(Memory.Empty, cancellation.Token).AsTask(); + + Assert.Equal(0, stream.Available); + Assert.Equal(2, stream.Remote.Available); + Assert.Equal(new byte[] { 0b1000_1010, 0x00 }, stream.Remote.NextAvailableBytes.ToArray()); + + cancellation.Cancel(); + await Assert.ThrowsAsync(async () => await receiveTask.ConfigureAwait(false)); + } + + [Fact] + public async Task WhenPongReceivedNothingShouldBeSentBack() + { + var stream = new WebSocketStream(); + using var client = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions()); + + using var cancellation = new CancellationTokenSource(); + + stream.Enqueue(0b1000_1010, 0x00); + var receiveTask = client.ReceiveAsync(Memory.Empty, cancellation.Token).AsTask(); + + Assert.Equal(0, stream.Available); + Assert.Equal(0, stream.Remote.Available); + + cancellation.Cancel(); + await Assert.ThrowsAsync(async () => await receiveTask.ConfigureAwait(false)); + } + + [Fact] + public async Task ClosingWebSocketsGracefully() + { + var stream = new WebSocketStream(); + using var cancellation = new CancellationTokenSource(TimeSpan.FromSeconds(155)); + using var client = WebSocket.CreateFromStream(stream, new WebSocketCreationOptions()); + using var server = WebSocket.CreateFromStream(stream.Remote, new WebSocketCreationOptions + { + IsServer = true + }); + + var clientClose = client.CloseAsync(WebSocketCloseStatus.PolicyViolation, "Yeet", cancellation.Token); + var result = await server.ReceiveAsync(Memory.Empty, cancellation.Token); + + Assert.True(result.EndOfMessage); + Assert.Equal(WebSocketMessageType.Close, result.MessageType); + Assert.Equal(0, result.Count); + Assert.Equal("Yeet", server.CloseStatusDescription); + Assert.Equal(WebSocketCloseStatus.PolicyViolation, server.CloseStatus); + + await server.CloseAsync(WebSocketCloseStatus.NormalClosure, null, cancellation.Token); + await clientClose; + + Assert.Equal(WebSocketState.Closed, server.State); + Assert.Equal(WebSocketState.Closed, client.State); + } + public abstract class ExposeProtectedWebSocket : WebSocket { public static new bool IsStateTerminal(WebSocketState state) =>