Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-124309: fix staggered race on eager tasks #124847

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions Lib/asyncio/staggered.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ async def staggered_race(coro_fns, delay, *, loop=None):
exceptions = []
running_tasks = []

async def run_one_coro(previous_failed) -> None:
async def run_one_coro(ok_to_start, previous_failed) -> None:
await ok_to_start.wait()
# Wait for the previous task to finish, or for delay seconds
if previous_failed is not None:
with contextlib.suppress(exceptions_mod.TimeoutError):
Expand All @@ -85,8 +86,10 @@ async def run_one_coro(previous_failed) -> None:
return
# Start task that will run the next coroutine
this_failed = locks.Event()
next_task = loop.create_task(run_one_coro(this_failed))
next_ok_to_start = locks.Event()
next_task = loop.create_task(run_one_coro(next_ok_to_start, this_failed))
running_tasks.append(next_task)
next_ok_to_start.set()
assert len(running_tasks) == this_index + 2
# Prepare place to put this coroutine's exceptions if not won
exceptions.append(None)
Expand Down Expand Up @@ -116,8 +119,10 @@ async def run_one_coro(previous_failed) -> None:
if i != this_index:
t.cancel()

first_task = loop.create_task(run_one_coro(None))
ok_to_start = locks.Event()
first_task = loop.create_task(run_one_coro(ok_to_start, None))
running_tasks.append(first_task)
ok_to_start.set()
try:
# Wait for a growing list of tasks to all finish: poor man's version of
# curio's TaskGroup or trio's nursery
Expand Down
46 changes: 46 additions & 0 deletions Lib/test/test_asyncio/test_eager_task_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,52 @@ async def run():

self.run_coro(run())

def test_staggered_race_with_eager_tasks(self):
# See https://github.com/python/cpython/issues/124309

async def fail():
await asyncio.sleep(0)
raise ValueError("no good")

async def run():
winner, index, excs = await asyncio.staggered.staggered_race(
[
lambda: asyncio.sleep(2, result="sleep2"),
lambda: asyncio.sleep(1, result="sleep1"),
lambda: fail()
],
delay=0.25
)
self.assertEqual(winner, 'sleep1')
self.assertEqual(index, 1)
self.assertIsNone(excs[index])
self.assertIsInstance(excs[0], asyncio.CancelledError)
self.assertIsInstance(excs[2], ValueError)

self.run_coro(run())

def test_staggered_race_with_eager_tasks_no_delay(self):
# See https://github.com/python/cpython/issues/124309
async def fail():
raise ValueError("no good")

async def run():
winner, index, excs = await asyncio.staggered.staggered_race(
[
lambda: fail(),
lambda: asyncio.sleep(1, result="sleep1"),
lambda: asyncio.sleep(0, result="sleep0"),
],
delay=None
)
self.assertEqual(winner, 'sleep1')
self.assertEqual(index, 1)
self.assertIsNone(excs[index])
self.assertIsInstance(excs[0], ValueError)
self.assertEqual(len(excs), 2)

self.run_coro(run())


class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
Task = tasks._PyTask
Expand Down
27 changes: 27 additions & 0 deletions Lib/test/test_asyncio/test_staggered.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,30 @@ async def coro(index):
self.assertEqual(len(excs), 2)
self.assertIsInstance(excs[0], ValueError)
self.assertIsInstance(excs[1], ValueError)


async def test_multiple_winners(self):
event = asyncio.Event()

async def coro(index):
await event.wait()
return index

async def do_set():
event.set()
await asyncio.Event().wait()

winner, index, excs = await staggered_race(
[
lambda: coro(0),
lambda: coro(1),
do_set,
],
delay=0.1,
)
self.assertIs(winner, 0)
self.assertIs(index, 0)
self.assertEqual(len(excs), 3)
self.assertIsNone(excs[0], None)
self.assertIsInstance(excs[1], asyncio.CancelledError)
self.assertIsInstance(excs[2], asyncio.CancelledError)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed :exc:`AssertionError` when using :func:`!asyncio.staggered.staggered_race` with :attr:`asyncio.eager_task_factory`.
Loading