diff --git a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
index 2da5179682d37..b96ce3a5f1d25 100644
--- a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
+++ b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
@@ -338,6 +338,7 @@ public void Connect(string host, int port) { }
public static bool ConnectAsync(System.Net.Sockets.SocketType socketType, System.Net.Sockets.ProtocolType protocolType, System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
public void Disconnect(bool reuseSocket) { }
public bool DisconnectAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
+ public System.Threading.Tasks.ValueTask DisconnectAsync(bool reuseSocket, System.Threading.CancellationToken cancellationToken = default) { throw null; }
public void Dispose() { }
protected virtual void Dispose(bool disposing) { }
[System.Runtime.Versioning.SupportedOSPlatformAttribute("windows")]
diff --git a/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj b/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj
index bbddbb391d54b..42314da6c7e9d 100644
--- a/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj
+++ b/src/libraries/System.Net.Sockets/src/System.Net.Sockets.csproj
@@ -45,7 +45,6 @@
-
-
diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/DisconnectOverlappedAsyncResult.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/DisconnectOverlappedAsyncResult.Unix.cs
deleted file mode 100644
index 856962b56b7a8..0000000000000
--- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/DisconnectOverlappedAsyncResult.Unix.cs
+++ /dev/null
@@ -1,11 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-
-namespace System.Net.Sockets
-{
- // DisconnectOverlappedAsyncResult - used to take care of storage for async Socket BeginDisconnect call.
- internal sealed partial class DisconnectOverlappedAsyncResult : BaseOverlappedAsyncResult
- {
- internal void PostCompletion(SocketError errorCode) => CompletionCallback(0, errorCode);
- }
-}
diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/DisconnectOverlappedAsyncResult.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/DisconnectOverlappedAsyncResult.cs
deleted file mode 100644
index 2d6c81ac584f0..0000000000000
--- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/DisconnectOverlappedAsyncResult.cs
+++ /dev/null
@@ -1,27 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-
-namespace System.Net.Sockets
-{
- // DisconnectOverlappedAsyncResult - used to take care of storage for async Socket BeginDisconnect call.
- internal sealed partial class DisconnectOverlappedAsyncResult : BaseOverlappedAsyncResult
- {
- internal DisconnectOverlappedAsyncResult(Socket socket, object? asyncState, AsyncCallback? asyncCallback) :
- base(socket, asyncState, asyncCallback)
- {
- }
-
- // This method will be called by us when the IO completes synchronously and
- // by the ThreadPool when the IO completes asynchronously.
- internal override object? PostCompletion(int numBytes)
- {
- if (ErrorCode == (int)SocketError.Success)
- {
- Socket socket = (Socket)AsyncObject!;
- socket.SetToDisconnected();
- socket._remoteEndPoint = null;
- }
- return base.PostCompletion(numBytes);
- }
- }
-}
diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs
index 461122cfd09a6..ac604ebc4a059 100644
--- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs
+++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs
@@ -263,6 +263,29 @@ public ValueTask ConnectAsync(string host, int port, CancellationToken cancellat
return ConnectAsync(ep, cancellationToken);
}
+ ///
+ /// Disconnects a connected socket from the remote host.
+ ///
+ /// Indicates whether the socket should be available for reuse after disconnect.
+ /// A cancellation token that can be used to cancel the asynchronous operation.
+ /// An asynchronous task that completes when the socket is disconnected.
+ public ValueTask DisconnectAsync(bool reuseSocket, CancellationToken cancellationToken = default)
+ {
+ if (cancellationToken.IsCancellationRequested)
+ {
+ return ValueTask.FromCanceled(cancellationToken);
+ }
+
+ AwaitableSocketAsyncEventArgs saea =
+ Interlocked.Exchange(ref _singleBufferSendEventArgs, null) ??
+ new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: false);
+
+ saea.DisconnectReuseSocket = reuseSocket;
+ saea.WrapExceptionsForNetworkStream = false;
+
+ return saea.DisconnectAsync(this, cancellationToken);
+ }
+
///
/// Receives data from a connected socket.
///
@@ -1028,6 +1051,25 @@ public ValueTask ConnectAsync(Socket socket)
ValueTask.FromException(CreateException(error));
}
+ public ValueTask DisconnectAsync(Socket socket, CancellationToken cancellationToken)
+ {
+ Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use");
+
+ if (socket.DisconnectAsync(this, cancellationToken))
+ {
+ _cancellationToken = cancellationToken;
+ return new ValueTask(this, _token);
+ }
+
+ SocketError error = SocketError;
+
+ Release();
+
+ return error == SocketError.Success ?
+ ValueTask.CompletedTask :
+ ValueTask.FromException(CreateException(error));
+ }
+
/// Gets the status of the operation.
public ValueTaskSourceStatus GetStatus(short token)
{
diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
index 2671112ffa139..d39abf6609ba5 100644
--- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
+++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
@@ -73,7 +73,6 @@ private sealed class CacheSet
private int _closeTimeout = Socket.DefaultCloseTimeout;
private int _disposed; // 0 == false, anything else == true
- #region Constructors
public Socket(SocketType socketType, ProtocolType protocolType)
: this(OSSupportsIPv6 ? AddressFamily.InterNetworkV6 : AddressFamily.InterNetwork, socketType, protocolType)
{
@@ -242,9 +241,10 @@ private static SafeSocketHandle ValidateHandle(SafeSocketHandle handle) =>
handle is null ? throw new ArgumentNullException(nameof(handle)) :
handle.IsInvalid ? throw new ArgumentException(SR.Arg_InvalidHandle, nameof(handle)) :
handle;
- #endregion
- #region Properties
+ //
+ // Properties
+ //
// The CLR allows configuration of these properties, separately from whether the OS supports IPv4/6. We
// do not provide these config options, so SupportsIPvX === OSSupportsIPvX.
@@ -761,9 +761,10 @@ internal bool CanTryAddressFamily(AddressFamily family)
{
return (family == _addressFamily) || (family == AddressFamily.InterNetwork && IsDualMode);
}
- #endregion
- #region Public Methods
+ //
+ // Public Methods
+ //
// Associates a socket with an end point.
public void Bind(EndPoint localEP)
@@ -2116,43 +2117,14 @@ public IAsyncResult BeginConnect(IPAddress address, int port, AsyncCallback? req
public IAsyncResult BeginConnect(IPAddress[] addresses, int port, AsyncCallback? requestCallback, object? state) =>
TaskToApm.Begin(ConnectAsync(addresses, port), requestCallback, state);
- public IAsyncResult BeginDisconnect(bool reuseSocket, AsyncCallback? callback, object? state)
+ public void EndConnect(IAsyncResult asyncResult)
{
ThrowIfDisposed();
-
- // Start context-flowing op. No need to lock - we don't use the context till the callback.
- DisconnectOverlappedAsyncResult asyncResult = new DisconnectOverlappedAsyncResult(this, state, callback);
- asyncResult.StartPostingAsyncOp(false);
-
- // Post the disconnect.
- DoBeginDisconnect(reuseSocket, asyncResult);
-
- // Finish flowing (or call the callback), and return.
- asyncResult.FinishPostingAsyncOp();
- return asyncResult;
+ TaskToApm.End(asyncResult);
}
- private void DoBeginDisconnect(bool reuseSocket, DisconnectOverlappedAsyncResult asyncResult)
- {
- SocketError errorCode = SocketError.Success;
-
- errorCode = SocketPal.DisconnectAsync(this, _handle, reuseSocket, asyncResult);
-
- if (errorCode == SocketError.Success)
- {
- SetToDisconnected();
- _remoteEndPoint = null;
- _localEndPoint = null;
- }
-
- if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this, $"UnsafeNclNativeMethods.OSSOCK.DisConnectEx returns:{errorCode}");
-
- // If the call failed, update our status and throw
- if (!CheckErrorAndUpdateStatus(errorCode))
- {
- throw new SocketException((int)errorCode);
- }
- }
+ public IAsyncResult BeginDisconnect(bool reuseSocket, AsyncCallback? callback, object? state) =>
+ TaskToApmBeginWithSyncExceptions(DisconnectAsync(reuseSocket).AsTask(), callback, state);
public void Disconnect(bool reuseSocket)
{
@@ -2175,47 +2147,12 @@ public void Disconnect(bool reuseSocket)
_localEndPoint = null;
}
- public void EndConnect(IAsyncResult asyncResult)
+ public void EndDisconnect(IAsyncResult asyncResult)
{
ThrowIfDisposed();
TaskToApm.End(asyncResult);
}
- public void EndDisconnect(IAsyncResult asyncResult)
- {
- ThrowIfDisposed();
-
- if (asyncResult == null)
- {
- throw new ArgumentNullException(nameof(asyncResult));
- }
-
- //get async result and check for errors
- LazyAsyncResult? castedAsyncResult = asyncResult as LazyAsyncResult;
- if (castedAsyncResult == null || castedAsyncResult.AsyncObject != this)
- {
- throw new ArgumentException(SR.net_io_invalidasyncresult, nameof(asyncResult));
- }
- if (castedAsyncResult.EndCalled)
- {
- throw new InvalidOperationException(SR.Format(SR.net_io_invalidendcall, nameof(EndDisconnect)));
- }
-
- //wait for completion if it hasn't occurred
- castedAsyncResult.InternalWaitForCompletion();
- castedAsyncResult.EndCalled = true;
-
- if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(this);
-
- //
- // if the asynchronous native call failed asynchronously
- // we'll throw a SocketException
- //
- if ((SocketError)castedAsyncResult.ErrorCode != SocketError.Success)
- {
- UpdateStatusAfterSocketErrorAndThrowException((SocketError)castedAsyncResult.ErrorCode);
- }
- }
public IAsyncResult BeginSend(byte[] buffer, int offset, int size, SocketFlags socketFlags, AsyncCallback? callback, object? state)
{
@@ -2668,7 +2605,10 @@ public void Shutdown(SocketShutdown how)
InternalSetBlocking(_willBlockInternal);
}
- #region Async methods
+ //
+ // Async methods
+ //
+
public bool AcceptAsync(SocketAsyncEventArgs e)
{
ThrowIfDisposed();
@@ -2889,7 +2829,9 @@ public static void CancelConnectAsync(SocketAsyncEventArgs e)
e.CancelConnectAsync();
}
- public bool DisconnectAsync(SocketAsyncEventArgs e)
+ public bool DisconnectAsync(SocketAsyncEventArgs e) => DisconnectAsync(e, default);
+
+ private bool DisconnectAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken)
{
// Throw if socket disposed
ThrowIfDisposed();
@@ -2904,7 +2846,7 @@ public bool DisconnectAsync(SocketAsyncEventArgs e)
SocketError socketError = SocketError.Success;
try
{
- socketError = e.DoOperationDisconnect(this, _handle);
+ socketError = e.DoOperationDisconnect(this, _handle, cancellationToken);
}
catch
{
@@ -3155,10 +3097,10 @@ private bool SendToAsync(SocketAsyncEventArgs e, CancellationToken cancellationT
return socketError == SocketError.IOPending;
}
- #endregion
- #endregion
- #region Internal and private properties
+ //
+ // Internal and private properties
+ //
private CacheSet Caches
{
@@ -3174,9 +3116,10 @@ private CacheSet Caches
}
internal bool Disposed => _disposed != 0;
- #endregion
- #region Internal and private methods
+ //
+ // Internal and private methods
+ //
internal static void GetIPProtocolInformation(AddressFamily addressFamily, Internals.SocketAddress socketAddress, out bool isIPv4, out bool isIPv6)
{
@@ -3889,6 +3832,16 @@ private static SocketError GetSocketErrorFromFaultedTask(Task t)
};
}
- #endregion
+ // Helper to maintain existing behavior of Socket APM methods to throw synchronously from Begin*.
+ private static IAsyncResult TaskToApmBeginWithSyncExceptions(Task task, AsyncCallback? callback, object? state)
+ {
+ if (task.IsFaulted)
+ {
+ task.GetAwaiter().GetResult();
+ Debug.Fail("Task faulted but GetResult did not throw???");
+ }
+
+ return TaskToApm.Begin(task, callback, state);
+ }
}
}
diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs
index b5d6b0b3b2e17..88120241d2fcc 100644
--- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs
+++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs
@@ -93,7 +93,7 @@ internal unsafe SocketError DoOperationConnect(Socket socket, SafeSocketHandle h
return socketError;
}
- internal SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle)
+ internal SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken)
{
SocketError socketError = SocketPal.Disconnect(socket, handle, _disconnectReuseSocket);
FinishOperationSync(socketError, 0, SocketFlags.None);
diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs
index 8e9981380be4a..7bca8498ebca1 100644
--- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs
+++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs
@@ -364,8 +364,11 @@ internal unsafe SocketError DoOperationConnectEx(Socket socket, SafeSocketHandle
}
}
- internal unsafe SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle)
+ internal unsafe SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken)
{
+ // Note: CancellationToken is ignored for now.
+ // See https://github.com/dotnet/runtime/issues/51452
+
NativeOverlapped* overlapped = AllocateNativeOverlapped();
try
{
@@ -1188,6 +1191,7 @@ private unsafe SocketError FinishOperationConnect()
private void CompleteCore()
{
_strongThisRef.Value = null; // null out this reference from the overlapped so this isn't kept alive artificially
+
if (_singleBufferHandleState != SingleBufferHandleState.None)
{
// If the state isn't None, then either it's Set, in which case there's state to cleanup,
@@ -1213,6 +1217,8 @@ void CompleteCoreSpin()
sw.SpinOnce();
}
+ Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.Set);
+
// Remove any cancellation registration. First dispose the registration
// to ensure that cancellation will either never fine or will have completed
// firing before we continue. Only then can we safely null out the overlapped.
@@ -1223,6 +1229,8 @@ void CompleteCoreSpin()
}
// Release any GC handles.
+ Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.Set);
+
if (_singleBufferHandleState == SingleBufferHandleState.Set)
{
_singleBufferHandleState = SingleBufferHandleState.None;
diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs
index 735ece8f8d0a7..3377991fd02cf 100644
--- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs
+++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs
@@ -1976,13 +1976,6 @@ public static SocketError AcceptAsync(Socket socket, SafeSocketHandle handle, Sa
return socketError;
}
- internal static SocketError DisconnectAsync(Socket socket, SafeSocketHandle handle, bool reuseSocket, DisconnectOverlappedAsyncResult asyncResult)
- {
- SocketError socketError = Disconnect(socket, handle, reuseSocket);
- asyncResult.PostCompletion(socketError);
- return socketError;
- }
-
internal static SocketError Disconnect(Socket socket, SafeSocketHandle handle, bool reuseSocket)
{
handle.SetToDisconnected();
diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs
index cf6e12e66dbba..e902d22c049bc 100644
--- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs
+++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs
@@ -1137,27 +1137,6 @@ public static void CheckDualModeReceiveSupport(Socket socket)
// Dual-mode sockets support received packet info on Windows.
}
- internal static unsafe SocketError DisconnectAsync(Socket socket, SafeSocketHandle handle, bool reuseSocket, DisconnectOverlappedAsyncResult asyncResult)
- {
- asyncResult.SetUnmanagedStructures(null);
- try
- {
- // This can throw ObjectDisposedException
- bool success = socket.DisconnectEx(
- handle,
- asyncResult.DangerousOverlappedPointer, // SafeHandle was just created in SetUnmanagedStructures
- (int)(reuseSocket ? TransmitFileOptions.ReuseSocket : 0),
- 0);
-
- return asyncResult.ProcessOverlappedResult(success, 0);
- }
- catch
- {
- asyncResult.ReleaseUnmanagedStructures();
- throw;
- }
- }
-
internal static SocketError Disconnect(Socket socket, SafeSocketHandle handle, bool reuseSocket)
{
SocketError errorCode = SocketError.Success;
diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/DisconnectTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/DisconnectTest.cs
index 603e0c3b84ae0..1d343cb434f6d 100644
--- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/DisconnectTest.cs
+++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/DisconnectTest.cs
@@ -9,124 +9,66 @@
namespace System.Net.Sockets.Tests
{
- public class DisconnectTest
+ public abstract class Disconnect : SocketTestHelperBase where T : SocketHelperBase, new()
{
- private readonly ITestOutputHelper _log;
-
- public DisconnectTest(ITestOutputHelper output)
- {
- _log = TestLogging.GetInstance();
- Assert.True(Capability.IPv4Support() || Capability.IPv6Support());
- }
-
- private static void OnCompleted(object sender, SocketAsyncEventArgs args)
- {
- EventWaitHandle handle = (EventWaitHandle)args.UserToken;
- handle.Set();
- }
-
- [Fact]
- public void InvalidArguments_Throw()
- {
- using (Socket s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
- {
- AssertExtensions.Throws("asyncResult", () => s.EndDisconnect(null));
- AssertExtensions.Throws("e", () => s.DisconnectAsync(null));
- AssertExtensions.Throws("asyncResult", () => s.EndDisconnect(Task.CompletedTask));
- s.Dispose();
- Assert.Throws(() => s.Disconnect(true));
- Assert.Throws(() => s.BeginDisconnect(true, null, null));
- Assert.Throws(() => s.EndDisconnect(null));
- Assert.Throws(() => { s.DisconnectAsync(null); });
- }
- }
+ protected Disconnect(ITestOutputHelper output) : base(output) { }
[Theory]
[InlineData(true)]
[InlineData(false)]
- [OuterLoop("https://github.com/dotnet/runtime/issues/18406")]
- public void Disconnect_Success(bool reuseSocket)
+ public async Task Disconnect_Success(bool reuseSocket)
{
- AutoResetEvent completed = new AutoResetEvent(false);
-
IPEndPoint loopback = new IPEndPoint(IPAddress.Loopback, 0);
using (var server1 = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, loopback))
using (var server2 = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, loopback))
{
- SocketAsyncEventArgs args = new SocketAsyncEventArgs();
- args.Completed += OnCompleted;
- args.UserToken = completed;
- args.RemoteEndPoint = server1.EndPoint;
- args.DisconnectReuseSocket = reuseSocket;
-
using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
{
- if (client.ConnectAsync(args))
- {
- completed.WaitOne();
- }
-
- Assert.Equal(SocketError.Success, args.SocketError);
-
- client.Disconnect(reuseSocket);
+ await ConnectAsync(client, server1.EndPoint);
+ Assert.True(client.Connected);
+ await DisconnectAsync(client, reuseSocket);
Assert.False(client.Connected);
- args.RemoteEndPoint = server2.EndPoint;
-
- if (client.ConnectAsync(args))
+ if (reuseSocket)
{
- completed.WaitOne();
+ // Note that the new connect operation must be asynchronous
+ // (why? I'm not sure, but that's the way it works currently)
+ await client.ConnectAsync(server2.EndPoint);
+ Assert.True(client.Connected);
+ }
+ else if (UsesSync)
+ {
+ await Assert.ThrowsAsync(async () => await ConnectAsync(client, server2.EndPoint));
+ }
+ else
+ {
+ SocketException se = await Assert.ThrowsAsync(async () => await ConnectAsync(client, server2.EndPoint));
+ Assert.Equal(SocketError.IsConnected, se.SocketErrorCode);
}
-
- Assert.Equal(reuseSocket ? SocketError.Success : SocketError.IsConnected, args.SocketError);
}
}
}
- [Theory]
- [InlineData(true)]
- [InlineData(false)]
- [OuterLoop("https://github.com/dotnet/runtime/issues/18406")]
- public void DisconnectAsync_Success(bool reuseSocket)
+ [Fact]
+ public async Task DisconnectAndReuse_ReconnectSync_ThrowsInvalidOperationException()
{
- AutoResetEvent completed = new AutoResetEvent(false);
-
IPEndPoint loopback = new IPEndPoint(IPAddress.Loopback, 0);
using (var server1 = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, loopback))
using (var server2 = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, loopback))
{
- SocketAsyncEventArgs args = new SocketAsyncEventArgs();
- args.Completed += OnCompleted;
- args.UserToken = completed;
- args.RemoteEndPoint = server1.EndPoint;
- args.DisconnectReuseSocket = reuseSocket;
-
using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
{
- if (client.ConnectAsync(args))
- {
- completed.WaitOne();
- }
-
- Assert.Equal(SocketError.Success, args.SocketError);
+ await ConnectAsync(client, server1.EndPoint);
+ Assert.True(client.Connected);
- if (client.DisconnectAsync(args))
- {
- completed.WaitOne();
- }
-
- Assert.Equal(SocketError.Success, args.SocketError);
+ await DisconnectAsync(client, reuseSocket: true);
Assert.False(client.Connected);
- args.RemoteEndPoint = server2.EndPoint;
-
- if (client.ConnectAsync(args))
- {
- completed.WaitOne();
- }
-
- Assert.Equal(reuseSocket ? SocketError.Success : SocketError.IsConnected, args.SocketError);
+ // Note that the new connect operation must be asynchronous
+ // (why? I'm not sure, but that's the way it works currently)
+ // So try connecting synchronously, and it should fail
+ Assert.Throws(() => client.Connect(server2.EndPoint));
}
}
}
@@ -134,46 +76,116 @@ public void DisconnectAsync_Success(bool reuseSocket)
[Theory]
[InlineData(true)]
[InlineData(false)]
- [OuterLoop("https://github.com/dotnet/runtime/issues/18406")]
- public void BeginDisconnect_Success(bool reuseSocket)
+ public void Disconnect_NotConnected_ThrowsInvalidOperationException(bool reuseSocket)
{
- AutoResetEvent completed = new AutoResetEvent(false);
+ using (Socket s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+ {
+ Assert.ThrowsAsync(async () => await DisconnectAsync(s, reuseSocket));
+ }
+ }
- IPEndPoint loopback = new IPEndPoint(IPAddress.Loopback, 0);
- using (var server1 = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, loopback))
- using (var server2 = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, loopback))
+ [Theory]
+ [InlineData(true)]
+ [InlineData(false)]
+ public void Disconnect_ObjectDisposed_ThrowsObjectDisposedException(bool reuseSocket)
+ {
+ using (Socket s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
{
- SocketAsyncEventArgs args = new SocketAsyncEventArgs();
- args.Completed += OnCompleted;
- args.UserToken = completed;
- args.RemoteEndPoint = server1.EndPoint;
+ s.Dispose();
+ Assert.ThrowsAsync(async () => await DisconnectAsync(s, reuseSocket));
+ }
+ }
+ }
- using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
- {
- if (client.ConnectAsync(args))
- {
- completed.WaitOne();
- }
+ public sealed class Disconnect_Sync : Disconnect
+ {
+ public Disconnect_Sync(ITestOutputHelper output) : base(output) { }
+ }
- Assert.Equal(SocketError.Success, args.SocketError);
+ public sealed class Disconnect_SyncForceNonBlocking : Disconnect
+ {
+ public Disconnect_SyncForceNonBlocking(ITestOutputHelper output) : base(output) { }
+ }
- IAsyncResult ar = client.BeginDisconnect(reuseSocket, null, null);
- client.EndDisconnect(ar);
+ public sealed class Disconnect_Apm : Disconnect
+ {
+ public Disconnect_Apm(ITestOutputHelper output) : base(output) { }
- Assert.False(client.Connected);
+ [Fact]
+ public void EndDisconnect_InvalidArguments_Throws()
+ {
+ using (Socket s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+ {
+ AssertExtensions.Throws("asyncResult", () => s.EndDisconnect(null));
+ AssertExtensions.Throws("asyncResult", () => s.EndDisconnect(Task.CompletedTask));
+ }
+ }
+
+ [Fact]
+ public void BeginDisconnect_NotConnected_ThrowSync()
+ {
+ using (Socket s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+ {
+ Assert.Throws(() => s.BeginDisconnect(true, null, null));
+ Assert.Throws(() => s.BeginDisconnect(false, null, null));
+ }
+ }
+
+ [Fact]
+ public void BeginDisconnection_ObjectDisposed_ThrowSync()
+ {
+ using (Socket s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+ {
+ s.Dispose();
+ Assert.Throws(() => s.BeginDisconnect(true, null, null));
+ Assert.Throws(() => s.BeginDisconnect(false, null, null));
+ }
+ }
+ }
- Assert.Throws(() => client.EndDisconnect(ar));
+ public sealed class Disconnect_Task : Disconnect
+ {
+ public Disconnect_Task(ITestOutputHelper output) : base(output) { }
+ }
- args.RemoteEndPoint = server2.EndPoint;
+ public sealed class Disconnect_CancellableTask : Disconnect
+ {
+ public Disconnect_CancellableTask(ITestOutputHelper output) : base(output) { }
- if (client.ConnectAsync(args))
- {
- completed.WaitOne();
- }
+ [Theory]
+ [InlineData(true)]
+ [InlineData(false)]
+ public async Task Disconnect_Precanceled_ThrowsOperationCanceledException(bool reuseSocket)
+ {
+ IPEndPoint loopback = new IPEndPoint(IPAddress.Loopback, 0);
+ using (var server1 = SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, loopback))
+ {
+ using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+ {
+ await ConnectAsync(client, server1.EndPoint);
+ Assert.True(client.Connected);
- Assert.Equal(reuseSocket ? SocketError.Success : SocketError.IsConnected, args.SocketError);
+ CancellationTokenSource precanceledSource = new CancellationTokenSource();
+ precanceledSource.Cancel();
+
+ OperationCanceledException oce = await Assert.ThrowsAnyAsync(async () => await client.DisconnectAsync(reuseSocket, precanceledSource.Token));
+ Assert.Equal(precanceledSource.Token, oce.CancellationToken);
}
}
}
}
+
+ public sealed class Disconnect_Eap : Disconnect
+ {
+ public Disconnect_Eap(ITestOutputHelper output) : base(output) { }
+
+ [Fact]
+ public void InvalidArguments_Throw()
+ {
+ using (Socket s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+ {
+ AssertExtensions.Throws("e", () => s.DisconnectAsync(null));
+ }
+ }
+ }
}
diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs
index bdaf68ffb497e..b1de69b736601 100644
--- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs
+++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs
@@ -35,6 +35,7 @@ public abstract Task ReceiveMessageFromAsync(
public abstract Task SendToAsync(Socket s, ArraySegment buffer, EndPoint endpoint);
public abstract Task SendFileAsync(Socket s, string fileName);
public abstract Task SendFileAsync(Socket s, string fileName, ArraySegment preBuffer, ArraySegment postBuffer, TransmitFileOptions flags);
+ public abstract Task DisconnectAsync(Socket s, bool reuseSocket);
public virtual bool GuaranteedSendOrdering => true;
public virtual bool ValidatesArrayArguments => true;
public virtual bool UsesSync => false;
@@ -97,6 +98,9 @@ public override Task SendToAsync(Socket s, ArraySegment buffer, EndPo
public override Task SendFileAsync(Socket s, string fileName) => Task.Run(() => s.SendFile(fileName));
public override Task SendFileAsync(Socket s, string fileName, ArraySegment preBuffer, ArraySegment postBuffer, TransmitFileOptions flags) =>
Task.Run(() => s.SendFile(fileName, preBuffer.Array, postBuffer.Array, flags));
+ public override Task DisconnectAsync(Socket s, bool reuseSocket) =>
+ Task.Run(() => s.Disconnect(reuseSocket));
+
public override bool GuaranteedSendOrdering => false;
public override bool UsesSync => true;
public override bool ConnectAfterDisconnectResultsInInvalidOperationException => true;
@@ -205,6 +209,11 @@ public override Task SendFileAsync(Socket s, string fileName, ArraySegment
Task.Factory.FromAsync(
(callback, state) => s.BeginSendFile(fileName, preBuffer.Array, postBuffer.Array, flags, callback, state),
s.EndSendFile, null);
+ public override Task DisconnectAsync(Socket s, bool reuseSocket) =>
+ Task.Factory.FromAsync(
+ (callback, state) => s.BeginDisconnect(reuseSocket, callback, state),
+ s.EndDisconnect, null);
+
public override bool UsesApm => true;
}
@@ -236,6 +245,8 @@ public override Task SendToAsync(Socket s, ArraySegment buffer, EndPo
s.SendToAsync(buffer, SocketFlags.None, endPoint);
public override Task SendFileAsync(Socket s, string fileName) => throw new NotSupportedException();
public override Task SendFileAsync(Socket s, string fileName, ArraySegment preBuffer, ArraySegment postBuffer, TransmitFileOptions flags) => throw new NotSupportedException();
+ public override Task DisconnectAsync(Socket s, bool reuseSocket) =>
+ s.DisconnectAsync(reuseSocket).AsTask();
}
// Same as above, but call the CancellationToken overloads where possible
@@ -272,6 +283,8 @@ public override Task SendToAsync(Socket s, ArraySegment buffer, EndPo
s.SendToAsync(buffer, SocketFlags.None, endPoint, _cts.Token).AsTask() ;
public override Task SendFileAsync(Socket s, string fileName) => throw new NotSupportedException();
public override Task SendFileAsync(Socket s, string fileName, ArraySegment preBuffer, ArraySegment postBuffer, TransmitFileOptions flags) => throw new NotSupportedException();
+ public override Task DisconnectAsync(Socket s, bool reuseSocket) =>
+ s.DisconnectAsync(reuseSocket, _cts.Token).AsTask();
}
public sealed class SocketHelperEap : SocketHelperBase
@@ -364,6 +377,13 @@ public override Task SendToAsync(Socket s, ArraySegment buffer, EndPo
});
public override Task SendFileAsync(Socket s, string fileName) => throw new NotSupportedException();
public override Task SendFileAsync(Socket s, string fileName, ArraySegment preBuffer, ArraySegment postBuffer, TransmitFileOptions flags) => throw new NotSupportedException();
+ public override Task DisconnectAsync(Socket s, bool reuseSocket) =>
+ InvokeAsync(s, e => true, e =>
+ {
+ e.DisconnectReuseSocket = reuseSocket;
+ return s.DisconnectAsync(e);
+ });
+
private static Task InvokeAsync(
Socket s,
Func getResult,
@@ -421,6 +441,7 @@ public Task ReceiveMessageFromAsync(Socket s, Ar
public Task SendFileAsync(Socket s, string fileName) => _socketHelper.SendFileAsync(s, fileName);
public Task SendFileAsync(Socket s, string fileName, ArraySegment preBuffer, ArraySegment postBuffer, TransmitFileOptions flags) =>
_socketHelper.SendFileAsync(s, fileName, preBuffer, postBuffer, flags);
+ public Task DisconnectAsync(Socket s, bool reuseSocket) => _socketHelper.DisconnectAsync(s, reuseSocket);
public bool GuaranteedSendOrdering => _socketHelper.GuaranteedSendOrdering;
public bool ValidatesArrayArguments => _socketHelper.ValidatesArrayArguments;
public bool UsesSync => _socketHelper.UsesSync;