diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs index 3a8593d0b0745..cde90ae0931cc 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs @@ -119,66 +119,70 @@ private Task AcceptAsyncApm(Socket? acceptSocket) internal Task ConnectAsync(EndPoint remoteEP) { - var tcs = new TaskCompletionSource(this); - BeginConnect(remoteEP, iar => + // Use ValueTaskReceive so the AwaitableSocketAsyncEventArgs can be re-used later. + AwaitableSocketAsyncEventArgs saea = LazyInitializer.EnsureInitialized(ref EventArgs.ValueTaskReceive, () => new AwaitableSocketAsyncEventArgs()); + + if (!saea.Reserve()) { - var innerTcs = (TaskCompletionSource)iar.AsyncState!; - try - { - ((Socket)innerTcs.Task.AsyncState!).EndConnect(iar); - innerTcs.TrySetResult(true); - } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; + saea = new AwaitableSocketAsyncEventArgs(); + saea.Reserve(); + } + + saea.RemoteEndPoint = remoteEP; + return saea.ConnectAsync(this).AsTask(); } internal Task ConnectAsync(IPAddress address, int port) + => ConnectAsync(new IPEndPoint(address, port)); + + internal Task ConnectAsync(IPAddress[] addresses, int port) { - var tcs = new TaskCompletionSource(this); - BeginConnect(address, port, iar => + if (addresses == null) { - var innerTcs = (TaskCompletionSource)iar.AsyncState!; - try - { - ((Socket)innerTcs.Task.AsyncState!).EndConnect(iar); - innerTcs.TrySetResult(true); - } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; + throw new ArgumentNullException(nameof(addresses)); + } + if (addresses.Length == 0) + { + throw new ArgumentException(SR.net_invalidAddressList, nameof(addresses)); + } + + return DoConnectAsync(addresses, port); } - internal Task ConnectAsync(IPAddress[] addresses, int port) + private async Task DoConnectAsync(IPAddress[] addresses, int port) { - var tcs = new TaskCompletionSource(this); - BeginConnect(addresses, port, iar => + Exception? lastException = null; + foreach (IPAddress address in addresses) { - var innerTcs = (TaskCompletionSource)iar.AsyncState!; try { - ((Socket)innerTcs.Task.AsyncState!).EndConnect(iar); - innerTcs.TrySetResult(true); + await ConnectAsync(address, port).ConfigureAwait(false); + return; } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; + catch (Exception ex) + { + lastException = ex; + } + } + Debug.Assert(lastException != null); + ExceptionDispatchInfo.Throw(lastException); } internal Task ConnectAsync(string host, int port) { - var tcs = new TaskCompletionSource(this); - BeginConnect(host, port, iar => + if (host == null) { - var innerTcs = (TaskCompletionSource)iar.AsyncState!; - try - { - ((Socket)innerTcs.Task.AsyncState!).EndConnect(iar); - innerTcs.TrySetResult(true); - } - catch (Exception e) { innerTcs.TrySetException(e); } - }, tcs); - return tcs.Task; + throw new ArgumentNullException(nameof(host)); + } + + if (IPAddress.TryParse(host, out IPAddress? parsedAddress)) + { + return ConnectAsync(new IPEndPoint(parsedAddress, port)); + } + else + { + return ConnectAsync(new DnsEndPoint(host, port)); + } } internal Task ReceiveAsync(ArraySegment buffer, SocketFlags socketFlags, bool fromNetworkStream) @@ -946,6 +950,32 @@ public ValueTask SendAsyncForNetworkStream(Socket socket, CancellationToken canc new ValueTask(Task.FromException(CreateException(error))); } + public ValueTask ConnectAsync(Socket socket) + { + Debug.Assert(Volatile.Read(ref _continuation) == null, $"Expected null continuation to indicate reserved for use"); + + try + { + if (socket.ConnectAsync(this)) + { + return new ValueTask(this, _token); + } + } + catch + { + Release(); + throw; + } + + SocketError error = SocketError; + + Release(); + + return error == SocketError.Success ? + default : + new ValueTask(Task.FromException(CreateException(error))); + } + /// Gets the status of the operation. public ValueTaskSourceStatus GetStatus(short token) { diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ArgumentValidationTests.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ArgumentValidationTests.cs index e11e5f7778b0b..655534f857ca0 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/ArgumentValidationTests.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/ArgumentValidationTests.cs @@ -1114,7 +1114,14 @@ public void BeginConnect_IPAddresses_EmptyIPAddresses_Throws_Argument() public void BeginConnect_IPAddresses_InvalidPort_Throws_ArgumentOutOfRange(int port) { Assert.Throws(() => GetSocket().BeginConnect(new[] { IPAddress.Loopback }, port, TheAsyncCallback, null)); - Assert.Throws(() => { GetSocket().ConnectAsync(new[] { IPAddress.Loopback }, port); }); + } + + [Theory] + [InlineData(-1)] + [InlineData(65536)] + public async Task ConnectAsync_IPAddresses_InvalidPort_Throws_ArgumentOutOfRange(int port) + { + await Assert.ThrowsAsync(() => GetSocket().ConnectAsync(new[] { IPAddress.Loopback }, port)); } [Fact] @@ -1126,12 +1133,16 @@ public void BeginConnect_IPAddresses_ListeningSocket_Throws_InvalidOperation() socket.Listen(1); Assert.Throws(() => socket.BeginConnect(new[] { IPAddress.Loopback }, 1, TheAsyncCallback, null)); } + } + [Fact] + public async Task ConnectAsync_IPAddresses_ListeningSocket_Throws_InvalidOperation() + { using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) { socket.Bind(new IPEndPoint(IPAddress.Loopback, 0)); socket.Listen(1); - Assert.Throws(() => { socket.ConnectAsync(new[] { IPAddress.Loopback }, 1); }); + await Assert.ThrowsAsync(() => socket.ConnectAsync(new[] { IPAddress.Loopback }, 1)); } } diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs index 6f01d22ca57bb..d4bdcfed80a0f 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs @@ -125,9 +125,6 @@ public async Task Connect_AfterDisconnect_Fails() [PlatformSpecific(~(TestPlatforms.OSX | TestPlatforms.FreeBSD))] // Not supported on BSD like OSes. public async Task ConnectGetsCanceledByDispose() { - bool usesApm = UsesApm || - (this is ConnectTask); // .NET Core ConnectAsync Task API is implemented using Apm - // We try this a couple of times to deal with a timing race: if the Dispose happens // before the operation is started, we won't see a SocketException. int msDelay = 100; @@ -167,7 +164,7 @@ await RetryHelper.ExecuteAsync(async () => disposedException = true; } - if (usesApm) + if (UsesApm) { Assert.Null(localSocketError); Assert.True(disposedException);