Skip to content

Commit

Permalink
refactor(framework) Remove asyncio from core Simulation Engine (#3470)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Jul 10, 2024
1 parent 02a11b9 commit 02c5d1d
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 117 deletions.
192 changes: 79 additions & 113 deletions src/py/flwr/server/superlink/fleet/vce/vce_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@
# ==============================================================================
"""Fleet Simulation Engine API."""


import asyncio
import json
import sys
import threading
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
from logging import DEBUG, ERROR, INFO, WARN
from pathlib import Path
from typing import Callable, Dict, List, Optional
from queue import Empty, Queue
from time import sleep
from typing import Callable, Dict, Optional

from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
from flwr.client.node_state import NodeState
Expand All @@ -31,7 +36,7 @@
from flwr.common.object_ref import load_app
from flwr.common.serde import message_from_taskins, message_to_taskres
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
from flwr.server.superlink.state import StateFactory
from flwr.server.superlink.state import State, StateFactory

from .backend import Backend, error_messages_backends, supported_backends

Expand All @@ -52,18 +57,21 @@ def _register_nodes(


# pylint: disable=too-many-arguments,too-many-locals
async def worker(
def worker(
app_fn: Callable[[], ClientApp],
taskins_queue: "asyncio.Queue[TaskIns]",
taskres_queue: "asyncio.Queue[TaskRes]",
taskins_queue: Queue[TaskIns],
taskres_queue: Queue[TaskRes],
node_states: Dict[int, NodeState],
backend: Backend,
f_stop: asyncio.Event,
) -> None:
"""Get TaskIns from queue and pass it to an actor in the pool to execute it."""
while True:
while not f_stop.is_set():
out_mssg = None
try:
task_ins: TaskIns = await taskins_queue.get()
# Fetch from queue with timeout. We use a timeout so
# the stopping event can be evaluated even when the queue is empty.
task_ins: TaskIns = taskins_queue.get(timeout=1.0)
node_id = task_ins.task.consumer.node_id

# Register and retrieve runstate
Expand All @@ -82,11 +90,9 @@ async def worker(
node_states[node_id].update_context(
task_ins.run_id, context=updated_context
)

except asyncio.CancelledError as e:
log(DEBUG, "Terminating async worker: %s", e)
break

except Empty:
# An exception raised if queue.get times out
pass
# Exceptions aren't raised but reported as an error message
except Exception as ex: # pylint: disable=broad-exception-caught
log(ERROR, ex)
Expand All @@ -110,83 +116,48 @@ async def worker(
task_res = message_to_taskres(out_mssg)
# Store TaskRes in state
task_res.task.pushed_at = time.time()
await taskres_queue.put(task_res)
taskres_queue.put(task_res)


async def add_taskins_to_queue(
queue: "asyncio.Queue[TaskIns]",
state_factory: StateFactory,
def add_taskins_to_queue(
state: State,
queue: Queue[TaskIns],
nodes_mapping: NodeToPartitionMapping,
backend: Backend,
consumers: List["asyncio.Task[None]"],
f_stop: asyncio.Event,
) -> None:
"""Retrieve TaskIns and add it to the queue."""
state = state_factory.state()
num_initial_consumers = len(consumers)
"""Put TaskIns in a queue from State."""
while not f_stop.is_set():
for node_id in nodes_mapping.keys():
task_ins = state.get_task_ins(node_id=node_id, limit=1)
if task_ins:
await queue.put(task_ins[0])

# Count consumers that are running
num_active = sum(not (cc.done()) for cc in consumers)

# Alert if number of consumers decreased by half
if num_active < num_initial_consumers // 2:
log(
WARN,
"Number of active workers has more than halved: (%i/%i active)",
num_active,
num_initial_consumers,
)

# Break if consumers died
if num_active == 0:
raise RuntimeError("All workers have died. Ending Simulation.")

# Log some stats
log(
DEBUG,
"Simulation Engine stats: "
"Active workers: (%i/%i) | %s (%i workers) | Tasks in queue: %i)",
num_active,
num_initial_consumers,
backend.__class__.__name__,
backend.num_workers,
queue.qsize(),
)
await asyncio.sleep(1.0)
log(DEBUG, "Async producer: Stopped pulling from StateFactory.")
task_ins_list = state.get_task_ins(node_id=node_id, limit=1)
for task_ins in task_ins_list:
queue.put(task_ins)
sleep(0.1)


async def put_taskres_into_state(
queue: "asyncio.Queue[TaskRes]",
state_factory: StateFactory,
f_stop: asyncio.Event,
def put_taskres_into_state(
state: State, queue: Queue[TaskRes], f_stop: threading.Event
) -> None:
"""Remove TaskRes from queue and add into State."""
state = state_factory.state()
"""Put TaskRes into State from a queue."""
while not f_stop.is_set():
if queue.qsize():
task_res = await queue.get()
state.store_task_res(task_res)
else:
await asyncio.sleep(0.1)
try:
taskres = queue.get(timeout=1.0)
state.store_task_res(taskres)
except Empty:
# queue is empty when timeout was triggered
pass


async def run(
def run(
app_fn: Callable[[], ClientApp],
backend_fn: Callable[[], Backend],
nodes_mapping: NodeToPartitionMapping,
state_factory: StateFactory,
node_states: Dict[int, NodeState],
f_stop: asyncio.Event,
) -> None:
"""Run the VCE async."""
taskins_queue: "asyncio.Queue[TaskIns]" = asyncio.Queue(128)
taskres_queue: "asyncio.Queue[TaskRes]" = asyncio.Queue(128)
"""Run the VCE."""
taskins_queue: "Queue[TaskIns]" = Queue()
taskres_queue: "Queue[TaskRes]" = Queue()

try:

Expand All @@ -197,39 +168,45 @@ async def run(
backend.build()

# Add workers (they submit Messages to Backend)
worker_tasks = [
asyncio.create_task(
worker(
app_fn,
taskins_queue,
taskres_queue,
node_states,
backend,
)
)
for _ in range(backend.num_workers)
]
# Create producer (adds TaskIns into Queue)
taskins_producer = asyncio.create_task(
add_taskins_to_queue(
state = state_factory.state()

extractor_th = threading.Thread(
target=add_taskins_to_queue,
args=(
state,
taskins_queue,
state_factory,
nodes_mapping,
backend,
worker_tasks,
f_stop,
)
),
)
extractor_th.start()

taskres_consumer = asyncio.create_task(
put_taskres_into_state(taskres_queue, state_factory, f_stop)
injector_th = threading.Thread(
target=put_taskres_into_state,
args=(
state,
taskres_queue,
f_stop,
),
)
injector_th.start()

# Wait for asyncio taks pulling/pushing TaskIns/TaskRes.
# These run forever until f_stop is set or until
# all worker (consumer) coroutines are completed. Workers
# also run forever and only end if an exception is raised.
await asyncio.gather(*(taskins_producer, taskres_consumer))
with ThreadPoolExecutor() as executor:
_ = [
executor.submit(
worker,
app_fn,
taskins_queue,
taskres_queue,
node_states,
backend,
f_stop,
)
for _ in range(backend.num_workers)
]

extractor_th.join()
injector_th.join()

except Exception as ex:

Expand All @@ -244,15 +221,6 @@ async def run(
raise RuntimeError("Simulation Engine crashed.") from ex

finally:
# Produced task terminated, now cancel worker tasks
for w_t in worker_tasks:
_ = w_t.cancel()

while not all(w_t.done() for w_t in worker_tasks):
log(DEBUG, "Terminating async workers...")
await asyncio.sleep(0.5)

await asyncio.gather(*[w_t for w_t in worker_tasks if not w_t.done()])

# Terminate backend
backend.terminate()
Expand Down Expand Up @@ -368,15 +336,13 @@ def _load() -> ClientApp:
_ = app_fn()

# Run main simulation loop
asyncio.run(
run(
app_fn,
backend_fn,
nodes_mapping,
state_factory,
node_states,
f_stop,
)
run(
app_fn,
backend_fn,
nodes_mapping,
state_factory,
node_states,
f_stop,
)
except LoadClientAppError as loadapp_ex:
f_stop_delay = 10
Expand Down
8 changes: 4 additions & 4 deletions src/py/flwr/server/superlink/fleet/vce/vce_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pathlib import Path
from time import sleep
from typing import Dict, Optional, Set, Tuple
from unittest import IsolatedAsyncioTestCase
from unittest import TestCase
from uuid import UUID

from flwr.client.client_app import LoadClientAppError
Expand Down Expand Up @@ -149,7 +149,7 @@ def start_and_shutdown(
"""Start Simulation Engine and terminate after specified number of seconds.
Some tests need to be terminated by triggering externally an asyncio.Event. This
is enabled whtn passing `duration`>0.
is enabled when passing `duration`>0.
"""
f_stop = asyncio.Event()

Expand Down Expand Up @@ -181,8 +181,8 @@ def start_and_shutdown(
termination_th.join()


class AsyncTestFleetSimulationEngineRayBackend(IsolatedAsyncioTestCase):
"""A basic class that enables testing asyncio functionalities."""
class TestFleetSimulationEngineRayBackend(TestCase):
"""A basic class that enables testing functionalities."""

def test_erroneous_no_supernodes_client_mapping(self) -> None:
"""Test with unset arguments."""
Expand Down

0 comments on commit 02c5d1d

Please sign in to comment.