diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/backend.py b/src/py/flwr/server/superlink/fleet/vce/backend/backend.py index 1d5e3a6a51a..31c64bd3b23 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/backend.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/backend.py @@ -33,8 +33,8 @@ def __init__(self, backend_config: BackendConfig, work_dir: str) -> None: """Construct a backend.""" @abstractmethod - async def build(self) -> None: - """Build backend asynchronously. + def build(self) -> None: + """Build backend. Different components need to be in place before workers in a backend are ready to accept jobs. When this method finishes executing, the backend should be fully @@ -54,11 +54,11 @@ def is_worker_idle(self) -> bool: """Report whether a backend worker is idle and can therefore run a ClientApp.""" @abstractmethod - async def terminate(self) -> None: + def terminate(self) -> None: """Terminate backend.""" @abstractmethod - async def process_message( + def process_message( self, app: Callable[[], ClientApp], message: Message, diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py index b6b9e248a65..0d2f4d193f0 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend.py @@ -153,12 +153,12 @@ def is_worker_idle(self) -> bool: """Report whether the pool has idle actors.""" return self.pool.is_actor_available() - async def build(self) -> None: + def build(self) -> None: """Build pool of Ray actors that this backend will submit jobs to.""" - await self.pool.add_actors_to_pool(self.pool.actors_capacity) + self.pool.add_actors_to_pool(self.pool.actors_capacity) log(DEBUG, "Constructed ActorPool with: %i actors", self.pool.num_actors) - async def process_message( + def process_message( self, app: Callable[[], ClientApp], message: Message, @@ -172,17 +172,16 @@ async def process_message( try: # Submit a task to the pool - future = await self.pool.submit( + future = self.pool.submit( lambda a, a_fn, mssg, cid, state: a.run.remote(a_fn, mssg, cid, state), (app, message, str(partition_id), context), ) - await future # Fetch result ( out_mssg, updated_context, - ) = await self.pool.fetch_result_and_return_actor_to_pool(future) + ) = self.pool.fetch_result_and_return_actor_to_pool(future) return out_mssg, updated_context @@ -193,11 +192,11 @@ async def process_message( self.__class__.__name__, ) # add actor back into pool - await self.pool.add_actor_back_to_pool(future) + self.pool.add_actor_back_to_pool(future) raise ex - async def terminate(self) -> None: + def terminate(self) -> None: """Terminate all actors in actor pool.""" - await self.pool.terminate_all_actors() + self.pool.terminate_all_actors() ray.shutdown() log(DEBUG, "Terminated %s", self.__class__.__name__) diff --git a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py index b822db87673..287983003f8 100644 --- a/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/backend/raybackend_test.py @@ -14,11 +14,10 @@ # ============================================================================== """Test for Ray backend for the Fleet API using the Simulation Engine.""" -import asyncio from math import pi from pathlib import Path from typing import Callable, Dict, Optional, Tuple, Union -from unittest import IsolatedAsyncioTestCase +from unittest import TestCase import ray @@ -84,18 +83,18 @@ def _load_app() -> ClientApp: return _load_app -async def backend_build_process_and_termination( +def backend_build_process_and_termination( backend: RayBackend, process_args: Optional[Tuple[Callable[[], ClientApp], Message, Context]] = None, ) -> Union[Tuple[Message, Context], None]: """Build, process job and terminate RayBackend.""" - await backend.build() + backend.build() to_return = None if process_args: - to_return = await backend.process_message(*process_args) + to_return = backend.process_message(*process_args) - await backend.terminate() + backend.terminate() return to_return @@ -129,10 +128,10 @@ def _create_message_and_context() -> Tuple[Message, Context, float]: return message, context, expected_output -class AsyncTestRayBackend(IsolatedAsyncioTestCase): - """A basic class that allows runnig multliple asyncio tests.""" +class TestRayBackend(TestCase): + """A basic class that allows runnig multliple tests.""" - async def on_cleanup(self) -> None: + def doCleanups(self) -> None: """Ensure Ray has shutdown.""" if ray.is_initialized(): ray.shutdown() @@ -140,9 +139,7 @@ async def on_cleanup(self) -> None: def test_backend_creation_and_termination(self) -> None: """Test creation of RayBackend and its termination.""" backend = RayBackend(backend_config={}, work_dir="") - asyncio.run( - backend_build_process_and_termination(backend=backend, process_args=None) - ) + backend_build_process_and_termination(backend=backend, process_args=None) def test_backend_creation_submit_and_termination( self, @@ -157,10 +154,8 @@ def test_backend_creation_submit_and_termination( message, context, expected_output = _create_message_and_context() - res = asyncio.run( - backend_build_process_and_termination( - backend=backend, process_args=(client_app_callable, message, context) - ) + res = backend_build_process_and_termination( + backend=backend, process_args=(client_app_callable, message, context) ) if res is None: @@ -189,7 +184,6 @@ def test_backend_creation_submit_and_termination_non_existing_client_app( self.test_backend_creation_submit_and_termination( client_app_loader=_load_from_module("a_non_existing_module:app") ) - self.addAsyncCleanup(self.on_cleanup) def test_backend_creation_submit_and_termination_existing_client_app( self, @@ -217,7 +211,6 @@ def test_backend_creation_submit_and_termination_existing_client_app_unsetworkdi client_app_loader=_load_from_module("raybackend_test:client_app"), workdir="/?&%$^#%@$!", ) - self.addAsyncCleanup(self.on_cleanup) def test_backend_creation_with_init_arguments(self) -> None: """Testing whether init args are properly parsed to Ray.""" @@ -248,5 +241,3 @@ def test_backend_creation_with_init_arguments(self) -> None: nodes = ray.nodes() assert nodes[0]["Resources"]["CPU"] == backend_config_2["init_args"]["num_cpus"] - - self.addAsyncCleanup(self.on_cleanup) diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api.py b/src/py/flwr/server/superlink/fleet/vce/vce_api.py index 0e8171485a5..4eeaa519700 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api.py @@ -14,14 +14,19 @@ # ============================================================================== """Fleet Simulation Engine API.""" + import asyncio import json import sys +import threading import time import traceback +from concurrent.futures import ThreadPoolExecutor from logging import DEBUG, ERROR, INFO, WARN from pathlib import Path -from typing import Callable, Dict, List, Optional +from queue import Empty, Queue +from time import sleep +from typing import Callable, Dict, Optional from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError from flwr.client.node_state import NodeState @@ -31,7 +36,7 @@ from flwr.common.object_ref import load_app from flwr.common.serde import message_from_taskins, message_to_taskres from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 -from flwr.server.superlink.state import StateFactory +from flwr.server.superlink.state import State, StateFactory from .backend import Backend, error_messages_backends, supported_backends @@ -52,18 +57,21 @@ def _register_nodes( # pylint: disable=too-many-arguments,too-many-locals -async def worker( +def worker( app_fn: Callable[[], ClientApp], - taskins_queue: "asyncio.Queue[TaskIns]", - taskres_queue: "asyncio.Queue[TaskRes]", + taskins_queue: "Queue[TaskIns]", + taskres_queue: "Queue[TaskRes]", node_states: Dict[int, NodeState], backend: Backend, + f_stop: asyncio.Event, ) -> None: """Get TaskIns from queue and pass it to an actor in the pool to execute it.""" - while True: + while not f_stop.is_set(): out_mssg = None try: - task_ins: TaskIns = await taskins_queue.get() + # Fetch from queue with timeout. We use a timeout so + # the stopping event can be evaluated even when the queue is empty. + task_ins: TaskIns = taskins_queue.get(timeout=1.0) node_id = task_ins.task.consumer.node_id # Register and retrieve runstate @@ -74,7 +82,7 @@ async def worker( message = message_from_taskins(task_ins) # Let backend process message - out_mssg, updated_context = await backend.process_message( + out_mssg, updated_context = backend.process_message( app_fn, message, context ) @@ -82,11 +90,9 @@ async def worker( node_states[node_id].update_context( task_ins.run_id, context=updated_context ) - - except asyncio.CancelledError as e: - log(DEBUG, "Terminating async worker: %s", e) - break - + except Empty: + # An exception raised if queue.get times out + pass # Exceptions aren't raised but reported as an error message except Exception as ex: # pylint: disable=broad-exception-caught log(ERROR, ex) @@ -110,73 +116,38 @@ async def worker( task_res = message_to_taskres(out_mssg) # Store TaskRes in state task_res.task.pushed_at = time.time() - await taskres_queue.put(task_res) + taskres_queue.put(task_res) -async def add_taskins_to_queue( - queue: "asyncio.Queue[TaskIns]", - state_factory: StateFactory, +def add_taskins_to_queue( + state: State, + queue: "Queue[TaskIns]", nodes_mapping: NodeToPartitionMapping, - backend: Backend, - consumers: List["asyncio.Task[None]"], f_stop: asyncio.Event, ) -> None: - """Retrieve TaskIns and add it to the queue.""" - state = state_factory.state() - num_initial_consumers = len(consumers) + """Put TaskIns in a queue from State.""" while not f_stop.is_set(): for node_id in nodes_mapping.keys(): - task_ins = state.get_task_ins(node_id=node_id, limit=1) - if task_ins: - await queue.put(task_ins[0]) - - # Count consumers that are running - num_active = sum(not (cc.done()) for cc in consumers) - - # Alert if number of consumers decreased by half - if num_active < num_initial_consumers // 2: - log( - WARN, - "Number of active workers has more than halved: (%i/%i active)", - num_active, - num_initial_consumers, - ) - - # Break if consumers died - if num_active == 0: - raise RuntimeError("All workers have died. Ending Simulation.") - - # Log some stats - log( - DEBUG, - "Simulation Engine stats: " - "Active workers: (%i/%i) | %s (%i workers) | Tasks in queue: %i)", - num_active, - num_initial_consumers, - backend.__class__.__name__, - backend.num_workers, - queue.qsize(), - ) - await asyncio.sleep(1.0) - log(DEBUG, "Async producer: Stopped pulling from StateFactory.") + task_ins_list = state.get_task_ins(node_id=node_id, limit=1) + for task_ins in task_ins_list: + queue.put(task_ins) + sleep(0.1) -async def put_taskres_into_state( - queue: "asyncio.Queue[TaskRes]", - state_factory: StateFactory, - f_stop: asyncio.Event, +def put_taskres_into_state( + state: State, queue: "Queue[TaskRes]", f_stop: threading.Event ) -> None: - """Remove TaskRes from queue and add into State.""" - state = state_factory.state() + """Put TaskRes into State from a queue.""" while not f_stop.is_set(): - if queue.qsize(): - task_res = await queue.get() - state.store_task_res(task_res) - else: - await asyncio.sleep(0.1) + try: + taskres = queue.get(timeout=1.0) + state.store_task_res(taskres) + except Empty: + # queue is empty when timeout was triggered + pass -async def run( +def run( app_fn: Callable[[], ClientApp], backend_fn: Callable[[], Backend], nodes_mapping: NodeToPartitionMapping, @@ -184,9 +155,9 @@ async def run( node_states: Dict[int, NodeState], f_stop: asyncio.Event, ) -> None: - """Run the VCE async.""" - taskins_queue: "asyncio.Queue[TaskIns]" = asyncio.Queue(128) - taskres_queue: "asyncio.Queue[TaskRes]" = asyncio.Queue(128) + """Run the VCE.""" + taskins_queue: "Queue[TaskIns]" = Queue() + taskres_queue: "Queue[TaskRes]" = Queue() try: @@ -194,42 +165,48 @@ async def run( backend = backend_fn() # Build backend - await backend.build() + backend.build() # Add workers (they submit Messages to Backend) - worker_tasks = [ - asyncio.create_task( - worker( - app_fn, - taskins_queue, - taskres_queue, - node_states, - backend, - ) - ) - for _ in range(backend.num_workers) - ] - # Create producer (adds TaskIns into Queue) - taskins_producer = asyncio.create_task( - add_taskins_to_queue( + state = state_factory.state() + + extractor_th = threading.Thread( + target=add_taskins_to_queue, + args=( + state, taskins_queue, - state_factory, nodes_mapping, - backend, - worker_tasks, f_stop, - ) + ), ) + extractor_th.start() - taskres_consumer = asyncio.create_task( - put_taskres_into_state(taskres_queue, state_factory, f_stop) + injector_th = threading.Thread( + target=put_taskres_into_state, + args=( + state, + taskres_queue, + f_stop, + ), ) + injector_th.start() - # Wait for asyncio taks pulling/pushing TaskIns/TaskRes. - # These run forever until f_stop is set or until - # all worker (consumer) coroutines are completed. Workers - # also run forever and only end if an exception is raised. - await asyncio.gather(*(taskins_producer, taskres_consumer)) + with ThreadPoolExecutor() as executor: + _ = [ + executor.submit( + worker, + app_fn, + taskins_queue, + taskres_queue, + node_states, + backend, + f_stop, + ) + for _ in range(backend.num_workers) + ] + + extractor_th.join() + injector_th.join() except Exception as ex: @@ -244,18 +221,9 @@ async def run( raise RuntimeError("Simulation Engine crashed.") from ex finally: - # Produced task terminated, now cancel worker tasks - for w_t in worker_tasks: - _ = w_t.cancel() - - while not all(w_t.done() for w_t in worker_tasks): - log(DEBUG, "Terminating async workers...") - await asyncio.sleep(0.5) - - await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()]) # Terminate backend - await backend.terminate() + backend.terminate() # pylint: disable=too-many-arguments,unused-argument,too-many-locals,too-many-branches @@ -368,15 +336,13 @@ def _load() -> ClientApp: _ = app_fn() # Run main simulation loop - asyncio.run( - run( - app_fn, - backend_fn, - nodes_mapping, - state_factory, - node_states, - f_stop, - ) + run( + app_fn, + backend_fn, + nodes_mapping, + state_factory, + node_states, + f_stop, ) except LoadClientAppError as loadapp_ex: f_stop_delay = 10 diff --git a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py index c0bf506fd2b..6c247b91a4e 100644 --- a/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/src/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -24,7 +24,7 @@ from pathlib import Path from time import sleep from typing import Dict, Optional, Set, Tuple -from unittest import IsolatedAsyncioTestCase +from unittest import TestCase from uuid import UUID from flwr.client.client_app import LoadClientAppError @@ -149,7 +149,7 @@ def start_and_shutdown( """Start Simulation Engine and terminate after specified number of seconds. Some tests need to be terminated by triggering externally an asyncio.Event. This - is enabled whtn passing `duration`>0. + is enabled when passing `duration`>0. """ f_stop = asyncio.Event() @@ -181,8 +181,8 @@ def start_and_shutdown( termination_th.join() -class AsyncTestFleetSimulationEngineRayBackend(IsolatedAsyncioTestCase): - """A basic class that enables testing asyncio functionalities.""" +class TestFleetSimulationEngineRayBackend(TestCase): + """A basic class that enables testing functionalities.""" def test_erroneous_no_supernodes_client_mapping(self) -> None: """Test with unset arguments.""" diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 7afffb86533..b1c9d2b9c0c 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -14,7 +14,6 @@ # ============================================================================== """Ray-based Flower Actor and ActorPool implementation.""" -import asyncio import threading from abc import ABC from logging import DEBUG, ERROR, WARNING @@ -411,9 +410,7 @@ def __init__( self.client_resources = client_resources # Queue of idle actors - self.pool: "asyncio.Queue[Type[VirtualClientEngineActor]]" = asyncio.Queue( - maxsize=1024 - ) + self.pool: List[VirtualClientEngineActor] = [] self.num_actors = 0 # Resolve arguments to pass during actor init @@ -427,38 +424,37 @@ def __init__( # Figure out how many actors can be created given the cluster resources # and the resources the user indicates each VirtualClient will need self.actors_capacity = pool_size_from_resources(client_resources) - self._future_to_actor: Dict[Any, Type[VirtualClientEngineActor]] = {} + self._future_to_actor: Dict[Any, VirtualClientEngineActor] = {} def is_actor_available(self) -> bool: """Return true if there is an idle actor.""" - return self.pool.qsize() > 0 + return len(self.pool) > 0 - async def add_actors_to_pool(self, num_actors: int) -> None: + def add_actors_to_pool(self, num_actors: int) -> None: """Add actors to the pool. This method may be executed also if new resources are added to your Ray cluster (e.g. you add a new node). """ for _ in range(num_actors): - await self.pool.put(self.create_actor_fn()) # type: ignore + self.pool.append(self.create_actor_fn()) # type: ignore self.num_actors += num_actors - async def terminate_all_actors(self) -> None: + def terminate_all_actors(self) -> None: """Terminate actors in pool.""" num_terminated = 0 - while self.pool.qsize(): - actor = await self.pool.get() + for actor in self.pool: actor.terminate.remote() # type: ignore num_terminated += 1 log(DEBUG, "Terminated %i actors", num_terminated) - async def submit( + def submit( self, actor_fn: Any, job: Tuple[ClientAppFn, Message, str, Context] ) -> Any: """On idle actor, submit job and return future.""" # Remove idle actor from pool - actor = await self.pool.get() + actor = self.pool.pop() # Submit job to actor app_fn, mssg, cid, context = job future = actor_fn(actor, app_fn, mssg, cid, context) @@ -467,18 +463,18 @@ async def submit( self._future_to_actor[future] = actor return future - async def add_actor_back_to_pool(self, future: Any) -> None: + def add_actor_back_to_pool(self, future: Any) -> None: """Ad actor assigned to run future back into the pool.""" actor = self._future_to_actor.pop(future) - await self.pool.put(actor) + self.pool.append(actor) - async def fetch_result_and_return_actor_to_pool( + def fetch_result_and_return_actor_to_pool( self, future: Any ) -> Tuple[Message, Context]: """Pull result given a future and add actor back to pool.""" - # Get actor that ran job - await self.add_actor_back_to_pool(future) # Retrieve result for object store # Instead of doing ray.get(future) we await it - _, out_mssg, updated_context = await future + _, out_mssg, updated_context = ray.get(future) + # Get actor that ran job + self.add_actor_back_to_pool(future) return out_mssg, updated_context