diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTestBase.SocketsHttpHandler.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTestBase.SocketsHttpHandler.cs index 602d177f5f1be..888b38b813127 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTestBase.SocketsHttpHandler.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTestBase.SocketsHttpHandler.cs @@ -3,14 +3,31 @@ using System.IO; using System.Net.Quic; +using System.Net.Sockets; using System.Net.Test.Common; using System.Reflection; +using System.Threading; using System.Threading.Tasks; namespace System.Net.Http.Functional.Tests { public abstract partial class HttpClientHandlerTestBase : FileCleanupTestBase { + protected static async Task DefaultConnectCallback(EndPoint endPoint, CancellationToken cancellationToken) + { + Socket socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; + try + { + await socket.ConnectAsync(endPoint, cancellationToken); + return new NetworkStream(socket, ownsSocket: true); + } + catch + { + socket.Dispose(); + throw; + } + } + protected static bool IsWinHttpHandler => false; public static bool IsQuicSupported diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/MetricsTest.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/MetricsTest.cs index f3a97d2f15d53..da36366246f41 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/MetricsTest.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/MetricsTest.cs @@ -292,17 +292,8 @@ await LoopbackServerFactory.CreateClientAndServerAsync(async uri => GetUnderlyingSocketsHttpHandler(Handler).ConnectCallback = async (ctx, cancellationToken) => { connectionStarted.SetResult(); - Socket socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; - try - { - await socket.ConnectAsync(ctx.DnsEndPoint, cancellationToken); - return new NetworkStream(socket, ownsSocket: true); - } - catch - { - socket.Dispose(); - throw; - } + + return await DefaultConnectCallback(ctx.DnsEndPoint, cancellationToken); }; // Enable recording request-duration to test the path with metrics enabled. diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Cancellation.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Cancellation.cs index c793a1d55d6e7..76d7086c37c17 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Cancellation.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.Cancellation.cs @@ -165,9 +165,7 @@ await LoopbackServerFactory.CreateClientAndServerAsync(async uri => else { // Succeed the second connection attempt - Socket socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; - await socket.ConnectAsync(context.DnsEndPoint, token); - return new NetworkStream(socket, ownsSocket: true); + return await DefaultConnectCallback(context.DnsEndPoint, token); } }; diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs index a6869ed981b57..2613b451be455 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/SocketsHttpHandlerTest.cs @@ -1369,17 +1369,7 @@ await RetryHelper.ExecuteAsync(async () => { Assert.Equal("foo", context.DnsEndPoint.Host); - Socket socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; - try - { - await socket.ConnectAsync(lastServerUri.IdnHost, lastServerUri.Port); - return new NetworkStream(socket, ownsSocket: true); - } - catch - { - socket.Dispose(); - throw; - } + return await DefaultConnectCallback(new DnsEndPoint(lastServerUri.IdnHost, lastServerUri.Port), ct); }; TaskCompletionSource waitingForLastRequest = new(TaskCreationOptions.RunContinuationsAsynchronously); @@ -2659,30 +2649,18 @@ public async Task Http2_MultipleConnectionsEnabled_ManyRequestsEnqueuedSimultane AcquireAllStreamSlots(server, client, sendTasks, RequestCount); - List<(Http2LoopbackConnection connection, int streamId)> acceptedRequests = new(); - await using Http2LoopbackConnection c1 = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 100 }); - for (int i = 0; i < MaxConcurrentStreams; i++) - { - (int streamId, _) = await c1.ReadAndParseRequestHeaderAsync(); - acceptedRequests.Add((c1, streamId)); - } + int[] streamIds1 = await AcceptRequests(c1, MaxConcurrentStreams); await using Http2LoopbackConnection c2 = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 100 }); - for (int i = 0; i < MaxConcurrentStreams; i++) - { - (int streamId, _) = await c2.ReadAndParseRequestHeaderAsync(); - acceptedRequests.Add((c2, streamId)); - } + int[] streamIds2 = await AcceptRequests(c2, MaxConcurrentStreams); await using Http2LoopbackConnection c3 = await server.EstablishConnectionAsync(new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = 100 }); (int finalStreamId, _) = await c3.ReadAndParseRequestHeaderAsync(); - acceptedRequests.Add((c3, finalStreamId)); - foreach ((Http2LoopbackConnection connection, int streamId) request in acceptedRequests) - { - await request.connection.SendDefaultResponseAsync(request.streamId); - } + await SendResponses(c1, streamIds1); + await SendResponses(c2, streamIds2); + await c3.SendDefaultResponseAsync(finalStreamId); await VerifySendTasks(sendTasks); } @@ -2702,25 +2680,22 @@ public async Task Http2_MultipleConnectionsEnabled_InfiniteRequestsCompletelyBlo Http2LoopbackConnection connection0 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams); - // Block the first connection on infinite requests. + // Accept requests but don't send responses on connection 0 int[] blockedStreamIds = await AcceptRequests(connection0, MaxConcurrentStreams).ConfigureAwait(false); - Assert.Equal(MaxConcurrentStreams, blockedStreamIds.Length); Http2LoopbackConnection connection1 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams); - await HandleAllPendingRequests(connection1, MaxConcurrentStreams).ConfigureAwait(false); + // Send responses on connection 1 + await SendResponses(connection1, await AcceptRequests(connection1, MaxConcurrentStreams).ConfigureAwait(false)); - // Complete infinite requests. - int handledRequestCount = await SendResponses(connection0, blockedStreamIds); - - Assert.Equal(MaxConcurrentStreams, handledRequestCount); + // Send responses on connection 0 + await SendResponses(connection0, blockedStreamIds); await VerifySendTasks(sendTasks).ConfigureAwait(false); } [ConditionalFact(nameof(SupportsAlpn))] - [ActiveIssue("https://github.com/dotnet/runtime/issues/91075", TestPlatforms.AnyUnix)] public async Task Http2_MultipleConnectionsEnabled_OpenAndCloseMultipleConnections_Success() { if (PlatformDetection.IsAndroid && (PlatformDetection.IsX86Process || PlatformDetection.IsX64Process)) @@ -2730,44 +2705,62 @@ public async Task Http2_MultipleConnectionsEnabled_OpenAndCloseMultipleConnectio const int MaxConcurrentStreams = 2; using Http2LoopbackServer server = Http2LoopbackServer.CreateServer(); + server.AllowMultipleConnections = true; + + // Allow 5 connections through the ConnectCallback. + SemaphoreSlim connectCallbackSemaphore = new(initialCount: 5); + using SocketsHttpHandler handler = CreateHandler(); + + handler.ConnectCallback = async (context, ct) => + { + await connectCallbackSemaphore.WaitAsync(ct); + + return await DefaultConnectCallback(context.DnsEndPoint, ct); + }; + using (HttpClient client = CreateHttpClient(handler)) { - server.AllowMultipleConnections = true; - List> sendTasks = new List>(); + List> sendTasks = new(); + Http2LoopbackConnection connection0 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams); + int[] streamIds0 = await AcceptRequests(connection0, MaxConcurrentStreams).ConfigureAwait(false); + Http2LoopbackConnection connection1 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams); + int[] streamIds1 = await AcceptRequests(connection1, MaxConcurrentStreams).ConfigureAwait(false); + Http2LoopbackConnection connection2 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams); + int[] streamIds2 = await AcceptRequests(connection2, MaxConcurrentStreams).ConfigureAwait(false); - Task[] handleRequestTasks = new[] { - HandleAllPendingRequests(connection0, MaxConcurrentStreams), - HandleAllPendingRequests(connection1, MaxConcurrentStreams), - HandleAllPendingRequests(connection2, MaxConcurrentStreams) - }; - - await TestHelper.WhenAllCompletedOrAnyFailed(handleRequestTasks).ConfigureAwait(false); + await TestHelper.WhenAllCompletedOrAnyFailed( + SendResponses(connection0, streamIds0), + SendResponses(connection1, streamIds1), + SendResponses(connection2, streamIds2)) + .ConfigureAwait(false); - await connection0.ShutdownIgnoringErrorsAsync(await handleRequestTasks[0]).ConfigureAwait(false); - await connection2.ShutdownIgnoringErrorsAsync(await handleRequestTasks[2]).ConfigureAwait(false); + await connection0.ShutdownIgnoringErrorsAsync(streamIds0[^1]).ConfigureAwait(false); + await connection2.ShutdownIgnoringErrorsAsync(streamIds2[^1]).ConfigureAwait(false); - //Fill all connection1's stream slots + // Fill all connection1's stream slots AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams); + streamIds1 = await AcceptRequests(connection1, MaxConcurrentStreams).ConfigureAwait(false); Http2LoopbackConnection connection3 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams); + int[] streamIds3 = await AcceptRequests(connection3, MaxConcurrentStreams).ConfigureAwait(false); + Http2LoopbackConnection connection4 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams); + int[] streamIds4 = await AcceptRequests(connection4, MaxConcurrentStreams).ConfigureAwait(false); - Task[] finalHandleTasks = new[] { - HandleAllPendingRequests(connection1, MaxConcurrentStreams), - HandleAllPendingRequests(connection3, MaxConcurrentStreams), - HandleAllPendingRequests(connection4, MaxConcurrentStreams) - }; - - await TestHelper.WhenAllCompletedOrAnyFailed(finalHandleTasks).ConfigureAwait(false); + await TestHelper.WhenAllCompletedOrAnyFailed( + SendResponses(connection1, streamIds1), + SendResponses(connection3, streamIds3), + SendResponses(connection4, streamIds4)) + .ConfigureAwait(false); await VerifySendTasks(sendTasks).ConfigureAwait(false); } @@ -2775,29 +2768,40 @@ public async Task Http2_MultipleConnectionsEnabled_OpenAndCloseMultipleConnectio [ConditionalFact(nameof(SupportsAlpn))] [OuterLoop("Incurs long delay")] - [ActiveIssue("https://github.com/dotnet/runtime/issues/91075", TestPlatforms.AnyUnix)] public async Task Http2_MultipleConnectionsEnabled_IdleConnectionTimeoutExpired_ConnectionRemovedAndNewCreated() { const int MaxConcurrentStreams = 2; using Http2LoopbackServer server = Http2LoopbackServer.CreateServer(); + server.AllowMultipleConnections = true; + + SemaphoreSlim connectCallbackSemaphore = new(initialCount: 2); + using SocketsHttpHandler handler = CreateHandler(); handler.PooledConnectionIdleTimeout = TimeSpan.FromSeconds(20); + + handler.ConnectCallback = async (context, ct) => + { + await connectCallbackSemaphore.WaitAsync(ct); + + return await DefaultConnectCallback(context.DnsEndPoint, ct); + }; + using (HttpClient client = CreateHttpClient(handler)) { - server.AllowMultipleConnections = true; - List> sendTasks = new List>(); + List> sendTasks0 = new(); + List> sendTasks1 = new(); + List> sendTasks2 = new(); + Http2LoopbackConnection connection0 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); - AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams); - int[] acceptedStreamIds = await AcceptRequests(connection0, MaxConcurrentStreams).ConfigureAwait(false); - Assert.Equal(MaxConcurrentStreams, acceptedStreamIds.Length); + AcquireAllStreamSlots(server, client, sendTasks0, MaxConcurrentStreams); + int[] streamIds0 = await AcceptRequests(connection0, MaxConcurrentStreams).ConfigureAwait(false); - List> connection1SendTasks = new List>(); Http2LoopbackConnection connection1 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); - AcquireAllStreamSlots(server, client, connection1SendTasks, MaxConcurrentStreams); - await HandleAllPendingRequests(connection1, MaxConcurrentStreams).ConfigureAwait(false); + AcquireAllStreamSlots(server, client, sendTasks1, MaxConcurrentStreams); + await SendResponses(connection1, await AcceptRequests(connection1, MaxConcurrentStreams).ConfigureAwait(false)); - // Complete all the requests. - await VerifySendTasks(connection1SendTasks).ConfigureAwait(false); + // Complete all the requests on connection1. + await VerifySendTasks(sendTasks1).ConfigureAwait(false); // Wait until the idle connection timeout expires. await connection1.WaitForClientDisconnectAsync(false).WaitAsync(TestHelper.PassingTestTimeout).ConfigureAwait(false); @@ -2805,17 +2809,20 @@ public async Task Http2_MultipleConnectionsEnabled_IdleConnectionTimeoutExpired_ Assert.True(connection1.IsInvalid); Assert.False(connection0.IsInvalid); - Http2LoopbackConnection connection2 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); - - AcquireAllStreamSlots(server, client, sendTasks, MaxConcurrentStreams); + // Due to a race condition in how a new Http2 connection is returned to the pool, we may have started a third connection attempt in the background. + // We were blocking such attempts from going through to the Socket layer until now to avoid having to deal with the extra connect when accepting connection2 below. + // Allow the third connection through the ConnectCallback now. + connectCallbackSemaphore.Release(); - await HandleAllPendingRequests(connection2, MaxConcurrentStreams).ConfigureAwait(false); + Http2LoopbackConnection connection2 = await PrepareConnection(server, client, MaxConcurrentStreams).ConfigureAwait(false); + AcquireAllStreamSlots(server, client, sendTasks2, MaxConcurrentStreams); + await SendResponses(connection2, await AcceptRequests(connection2, MaxConcurrentStreams).ConfigureAwait(false)); - //Make sure connection0 is still alive. - int handledRequests0 = await SendResponses(connection0, acceptedStreamIds).ConfigureAwait(false); - Assert.Equal(MaxConcurrentStreams, handledRequests0); + // Make sure connection0 is still alive. + await SendResponses(connection0, streamIds0).ConfigureAwait(false); - await VerifySendTasks(sendTasks).ConfigureAwait(false); + await VerifySendTasks(sendTasks0).ConfigureAwait(false); + await VerifySendTasks(sendTasks2).ConfigureAwait(false); } } @@ -2844,7 +2851,10 @@ private async Task PrepareConnection(Http2LoopbackServe Task warmUpTask = client.GetAsync(server.Address); - Http2LoopbackConnection connection = await GetConnection(server, maxConcurrentStreams).WaitAsync(TestHelper.PassingTestTimeout).ConfigureAwait(false); + var concurrentStreamsSetting = new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = maxConcurrentStreams }; + + Http2LoopbackConnection connection = await server.EstablishConnectionAsync(timeout: null, ackTimeout: TimeSpan.FromSeconds(10), concurrentStreamsSetting) + .WaitAsync(TestHelper.PassingTestTimeout).ConfigureAwait(false); (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync().WaitAsync(TestHelper.PassingTestTimeout).ConfigureAwait(false); await connection.SendDefaultResponseAsync(streamId).WaitAsync(TestHelper.PassingTestTimeout).ConfigureAwait(false); @@ -2864,49 +2874,25 @@ private static void AcquireAllStreamSlots(Http2LoopbackServer server, HttpClient } } - private static async Task GetConnection(Http2LoopbackServer server, uint maxConcurrentStreams) - { - var concurrentStreamsSetting = new SettingsEntry { SettingId = SettingId.MaxConcurrentStreams, Value = maxConcurrentStreams }; - - return await server.EstablishConnectionAsync(timeout: null, ackTimeout: TimeSpan.FromSeconds(10), concurrentStreamsSetting).ConfigureAwait(false); - } - - private async Task HandleAllPendingRequests(Http2LoopbackConnection connection, int totalRequestCount) - { - int lastStreamId = -1; - for (int i = 0; i < totalRequestCount; i++) - { - (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync().ConfigureAwait(false); - await connection.SendDefaultResponseAsync(streamId).ConfigureAwait(false); - lastStreamId = streamId; - } - - return lastStreamId; - } - private async Task AcceptRequests(Http2LoopbackConnection connection, int requestCount) { int[] streamIds = new int[requestCount]; for (int i = 0; i < streamIds.Length; i++) { - (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync().ConfigureAwait(false); + (int streamId, _) = await connection.ReadAndParseRequestHeaderAsync().WaitAsync(TestHelper.PassingTestTimeout).ConfigureAwait(false); streamIds[i] = streamId; } return streamIds; } - private async Task SendResponses(Http2LoopbackConnection connection, IEnumerable streamIds) + private async Task SendResponses(Http2LoopbackConnection connection, IEnumerable streamIds) { - int count = 0; foreach (int streamId in streamIds) { - count++; - await connection.SendDefaultResponseAsync(streamId).ConfigureAwait(false); + await connection.SendDefaultResponseAsync(streamId).WaitAsync(TestHelper.PassingTestTimeout).ConfigureAwait(false); } - - return count; } } @@ -3110,10 +3096,7 @@ public async Task ConnectCallback_ConnectionPrefix_Success(bool useSsl) var socketsHandler = (SocketsHttpHandler)GetUnderlyingSocketsHttpHandler(handler); socketsHandler.ConnectCallback = async (context, token) => { - Socket clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - await clientSocket.ConnectAsync(listenSocket.LocalEndPoint); - - Stream clientStream = new NetworkStream(clientSocket, ownsSocket: true); + Stream clientStream = await DefaultConnectCallback(listenSocket.LocalEndPoint, token); await clientStream.WriteAsync(RequestPrefix);