From bc8651097eb24e70a6adb7ed9346e284d079c52b Mon Sep 17 00:00:00 2001 From: Vlad Emelianov Date: Fri, 27 Sep 2024 05:06:30 +0300 Subject: [PATCH 1/3] Support Python 3.8+ in backported asyncio.staggered --- src/aiohappyeyeballs/_staggered.py | 90 +++++++++++++++++++++--------- 1 file changed, 65 insertions(+), 25 deletions(-) diff --git a/src/aiohappyeyeballs/_staggered.py b/src/aiohappyeyeballs/_staggered.py index b5c6798..0d5727c 100644 --- a/src/aiohappyeyeballs/_staggered.py +++ b/src/aiohappyeyeballs/_staggered.py @@ -1,17 +1,17 @@ import asyncio import contextlib +from asyncio import events, locks, tasks +from asyncio import exceptions as exceptions_mod from typing import Awaitable, Callable, Iterable, List, Optional, Tuple, TypeVar - -class _Done(Exception): - pass - - _T = TypeVar("_T") async def staggered_race( - coro_fns: Iterable[Callable[[], Awaitable[_T]]], delay: Optional[float] + coro_fns: Iterable[Callable[[], Awaitable[_T]]], + delay: Optional[float], + *, + loop: Optional[asyncio.AbstractEventLoop] = None, ) -> Tuple[Optional[_T], Optional[int], List[Optional[BaseException]]]: """ Run coroutines with staggered start times and take the first to finish. @@ -45,6 +45,8 @@ async def staggered_race( delay: amount of time, in seconds, between starting coroutines. If ``None``, the coroutines will run sequentially. + loop: the event loop to use. + Returns: tuple *(winner_result, winner_index, exceptions)* where @@ -63,15 +65,36 @@ async def staggered_race( """ # TODO: when we have aiter() and anext(), allow async iterables in coro_fns. + loop = loop or events.get_running_loop() + enum_coro_fns = enumerate(coro_fns) winner_result = None winner_index = None exceptions: List[Optional[BaseException]] = [] + running_tasks: List[tasks.Task] = [] + + async def run_one_coro(previous_failed) -> None: + # Wait for the previous task to finish, or for delay seconds + if previous_failed is not None: + with contextlib.suppress(exceptions_mod.TimeoutError): + # Use asyncio.wait_for() instead of asyncio.wait() here, so + # that if we get cancelled at this point, Event.wait() is also + # cancelled, otherwise there will be a "Task destroyed but it is + # pending" later. + await tasks.wait_for(previous_failed.wait(), delay) + # Get the next coroutine to run + try: + this_index, coro_fn = next(enum_coro_fns) + except StopIteration: + return + # Start task that will run the next coroutine + this_failed = locks.Event() + next_task = loop.create_task(run_one_coro(this_failed)) + running_tasks.append(next_task) + assert len(running_tasks) == this_index + 2 + # Prepare place to put this coroutine's exceptions if not won + exceptions.append(None) + assert len(exceptions) == this_index + 1 - async def run_one_coro( - this_index: int, - coro_fn: Callable[[], Awaitable[_T]], - this_failed: asyncio.Event, - ) -> None: try: result = await coro_fn() except (SystemExit, KeyboardInterrupt): @@ -82,20 +105,37 @@ async def run_one_coro( else: # Store winner's results nonlocal winner_index, winner_result - assert winner_index is None # noqa: S101 + assert winner_index is None winner_index = this_index winner_result = result - raise _Done - + # Cancel all other tasks. We take care to not cancel the current + # task as well. If we do so, then since there is no `await` after + # here and CancelledError are usually thrown at one, we will + # encounter a curious corner case where the current task will end + # up as done() == True, cancelled() == False, exception() == + # asyncio.CancelledError. This behavior is specified in + # https://bugs.python.org/issue30048 + for i, t in enumerate(running_tasks): + if i != this_index: + t.cancel() + + first_task = loop.create_task(run_one_coro(None)) + running_tasks.append(first_task) try: - async with asyncio.TaskGroup() as tg: - for this_index, coro_fn in enumerate(coro_fns): - this_failed = asyncio.Event() - exceptions.append(None) - tg.create_task(run_one_coro(this_index, coro_fn, this_failed)) - with contextlib.suppress(TimeoutError): - await asyncio.wait_for(this_failed.wait(), delay) - except* _Done: - pass - - return winner_result, winner_index, exceptions + # Wait for a growing list of tasks to all finish: poor man's version of + # curio's TaskGroup or trio's nursery + done_count = 0 + while done_count != len(running_tasks): + done, _ = await tasks.wait(running_tasks) + done_count = len(done) + # If run_one_coro raises an unhandled exception, it's probably a + # programming error, and I want to see it. + if __debug__: + for d in done: + if d.done() and not d.cancelled() and d.exception(): + raise d.exception() + return winner_result, winner_index, exceptions + finally: + # Make sure no tasks are left running if we leave this function + for t in running_tasks: + t.cancel() From 93c252a64a80213304f66d9e7a8ee39a62532b91 Mon Sep 17 00:00:00 2001 From: Vlad Emelianov Date: Fri, 27 Sep 2024 05:21:08 +0300 Subject: [PATCH 2/3] Fix linting issues --- src/aiohappyeyeballs/_staggered.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/aiohappyeyeballs/_staggered.py b/src/aiohappyeyeballs/_staggered.py index 0d5727c..f40a1d3 100644 --- a/src/aiohappyeyeballs/_staggered.py +++ b/src/aiohappyeyeballs/_staggered.py @@ -70,9 +70,9 @@ async def staggered_race( winner_result = None winner_index = None exceptions: List[Optional[BaseException]] = [] - running_tasks: List[tasks.Task] = [] + running_tasks: List[tasks.Task[None]] = [] - async def run_one_coro(previous_failed) -> None: + async def run_one_coro(previous_failed: Optional[locks.Event]) -> None: # Wait for the previous task to finish, or for delay seconds if previous_failed is not None: with contextlib.suppress(exceptions_mod.TimeoutError): @@ -90,10 +90,10 @@ async def run_one_coro(previous_failed) -> None: this_failed = locks.Event() next_task = loop.create_task(run_one_coro(this_failed)) running_tasks.append(next_task) - assert len(running_tasks) == this_index + 2 + assert len(running_tasks) == this_index + 2 # noqa: S101 # Prepare place to put this coroutine's exceptions if not won exceptions.append(None) - assert len(exceptions) == this_index + 1 + assert len(exceptions) == this_index + 1 # noqa: S101 try: result = await coro_fn() @@ -105,7 +105,7 @@ async def run_one_coro(previous_failed) -> None: else: # Store winner's results nonlocal winner_index, winner_result - assert winner_index is None + assert winner_index is None # noqa: S101 winner_index = this_index winner_result = result # Cancel all other tasks. We take care to not cancel the current @@ -133,7 +133,7 @@ async def run_one_coro(previous_failed) -> None: if __debug__: for d in done: if d.done() and not d.cancelled() and d.exception(): - raise d.exception() + raise d.exception() # type: ignore return winner_result, winner_index, exceptions finally: # Make sure no tasks are left running if we leave this function From 391fd026cc96dca8fb26a0e25892803319bda382 Mon Sep 17 00:00:00 2001 From: Vlad Emelianov Date: Fri, 27 Sep 2024 05:22:47 +0300 Subject: [PATCH 3/3] Add missing type annotations --- src/aiohappyeyeballs/_staggered.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aiohappyeyeballs/_staggered.py b/src/aiohappyeyeballs/_staggered.py index f40a1d3..f0aa13a 100644 --- a/src/aiohappyeyeballs/_staggered.py +++ b/src/aiohappyeyeballs/_staggered.py @@ -67,8 +67,8 @@ async def staggered_race( # TODO: when we have aiter() and anext(), allow async iterables in coro_fns. loop = loop or events.get_running_loop() enum_coro_fns = enumerate(coro_fns) - winner_result = None - winner_index = None + winner_result: Optional[_T] = None + winner_index: Optional[int] = None exceptions: List[Optional[BaseException]] = [] running_tasks: List[tasks.Task[None]] = []