diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index f8c69bfef12..3ae81ffce96 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -14,14 +14,19 @@ # ============================================================================== """Fleet Simulation Engine API.""" + import asyncio import json import sys +import threading import time import traceback +from concurrent.futures import ThreadPoolExecutor from logging import DEBUG, ERROR, INFO, WARN from pathlib import Path -from typing import Callable, Dict, List, Optional +from queue import Empty, Queue +from time import sleep +from typing import Callable, Dict, Optional from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError from flwr.client.node_state import NodeState @@ -31,7 +36,7 @@ from flwr.common.object_ref import load_app from flwr.common.serde import message_from_taskins, message_to_taskres from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 -from flwr.server.superlink.state import StateFactory +from flwr.server.superlink.state import State, StateFactory from .backend import Backend, error_messages_backends, supported_backends @@ -52,18 +57,21 @@ def _register_nodes( # pylint: disable=too-many-arguments,too-many-locals -async def worker( +def worker( app_fn: Callable[[], ClientApp], - taskins_queue: "asyncio.Queue[TaskIns]", - taskres_queue: "asyncio.Queue[TaskRes]", + taskins_queue: Queue[TaskIns], + taskres_queue: Queue[TaskRes], node_states: Dict[int, NodeState], backend: Backend, + f_stop: asyncio.Event, ) -> None: """Get TaskIns from queue and pass it to an actor in the pool to execute it.""" - while True: + while not f_stop.is_set(): out_mssg = None try: - task_ins: TaskIns = await taskins_queue.get() + # Fetch from queue with timeout. We use a timeout so + # the stopping event can be evaluated even when the queue is empty. + task_ins: TaskIns = taskins_queue.get(timeout=1.0) node_id = task_ins.task.consumer.node_id # Register and retrieve runstate @@ -82,11 +90,9 @@ async def worker( node_states[node_id].update_context( task_ins.run_id, context=updated_context ) - - except asyncio.CancelledError as e: - log(DEBUG, "Terminating async worker: %s", e) - break - + except Empty: + # An exception raised if queue.get times out + pass # Exceptions aren't raised but reported as an error message except Exception as ex: # pylint: disable=broad-exception-caught log(ERROR, ex) @@ -110,73 +116,38 @@ async def worker( task_res = message_to_taskres(out_mssg) # Store TaskRes in state task_res.task.pushed_at = time.time() - await taskres_queue.put(task_res) + taskres_queue.put(task_res) -async def add_taskins_to_queue( - queue: "asyncio.Queue[TaskIns]", - state_factory: StateFactory, +def add_taskins_to_queue( + state: State, + queue: Queue[TaskIns], nodes_mapping: NodeToPartitionMapping, - backend: Backend, - consumers: List["asyncio.Task[None]"], f_stop: asyncio.Event, ) -> None: - """Retrieve TaskIns and add it to the queue.""" - state = state_factory.state() - num_initial_consumers = len(consumers) + """Put TaskIns in a queue from State.""" while not f_stop.is_set(): for node_id in nodes_mapping.keys(): - task_ins = state.get_task_ins(node_id=node_id, limit=1) - if task_ins: - await queue.put(task_ins[0]) - - # Count consumers that are running - num_active = sum(not (cc.done()) for cc in consumers) - - # Alert if number of consumers decreased by half - if num_active < num_initial_consumers // 2: - log( - WARN, - "Number of active workers has more than halved: (%i/%i active)", - num_active, - num_initial_consumers, - ) - - # Break if consumers died - if num_active == 0: - raise RuntimeError("All workers have died. Ending Simulation.") - - # Log some stats - log( - DEBUG, - "Simulation Engine stats: " - "Active workers: (%i/%i) | %s (%i workers) | Tasks in queue: %i)", - num_active, - num_initial_consumers, - backend.__class__.__name__, - backend.num_workers, - queue.qsize(), - ) - await asyncio.sleep(1.0) - log(DEBUG, "Async producer: Stopped pulling from StateFactory.") + task_ins_list = state.get_task_ins(node_id=node_id, limit=1) + for task_ins in task_ins_list: + queue.put(task_ins) + sleep(0.1) -async def put_taskres_into_state( - queue: "asyncio.Queue[TaskRes]", - state_factory: StateFactory, - f_stop: asyncio.Event, +def put_taskres_into_state( + state: State, queue: Queue[TaskRes], f_stop: threading.Event ) -> None: - """Remove TaskRes from queue and add into State.""" - state = state_factory.state() + """Put TaskRes into State from a queue.""" while not f_stop.is_set(): - if queue.qsize(): - task_res = await queue.get() - state.store_task_res(task_res) - else: - await asyncio.sleep(0.1) + try: + taskres = queue.get(timeout=1.0) + state.store_task_res(taskres) + except Empty: + # queue is empty when timeout was triggered + pass -async def run( +def run( app_fn: Callable[[], ClientApp], backend_fn: Callable[[], Backend], nodes_mapping: NodeToPartitionMapping, @@ -184,9 +155,9 @@ async def run( node_states: Dict[int, NodeState], f_stop: asyncio.Event, ) -> None: - """Run the VCE async.""" - taskins_queue: "asyncio.Queue[TaskIns]" = asyncio.Queue(128) - taskres_queue: "asyncio.Queue[TaskRes]" = asyncio.Queue(128) + """Run the VCE.""" + taskins_queue: "Queue[TaskIns]" = Queue() + taskres_queue: "Queue[TaskRes]" = Queue() try: @@ -197,39 +168,45 @@ async def run( backend.build() # Add workers (they submit Messages to Backend) - worker_tasks = [ - asyncio.create_task( - worker( - app_fn, - taskins_queue, - taskres_queue, - node_states, - backend, - ) - ) - for _ in range(backend.num_workers) - ] - # Create producer (adds TaskIns into Queue) - taskins_producer = asyncio.create_task( - add_taskins_to_queue( + state = state_factory.state() + + extractor_th = threading.Thread( + target=add_taskins_to_queue, + args=( + state, taskins_queue, - state_factory, nodes_mapping, - backend, - worker_tasks, f_stop, - ) + ), ) + extractor_th.start() - taskres_consumer = asyncio.create_task( - put_taskres_into_state(taskres_queue, state_factory, f_stop) + injector_th = threading.Thread( + target=put_taskres_into_state, + args=( + state, + taskres_queue, + f_stop, + ), ) + injector_th.start() - # Wait for asyncio taks pulling/pushing TaskIns/TaskRes. - # These run forever until f_stop is set or until - # all worker (consumer) coroutines are completed. Workers - # also run forever and only end if an exception is raised. - await asyncio.gather(*(taskins_producer, taskres_consumer)) + with ThreadPoolExecutor() as executor: + _ = [ + executor.submit( + worker, + app_fn, + taskins_queue, + taskres_queue, + node_states, + backend, + f_stop, + ) + for _ in range(backend.num_workers) + ] + + extractor_th.join() + injector_th.join() except Exception as ex: @@ -244,15 +221,6 @@ async def run( raise RuntimeError("Simulation Engine crashed.") from ex finally: - # Produced task terminated, now cancel worker tasks - for w_t in worker_tasks: - _ = w_t.cancel() - - while not all(w_t.done() for w_t in worker_tasks): - log(DEBUG, "Terminating async workers...") - await asyncio.sleep(0.5) - - await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()]) # Terminate backend backend.terminate() @@ -368,15 +336,13 @@ def _load() -> ClientApp: _ = app_fn() # Run main simulation loop - asyncio.run( - run( - app_fn, - backend_fn, - nodes_mapping, - state_factory, - node_states, - f_stop, - ) + run( + app_fn, + backend_fn, + nodes_mapping, + state_factory, + node_states, + f_stop, ) except LoadClientAppError as loadapp_ex: f_stop_delay = 10 diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py index c0bf506fd2b..6c247b91a4e 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -24,7 +24,7 @@ from pathlib import Path from time import sleep from typing import Dict, Optional, Set, Tuple -from unittest import IsolatedAsyncioTestCase +from unittest import TestCase from uuid import UUID from flwr.client.client_app import LoadClientAppError @@ -149,7 +149,7 @@ def start_and_shutdown( """Start Simulation Engine and terminate after specified number of seconds. Some tests need to be terminated by triggering externally an asyncio.Event. This - is enabled whtn passing `duration`>0. + is enabled when passing `duration`>0. """ f_stop = asyncio.Event() @@ -181,8 +181,8 @@ def start_and_shutdown( termination_th.join() -class AsyncTestFleetSimulationEngineRayBackend(IsolatedAsyncioTestCase): - """A basic class that enables testing asyncio functionalities.""" +class TestFleetSimulationEngineRayBackend(TestCase): + """A basic class that enables testing functionalities.""" def test_erroneous_no_supernodes_client_mapping(self) -> None: """Test with unset arguments."""