diff --git a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs index 2da5179682d37..b96ce3a5f1d25 100644 --- a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs +++ b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs @@ -338,6 +338,7 @@ public void Connect(string host, int port) { } public static bool ConnectAsync(System.Net.Sockets.SocketType socketType, System.Net.Sockets.ProtocolType protocolType, System.Net.Sockets.SocketAsyncEventArgs e) { throw null; } public void Disconnect(bool reuseSocket) { } public bool DisconnectAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; } + public System.Threading.Tasks.ValueTask DisconnectAsync(bool reuseSocket, System.Threading.CancellationToken cancellationToken = default) { throw null; } public void Dispose() { } protected virtual void Dispose(bool disposing) { } [System.Runtime.Versioning.SupportedOSPlatformAttribute("windows")] diff --git a/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj b/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj index bbddbb391d54b..42314da6c7e9d 100644 --- a/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj +++ b/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj @@ -45,7 +45,6 @@ - - diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/DisconnectOverlappedAsyncResult.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/DisconnectOverlappedAsyncResult.Unix.cs deleted file mode 100644 index 856962b56b7a8..0000000000000 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/DisconnectOverlappedAsyncResult.Unix.cs +++ /dev/null @@ -1,11 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace System.Net.Sockets -{ - // DisconnectOverlappedAsyncResult - used to take care of storage for async Socket BeginDisconnect call. - internal sealed partial class DisconnectOverlappedAsyncResult : BaseOverlappedAsyncResult - { - internal void PostCompletion(SocketError errorCode) => CompletionCallback(0, errorCode); - } -} diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/DisconnectOverlappedAsyncResult.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/DisconnectOverlappedAsyncResult.cs deleted file mode 100644 index 2d6c81ac584f0..0000000000000 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/DisconnectOverlappedAsyncResult.cs +++ /dev/null @@ -1,27 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace System.Net.Sockets -{ - // DisconnectOverlappedAsyncResult - used to take care of storage for async Socket BeginDisconnect call. - internal sealed partial class DisconnectOverlappedAsyncResult : BaseOverlappedAsyncResult - { - internal DisconnectOverlappedAsyncResult(Socket socket, object? asyncState, AsyncCallback? asyncCallback) : - base(socket, asyncState, asyncCallback) - { - } - - // This method will be called by us when the IO completes synchronously and - // by the ThreadPool when the IO completes asynchronously. - internal override object? PostCompletion(int numBytes) - { - if (ErrorCode == (int)SocketError.Success) - { - Socket socket = (Socket)AsyncObject!; - socket.SetToDisconnected(); - socket._remoteEndPoint = null; - } - return base.PostCompletion(numBytes); - } - } -} diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs index 461122cfd09a6..ac604ebc4a059 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs @@ -263,6 +263,29 @@ public ValueTask ConnectAsync(string host, int port, CancellationToken cancellat return ConnectAsync(ep, cancellationToken); } + /// + /// Disconnects a connected socket from the remote host. + /// + /// Indicates whether the socket should be available for reuse after disconnect. + /// A cancellation token that can be used to cancel the asynchronous operation. + /// An asynchronous task that completes when the socket is disconnected. + public ValueTask DisconnectAsync(bool reuseSocket, CancellationToken cancellationToken = default) + { + if (cancellationToken.IsCancellationRequested) + { + return ValueTask.FromCanceled(cancellationToken); + } + + AwaitableSocketAsyncEventArgs saea = + Interlocked.Exchange(ref _singleBufferSendEventArgs, null) ?? + new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: false); + + saea.DisconnectReuseSocket = reuseSocket; + saea.WrapExceptionsForNetworkStream = false; + + return saea.DisconnectAsync(this, cancellationToken); + } + /// /// Receives data from a connected socket. /// @@ -1028,6 +1051,25 @@ public ValueTask ConnectAsync(Socket socket) ValueTask.FromException(CreateException(error)); } + public ValueTask DisconnectAsync(Socket socket, CancellationToken cancellationToken) + { + Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use"); + + if (socket.DisconnectAsync(this, cancellationToken)) + { + _cancellationToken = cancellationToken; + return new ValueTask(this, _token); + } + + SocketError error = SocketError; + + Release(); + + return error == SocketError.Success ? + ValueTask.CompletedTask : + ValueTask.FromException(CreateException(error)); + } + /// Gets the status of the operation. public ValueTaskSourceStatus GetStatus(short token) { diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 2671112ffa139..d39abf6609ba5 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -73,7 +73,6 @@ private sealed class CacheSet private int _closeTimeout = Socket.DefaultCloseTimeout; private int _disposed; // 0 == false, anything else == true - #region Constructors public Socket(SocketType socketType, ProtocolType protocolType) : this(OSSupportsIPv6 ? AddressFamily.InterNetworkV6 : AddressFamily.InterNetwork, socketType, protocolType) { @@ -242,9 +241,10 @@ private static SafeSocketHandle ValidateHandle(SafeSocketHandle handle) => handle is null ? throw new ArgumentNullException(nameof(handle)) : handle.IsInvalid ? throw new ArgumentException(SR.Arg_InvalidHandle, nameof(handle)) : handle; - #endregion - #region Properties + // + // Properties + // // The CLR allows configuration of these properties, separately from whether the OS supports IPv4/6. We // do not provide these config options, so SupportsIPvX === OSSupportsIPvX. @@ -761,9 +761,10 @@ internal bool CanTryAddressFamily(AddressFamily family) { return (family == _addressFamily) || (family == AddressFamily.InterNetwork && IsDualMode); } - #endregion - #region Public Methods + // + // Public Methods + // // Associates a socket with an end point. public void Bind(EndPoint localEP) @@ -2116,43 +2117,14 @@ public IAsyncResult BeginConnect(IPAddress address, int port, AsyncCallback? req public IAsyncResult BeginConnect(IPAddress[] addresses, int port, AsyncCallback? requestCallback, object? state) => TaskToApm.Begin(ConnectAsync(addresses, port), requestCallback, state); - public IAsyncResult BeginDisconnect(bool reuseSocket, AsyncCallback? callback, object? state) + public void EndConnect(IAsyncResult asyncResult) { ThrowIfDisposed(); - - // Start context-flowing op. No need to lock - we don't use the context till the callback. - DisconnectOverlappedAsyncResult asyncResult = new DisconnectOverlappedAsyncResult(this, state, callback); - asyncResult.StartPostingAsyncOp(false); - - // Post the disconnect. - DoBeginDisconnect(reuseSocket, asyncResult); - - // Finish flowing (or call the callback), and return. - asyncResult.FinishPostingAsyncOp(); - return asyncResult; + TaskToApm.End(asyncResult); } - private void DoBeginDisconnect(bool reuseSocket, DisconnectOverlappedAsyncResult asyncResult) - { - SocketError errorCode = SocketError.Success; - - errorCode = SocketPal.DisconnectAsync(this, _handle, reuseSocket, asyncResult); - - if (errorCode == SocketError.Success) - { - SetToDisconnected(); - _remoteEndPoint = null; - _localEndPoint = null; - } - - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"UnsafeNclNativeMethods.OSSOCK.DisConnectEx returns:{errorCode}"); - - // If the call failed, update our status and throw - if (!CheckErrorAndUpdateStatus(errorCode)) - { - throw new SocketException((int)errorCode); - } - } + public IAsyncResult BeginDisconnect(bool reuseSocket, AsyncCallback? callback, object? state) => + TaskToApmBeginWithSyncExceptions(DisconnectAsync(reuseSocket).AsTask(), callback, state); public void Disconnect(bool reuseSocket) { @@ -2175,47 +2147,12 @@ public void Disconnect(bool reuseSocket) _localEndPoint = null; } - public void EndConnect(IAsyncResult asyncResult) + public void EndDisconnect(IAsyncResult asyncResult) { ThrowIfDisposed(); TaskToApm.End(asyncResult); } - public void EndDisconnect(IAsyncResult asyncResult) - { - ThrowIfDisposed(); - - if (asyncResult == null) - { - throw new ArgumentNullException(nameof(asyncResult)); - } - - //get async result and check for errors - LazyAsyncResult? castedAsyncResult = asyncResult as LazyAsyncResult; - if (castedAsyncResult == null || castedAsyncResult.AsyncObject != this) - { - throw new ArgumentException(SR.net_io_invalidasyncresult, nameof(asyncResult)); - } - if (castedAsyncResult.EndCalled) - { - throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, nameof(EndDisconnect))); - } - - //wait for completion if it hasn't occurred - castedAsyncResult.InternalWaitForCompletion(); - castedAsyncResult.EndCalled = true; - - if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this); - - // - // if the asynchronous native call failed asynchronously - // we'll throw a SocketException - // - if ((SocketError)castedAsyncResult.ErrorCode != SocketError.Success) - { - UpdateStatusAfterSocketErrorAndThrowException((SocketError)castedAsyncResult.ErrorCode); - } - } public IAsyncResult BeginSend(byte[] buffer, int offset, int size, SocketFlags socketFlags, AsyncCallback? callback, object? state) { @@ -2668,7 +2605,10 @@ public void Shutdown(SocketShutdown how) InternalSetBlocking(_willBlockInternal); } - #region Async methods + // + // Async methods + // + public bool AcceptAsync(SocketAsyncEventArgs e) { ThrowIfDisposed(); @@ -2889,7 +2829,9 @@ public static void CancelConnectAsync(SocketAsyncEventArgs e) e.CancelConnectAsync(); } - public bool DisconnectAsync(SocketAsyncEventArgs e) + public bool DisconnectAsync(SocketAsyncEventArgs e) => DisconnectAsync(e, default); + + private bool DisconnectAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken) { // Throw if socket disposed ThrowIfDisposed(); @@ -2904,7 +2846,7 @@ public bool DisconnectAsync(SocketAsyncEventArgs e) SocketError socketError = SocketError.Success; try { - socketError = e.DoOperationDisconnect(this, _handle); + socketError = e.DoOperationDisconnect(this, _handle, cancellationToken); } catch { @@ -3155,10 +3097,10 @@ private bool SendToAsync(SocketAsyncEventArgs e, CancellationToken cancellationT return socketError == SocketError.IOPending; } - #endregion - #endregion - #region Internal and private properties + // + // Internal and private properties + // private CacheSet Caches { @@ -3174,9 +3116,10 @@ private CacheSet Caches } internal bool Disposed => _disposed != 0; - #endregion - #region Internal and private methods + // + // Internal and private methods + // internal static void GetIPProtocolInformation(AddressFamily addressFamily, Internals.SocketAddress socketAddress, out bool isIPv4, out bool isIPv6) { @@ -3889,6 +3832,16 @@ private static SocketError GetSocketErrorFromFaultedTask(Task t) }; } - #endregion + // Helper to maintain existing behavior of Socket APM methods to throw synchronously from Begin*. + private static IAsyncResult TaskToApmBeginWithSyncExceptions(Task task, AsyncCallback? callback, object? state) + { + if (task.IsFaulted) + { + task.GetAwaiter().GetResult(); + Debug.Fail("Task faulted but GetResult did not throw???"); + } + + return TaskToApm.Begin(task, callback, state); + } } } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs index b5d6b0b3b2e17..88120241d2fcc 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs @@ -93,7 +93,7 @@ internal unsafe SocketError DoOperationConnect(Socket socket, SafeSocketHandle h return socketError; } - internal SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle) + internal SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken) { SocketError socketError = SocketPal.Disconnect(socket, handle, _disconnectReuseSocket); FinishOperationSync(socketError, 0, SocketFlags.None); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs index 8e9981380be4a..7bca8498ebca1 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs @@ -364,8 +364,11 @@ internal unsafe SocketError DoOperationConnectEx(Socket socket, SafeSocketHandle } } - internal unsafe SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle) + internal unsafe SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken) { + // Note: CancellationToken is ignored for now. + // See https://github.com/dotnet/runtime/issues/51452 + NativeOverlapped* overlapped = AllocateNativeOverlapped(); try { @@ -1188,6 +1191,7 @@ private unsafe SocketError FinishOperationConnect() private void CompleteCore() { _strongThisRef.Value = null; // null out this reference from the overlapped so this isn't kept alive artificially + if (_singleBufferHandleState != SingleBufferHandleState.None) { // If the state isn't None, then either it's Set, in which case there's state to cleanup, @@ -1213,6 +1217,8 @@ void CompleteCoreSpin() sw.SpinOnce(); } + Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.Set); + // Remove any cancellation registration. First dispose the registration // to ensure that cancellation will either never fine or will have completed // firing before we continue. Only then can we safely null out the overlapped. @@ -1223,6 +1229,8 @@ void CompleteCoreSpin() } // Release any GC handles. + Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.Set); + if (_singleBufferHandleState == SingleBufferHandleState.Set) { _singleBufferHandleState = SingleBufferHandleState.None; diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs index 735ece8f8d0a7..3377991fd02cf 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs @@ -1976,13 +1976,6 @@ public static SocketError AcceptAsync(Socket socket, SafeSocketHandle handle, Sa return socketError; } - internal static SocketError DisconnectAsync(Socket socket, SafeSocketHandle handle, bool reuseSocket, DisconnectOverlappedAsyncResult asyncResult) - { - SocketError socketError = Disconnect(socket, handle, reuseSocket); - asyncResult.PostCompletion(socketError); - return socketError; - } - internal static SocketError Disconnect(Socket socket, SafeSocketHandle handle, bool reuseSocket) { handle.SetToDisconnected(); diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs index cf6e12e66dbba..e902d22c049bc 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs @@ -1137,27 +1137,6 @@ public static void CheckDualModeReceiveSupport(Socket socket) // Dual-mode sockets support received packet info on Windows. } - internal static unsafe SocketError DisconnectAsync(Socket socket, SafeSocketHandle handle, bool reuseSocket, DisconnectOverlappedAsyncResult asyncResult) - { - asyncResult.SetUnmanagedStructures(null); - try - { - // This can throw ObjectDisposedException - bool success = socket.DisconnectEx( - handle, - asyncResult.DangerousOverlappedPointer, // SafeHandle was just created in SetUnmanagedStructures - (int)(reuseSocket ? TransmitFileOptions.ReuseSocket : 0), - 0); - - return asyncResult.ProcessOverlappedResult(success, 0); - } - catch - { - asyncResult.ReleaseUnmanagedStructures(); - throw; - } - } - internal static SocketError Disconnect(Socket socket, SafeSocketHandle handle, bool reuseSocket) { SocketError errorCode = SocketError.Success; diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/DisconnectTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/DisconnectTest.cs index 603e0c3b84ae0..1d343cb434f6d 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/DisconnectTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/DisconnectTest.cs @@ -9,124 +9,66 @@ namespace System.Net.Sockets.Tests { - public class DisconnectTest + public abstract class Disconnect : SocketTestHelperBase where T : SocketHelperBase, new() { - private readonly ITestOutputHelper _log; - - public DisconnectTest(ITestOutputHelper output) - { - _log = TestLogging.GetInstance(); - Assert.True(Capability.IPv4Support() || Capability.IPv6Support()); - } - - private static void OnCompleted(object sender, SocketAsyncEventArgs args) - { - EventWaitHandle handle = (EventWaitHandle)args.UserToken; - handle.Set(); - } - - [Fact] - public void InvalidArguments_Throw() - { - using (Socket s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) - { - AssertExtensions.Throws("asyncResult", () => s.EndDisconnect(null)); - AssertExtensions.Throws("e", () => s.DisconnectAsync(null)); - AssertExtensions.Throws("asyncResult", () => s.EndDisconnect(Task.CompletedTask)); - s.Dispose(); - Assert.Throws(() => s.Disconnect(true)); - Assert.Throws(() => s.BeginDisconnect(true, null, null)); - Assert.Throws(() => s.EndDisconnect(null)); - Assert.Throws(() => { s.DisconnectAsync(null); }); - } - } + protected Disconnect(ITestOutputHelper output) : base(output) { } [Theory] [InlineData(true)] [InlineData(false)] - [OuterLoop("https://github.com/dotnet/runtime/issues/18406")] - public void Disconnect_Success(bool reuseSocket) + public async Task Disconnect_Success(bool reuseSocket) { - AutoResetEvent completed = new AutoResetEvent(false); - IPEndPoint loopback = new IPEndPoint(IPAddress.Loopback, 0); using (var server1 = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, loopback)) using (var server2 = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, loopback)) { - SocketAsyncEventArgs args = new SocketAsyncEventArgs(); - args.Completed += OnCompleted; - args.UserToken = completed; - args.RemoteEndPoint = server1.EndPoint; - args.DisconnectReuseSocket = reuseSocket; - using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) { - if (client.ConnectAsync(args)) - { - completed.WaitOne(); - } - - Assert.Equal(SocketError.Success, args.SocketError); - - client.Disconnect(reuseSocket); + await ConnectAsync(client, server1.EndPoint); + Assert.True(client.Connected); + await DisconnectAsync(client, reuseSocket); Assert.False(client.Connected); - args.RemoteEndPoint = server2.EndPoint; - - if (client.ConnectAsync(args)) + if (reuseSocket) { - completed.WaitOne(); + // Note that the new connect operation must be asynchronous + // (why? I'm not sure, but that's the way it works currently) + await client.ConnectAsync(server2.EndPoint); + Assert.True(client.Connected); + } + else if (UsesSync) + { + await Assert.ThrowsAsync(async () => await ConnectAsync(client, server2.EndPoint)); + } + else + { + SocketException se = await Assert.ThrowsAsync(async () => await ConnectAsync(client, server2.EndPoint)); + Assert.Equal(SocketError.IsConnected, se.SocketErrorCode); } - - Assert.Equal(reuseSocket ? SocketError.Success : SocketError.IsConnected, args.SocketError); } } } - [Theory] - [InlineData(true)] - [InlineData(false)] - [OuterLoop("https://github.com/dotnet/runtime/issues/18406")] - public void DisconnectAsync_Success(bool reuseSocket) + [Fact] + public async Task DisconnectAndReuse_ReconnectSync_ThrowsInvalidOperationException() { - AutoResetEvent completed = new AutoResetEvent(false); - IPEndPoint loopback = new IPEndPoint(IPAddress.Loopback, 0); using (var server1 = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, loopback)) using (var server2 = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, loopback)) { - SocketAsyncEventArgs args = new SocketAsyncEventArgs(); - args.Completed += OnCompleted; - args.UserToken = completed; - args.RemoteEndPoint = server1.EndPoint; - args.DisconnectReuseSocket = reuseSocket; - using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) { - if (client.ConnectAsync(args)) - { - completed.WaitOne(); - } - - Assert.Equal(SocketError.Success, args.SocketError); + await ConnectAsync(client, server1.EndPoint); + Assert.True(client.Connected); - if (client.DisconnectAsync(args)) - { - completed.WaitOne(); - } - - Assert.Equal(SocketError.Success, args.SocketError); + await DisconnectAsync(client, reuseSocket: true); Assert.False(client.Connected); - args.RemoteEndPoint = server2.EndPoint; - - if (client.ConnectAsync(args)) - { - completed.WaitOne(); - } - - Assert.Equal(reuseSocket ? SocketError.Success : SocketError.IsConnected, args.SocketError); + // Note that the new connect operation must be asynchronous + // (why? I'm not sure, but that's the way it works currently) + // So try connecting synchronously, and it should fail + Assert.Throws(() => client.Connect(server2.EndPoint)); } } } @@ -134,46 +76,116 @@ public void DisconnectAsync_Success(bool reuseSocket) [Theory] [InlineData(true)] [InlineData(false)] - [OuterLoop("https://github.com/dotnet/runtime/issues/18406")] - public void BeginDisconnect_Success(bool reuseSocket) + public void Disconnect_NotConnected_ThrowsInvalidOperationException(bool reuseSocket) { - AutoResetEvent completed = new AutoResetEvent(false); + using (Socket s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + Assert.ThrowsAsync(async () => await DisconnectAsync(s, reuseSocket)); + } + } - IPEndPoint loopback = new IPEndPoint(IPAddress.Loopback, 0); - using (var server1 = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, loopback)) - using (var server2 = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, loopback)) + [Theory] + [InlineData(true)] + [InlineData(false)] + public void Disconnect_ObjectDisposed_ThrowsObjectDisposedException(bool reuseSocket) + { + using (Socket s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) { - SocketAsyncEventArgs args = new SocketAsyncEventArgs(); - args.Completed += OnCompleted; - args.UserToken = completed; - args.RemoteEndPoint = server1.EndPoint; + s.Dispose(); + Assert.ThrowsAsync(async () => await DisconnectAsync(s, reuseSocket)); + } + } + } - using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) - { - if (client.ConnectAsync(args)) - { - completed.WaitOne(); - } + public sealed class Disconnect_Sync : Disconnect + { + public Disconnect_Sync(ITestOutputHelper output) : base(output) { } + } - Assert.Equal(SocketError.Success, args.SocketError); + public sealed class Disconnect_SyncForceNonBlocking : Disconnect + { + public Disconnect_SyncForceNonBlocking(ITestOutputHelper output) : base(output) { } + } - IAsyncResult ar = client.BeginDisconnect(reuseSocket, null, null); - client.EndDisconnect(ar); + public sealed class Disconnect_Apm : Disconnect + { + public Disconnect_Apm(ITestOutputHelper output) : base(output) { } - Assert.False(client.Connected); + [Fact] + public void EndDisconnect_InvalidArguments_Throws() + { + using (Socket s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + AssertExtensions.Throws("asyncResult", () => s.EndDisconnect(null)); + AssertExtensions.Throws("asyncResult", () => s.EndDisconnect(Task.CompletedTask)); + } + } + + [Fact] + public void BeginDisconnect_NotConnected_ThrowSync() + { + using (Socket s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + Assert.Throws(() => s.BeginDisconnect(true, null, null)); + Assert.Throws(() => s.BeginDisconnect(false, null, null)); + } + } + + [Fact] + public void BeginDisconnection_ObjectDisposed_ThrowSync() + { + using (Socket s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + s.Dispose(); + Assert.Throws(() => s.BeginDisconnect(true, null, null)); + Assert.Throws(() => s.BeginDisconnect(false, null, null)); + } + } + } - Assert.Throws(() => client.EndDisconnect(ar)); + public sealed class Disconnect_Task : Disconnect + { + public Disconnect_Task(ITestOutputHelper output) : base(output) { } + } - args.RemoteEndPoint = server2.EndPoint; + public sealed class Disconnect_CancellableTask : Disconnect + { + public Disconnect_CancellableTask(ITestOutputHelper output) : base(output) { } - if (client.ConnectAsync(args)) - { - completed.WaitOne(); - } + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task Disconnect_Precanceled_ThrowsOperationCanceledException(bool reuseSocket) + { + IPEndPoint loopback = new IPEndPoint(IPAddress.Loopback, 0); + using (var server1 = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, loopback)) + { + using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + await ConnectAsync(client, server1.EndPoint); + Assert.True(client.Connected); - Assert.Equal(reuseSocket ? SocketError.Success : SocketError.IsConnected, args.SocketError); + CancellationTokenSource precanceledSource = new CancellationTokenSource(); + precanceledSource.Cancel(); + + OperationCanceledException oce = await Assert.ThrowsAnyAsync(async () => await client.DisconnectAsync(reuseSocket, precanceledSource.Token)); + Assert.Equal(precanceledSource.Token, oce.CancellationToken); } } } } + + public sealed class Disconnect_Eap : Disconnect + { + public Disconnect_Eap(ITestOutputHelper output) : base(output) { } + + [Fact] + public void InvalidArguments_Throw() + { + using (Socket s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + AssertExtensions.Throws("e", () => s.DisconnectAsync(null)); + } + } + } } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs index bdaf68ffb497e..b1de69b736601 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs @@ -35,6 +35,7 @@ public abstract Task ReceiveMessageFromAsync( public abstract Task SendToAsync(Socket s, ArraySegment buffer, EndPoint endpoint); public abstract Task SendFileAsync(Socket s, string fileName); public abstract Task SendFileAsync(Socket s, string fileName, ArraySegment preBuffer, ArraySegment postBuffer, TransmitFileOptions flags); + public abstract Task DisconnectAsync(Socket s, bool reuseSocket); public virtual bool GuaranteedSendOrdering => true; public virtual bool ValidatesArrayArguments => true; public virtual bool UsesSync => false; @@ -97,6 +98,9 @@ public override Task SendToAsync(Socket s, ArraySegment buffer, EndPo public override Task SendFileAsync(Socket s, string fileName) => Task.Run(() => s.SendFile(fileName)); public override Task SendFileAsync(Socket s, string fileName, ArraySegment preBuffer, ArraySegment postBuffer, TransmitFileOptions flags) => Task.Run(() => s.SendFile(fileName, preBuffer.Array, postBuffer.Array, flags)); + public override Task DisconnectAsync(Socket s, bool reuseSocket) => + Task.Run(() => s.Disconnect(reuseSocket)); + public override bool GuaranteedSendOrdering => false; public override bool UsesSync => true; public override bool ConnectAfterDisconnectResultsInInvalidOperationException => true; @@ -205,6 +209,11 @@ public override Task SendFileAsync(Socket s, string fileName, ArraySegment Task.Factory.FromAsync( (callback, state) => s.BeginSendFile(fileName, preBuffer.Array, postBuffer.Array, flags, callback, state), s.EndSendFile, null); + public override Task DisconnectAsync(Socket s, bool reuseSocket) => + Task.Factory.FromAsync( + (callback, state) => s.BeginDisconnect(reuseSocket, callback, state), + s.EndDisconnect, null); + public override bool UsesApm => true; } @@ -236,6 +245,8 @@ public override Task SendToAsync(Socket s, ArraySegment buffer, EndPo s.SendToAsync(buffer, SocketFlags.None, endPoint); public override Task SendFileAsync(Socket s, string fileName) => throw new NotSupportedException(); public override Task SendFileAsync(Socket s, string fileName, ArraySegment preBuffer, ArraySegment postBuffer, TransmitFileOptions flags) => throw new NotSupportedException(); + public override Task DisconnectAsync(Socket s, bool reuseSocket) => + s.DisconnectAsync(reuseSocket).AsTask(); } // Same as above, but call the CancellationToken overloads where possible @@ -272,6 +283,8 @@ public override Task SendToAsync(Socket s, ArraySegment buffer, EndPo s.SendToAsync(buffer, SocketFlags.None, endPoint, _cts.Token).AsTask() ; public override Task SendFileAsync(Socket s, string fileName) => throw new NotSupportedException(); public override Task SendFileAsync(Socket s, string fileName, ArraySegment preBuffer, ArraySegment postBuffer, TransmitFileOptions flags) => throw new NotSupportedException(); + public override Task DisconnectAsync(Socket s, bool reuseSocket) => + s.DisconnectAsync(reuseSocket, _cts.Token).AsTask(); } public sealed class SocketHelperEap : SocketHelperBase @@ -364,6 +377,13 @@ public override Task SendToAsync(Socket s, ArraySegment buffer, EndPo }); public override Task SendFileAsync(Socket s, string fileName) => throw new NotSupportedException(); public override Task SendFileAsync(Socket s, string fileName, ArraySegment preBuffer, ArraySegment postBuffer, TransmitFileOptions flags) => throw new NotSupportedException(); + public override Task DisconnectAsync(Socket s, bool reuseSocket) => + InvokeAsync(s, e => true, e => + { + e.DisconnectReuseSocket = reuseSocket; + return s.DisconnectAsync(e); + }); + private static Task InvokeAsync( Socket s, Func getResult, @@ -421,6 +441,7 @@ public Task ReceiveMessageFromAsync(Socket s, Ar public Task SendFileAsync(Socket s, string fileName) => _socketHelper.SendFileAsync(s, fileName); public Task SendFileAsync(Socket s, string fileName, ArraySegment preBuffer, ArraySegment postBuffer, TransmitFileOptions flags) => _socketHelper.SendFileAsync(s, fileName, preBuffer, postBuffer, flags); + public Task DisconnectAsync(Socket s, bool reuseSocket) => _socketHelper.DisconnectAsync(s, reuseSocket); public bool GuaranteedSendOrdering => _socketHelper.GuaranteedSendOrdering; public bool ValidatesArrayArguments => _socketHelper.ValidatesArrayArguments; public bool UsesSync => _socketHelper.UsesSync;