From e0a41a5dd12cb6e9277b05abebac5c70be684dd7 Mon Sep 17 00:00:00 2001 From: Kumar Aditya Date: Sun, 29 Sep 2024 08:42:46 +0530 Subject: [PATCH] GH-124639: add back loop param to staggered_race (#124700) --- Lib/asyncio/staggered.py | 10 ++++++++-- Lib/test/test_asyncio/test_staggered.py | 19 +++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/Lib/asyncio/staggered.py b/Lib/asyncio/staggered.py index 4458d01dece0e6..6ccf5c3c269ff0 100644 --- a/Lib/asyncio/staggered.py +++ b/Lib/asyncio/staggered.py @@ -11,7 +11,7 @@ class _Done(Exception): pass -async def staggered_race(coro_fns, delay): +async def staggered_race(coro_fns, delay, *, loop=None): """Run coroutines with staggered start times and take the first to finish. This method takes an iterable of coroutine functions. The first one is @@ -82,7 +82,13 @@ async def run_one_coro(this_index, coro_fn, this_failed): raise _Done try: - async with taskgroups.TaskGroup() as tg: + tg = taskgroups.TaskGroup() + # Intentionally override the loop in the TaskGroup to avoid + # using the running loop, preserving backwards compatibility + # TaskGroup only starts using `_loop` after `__aenter__` + # so overriding it here is safe. + tg._loop = loop + async with tg: for this_index, coro_fn in enumerate(coro_fns): this_failed = locks.Event() exceptions.append(None) diff --git a/Lib/test/test_asyncio/test_staggered.py b/Lib/test/test_asyncio/test_staggered.py index 21a39b3f911747..8cd98394aea8f8 100644 --- a/Lib/test/test_asyncio/test_staggered.py +++ b/Lib/test/test_asyncio/test_staggered.py @@ -121,6 +121,25 @@ async def coro(index): self.assertIsInstance(excs[0], ValueError) self.assertIsNone(excs[1]) + def test_loop_argument(self): + loop = asyncio.new_event_loop() + async def coro(): + self.assertEqual(loop, asyncio.get_running_loop()) + return 'coro' + + async def main(): + winner, index, excs = await staggered_race( + [coro], + delay=0.1, + loop=loop + ) + + self.assertEqual(winner, 'coro') + self.assertEqual(index, 0) + + loop.run_until_complete(main()) + loop.close() + if __name__ == "__main__": unittest.main()