From c01e4634036a31c5326d347e87b26a1aa9cc5b7c Mon Sep 17 00:00:00 2001 From: Ming Zhou Date: Fri, 27 Oct 2023 19:39:12 +0800 Subject: [PATCH] tmp save --- malib/agent/manager.py | 24 +- malib/backend/dataset_server/__init__.py | 0 malib/backend/dataset_server/utils.py | 17 +- malib/common/task.py | 6 +- malib/rl/common/policy.py | 16 + malib/rl/pg/policy.py | 23 +- malib/rollout/__init__.py | 6 +- .../inference/{ray/server.py => client.py} | 42 +- .../{ray/client.py => env_runner.py} | 88 ++-- .../{model_server.py => model_client.py} | 9 +- malib/rollout/inference/ray/__init__.py | 26 -- malib/rollout/inference/utils.py | 6 +- malib/rollout/manager.py | 54 +-- malib/rollout/pb_rolloutworker.py | 40 +- malib/rollout/rolloutworker.py | 260 ++---------- malib/scenarios/sarl_scenario.py | 13 +- malib/scenarios/scenario.py | 56 ++- tests/rollout/test_env_runner.py | 42 ++ tests/rollout/test_ray_inference.py | 397 ------------------ 19 files changed, 257 insertions(+), 868 deletions(-) delete mode 100644 malib/backend/dataset_server/__init__.py rename malib/rollout/inference/{ray/server.py => client.py} (78%) rename malib/rollout/inference/{ray/client.py => env_runner.py} (78%) rename malib/rollout/inference/{model_server.py => model_client.py} (88%) delete mode 100644 malib/rollout/inference/ray/__init__.py create mode 100644 tests/rollout/test_env_runner.py delete mode 100644 tests/rollout/test_ray_inference.py diff --git a/malib/agent/manager.py b/malib/agent/manager.py index 4fd211c3..7eca9f3a 100644 --- a/malib/agent/manager.py +++ b/malib/agent/manager.py @@ -57,11 +57,6 @@ ) -def validate_spaces(agent_groups: Dict[str, Set[AgentID]], env_desc: Dict[str, Any]): - # TODO(ming): check whether the agents in the group share the same observation space and action space - raise NotImplementedError - - class TrainingManager(Manager): def __init__( self, @@ -70,6 +65,7 @@ def __init__( algorithms: Dict[str, Any], env_desc: Dict[str, Any], agent_mapping_func: Callable[[AgentID], str], + group_info: Dict[str, Any], training_config: Union[Dict[str, Any], TrainingConfig], log_dir: str, remote_mode: bool = True, @@ -98,16 +94,10 @@ def __init__( training_config = TrainingConfig.from_raw(training_config) # interface config give the agent type used here and the group mapping if needed - agent_groups = defaultdict(lambda: set()) - for agent in env_desc["possible_agents"]: - rid = agent_mapping_func(agent) - agent_groups[rid].add(agent) - - validate_spaces(agent_groups, env_desc) # FIXME(ming): resource configuration is not available now, will open in the next version if training_config.trainer_config.get("use_cuda", False): - num_gpus = 1 / len(agent_groups) + num_gpus = 1 / len(group_info["agent_groups"]) else: num_gpus = 0.0 if not os.path.exists(log_dir): @@ -125,7 +115,7 @@ def __init__( "training" in stopping_conditions ), f"Stopping conditions should contains `training` stoppong conditions: {stopping_conditions}" - for rid, agents in agent_groups.items(): + for rid, agents in group_info["agent_groups"].items(): _cls = learner_cls.remote if remote_mode else learner_cls learners[rid] = _cls( experiment_tag=experiment_tag, @@ -145,8 +135,7 @@ def __init__( _ = ray.get([x.connect.remote() for x in learners.values()]) # TODO(ming): collect data entrypoints from learners - - self._agent_groups = agent_groups + self._group_info = group_info self._runtime_ids = tuple(self._agent_groups.keys()) self._experiment_tag = experiment_tag self._env_description = env_desc @@ -170,7 +159,7 @@ def agent_groups(self) -> Dict[str, Set[AgentID]]: Dict[str, Set[AgentID]]: A dict of agent set. """ - return self._agent_groups + return self._group_info["agent_groups"] @property def get_data_entrypoints(self) -> Dict[str, str]: @@ -202,6 +191,9 @@ def runtime_ids(self) -> Tuple[str]: return self._runtime_ids + def get_data_entrypoint_mapping(self) -> Dict[AgentID, str]: + raise NotImplementedError + def add_policies( self, interface_ids: Sequence[str] = None, n: Union[int, Dict[str, int]] = 1 ) -> Dict[str, Type[StrategySpec]]: diff --git a/malib/backend/dataset_server/__init__.py b/malib/backend/dataset_server/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/malib/backend/dataset_server/utils.py b/malib/backend/dataset_server/utils.py index 190d2360..a78e6517 100644 --- a/malib/backend/dataset_server/utils.py +++ b/malib/backend/dataset_server/utils.py @@ -2,17 +2,26 @@ import pickle import grpc +import sys +import os + +sys.path.append(os.path.dirname(__file__)) from . import data_pb2 from . import data_pb2_grpc -def send_data(host: str, port: int, data: Any): +def send_data(data: Any, host: str = None, port: int = None, entrypoint: str = None): if not isinstance(data, bytes): data = pickle.dumps(data) - with grpc.insecure_channel(f"{host}:{port}") as channel: - stub = data_pb2_grpc.SendDataStub(channel) - reply = stub.Collect(data_pb2.Data(data=data)) + if host is not None: + with grpc.insecure_channel(f"{host}:{port}") as channel: + stub = data_pb2_grpc.SendDataStub(channel) + reply = stub.Collect(data_pb2.Data(data=data)) + else: + with grpc.insecure_channel(entrypoint) as channel: + stub = data_pb2_grpc.SendDataStub(channel) + reply = stub.Collect(data_pb2.Data(data=data)) return reply.message diff --git a/malib/common/task.py b/malib/common/task.py index 007b6fb2..e59df234 100644 --- a/malib/common/task.py +++ b/malib/common/task.py @@ -15,8 +15,9 @@ class TaskType(IntEnum): @dataclass class RolloutTask: task_type: int - active_agents: List[AgentID] strategy_specs: Dict[str, Any] = field(default_factory=dict()) + stopping_conditions: Dict[str, Any] = field(default_factory=dict()) + data_entrypoint_mapping: Dict[str, Any] = field(default_factory=dict()) @classmethod def from_raw( @@ -32,9 +33,6 @@ def from_raw( @dataclass class OptimizationTask: - data_entrypoints: Dict[str, str] - """a mapping defines the data request identifier and the data entrypoint.""" - stop_conditions: Dict[str, Any] """stopping conditions for optimization task, e.g., max iteration, max time, etc.""" diff --git a/malib/rl/common/policy.py b/malib/rl/common/policy.py index 0f673f9a..f3525308 100644 --- a/malib/rl/common/policy.py +++ b/malib/rl/common/policy.py @@ -28,6 +28,7 @@ import torch import torch.nn as nn +import gym from gym import spaces @@ -100,6 +101,21 @@ def __init__( use_sde=custom_config.get("use_sde", False), dist_kwargs=custom_config.get("dist_kwargs", None), ) + if kwargs.get("model_client"): + self.model = kwargs["model_client"] + else: + self.model = self.create_model() + + def create_model(self): + raise NotImplementedError + + @property + def action_space(self) -> gym.Space: + return self._action_space + + @property + def observation_space(self) -> gym.Space: + return self._observation_space @property def model_config(self): diff --git a/malib/rl/pg/policy.py b/malib/rl/pg/policy.py index 583eccbd..8903ff73 100644 --- a/malib/rl/pg/policy.py +++ b/malib/rl/pg/policy.py @@ -70,36 +70,39 @@ def __init__( observation_space, action_space, model_config, custom_config, **kwargs ) + def create_model(self): # update model preprocess_net config here action_shape = ( - (action_space.n,) if len(action_space.shape) == 0 else action_space.shape + (self.action_space.n,) + if len(self.action_space.shape) == 0 + else self.action_space.shape ) preprocess_net: nn.Module = net.make_net( - observation_space, + self.observation_space, self.device, - model_config["preprocess_net"].get("net_type", None), - **model_config["preprocess_net"]["config"] + self.model_config["preprocess_net"].get("net_type", None), + **self.model_config["preprocess_net"]["config"] ) - if isinstance(action_space, spaces.Discrete): + if isinstance(self.action_space, spaces.Discrete): self.actor = discrete.Actor( preprocess_net=preprocess_net, action_shape=action_shape, - hidden_sizes=model_config["hidden_sizes"], + hidden_sizes=self.model_config["hidden_sizes"], softmax_output=False, device=self.device, ) - elif isinstance(action_space, spaces.Box): + elif isinstance(self.action_space, spaces.Box): self.actor = continuous.Actor( preprocess_net=preprocess_net, action_shape=action_shape, - hidden_sizes=model_config["hidden_sizes"], - max_action=custom_config.get("max_action", 1.0), + hidden_sizes=self.model_config["hidden_sizes"], + max_action=self.custom_config.get("max_action", 1.0), device=self.device, ) else: raise TypeError( - "Unexpected action space type: {}".format(type(action_space)) + "Unexpected action space type: {}".format(type(self.action_space)) ) self.register_state(self.actor, "actor") diff --git a/malib/rollout/__init__.py b/malib/rollout/__init__.py index 8d802638..0e02ff91 100644 --- a/malib/rollout/__init__.py +++ b/malib/rollout/__init__.py @@ -23,8 +23,8 @@ # SOFTWARE. from .pb_rolloutworker import RolloutWorker -from .inference.ray.client import RayInferenceClient as InferenceClient -from .inference.ray.server import RayInferenceWorkerSet as InferenceWorkerSet +from .inference.env_runner import EnvRunner +from .inference.client import InferenceClient -__all__ = ["RolloutWorker", "InferenceClient", "InferenceWorkerSet"] +__all__ = ["RolloutWorker", "EnvRunner", "InferenceClient"] diff --git a/malib/rollout/inference/ray/server.py b/malib/rollout/inference/client.py similarity index 78% rename from malib/rollout/inference/ray/server.py rename to malib/rollout/inference/client.py index bc9a2487..fa28c7f5 100644 --- a/malib/rollout/inference/ray/server.py +++ b/malib/rollout/inference/client.py @@ -31,9 +31,7 @@ import os import pickle as pkl -import ray import gym -import torch from malib import settings from malib.remote.interface import RemoteInterface @@ -42,20 +40,17 @@ from malib.utils.episode import Episode from malib.common.strategy_spec import StrategySpec from malib.rl.common.policy import Policy -from malib.backend.parameter_server import ParameterServer -ClientHandler = namedtuple("ClientHandler", "sender,recver,runtime_config,rnn_states") +Connection = namedtuple("Connection", "sender,recver,runtime_config,rnn_states") -class RayInferenceWorkerSet(RemoteInterface): +class InferenceClient(RemoteInterface): def __init__( self, agent_id: AgentID, observation_space: gym.Space, action_space: gym.Space, - parameter_server: ParameterServer, - governed_agents: List[AgentID], ) -> None: """Create ray-based inference server. @@ -63,26 +58,22 @@ def __init__( agent_id (AgentID): Runtime agent id, not environment agent id. observation_space (gym.Space): Observation space related to the governed environment agents. action_space (gym.Space): Action space related to the governed environment agents. - parameter_server (ParameterServer): Parameter server. - governed_agents (List[AgentID]): A list of environment agents. """ self.runtime_agent_id = agent_id self.observation_space = observation_space self.action_space = action_space - self.parameter_server = parameter_server self.thread_pool = ThreadPoolExecutor() - self.governed_agents = governed_agents self.policies: Dict[str, Policy] = {} self.strategy_spec_dict: Dict[str, StrategySpec] = {} def shutdown(self): self.thread_pool.shutdown(wait=True) - for _handler in self.clients.values(): + for _handler in self.connections.values(): _handler.sender.shutdown(True) _handler.recver.shutdown(True) - self.clients: Dict[int, ClientHandler] = {} + self.connections: Dict[int, Connection] = {} def save(self, model_dir: str) -> None: if not os.path.exists(model_dir): @@ -100,12 +91,6 @@ def compute_action( strategy_specs: Dict[AgentID, StrategySpec] = runtime_config["strategy_specs"] return_dataframes: List[DataFrame] = [] - # check policy - self._update_policies( - runtime_config["strategy_specs"][self.runtime_agent_id], - self.runtime_agent_id, - ) - assert len(dataframes) > 0 for dataframe in dataframes: @@ -132,16 +117,6 @@ def compute_action( rets = {} - with timer.time_avg("policy_update"): - info = ray.get( - self.parameter_server.get_weights.remote( - spec_id=spec.id, - spec_policy_id=spec_policy_id, - ) - ) - if info["weights"] is not None: - self.policies[policy_id].load_state_dict(info["weights"]) - with timer.time_avg("compute_action"): ( rets[Episode.ACTION], @@ -170,19 +145,12 @@ def compute_action( continue else: rets[k] = v.reshape(batch_size, -1) + return_dataframes.append( DataFrame(identifier=agent_id, data=rets, meta_data=dataframe.meta_data) ) - # print(f"timer information: {timer.todict()}") return return_dataframes - def _update_policies(self, strategy_spec: StrategySpec, agent_id: AgentID): - for strategy_spec_pid in strategy_spec.policy_ids: - policy_id = f"{strategy_spec.id}/{strategy_spec_pid}" - if policy_id not in self.policies: - policy = strategy_spec.gen_policy(device="cpu") - self.policies[policy_id] = policy - def _get_initial_states(self, client_id, observation, policy: Policy, identifier): if ( diff --git a/malib/rollout/inference/ray/client.py b/malib/rollout/inference/env_runner.py similarity index 78% rename from malib/rollout/inference/ray/client.py rename to malib/rollout/inference/env_runner.py index aa143498..1b23fa02 100644 --- a/malib/rollout/inference/ray/client.py +++ b/malib/rollout/inference/env_runner.py @@ -23,7 +23,7 @@ # SOFTWARE. from argparse import Namespace -from typing import Any, List, Dict, Tuple +from typing import Any, List, Dict, Tuple, Set from types import LambdaType from collections import defaultdict @@ -31,36 +31,38 @@ import time import traceback +import pickle import ray -from ray.util.queue import Queue from ray.actor import ActorHandle -from malib.utils.logging import Logger - from malib.utils.typing import AgentID, DataFrame, BehaviorMode -from malib.utils.episode import Episode, NewEpisodeDict, NewEpisodeList +from malib.utils.episode import NewEpisodeList from malib.utils.preprocessor import Preprocessor, get_preprocessor from malib.utils.timing import Timing from malib.remote.interface import RemoteInterface from malib.rollout.envs.vector_env import VectorEnv, SubprocVecEnv -from malib.rollout.inference.ray.server import RayInferenceWorkerSet +from malib.rollout.inference.client import InferenceClient from malib.rollout.inference.utils import process_env_rets, process_policy_outputs +from malib.backend.dataset_server.utils import send_data + +class EnvRunner(RemoteInterface): + def __repr__(self) -> str: + return f"" -class RayInferenceClient(RemoteInterface): def __init__( self, env_desc: Dict[str, Any], - dataset_server: ray.ObjectRef, max_env_num: int, + agent_groups: Dict[str, Set], use_subproc_env: bool = False, batch_mode: str = "time_step", postprocessor_types: Dict = None, training_agent_mapping: LambdaType = None, custom_config: Dict[str, Any] = {}, ): - """Construct an inference client. + """Construct an inference client, one for each agent. Args: env_desc (Dict[str, Any]): Environment description @@ -73,7 +75,6 @@ def __init__( custom_config (Dict[str, Any], optional): Custom configuration. Defaults to an empty dict. """ - self.dataset_server = dataset_server self.use_subproc_env = use_subproc_env self.batch_mode = batch_mode self.postprocessor_types = postprocessor_types or ["defaults"] @@ -82,15 +83,8 @@ def __init__( self.training_agent_mapping = training_agent_mapping or (lambda agent: agent) self.max_env_num = max_env_num self.custom_configs = custom_config - - agent_group = defaultdict(lambda: []) - runtime_agent_ids = [] - for agent in env_desc["possible_agents"]: - runtime_id = training_agent_mapping(agent) - agent_group[runtime_id].append(agent) - runtime_agent_ids.append(runtime_id) - self.runtime_agent_ids = set(runtime_agent_ids) - self.agent_group = dict(agent_group) + self.runtime_agent_ids = list(agent_groups.keys()) + self.agent_groups = agent_groups obs_spaces = env_desc["observation_spaces"] act_spaces = env_desc["action_spaces"] @@ -121,9 +115,9 @@ def close(self): def run( self, - agent_interfaces: Dict[AgentID, RayInferenceWorkerSet], + inference_clients: Dict[AgentID, InferenceClient], rollout_config: Dict[str, Any], - dataset_writer_info_dict: Dict[str, Tuple[str, Queue]] = None, + data_entrypoint_mapping: Dict[AgentID, str] = None, ) -> Dict[str, Any]: """Executes environment runner to collect training data or run purely simulation/evaluation. @@ -131,7 +125,7 @@ def run( Only simulation/evaluation tasks return evaluation information. Args: - agent_interfaces (Dict[AgentID, InferenceWorkerSet]): A dict of agent interface servers. + inference_clients (Dict[AgentID, InferenceClient]): A dict of agent interface servers. rollout_config (Dict[str, Any]): Rollout configuration. dataset_writer_info_dict (Dict[str, Tuple[str, Queue]], optional): Dataset writer info dict. Defaults to None. @@ -148,20 +142,13 @@ def run( "strategy_specs": rollout_config["strategy_specs"], } - if task_type == "rollout": - assert ( - dataset_writer_info_dict is not None - ), "rollout task has no available dataset writer" - server_runtime_config["behavior_mode"] = BehaviorMode.EXPLORATION - elif task_type in ["evaluation", "simulation"]: - server_runtime_config["behavior_mode"] = BehaviorMode.EXPLOITATION - eval_results, performance = env_runner( self, - agent_interfaces, + inference_clients, + self.preprocessor, rollout_config, server_runtime_config, - dwriter_info_dict=dataset_writer_info_dict, + data_entrypoint_mapping, ) res = performance.copy() @@ -171,11 +158,12 @@ def run( def env_runner( - client: RayInferenceClient, - servers: Dict[str, RayInferenceWorkerSet], + client: InferenceClient, + agents: Dict[str, InferenceClient], + preprocessors: Dict[str, Preprocessor], rollout_config: Dict[str, Any], server_runtime_config: Dict[str, Any], - dwriter_info_dict: Dict[str, Tuple[str, Queue]] = None, + data_entrypoint_mapping: Dict[AgentID, str], ) -> Tuple[List[Dict[str, Any]], Dict[str, float]]: """The main logic of environment stepping, also for data collections. @@ -198,11 +186,11 @@ def env_runner( """ # check whether remote server or not - evaluate_on = server_runtime_config["behavior_mode"] == BehaviorMode.EXPLOITATION - remote_actor = isinstance(list(servers.values())[0], ActorHandle) + evaluate_on = rollout_config["behavior_mode"] == BehaviorMode.EXPLOITATION + remote_actor = isinstance(list(agents.values())[0], ActorHandle) try: - if dwriter_info_dict is not None: + if data_entrypoint_mapping is not None: episodes = NewEpisodeList( num=client.env.num_envs, agents=client.env.possible_agents ) @@ -217,9 +205,10 @@ def env_runner( env_dones, processed_env_ret, dataframes = process_env_rets( env_rets=env_rets, - preprocessor=server_runtime_config["preprocessor"], + preprocessors=preprocessors, preset_meta_data={"evaluate": evaluate_on}, ) + # env ret is key first, not agent first: state, obs if episodes is not None: episodes.record( @@ -240,20 +229,20 @@ def env_runner( if remote_actor: policy_outputs: Dict[str, List[DataFrame]] = { rid: ray.get( - server.compute_action.remote( + agent.compute_action.remote( grouped_data_frames[rid], runtime_config=server_runtime_config, ) ) - for rid, server in servers.items() + for rid, agent in agents.items() } else: policy_outputs: Dict[str, List[DataFrame]] = { - rid: server.compute_action( + rid: agent.compute_action( grouped_data_frames[rid], runtime_config=server_runtime_config, ) - for rid, server in servers.items() + for rid, agent in agents.items() } with client.timer.time_avg("process_policy_output"): @@ -284,18 +273,11 @@ def env_runner( cnt += 1 - if dwriter_info_dict is not None: + if data_entrypoint_mapping is not None: # episode_id: agent_id: dict_data episodes = episodes.to_numpy() - for rid, writer_info in dwriter_info_dict.items(): - # get agents from agent group - agents = client.agent_group[rid] - batches = [] - # FIXME(ming): multi-agent is wrong! - for episode in episodes: - agent_buffer = [episode[aid] for aid in agents] - batches.append(agent_buffer) - writer_info[-1].put_nowait_batch(batches) + for entrypoint in data_entrypoint_mapping.values(): + send_data(pickle.dumps(episodes), entrypoint) end = time.time() rollout_info = client.env.collect_info() except Exception as e: diff --git a/malib/rollout/inference/model_server.py b/malib/rollout/inference/model_client.py similarity index 88% rename from malib/rollout/inference/model_server.py rename to malib/rollout/inference/model_client.py index f5fbb274..5391e0a1 100644 --- a/malib/rollout/inference/model_server.py +++ b/malib/rollout/inference/model_client.py @@ -9,6 +9,8 @@ import torch import ray +from malib.utils.typing import AgentID, DataFrame + def load_state_dict(client, timeout=10): if isinstance(client, ray.ObjectRef): @@ -21,7 +23,6 @@ class ModelClient: def __init__( self, entry_point: str, model_cls: nn.Module, model_args: Dict[str, Any] ): - # TODO(ming): init server from entry point cluster_type, name_or_address = entry_point.split(":") if "ray" in cluster_type: @@ -42,6 +43,12 @@ def __call__(self, *args: Any, **kwds: Any) -> Any: with torch.inference_mode(): return self.model(*args, **kwds) + def actor(self, *args, **kwargs): + return self.model.actor(*args, **kwargs) + + def critic(self, *args, **kwargs): + return self.model.critic(*args, **kwargs) + def _model_update(self, event: threading.Event): while not event.is_set(): # TODO(ming): update model from remote server diff --git a/malib/rollout/inference/ray/__init__.py b/malib/rollout/inference/ray/__init__.py deleted file mode 100644 index 26ac2ced..00000000 --- a/malib/rollout/inference/ray/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# MIT License - -# Copyright (c) 2021 MARL @ SJTU - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from .client import RayInferenceClient -from .server import RayInferenceWorkerSet - -__all__ = ["RayInferenceClient", "RayInferenceWorkerSet"] diff --git a/malib/rollout/inference/utils.py b/malib/rollout/inference/utils.py index 588f0538..c6aa1ed3 100644 --- a/malib/rollout/inference/utils.py +++ b/malib/rollout/inference/utils.py @@ -37,7 +37,7 @@ def process_env_rets( env_rets: List[Tuple["states", "observations", "rewards", "dones", "infos"]], - preprocessor: Dict[AgentID, Preprocessor], + preprocessors: Dict[AgentID, Preprocessor], preset_meta_data: Dict[str, Any], ): """Process environment returns, generally, for the observation transformation. @@ -70,7 +70,7 @@ def process_env_rets( agents = list(ret[1].keys()) processed_obs = { - agent: preprocessor[agent].transform(raw_obs) + agent: preprocessors[agent].transform(raw_obs) for agent, raw_obs in ret[1].items() } @@ -85,7 +85,7 @@ def process_env_rets( agent_state_list[agent].append(_state) env_rets_to_save[Episode.CUR_STATE] = ret[0] - original_obs_space = list(preprocessor.values())[0].original_space + original_obs_space = list(preprocessors.values())[0].original_space if ( isinstance(original_obs_space, spaces.Dict) and "action_mask" in original_obs_space diff --git a/malib/rollout/manager.py b/malib/rollout/manager.py index e6b39471..1a4ddbd7 100644 --- a/malib/rollout/manager.py +++ b/malib/rollout/manager.py @@ -27,7 +27,6 @@ """ from typing import Dict, Tuple, Any, Callable, Set, List, Union -from collections import defaultdict import traceback import ray @@ -36,7 +35,7 @@ from ray.util import ActorPool from malib.utils.logging import Logger -from malib.common.task import TaskType, RolloutTask +from malib.common.task import RolloutTask from malib.common.manager import Manager from malib.remote.interface import RemoteInterface from malib.common.strategy_spec import StrategySpec @@ -79,6 +78,7 @@ def __init__( stopping_conditions: Dict[str, Any], num_worker: int, agent_mapping_func: Callable, + group_info: Dict[str, Any], rollout_config: Dict[str, Any], env_desc: Dict[str, Any], log_dir: str, @@ -110,6 +110,7 @@ def __init__( experiment_tag=experiment_tag, env_desc=env_desc, agent_mapping_func=agent_mapping_func, + agent_groups=group_info["agent_groups"], rollout_config=rollout_config, log_dir=log_dir, rollout_callback=None, @@ -121,14 +122,8 @@ def __init__( self._workers: List[ray.actor] = workers self._actor_pool = ActorPool(self._workers) - - agent_groups = defaultdict(lambda: set()) - for agent in env_desc["possible_agents"]: - rid = agent_mapping_func(agent) - agent_groups[rid].add(agent) - - self._runtime_ids = tuple(agent_groups.keys()) - self._agent_groups = dict(agent_groups) + self._runtime_ids = tuple(group_info["agent_groups"].keys()) + self._group_info = group_info self.experiment_tag = experiment_tag assert ( @@ -155,7 +150,7 @@ def agent_groups(self) -> Dict[str, Set]: Dict[str, Set]: A dict of set. """ - return self._agent_groups + return self._group_info["agent_groups"] @property def workers(self) -> List[RemoteInterface]: @@ -167,7 +162,7 @@ def workers(self) -> List[RemoteInterface]: return self._workers - def submit(self, task: Union[Dict[str, Any], List[Dict[str, Any]]], task_type: Any): + def submit(self, task: Union[Dict[str, Any], List[Dict[str, Any]]]): """Submit a task to workers Args: @@ -182,40 +177,7 @@ def submit(self, task: Union[Dict[str, Any], List[Dict[str, Any]]], task_type: A for _task in task: validate_strategy_specs(_task.strategy_specs) - self._actor_pool.submit( - lambda actor, _task: actor.rollout.remote(_task, stopping_conditions) - ) - - def _rollout(self, task_list: List[Dict[str, Any]]) -> None: - """Start rollout task without blocking. - - Args: - task_list (List[Dict[str, Any]]): A list of task dict, keys include: - - `strategy_specs`: a dict of strategy specs, mapping from runtime ids to specs. - - `trainable_agents`: a list of trainable agents. - - """ - - # validate all strategy specs here - for task in task_list: - validate_strategy_specs(task["strategy_specs"]) - - while self._actor_pool.has_next(): - try: - self._actor_pool.get_next(timeout=0) - except TimeoutError: - pass - - for task in task_list: - self._actor_pool.submit( - lambda actor, task: actor.rollout.remote( - runtime_strategy_specs=task["strategy_specs"], - stopping_conditions=self.stopping_conditions["rollout"], - trainable_agents=task["trainable_agents"], - data_entrypoints=task["data_entrypoints"], - ), - task, - ) + self._actor_pool.submit(lambda actor, _task: actor.rollout.remote(_task)) def retrive_results(self): """Retrieve task results diff --git a/malib/rollout/pb_rolloutworker.py b/malib/rollout/pb_rolloutworker.py index c37426ef..852dc68d 100644 --- a/malib/rollout/pb_rolloutworker.py +++ b/malib/rollout/pb_rolloutworker.py @@ -25,7 +25,6 @@ from typing import Dict, Any from malib.rollout.rolloutworker import RolloutWorker, parse_rollout_info -from malib.common.strategy_spec import StrategySpec from malib.utils.logging import Logger @@ -36,7 +35,7 @@ def step_rollout( self, eval_step: bool, rollout_config: Dict[str, Any], - dataset_writer_info_dict: Dict[str, Any], + data_entrypoint_mapping: Dict[str, Any], ): tasks = [rollout_config for _ in range(self.rollout_config["num_threads"])] @@ -53,11 +52,11 @@ def step_rollout( rets = [ x - for x in self.actor_pool.map( + for x in self.env_runner_pool.map( lambda a, task: a.run.remote( - agent_interfaces=self.agent_interfaces, + inference_clients=self.inference_clients, rollout_config=task, - dataset_writer_info_dict=dataset_writer_info_dict, + data_entrypoint_mapping=data_entrypoint_mapping, ), tasks, ) @@ -67,34 +66,3 @@ def step_rollout( parsed_results = parse_rollout_info(rets) Logger.debug(f"parsed results: {parsed_results}") return parsed_results - - def step_simulation( - self, - runtime_strategy_specs: Dict[str, StrategySpec], - runtime_config_template: Dict[str, Any], - ) -> Dict[str, Any]: - """Step simulation task with a given list of strategy spec dicts. - - Args: - runtime_strategy_specs (Dict[str, StrategySpec]): A strategy spec dicts. - runtime_config_template (Dict[str, Any]): Runtime configuration template. - - Returns: - Dict[str, Any]: Evaluation results, a dict. - """ - - task = runtime_config_template.copy() - task["strategy_specs"] = runtime_strategy_specs - - # we should keep dimension as tasks. - rets = [ - parse_rollout_info([x]) - for x in self.actor_pool.map( - lambda a, task: a.run.remote( - agent_interfaces=self.agent_interfaces, rollout_config=task - ), - [task], - ) - ][0] - - return rets diff --git a/malib/rollout/rolloutworker.py b/malib/rollout/rolloutworker.py index 8c406827..6daff6e9 100644 --- a/malib/rollout/rolloutworker.py +++ b/malib/rollout/rolloutworker.py @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Dict, Any, List, Callable, Sequence, Tuple +from typing import Dict, Any, List, Callable, Sequence, Tuple, Set from abc import abstractmethod from collections import defaultdict @@ -37,21 +37,18 @@ import numpy as np from ray.util import ActorPool -from ray.util.queue import Queue from torch.utils import tensorboard from malib import settings -from malib.utils.typing import AgentID +from malib.utils.typing import AgentID, BehaviorMode from malib.utils.logging import Logger from malib.utils.stopping_conditions import get_stopper from malib.utils.monitor import write_to_tensorboard from malib.common.strategy_spec import StrategySpec -from malib.common.task import RolloutTask, TaskType +from malib.common.task import RolloutTask from malib.remote.interface import RemoteInterface -from malib.rollout.inference.ray.server import ( - RayInferenceWorkerSet as RayInferenceServer, -) -from malib.rollout.inference.ray.client import RayInferenceClient +from malib.rollout.inference.client import InferenceClient +from malib.rollout.inference.env_runner import EnvRunner PARAMETER_GET_TIMEOUT = 3 @@ -116,34 +113,6 @@ def log(message: str): logger.log(settings.LOG_LEVEL, f"(rollout worker) {message}") -def validate_agent_group( - agent_group: Dict[str, List[AgentID]], - full_keys: List[AgentID], - observation_spaces: Dict[AgentID, gym.Space], - action_spaces: Dict[AgentID, gym.Space], -) -> None: - """Validate agent group, check spaces. - - Args: - agent_group (Dict[str, List[AgentID]]): A dict, mapping from runtime ids to lists of agent ids. - full_keys (List[AgentID]): A list of original environment agent ids. - observation_spaces (Dict[AgentID, gym.Space]): Agent observation space dict. - action_spaces (Dict[AgentID, gym.Space]): Agent action space dict. - - Raises: - RuntimeError: Agents in a same group should share the same observation space and action space. - NotImplementedError: _description_ - """ - for agents in agent_group.values(): - select_obs_space = observation_spaces[agents[0]] - select_act_space = action_spaces[agents[0]] - for agent in agents[1:]: - assert type(select_obs_space) == type(observation_spaces[agent]) - assert select_obs_space.shape == observation_spaces[agent].shape - assert type(select_act_space) == type(action_spaces[agent]) - assert select_act_space.shape == action_spaces[agent].shape - - def default_rollout_callback(coordinator: ray.ObjectRef, results: Dict[str, Any]): pass @@ -180,6 +149,7 @@ def __init__( experiment_tag: str, env_desc: Dict[str, Any], agent_mapping_func: Callable, + agent_groups: Dict[str, Set], rollout_config: Dict[str, Any], log_dir: str, rollout_callback: Callable[[ray.ObjectRef, Dict[str, Any]], Any] = None, @@ -187,9 +157,7 @@ def __init__( resource_config: Dict[str, Any] = None, verbose: bool = True, ): - """Create a instance for simulations, rollout and evaluation. This base class initializes \ - all necessary servers and workers for rollouts. Including remote agent interfaces, \ - workers for simultaions. + """Construct a rollout worker, consuming rollout/evaluation tasks. Args: env_desc (Dict[str, Any]): The environment description. @@ -212,62 +180,29 @@ def __init__( self.worker_indentifier = f"rolloutworker_{os.getpid()}" # map agents - agent_group = defaultdict(lambda: []) - runtime_agent_ids = [] - for agent in env_desc["possible_agents"]: - runtime_id = agent_mapping_func(agent) - agent_group[runtime_id].append(agent) - runtime_agent_ids.append(runtime_id) - runtime_agent_ids = set(runtime_agent_ids) - agent_group = dict(agent_group) resource_config = resource_config or DEFAULT_RESOURCE_CONFIG - # valid agent group - validate_agent_group( - agent_group=agent_group, - full_keys=env_desc["possible_agents"], - observation_spaces=env_desc["observation_spaces"], - action_spaces=env_desc["action_spaces"], - ) - self.env_description = env_desc self.env_agents = env_desc["possible_agents"] - self.runtime_agent_ids = runtime_agent_ids - self.agent_group = agent_group + self.runtime_agent_ids = list(agent_groups.keys()) + self.agent_groups = agent_groups self.rollout_config: Dict[str, Any] = rollout_config validate_runtime_configs(self.rollout_config) - self.coordinator = None - self.dataset_server = None - self.parameter_server = None - - self.init_servers() - - if rollout_config["inference_server"] == "local": - self.inference_server_cls = None - self.inference_client_cls = RayInferenceClient.as_remote( - **resource_config["inference_client"] - ) - elif rollout_config["inference_server"] == "ray": - self.inference_client_cls = RayInferenceClient.as_remote( - **resource_config["inference_client"] - ) - self.inference_server_cls = RayInferenceServer.as_remote( - **resource_config["inference_server"] - ).options(max_concurrency=100) - - else: - raise ValueError( - "unexpected inference server type: {}".format( - rollout_config["inference_server"] - ) - ) + self.inference_client_cls = InferenceClient.as_remote( + **resource_config["inference_client"] + ) + self.env_runner_cls = EnvRunner.as_remote( + **resource_config["inference_server"] + ).options(max_concurrency=100) - self.agent_interfaces = self.init_agent_interfaces(env_desc, runtime_agent_ids) - self.actor_pool: ActorPool = self.init_actor_pool( + self.env_runner_pool: ActorPool = self.init_env_runner_pool( env_desc, rollout_config, agent_mapping_func ) + self.inference_clients: Dict[ + AgentID, ray.ObjectRef + ] = self.create_inference_clients() self.log_dir = log_dir self.rollout_callback = rollout_callback or default_rollout_callback @@ -276,48 +211,10 @@ def __init__( self.experiment_tag = experiment_tag self.verbose = verbose - def init_agent_interfaces( - self, env_desc: Dict[str, Any], runtime_ids: Sequence[AgentID] - ) -> Dict[AgentID, Any]: - """Initialize agent interfaces which is a dict of `InterfaceWorkerSet`. The keys in the \ - dict is generated from the given agent mapping function. - - Args: - env_desc (Dict[str, Any]): Environment description. - runtime_ids (Sequence[AgentID]): Available runtime ids, generated with agent mapping function. - - Returns: - Dict[AgentID, Any]: A dict of `InferenceWorkerSet`, mapping from `runtime_ids` to `ray.ObjectRef(s)` - """ - - # interact with environment - if self.inference_server_cls is None: - return None + def create_inference_clients(self) -> Dict[AgentID, ray.ObjectRef]: + raise NotImplementedError - obs_spaces = env_desc["observation_spaces"] - act_spaces = env_desc["action_spaces"] - - runtime_obs_spaces = {} - runtime_act_spaces = {} - - for rid, agents in self.agent_group.items(): - runtime_obs_spaces[rid] = obs_spaces[agents[0]] - runtime_act_spaces[rid] = act_spaces[agents[0]] - - agent_interfaces = { - runtime_id: self.inference_server_cls.remote( - agent_id=runtime_id, - observation_space=runtime_obs_spaces[runtime_id], - action_space=runtime_act_spaces[runtime_id], - parameter_server=self.parameter_server, - governed_agents=self.agent_group[runtime_id], - ) - for runtime_id in runtime_ids - } - - return agent_interfaces - - def init_actor_pool( + def init_env_runner_pool( self, env_desc: Dict[str, Any], rollout_config: Dict[str, Any], @@ -344,12 +241,12 @@ def init_actor_pool( num_env_per_thread = rollout_config["num_env_per_thread"] num_eval_threads = rollout_config["num_eval_threads"] - actor_pool = ActorPool( + env_runner_pool = ActorPool( [ - self.inference_client_cls.remote( + self.env_runner_cls.remote( env_desc, - ray.get_actor(settings.OFFLINE_DATASET_ACTOR), max_env_num=num_env_per_thread, + agent_groups=self.agent_groups, use_subproc_env=rollout_config["use_subproc_env"], batch_mode=rollout_config["batch_mode"], postprocessor_types=rollout_config["postprocessor_types"], @@ -358,77 +255,28 @@ def init_actor_pool( for _ in range(num_threads + num_eval_threads) ] ) - return actor_pool - - def init_servers(self): - """Connect to data servers. + return env_runner_pool - Raises: - RuntimeError: Runtime errors. - """ - - retries = 100 - while True: - try: - if self.parameter_server is None: - self.parameter_server = ray.get_actor( - settings.PARAMETER_SERVER_ACTOR - ) - - if self.dataset_server is None: - self.dataset_server = ray.get_actor(settings.OFFLINE_DATASET_ACTOR) - break - except Exception as e: - retries -= 1 - if retries == 0: - raise RuntimeError(traceback.format_exc()) - else: - logger.log( - logging.WARNING, - f"waiting for coordinator server initialization ... {self.worker_indentifier}", - ) - time.sleep(1) - - def rollout( - self, - runtime_strategy_specs: Dict[str, StrategySpec], - stopping_conditions: Dict[str, Any], - data_entrypoints: Dict[str, str] = None, - active_agents: List[AgentID] = None, - ): + def rollout(self, task: RolloutTask): """Rollout, collecting training data when `data_entrypoints` is given, until meets the stopping conditions. The `active_agents` should be None or a none-empty list to specify active agents if rollout is not serve for evaluation. NOTE: the data collection will be triggered only for active agents. Args: - runtime_strategy_specs (Dict[str, StrategySpec]): A dict of strategy spec, mapping from runtime id to `StrategySpec`. - stopping_conditions (Dict[str, Any]): A dict of stopping conditions. - data_entrypoints (Dict[str, str], optional): Mapping from runtimeids to dataentrypoint names. None for evaluation. - active_agents (List[AgentID], optional): A list of environment agent id. Defaults to None, which means all environment agents will be trainable. Empty list for evaluation mode. + task: None """ - stopper = get_stopper(stopping_conditions) + stopper = get_stopper(task.stopping_conditions) active_agents = active_agents or self.env_agents - - if data_entrypoints is not None: - queue_info_dict: Dict[str, Tuple[str, Queue]] = { - rid: None for rid in self.runtime_agent_ids - } - for rid, identifier in data_entrypoints.items(): - queue_id, queue = ray.get( - self.dataset_server.start_producer_pipe.remote(name=identifier) - ) - queue_info_dict[rid] = (queue_id, queue) - else: - queue_info_dict = None + runtime_strategy_specs = task.strategy_specs + data_entrypoint_mapping = task.data_entrypoint_mapping rollout_config = self.rollout_config.copy() rollout_config.update( { "flag": "rollout", "strategy_specs": runtime_strategy_specs, - "active_agents": active_agents, - "agent_group": self.agent_group, + "behavior_mode": BehaviorMode.EXPLORATION, } ) total_timesteps = 0 @@ -443,10 +291,11 @@ def rollout( self.set_running(True) start_time = time.time() - # TODO(ming): share the stopping conditions here while self.is_running(): eval_step = (epoch + 1) % self.rollout_config["eval_interval"] == 0 - results = self.step_rollout(eval_step, rollout_config, queue_info_dict) + results = self.step_rollout( + eval_step, rollout_config, data_entrypoint_mapping + ) total_timesteps += results["total_timesteps"] eval_results = results.get("evaluation", None) @@ -480,29 +329,12 @@ def rollout( self.rollout_callback(self.coordinator, results) return results - def simulate(self, runtime_strategy_specs: Dict[str, StrategySpec]): - """Handling simulation task.""" - - runtime_config_template = self.rollout_config.copy() - runtime_config_template.update( - { - "flag": "simulation", - } - ) - - results: Dict[str, Any] = self.step_simulation( - runtime_strategy_specs, runtime_config_template - ) - - self.simulate_callback(self.coordinator, results) - return results - @abstractmethod def step_rollout( self, eval_step: bool, rollout_config: Dict[str, Any], - dataset_writer_info_dict: Dict[str, Any], + data_entrypoint_mapping: Dict[AgentID, str], ) -> List[Dict[str, Any]]: """The logic function to run rollout. Users must implment this method. @@ -521,6 +353,7 @@ def step_rollout( - `agent_group`: a dict that maps runtime agents to a list of environment agents, which describes the envrionment agents \ governed by what runtime agent interface. - `fragment_length`: the maximum of collected data frames. + data_entrypoint_mapping: ... Raises: NotImplementedError: _description_ @@ -529,25 +362,6 @@ def step_rollout( List[Dict[str, Any]]: Evaluation results, could be empty. """ - @abstractmethod - def step_simulation( - self, - runtime_strategy_specs: Dict[str, StrategySpec], - rollout_config: Dict[str, Any], - ) -> Dict[str, Any]: - """Logic function for running simulation of a list of strategy spec dict. - - Args: - runtime_strategy_specs (Dict[str, StrategySpec]): A strategy spec dict. - rollout_config (Dict[str, Any]): Runtime configuration template. - - Raises: - NotImplementedError: Not implemented error. - - Returns: - Dict[str, Any]: A evaluation results. - """ - def assign_episode_id(self): return f"eps-{self.worker_indentifier}-{time.time()}" diff --git a/malib/scenarios/sarl_scenario.py b/malib/scenarios/sarl_scenario.py index d5f0ebd7..64a6da1f 100644 --- a/malib/scenarios/sarl_scenario.py +++ b/malib/scenarios/sarl_scenario.py @@ -71,6 +71,7 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = algorithms=scenario.algorithms, env_desc=scenario.env_desc, agent_mapping_func=scenario.agent_mapping_func, + group_info=scenario.group_info, training_config=scenario.training_config, log_dir=scenario.log_dir, remote_mode=True, @@ -83,6 +84,7 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = stopping_conditions=scenario.stopping_conditions, num_worker=scenario.num_worker, agent_mapping_func=scenario.agent_mapping_func, + group_info=scenario.group_info, rollout_config=scenario.rollout_config, env_desc=scenario.env_desc, log_dir=scenario.log_dir, @@ -97,13 +99,8 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = f"Training manager was inistialized with a strategy spec:\n{strategy_specs}" ) - data_entrypoints = {rid: rid for rid in training_manager.runtime_ids} - - assert len(data_entrypoints) == 1, "Support single agent only!" - optimization_task = OptimizationTask( active_agents=None, - data_entrypoints=data_entrypoints, stop_conditions=scenario.stopping_conditions["training"], ) training_manager.submit(optimization_task) @@ -111,7 +108,7 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = rollout_task = { "num_workers": 1, "runtime_strategy_specs": strategy_specs, - "data_entrypoints": None, + "data_entrypoints": training_manager.get_data_entrypoint_mapping(), "rollout_config": scenario.rollout_config, "active_agents": None, } @@ -123,8 +120,8 @@ def execution_plan(experiment_tag: str, scenario: SARLScenario, verbose: bool = ), } - rollout_manager.submit(rollout_task, task_type=TaskType.ROLLOUT) - rollout_manager.submit(evaluation_task, task_type=TaskType.EVALUATION) + rollout_manager.submit(rollout_task) + rollout_manager.submit(evaluation_task) results = league.get_results() diff --git a/malib/scenarios/scenario.py b/malib/scenarios/scenario.py index 6a25c256..f009a085 100644 --- a/malib/scenarios/scenario.py +++ b/malib/scenarios/scenario.py @@ -22,13 +22,50 @@ from abc import ABC, abstractmethod from types import LambdaType -from typing import Callable, Union, Dict, Any +from typing import Callable, Union, Dict, Any, Set, List from copy import deepcopy +from collections import defaultdict + +import gym + +from malib.utils.typing import AgentID DEFAULT_STOPPING_CONDITIONS = {} +def validate_spaces(agent_groups: Dict[str, Set[AgentID]], env_desc: Dict[str, Any]): + # TODO(ming): check whether the agents in the group share the same observation space and action space + raise NotImplementedError + + +def validate_agent_group( + agent_group: Dict[str, List[AgentID]], + observation_spaces: Dict[AgentID, gym.Space], + action_spaces: Dict[AgentID, gym.Space], +) -> None: + """Validate agent group, check spaces. + + Args: + agent_group (Dict[str, List[AgentID]]): A dict, mapping from runtime ids to lists of agent ids. + full_keys (List[AgentID]): A list of original environment agent ids. + observation_spaces (Dict[AgentID, gym.Space]): Agent observation space dict. + action_spaces (Dict[AgentID, gym.Space]): Agent action space dict. + + Raises: + RuntimeError: Agents in a same group should share the same observation space and action space. + NotImplementedError: _description_ + """ + for agents in agent_group.values(): + select_obs_space = observation_spaces[agents[0]] + select_act_space = action_spaces[agents[0]] + for agent in agents[1:]: + assert type(select_obs_space) == type(observation_spaces[agent]) + assert select_obs_space.shape == observation_spaces[agent].shape + assert type(select_act_space) == type(action_spaces[agent]) + assert select_act_space.shape == action_spaces[agent].shape + + class Scenario(ABC): @abstractmethod def __init__( @@ -49,6 +86,23 @@ def __init__( self.env_desc = env_desc self.algorithms = algorithms self.agent_mapping_func = agent_mapping_func + # then generate grouping information here + agent_groups = defaultdict(lambda: set()) + grouped_obs_space = {} + grouped_act_space = {} + for agent in env_desc["possible_agents"]: + rid = agent_mapping_func(agent) + agent_groups[rid].add(agent) + grouped_obs_space[rid] = env_desc["observation_spaces"][agent] + grouped_act_space[rid] = env_desc["action_spaces"][agent] + self.group_info = { + "observation_space": grouped_obs_space, + "action_space": grouped_act_space, + "agent_groups": agent_groups, + } + validate_agent_group( + agent_groups, env_desc["observation_spaces"], env_desc["action_spaces"] + ) self.training_config = training_config self.rollout_config = rollout_config self.stopping_conditions = stopping_conditions or DEFAULT_STOPPING_CONDITIONS diff --git a/tests/rollout/test_env_runner.py b/tests/rollout/test_env_runner.py new file mode 100644 index 00000000..7085bd40 --- /dev/null +++ b/tests/rollout/test_env_runner.py @@ -0,0 +1,42 @@ +from typing import List, Dict, Any + +import pytest + +from malib.utils.typing import BehaviorMode +from malib.common.strategy_spec import StrategySpec +from malib.rollout.inference import env_runner +from malib.rollout.inference.client import InferenceClient +from malib.rollout.envs import mdp + + +@pytest.mark.parametrize( + "env_desc,max_env_num", + [ + [mdp.env_desc_gen(env_id="multi_round_nmdp"), 1], + ], +) +def test_env_runner(env_desc: Dict[str, Any], max_env_num: int): + agent_groups = dict(zip(env_desc["possible_agents"], env_desc["possible_agents"])) + runner = env_runner.EnvRunner(env_desc, max_env_num, agent_groups) + + agents = env_desc["possible_agents"] + observation_spaces = env_desc["observation_spaces"] + action_spaces = env_desc["action_spaces"] + + inference_remote_cls = InferenceClient.as_remote(num_cpus=1) + rollout_config = { + "flag": "evaluation", + "strategy_specs": { + agent: StrategySpec(agent, ["policy-0"], meta_data={}) for agent in agents + }, + "behavior_mode": BehaviorMode.EXPLOITATION, + } + + infer_clients = { + agent: inference_remote_cls.remote( + agent, observation_spaces[agent], action_spaces[agent] + ) + for agent in agents + } + + runner.run(infer_clients, rollout_config) diff --git a/tests/rollout/test_ray_inference.py b/tests/rollout/test_ray_inference.py deleted file mode 100644 index 478b6144..00000000 --- a/tests/rollout/test_ray_inference.py +++ /dev/null @@ -1,397 +0,0 @@ -# MIT License - -# Copyright (c) 2021 MARL @ SJTU - -# Author: Ming Zhou - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from typing import Callable, Dict, Any, List, Tuple -from argparse import Namespace -from collections import defaultdict - -import pytest -import ray - -from malib.agent.agent_interface import AgentInterface -from malib.agent.manager import TrainingManager -from malib.backend.parameter_server import ParameterServer - -# from malib.rollout.envs.dummy_env import env_desc_gen -from malib.runner import start_servers -from malib.rollout.envs.gym import env_desc_gen as gym_env_desc_gen -from malib.rollout.envs.open_spiel import env_desc_gen as open_spiel_env_desc_gen -from malib.rollout.envs.vector_env import VectorEnv -from malib.rollout.inference.utils import process_policy_outputs -from malib.rollout.rolloutworker import parse_rollout_info -from malib.utils.episode import Episode, NewEpisodeDict -from malib.utils.typing import AgentID, PolicyID -from malib.agent.indepdent_agent import IndependentAgent -from malib.common.strategy_spec import StrategySpec -from malib.scenarios.marl_scenario import MARLScenario -from malib.rollout.inference.ray.server import RayInferenceWorkerSet -from malib.rollout.inference.ray.client import env_runner, RayInferenceClient -from malib.utils.typing import BehaviorMode - - -def dqn(): - from malib.rl.dqn import DQNPolicy, DQNTrainer, DEFAULT_CONFIG - - algorithms = { - "default": ( - DQNPolicy, - DQNTrainer, - # model configuration, None for default - { - "net_type": "general_net", - "config": {"hidden_sizes": [64, 64]}, - }, - {}, - ) - } - trainer_config = DEFAULT_CONFIG["training_config"].copy() - return [algorithms, trainer_config] - - -def build_marl_scenario( - algorithms: Dict[str, Dict], - env_description: Dict[str, Any], - learner_cls, - trainer_config: Dict[str, Any], - agent_mapping_func: Callable, - runtime_logdir: str, -) -> MARLScenario: - training_config = { - "type": learner_cls, - "trainer_config": trainer_config, - "custom_config": {}, - } - rollout_config = { - "fragment_length": 200, # every thread - "max_step": 20, - "num_eval_episodes": 10, - "num_threads": 2, - "num_env_per_thread": 2, - "num_eval_threads": 1, - "use_subproc_env": False, - "batch_mode": "time_step", - "postprocessor_types": ["defaults"], - # every # rollout epoch run evaluation. - "eval_interval": 1, - "inference_server": "ray", # three kinds of inference server: `local`, `pipe` and `ray` - } - scenario = MARLScenario( - name="test_ray_inference", - log_dir=runtime_logdir, - algorithms=algorithms, - env_description=env_description, - training_config=training_config, - rollout_config=rollout_config, - agent_mapping_func=agent_mapping_func, - stopping_conditions={ - "training": {"max_iteration": int(1e10)}, - "rollout": {"max_iteration": 1000, "minimum_reward_improvement": 1.0}, - }, - ) - return scenario - - -def push_policy_to_parameter_server( - scenario: MARLScenario, parameter_server: ParameterServer -) -> Dict[AgentID, StrategySpec]: - """Generate a dict of strategy spec, generate policies and push them to the remote parameter server. - - Args: - scenario (MARLScenario): Scenario instance. - agents (List[AgentID]): A list of enviornment agents. - parameter_server (ParameterServer): Remote parameter server. - - Returns: - Dict[AgentID, StrategySpec]: A dict of strategy specs. - """ - - res = dict() - for agent in scenario.env_desc["possible_agents"]: - sid = scenario.agent_mapping_func(agent) - if sid in res: - continue - spec_pid = f"policy-0" - strategy_spec = StrategySpec( - identifier=sid, - policy_ids=[spec_pid], - meta_data={ - "policy_cls": scenario.algorithms["default"][0], - "experiment_tag": "test_ray_inference", - "kwargs": { - "observation_space": scenario.env_desc["observation_spaces"][agent], - "action_space": scenario.env_desc["action_spaces"][agent], - "model_config": scenario.algorithms["default"][2], - "custom_config": scenario.algorithms["default"][3], - "kwargs": {}, - }, - }, - ) - policy = strategy_spec.gen_policy() - ray.get(parameter_server.create_table.remote(strategy_spec)) - ray.get( - parameter_server.set_weights.remote( - spec_id=strategy_spec.id, - spec_policy_id=spec_pid, - state_dict=policy.state_dict(), - ) - ) - res[sid] = strategy_spec - return res - - -def generate_cs( - scenario: MARLScenario, dataset_server, parameter_server -) -> Tuple[RayInferenceClient, Dict[str, RayInferenceWorkerSet]]: - env_desc = scenario.env_desc - observation_spaces = env_desc["observation_spaces"] - action_spaces = env_desc["action_spaces"] - servers = dict.fromkeys(env_desc["possible_agents"], None) - agent_group = defaultdict(list) - for agent in env_desc["possible_agents"]: - sid = scenario.agent_mapping_func(agent) - agent_group[sid].append(agent) - - client = RayInferenceClient( - env_desc=scenario.env_desc, - dataset_server=dataset_server, - max_env_num=scenario.rollout_config["num_env_per_thread"], - use_subproc_env=scenario.rollout_config["use_subproc_env"], - batch_mode=scenario.rollout_config["batch_mode"], - postprocessor_types=scenario.rollout_config["postprocessor_types"], - training_agent_mapping=scenario.agent_mapping_func, - ) - - for sid, agents in agent_group.items(): - servers[sid] = RayInferenceWorkerSet( - agent_id=sid, - observation_space=observation_spaces[agent], - action_space=action_spaces[agent], - parameter_server=parameter_server, - governed_agents=agents.copy(), - ) - - return client, servers - - -from malib.rollout.inference.ray.client import process_env_rets - - -def rollout_func( - episode_dict: NewEpisodeDict, - client: RayInferenceClient, - servers: Dict[str, RayInferenceWorkerSet], - rollout_config, - server_runtime_config, - evaluate, -): - env_rets = client.env.reset( - fragment_length=rollout_config["fragment_length"], - max_step=rollout_config["max_step"], - ) - processed_env_ret, dataframes = process_env_rets( - env_rets, - preprocessor=server_runtime_config["preprocessor"], - preset_meta_data={"evaluate": evaluate}, - ) - if episode_dict is not None: - episode_dict.record(processed_env_ret, agent_first=False) - - cnt = 0 - while not client.env.is_terminated(): - grouped_dataframes = defaultdict(list) - for agent, dataframe in dataframes.items(): - runtime_id = client.training_agent_mapping(agent) - grouped_dataframes[runtime_id].append(dataframe) - - policy_outputs = { - rid: server.compute_action( - grouped_dataframes[rid], runtime_config=server_runtime_config - ) - for rid, server in servers.items() - } - - env_actions, processed_policy_outputs = process_policy_outputs( - policy_outputs, client.env - ) - - assert len(env_actions) > 0, "inference server may be stucked." - - if episode_dict is not None: - episode_dict.record(processed_policy_outputs, agent_first=True) - - env_rets = client.env.step(env_actions) - if len(env_rets) < 1: - dataframes = {} - continue - - processed_env_ret, dataframes = process_env_rets( - env_rets, - preprocessor=server_runtime_config["preprocessor"], - preset_meta_data={"evaluate": evaluate}, - ) - - if episode_dict is not None: - episode_dict.record(processed_env_ret, agent_first=False) - - cnt += 1 - - -def data_servers(): - if not ray.is_initialized(): - ray.init() - - parameter_server, offline_dataset_server = start_servers() - return parameter_server, offline_dataset_server - - -@pytest.mark.parametrize( - "env_desc", - [ - gym_env_desc_gen(env_id="CartPole-v1"), - # open_spiel_env_desc_gen(env_id="kuhn_poker"), - # mdp_env_desc_gen(env_id="two_round_dmdp"), - ], -) -@pytest.mark.parametrize("learner_cls", [IndependentAgent]) -@pytest.mark.parametrize("algorithms,trainer_config", [dqn()]) -def test_inference_mechanism(env_desc, learner_cls, algorithms, trainer_config): - parameter_server, dataset_server = data_servers() - scenario: MARLScenario = build_marl_scenario( - algorithms, - env_desc, - learner_cls, - trainer_config, - agent_mapping_func=lambda agent: agent, - runtime_logdir="./logs", - ) - client, servers = generate_cs(scenario, dataset_server, parameter_server) - training_manager = TrainingManager( - experiment_tag=scenario.name, - stopping_conditions=scenario.stopping_conditions, - algorithms=scenario.algorithms, - env_desc=scenario.env_desc, - agent_mapping_func=scenario.agent_mapping_func, - training_config=scenario.training_config, - log_dir=scenario.log_dir, - remote_mode=True, - resource_config=scenario.resource_config["training"], - verbose=True, - ) - data_entrypoints = {k: k for k in training_manager.agent_groups.keys()} - - # add policies and start training - strategy_specs = training_manager.add_policies(n=scenario.num_policy_each_interface) - strategy_specs = strategy_specs - data_entrypoints = data_entrypoints - - rollout_config = scenario.rollout_config.copy() - rollout_config["flag"] = "rollout" - - server_runtime_config = { - "strategy_specs": strategy_specs, - "behavior_mode": BehaviorMode.EXPLOITATION, - "preprocessor": client.preprocessor, - } - - dwriter_info_dict = dict.fromkeys(data_entrypoints.keys(), None) - - for rid, identifier in data_entrypoints.items(): - queue_id, queue = ray.get( - dataset_server.start_producer_pipe.remote(name=identifier) - ) - dwriter_info_dict[rid] = (queue_id, queue) - - eval_results, performance_results = env_runner( - client, - servers, - rollout_config, - server_runtime_config, - dwriter_info_dict, - ) - eval_results = parse_rollout_info([{"evaluation": eval_results}]) - print(eval_results["evaluation"]) - print(performance_results) - - for rid, identifier in data_entrypoints.items(): - ray.get(dataset_server.end_producer_pipe.remote(identifier)) - - ray.kill(parameter_server) - ray.kill(dataset_server) - ray.shutdown() - - -# def test_inference_pipeline(self): -# """This function tests the inference pipeline without using default env runner""" - -# training_manager.run(data_entrypoints) - -# rollout_config = scenario.rollout_config.copy() -# rollout_config["flag"] = "rollout" -# server_runtime_config = { -# "strategy_specs": strategy_specs, -# "behavior_mode": BehaviorMode.EXPLOITATION, -# "preprocessor": client.preprocessor, -# } - -# dwriter_info_dict = dict.fromkeys(data_entrypoints.keys(), None) - -# for rid, identifier in data_entrypoints.items(): -# queue_id, queue = ray.get( -# dataset_server.start_producer_pipe.remote(name=identifier) -# ) -# dwriter_info_dict[rid] = (queue_id, queue) - -# # collect episodes and run training -# rollout_env = client.env -# for n_epoch in range(2): -# episode_dict = NewEpisodeDict( -# lambda: Episode(agents=scenario.env_desc["possible_agents"]) -# ) -# rollout_func( -# episode_dict, -# client, -# servers, -# rollout_config, -# server_runtime_config, -# False, -# ) - -# episodes = episode_dict.to_numpy() -# for rid, writer_info in dwriter_info_dict.items(): -# agents = client.agent_group[rid] -# batches = [] -# for episode in episodes.values(): -# agent_buffer = [episode[aid] for aid in agents] -# batches.append(agent_buffer) -# writer_info[-1].put_nowait_batch(batches) -# rollout_info = client.env.collect_info() -# eval_results = list(rollout_info.values()) -# rollout_res = parse_rollout_info([{"evaluation": eval_results}]) - -# print("epoch: {}\nrollout_res: {}\n".format(n_epoch, rollout_res)) - -# client.env = rollout_env - -# training_manager.cancel_pending_tasks() -# # training_manager.terminate()