diff --git a/src/aiohappyeyeballs/_staggered.py b/src/aiohappyeyeballs/_staggered.py new file mode 100644 index 0000000..b5c6798 --- /dev/null +++ b/src/aiohappyeyeballs/_staggered.py @@ -0,0 +1,101 @@ +import asyncio +import contextlib +from typing import Awaitable, Callable, Iterable, List, Optional, Tuple, TypeVar + + +class _Done(Exception): + pass + + +_T = TypeVar("_T") + + +async def staggered_race( + coro_fns: Iterable[Callable[[], Awaitable[_T]]], delay: Optional[float] +) -> Tuple[Optional[_T], Optional[int], List[Optional[BaseException]]]: + """ + Run coroutines with staggered start times and take the first to finish. + + This method takes an iterable of coroutine functions. The first one is + started immediately. From then on, whenever the immediately preceding one + fails (raises an exception), or when *delay* seconds has passed, the next + coroutine is started. This continues until one of the coroutines complete + successfully, in which case all others are cancelled, or until all + coroutines fail. + + The coroutines provided should be well-behaved in the following way: + + * They should only ``return`` if completed successfully. + + * They should always raise an exception if they did not complete + successfully. In particular, if they handle cancellation, they should + probably reraise, like this:: + + try: + # do work + except asyncio.CancelledError: + # undo partially completed work + raise + + Args: + coro_fns: an iterable of coroutine functions, i.e. callables that + return a coroutine object when called. Use ``functools.partial`` or + lambdas to pass arguments. + + delay: amount of time, in seconds, between starting coroutines. If + ``None``, the coroutines will run sequentially. + + Returns: + tuple *(winner_result, winner_index, exceptions)* where + + - *winner_result*: the result of the winning coroutine, or ``None`` + if no coroutines won. + + - *winner_index*: the index of the winning coroutine in + ``coro_fns``, or ``None`` if no coroutines won. If the winning + coroutine may return None on success, *winner_index* can be used + to definitively determine whether any coroutine won. + + - *exceptions*: list of exceptions returned by the coroutines. + ``len(exceptions)`` is equal to the number of coroutines actually + started, and the order is the same as in ``coro_fns``. The winning + coroutine's entry is ``None``. + + """ + # TODO: when we have aiter() and anext(), allow async iterables in coro_fns. + winner_result = None + winner_index = None + exceptions: List[Optional[BaseException]] = [] + + async def run_one_coro( + this_index: int, + coro_fn: Callable[[], Awaitable[_T]], + this_failed: asyncio.Event, + ) -> None: + try: + result = await coro_fn() + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as e: + exceptions[this_index] = e + this_failed.set() # Kickstart the next coroutine + else: + # Store winner's results + nonlocal winner_index, winner_result + assert winner_index is None # noqa: S101 + winner_index = this_index + winner_result = result + raise _Done + + try: + async with asyncio.TaskGroup() as tg: + for this_index, coro_fn in enumerate(coro_fns): + this_failed = asyncio.Event() + exceptions.append(None) + tg.create_task(run_one_coro(this_index, coro_fn, this_failed)) + with contextlib.suppress(TimeoutError): + await asyncio.wait_for(this_failed.wait(), delay) + except* _Done: + pass + + return winner_result, winner_index, exceptions diff --git a/src/aiohappyeyeballs/impl.py b/src/aiohappyeyeballs/impl.py index 39e7672..1017e82 100644 --- a/src/aiohappyeyeballs/impl.py +++ b/src/aiohappyeyeballs/impl.py @@ -6,9 +6,9 @@ import itertools import socket import sys -from asyncio import staggered from typing import List, Optional, Sequence +from . import staggered from .types import AddrInfoType if sys.version_info < (3, 8, 2): # noqa: UP036 diff --git a/src/aiohappyeyeballs/staggered.py b/src/aiohappyeyeballs/staggered.py new file mode 100644 index 0000000..6a8b391 --- /dev/null +++ b/src/aiohappyeyeballs/staggered.py @@ -0,0 +1,9 @@ +import sys + +if sys.version_info > (3, 11): + # https://github.com/python/cpython/issues/124639#issuecomment-2378129834 + from ._staggered import staggered_race +else: + from asyncio.staggered import staggered_race + +__all__ = ["staggered_race"] diff --git a/tests/test_impl.py b/tests/test_impl.py index 61d34d7..cf23ee2 100644 --- a/tests/test_impl.py +++ b/tests/test_impl.py @@ -1368,6 +1368,88 @@ async def _sock_connect( ] +@patch_socket +@pytest.mark.asyncio +@pytest.mark.xfail(reason="raises RuntimeError: coroutine ignored GeneratorExit") +async def test_handling_system_exit( + m_socket: ModuleType, +) -> None: + """Test handling SystemExit.""" + mock_socket = mock.MagicMock( + family=socket.AF_INET, + type=socket.SOCK_STREAM, + proto=socket.IPPROTO_TCP, + fileno=mock.MagicMock(return_value=1), + ) + create_calls = [] + + def _socket(*args, **kw): + for attr in kw: + setattr(mock_socket, attr, kw[attr]) + return mock_socket + + async def _sock_connect( + sock: socket.socket, address: Tuple[str, int, int, int] + ) -> None: + create_calls.append(address) + raise SystemExit + + m_socket.socket = _socket # type: ignore + ipv6_addr_info = ( + socket.AF_INET6, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("dead:beef::", 80, 0, 0), + ) + ipv6_addr_info_2 = ( + socket.AF_INET6, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("dead:aaaa::", 80, 0, 0), + ) + ipv4_addr_info = ( + socket.AF_INET, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("107.6.106.83", 80), + ) + addr_info = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info] + local_addr_infos = [ + ( + socket.AF_INET6, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("::1", 0, 0, 0), + ), + ( + socket.AF_INET, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("127.0.0.1", 0), + ), + ] + loop = asyncio.get_running_loop() + with pytest.raises(SystemExit), mock.patch.object( + loop, "sock_connect", _sock_connect + ): + await start_connection( + addr_info, + happy_eyeballs_delay=0.3, + interleave=2, + local_addr_infos=local_addr_infos, + ) + + # Stopped after the first call + assert create_calls == [ + ("dead:beef::", 80, 0, 0), + ] + + @pytest.mark.asyncio @pytest.mark.skipif(sys.version_info >= (3, 8, 2), reason="requires < python 3.8.2") def test_python_38_compat() -> None: