From 133e929a791d209b578b4822a7a07f4570b3803b Mon Sep 17 00:00:00 2001 From: Peter Bierma Date: Mon, 30 Sep 2024 21:37:27 -0400 Subject: [PATCH] gh-124309: Revert eager task factory fix to prevent breaking downstream (#124810) * Revert "GH-124639: add back loop param to staggered_race (#124700)" This reverts commit e0a41a5dd12cb6e9277b05abebac5c70be684dd7. * Revert "gh-124309: Modernize the `staggered_race` implementation to support eager task factories (#124390)" This reverts commit de929f353c413459834a2a37b2d9b0240673d874. --- Lib/asyncio/base_events.py | 2 +- Lib/asyncio/staggered.py | 83 ++++++++++++++----- .../test_asyncio/test_eager_task_factory.py | 47 ----------- Lib/test/test_asyncio/test_staggered.py | 56 +------------ ...-09-23-18-18-23.gh-issue-124309.iFcarA.rst | 1 - 5 files changed, 65 insertions(+), 124 deletions(-) delete mode 100644 Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index ffcc0174e1e245..000647f57dd9e3 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -1144,7 +1144,7 @@ async def create_connection( (functools.partial(self._connect_sock, exceptions, addrinfo, laddr_infos) for addrinfo in infos), - happy_eyeballs_delay) + happy_eyeballs_delay, loop=self) if sock is None: exceptions = [exc for sub in exceptions for exc in sub] diff --git a/Lib/asyncio/staggered.py b/Lib/asyncio/staggered.py index 6ccf5c3c269ff0..c3a7441a7b091d 100644 --- a/Lib/asyncio/staggered.py +++ b/Lib/asyncio/staggered.py @@ -4,12 +4,11 @@ import contextlib +from . import events +from . import exceptions as exceptions_mod from . import locks from . import tasks -from . import taskgroups -class _Done(Exception): - pass async def staggered_race(coro_fns, delay, *, loop=None): """Run coroutines with staggered start times and take the first to finish. @@ -43,6 +42,8 @@ async def staggered_race(coro_fns, delay, *, loop=None): delay: amount of time, in seconds, between starting coroutines. If ``None``, the coroutines will run sequentially. + loop: the event loop to use. + Returns: tuple *(winner_result, winner_index, exceptions)* where @@ -61,11 +62,36 @@ async def staggered_race(coro_fns, delay, *, loop=None): """ # TODO: when we have aiter() and anext(), allow async iterables in coro_fns. + loop = loop or events.get_running_loop() + enum_coro_fns = enumerate(coro_fns) winner_result = None winner_index = None exceptions = [] + running_tasks = [] + + async def run_one_coro(previous_failed) -> None: + # Wait for the previous task to finish, or for delay seconds + if previous_failed is not None: + with contextlib.suppress(exceptions_mod.TimeoutError): + # Use asyncio.wait_for() instead of asyncio.wait() here, so + # that if we get cancelled at this point, Event.wait() is also + # cancelled, otherwise there will be a "Task destroyed but it is + # pending" later. + await tasks.wait_for(previous_failed.wait(), delay) + # Get the next coroutine to run + try: + this_index, coro_fn = next(enum_coro_fns) + except StopIteration: + return + # Start task that will run the next coroutine + this_failed = locks.Event() + next_task = loop.create_task(run_one_coro(this_failed)) + running_tasks.append(next_task) + assert len(running_tasks) == this_index + 2 + # Prepare place to put this coroutine's exceptions if not won + exceptions.append(None) + assert len(exceptions) == this_index + 1 - async def run_one_coro(this_index, coro_fn, this_failed): try: result = await coro_fn() except (SystemExit, KeyboardInterrupt): @@ -79,23 +105,34 @@ async def run_one_coro(this_index, coro_fn, this_failed): assert winner_index is None winner_index = this_index winner_result = result - raise _Done - + # Cancel all other tasks. We take care to not cancel the current + # task as well. If we do so, then since there is no `await` after + # here and CancelledError are usually thrown at one, we will + # encounter a curious corner case where the current task will end + # up as done() == True, cancelled() == False, exception() == + # asyncio.CancelledError. This behavior is specified in + # https://bugs.python.org/issue30048 + for i, t in enumerate(running_tasks): + if i != this_index: + t.cancel() + + first_task = loop.create_task(run_one_coro(None)) + running_tasks.append(first_task) try: - tg = taskgroups.TaskGroup() - # Intentionally override the loop in the TaskGroup to avoid - # using the running loop, preserving backwards compatibility - # TaskGroup only starts using `_loop` after `__aenter__` - # so overriding it here is safe. - tg._loop = loop - async with tg: - for this_index, coro_fn in enumerate(coro_fns): - this_failed = locks.Event() - exceptions.append(None) - tg.create_task(run_one_coro(this_index, coro_fn, this_failed)) - with contextlib.suppress(TimeoutError): - await tasks.wait_for(this_failed.wait(), delay) - except* _Done: - pass - - return winner_result, winner_index, exceptions + # Wait for a growing list of tasks to all finish: poor man's version of + # curio's TaskGroup or trio's nursery + done_count = 0 + while done_count != len(running_tasks): + done, _ = await tasks.wait(running_tasks) + done_count = len(done) + # If run_one_coro raises an unhandled exception, it's probably a + # programming error, and I want to see it. + if __debug__: + for d in done: + if d.done() and not d.cancelled() and d.exception(): + raise d.exception() + return winner_result, winner_index, exceptions + finally: + # Make sure no tasks are left running if we leave this function + for t in running_tasks: + t.cancel() diff --git a/Lib/test/test_asyncio/test_eager_task_factory.py b/Lib/test/test_asyncio/test_eager_task_factory.py index 1579ad1188d725..0777f39b572486 100644 --- a/Lib/test/test_asyncio/test_eager_task_factory.py +++ b/Lib/test/test_asyncio/test_eager_task_factory.py @@ -213,53 +213,6 @@ async def run(): self.run_coro(run()) - def test_staggered_race_with_eager_tasks(self): - # See https://github.com/python/cpython/issues/124309 - - async def fail(): - await asyncio.sleep(0) - raise ValueError("no good") - - async def run(): - winner, index, excs = await asyncio.staggered.staggered_race( - [ - lambda: asyncio.sleep(2, result="sleep2"), - lambda: asyncio.sleep(1, result="sleep1"), - lambda: fail() - ], - delay=0.25 - ) - self.assertEqual(winner, 'sleep1') - self.assertEqual(index, 1) - self.assertIsNone(excs[index]) - self.assertIsInstance(excs[0], asyncio.CancelledError) - self.assertIsInstance(excs[2], ValueError) - - self.run_coro(run()) - - def test_staggered_race_with_eager_tasks_no_delay(self): - # See https://github.com/python/cpython/issues/124309 - async def fail(): - raise ValueError("no good") - - async def run(): - winner, index, excs = await asyncio.staggered.staggered_race( - [ - lambda: fail(), - lambda: asyncio.sleep(1, result="sleep1"), - lambda: asyncio.sleep(0, result="sleep0"), - ], - delay=None - ) - self.assertEqual(winner, 'sleep1') - self.assertEqual(index, 1) - self.assertIsNone(excs[index]) - self.assertIsInstance(excs[0], ValueError) - self.assertEqual(len(excs), 2) - - self.run_coro(run()) - - class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase): Task = tasks._PyTask diff --git a/Lib/test/test_asyncio/test_staggered.py b/Lib/test/test_asyncio/test_staggered.py index 8cd98394aea8f8..e6e32f7dbbbcba 100644 --- a/Lib/test/test_asyncio/test_staggered.py +++ b/Lib/test/test_asyncio/test_staggered.py @@ -82,64 +82,16 @@ async def test_none_successful(self): async def coro(index): raise ValueError(index) - for delay in [None, 0, 0.1, 1]: - with self.subTest(delay=delay): - winner, index, excs = await staggered_race( - [ - lambda: coro(0), - lambda: coro(1), - ], - delay=delay, - ) - - self.assertIs(winner, None) - self.assertIs(index, None) - self.assertEqual(len(excs), 2) - self.assertIsInstance(excs[0], ValueError) - self.assertIsInstance(excs[1], ValueError) - - async def test_long_delay_early_failure(self): - async def coro(index): - await asyncio.sleep(0) # Dummy coroutine for the 1 case - if index == 0: - await asyncio.sleep(0.1) # Dummy coroutine - raise ValueError(index) - - return f'Res: {index}' - winner, index, excs = await staggered_race( [ lambda: coro(0), lambda: coro(1), ], - delay=10, + delay=None, ) - self.assertEqual(winner, 'Res: 1') - self.assertEqual(index, 1) + self.assertIs(winner, None) + self.assertIs(index, None) self.assertEqual(len(excs), 2) self.assertIsInstance(excs[0], ValueError) - self.assertIsNone(excs[1]) - - def test_loop_argument(self): - loop = asyncio.new_event_loop() - async def coro(): - self.assertEqual(loop, asyncio.get_running_loop()) - return 'coro' - - async def main(): - winner, index, excs = await staggered_race( - [coro], - delay=0.1, - loop=loop - ) - - self.assertEqual(winner, 'coro') - self.assertEqual(index, 0) - - loop.run_until_complete(main()) - loop.close() - - -if __name__ == "__main__": - unittest.main() + self.assertIsInstance(excs[1], ValueError) diff --git a/Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst b/Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst deleted file mode 100644 index 89610fa44bf743..00000000000000 --- a/Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst +++ /dev/null @@ -1 +0,0 @@ -Fixed :exc:`AssertionError` when using :func:`!asyncio.staggered.staggered_race` with :attr:`asyncio.eager_task_factory`.