diff --git a/examples/run_gym.py b/examples/run_gym.py index 91f9b2e0..1ad8c40c 100644 --- a/examples/run_gym.py +++ b/examples/run_gym.py @@ -26,7 +26,6 @@ import os import time -from malib.runner import run from malib.learner import IndependentAgent from malib.scenarios.marl_scenario import MARLScenario from malib.rl.dqn import DQNPolicy, DQNTrainer, DEFAULT_CONFIG diff --git a/malib/backend/dataset_server/data_loader.py b/malib/backend/dataset_server/data_loader.py index f151190d..9d774ffc 100644 --- a/malib/backend/dataset_server/data_loader.py +++ b/malib/backend/dataset_server/data_loader.py @@ -2,6 +2,7 @@ import threading import grpc +import socket from concurrent import futures from torch.utils.data import DataLoader, Dataset @@ -33,6 +34,7 @@ def __init__( max_message_length, find_free_port(), ) + self.host = socket.gethostbyname(socket.gethostbyname()) def _start_servicer( self, max_workers: int, max_message_length: int, grpc_port: int @@ -52,6 +54,10 @@ def _start_servicer( return server + @property + def entrypoint(self) -> str: + return f"{self.host}:{self.server._state.port}" + def __len__(self): return self.feature_handler_caller.block_size diff --git a/malib/backend/league.py b/malib/backend/league.py index d33af6f8..39d6e5ce 100644 --- a/malib/backend/league.py +++ b/malib/backend/league.py @@ -8,39 +8,34 @@ from malib.utils.logging import Logger from malib.common.manager import Manager +from malib.common.task import Task, RolloutTask, OptimizationTask class League: def __init__( self, - training_manager: Manager, + learner_manager: Manager, rollout_manager: Manager, inference_manager: Manager, ) -> None: - self.training_manager = training_manager + self.learner_manager = learner_manager self.rollout_manager = rollout_manager self.inferenc_managfer = inference_manager - self.flight_servers = [] self.rw_lock = rwlock.RWLockFair() self.event = threading.Event() self.thread_pool = futures.ThreadPoolExecutor() - def register_flight_server(self, flight_server_address: str): - raise NotImplementedError - - def list_flight_servers(self) -> List[str]: - raise NotImplementedError - - def _flight_server_check(self): - while not self.event.is_set(): - with self.rw_lock.gen_rlock(): - for flight_server in self.flight_servers: - if not ray.util.check_connection(flight_server): - self.flight_servers.remove(flight_server) - self.event.wait(10) - def list_learners(self): - return self.training_manager.workers() + return self.learner_manager.workers() + + def submit(self, task_desc: Task, wait: bool = False): + if isinstance(task_desc, RolloutTask): + res = self.rollout_manager.submit(task_desc, wait) + elif isinstance(task_desc, OptimizationTask): + res = self.learner_manager.submit(task_desc, wait) + else: + raise ValueError(f"Unexpected task type: {isinstance(task_desc)}") + return res def list_rollout_workers(self): return self.rollout_manager.workers() @@ -62,7 +57,7 @@ def get_results(self) -> Dict[str, Dict[str, Any]]: while True: for result in self.rollout_manager.get_results(): rollout_results.append(result) - for result in self.training_manager.get_results(): + for result in self.learner_manager.get_results(): training_results.append(result) except KeyboardInterrupt: Logger.info("Keyboard interruption was detected, recalling resources ...") @@ -76,5 +71,5 @@ def get_results(self) -> Dict[str, Dict[str, Any]]: def terminate(self): self.event.set() self.thread_pool.shutdown() - self.training_manager.terminate() + self.learner_manager.terminate() self.rollout_manager.terminate() diff --git a/malib/common/manager.py b/malib/common/manager.py index 80d394f7..87e4c0ca 100644 --- a/malib/common/manager.py +++ b/malib/common/manager.py @@ -20,8 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import traceback -from typing import List, Generator, Any +from typing import List, Any from abc import abstractmethod, ABC import ray @@ -52,8 +51,18 @@ def namespace(self) -> str: def workers(self) -> List[RemoteInterface]: raise NotImplementedError + @abstractmethod def retrive_results(self): - raise NotImplementedError + """Retrieve execution results.""" + + @abstractmethod + def submit(self, task: Any, wait: bool = False) -> Any: + """Submit task to workers. + + Args: + task (Any): Task description. + wait (bool, optional): Block or not. Defaults to False. + """ def wait(self) -> List[Any]: """Wait workers to be terminated, and retrieve the executed results. diff --git a/malib/common/task.py b/malib/common/task.py index e59df234..1155e967 100644 --- a/malib/common/task.py +++ b/malib/common/task.py @@ -12,9 +12,12 @@ class TaskType(IntEnum): OPTIMIZATION = 2 +class Task: + pass + + @dataclass -class RolloutTask: - task_type: int +class RolloutTask(Task): strategy_specs: Dict[str, Any] = field(default_factory=dict()) stopping_conditions: Dict[str, Any] = field(default_factory=dict()) data_entrypoint_mapping: Dict[str, Any] = field(default_factory=dict()) @@ -32,7 +35,7 @@ def from_raw( @dataclass -class OptimizationTask: +class OptimizationTask(Task): stop_conditions: Dict[str, Any] """stopping conditions for optimization task, e.g., max iteration, max time, etc.""" diff --git a/malib/learner/learner.py b/malib/learner/learner.py index da87e2f4..56fc4114 100644 --- a/malib/learner/learner.py +++ b/malib/learner/learner.py @@ -33,7 +33,7 @@ import torch import ray -from ray.util.queue import Queue +from gym import spaces from torch.utils import tensorboard from torch.utils.data import DataLoader @@ -42,10 +42,11 @@ from malib.utils.tianshou_batch import Batch from malib.utils.monitor import write_to_tensorboard from malib.remote.interface import RemoteInterface -from malib.rl.common.trainer import Trainer from malib.common.task import OptimizationTask from malib.common.strategy_spec import StrategySpec from malib.backend.dataset_server.data_loader import DynamicDataset +from malib.rl.common.trainer import Trainer +from malib.rl.common.policy import Policy class Learner(RemoteInterface, ABC): @@ -57,8 +58,8 @@ def __init__( experiment_tag: str, runtime_id: str, log_dir: str, - observation_space: gym.Space, - action_space: gym.Space, + observation_space: spaces.Space, + action_space: spaces.Space, algorithms: Dict[str, Tuple[Type, Type, Dict, Dict]], agent_mapping_func: Callable[[AgentID], str], governed_agents: Tuple[AgentID], @@ -109,24 +110,32 @@ def __init__( self._strategy_spec = strategy_spec self._agent_mapping_func = agent_mapping_func self._custom_config = custom_config + self._policy = strategy_spec.gen_policy(device=device) self._summary_writer = tensorboard.SummaryWriter(log_dir=log_dir) self._trainer_config = trainer_config + + # load policy for trainer + self._trainer: Trainer = algorithms["default"][1](trainer_config, self._policy) self._total_step = 0 self._total_epoch = 0 - self._trainer: Trainer = algorithms["default"][1](trainer_config) - self._policies = {} dataset = dataset or self.create_dataset() self._data_loader = DataLoader(dataset, batch_size=trainer_config["batch_size"]) - self._active_tups = deque() - self._verbose = verbose @property def verbose(self) -> bool: return self._verbose + @property + def strategy_spec(self) -> StrategySpec: + return self._strategy_spec + + @property + def policy(self) -> Policy: + return self._policy + @property def data_loader(self) -> DataLoader: return self._data_loader @@ -151,9 +160,28 @@ def device(self) -> Union[str, torch.DeviceObjType]: return self._device + @property + def trainer(self) -> Trainer: + return self._trainer + + def get_data_entrypoint(self) -> str: + return self.data_loader.dataset.entrypoint + + def get_strategy_spec(self) -> StrategySpec: + return self._strategy_spec + + def get_state_dict(self) -> Dict[str, torch.Tensor]: + return self.policy.state_dict(device="cpu") + + @abstractmethod def create_dataset(self) -> DynamicDataset: - raise NotImplementedError + """Create dataset + + Returns: + DynamicDataset: Must be an subinstance of DynamicDataset + """ + @abstractmethod def add_policies(self, n: int) -> StrategySpec: """Construct `n` new policies and return the latest strategy spec. @@ -164,61 +192,6 @@ def add_policies(self, n: int) -> StrategySpec: StrategySpec: The latest strategy spec instance. """ - for _ in range(n): - spec_pid = f"policy-{len(self._strategy_spec)}" - self._strategy_spec.register_policy_id(policy_id=spec_pid) - policy = self._strategy_spec.gen_policy() - policy_id = f"{self._strategy_spec.id}/{spec_pid}" - self._policies[policy_id] = policy - # active tups store the policy info tuple for training, the - # the data request relies on it. - self._active_tups.append((self._strategy_spec.id, spec_pid)) - self._trainer.reset(policy_instance=policy) - - ray.get(self._parameter_server.create_table.remote(self._strategy_spec)) - ray.get( - self._parameter_server.set_weights.remote( - spec_id=self._strategy_spec.id, - spec_policy_id=spec_pid, - state_dict=policy.state_dict(), - ) - ) - - return self._strategy_spec - - def push(self): - """Push local weights to remote server""" - - pending_tasks = [] - for spec_pid in self._strategy_spec.policy_ids: - pid = f"{self._strategy_spec.id}/{spec_pid}" - task = self._parameter_server.set_weights.remote( - spec_id=self._strategy_spec.id, - spec_policy_id=spec_pid, - state_dict=self._policies[pid].state_dict(), - ) - pending_tasks.append(task) - while len(pending_tasks) > 0: - dones, pending_tasks = ray.wait(pending_tasks) - - def pull(self): - """Pull remote weights to update local version.""" - - pending_tasks = [] - - for spec_pid in self._strategy_spec.policy_ids: - pid = f"{self._strategy_spec.id}/{spec_pid}" - task = self._parameter_server.get_weights.remote( - spec_id=self._strategy_spec.id, spec_policy_id=spec_pid - ) - pending_tasks.append(task) - - while len(pending_tasks) > 0: - dones, pending_tasks = ray.wait(pending_tasks) - for done in ray.get(dones): - pid = "{}/{}".format(done["spec_id"], done["spec_policy_id"]) - self._policies[pid].load_state_dict(done["weights"]) - @abstractmethod def multiagent_post_process( self, @@ -246,21 +219,8 @@ def get_interface_state(self) -> Dict[str, Any]: "total_step": self._total_step, "total_epoch": self._total_epoch, "policy_num": len(self._strategy_spec), - "active_tups": list(self._active_tups), } - def sync_remote_parameters(self): - """Push latest network parameters of active policies to remote parameter server.""" - - top_active_tup = self._active_tups[0] - ray.get( - self._parameter_server.set_weights.remote( - spec_id=top_active_tup[0], - spec_policy_id=top_active_tup[1], - state_dict=self._trainer.policy.state_dict(device="cpu"), - ) - ) - def train(self, task: OptimizationTask) -> Dict[str, Any]: """Executes a optimization task and returns the final interface state. @@ -272,25 +232,21 @@ def train(self, task: OptimizationTask) -> Dict[str, Any]: Dict[str, Any]: A dict that describes the final state. """ - # XXX(ming): why we need to reset the state here? I think it is not necessary as - # an optimization task should be independent with other tasks. - self.set_running(True) try: while self.is_running(): for data in self.data_loader: batch_info = self.multiagent_post_process(data) - step_info_list = self._trainer(batch_info) + step_info_list = self.trainer(batch_info) for step_info in step_info_list: self._total_step += 1 write_to_tensorboard( self._summary_writer, info=step_info, global_step=self._total_step, - prefix=f"Training/{self._runtime_id}", + prefix=f"Learner/{self._runtime_id}", ) - self.sync_remote_parameters() self._total_epoch += 1 except Exception as e: Logger.warning( diff --git a/malib/learner/manager.py b/malib/learner/manager.py index c51e18b8..ea15b225 100644 --- a/malib/learner/manager.py +++ b/malib/learner/manager.py @@ -53,11 +53,11 @@ DEFAULT_RESOURCE_CONFIG = dict( - num_cpus=None, num_gpus=None, memory=None, object_store_memory=None, resources=None + num_cpus=None, num_gpus=None, memory=None, resources=None ) -class TrainingManager(Manager): +class LearnerManager(Manager): def __init__( self, experiment_tag: str, @@ -68,12 +68,11 @@ def __init__( group_info: Dict[str, Any], training_config: Union[Dict[str, Any], TrainingConfig], log_dir: str, - remote_mode: bool = True, resource_config: Dict[str, Any] = None, ray_actor_namespace: str = "learner", verbose: bool = True, ): - """Create an TrainingManager instance which is responsible for the multi agent training + """Create an LearnerManager instance which is responsible for the multi agent training tasks execution and rollout task requests sending. Args: @@ -86,7 +85,6 @@ def __init__( training_config (Dict[str, Any]): Training configuration, for agent interface, keys include \ `type`, `trainer_config` and `custom_config`. log_dir (str): Directory for logging. - remote_mode (bool, Optional): Init learners as remote actor or not. Default is True. """ super().__init__(verbose=verbose, namespace=ray_actor_namespace) @@ -96,7 +94,7 @@ def __init__( # interface config give the agent type used here and the group mapping if needed - # FIXME(ming): resource configuration is not available now, will open in the next version + # FIXME(ming): resource configuration is not available now, will turn-on in the next version if training_config.trainer_config.get("use_cuda", False): num_gpus = 1 / len(group_info["agent_groups"]) else: @@ -107,18 +105,19 @@ def __init__( learner_cls = training_config.learner_type # update num gpus resource_config["num_gpus"] = num_gpus - learner_cls = learner_cls.as_remote(**resource_config).options( - max_concurrency=10 - ) - learners: Dict[str, Union[Learner, ray.ObjectRef]] = {} + learner_cls = learner_cls.as_remote(**resource_config) + learners: Dict[str, ray.ObjectRef] = {} assert ( "training" in stopping_conditions ), f"Stopping conditions should contains `training` stoppong conditions: {stopping_conditions}" + ready_check = [] + for rid, agents in group_info["agent_groups"].items(): - _cls = learner_cls.remote if remote_mode else learner_cls - learners[rid] = _cls( + learners[rid] = learner_cls.options( + name=f"learner_{rid}", max_concurrency=10, namespace=self.namespace + ).remote( experiment_tag=experiment_tag, runtime_id=rid, log_dir=f"{log_dir}/learner_{rid}", @@ -131,11 +130,22 @@ def __init__( custom_config=training_config.custom_config, verbose=verbose, ) + ready_check.append(learners[rid].ready.remote()) # ensure all interfaces have been started up - tasks = list(learners.values()) - while len(tasks): - _, tasks = ray.wait(tasks, num_returns=1, timeout=1) + while len(ready_check): + _, ready_check = ray.wait(ready_check, num_returns=1, timeout=1) + + data_entrypoints = ray.get( + [x.get_data_entrypoint.remote() for x in learners.values()] + ) + self._data_entrypoints = dict(zip(learners.keys(), data_entrypoints)) + self._learner_entrypoints = dict( + zip( + learners.keys(), + [f"{self.namespace}:learner_{rid}" for rid in learners.keys()], + ) + ) # TODO(ming): collect data entrypoints from learners self._group_info = group_info @@ -146,7 +156,6 @@ def __init__( self._log_dir = log_dir self._agent_mapping_func = agent_mapping_func self._learners = learners - self._remote_mode = remote_mode self._thread_pool = ThreadPoolExecutor(max_workers=len(learners)) self._stopping_conditions = stopping_conditions @@ -165,14 +174,18 @@ def agent_groups(self) -> Dict[str, Set[AgentID]]: return self._group_info["agent_groups"] @property - def get_data_entrypoints(self) -> Dict[str, str]: + def data_entrypoints(self) -> Dict[str, str]: """Return a dict of data entrypoints, maps from runtime ids to data entrypoints. Returns: Dict[str, str]: A dict of data entrypoints. """ - return {rid: rid for rid in self._runtime_ids} + return self._data_entrypoints + + @property + def learner_entrypoints(self) -> Dict[str, str]: + return self._learner_entrypoints @property def workers(self) -> List[RemoteInterface]: @@ -194,9 +207,6 @@ def runtime_ids(self) -> Tuple[str]: return self._runtime_ids - def get_data_entrypoint_mapping(self) -> Dict[AgentID, str]: - raise NotImplementedError - def add_policies( self, interface_ids: Sequence[str] = None, n: Union[int, Dict[str, int]] = 1 ) -> Dict[str, Type[StrategySpec]]: @@ -217,21 +227,15 @@ def add_policies( policy_nums = dict.fromkeys(interface_ids, n) if isinstance(n, int) else n - if self._remote_mode: - strategy_spec_list: List[StrategySpec] = ray.get( - [ - self._learners[k].add_policies.remote(n=policy_nums[k]) - for k in interface_ids - ] - ) - strategy_spec_dict: Dict[str, StrategySpec] = dict( - zip(interface_ids, strategy_spec_list) - ) - else: - strategy_spec_dict = { - k: self._learners[k].add_policies(n=policy_nums[k]) + strategy_spec_list: List[StrategySpec] = ray.get( + [ + self._learners[k].add_policies.remote(n=policy_nums[k]) for k in interface_ids - } + ] + ) + strategy_spec_dict: Dict[str, StrategySpec] = dict( + zip(interface_ids, strategy_spec_list) + ) return strategy_spec_dict @@ -249,11 +253,8 @@ def submit(self, task: OptimizationTask): raise RuntimeError(f"Agent {aid} is not registered in training manager") else: learner = self._learners[rid] - if self._remote_mode: - ray_task = learner.train.remote(task) - self.pending_tasks.append(ray_task) - else: - raise NotImplementedError + ray_task = learner.train.remote(task) + self.pending_tasks.append(ray_task) def retrive_results(self) -> Generator: """Return a generator of results. @@ -262,36 +263,18 @@ def retrive_results(self) -> Generator: Generator: A generator for task results. """ - if self._remote_mode: - while len(self.pending_tasks) > 0: - dones, self.pending_tasks = ray.wait(self.pending_tasks) - for done in ray.get(dones): - yield done - else: - for task in self.pending_tasks: - assert isinstance(task, Future) - try: - if task.done(): - yield task.result(timeout=10) - except TimeoutError: - Logger.error( - f"Retrieving results of training task is timeout: {traceback.format_exc()}" - ) - except CancelledError: - Logger.error( - f"Try to retrieve results of a cancelled task: {traceback.format_exc()}" - ) - except Exception: - Logger.error(traceback.format_exc()) + while len(self.pending_tasks): + dones, self.pending_tasks = ray.wait(self.pending_tasks) + for done in ray.get(dones): + yield done def terminate(self) -> None: """Terminate all training actors.""" super().terminate() - if self._remote_mode: - for x in self._learners.values(): - ray.kill(x) + for x in self._learners.values(): + ray.kill(x) self._thread_pool.shutdown() del self._learners diff --git a/malib/models/model_client.py b/malib/models/model_client.py index 083b7092..e239fd45 100644 --- a/malib/models/model_client.py +++ b/malib/models/model_client.py @@ -31,15 +31,9 @@ def __init__(self, entry_point: str, model_config: ModelConfig): NotImplementedError: Unsupported cluster type. """ - cluster_type, name_or_address = entry_point.split(":") + namespace, name = entry_point.split(":") - if "ray" in cluster_type: - self.client = ray.get_actor(name_or_address) - else: - raise NotImplementedError - - self.cluster_type = cluster_type - self.server_address = name_or_address + self.client = ray.get_actor(name=name, namespace=namespace) self.thread_pool = futures.ThreadPoolExecutor(max_workers=10) self.event = threading.Event() @@ -59,10 +53,11 @@ def critic(self, *args, **kwargs): def _model_update(self, event: threading.Event): while not event.is_set(): - # TODO(ming): update model from remote server try: - state_dict = load_state_dict(self.client) - + state_dict = load_state_dict( + ray.get(self.client.get_state_dict.remote(), timeout=10) + ) + self.model.load_state_dict(state_dict) event.wait(0.5) except TimeoutError: # TODO(ming): count or reconnect diff --git a/malib/remote/interface.py b/malib/remote/interface.py index 43828237..4664ca4c 100644 --- a/malib/remote/interface.py +++ b/malib/remote/interface.py @@ -39,7 +39,6 @@ def as_remote( num_cpus: int = None, num_gpus: int = None, memory: int = None, - object_store_memory: int = None, resources: dict = None, ) -> type: """Return a remote class for Actor initialization""" @@ -48,10 +47,13 @@ def as_remote( num_cpus=num_cpus, num_gpus=num_gpus, memory=memory, - object_store_memory=object_store_memory, resources=resources, )(cls) + def ready(self): + """For initialization checking. Always return True.""" + return True + def stop_pending_tasks(self): """External object can call this method to stop all pending tasks.""" diff --git a/malib/rl/config.py b/malib/rl/config.py new file mode 100644 index 00000000..03e4c6d4 --- /dev/null +++ b/malib/rl/config.py @@ -0,0 +1,16 @@ +from typing import Dict, Any + +from dataclasses import dataclass + +from malib.rl.common.policy import Policy +from malib.rl.common.trainer import Trainer + + +@dataclass +class Algorithm: + + policy: Policy + + trainer: Trainer + + model_config: Dict[str, Any] diff --git a/malib/rollout/envs/env.py b/malib/rollout/envs/env.py index 2325322c..81492a77 100644 --- a/malib/rollout/envs/env.py +++ b/malib/rollout/envs/env.py @@ -22,6 +22,7 @@ from typing import Dict, List, Any, Union, Tuple, Sequence +import copy import uuid import gym import numpy as np @@ -55,6 +56,9 @@ def __init__(self, **configs): self._current_players = [] self._state: Dict[str, np.ndarray] = None self._deactivated = True + self._agents = self.register_agents() + self._observation_spaces = self.register_observation_spaces() + self._action_spaces = self.register_action_spaces() def record_episode_info_step( self, @@ -87,27 +91,40 @@ def record_episode_info_step( self.episode_metrics["episode_reward"] += sum(rewards.values()) @property - def possible_agents(self) -> List[AgentID]: + def configs(self): + return copy.deepcopy(self._configs) + + @property + def possible_agents(self) -> Tuple[AgentID]: """Return a list of environment agent ids""" - raise NotImplementedError + return tuple(self._agents) @property def observation_spaces(self) -> Dict[AgentID, gym.Space]: """A dict of agent observation spaces""" - raise NotImplementedError + return self._observation_spaces @property def action_spaces(self) -> Dict[AgentID, gym.Space]: """A dict of agent action spaces""" - raise NotImplementedError + return self._action_spaces @property def is_deactivated(self) -> bool: return self._deactivated + def register_observation_spaces(self): + raise NotImplementedError + + def register_action_spaces(self): + raise NotImplementedError + + def register_agents(self): + raise NotImplementedError + def deactivate(self): self._deactivated = True diff --git a/malib/rollout/envs/mdp/env.py b/malib/rollout/envs/mdp/env.py index 6b8f7fb3..e97c99ee 100644 --- a/malib/rollout/envs/mdp/env.py +++ b/malib/rollout/envs/mdp/env.py @@ -9,7 +9,6 @@ class MDPEnvironment(Environment): def __init__(self, **configs): - super().__init__(**configs) try: from blackhc import mdp @@ -35,18 +34,16 @@ def __init__(self, **configs): ) self.env = scenarios[env_id]().to_env() - self._possible_agents = ["default"] - @property - def possible_agents(self) -> List[AgentID]: - return self._possible_agents + super().__init__(**configs) + + def register_agents(self): + return ["default"] - @property - def observation_spaces(self) -> Dict[AgentID, gym.Space]: + def register_observation_spaces(self): return dict.fromkeys(self.possible_agents, self.env.observation_space) - @property - def action_spaces(self) -> Dict[AgentID, gym.Space]: + def register_action_spaces(self): return dict.fromkeys(self.possible_agents, self.env.action_space) def time_step( diff --git a/malib/rollout/envs/random/__init__.py b/malib/rollout/envs/random/__init__.py new file mode 100644 index 00000000..92e3686e --- /dev/null +++ b/malib/rollout/envs/random/__init__.py @@ -0,0 +1,36 @@ +# MIT License + +# Copyright (c) 2021 MARL @ SJTU + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from .env import RandomEnv + + +def env_desc_gen(**config): + env = RandomEnv(**config) + env_desc = { + "creator": RandomEnv, + "possible_agents": env.possible_agents, + "action_spaces": env.action_spaces, + "observation_spaces": env.observation_spaces, + "config": config, + } + env.close() + return env_desc diff --git a/malib/rollout/envs/random/env.py b/malib/rollout/envs/random/env.py new file mode 100644 index 00000000..72ec8aa3 --- /dev/null +++ b/malib/rollout/envs/random/env.py @@ -0,0 +1,62 @@ +from typing import Any, Dict, List, Sequence, Tuple, Union + +import gym +import random + +from gym import spaces + +from malib.rollout.envs.env import Environment +from malib.utils.typing import AgentID + + +class RandomEnv(Environment): + def __init__(self, **configs): + assert "num_agents" in configs + super().__init__(**configs) + + def register_agents(self): + return {f"agent_{i}" for i in range(self.configs["num_agents"])} + + def register_observation_spaces(self): + return { + agent: spaces.Box(low=-1, high=1, shape=(2,)) + for agent in self.possible_agents + } + + def register_action_spaces(self): + return {agent: spaces.Discrete(4) for agent in self.possible_agents} + + def get_state(self) -> Any: + return None + + def reset(self, max_step: int = None): + super().reset(max_step) + obs = {k: v.sample() for k, v in self.observation_spaces.items()} + return self.get_state(), obs + + def time_step( + self, actions: Dict[AgentID, Any] + ) -> Tuple[ + Dict[AgentID, Any], + Dict[AgentID, float], + Dict[AgentID, bool], + Dict[AgentID, Any], + ]: + # assert action whether in space + for k, v in actions.items(): + _space = self.action_spaces[k] + assert _space.contains(v), (k, v, _space) + obs = {k: v.sample() for k, v in self.observation_spaces.items()} + rews = {k: random.random() for k in self.possible_agents} + state = self.get_state() + + return ( + state, + obs, + rews, + {k: False for k in self.possible_agents}, + {k: {} for k in self.possible_agents}, + ) + + def close(self): + pass diff --git a/malib/rollout/inference/client.py b/malib/rollout/inference/client.py index 4892ecf3..82715e71 100644 --- a/malib/rollout/inference/client.py +++ b/malib/rollout/inference/client.py @@ -34,10 +34,6 @@ import numpy as np from malib.remote.interface import RemoteInterface -from malib.utils.typing import AgentID, DataFrame -from malib.utils.timing import Timing -from malib.utils.episode import Episode -from malib.common.strategy_spec import StrategySpec from malib.models.config import ModelConfig from malib.rl.common.policy import Policy, PolicyReturn @@ -50,7 +46,7 @@ class InferenceClient(RemoteInterface): def __init__( self, - entry_point: str, + model_entry_point: str, policy_cls: Type, observation_space: gym.Space, action_space: gym.Space, @@ -74,7 +70,7 @@ def __init__( observation_space, action_space, model_config, - model_entry_point=entry_point, + model_entry_point=model_entry_point, ) def shutdown(self): diff --git a/malib/rollout/inference/env_runner.py b/malib/rollout/inference/env_runner.py index 29c7bdbf..bc28882e 100644 --- a/malib/rollout/inference/env_runner.py +++ b/malib/rollout/inference/env_runner.py @@ -132,7 +132,7 @@ def __init__( max_env_num: int, use_subproc_env: bool = False, agent_groups: Dict[str, Set] = None, - inferenc_client_namespace: str = None, + inference_entry_points: Dict[str, str] = None, ) -> None: super().__init__() @@ -141,7 +141,7 @@ def __init__( self._env_func = env_func self._envs = [] self._agent_groups = agent_groups - self._infer_client_namespace = inferenc_client_namespace + self._inference_entry_points = inference_entry_points self._inference_clients = None @property @@ -192,7 +192,7 @@ def run( if inference_clients is None: assert ( - self._infer_client_namespace is not None + self._inference_entry_points is not None ), "Inference client namespace should be specified if infer_clients is not given." assert ( self._agent_groups is not None @@ -200,9 +200,8 @@ def run( if self.inference_clients is None: res = {} for rid, _agents in self._agent_groups.items(): - client = ray.get_actor( - f"inference_{rid}", namespace=self._infer_client_namespace - ) + namespace, name = self._inference_entry_points[rid].split(":") + client = ray.get_actor(name=name, namespace=namespace) res.update(dict.fromkeys(_agents, client)) self._inference_clients = res inference_clients = self.inference_clients diff --git a/malib/rollout/inference/manager.py b/malib/rollout/inference/manager.py index bd73e8c1..edac1189 100644 --- a/malib/rollout/inference/manager.py +++ b/malib/rollout/inference/manager.py @@ -1,8 +1,9 @@ -from typing import Dict, Set +from typing import Any, Dict, Set import ray from malib.common.manager import Manager +from malib.rl.config import Algorithm from malib.scenarios import Scenario from malib.rollout.inference.client import InferenceClient @@ -12,32 +13,60 @@ def __init__( self, group_info: Dict[str, Set], ray_actor_namespace: str, - entrypoints: Dict[str, str], - scenario: Scenario, + model_entry_point: Dict[str, str], + algorithm: Algorithm, verbose: bool = False, ): super().__init__(verbose, namespace=ray_actor_namespace) - inference_remote_cls = InferenceClient.as_remote(num_cpus=1).options( - namespace=self.namespace - ) + inference_remote_cls = InferenceClient.as_remote(num_cpus=1) obs_spaces = group_info["observation_space"] act_spaces = group_info["action_space"] agent_groups = group_info["agent_groups"] - self.infer_clients = {} + self._infer_clients = {} + self._inference_entry_points = {} + # FIXME(Ming): for debug only + model_entry_point = model_entry_point or { + rid: None for rid in agent_groups.keys() + } + + infer_client_ready_check = [] for rid, _ in agent_groups.items(): - self.infer_clients[rid] = inference_remote_cls.options( - name=f"inference_{rid}" + actor_name = f"inference_{rid}" + self._infer_clients[rid] = inference_remote_cls.options( + namespace=self.namespace, name=actor_name ).remote( - entry_point=entrypoints[rid], - policy_cls=scenario.algorithms[rid].policy_cls, + model_entry_point=model_entry_point[rid], + policy_cls=algorithm.policy, observation_space=obs_spaces[rid], action_space=act_spaces[rid], - model_config=scenario.training_config["model_config"], + model_config=algorithm.model_config, + ) + infer_client_ready_check.append(self._infer_clients[rid].ready.remote()) + self._inference_entry_points[rid] = "{}:{}".format( + self.namespace, actor_name ) # check ready - tasks = list(self.infer_clients.values()) - while len(tasks): - _, tasks = ray.wait(tasks, num_returns=1, timeout=1) + while len(infer_client_ready_check): + _, infer_client_ready_check = ray.wait( + infer_client_ready_check, num_returns=1, timeout=1 + ) + + def get_inference_client(self, runtime_id: str) -> InferenceClient: + return self.inference_clients[runtime_id] + + @property + def inference_clients(self) -> Dict[str, ray.ObjectRef]: + return self._infer_clients + + @property + def inference_entry_points(self) -> str: + return self._inference_entry_points + + def submit(self, task: Any, wait: bool = False): + pass + + def retrive_results(self): + pass diff --git a/malib/rollout/manager.py b/malib/rollout/manager.py index e3a99274..e669a426 100644 --- a/malib/rollout/manager.py +++ b/malib/rollout/manager.py @@ -75,7 +75,6 @@ def validate_strategy_specs(specs: Dict[str, StrategySpec]): class RolloutWorkerManager(Manager): def __init__( self, - experiment_tag: str, stopping_conditions: Dict[str, Any], num_worker: int, group_info: Dict[str, Any], @@ -89,7 +88,6 @@ def __init__( """Construct a manager for multiple rollout workers. Args: - experiment_tag (str): Experiment tag. num_worker (int): Indicates how many rollout workers will be initialized. rollout_config (Dict[str, Any]): Runtime rollout configuration. env_desc (Dict[str, Any]): Environment description. @@ -101,15 +99,14 @@ def __init__( super().__init__(verbose=verbose, namespace=ray_actor_namespace) rollout_worker_cls = PBRolloutWorker - worker_cls = rollout_worker_cls.as_remote(num_cpus=0, num_gpus=0).options( - namespace=self.namespace - ) + worker_cls = rollout_worker_cls.as_remote(num_cpus=0, num_gpus=0).options() workers = [] - + ready_check = [] for i in range(num_worker): workers.append( - worker_cls.options(max_concurrency=100, name=f"actor_{i}").remote( - experiment_tag=experiment_tag, + worker_cls.options( + max_concurrency=100, namespace=self.namespace, name=f"actor_{i}" + ).remote( env_desc=env_desc, agent_groups=group_info["agent_groups"], rollout_config=RolloutConfig.from_raw(rollout_config), @@ -120,12 +117,15 @@ def __init__( verbose=verbose, ) ) + ready_check.append(workers[-1].ready.remote()) + + while len(ready_check): + _, ready_check = ray.wait(ready_check, num_returns=1, timeout=1) - self._workers: List[ray.actor] = workers + self._workers: List[ray.ObjectRef] = workers self._actor_pool = ActorPool(self._workers) self._runtime_ids = tuple(group_info["agent_groups"].keys()) self._group_info = group_info - self.experiment_tag = experiment_tag assert ( "rollout" in stopping_conditions @@ -163,7 +163,9 @@ def workers(self) -> List[RemoteInterface]: return self._workers - def submit(self, task: Union[Dict[str, Any], List[Dict[str, Any]]]): + def submit( + self, task: Union[Dict[str, Any], List[Dict[str, Any]]], wait: bool = False + ) -> Any: """Submit a task to workers Args: @@ -180,6 +182,12 @@ def submit(self, task: Union[Dict[str, Any], List[Dict[str, Any]]]): validate_strategy_specs(_task.strategy_specs) self._actor_pool.submit(lambda actor, _task: actor.rollout.remote(_task)) + if wait: + result_list = self.wait() + return result_list + else: + return None + def retrive_results(self): """Retrieve task results diff --git a/malib/rollout/pb_rolloutworker.py b/malib/rollout/pb_rolloutworker.py index 7433edc2..deb15a83 100644 --- a/malib/rollout/pb_rolloutworker.py +++ b/malib/rollout/pb_rolloutworker.py @@ -49,7 +49,6 @@ def step_rollout( data_entrypoint_mapping=data_entrypoint_mapping, ) ) - # check evaluation info parsed_results = parse_rollout_info(results) Logger.debug(f"parsed results: {parsed_results}") diff --git a/malib/rollout/rollout_config.py b/malib/rollout/rollout_config.py index c569376c..4e462d6b 100644 --- a/malib/rollout/rollout_config.py +++ b/malib/rollout/rollout_config.py @@ -5,8 +5,6 @@ @dataclass class RolloutConfig: - inference_server_type: str - """Inference server type""" num_workers: int = 1 """Defines how many workers will be used for executing one rollout task, default is 1""" @@ -23,6 +21,8 @@ class RolloutConfig: timelimit: int = 256 """Specifying how many time steps will be collected for each rollout, default is 256""" + inference_entry_points: Dict[str, str] = field(default_factory=dict) + @classmethod def from_raw( cls, config: Union["RolloutConfig", Dict[str, Any]] diff --git a/malib/rollout/rolloutworker.py b/malib/rollout/rolloutworker.py index 4284e29f..aa35ac44 100644 --- a/malib/rollout/rolloutworker.py +++ b/malib/rollout/rolloutworker.py @@ -76,49 +76,18 @@ def parse_rollout_info(raw_statistics: List[Dict[str, Any]]) -> Dict[str, Any]: Dict[str, Any]: A merged dict. """ - results = {"total_timesteps": 0, "FPS": 0.0} - evaluation = [] - - for e in raw_statistics: - # when task mode is `simualtion` or `evaluation`, then - # evaluation result is not empty. - if "evaluation" in e: - evaluation.extend(e["evaluation"]) - - for k, v in e.items(): - if k == "total_timesteps": - results[k] += v - elif k == "FPS": - results[k] += v - # else: - # raise ValueError(f"Unknow key: {k} / {v}") - - if len(evaluation) > 0: - raw_eval_results = defaultdict(lambda: []) - for e in evaluation: - for k, v in e.items(): - if isinstance(v, (Tuple, List)): - v = sum(v) - raw_eval_results[k].append(v) - eval_results = {} - for k, v in raw_eval_results.items(): - # convert v to array - eval_results.update( - {f"{k}_max": np.max(v), f"{k}_min": np.min(v), f"{k}_mean": np.mean(v)} - ) - results["evaluation"] = eval_results - return results + return raw_statistics def log(message: str): logger.log(settings.LOG_LEVEL, f"(rollout worker) {message}") -def default_rollout_callback(coordinator: ray.ObjectRef, results: Dict[str, Any]): +def default_rollout_callback(results: Dict[str, Any]): pass -def default_simulate_callback(coordinator: ray.ObjectRef, results: Dict[str, Any]): +def default_simulate_callback(results: Dict[str, Any]): pass @@ -147,9 +116,8 @@ def validate_runtime_configs(configs: Dict[str, Any]): class RolloutWorker(RemoteInterface): def __init__( self, - experiment_tag: str, env_desc: Dict[str, Any], - agent_groups: Dict[str, Set], + agent_groups: Dict[str, Tuple], rollout_config: Union[RolloutConfig, Dict[str, Any]], log_dir: str, rollout_callback: Callable[[ray.ObjectRef, Dict[str, Any]], Any] = None, @@ -166,7 +134,6 @@ def __init__( * `max_step`: int, the maximum step of each episode. * `num_eval_episodes`: int, the number of epsiodes for each evaluation. log_dir (str): Log directory. - experiment_tag (str): Experiment tag, to create a data table. rollout_callback (Callable[[ray.ObjectRef, Dict[str, Any]], Any], optional): Callback function for rollout task, users can determine how \ to cordinate with coordinator here. Defaults by None, indicating no coordination. simulate_callback (Callable[[ray.ObjectRef, Dict[str, Any]], Any]): Callback function for simulation task, users can determine \ @@ -186,8 +153,6 @@ def __init__( self.agent_groups = agent_groups self.rollout_config = RolloutConfig.from_raw(rollout_config) - validate_runtime_configs(self.rollout_config) - # create environment runner, handling evaluation or rollout task env_runner_resource_config = resource_config["inference_server"] self.env_runner = self.create_env_runner( @@ -198,7 +163,6 @@ def __init__( self.rollout_callback = rollout_callback or default_rollout_callback self.simulate_callback = simulate_callback or default_simulate_callback self.tb_writer = tensorboard.SummaryWriter(log_dir=log_dir) - self.experiment_tag = experiment_tag self.verbose = verbose def create_env_runner( @@ -225,6 +189,8 @@ def create_env_runner( env_func=lambda: env_desc["creator"](**env_desc["config"]), max_env_num=rollout_config.n_envs_per_worker, use_subproc_env=rollout_config.use_subproc_env, + agent_groups=self.agent_groups, + inference_entry_points=rollout_config.inference_entry_points, ) return env_runner @@ -239,7 +205,6 @@ def rollout(self, task: RolloutTask): """ stopper = get_stopper(task.stopping_conditions) - active_agents = active_agents or self.env_agents total_timesteps = 0 eval_results = {} @@ -258,29 +223,27 @@ def rollout(self, task: RolloutTask): results = self.step_rollout( eval_step, task.strategy_specs, - self.rollout_config, task.data_entrypoint_mapping, ) - total_timesteps += results["total_timesteps"] - eval_results = results.get("evaluation", None) - - performance["rollout_iter_rate"] = (epoch + 1) / (time.time() - start_time) - performance["rollout_FPS"] = results["FPS"] - performance["ave_rollout_FPS"] = ( - performance["ave_rollout_FPS"] * epoch + results["FPS"] - ) / (epoch + 1) - - if eval_results is not None: - if self.verbose: - eval_results["performance"] = performance - formatted_results = pprint.pformat(eval_results) - Logger.info(f"Evaluation at epoch: {epoch}\n{formatted_results}") - write_to_tensorboard( - self.tb_writer, - eval_results, - global_step=total_timesteps, - prefix="Evaluation", - ) + # total_timesteps += results["total_timesteps"] + + # performance["rollout_iter_rate"] = (epoch + 1) / (time.time() - start_time) + # performance["rollout_FPS"] = results["FPS"] + # performance["ave_rollout_FPS"] = ( + # performance["ave_rollout_FPS"] * epoch + results["FPS"] + # ) / (epoch + 1) + + # if self.verbose: + # eval_results["performance"] = performance + # formatted_results = pprint.pformat(eval_results) + # Logger.info(f"Evaluation at epoch: {epoch}\n{formatted_results}") + + # write_to_tensorboard( + # self.tb_writer, + # results, + # global_step=total_timesteps, + # prefix="Evaluation", + # ) write_to_tensorboard( self.tb_writer, performance, global_step=epoch, prefix="Performance" @@ -291,7 +254,8 @@ def rollout(self, task: RolloutTask): break epoch += 1 - self.rollout_callback(self.coordinator, results) + self.rollout_callback(results) + return results @abstractmethod @@ -299,26 +263,12 @@ def step_rollout( self, eval_step: bool, strategy_specs: Dict[AgentID, StrategySpec], - rollout_config: Dict[str, Any], data_entrypoint_mapping: Dict[AgentID, str], ) -> List[Dict[str, Any]]: """The logic function to run rollout. Users must implment this method. Args: eval_step (bool): Indicate evaluation or not. - rollout_config (Dict[str, Any]): Runtime configurations to control the amount of sampled data. Keys include: - - `flag`: indicate the task type, the value is rollout. - - `max_step`: indicates the maximum length of an episode. - - `num_episodes`: indicates how many episodes will be collected. - - `policy_distribution`: a dict describes the policy distribution. - - `parameter_desc_dict`: a dict describes the parameter description. - - `trainable_pairs`: a dict describes the trainable policy configuration, it is a mapping from `runtime_ids` \ - to a tuple of policy id and policy configuration. - - `behavior_policies`: a dict maps runtime agents to policy ids, it specifies the behavior policy for available agents, \ - could be a subset of the full agent set. - - `agent_group`: a dict that maps runtime agents to a list of environment agents, which describes the envrionment agents \ - governed by what runtime agent interface. - - `fragment_length`: the maximum of collected data frames. data_entrypoint_mapping: ... Raises: diff --git a/malib/runner.py b/malib/runner.py deleted file mode 100644 index 380efb8c..00000000 --- a/malib/runner.py +++ /dev/null @@ -1,101 +0,0 @@ -# MIT License - -# Copyright (c) 2021 MARL @ SJTU - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import time -import ray - -from malib import settings -from malib.utils.logging import Logger -from malib.scenarios import marl_scenario, psro_scenario -from malib.scenarios.scenario import Scenario -from malib.backend.offline_dataset_server import OfflineDataset -from malib.backend.parameter_server import ParameterServer - - -def start_servers(data_table_capacity: int = 100000): - try: - offline_dataset_server = ( - OfflineDataset.as_remote(num_cpus=0) - .options(name=settings.OFFLINE_DATASET_ACTOR, max_concurrency=100) - .remote(table_capacity=data_table_capacity) - ) - ray.get(offline_dataset_server.start.remote()) - except ValueError: - Logger.warning("detected existing offline dataset server") - offline_dataset_server = ray.get_actor(settings.OFFLINE_DATASET_ACTOR) - - try: - parameter_server = ( - ParameterServer.as_remote(num_cpus=1) - .options(name=settings.PARAMETER_SERVER_ACTOR, max_concurrency=100) - .remote() - ) - ray.get(parameter_server.start.remote()) - except ValueError: - Logger.warning("detected exisitng parameter server") - parameter_server = ray.get_actor(settings.PARAMETER_SERVER_ACTOR) - - return parameter_server, offline_dataset_server - - -def run(scenario: Scenario, cluster_address: str = "auto"): - """Load scenario to the execution plan and lauch a cluster. The instance will search an active \ - cluster by default. Users can also determine the specified cluster with given `cluster_address`. - - Args: - scenario (Scenario): Scenario instance. - cluster_address (str, optional): Ray cluster address. Defaults to "auto", which means the \ - training instance will search an active cluster. - - Raises: - TypeError: Unexpected scenario type. - """ - - try: - start_ray_info = ray.init(address="auto", dashboard_port=8265) - except ConnectionError: - Logger.warning("No active cluster deteced, will create a local ray instance.") - start_ray_info = ray.init() - - try: - Logger.info("Ray lauched: {}".format(start_ray_info)) - Logger.info("Ray cluster resources info: {}".format(ray.cluster_resources())) - - parameter_server, offline_dataset_server = start_servers() - scenario.parameter_server = parameter_server - scenario.offline_dataset_server = offline_dataset_server - - experiment_tag = f"malib-{scenario.name}-{time.strftime('%Y-%m-%d-%H%M%S')}" - - if isinstance(scenario, psro_scenario.PSROScenario): - psro_scenario.execution_plan(experiment_tag, scenario) - elif isinstance(scenario, marl_scenario.MARLScenario): - marl_scenario.execution_plan(experiment_tag, scenario) - else: - raise TypeError("Unexpected scenario type: {}".format(scenario)) - except KeyboardInterrupt: - ray.shutdown() - except TypeError as e: - ray.shutdown() - raise e - except Exception as e: - raise e diff --git a/malib/scenarios/sarl_scenario.py b/malib/scenarios/sarl_scenario.py index 808df037..653f6cdc 100644 --- a/malib/scenarios/sarl_scenario.py +++ b/malib/scenarios/sarl_scenario.py @@ -23,14 +23,14 @@ # SOFTWARE. from typing import Dict, Any -from malib.common.task import OptimizationTask, RolloutTask +from malib.common.task import TaskType, OptimizationTask, RolloutTask from malib.scenarios import Scenario - +from malib.utils.stopping_conditions import StoppingCondition, get_stopper from malib.utils.logging import Logger from malib.backend.league import League -from malib.learner.manager import TrainingManager -from malib.rollout.manager import RolloutWorkerManager, TaskType +from malib.learner.manager import LearnerManager +from malib.rollout.manager import RolloutWorkerManager from malib.rollout.inference.manager import InferenceManager @@ -44,8 +44,6 @@ def __init__( training_config: Dict[str, Any], rollout_config: Dict[str, Any], stopping_conditions: Dict[str, Any], - dataset_config: Dict[str, Any], - parameter_server_config: Dict[str, Any], resource_config: Dict[str, Any] = None, ): super().__init__( @@ -57,16 +55,17 @@ def __init__( training_config, rollout_config, stopping_conditions, - dataset_config, - parameter_server_config, ) self.num_policy_each_interface = 1 self.resource_config = resource_config or {"training": None, "rollout": None} + def create_global_stopper(self) -> StoppingCondition: + return get_stopper(self.stopping_conditions) + def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = True): # TODO(ming): simplify the initialization of training and rollout manager with a scenario instance as input - training_manager = TrainingManager( + learner_manager = LearnerManager( experiment_tag=experiment_tag, stopping_conditions=scenario.stopping_conditions, algorithms=scenario.algorithms, @@ -81,8 +80,14 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = verbose=verbose, ) + inference_manager = InferenceManager( + group_info=scenario.group_info, + ray_actor_namespace="inference_{}".format(experiment_tag), + model_entry_point=learner_manager.learner_entrypoints, + scenario=scenario, + ) + rollout_manager = RolloutWorkerManager( - experiment_tag=experiment_tag, stopping_conditions=scenario.stopping_conditions, num_worker=scenario.num_worker, group_info=scenario.group_info, @@ -94,32 +99,22 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = verbose=verbose, ) - inference_manager = InferenceManager( - group_info=scenario.group_info, - ray_actor_namespace="inference_{}".format(experiment_tag), - entrypoints=training_manager.get_data_entrypoints(), - scenario=scenario, - ) - - league = League(rollout_manager, training_manager, inference_manager) - - # NOTE(ming): if all agents are active, the strategy specs should not contain any pids - strategy_specs = training_manager.add_policies(n=1) - Logger.info( - f"Training manager was inistialized with a strategy spec:\n{strategy_specs}" + league = League( + learner_manager, rollout_manager, inference_manager, namespace=experiment_tag ) optimization_task = OptimizationTask( active_agents=scenario.env_desc["possible_agents"], stop_conditions=scenario.stopping_conditions["training"], ) - training_manager.submit(optimization_task) + + strategy_specs = learner_manager.get_strategy_specs() rollout_task = RolloutTask( task_type=TaskType.ROLLOUT, strategy_specs=strategy_specs, stopping_conditions=scenario.stopping_conditions["rollout"], - data_entrypoint_mapping=training_manager.get_data_entrypoint_mapping(), + data_entrypoint_mapping=learner_manager.data_entrypoints, ) evaluation_task = RolloutTask( @@ -127,8 +122,20 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = strategy_specs=strategy_specs, ) - rollout_manager.submit(rollout_task) - rollout_manager.submit(evaluation_task) + stopper = scenario.create_global_stopper() + epoch_cnt = 0 + + while True: + rollout_results = league.submit(rollout_task, wait=True) + training_results = league.submit(optimization_task, wait=True) + evaluation_results = league.submit(evaluation_task, wait=True) + epoch_cnt += 1 + if stopper.should_stop( + evaluation_results, training_results, rollout_results, epoch_cnt + ): + break + if epoch_cnt % scenario.save_interval == 0: + league.save_checkpoint(global_step=epoch_cnt) results = league.get_results() league.terminate() diff --git a/malib/scenarios/scenario.py b/malib/scenarios/scenario.py index f009a085..03af7069 100644 --- a/malib/scenarios/scenario.py +++ b/malib/scenarios/scenario.py @@ -22,13 +22,14 @@ from abc import ABC, abstractmethod from types import LambdaType -from typing import Callable, Union, Dict, Any, Set, List +from typing import Dict, Any, Set, Tuple from copy import deepcopy from collections import defaultdict import gym from malib.utils.typing import AgentID +from malib.utils.stopping_conditions import StoppingCondition DEFAULT_STOPPING_CONDITIONS = {} @@ -40,7 +41,7 @@ def validate_spaces(agent_groups: Dict[str, Set[AgentID]], env_desc: Dict[str, A def validate_agent_group( - agent_group: Dict[str, List[AgentID]], + agent_group: Dict[str, Tuple[AgentID]], observation_spaces: Dict[AgentID, gym.Space], action_spaces: Dict[AgentID, gym.Space], ) -> None: @@ -66,6 +67,23 @@ def validate_agent_group( assert select_act_space.shape == action_spaces[agent].shape +def form_group_info(env_desc, agent_mapping_func): + agent_groups = defaultdict(lambda: list()) + grouped_obs_space = {} + grouped_act_space = {} + for agent in env_desc["possible_agents"]: + rid = agent_mapping_func(agent) + agent_groups[rid].append(agent) + grouped_obs_space[rid] = env_desc["observation_spaces"][agent] + grouped_act_space[rid] = env_desc["action_spaces"][agent] + agent_groups = {k: tuple(v) for k, v in agent_groups.items()} + return { + "observation_space": grouped_obs_space, + "action_space": grouped_act_space, + "agent_groups": agent_groups, + } + + class Scenario(ABC): @abstractmethod def __init__( @@ -78,8 +96,6 @@ def __init__( training_config: Dict[str, Any], rollout_config: Dict[str, Any], stopping_conditions: Dict[str, Any], - dataset_config: Dict[str, Any], - parameter_server_config: Dict[str, Any], ): self.name = name self.log_dir = log_dir @@ -87,33 +103,23 @@ def __init__( self.algorithms = algorithms self.agent_mapping_func = agent_mapping_func # then generate grouping information here - agent_groups = defaultdict(lambda: set()) - grouped_obs_space = {} - grouped_act_space = {} - for agent in env_desc["possible_agents"]: - rid = agent_mapping_func(agent) - agent_groups[rid].add(agent) - grouped_obs_space[rid] = env_desc["observation_spaces"][agent] - grouped_act_space[rid] = env_desc["action_spaces"][agent] - self.group_info = { - "observation_space": grouped_obs_space, - "action_space": grouped_act_space, - "agent_groups": agent_groups, - } + self.group_info = form_group_info(env_desc, agent_mapping_func) validate_agent_group( - agent_groups, env_desc["observation_spaces"], env_desc["action_spaces"] + self.group_info["agent_groups"], + env_desc["observation_spaces"], + env_desc["action_spaces"], ) self.training_config = training_config self.rollout_config = rollout_config self.stopping_conditions = stopping_conditions or DEFAULT_STOPPING_CONDITIONS - self.dataset_config = dataset_config or {"table_capacity": 1000} - self.parameter_server_config = parameter_server_config or {} - self.parameter_server = None - self.offline_dataset_server = None def copy(self): return deepcopy(self) + @abstractmethod + def create_global_stopper(self) -> StoppingCondition: + """Create a global stopper.""" + def with_updates(self, **kwargs) -> "Scenario": new_copy = self.copy() for k, v in kwargs.items(): diff --git a/malib/utils/stopping_conditions.py b/malib/utils/stopping_conditions.py index 7a67b329..27584ede 100644 --- a/malib/utils/stopping_conditions.py +++ b/malib/utils/stopping_conditions.py @@ -33,17 +33,17 @@ class StoppingCondition(ABC): @abstractmethod - def should_stop(self, latest_trainer_result: dict, *args, **kwargs) -> bool: + def should_stop(self, results, **kwargs) -> bool: pass class NoStoppingCondition(StoppingCondition): - def should_stop(self, latest_trainer_result: dict, *args, **kwargs) -> bool: + def should_stop(self, results, **kwargs) -> bool: return False class StopImmediately(StoppingCondition): - def should_stop(self, latest_trainer_result: dict, *args, **kwargs) -> bool: + def should_stop(self, results, **kwargs) -> bool: return True @@ -51,10 +51,8 @@ class RewardImprovementStopping(StoppingCondition): def __init__(self, mininum_reward_improvement: float) -> None: self.minium_reward_improvement = mininum_reward_improvement - def should_stop(self, latest_trainer_result: dict, *args, **kwargs) -> bool: - reward_this_iter = latest_trainer_result.get( - "evaluation", {"episode_reward_mean": float("inf")} - )["episode_reward_mean"] + def should_stop(self, results, **kwargs) -> bool: + reward_this_iter = results.get("episode_reward_mean", float("inf")) if reward_this_iter == float("inf"): return False should_stop = False @@ -69,7 +67,7 @@ def __init__( self.max_iteration = max_iteration self.n_iteration = 0 - def should_stop(self, latest_trainer_result: dict, *args, **kwargs) -> bool: + def should_stop(self, results, **kwargs) -> bool: self.n_iteration += 1 should_stop = False @@ -87,8 +85,14 @@ def __init__(self, stoppings: List[StoppingCondition]) -> None: super().__init__() self.stoppings = stoppings - def should_stop(self, latest_trainer_result: dict, *args, **kwargs) -> bool: - stops = [e.should_stop(latest_trainer_result) for e in self.stoppings] + def should_stop(self, results, **kwargs) -> bool: + stops = [ + e.should_stop( + results, + **kwargs, + ) + for e in self.stoppings + ] return all(stops) diff --git a/tests/rollout/test_env_runner.py b/tests/rollout/test_env_runner.py index b3c368db..ddddaa3a 100644 --- a/tests/rollout/test_env_runner.py +++ b/tests/rollout/test_env_runner.py @@ -41,7 +41,6 @@ def test_env_runner(env_desc: Dict[str, Any], max_env_num: int): for agent in agents } rollout_config = RolloutConfig( - inference_server_type="ray", num_workers=1, eval_interval=1, n_envs_per_worker=10, @@ -51,7 +50,7 @@ def test_env_runner(env_desc: Dict[str, Any], max_env_num: int): infer_clients = { agent: inference_remote_cls.remote( - entry_point=None, + model_entry_point=None, policy_cls=RandomPolicy, observation_space=observation_spaces[agent], action_space=action_spaces[agent], diff --git a/tests/rollout/test_pb_rollout_worker.py b/tests/rollout/test_pb_rollout_worker.py index a553f315..a8652705 100644 --- a/tests/rollout/test_pb_rollout_worker.py +++ b/tests/rollout/test_pb_rollout_worker.py @@ -25,15 +25,20 @@ from typing import Dict, Any import pytest -import threading -import time import ray from pytest_mock import MockerFixture from gym import spaces -from malib.runner import start_servers -from malib.mocker.mocker_utils import FakeInferenceClient, FakeInferenceServer +from malib.common.task import RolloutTask +from malib.common.strategy_spec import StrategySpec +from malib.rl.random import RandomPolicy +from malib.rl.config import Algorithm +from malib.rollout.envs.random import env_desc_gen +from malib.rollout.rollout_config import RolloutConfig +from malib.rollout.pb_rolloutworker import PBRolloutWorker +from malib.rollout.inference.manager import InferenceManager +from malib.scenarios.scenario import form_group_info def gen_rollout_config(inference_server_type: str): @@ -52,93 +57,69 @@ def gen_rollout_config(inference_server_type: str): } -def create_rollout_worker( - mocker: MockerFixture, env_desc: Dict[str, Any], rollout_config: Dict[str, Any] -): - mocker.patch( - "malib.rollout.rolloutworker.RayInferenceClient", new=FakeInferenceClient - ) - mocker.patch( - "malib.rollout.rolloutworker.RayInferenceServer", new=FakeInferenceServer - ) - from malib.rollout.pb_rolloutworker import PBRolloutWorker - - worker = PBRolloutWorker( - experiment_tag="test_rollout_worker", - env_desc=env_desc, - agent_mapping_func=lambda agent: agent, - rollout_config=rollout_config, - log_dir="./logs", - ) - return worker - - @pytest.mark.parametrize("n_player", [1, 2]) -@pytest.mark.parametrize("inference_server_type", ["local", "ray"]) class TestRolloutWorker: - def test_rollout( - self, mocker: MockerFixture, n_player: int, inference_server_type: str - ): - if not ray.is_initialized(): - ray.init() - - parameter_server, dataset_server = start_servers() - - agents = [f"player_{i}" for i in range(n_player)] - - worker = create_rollout_worker( - mocker, - env_desc={ - "possible_agents": agents, - "observation_spaces": { - agent: spaces.Box(-1, 1.0, shape=(2,)) for agent in agents - }, - "action_spaces": { - agent: spaces.Box(-1, 1, shape=(2,)) for agent in agents - }, - }, - rollout_config=gen_rollout_config(inference_server_type), - ) - - data_entrypoints = {agent: agent for agent in agents} - results = worker.rollout( - None, - {"max_iteration": 2}, - data_entrypoints, - None, - ) - print("rollout results:", results) - - ray.kill(parameter_server) - ray.kill(dataset_server) - ray.shutdown() - - def test_simulation( - self, mocker: MockerFixture, n_player: int, inference_server_type: str - ): - if not ray.is_initialized(): - ray.init() - - parameter_server, dataset_server = start_servers() - - agents = [f"player_{i}" for i in range(n_player)] - - worker = create_rollout_worker( - mocker, - env_desc={ - "possible_agents": agents, - "observation_spaces": { - agent: spaces.Box(-1, 1.0, shape=(2,)) for agent in agents - }, - "action_spaces": { - agent: spaces.Box(-1, 1, shape=(2,)) for agent in agents - }, - }, - rollout_config=gen_rollout_config(inference_server_type), - ) - - results = worker.simulate({}) - - ray.kill(parameter_server) - ray.kill(dataset_server) - ray.shutdown() + def test_rollout(self, n_player: int): + with ray.init(local_mode=True): + env_desc = env_desc_gen(num_agents=n_player) + obs_spaces = env_desc["observation_spaces"] + act_spaces = env_desc["action_spaces"] + agents = env_desc["possible_agents"] + log_dir = "./logs" + + algorithm = Algorithm( + policy=RandomPolicy, + trainer=None, + model_config=None, + ) + + rollout_config = RolloutConfig( + num_workers=1, + eval_interval=1, + n_envs_per_worker=10, + use_subproc_env=False, + timelimit=256, + ) + + group_info = form_group_info(env_desc, lambda agent: "default") + + inference_namespace = "test_pb_rolloutworker" + + infer_manager = InferenceManager( + group_info=group_info, + ray_actor_namespace=inference_namespace, + algorithm=algorithm, + model_entry_point=None, + ) + + rollout_config.inference_entry_points = infer_manager.inference_entry_points + + strategy_specs = { + agent: StrategySpec( + policy_cls=algorithm.policy, + observation_space=obs_spaces[agent], + action_space=act_spaces[agent], + identifier=agent, + model_config=algorithm.model_config, + policy_ids=["policy-0"], + ) + for agent in agents + } + + worker = PBRolloutWorker( + env_desc=env_desc, + agent_groups=group_info["agent_groups"], + rollout_config=rollout_config, + log_dir=log_dir, + ) + + task = RolloutTask( + strategy_specs=strategy_specs, + stopping_conditions={"max_iteration": 10}, + data_entrypoint_mapping=None, # no data collect + ) + stats = worker.rollout(task) + + # def test_rollout_with_data_entrypoint(self, mocker: MockerFixture, n_player: int): + # with ray.init(local_mode=True): + # pass