Skip to content

Commit

Permalink
tmp save
Browse files Browse the repository at this point in the history
  • Loading branch information
Ming Zhou committed Oct 27, 2023
1 parent 19ed9ab commit c01e463
Show file tree
Hide file tree
Showing 19 changed files with 257 additions and 868 deletions.
24 changes: 8 additions & 16 deletions malib/agent/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,6 @@
)


def validate_spaces(agent_groups: Dict[str, Set[AgentID]], env_desc: Dict[str, Any]):
# TODO(ming): check whether the agents in the group share the same observation space and action space
raise NotImplementedError


class TrainingManager(Manager):
def __init__(
self,
Expand All @@ -70,6 +65,7 @@ def __init__(
algorithms: Dict[str, Any],
env_desc: Dict[str, Any],
agent_mapping_func: Callable[[AgentID], str],
group_info: Dict[str, Any],
training_config: Union[Dict[str, Any], TrainingConfig],
log_dir: str,
remote_mode: bool = True,
Expand Down Expand Up @@ -98,16 +94,10 @@ def __init__(
training_config = TrainingConfig.from_raw(training_config)

# interface config give the agent type used here and the group mapping if needed
agent_groups = defaultdict(lambda: set())
for agent in env_desc["possible_agents"]:
rid = agent_mapping_func(agent)
agent_groups[rid].add(agent)

validate_spaces(agent_groups, env_desc)

# FIXME(ming): resource configuration is not available now, will open in the next version
if training_config.trainer_config.get("use_cuda", False):
num_gpus = 1 / len(agent_groups)
num_gpus = 1 / len(group_info["agent_groups"])
else:
num_gpus = 0.0
if not os.path.exists(log_dir):
Expand All @@ -125,7 +115,7 @@ def __init__(
"training" in stopping_conditions
), f"Stopping conditions should contains `training` stoppong conditions: {stopping_conditions}"

for rid, agents in agent_groups.items():
for rid, agents in group_info["agent_groups"].items():
_cls = learner_cls.remote if remote_mode else learner_cls
learners[rid] = _cls(
experiment_tag=experiment_tag,
Expand All @@ -145,8 +135,7 @@ def __init__(
_ = ray.get([x.connect.remote() for x in learners.values()])

# TODO(ming): collect data entrypoints from learners

self._agent_groups = agent_groups
self._group_info = group_info
self._runtime_ids = tuple(self._agent_groups.keys())
self._experiment_tag = experiment_tag
self._env_description = env_desc
Expand All @@ -170,7 +159,7 @@ def agent_groups(self) -> Dict[str, Set[AgentID]]:
Dict[str, Set[AgentID]]: A dict of agent set.
"""

return self._agent_groups
return self._group_info["agent_groups"]

@property
def get_data_entrypoints(self) -> Dict[str, str]:
Expand Down Expand Up @@ -202,6 +191,9 @@ def runtime_ids(self) -> Tuple[str]:

return self._runtime_ids

def get_data_entrypoint_mapping(self) -> Dict[AgentID, str]:
raise NotImplementedError

def add_policies(
self, interface_ids: Sequence[str] = None, n: Union[int, Dict[str, int]] = 1
) -> Dict[str, Type[StrategySpec]]:
Expand Down
Empty file.
17 changes: 13 additions & 4 deletions malib/backend/dataset_server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,26 @@

import pickle
import grpc
import sys
import os

sys.path.append(os.path.dirname(__file__))

from . import data_pb2
from . import data_pb2_grpc


def send_data(host: str, port: int, data: Any):
def send_data(data: Any, host: str = None, port: int = None, entrypoint: str = None):
if not isinstance(data, bytes):
data = pickle.dumps(data)

with grpc.insecure_channel(f"{host}:{port}") as channel:
stub = data_pb2_grpc.SendDataStub(channel)
reply = stub.Collect(data_pb2.Data(data=data))
if host is not None:
with grpc.insecure_channel(f"{host}:{port}") as channel:
stub = data_pb2_grpc.SendDataStub(channel)
reply = stub.Collect(data_pb2.Data(data=data))
else:
with grpc.insecure_channel(entrypoint) as channel:
stub = data_pb2_grpc.SendDataStub(channel)
reply = stub.Collect(data_pb2.Data(data=data))

return reply.message
6 changes: 2 additions & 4 deletions malib/common/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ class TaskType(IntEnum):
@dataclass
class RolloutTask:
task_type: int
active_agents: List[AgentID]
strategy_specs: Dict[str, Any] = field(default_factory=dict())
stopping_conditions: Dict[str, Any] = field(default_factory=dict())
data_entrypoint_mapping: Dict[str, Any] = field(default_factory=dict())

@classmethod
def from_raw(
Expand All @@ -32,9 +33,6 @@ def from_raw(

@dataclass
class OptimizationTask:
data_entrypoints: Dict[str, str]
"""a mapping defines the data request identifier and the data entrypoint."""

stop_conditions: Dict[str, Any]
"""stopping conditions for optimization task, e.g., max iteration, max time, etc."""

Expand Down
16 changes: 16 additions & 0 deletions malib/rl/common/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import torch
import torch.nn as nn
import gym

from gym import spaces

Expand Down Expand Up @@ -100,6 +101,21 @@ def __init__(
use_sde=custom_config.get("use_sde", False),
dist_kwargs=custom_config.get("dist_kwargs", None),
)
if kwargs.get("model_client"):
self.model = kwargs["model_client"]
else:
self.model = self.create_model()

def create_model(self):
raise NotImplementedError

@property
def action_space(self) -> gym.Space:
return self._action_space

@property
def observation_space(self) -> gym.Space:
return self._observation_space

@property
def model_config(self):
Expand Down
23 changes: 13 additions & 10 deletions malib/rl/pg/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,36 +70,39 @@ def __init__(
observation_space, action_space, model_config, custom_config, **kwargs
)

def create_model(self):
# update model preprocess_net config here
action_shape = (
(action_space.n,) if len(action_space.shape) == 0 else action_space.shape
(self.action_space.n,)
if len(self.action_space.shape) == 0
else self.action_space.shape
)

preprocess_net: nn.Module = net.make_net(
observation_space,
self.observation_space,
self.device,
model_config["preprocess_net"].get("net_type", None),
**model_config["preprocess_net"]["config"]
self.model_config["preprocess_net"].get("net_type", None),
**self.model_config["preprocess_net"]["config"]
)
if isinstance(action_space, spaces.Discrete):
if isinstance(self.action_space, spaces.Discrete):
self.actor = discrete.Actor(
preprocess_net=preprocess_net,
action_shape=action_shape,
hidden_sizes=model_config["hidden_sizes"],
hidden_sizes=self.model_config["hidden_sizes"],
softmax_output=False,
device=self.device,
)
elif isinstance(action_space, spaces.Box):
elif isinstance(self.action_space, spaces.Box):
self.actor = continuous.Actor(
preprocess_net=preprocess_net,
action_shape=action_shape,
hidden_sizes=model_config["hidden_sizes"],
max_action=custom_config.get("max_action", 1.0),
hidden_sizes=self.model_config["hidden_sizes"],
max_action=self.custom_config.get("max_action", 1.0),
device=self.device,
)
else:
raise TypeError(
"Unexpected action space type: {}".format(type(action_space))
"Unexpected action space type: {}".format(type(self.action_space))
)

self.register_state(self.actor, "actor")
Expand Down
6 changes: 3 additions & 3 deletions malib/rollout/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
# SOFTWARE.

from .pb_rolloutworker import RolloutWorker
from .inference.ray.client import RayInferenceClient as InferenceClient
from .inference.ray.server import RayInferenceWorkerSet as InferenceWorkerSet
from .inference.env_runner import EnvRunner
from .inference.client import InferenceClient


__all__ = ["RolloutWorker", "InferenceClient", "InferenceWorkerSet"]
__all__ = ["RolloutWorker", "EnvRunner", "InferenceClient"]
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@
import os

import pickle as pkl
import ray
import gym
import torch

from malib import settings
from malib.remote.interface import RemoteInterface
Expand All @@ -42,47 +40,40 @@
from malib.utils.episode import Episode
from malib.common.strategy_spec import StrategySpec
from malib.rl.common.policy import Policy
from malib.backend.parameter_server import ParameterServer


ClientHandler = namedtuple("ClientHandler", "sender,recver,runtime_config,rnn_states")
Connection = namedtuple("Connection", "sender,recver,runtime_config,rnn_states")


class RayInferenceWorkerSet(RemoteInterface):
class InferenceClient(RemoteInterface):
def __init__(
self,
agent_id: AgentID,
observation_space: gym.Space,
action_space: gym.Space,
parameter_server: ParameterServer,
governed_agents: List[AgentID],
) -> None:
"""Create ray-based inference server.
Args:
agent_id (AgentID): Runtime agent id, not environment agent id.
observation_space (gym.Space): Observation space related to the governed environment agents.
action_space (gym.Space): Action space related to the governed environment agents.
parameter_server (ParameterServer): Parameter server.
governed_agents (List[AgentID]): A list of environment agents.
"""

self.runtime_agent_id = agent_id
self.observation_space = observation_space
self.action_space = action_space
self.parameter_server = parameter_server

self.thread_pool = ThreadPoolExecutor()
self.governed_agents = governed_agents
self.policies: Dict[str, Policy] = {}
self.strategy_spec_dict: Dict[str, StrategySpec] = {}

def shutdown(self):
self.thread_pool.shutdown(wait=True)
for _handler in self.clients.values():
for _handler in self.connections.values():
_handler.sender.shutdown(True)
_handler.recver.shutdown(True)
self.clients: Dict[int, ClientHandler] = {}
self.connections: Dict[int, Connection] = {}

def save(self, model_dir: str) -> None:
if not os.path.exists(model_dir):
Expand All @@ -100,12 +91,6 @@ def compute_action(
strategy_specs: Dict[AgentID, StrategySpec] = runtime_config["strategy_specs"]
return_dataframes: List[DataFrame] = []

# check policy
self._update_policies(
runtime_config["strategy_specs"][self.runtime_agent_id],
self.runtime_agent_id,
)

assert len(dataframes) > 0

for dataframe in dataframes:
Expand All @@ -132,16 +117,6 @@ def compute_action(

rets = {}

with timer.time_avg("policy_update"):
info = ray.get(
self.parameter_server.get_weights.remote(
spec_id=spec.id,
spec_policy_id=spec_policy_id,
)
)
if info["weights"] is not None:
self.policies[policy_id].load_state_dict(info["weights"])

with timer.time_avg("compute_action"):
(
rets[Episode.ACTION],
Expand Down Expand Up @@ -170,19 +145,12 @@ def compute_action(
continue
else:
rets[k] = v.reshape(batch_size, -1)

return_dataframes.append(
DataFrame(identifier=agent_id, data=rets, meta_data=dataframe.meta_data)
)
# print(f"timer information: {timer.todict()}")
return return_dataframes

def _update_policies(self, strategy_spec: StrategySpec, agent_id: AgentID):
for strategy_spec_pid in strategy_spec.policy_ids:
policy_id = f"{strategy_spec.id}/{strategy_spec_pid}"
if policy_id not in self.policies:
policy = strategy_spec.gen_policy(device="cpu")
self.policies[policy_id] = policy


def _get_initial_states(self, client_id, observation, policy: Policy, identifier):
if (
Expand Down
Loading

0 comments on commit c01e463

Please sign in to comment.