diff --git a/tests/_async/test_http2.py b/tests/_async/test_http2.py index b4ec66488..42542481c 100644 --- a/tests/_async/test_http2.py +++ b/tests/_async/test_http2.py @@ -1,8 +1,12 @@ +import time + import hpack import hyperframe.frame import pytest +import trio as concurrency import httpcore +from tests import h2server @pytest.mark.anyio @@ -380,3 +384,28 @@ async def test_http2_remote_max_streams_update(): conn._h2_state.local_settings.max_concurrent_streams, ) i += 1 + + +@pytest.mark.trio +@pytest.mark.xfail(reason="https://github.com/encode/httpx/discussions/3278") +async def test_slow_overlapping_requests(): + fetches = [] + + with h2server.run() as server: + url = f"http://127.0.0.1:{server.port}/" + + async with httpcore.AsyncConnectionPool(http1=False, http2=True) as pool: + + async def fetch(start_delay): + await concurrency.sleep(start_delay) + + start = time.time() + await pool.request("GET", url) + end = time.time() + fetches.append(round(end - start, 1)) + + async with concurrency.open_nursery() as nursery: + for start_delay in [0, 0.2, 0.4, 0.6, 0.8]: + nursery.start_soon(fetch, start_delay) + + assert fetches == [1.0] * 5 diff --git a/tests/_sync/test_http2.py b/tests/_sync/test_http2.py index 695359bd6..334da3970 100644 --- a/tests/_sync/test_http2.py +++ b/tests/_sync/test_http2.py @@ -1,8 +1,12 @@ +import time + import hpack import hyperframe.frame import pytest +from tests import concurrency import httpcore +from tests import h2server @@ -380,3 +384,28 @@ def test_http2_remote_max_streams_update(): conn._h2_state.local_settings.max_concurrent_streams, ) i += 1 + + + +@pytest.mark.xfail(reason="https://github.com/encode/httpx/discussions/3278") +def test_slow_overlapping_requests(): + fetches = [] + + with h2server.run() as server: + url = f"http://127.0.0.1:{server.port}/" + + with httpcore.ConnectionPool(http1=False, http2=True) as pool: + + def fetch(start_delay): + concurrency.sleep(start_delay) + + start = time.time() + pool.request("GET", url) + end = time.time() + fetches.append(round(end - start, 1)) + + with concurrency.open_nursery() as nursery: + for start_delay in [0, 0.2, 0.4, 0.6, 0.8]: + nursery.start_soon(fetch, start_delay) + + assert fetches == [1.0] * 5 diff --git a/tests/concurrency.py b/tests/concurrency.py index a0572d531..382d206a3 100644 --- a/tests/concurrency.py +++ b/tests/concurrency.py @@ -10,6 +10,7 @@ """ import threading +import time from types import TracebackType from typing import Any, Callable, List, Optional, Type @@ -39,3 +40,7 @@ def start_soon(self, func: Callable[..., object], *args: Any) -> None: def open_nursery() -> Nursery: return Nursery() + + +def sleep(seconds: float) -> None: + time.sleep(seconds) diff --git a/tests/h2server.py b/tests/h2server.py new file mode 100644 index 000000000..6b910c72f --- /dev/null +++ b/tests/h2server.py @@ -0,0 +1,93 @@ +import contextlib +import logging +import socket +import threading +import time + +import h2.config +import h2.connection +import h2.events + + +def send_response(sock, conn, event): + start = time.time() + logging.info("Starting %s.", event) + + time.sleep(1) + + stream_id = event.stream_id + conn.send_headers( + stream_id=stream_id, + headers=[(":status", "200"), ("server", "basic-h2-server/1.0")], + ) + data_to_send = conn.data_to_send() + if data_to_send: + sock.sendall(data_to_send) + + conn.send_data(stream_id=stream_id, data=b"it works!", end_stream=True) + data_to_send = conn.data_to_send() + if data_to_send: + sock.sendall(data_to_send) + + end = time.time() + logging.info("Finished %s in %.03fs.", event, end - start) + + +def handle(sock: socket.socket) -> None: + config = h2.config.H2Configuration(client_side=False) + conn = h2.connection.H2Connection(config=config) + conn.initiate_connection() + sock.sendall(conn.data_to_send()) + + while True: + data = sock.recv(65535) + if not data: + sock.close() + break + + events = conn.receive_data(data) + for event in events: + if isinstance(event, h2.events.RequestReceived): + threading.Thread(target=send_response, args=(sock, conn, event)).start() + + +class HTTP2Server: + def __init__( + self, *, host: str = "127.0.0.1", port: int = 0, timeout: float = 0.2 + ) -> None: + self.sock = socket.socket() + self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.sock.settimeout(timeout) + self.sock.bind((host, port)) + self.port = self.sock.getsockname()[1] + self.sock.listen(5) + + def run(self) -> None: + while True: + try: + handle(self.sock.accept()[0]) + except socket.timeout: # pragma: no cover + pass + except OSError: + break + + +@contextlib.contextmanager +def run(**kwargs): + server = HTTP2Server(**kwargs) + thr = threading.Thread(target=server.run) + thr.start() + try: + yield server + finally: + server.sock.close() + thr.join() + + +if __name__ == "__main__": # pragma: no cover + logging.basicConfig( + format="%(relativeCreated)5i <%(threadName)s> %(filename)s:%(lineno)s] %(message)s", + level=logging.INFO, + ) + + HTTP2Server(port=8100).run()