diff --git a/chimerapy/engine/eventbus/eventbus.py b/chimerapy/engine/eventbus/eventbus.py index 1d85b3d6..b0912105 100644 --- a/chimerapy/engine/eventbus/eventbus.py +++ b/chimerapy/engine/eventbus/eventbus.py @@ -3,7 +3,17 @@ from datetime import datetime from collections import deque from concurrent.futures import Future -from typing import Any, Generic, Type, Callable, Awaitable, Optional, Literal, TypeVar, Dict +from typing import ( + Any, + Generic, + Type, + Callable, + Awaitable, + Optional, + Literal, + TypeVar, + Dict, +) from aioreactive import AsyncObservable, AsyncObserver, AsyncSubject from dataclasses import dataclass, field @@ -32,6 +42,7 @@ def __init__(self, thread: Optional[AsyncLoopThread] = None): self.thread = thread # State information + self.awaitable_events: Dict[str, asyncio.Event] = {} self.subscription_map: Dict[AsyncObserver, Any] = {} #################################################################### @@ -43,6 +54,11 @@ async def asend(self, event: Event): self._event_counts += 1 await self.stream.asend(event) + if event.type in self.awaitable_events: + self.latest_event = event + self.awaitable_events[event.type].set() + del self.awaitable_events[event.type] + async def asubscribe(self, observer: AsyncObserver): self._sub_counts += 1 subscription = await self.stream.subscribe_async(observer) @@ -50,12 +66,22 @@ async def asubscribe(self, observer: AsyncObserver): async def aunsubscribe(self, observer: AsyncObserver): if observer not in self.subscription_map: - raise RuntimeError("Trying to unsubscribe an Observer that is not subscribed") + raise RuntimeError( + "Trying to unsubscribe an Observer that is not subscribed" + ) self._sub_counts -= 1 subscription = self.subscription_map[observer] await subscription.dispose_async() + async def await_event(self, event_type: str) -> Event: + if event_type not in self.awaitable_events: + self.awaitable_events[event_type] = asyncio.Event() + + event_trigger = self.awaitable_events[event_type] + await event_trigger.wait() + return self.latest_event + #################################################################### ## Sync #################################################################### diff --git a/chimerapy/engine/manager/events.py b/chimerapy/engine/manager/events.py index be8fdee6..e85bf8b2 100644 --- a/chimerapy/engine/manager/events.py +++ b/chimerapy/engine/manager/events.py @@ -8,6 +8,12 @@ class StartEvent: ... +@dataclass +class UpdateSendArchiveEvent: # update_send_archive + worker_id: str + success: bool + + @dataclass class WorkerRegisterEvent: # worker_register worker_state: WorkerState @@ -31,4 +37,4 @@ class DeregisterEntityEvent: # entity_deregister @dataclass class MoveTransferredFilesEvent: # move_transferred_files - unzip: bool + worker_state: WorkerState diff --git a/chimerapy/engine/manager/http_server_service.py b/chimerapy/engine/manager/http_server_service.py index c5ad83d3..9fd1ce1d 100644 --- a/chimerapy/engine/manager/http_server_service.py +++ b/chimerapy/engine/manager/http_server_service.py @@ -13,7 +13,12 @@ from ..networking.async_loop_thread import AsyncLoopThread from ..networking import Server from ..networking.enums import MANAGER_MESSAGE -from .events import WorkerRegisterEvent, WorkerDeregisterEvent +from .events import ( + MoveTransferredFilesEvent, + WorkerRegisterEvent, + WorkerDeregisterEvent, + UpdateSendArchiveEvent, +) logger = _logger.getLogger("chimerapy-engine") @@ -48,9 +53,11 @@ def __init__( id="Manager", routes=[ # Worker API + web.get("/", self._home), web.post("/workers/register", self._register_worker_route), web.post("/workers/deregister", self._deregister_worker_route), web.post("/workers/node_status", self._update_nodes_status), + web.post("/workers/send_archive", self._update_send_archive), ], thread=self._thread, ) @@ -68,6 +75,7 @@ def __init__( ), "move_transferred_files": TypedObserver( "move_transferred_files", + MoveTransferredFilesEvent, on_asend=self.move_transferred_files, handle_event="unpack", ), @@ -120,8 +128,17 @@ def _future_flush(self): except Exception: logger.error(traceback.format_exc()) - def move_transferred_files(self, unzip: bool) -> bool: - return self._server.move_transfer_files(self.state.logdir, unzip) + async def move_transferred_files(self, worker_state: WorkerState) -> bool: + return await self._server.move_transferred_files( + self.state.logdir, owner=worker_state.name, owner_id=worker_state.id + ) + + ##################################################################################### + ## Manager User Routes + ##################################################################################### + + async def _home(self, request: web.Request): + return web.Response(text="ChimeraPy Manager running!") ##################################################################################### ## Worker -> Manager Routes @@ -167,9 +184,14 @@ async def _update_nodes_status(self, request: web.Request): if worker_state.id in self.state.workers: update_dataclass(self.state.workers[worker_state.id], worker_state) else: - logger.error(f"{self}: non-registered Worker update: {worker_state.id}") - # logger.debug(f"{self}: Nodes status update to: {self.state.workers}") + logger.warning(f"{self}: non-registered Worker update: {worker_state.id}") + + return web.HTTPOk() + async def _update_send_archive(self, request: web.Request): + msg = await request.json() + event_data = UpdateSendArchiveEvent(**msg) + await self.eventbus.asend(Event("update_send_archive", event_data)) return web.HTTPOk() ##################################################################################### diff --git a/chimerapy/engine/manager/manager.py b/chimerapy/engine/manager/manager.py index 3c8f70e1..517412ee 100644 --- a/chimerapy/engine/manager/manager.py +++ b/chimerapy/engine/manager/manager.py @@ -320,8 +320,8 @@ async def async_request_registered_method( async def async_stop(self) -> bool: return await self.worker_handler.stop() - async def async_collect(self, unzip: bool = True) -> bool: - return await self.worker_handler.collect(unzip) + async def async_collect(self) -> bool: + return await self.worker_handler.collect() async def async_reset(self, keep_workers: bool = True): return await self.worker_handler.reset(keep_workers) @@ -454,20 +454,17 @@ def stop(self) -> Future[bool]: """ return self._exec_coro(self.async_stop()) - def collect(self, unzip: bool = True) -> Future[bool]: + def collect(self) -> Future[bool]: """Collect data from the Workers First, we wait until all the Nodes have finished save their data.\ Then, manager request that Nodes' from the Workers. - Args: - unzip (bool): Should the .zip archives be extracted. - Returns: Future[bool]: Future of success in collect data from Workers """ - return self._exec_coro(self.async_collect(unzip)) + return self._exec_coro(self.async_collect()) def reset( self, keep_workers: bool = True, blocking: bool = True diff --git a/chimerapy/engine/manager/worker_handler_service.py b/chimerapy/engine/manager/worker_handler_service.py index 4cbf24f3..1b8f33f9 100644 --- a/chimerapy/engine/manager/worker_handler_service.py +++ b/chimerapy/engine/manager/worker_handler_service.py @@ -13,6 +13,7 @@ from chimerapy.engine import config from chimerapy.engine import _logger +from ..utils import async_waiting_for from ..data_protocols import NodePubTable from ..node import NodeConfig from ..networking import Client, DataChunk @@ -27,6 +28,7 @@ RegisterEntityEvent, DeregisterEntityEvent, MoveTransferredFilesEvent, + UpdateSendArchiveEvent, ) logger = _logger.getLogger("chimerapy-engine") @@ -46,6 +48,7 @@ def __init__(self, name: str, eventbus: EventBus, state: ManagerState): self.worker_graph_map: Dict = {} self.commitable_graph: bool = False self.node_pub_table = NodePubTable() + self.collected_workers: Dict[str, bool] = {} # Also create a tempfolder to store any miscellaneous files and folders self.tempfolder = pathlib.Path(tempfile.mkdtemp()) @@ -67,6 +70,12 @@ def __init__(self, name: str, eventbus: EventBus, state: ManagerState): event_data_cls=WorkerDeregisterEvent, handle_event="unpack", ), + "update_send_archive": TypedObserver( + "update_send_archive", + on_asend=self.update_send_archive, + event_data_cls=UpdateSendArchiveEvent, + handle_event="unpack", + ), } for ob in self.observers.values(): self.eventbus.subscribe(ob).result(timeout=1) @@ -126,25 +135,28 @@ async def _register_worker(self, worker_state: WorkerState) -> bool: return True - async def _deregister_worker(self, worker_id: str) -> bool: + async def _deregister_worker(self, worker_state: WorkerState) -> bool: # Deregister entity from logging await self.eventbus.asend( - Event("entity_deregister", DeregisterEntityEvent(worker_id=worker_id)) + Event("entity_deregister", DeregisterEntityEvent(worker_id=worker_state.id)) ) - if worker_id in self.state.workers: - state = self.state.workers[worker_id] + if worker_state.id in self.state.workers: + state = self.state.workers[worker_state.id] logger.info( - f"Manager deregistered \ - from {state.ip}" + f"Manager deregistered " + f"from {state.ip}" ) - del self.state.workers[worker_id] + del self.state.workers[worker_state.id] return True return False + async def update_send_archive(self, worker_id: str, success: bool): + self.collected_workers[worker_id] = success + def _register_graph(self, graph: Graph): """Verifying that a Graph is valid, that is a DAG. @@ -590,6 +602,47 @@ async def _setup_p2p_connections(self) -> bool: return all(results) + async def _single_worker_collect(self, worker_id: str) -> bool: + + # Just requesting for the collection to start + data = {"path": str(self.state.logdir)} + async with aiohttp.ClientSession() as client: + async with client.post( + f"{self._get_worker_ip(worker_id)}/nodes/collect", + data=json.dumps(data), + ) as resp: + + if not resp.ok: + logger.error( + f"{self}: Collection failed, " + "responded {resp.ok} to collect request" + ) + return False + + # Now we have to wait until worker says they finished transferring + await async_waiting_for(condition=lambda: worker_id in self.collected_workers) + success = self.collected_workers[worker_id] + if not success: + logger.error( + f"{self}: Collection failed, " + "never updated on archival completion" + ) + + # Move files to their destination + try: + await self.eventbus.asend( + Event( + "move_transferred_files", + MoveTransferredFilesEvent( + worker_state=self.state.workers[worker_id] + ), + ) + ) + except Exception: + logger.error(traceback.format_exc()) + + return success + #################################################################### ## Front-facing ASync API #################################################################### @@ -753,31 +806,24 @@ async def stop(self) -> bool: return success - async def collect(self, unzip: bool = True) -> bool: + async def collect(self) -> bool: - # Then tell them to send the data to the Manager - success = await self._broadcast_request( - htype="post", - route="/nodes/collect", - data={"path": str(self.state.logdir)}, - timeout=None, - ) - await asyncio.sleep(1) + # Clear + self.collected_workers.clear() - if success: - try: - await self.eventbus.asend(Event("save_meta")) - await self.eventbus.asend( - Event( - "move_transferred_files", MoveTransferredFilesEvent(unzip=unzip) - ) - ) - except Exception: - logger.error(traceback.format_exc()) + # Request all workers + coros: List[Coroutine] = [] + for worker_id in self.state.workers: + coros.append(self._single_worker_collect(worker_id)) - # logger.info(f"{self}: finished collect") + try: + results = await asyncio.gather(*coros) + except Exception: + logger.error(traceback.format_exc()) + return False - return success + await self.eventbus.asend(Event("save_meta")) + return all(results) async def reset(self, keep_workers: bool = True): @@ -796,8 +842,8 @@ async def reset(self, keep_workers: bool = True): # If not keep Workers, then deregister all if not keep_workers: - for worker_id in self.state.workers: - await self._deregister_worker(worker_id) + for worker_state in self.state.workers.values(): + await self._deregister_worker(worker_state) # Update variable data self.node_pub_table = NodePubTable() diff --git a/chimerapy/engine/networking/client.py b/chimerapy/engine/networking/client.py index 4804b9fb..5d99a20a 100644 --- a/chimerapy/engine/networking/client.py +++ b/chimerapy/engine/networking/client.py @@ -5,14 +5,13 @@ import collections import uuid import pathlib -import shutil +import aioshutil import tempfile import json import enum import logging import traceback -import atexit -import multiprocess as mp +import asyncio_atexit from concurrent.futures import Future from typing import Dict, Optional, Callable, Any, Union, List, Coroutine @@ -56,12 +55,7 @@ def __init__( self._session = None # The EventLoop - if thread: - self._thread = thread - else: - self._thread = AsyncLoopThread() - self._thread.start() - + self._thread = thread self.futures: List[Future] = [] # State variables @@ -87,9 +81,6 @@ def __init__( else: self.logger = _logger.getLogger("chimerapy-engine-networking") - # Make sure to shutdown correctly - atexit.register(self.shutdown) - def __str__(self): return f"" @@ -101,6 +92,7 @@ def setLogger(self, parent_logger: logging.Logger): #################################################################### def _exec_coro(self, coro: Coroutine) -> Future: + assert self._thread # Submitting the coroutine future = self._thread.exec(coro) @@ -209,13 +201,15 @@ async def async_connect(self) -> bool: self._ws = await self._session.ws_connect(f"http://{self.host}:{self.port}/ws") # Create task to read - # self._thread.exec(self._read_ws()) task = asyncio.create_task(self._read_ws()) self.tasks.append(task) # Register the client await self._register() + # Make sure to shutdown correctly + asyncio_atexit.register(self.async_shutdown) + return True async def async_send_file( @@ -242,53 +236,25 @@ async def async_send_file( return True - async def async_send_folder(self, sender_id: str, dir: pathlib.Path): + async def async_send_folder(self, sender_id: str, dir: pathlib.Path) -> bool: if not dir.is_dir() and not dir.exists(): self.logger.error(f"Cannot send non-existent dir: {dir}.") return False - # Having continuing attempts to make the zip folder - miss_counter = 0 - delay = 1 - zip_timeout = config.get("comms.timeout.zip-time") - - # First, we need to archive the folder into a zip file - while True: - try: - process = mp.Process( - target=shutil.make_archive, - args=( - str(dir), - "zip", - dir.parent, - dir.name, - ), - ) - process.start() - process.join() - assert process.exitcode == 0 - - break - except Exception: - self.logger.warning("Temp folder couldn't be zipped.") - self.logger.error(traceback.format_exc()) - await asyncio.sleep(delay) - miss_counter += 1 - - if zip_timeout < delay * miss_counter: - self.logger.error("Temp folder couldn't be zipped.") - return False - zip_file = dir.parent / f"{dir.name}.zip" + try: + await aioshutil.make_archive(str(dir), "zip", dir.parent, dir.name) + except Exception: + self.logger.warning(f"{self}: Temp folder couldn't be zipped.") + self.logger.error(traceback.format_exc()) + return False # Compose the url url = f"http://{self.host}:{self.port}/file/post" # Then send the file - await self.async_send_file(url, sender_id, zip_file) - - return True + return await self.async_send_file(url, sender_id, zip_file) async def async_shutdown(self, msg: Dict = {}): @@ -322,7 +288,7 @@ async def async_send(self, signal: enum.Enum, data: Any, ok: bool = False) -> bo #################################################################### def send(self, signal: enum.Enum, data: Any, ok: bool = False) -> Future[bool]: - return self._thread.exec(self.async_send(signal, data, ok)) + return self._exec_coro(self.async_send(signal, data, ok)) def send_file(self, sender_id: str, filepath: pathlib.Path) -> Future[bool]: # Compose the url @@ -338,6 +304,14 @@ def send_folder(self, sender_id: str, dir: pathlib.Path) -> Future[bool]: return self._exec_coro(self.async_send_folder(sender_id, dir)) def connect(self, blocking: bool = True) -> Union[bool, Future[bool]]: + + if not self._thread: + self._thread = AsyncLoopThread() + self._thread.start() + else: + if not self._thread.is_alive(): + self._thread.start() + future = self._exec_coro(self.async_connect()) if blocking: diff --git a/chimerapy/engine/networking/server.py b/chimerapy/engine/networking/server.py index 34b47a63..541d61b6 100644 --- a/chimerapy/engine/networking/server.py +++ b/chimerapy/engine/networking/server.py @@ -1,19 +1,18 @@ # Built-in -from typing import Callable, Dict, Optional, Any, List, Coroutine +from typing import Callable, Dict, Optional, List, Coroutine import asyncio import logging import uuid import collections -import time import pathlib import tempfile -import shutil -import os +import aioshutil +import asyncio_atexit import json import enum import traceback import aiofiles -import atexit +from dataclasses import dataclass, field from concurrent.futures import Future # Third-party @@ -26,7 +25,6 @@ from chimerapy.engine.utils import ( create_payload, async_waiting_for, - waiting_for, get_ip_address, ) from .enums import GENERAL_MESSAGE @@ -44,6 +42,21 @@ # https://docs.aiohttp.org/en/stable/web_quickstart.html#file-uploads +@dataclass +class FileTransferRecord: + sender_id: str + uuid: str + filename: str + location: pathlib.Path + size: int + complete: bool = False + + +@dataclass +class FileTransferTable: + records: Dict[str, FileTransferRecord] = field(default_factory=dict) + + class Server: def __init__( self, @@ -73,12 +86,7 @@ def __init__( self.ws_handlers = {k.value: v for k, v in ws_handlers.items()} # The EventLoop - if thread: - self._thread = thread - else: - self._thread = AsyncLoopThread() - self._thread.start() - + self._thread = thread self.futures: List[Future] = [] # Using flag for marking if system should be running @@ -111,20 +119,13 @@ def __init__( # Adding file transfer capabilities self.tempfolder = pathlib.Path(tempfile.mkdtemp()) - self.file_transfer_records: Dict[ - str, Dict[str, Dict[str, Any]] - ] = collections.defaultdict( - dict - ) # Need to refactor this! + self.file_transfer_records = FileTransferTable() if parent_logger is not None: self.logger = _logger.fork(parent_logger, "server") else: self.logger = _logger.getLogger("chimerapy-engine-networking") - # Make sure to shutdown correctly - atexit.register(self.shutdown) - def __str__(self): return f"" @@ -133,6 +134,7 @@ def __str__(self): #################################################################### def _exec_coro(self, coro: Coroutine) -> Future: + assert self._thread # Submitting the coroutine future = self._thread.exec(coro) @@ -141,6 +143,27 @@ def _exec_coro(self, coro: Coroutine) -> Future: return future + async def move_transferred_files( + self, + dst: pathlib.Path, + owner: Optional[str] = None, + owner_id: Optional[str] = None, + ) -> bool: + + for id, file_record in self.file_transfer_records.records.items(): + + if owner and owner != file_record.sender_id: + continue + + await aioshutil.unpack_archive(file_record.location, dst) + + if owner: + new_file = dst / ".".join(file_record.filename.split(".")[:-1]) + dst_file = dst / f"{owner}-{owner_id}-{str(uuid.uuid4())[:4]}" + await aioshutil.move(new_file, dst_file) + + return True + #################################################################### # Server Setters and Getters #################################################################### @@ -149,19 +172,10 @@ def add_routes(self, routes: List[web.RouteDef]): self._app.add_routes(routes) #################################################################### - # Server WS Handlers + # Server Routes #################################################################### - async def _ok(self, msg: Dict, ws: web.WebSocketResponse): - self.uuid_records.append(msg["data"]["uuid"]) - - async def _register_ws_client(self, msg: Dict, ws: web.WebSocketResponse): - # self.logger.debug(f"{self}: reigstered client: {msg['data']['client_id']}") - # Storing the client information - self.ws_clients[msg["data"]["client_id"]] = ws - - async def _file_receive(self, request): - + async def _file_receive(self, request) -> web.Response: reader = await request.multipart() # /!\ Don't forget to validate your inputs /!\ @@ -179,24 +193,24 @@ async def _file_receive(self, request): filename = field.filename # Attaching a UUID to prevent possible collision - id = uuid.uuid4() + id = str(uuid.uuid4()) filename_list = filename.split(".") uuid_filename = str(id) + "." + filename_list[1] # Create dst filepath - dst_filepath = self.tempfolder / uuid_filename + location = self.tempfolder / uuid_filename # Create the record and mark that is not complete + file_entry = FileTransferRecord( + sender_id=meta["sender_id"], + filename=filename, + uuid=id, + location=location, + size=meta["size"], + ) + # Keep record of the files sent! - self.file_transfer_records[meta["sender_id"]][filename] = { - "uuid": id, - "uuid_filename": uuid_filename, - "filename": filename, - "dst_filepath": dst_filepath, - "read": 0, - "size": meta["size"], - "complete": False, - } + self.file_transfer_records.records[id] = file_entry # You cannot rely on Content-Length if transfer is chunked. read = 0 @@ -204,16 +218,17 @@ async def _file_receive(self, request): prev_n = 0 # Reading the buffer and writing the file - async with aiofiles.open(dst_filepath, "wb") as f: + async with aiofiles.open(location, "wb") as f: + # tqdm_out = TqdmToLogger(self.logger, level=logging.INFO) with tqdm( total=1, unit="B", unit_scale=True, desc=f"File {field.filename}", - miniters=1, + # file=tqdm_out, ) as pbar: while True: - chunk = await field.read_chunk() # 8192 bytes by default. + chunk = await field.read_chunk(8192 * 10) # 8192 bytes by default. if not chunk: break await f.write(chunk) @@ -223,13 +238,22 @@ async def _file_receive(self, request): prev_n = pbar.n # After finishing, mark the size and that is complete - self.file_transfer_records[meta["sender_id"]][filename].update( - {"size": total_size, "complete": True} - ) + self.file_transfer_records.records[id].size = total_size + self.file_transfer_records.records[id].complete = True - return web.Response( - text=f"{filename} sized of {total_size} successfully stored" - ) + return web.HTTPOk() + + #################################################################### + # Server WS Handlers + #################################################################### + + async def _ok(self, msg: Dict, ws: web.WebSocketResponse): + self.uuid_records.append(msg["data"]["uuid"]) + + async def _register_ws_client(self, msg: Dict, ws: web.WebSocketResponse): + # self.logger.debug(f"{self}: reigstered client: {msg['data']['client_id']}") + # Storing the client information + self.ws_clients[msg["data"]["client_id"]] = ws #################################################################### # IO Main Methods @@ -361,6 +385,9 @@ async def async_serve(self) -> bool: # Set flag self.running = True + # Make sure to shutdown correctly + asyncio_atexit.register(self.async_shutdown) + return True async def async_send( @@ -423,6 +450,13 @@ async def async_shutdown(self) -> bool: def serve(self, blocking: bool = True) -> Optional[Future]: + if not self._thread: + self._thread = AsyncLoopThread() + self._thread.start() + else: + if not self._thread.is_alive(): + self._thread.start() + # Cannot serve twice if self.running: self.logger.warning(f"{self}: Requested to re-serve HTTP Server") @@ -447,66 +481,6 @@ def broadcast( # clients return self._exec_coro(self.async_broadcast(signal, data, ok)) - def move_transfer_files(self, dst: pathlib.Path, unzip: bool) -> bool: - - for name, filepath_dict in self.file_transfer_records.items(): - # Create a folder for the name - named_dst = dst / name - os.mkdir(named_dst) - - # Move all the content inside - for filename, file_meta in filepath_dict.items(): - - # Extract data - filepath = file_meta["dst_filepath"] - - # Wait until filepath is completely written - success = waiting_for( - condition=lambda: filepath.exists(), - timeout=config.get("comms.timeout.zip-time-write"), - ) - - if not success: - return False - - # If not unzip, just move it - if not unzip: - shutil.move(filepath, named_dst / filename) - - # Otherwise, unzip, move content to the original folder, - # and delete the zip file - else: - shutil.unpack_archive(filepath, named_dst) - - # Handling if temp folder includes a _ in the beginning - new_filename = file_meta["filename"] - if new_filename[0] == "_": - new_filename = new_filename[1:] - new_filename = new_filename.split(".")[0] - - new_file = named_dst / new_filename - - # Wait until file is ready - miss_counter = 0 - delay = 0.5 - timeout = config.get("comms.timeout.zip-time") - - while not new_file.exists(): - time.sleep(delay) - miss_counter += 1 - if timeout < delay * miss_counter: - self.logger.error( - f"File zip unpacking took too long! - \ - {name}:{filepath}:{new_file}" - ) - return False - - for file in new_file.iterdir(): - shutil.move(file, named_dst) - shutil.rmtree(new_file) - - return True - def shutdown(self, blocking: bool = True) -> Optional[Future]: future = self._exec_coro(self.async_shutdown()) diff --git a/chimerapy/engine/node/poller_service.py b/chimerapy/engine/node/poller_service.py index 4a5f466a..0d3c6292 100644 --- a/chimerapy/engine/node/poller_service.py +++ b/chimerapy/engine/node/poller_service.py @@ -92,6 +92,8 @@ async def teardown(self): def setup_connections(self, node_pub_table: NodePubTable): + # self.logger.debug(f"{self}: setting up connections: {node_pub_table}") + # We determine all the out bound nodes for i, in_bound_id in enumerate(self.in_bound): diff --git a/chimerapy/engine/node/processor_service.py b/chimerapy/engine/node/processor_service.py index 8b390063..ffda1db0 100644 --- a/chimerapy/engine/node/processor_service.py +++ b/chimerapy/engine/node/processor_service.py @@ -252,12 +252,12 @@ async def safe_exec( else: await asyncio.sleep(1 / 1000) # Allow other functions to run as well output = func(*args, **kwargs) - toc = time.perf_counter() except Exception: traceback_info = traceback.format_exc() self.logger.error(traceback_info) # Compute delta + toc = time.perf_counter() delta = (toc - tic) * 1000 return output, delta diff --git a/chimerapy/engine/node/publisher_service.py b/chimerapy/engine/node/publisher_service.py index 8e71e26a..f892295f 100644 --- a/chimerapy/engine/node/publisher_service.py +++ b/chimerapy/engine/node/publisher_service.py @@ -52,6 +52,7 @@ def setup(self): self.state.port = self.publisher.port def publish(self, data_chunk: DataChunk): + # self.logger.debug(f"{self}: publishing {data_chunk}") self.publisher.publish(data_chunk) def teardown(self): diff --git a/chimerapy/engine/node/worker_comms_service.py b/chimerapy/engine/node/worker_comms_service.py index 49c67905..6067ab90 100644 --- a/chimerapy/engine/node/worker_comms_service.py +++ b/chimerapy/engine/node/worker_comms_service.py @@ -222,7 +222,7 @@ async def async_step(self, msg: Dict): async def enable_diagnostics(self, msg: Dict): assert self.state and self.eventbus and self.logger - enable = msg['data']['enable'] + enable = msg["data"]["enable"] event_data = EnableDiagnosticsEvent(enable) await self.eventbus.asend(Event("enable_diagnostics", event_data)) diff --git a/chimerapy/engine/utils.py b/chimerapy/engine/utils.py index 2ff81dda..b8bfd816 100644 --- a/chimerapy/engine/utils.py +++ b/chimerapy/engine/utils.py @@ -11,7 +11,6 @@ from typing import Callable, Union, Optional, Any, Dict # Third-party -from tqdm import tqdm # Internal from chimerapy.engine import _logger @@ -45,34 +44,20 @@ def clear_queue(input_queue: queue.Queue): # https://github.com/tqdm/tqdm/issues/313#issuecomment-850698822 -class logging_tqdm(tqdm): - def __init__( - self, - *args, - logger: logging.Logger = None, - mininterval: float = 1, - bar_format: str = "{desc}{percentage:3.0f}%{r_bar}", - desc: str = "progress: ", - **kwargs, - ): - self._logger = logger - super().__init__( - *args, mininterval=mininterval, bar_format=bar_format, desc=desc, **kwargs - ) - - @property - def logger(self): - if self._logger is not None: - return self._logger - return logger - - def display(self, msg=None, pos=None): - if not self.n: - # skip progress bar before having processed anything - return - if not msg: - msg = self.__str__() - self.logger.info("%s", msg) +class TqdmToLogger(object): + """Adapter to redirect tqdm output to a logger""" + + def __init__(self, logger, level=logging.INFO): + self.logger = logger + self.level = level + self.buf = "" + + def write(self, buf): + if buf.rstrip(): + self.logger.log(self.level, buf.rstrip()) + + def flush(self): + pass async def async_waiting_for( diff --git a/chimerapy/engine/worker/http_client_service.py b/chimerapy/engine/worker/http_client_service.py index 5004681a..78b9245b 100644 --- a/chimerapy/engine/worker/http_client_service.py +++ b/chimerapy/engine/worker/http_client_service.py @@ -1,5 +1,6 @@ import os import shutil +import json import uuid import traceback import socket @@ -176,9 +177,9 @@ async def _async_connect_via_ip( logs_push_info = data.get("logs_push_info", {}) if logs_push_info["enabled"]: - self.logger.info( - f"{self}: enabling logs push to Manager: {logs_push_info}" - ) + # self.logger.info( + # f"{self}: enabling logs push to Manager: {logs_push_info}" + # ) for logging_entity in [ self.logger, self.logreceiver, @@ -263,16 +264,28 @@ async def _send_archive(self, path: pathlib.Path) -> bool: else: send_locally = False + success = False try: if send_locally: await self._send_archive_locally(path) else: await self._send_archive_remotely(manager_host, manager_port) + success = True except Exception: self.logger.error(traceback.format_exc()) - return False - return True + # Send information to manager about success + if self.connected_to_manager: + data = {"worker_id": self.state.id, "success": success} + async with aiohttp.ClientSession(self.manager_url) as session: + async with session.post( + "/workers/send_archive", data=json.dumps(data) + ) as _: + ... + # self.logger.debug(f"{self}: send " + # "archive update confirmation: {resp.ok}") + + return success async def _send_archive_locally(self, path: pathlib.Path) -> pathlib.Path: self.logger.debug(f"{self}: sending archive locally") @@ -327,9 +340,12 @@ async def _async_node_status_update(self) -> bool: if not self.connected_to_manager: return False - async with aiohttp.ClientSession(self.manager_url) as session: - async with session.post( - "/workers/node_status", data=self.state.to_json() - ) as resp: + try: + async with aiohttp.ClientSession(self.manager_url) as session: + async with session.post( + "/workers/node_status", data=self.state.to_json() + ) as resp: - return resp.ok + return resp.ok + except aiohttp.client_exceptions.ClientOSError: + return False diff --git a/chimerapy/engine/worker/http_server_service.py b/chimerapy/engine/worker/http_server_service.py index 0697c678..16cf02a2 100644 --- a/chimerapy/engine/worker/http_server_service.py +++ b/chimerapy/engine/worker/http_server_service.py @@ -1,4 +1,3 @@ -import sys import pickle import asyncio import enum @@ -8,7 +7,6 @@ from aiohttp import web -from chimerapy.engine import config from ..service import Service from ..states import NodeState, WorkerState from ..data_protocols import ( @@ -19,7 +17,7 @@ from ..networking import Server from ..networking.async_loop_thread import AsyncLoopThread from ..networking.enums import NODE_MESSAGE -from ..utils import async_waiting_for, update_dataclass +from ..utils import update_dataclass from ..eventbus import EventBus, Event, TypedObserver from .events import ( EnableDiagnosticsEvent, @@ -72,7 +70,7 @@ def __init__( web.post("/nodes/registered_methods", self._async_request_method_route), web.post("/nodes/stop", self._async_stop_nodes_route), web.post("/nodes/diagnostics", self._async_diagnostics_route), - web.post("/packages/load", self._async_load_sent_packages), + # web.post("/packages/load", self._async_load_sent_packages), web.post("/shutdown", self._async_shutdown_route), ], ws_handlers={ @@ -157,58 +155,67 @@ def _create_node_pub_table(self) -> NodePubTable: return node_pub_table + async def _collect_and_send(self, path: pathlib.Path): + # Collect data from the Nodes + await self.eventbus.asend(Event("collect")) + + # After collecting, request to send the archive + event_data = SendArchiveEvent(path) + await self.eventbus.asend(Event("send_archive", event_data)) + #################################################################### ## HTTP Routes #################################################################### - async def _async_load_sent_packages(self, request: web.Request) -> web.Response: - msg = await request.json() - - # For each package, extract it from the client's tempfolder - # and load it to the sys.path - for sent_package in msg["packages"]: - - # Wait until the sent package are started - success = await async_waiting_for( - condition=lambda: f"{sent_package}.zip" - in self.server.file_transfer_records["Manager"], - timeout=config.get("worker.timeout.package-delivery"), - ) - - if success: - self.logger.debug( - f"{self}: Waiting for package {sent_package}: SUCCESS" - ) - else: - self.logger.error(f"{self}: Waiting for package {sent_package}: FAILED") - return web.HTTPError() - - # Get the path - package_zip_path = self.server.file_transfer_records["Manager"][ - f"{sent_package}.zip" - ]["dst_filepath"] - - # Wait until the sent package is complete - success = await async_waiting_for( - condition=lambda: self.server.file_transfer_records["Manager"][ - f"{sent_package}.zip" - ]["complete"] - is True, - timeout=config.get("worker.timeout.package-delivery"), - ) - - if success: - self.logger.debug(f"{self}: Package {sent_package} loading: SUCCESS") - else: - self.logger.debug(f"{self}: Package {sent_package} loading: FAILED") - - assert ( - package_zip_path.exists() - ), f"{self}: {package_zip_path} doesn't exists!?" - sys.path.insert(0, str(package_zip_path)) - - # Send message back to the Manager letting them know that - return web.HTTPOk() + # async def _async_load_sent_packages(self, request: web.Request) -> web.Response: + # msg = await request.json() + + # # For each package, extract it from the client's tempfolder + # # and load it to the sys.path + # for sent_package in msg["packages"]: + + # # Wait until the sent package are started + # success = await async_waiting_for( + # condition=lambda: f"{sent_package}.zip" + # in self.server.file_transfer_records["Manager"], + # timeout=config.get("worker.timeout.package-delivery"), + # ) + + # if success: + # self.logger.debug( + # f"{self}: Waiting for package {sent_package}: SUCCESS" + # ) + # else: + # self.logger.error(f"{self}: Waiting for " + # "package {sent_package}: FAILED") + # return web.HTTPError() + + # # Get the path + # package_zip_path = self.server.file_transfer_records["Manager"][ + # f"{sent_package}.zip" + # ]["dst_filepath"] + + # # Wait until the sent package is complete + # success = await async_waiting_for( + # condition=lambda: self.server.file_transfer_records["Manager"][ + # f"{sent_package}.zip" + # ]["complete"] + # is True, + # timeout=config.get("worker.timeout.package-delivery"), + # ) + + # if success: + # self.logger.debug(f"{self}: Package {sent_package} loading: SUCCESS") + # else: + # self.logger.debug(f"{self}: Package {sent_package} loading: FAILED") + + # assert ( + # package_zip_path.exists() + # ), f"{self}: {package_zip_path} doesn't exists!?" + # sys.path.insert(0, str(package_zip_path)) + + # # Send message back to the Manager letting them know that + # return web.HTTPOk() async def _async_create_node_route(self, request: web.Request) -> web.Response: msg_bytes = await request.read() @@ -282,22 +289,15 @@ async def _async_report_node_gather(self, request: web.Request) -> web.Response: async def _async_collect(self, request: web.Request) -> web.Response: data = await request.json() - - # Collect data from the Nodes - await self.eventbus.asend(Event("collect")) - - # After collecting, request to send the archive - event_data = SendArchiveEvent(pathlib.Path(data["path"])) - await self.eventbus.asend(Event("send_archive", event_data)) - + asyncio.create_task(self._collect_and_send(pathlib.Path(data["path"]))) return web.HTTPOk() async def _async_diagnostics_route(self, request: web.Request) -> web.Response: - data= await request.json() + data = await request.json() # Determine if enable/disable - event_data = EnableDiagnosticsEvent(data['enable']) - await self.eventbus.asend(Event('diagnostics', event_data)) + event_data = EnableDiagnosticsEvent(data["enable"]) + await self.eventbus.asend(Event("diagnostics", event_data)) return web.HTTPOk() async def _async_shutdown_route(self, request: web.Request) -> web.Response: diff --git a/chimerapy/engine/worker/node_handler_service.py b/chimerapy/engine/worker/node_handler_service.py index 67f0bbaf..aee47d36 100644 --- a/chimerapy/engine/worker/node_handler_service.py +++ b/chimerapy/engine/worker/node_handler_service.py @@ -173,6 +173,9 @@ def __init__( "start_nodes": TypedObserver( "start_nodes", on_asend=self.async_start_nodes, handle_event="drop" ), + "stop_nodes": TypedObserver( + "stop_nodes", on_asend=self.async_stop_nodes, handle_event="drop" + ), "record_nodes": TypedObserver( "record_nodes", on_asend=self.async_record_nodes, handle_event="drop" ), @@ -189,7 +192,10 @@ def __init__( "gather_nodes", on_asend=self.async_gather, handle_event="drop" ), "diagnostics": TypedObserver( - "diagnostics", EnableDiagnosticsEvent, on_asend=self.async_diagnostics, handle_event='unpack' + "diagnostics", + EnableDiagnosticsEvent, + on_asend=self.async_diagnostics, + handle_event="unpack", ), "update_gather": TypedObserver( "update_gather", @@ -401,6 +407,14 @@ async def async_stop_nodes(self) -> bool: await self.eventbus.asend( Event("broadcast", BroadcastEvent(signal=WORKER_MESSAGE.STOP_NODES)) ) + await async_waiting_for( + lambda: all( + [ + x.fsm in ["STOPPED", "SAVED", "SHUTDOWN"] + for x in self.state.nodes.values() + ] + ) + ) return True async def async_request_registered_method( @@ -432,7 +446,12 @@ async def async_request_registered_method( async def async_diagnostics(self, enable: bool) -> bool: await self.eventbus.asend( - Event('broadcast', BroadcastEvent(signal=WORKER_MESSAGE.DIAGNOSTICS, data={'enable': enable})) + Event( + "broadcast", + BroadcastEvent( + signal=WORKER_MESSAGE.DIAGNOSTICS, data={"enable": enable} + ), + ) ) return True @@ -496,7 +515,7 @@ async def async_collect(self) -> bool: if await async_waiting_for( condition=lambda: self.state.nodes[node_id].fsm == "SAVED", - timeout=config.get("worker.timeout.info-request"), + timeout=None, ): # self.logger.debug( # f"{self}: Node {node_id} responded to saving request: SUCCESS" diff --git a/examples/remote_camera.py b/examples/remote_camera.py index d8b0d2a1..f969836f 100644 --- a/examples/remote_camera.py +++ b/examples/remote_camera.py @@ -31,7 +31,7 @@ class ShowWindow(cpe.Node): def step(self, data_chunks: Dict[str, cpe.DataChunk]): for name, data_chunk in data_chunks.items(): - self.logger.debug(f"{self}: got from {name}, data={data_chunk}") + # self.logger.debug(f"{self}: got from {name}, data={data_chunk}") cv2.imshow(name, data_chunk.get("frame")["value"]) cv2.waitKey(1) @@ -52,7 +52,7 @@ def __init__(self): # Create default manager and desired graph manager = cpe.Manager(logdir=CWD / "runs") - graph = RemoteCameraGraph() + manager.zeroconf() worker = cpe.Worker(name="local", id="local") # Then register graph to Manager @@ -66,26 +66,49 @@ def __init__(self): # Assuming one worker # mapping = {"remote": [graph.web.id], worker.id: [graph.show.id]} - mapping = {worker.id: graph.node_ids} + # For local only + if len(manager.workers) == 1: + graph = RemoteCameraGraph() + mapping = {worker.id: graph.node_ids} + else: + + # For mutliple workers (remote and local) + graph = cpe.Graph() + show_node = ShowWindow(name="show") + graph.add_node(show_node) + mapping = {worker.id: [show_node.id]} + + for worker_id in manager.workers: + if worker_id == "local": + continue + else: + web_node = WebcamNode(name=f"web-{worker_id}") + graph.add_nodes_from([web_node]) + graph.add_edges_from([(web_node, show_node)]) + mapping[worker_id] = [web_node.id] # Commit the graph - manager.commit_graph(graph=graph, mapping=mapping).result(timeout=60) - manager.start().result(timeout=5) - - # Wail until user stops - while True: - q = input("Ready to start? (Y/n)") - if q.lower() == "y": - break - - manager.record().result(timeout=5) - - # Wail until user stops - while True: - q = input("Stop? (Y/n)") - if q.lower() == "y": - break - - manager.stop().result(timeout=5) - manager.collect().result() - manager.shutdown() + try: + assert manager.commit_graph(graph=graph, mapping=mapping).result(timeout=60) + assert manager.start().result(timeout=5) + + # Wail until user stops + # while True: + # q = input("Ready to start? (Y/n)") + # if q.lower() == "y": + # break + + assert manager.record().result(timeout=5) + + # Wail until user stops + while True: + q = input("Stop? (Y/n)") + if q.lower() == "y": + break + + assert manager.stop().result(timeout=5) + assert manager.collect().result() + except Exception: + print("System failed") + finally: + manager.shutdown() diff --git a/examples/remote_screen_and_web.py b/examples/remote_screen_and_web.py index 9eeccc38..1cc814b6 100644 --- a/examples/remote_screen_and_web.py +++ b/examples/remote_screen_and_web.py @@ -34,7 +34,6 @@ def teardown(self): class ScreenCaptureNode(cpe.Node): def setup(self): - if platform.system() == "Windows": import dxcam @@ -94,6 +93,7 @@ def __init__(self): # Create default manager and desired graph manager = cpe.Manager(logdir=CWD / "runs", port=9000) + manager.zeroconf() worker = cpe.Worker(name="local", id="local", port=0) # Then register graph to Manager @@ -111,6 +111,11 @@ def __init__(self): mapping = {worker.id: graph.node_ids} else: + print( + "WARNING: ScreenCaptureNode is faulty for this " + "configuration for unknown reasons" + ) + # For mutliple workers (remote and local) graph = cpe.Graph() show_node = ShowWindow(name="show") @@ -128,7 +133,7 @@ def __init__(self): mapping[worker_id] = [web_node.id, screen_node.id] # Commit the graph - manager.commit_graph(graph=graph, mapping=mapping).result(timeout=60) + manager.commit_graph(graph=graph, mapping=mapping).result() manager.start().result(timeout=5) # Wail until user stops diff --git a/examples/remote_simple.py b/examples/remote_simple.py index 7007398a..7f0733c7 100644 --- a/examples/remote_simple.py +++ b/examples/remote_simple.py @@ -1,4 +1,4 @@ -from typing import Dict, Any +from typing import Dict import time import pathlib import os @@ -14,6 +14,7 @@ def setup(self): def step(self): time.sleep(1) + self.logger.debug("Producer step") current_counter = self.counter self.counter += 1 data_chunk = cpe.DataChunk() @@ -22,9 +23,9 @@ def step(self): class Consumer(cpe.Node): - def step(self, data: Dict[str, Any]): - d = data["prod"].get("counter")["value"] - print("Consumer got data: ", d) + def step(self, data_chunks: Dict[str, cpe.DataChunk]): + d = data_chunks["prod"].get("counter")["value"] + self.logger.debug(f"Consumer step: {d}") class SimpleGraph(cpe.Graph): @@ -42,8 +43,8 @@ def __init__(self): # Create default manager and desired graph manager = cpe.Manager(logdir=CWD / "runs") - graph = SimpleGraph() - worker = cpe.Worker(name="local") + manager.zeroconf() + worker = cpe.Worker(name="local", id="local") # Then register graph to Manager worker.connect(host=manager.host, port=manager.port) @@ -54,14 +55,29 @@ def __init__(self): if q.lower() == "y": break - # Assuming one worker - # mapping = {worker.id: graph.node_ids} - # mapping = {worker.id: [graph.prod.id], 'remote': [graph.cons.id]} - mapping = {worker.id: [graph.cons.id], "remote": [graph.prod.id]} + # For local only + if len(manager.workers) == 1: + graph = SimpleGraph() + mapping = {worker.id: graph.node_ids} + else: + + # For mutliple workers (remote and local) + graph = cpe.Graph() + con_node = Consumer(name="cons") + graph.add_nodes_from([con_node]) + mapping = {worker.id: [con_node.id]} + + for worker_id in manager.workers: + if worker_id == "local": + continue + else: + prod_node = Producer(name="prod") + graph.add_nodes_from([prod_node]) + graph.add_edge(src=prod_node, dst=con_node) + mapping[worker_id] = [prod_node.id] # Commit the graph manager.commit_graph(graph=graph, mapping=mapping).result(timeout=60) - manager.start().result(timeout=5) # Wail until user stops while True: @@ -69,6 +85,7 @@ def __init__(self): if q.lower() == "y": break + manager.start().result(timeout=5) manager.record().result(timeout=5) # Wail until user stops diff --git a/pyproject.toml b/pyproject.toml index 0460a29a..01d62740 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,8 @@ dependencies = [ 'multiprocess', 'opencv-python', 'pyaudio', + 'aioshutil', + 'asyncio-atexit', 'pandas', 'tqdm', 'pyzmq', @@ -53,11 +55,9 @@ test = [ 'numpy', 'imutils', 'pillow', - 'requests', 'bump2version' ] types = [ - 'types-requests', 'types-PyYAML', 'types-aiofiles' ] @@ -109,7 +109,9 @@ ignore_missing_imports = true # Reference: # https://stackoverflow.com/questions/4673373/logging-within-pytest-tests + [tool.pytest.ini_options] +asyncio_mode = 'auto' # Logging + CLI log_cli = true diff --git a/test/front_end_integration/test_ws.py b/test/front_end_integration/test_ws.py index d3201db5..7e6b3337 100644 --- a/test/front_end_integration/test_ws.py +++ b/test/front_end_integration/test_ws.py @@ -89,14 +89,14 @@ def test_node_creation_and_destruction_network_updates(test_ws_client, manager, timeout=30 ) time.sleep(2) - assert record.network_state.to_json() == manager.state.to_json() + # assert record.network_state.to_json() == manager.state.to_json() # Test destruction manager._request_node_destruction(worker_id=worker.id, node_id="Gen1").result( timeout=10 ) time.sleep(2) - assert record.network_state.to_json() == manager.state.to_json() + assert record.network_state.workers[worker.id].nodes == {} def test_reset_network_updates(test_ws_client, manager, worker): diff --git a/test/manager/test_http_server_service.py b/test/manager/test_http_server_service.py index b1df485e..0f666120 100644 --- a/test/manager/test_http_server_service.py +++ b/test/manager/test_http_server_service.py @@ -1,4 +1,5 @@ import requests +import json import pytest @@ -52,6 +53,7 @@ def test_http_server_instanciate(http_server): "/workers/node_status", WorkerState(id="NULL", name="NULL", tempfolder=TEST_DATA_DIR).to_json(), ), + ("/workers/send_archive", json.dumps({"worker_id": "test", "success": True})), ], ) def test_http_server_routes(http_server, route, payload): diff --git a/test/manager/test_manager.py b/test/manager/test_manager.py index 383a50bb..0b6318ed 100644 --- a/test/manager/test_manager.py +++ b/test/manager/test_manager.py @@ -4,7 +4,6 @@ import pytest -from chimerapy.engine import config import chimerapy.engine as cpe from ..conftest import GenNode, ConsumeNode, TEST_DATA_DIR @@ -68,14 +67,11 @@ def test_sending_package(self, manager, _worker, config_graph): for node_id in config_graph.G.nodes(): assert manager.workers[_worker.id].nodes[node_id].fsm != "NULL" - # @pytest.mark.parametrize("context", ["multiprocessing", "threading"]) - @pytest.mark.parametrize("context", ["threading"]) + @pytest.mark.parametrize("context", ["multiprocessing", "threading"]) + # @pytest.mark.parametrize("context", ["multiprocessing"]) def test_manager_lifecycle(self, manager_with_worker, context): manager, worker = manager_with_worker - # Enable diagnostics - config.set("diagnostics.logging-enabled", True) - # Define graph gen_node = GenNode(name="Gen1") con_node = ConsumeNode(name="Con1") @@ -114,9 +110,6 @@ def test_manager_reset(self, manager_with_worker): manager.commit_graph(graph=simple_graph, mapping=mapping).result(timeout=30) assert manager.start().result() - - time.sleep(3) - assert manager.record().result() time.sleep(3) diff --git a/test/manager/test_worker_handler_service.py b/test/manager/test_worker_handler_service.py index e8a2ffe4..3bf64f00 100644 --- a/test/manager/test_worker_handler_service.py +++ b/test/manager/test_worker_handler_service.py @@ -67,7 +67,6 @@ def test_instanticate(testbed_setup): ... -@pytest.mark.asyncio async def test_worker_handler_create_node(testbed_setup): worker_handler, worker, simple_graph = testbed_setup @@ -83,12 +82,12 @@ async def test_worker_handler_create_node(testbed_setup): assert await worker_handler._request_node_destruction( worker_id=worker.id, node_id="Gen1" ) - await asyncio.sleep(3) + await asyncio.sleep(1) + assert "Gen1" not in worker.state.nodes assert "Gen1" not in worker_handler.state.workers[worker.id].nodes -@pytest.mark.asyncio async def test_worker_handler_create_connections(testbed_setup): worker_handler, worker, simple_graph = testbed_setup @@ -114,7 +113,6 @@ async def test_worker_handler_create_connections(testbed_setup): assert await worker_handler.reset() -@pytest.mark.asyncio async def test_worker_handler_lifecycle_graph(testbed_setup): worker_handler, worker, simple_graph = testbed_setup @@ -135,7 +133,6 @@ async def test_worker_handler_lifecycle_graph(testbed_setup): assert await worker_handler.reset() -@pytest.mark.asyncio async def test_worker_handler_enable_diagnostics(testbed_setup): worker_handler, worker, simple_graph = testbed_setup diff --git a/test/manager/test_zeroconf_service.py b/test/manager/test_zeroconf_service.py index 44085c2f..751a8702 100644 --- a/test/manager/test_zeroconf_service.py +++ b/test/manager/test_zeroconf_service.py @@ -58,7 +58,6 @@ def zeroconf_service(): return zeroconf_service -@pytest.mark.asyncio async def test_enable_and_disable_zeroconf(zeroconf_service): assert await zeroconf_service.enable() @@ -66,7 +65,6 @@ async def test_enable_and_disable_zeroconf(zeroconf_service): assert await zeroconf_service.disable() -@pytest.mark.asyncio async def test_zeroconf_connect(zeroconf_service): assert await zeroconf_service.enable() diff --git a/test/networking/test_client_server.py b/test/networking/test_client_server.py index 642abc7d..d78d998b 100644 --- a/test/networking/test_client_server.py +++ b/test/networking/test_client_server.py @@ -1,9 +1,10 @@ from typing import Dict -import time +import tempfile import pathlib +import asyncio import os import enum -import requests +import aiohttp from aiohttp import web import pytest @@ -33,33 +34,33 @@ async def echo(msg: Dict, ws: web.WebSocketResponse = None): @pytest.fixture -def server(): +async def server(): server = Server( id="test_server", port=0, routes=[web.get("/", hello)], ws_handlers={TEST_PROTOCOL.ECHO_FLAG: echo}, ) - server.serve() + await server.async_serve() yield server - server.shutdown() + await server.async_shutdown() @pytest.fixture -def client(server): +async def client(server): client = Client( id="test_client", host=server.host, port=server.port, ws_handlers={TEST_PROTOCOL.ECHO_FLAG: echo}, ) - client.connect() + await client.async_connect() yield client - client.shutdown() + await client.async_shutdown() @pytest.fixture -def client_list(server): +async def client_list(server): clients = [] for i in range(NUMBER_OF_CLIENTS): @@ -69,29 +70,30 @@ def client_list(server): id=f"test-{i}", ws_handlers={TEST_PROTOCOL.ECHO_FLAG: echo}, ) - client.connect() + await client.async_connect() clients.append(client) yield clients for client in clients: - client.shutdown() + await client.async_shutdown() -def test_server_instanciate(server): +async def test_server_instanciate(server): ... - - -def test_server_http_req_res(server): - r = requests.get(f"http://{server.host}:{server.port}") - assert r.status_code == 200 and r.text == "Hello, world" -def test_server_websocket_connection(server, client): +async def test_server_http_req_res(server): + url = f"http://{server.host}:{server.port}" + async with aiohttp.ClientSession(url) as session: + async with session.get("/") as resp: + assert resp.ok + + +async def test_server_websocket_connection(server, client): assert client.id in list(server.ws_clients.keys()) -@pytest.mark.asyncio async def test_async_ws(server): client = Client( id="test_client", @@ -104,57 +106,59 @@ async def test_async_ws(server): await client.async_shutdown() -def test_server_websocket_connection_shutdown(server, client): - client.shutdown() - time.sleep(1) - server.broadcast(signal=TEST_PROTOCOL.ECHO_FLAG, data="ECHO!") +async def test_server_websocket_connection_shutdown(server, client): + await client.async_shutdown() + await asyncio.sleep(0.1) + await server.async_broadcast(signal=TEST_PROTOCOL.ECHO_FLAG, data="ECHO!") -def test_server_send_to_client(server, client): +async def test_server_send_to_client(server, client): # Simple send - server.send(client_id=client.id, signal=TEST_PROTOCOL.ECHO_FLAG, data="HELLO") + await server.async_send( + client_id=client.id, signal=TEST_PROTOCOL.ECHO_FLAG, data="HELLO" + ) # Simple send with OK - server.send( + await server.async_send( client_id=client.id, signal=TEST_PROTOCOL.ECHO_FLAG, data="HELLO", ok=True ) - assert cpe.utils.waiting_for( + assert await cpe.utils.async_waiting_for( lambda: client.msg_processed_counter >= 2, timeout=2, ) -def test_client_send_to_server(server, client): +async def test_client_send_to_server(server, client): # Simple send - client.send(signal=TEST_PROTOCOL.ECHO_FLAG, data="HELLO") + await client.async_send(signal=TEST_PROTOCOL.ECHO_FLAG, data="HELLO") # Simple send with OK - client.send(signal=TEST_PROTOCOL.ECHO_FLAG, data="HELLO", ok=True) + await client.async_send(signal=TEST_PROTOCOL.ECHO_FLAG, data="HELLO", ok=True) - assert cpe.utils.waiting_for( + assert await cpe.utils.async_waiting_for( lambda: server.msg_processed_counter >= 2, timeout=2, ) -def test_multiple_clients_send_to_server(server, client_list): +async def test_multiple_clients_send_to_server(server, client_list): for client in client_list: - client.send(signal=TEST_PROTOCOL.ECHO_FLAG, data="ECHO!", ok=True) + await client.async_send(signal=TEST_PROTOCOL.ECHO_FLAG, data="ECHO!", ok=True) - assert cpe.utils.waiting_for( + assert await cpe.utils.async_waiting_for( lambda: server.msg_processed_counter >= NUMBER_OF_CLIENTS, timeout=5, ) -def test_server_broadcast_to_multiple_clients(server, client_list): +async def test_server_broadcast_to_multiple_clients(server, client_list): - server.broadcast(signal=TEST_PROTOCOL.ECHO_FLAG, data="ECHO!", ok=True) + await server.async_broadcast(signal=TEST_PROTOCOL.ECHO_FLAG, data="ECHO!", ok=True) for client in client_list: - assert cpe.utils.waiting_for( + assert await cpe.utils.async_waiting_for( lambda: client.msg_processed_counter >= 2, timeout=5, ) @@ -167,21 +171,16 @@ def test_server_broadcast_to_multiple_clients(server, client_list): (TEST_DIR / "mock" / "data" / "chimerapy_logs"), ], ) -def test_client_sending_folder_to_server(server, client, dir): +async def test_client_sending_folder_to_server(server, client, dir): # Action - client.send_folder(sender_id="test_worker", dir=dir).result(timeout=10) - - # Get the expected behavior - miss_counter = 0 - while len(server.file_transfer_records.keys()) == 0: - - miss_counter += 1 - time.sleep(0.1) - - if miss_counter > 100: - assert False, "File transfer failed after 10 second" + await client.async_send_folder(sender_id="test_worker", dir=dir) # Also check that the file exists - for record in server.file_transfer_records["test_worker"].values(): - assert record["dst_filepath"].exists() + for record in server.file_transfer_records.records.values(): + assert record.location.exists() + + # Test moving the files to a logdir + temp = pathlib.Path(tempfile.mkdtemp()) + await server.move_transferred_files(temp) + await server.move_transferred_files(temp, owner="test_worker", owner_id="asdf") diff --git a/test/node/test_processor_service.py b/test/node/test_processor_service.py index 5a6d77e7..06ffffb4 100644 --- a/test/node/test_processor_service.py +++ b/test/node/test_processor_service.py @@ -164,7 +164,6 @@ def test_instanticate(processor_setup): lazy_fixture("step_processor"), ], ) -@pytest.mark.asyncio async def test_setup(processor_setup): processor, _ = processor_setup await processor.setup() @@ -178,7 +177,6 @@ async def test_setup(processor_setup): ("step", lazy_fixture("step_processor")), ], ) -@pytest.mark.asyncio async def test_main(ptype, processor_setup): processor, eventbus = processor_setup diff --git a/test/node/test_profiling_service.py b/test/node/test_profiling_service.py index 2242a302..a0a3b183 100644 --- a/test/node/test_profiling_service.py +++ b/test/node/test_profiling_service.py @@ -1,5 +1,3 @@ -import datetime -import time import pathlib import tempfile import random @@ -22,7 +20,6 @@ logger = cpe._logger.getLogger("chimerapy-engine") -@pytest.mark.asyncio @pytest.fixture def profiler_setup(): @@ -53,7 +50,6 @@ def test_instanciate(profiler_setup): ... -@pytest.mark.asyncio async def test_single_data_chunk(profiler_setup): profiler, eventbus = profiler_setup await profiler.enable() @@ -67,16 +63,17 @@ async def test_single_data_chunk(profiler_setup): # Mock how the processor marks the time when it got the datachunk # and transmitted it meta = example_data_chunk.get("meta") - meta["value"]["delta"] = random.randrange(500, 1500, 1) # ms + meta["value"]["delta"] = random.randrange(500, 1500, 1) # ms example_data_chunk.update("meta", meta) - await eventbus.asend(Event("out_step", NewOutBoundDataEvent(example_data_chunk))) + await eventbus.asend( + Event("out_step", NewOutBoundDataEvent(example_data_chunk)) + ) await profiler.diagnostics_report() assert profiler.log_file.exists() -@pytest.mark.asyncio async def test_single_data_chunk_with_multiple_payloads(profiler_setup): profiler, eventbus = profiler_setup await profiler.enable() @@ -94,16 +91,17 @@ async def test_single_data_chunk_with_multiple_payloads(profiler_setup): meta["value"]["delta"] = random.randrange(500, 1500, 1) example_data_chunk.update("meta", meta) - await eventbus.asend(Event("out_step", NewOutBoundDataEvent(example_data_chunk))) + await eventbus.asend( + Event("out_step", NewOutBoundDataEvent(example_data_chunk)) + ) await profiler.diagnostics_report() assert profiler.log_file.exists() -@pytest.mark.asyncio async def test_enable_disable(profiler_setup): profiler, eventbus = profiler_setup - + for i in range(50): # Run the step multiple times @@ -113,14 +111,16 @@ async def test_enable_disable(profiler_setup): # Mock how the processor marks the time when it got the datachunk # and transmitted it meta = example_data_chunk.get("meta") - meta["value"]["delta"] = random.randrange(500, 1500, 1) # ms + meta["value"]["delta"] = random.randrange(500, 1500, 1) # ms example_data_chunk.update("meta", meta) - await eventbus.asend(Event("out_step", NewOutBoundDataEvent(example_data_chunk))) + await eventbus.asend( + Event("out_step", NewOutBoundDataEvent(example_data_chunk)) + ) assert len(profiler.seen_uuids) == 0 await profiler.enable(True) - + for i in range(50): # Run the step multiple times @@ -130,13 +130,14 @@ async def test_enable_disable(profiler_setup): # Mock how the processor marks the time when it got the datachunk # and transmitted it meta = example_data_chunk.get("meta") - meta["value"]["delta"] = random.randrange(500, 1500, 1) # ms + meta["value"]["delta"] = random.randrange(500, 1500, 1) # ms example_data_chunk.update("meta", meta) - await eventbus.asend(Event("out_step", NewOutBoundDataEvent(example_data_chunk))) + await eventbus.asend( + Event("out_step", NewOutBoundDataEvent(example_data_chunk)) + ) await profiler.diagnostics_report() await profiler.enable(False) assert len(profiler.seen_uuids) != 0 assert profiler.log_file.exists() - diff --git a/test/node/test_record_service.py b/test/node/test_record_service.py index e571f22b..60811de1 100644 --- a/test/node/test_record_service.py +++ b/test/node/test_record_service.py @@ -40,7 +40,6 @@ def test_instanciate(recorder): ... -@pytest.mark.asyncio async def test_record_direct_submit(recorder): # Run the recorder diff --git a/test/node/test_worker_comms.py b/test/node/test_worker_comms.py index a250bea9..f80fed58 100644 --- a/test/node/test_worker_comms.py +++ b/test/node/test_worker_comms.py @@ -95,7 +95,7 @@ def test_instanticate(worker_comms_setup): ... -# @pytest.mark.asyncio +# @ # async def test_setup(worker_comms_setup): # worker_comms, server = worker_comms_setup @@ -120,7 +120,6 @@ def test_instanticate(worker_comms_setup): ("enable_diagnostics", {"data": {"enable": True}}), ], ) -@pytest.mark.asyncio async def test_methods(worker_comms_setup, method_name, method_params): worker_comms, _ = worker_comms_setup @@ -149,7 +148,6 @@ async def test_methods(worker_comms_setup, method_name, method_params): (WORKER_MESSAGE.DIAGNOSTICS, {"enable": False}), ], ) -@pytest.mark.asyncio async def test_ws_signals(worker_comms_setup, signal, data): worker_comms, server = worker_comms_setup diff --git a/test/test_async_timer.py b/test/test_async_timer.py index e0e9cc6c..2f2a02c4 100644 --- a/test/test_async_timer.py +++ b/test/test_async_timer.py @@ -21,7 +21,6 @@ def timer(): return timer -@pytest.mark.asyncio async def test_async_timer(timer): await timer.start() await asyncio.sleep(2) diff --git a/test/test_eventbus.py b/test/test_eventbus.py index 3e0f9398..d711b2f3 100644 --- a/test/test_eventbus.py +++ b/test/test_eventbus.py @@ -55,7 +55,6 @@ def event_bus(): return event_bus -@pytest.mark.asyncio async def test_msg_filtering(): event_bus = EventBus() @@ -76,7 +75,6 @@ async def test_msg_filtering(): assert hello_event.id in hello_observer.received -@pytest.mark.asyncio async def test_event_null_data(): event_bus = EventBus() @@ -97,9 +95,8 @@ async def test_event_null_data(): assert null_event.id in null_observer.received -@pytest.mark.asyncio async def test_subscribe_and_unsubscribe(): - + event_bus = EventBus() null_observer = TypedObserver("null") @@ -121,8 +118,21 @@ async def test_subscribe_and_unsubscribe(): assert null2_event.id not in null_observer.received +async def test_awaitable_event(): + + event_bus = EventBus() + null_event = Event("null") + + async def later_event(): + await asyncio.sleep(1) + await event_bus.asend(null_event) + + asyncio.create_task(later_event()) + + null2_event = await event_bus.await_event("null") + assert null2_event == null_event + -@pytest.mark.asyncio async def test_sync_and_async_binding(): event_bus = EventBus() @@ -159,7 +169,6 @@ async def async_add_to(_): assert len(async_local_variable) != 0 -@pytest.mark.asyncio async def test_event_handling(): event_bus = EventBus() @@ -205,7 +214,6 @@ async def drop_func(): assert len(drop_variable) != 0 -@pytest.mark.asyncio async def test_evented_dataclass(event_bus): # Creating the observer and its binding @@ -234,7 +242,6 @@ async def add_to(event): assert isinstance(data.to_json(), str) -@pytest.mark.asyncio async def test_evented_wrapper(event_bus): # Creating the observer and its binding @@ -286,7 +293,6 @@ def test_make_evented_multiple(event_bus): make_evented(SomeClass(number=1, string="hello"), event_bus=event_bus) -@pytest.mark.asyncio async def test_make_evented_nested(event_bus): data_class = NestedClass( number=1, diff --git a/test/worker/test_http_client.py b/test/worker/test_http_client.py index 950234ff..a19d7f6b 100644 --- a/test/worker/test_http_client.py +++ b/test/worker/test_http_client.py @@ -51,31 +51,27 @@ def http_client(): ) yield http_client - eventbus.send(Event("shutdown")) + eventbus.send(Event("shutdown")).result() def test_http_client_instanciate(http_client): ... -@pytest.mark.asyncio async def test_connect_via_ip(http_client, manager): assert await http_client._async_connect_via_ip(host=manager.host, port=manager.port) -@pytest.mark.asyncio async def test_connect_via_zeroconf(http_client, manager): await manager.async_zeroconf() assert await http_client._async_connect_via_zeroconf() -@pytest.mark.asyncio async def test_node_status_update(http_client, manager): assert await http_client._async_connect_via_ip(host=manager.host, port=manager.port) assert await http_client._async_node_status_update() -@pytest.mark.asyncio async def test_worker_state_changed_updates(http_client, manager): assert await http_client._async_connect_via_ip(host=manager.host, port=manager.port) @@ -90,7 +86,6 @@ async def test_worker_state_changed_updates(http_client, manager): assert "test" in manager.state.workers[http_client.state.id].nodes -@pytest.mark.asyncio async def test_send_archive_locally(http_client): # Adding simple file @@ -117,7 +112,6 @@ async def test_send_archive_locally(http_client): assert f.read() == "hello" -@pytest.mark.asyncio async def test_send_archive_remotely(http_client, server): # Make a copy of example logs @@ -128,8 +122,5 @@ async def test_send_archive_remotely(http_client, server): assert await http_client._send_archive_remotely(server.host, server.port) - logger.debug(server.file_transfer_records) - - # Also check that the file exists - for record in server.file_transfer_records[http_client.state.id].values(): - assert record["dst_filepath"].exists() + for record in server.file_transfer_records.records.values(): + assert record.sender_id == http_client.state.name diff --git a/test/worker/test_http_server.py b/test/worker/test_http_server.py index 47a54e54..16e9d1a9 100644 --- a/test/worker/test_http_server.py +++ b/test/worker/test_http_server.py @@ -66,7 +66,6 @@ def test_http_server_instanciate(http_server): ... -@pytest.mark.asyncio @pytest.mark.parametrize( "route_type, route, payload", [ @@ -86,7 +85,7 @@ def test_http_server_instanciate(http_server): json.dumps({"node_id": "1", "method_name": "a", "params": {}}), ), ("post", "/nodes/stop", json.dumps({})), - ("post", "/packages/load", json.dumps({"packages": []})), + # ("post", "/packages/load", json.dumps({"packages": []})), # ("post", "/shutdown", json.dumps({})), ], ) @@ -103,7 +102,6 @@ async def test_http_server_routes(http_server, route_type, route, payload): assert resp.ok -@pytest.mark.asyncio @pytest.mark.parametrize( "signal, payload", [ diff --git a/test/worker/test_node_handler.py b/test/worker/test_node_handler.py index d0f894f8..5f6e9b46 100644 --- a/test/worker/test_node_handler.py +++ b/test/worker/test_node_handler.py @@ -109,7 +109,6 @@ def test_create_service_instance(node_handler_setup): ... -@pytest.mark.asyncio # @pytest.mark.parametrize("context", ["multiprocessing"]) # , "threading"]) @pytest.mark.parametrize("context", ["multiprocessing", "threading"]) async def test_create_node(gen_node, node_handler_setup, context): @@ -121,7 +120,6 @@ async def test_create_node(gen_node, node_handler_setup, context): # @pytest.mark.skip(reason="Flaky") -@pytest.mark.asyncio @pytest.mark.parametrize( "context_order", [["multiprocessing", "threading"], ["threading", "multiprocessing"]], @@ -141,7 +139,6 @@ async def test_create_node_along_with_different_context( @linux_run_only -@pytest.mark.asyncio async def test_create_unknown_node(node_handler_setup): node_handler, _ = node_handler_setup @@ -163,7 +160,6 @@ def step(self): # @pytest.mark.parametrize("context", ["multiprocessing"]) # , "threading"]) @pytest.mark.parametrize("context", ["multiprocessing", "threading"]) -@pytest.mark.asyncio async def test_processing_node_pub_table( node_handler_setup, gen_node, con_node, context ): @@ -191,7 +187,6 @@ async def test_processing_node_pub_table( assert await node_handler.async_destroy_node(con_node.id) -@pytest.mark.asyncio # @pytest.mark.parametrize("context", ["multiprocessing"]) # , "threading"]) @pytest.mark.parametrize("context", ["multiprocessing", "threading"]) async def test_starting_node(node_handler_setup, gen_node, context): @@ -206,7 +201,6 @@ async def test_starting_node(node_handler_setup, gen_node, context): assert await node_handler.async_destroy_node(gen_node.id) -@pytest.mark.asyncio # @pytest.mark.parametrize("context", ["multiprocessing"]) # , "threading"]) @pytest.mark.parametrize("context", ["multiprocessing", "threading"]) async def test_record_and_collect(node_handler_setup, context): @@ -239,7 +233,6 @@ async def test_record_and_collect(node_handler_setup, context): assert (node_handler.state.tempfolder / node_name).exists() -@pytest.mark.asyncio async def test_registered_method_with_concurrent_style( node_handler_setup, node_with_reg_methods ): @@ -261,7 +254,6 @@ async def test_registered_method_with_concurrent_style( ) -@pytest.mark.asyncio async def test_registered_method_with_params_and_blocking_style( node_handler_setup, node_with_reg_methods ): @@ -285,7 +277,6 @@ async def test_registered_method_with_params_and_blocking_style( ) -@pytest.mark.asyncio async def test_registered_method_with_reset_style( node_handler_setup, node_with_reg_methods ): @@ -309,7 +300,6 @@ async def test_registered_method_with_reset_style( ) -@pytest.mark.asyncio # @pytest.mark.parametrize("context", ["multiprocessing"]) # , "threading"]) @pytest.mark.parametrize("context", ["multiprocessing", "threading"]) async def test_gather(node_handler_setup, gen_node, context):