Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change how futures are set up in the worker events test #666

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 52 additions & 37 deletions tests/unit_tests/worker/test_task_worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import itertools
import threading
from collections.abc import Callable, Iterable
from collections.abc import AsyncIterator, Callable, Iterable
from concurrent.futures import Future
from queue import Full
from typing import Any, TypeVar
Expand Down Expand Up @@ -358,15 +359,32 @@ def assert_running_count_plan_produces_ordered_worker_and_data_events(
]

count = itertools.count()
events: Future[list[Any]] = take_events_from_streams(
event_streams,
lambda _: next(count) >= len(expected_events) - 1,
)
events = []

async def collect_events():
events_iterator = take_events_from_streams(
event_streams,
lambda _: next(count) >= len(expected_events) - 1,
)
async for event in events_iterator:
events.append(event)
if len(events) >= len(expected_events):
break

task_id = worker.submit_task(task)
worker.begin_task(task_id)
results = events.result(timeout=timeout)
# Await for events to be collected with proper timeout
try:
asyncio.run(asyncio.wait_for(collect_events(), timeout=timeout))
except asyncio.TimeoutError:
pytest.fail(f"Test timed out after {timeout} seconds while waiting for events.")

_compare_events(expected_events, task_id, events)


def _compare_events(
expected_events: list[DataEvent | WorkerEvent], task_id: str, results: list[Any]
) -> None:
for actual, expected in itertools.zip_longest(results, expected_events):
if isinstance(expected, WorkerEvent):
if expected.task_status:
Expand Down Expand Up @@ -405,47 +423,44 @@ def on_event(event: E, event_id: str | None) -> None:
return future


def take_events_from_streams(
streams: list[EventStream[Any, int]],
async def take_events_from_streams(
streams: list["EventStream[Any, Any]"],
cutoff_predicate: Callable[[Any], bool],
) -> Future[list[Any]]:
"""Returns a collated list of futures for events in numerous event streams.

The support for generic and algebraic types doesn't appear to extend to
taking an arbitrary list of concrete types with single but differing
generic arguments while also maintaining the generality of the argument
types.

The type for streams will be any combination of event streams each of a
given event type, where the event type is generic:

List[
Union[
EventStream[WorkerEvent, int],
EventStream[DataEvent, int],
EventStream[ProgressEvent, int]
]
]
) -> AsyncIterator[Any]:
"""Returns an async generator that yields events from multiple event streams."""

"""
events: list[Any] = []
future: Future[list[Any]] = Future()
event_queue = asyncio.Queue() # Queue to store events from the streams
cutoff_reached = asyncio.Event() # Event to signal when to stop listening

def on_event(event: Any, event_id: str | None) -> None:
print(event)
events.append(event)
"""Callback for events."""
event_queue.put_nowait(event) # Add the event to the async queue
if cutoff_predicate(event):
future.set_result(events)
cutoff_reached.set() # Signal the cutoff event

subscriptions = []

# Subscribe to all the event streams
for stream in streams:
sub = stream.subscribe(on_event)

def callback(unused: Future[list[Any]], stream=stream, sub=sub):
subscriptions.append((stream, sub))

async def event_producer():
"""Asynchronously yield events from the queue."""
while not cutoff_reached.is_set():
event = await event_queue.get() # Wait for the next event
yield event

try:
# Yield events using the event_producer async generator
async for event in event_producer():
yield event

finally:
# Ensure we unsubscribe from all streams once done
for stream, sub in subscriptions:
stream.unsubscribe(sub)

future.add_done_callback(callback)
return future


@pytest.mark.parametrize(
"status, expected_task_ids",
Expand Down
Loading