Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: rewrite staggered_race to be race safe #101

Merged
merged 59 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
4c96188
feat: rewrite staggered_race to be race safe
bdraco Sep 30, 2024
67743c3
feat: rewrite staggered_race to be race safe
bdraco Sep 30, 2024
f13db9f
wip
bdraco Sep 30, 2024
1056e9e
wip
bdraco Sep 30, 2024
47c4b1d
wip
bdraco Sep 30, 2024
5338f38
wip
bdraco Sep 30, 2024
9d0838c
tweak
bdraco Sep 30, 2024
e2d38d7
comments
bdraco Sep 30, 2024
41c4959
comments
bdraco Sep 30, 2024
6c87e32
comments
bdraco Sep 30, 2024
7e49064
cleanup
bdraco Sep 30, 2024
11bf906
cleanup
bdraco Sep 30, 2024
26a6b6d
fix: py38
bdraco Sep 30, 2024
c19d61a
fix: py38
bdraco Sep 30, 2024
36fc530
fix: py38
bdraco Sep 30, 2024
4623300
fix: py38
bdraco Sep 30, 2024
e410565
fix: tests
bdraco Sep 30, 2024
013b951
tweak
bdraco Sep 30, 2024
aa55328
fix: coverage
bdraco Sep 30, 2024
18e2912
fix: coverage
bdraco Sep 30, 2024
cbe6f7c
fix: add eager task factory tests
bdraco Sep 30, 2024
df0171d
fix: add eager task factory tests
bdraco Sep 30, 2024
6a7fd91
naming
bdraco Sep 30, 2024
7a4b659
fix: lint
bdraco Sep 30, 2024
44a6e8b
comments
bdraco Sep 30, 2024
42e2b99
rename tests
bdraco Sep 30, 2024
7c66687
rename tests
bdraco Sep 30, 2024
cf347f1
needs a guard
bdraco Sep 30, 2024
580b0cd
tweak
bdraco Sep 30, 2024
5d3fd43
tweak
bdraco Sep 30, 2024
ec88d1e
preen
bdraco Sep 30, 2024
450f183
cleanup
bdraco Sep 30, 2024
86b768d
cleanup
bdraco Sep 30, 2024
d3c2ecd
cleanup
bdraco Sep 30, 2024
4dd8a72
avoid consuming iterable
bdraco Sep 30, 2024
88de10f
avoid consuming iterable
bdraco Sep 30, 2024
1569c42
avoid consuming iterable
bdraco Sep 30, 2024
7d3bc9e
avoid consuming iterable
bdraco Sep 30, 2024
c7bc061
avoid consuming iterable
bdraco Sep 30, 2024
f7ea011
avoid consuming iterable
bdraco Sep 30, 2024
591ca0d
avoid consuming iterable
bdraco Sep 30, 2024
a3fa3ef
avoid consuming iterable
bdraco Sep 30, 2024
9e5edc4
avoid consuming iterable
bdraco Sep 30, 2024
b71b000
avoid consuming iterable
bdraco Sep 30, 2024
c302bf8
comments
bdraco Sep 30, 2024
9f1e400
comments
bdraco Sep 30, 2024
604eb5e
comments
bdraco Sep 30, 2024
964cf86
comments
bdraco Sep 30, 2024
faf5353
comments
bdraco Sep 30, 2024
39746b6
reduce
bdraco Sep 30, 2024
c4b7053
Revert "reduce"
bdraco Sep 30, 2024
b7e69ce
reduce
bdraco Sep 30, 2024
b7745fd
Merge branch 'main' into stag
bdraco Sep 30, 2024
29c4e46
chore(pre-commit.ci): auto fixes
pre-commit-ci[bot] Sep 30, 2024
28d9c91
more coverage
bdraco Sep 30, 2024
31a5a77
Merge remote-tracking branch 'origin/stag' into stag
bdraco Sep 30, 2024
c59f344
chore: add tests for multiple winners
bdraco Sep 30, 2024
c3e7170
chore: add tests for multiple winners
bdraco Sep 30, 2024
9fd0030
Merge branch 'main' into stag
bdraco Sep 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 130 additions & 29 deletions src/aiohappyeyeballs/_staggered.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,54 @@
import asyncio
import contextlib
from typing import Awaitable, Callable, Iterable, List, Optional, Tuple, TypeVar
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
)

_T = TypeVar("_T")

class _Done(Exception):
pass

def _set_result(wait_next: "asyncio.Future[None]") -> None:
"""Set the result of a future if it is not already done."""
if not wait_next.done():
wait_next.set_result(None)

_T = TypeVar("_T")

async def _wait_one(
futures: "Iterable[asyncio.Future[Any]]",
loop: asyncio.AbstractEventLoop,
) -> _T:
"""Wait for the first future to complete."""
wait_next = loop.create_future()

def _on_completion(fut: "asyncio.Future[Any]") -> None:
if not wait_next.done():
wait_next.set_result(fut)

for f in futures:
f.add_done_callback(_on_completion)

try:
return await wait_next
finally:
for f in futures:
f.remove_done_callback(_on_completion)


async def staggered_race(
coro_fns: Iterable[Callable[[], Awaitable[_T]]], delay: Optional[float]
coro_fns: Iterable[Callable[[], Awaitable[_T]]],
delay: Optional[float],
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> Tuple[Optional[_T], Optional[int], List[Optional[BaseException]]]:
"""
Run coroutines with staggered start times and take the first to finish.
Expand All @@ -38,14 +75,18 @@ async def staggered_race(
raise

Args:
----
coro_fns: an iterable of coroutine functions, i.e. callables that
return a coroutine object when called. Use ``functools.partial`` or
lambdas to pass arguments.

delay: amount of time, in seconds, between starting coroutines. If
``None``, the coroutines will run sequentially.

loop: the event loop to use. If ``None``, the running loop is used.

Returns:
-------
tuple *(winner_result, winner_index, exceptions)* where

- *winner_result*: the result of the winning coroutine, or ``None``
Expand All @@ -62,40 +103,100 @@ async def staggered_race(
coroutine's entry is ``None``.

"""
# TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
winner_result = None
winner_index = None
loop = loop or asyncio.get_running_loop()
exceptions: List[Optional[BaseException]] = []
tasks: Set[asyncio.Task[Optional[Tuple[_T, int]]]] = set()

async def run_one_coro(
this_index: int,
coro_fn: Callable[[], Awaitable[_T]],
this_failed: asyncio.Event,
) -> None:
this_index: int,
start_next: "asyncio.Future[None]",
) -> Optional[Tuple[_T, int]]:
"""
Run a single coroutine.

If the coroutine fails, set the exception in the exceptions list and
start the next coroutine by setting the result of the start_next.

If the coroutine succeeds, return the result and the index of the
coroutine in the coro_fns list.

If SystemExit or KeyboardInterrupt is raised, re-raise it.
"""
try:
result = await coro_fn()
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as e:
exceptions[this_index] = e
this_failed.set() # Kickstart the next coroutine
else:
# Store winner's results
nonlocal winner_index, winner_result
assert winner_index is None # noqa: S101
winner_index = this_index
winner_result = result
raise _Done
_set_result(start_next) # Kickstart the next coroutine
return None

return result, this_index

start_next_timer: Optional[asyncio.TimerHandle] = None
start_next: Optional[asyncio.Future[None]]
task: asyncio.Task[Optional[Tuple[_T, int]]]
done: Union[asyncio.Future[None], asyncio.Task[Optional[Tuple[_T, int]]]]
coro_iter = iter(coro_fns)
this_index = -1
try:
async with asyncio.TaskGroup() as tg:
for this_index, coro_fn in enumerate(coro_fns):
this_failed = asyncio.Event()
while True:
if coro_fn := next(coro_iter, None):
this_index += 1
exceptions.append(None)
tg.create_task(run_one_coro(this_index, coro_fn, this_failed))
with contextlib.suppress(TimeoutError):
await asyncio.wait_for(this_failed.wait(), delay)
except* _Done:
pass

return winner_result, winner_index, exceptions
start_next = loop.create_future()
task = loop.create_task(run_one_coro(coro_fn, this_index, start_next))
tasks.add(task)
start_next_timer = (
loop.call_later(delay, _set_result, start_next) if delay else None
)
elif not tasks:
# We exhausted the coro_fns list and no tasks are running
# so we have no winner and all coroutines failed.
break

while tasks:
done = await _wait_one(
[*tasks, start_next] if start_next else tasks, loop
)
if done is start_next:
# The current task has failed or the timer has expired
# so we need to start the next task.
start_next = None
if start_next_timer:
start_next_timer.cancel()
start_next_timer = None

# Break out of the task waiting loop to start the next
# task.
break

if TYPE_CHECKING:
assert isinstance(done, asyncio.Task)

tasks.remove(done)
if winner := done.result():
return *winner, exceptions
finally:
# We either have:
# - a winner
# - all tasks failed
# - a KeyboardInterrupt or SystemExit.

#
# If the timer is still running, cancel it.
#
if start_next_timer:
start_next_timer.cancel()

#
# If there are any tasks left, cancel them and than
# wait them so they fill the exceptions list.
#
for task in tasks:
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task

return None, None, exceptions
4 changes: 2 additions & 2 deletions src/aiohappyeyeballs/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from typing import List, Optional, Sequence, Union

from . import staggered
from . import _staggered
from .types import AddrInfoType

if sys.version_info < (3, 8, 2): # noqa: UP036
Expand Down Expand Up @@ -86,7 +86,7 @@ async def start_connection(
except (RuntimeError, OSError):
continue
else: # using happy eyeballs
sock, _, _ = await staggered.staggered_race(
sock, _, _ = await _staggered.staggered_race(
(
functools.partial(
_connect_sock, current_loop, exceptions, addrinfo, local_addr_infos
Expand Down
9 changes: 0 additions & 9 deletions src/aiohappyeyeballs/staggered.py

This file was deleted.

32 changes: 32 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Configuration for the tests."""

import asyncio
import threading
from typing import Generator

import pytest


@pytest.fixture(autouse=True)
def verify_threads_ended():
"""Verify that the threads are not running after the test."""
threads_before = frozenset(threading.enumerate())
yield
threads = frozenset(threading.enumerate()) - threads_before
assert not threads


@pytest.fixture(autouse=True)
def verify_no_lingering_tasks(
event_loop: asyncio.AbstractEventLoop,
) -> Generator[None, None, None]:
"""Verify that all tasks are cleaned up."""
tasks_before = asyncio.all_tasks(event_loop)
yield

tasks = asyncio.all_tasks(event_loop) - tasks_before
for task in tasks:
pytest.fail(f"Task still running: {task!r}")
task.cancel()
if tasks:
event_loop.run_until_complete(asyncio.wait(tasks))
86 changes: 86 additions & 0 deletions tests/test_staggered.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import asyncio
import sys
from functools import partial

import pytest

from aiohappyeyeballs._staggered import staggered_race


@pytest.mark.asyncio
async def test_one_winners():
"""Test that there is only one winner when there is no await in the coro."""
winners = []

async def coro(idx):
winners.append(idx)
return idx

coros = [partial(coro, idx) for idx in range(4)]

winner, index, excs = await staggered_race(
coros,
delay=None,
)
assert len(winners) == 1
assert winners == [0]
assert winner == 0
assert index == 0
assert excs == [None]


@pytest.mark.asyncio
async def test_multiple_winners():
"""Test multiple winners are handled correctly."""
loop = asyncio.get_running_loop()
winners = []
finish = loop.create_future()

async def coro(idx):
await finish
winners.append(idx)
return idx

coros = [partial(coro, idx) for idx in range(4)]

task = loop.create_task(staggered_race(coros, delay=0.00001))
await asyncio.sleep(0.1)
loop.call_soon(finish.set_result, None)
winner, index, excs = await task
assert len(winners) == 4
assert winners == [0, 1, 2, 3]
assert winner == 0
assert index == 0
assert excs == [None, None, None, None]


@pytest.mark.skipif(sys.version_info < (3, 12), reason="requires python3.12 or higher")
def test_multiple_winners_eager_task_factory():
"""Test multiple winners are handled correctly."""
loop = asyncio.new_event_loop()
eager_task_factory = asyncio.create_eager_task_factory(asyncio.Task)
loop.set_task_factory(eager_task_factory)
asyncio.set_event_loop(None)

async def run():
winners = []
finish = loop.create_future()

async def coro(idx):
await finish
winners.append(idx)
return idx

coros = [partial(coro, idx) for idx in range(4)]

task = loop.create_task(staggered_race(coros, delay=0.00001))
await asyncio.sleep(0.1)
loop.call_soon(finish.set_result, None)
winner, index, excs = await task
assert len(winners) == 4
assert winners == [0, 1, 2, 3]
assert winner == 0
assert index == 0
assert excs == [None, None, None, None]

loop.run_until_complete(run())
Loading
Loading