Skip to content

Commit

Permalink
pb rollout worker test passed
Browse files Browse the repository at this point in the history
  • Loading branch information
Ming Zhou committed Nov 10, 2023
1 parent 47cd917 commit 8ddc763
Show file tree
Hide file tree
Showing 27 changed files with 525 additions and 572 deletions.
1 change: 0 additions & 1 deletion examples/run_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import os
import time

from malib.runner import run
from malib.learner import IndependentAgent
from malib.scenarios.marl_scenario import MARLScenario
from malib.rl.dqn import DQNPolicy, DQNTrainer, DEFAULT_CONFIG
Expand Down
6 changes: 6 additions & 0 deletions malib/backend/dataset_server/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import threading
import grpc
import socket

from concurrent import futures
from torch.utils.data import DataLoader, Dataset
Expand Down Expand Up @@ -33,6 +34,7 @@ def __init__(
max_message_length,
find_free_port(),
)
self.host = socket.gethostbyname(socket.gethostbyname())

def _start_servicer(
self, max_workers: int, max_message_length: int, grpc_port: int
Expand All @@ -52,6 +54,10 @@ def _start_servicer(

return server

@property
def entrypoint(self) -> str:
return f"{self.host}:{self.server._state.port}"

def __len__(self):
return self.feature_handler_caller.block_size

Expand Down
35 changes: 15 additions & 20 deletions malib/backend/league.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,34 @@

from malib.utils.logging import Logger
from malib.common.manager import Manager
from malib.common.task import Task, RolloutTask, OptimizationTask


class League:
def __init__(
self,
training_manager: Manager,
learner_manager: Manager,
rollout_manager: Manager,
inference_manager: Manager,
) -> None:
self.training_manager = training_manager
self.learner_manager = learner_manager
self.rollout_manager = rollout_manager
self.inferenc_managfer = inference_manager
self.flight_servers = []
self.rw_lock = rwlock.RWLockFair()
self.event = threading.Event()
self.thread_pool = futures.ThreadPoolExecutor()

def register_flight_server(self, flight_server_address: str):
raise NotImplementedError

def list_flight_servers(self) -> List[str]:
raise NotImplementedError

def _flight_server_check(self):
while not self.event.is_set():
with self.rw_lock.gen_rlock():
for flight_server in self.flight_servers:
if not ray.util.check_connection(flight_server):
self.flight_servers.remove(flight_server)
self.event.wait(10)

def list_learners(self):
return self.training_manager.workers()
return self.learner_manager.workers()

def submit(self, task_desc: Task, wait: bool = False):
if isinstance(task_desc, RolloutTask):
res = self.rollout_manager.submit(task_desc, wait)
elif isinstance(task_desc, OptimizationTask):
res = self.learner_manager.submit(task_desc, wait)
else:
raise ValueError(f"Unexpected task type: {isinstance(task_desc)}")
return res

def list_rollout_workers(self):
return self.rollout_manager.workers()
Expand All @@ -62,7 +57,7 @@ def get_results(self) -> Dict[str, Dict[str, Any]]:
while True:
for result in self.rollout_manager.get_results():
rollout_results.append(result)
for result in self.training_manager.get_results():
for result in self.learner_manager.get_results():
training_results.append(result)
except KeyboardInterrupt:
Logger.info("Keyboard interruption was detected, recalling resources ...")
Expand All @@ -76,5 +71,5 @@ def get_results(self) -> Dict[str, Dict[str, Any]]:
def terminate(self):
self.event.set()
self.thread_pool.shutdown()
self.training_manager.terminate()
self.learner_manager.terminate()
self.rollout_manager.terminate()
15 changes: 12 additions & 3 deletions malib/common/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import traceback
from typing import List, Generator, Any
from typing import List, Any
from abc import abstractmethod, ABC

import ray
Expand Down Expand Up @@ -52,8 +51,18 @@ def namespace(self) -> str:
def workers(self) -> List[RemoteInterface]:
raise NotImplementedError

@abstractmethod
def retrive_results(self):
raise NotImplementedError
"""Retrieve execution results."""

@abstractmethod
def submit(self, task: Any, wait: bool = False) -> Any:
"""Submit task to workers.
Args:
task (Any): Task description.
wait (bool, optional): Block or not. Defaults to False.
"""

def wait(self) -> List[Any]:
"""Wait workers to be terminated, and retrieve the executed results.
Expand Down
9 changes: 6 additions & 3 deletions malib/common/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ class TaskType(IntEnum):
OPTIMIZATION = 2


class Task:
pass


@dataclass
class RolloutTask:
task_type: int
class RolloutTask(Task):
strategy_specs: Dict[str, Any] = field(default_factory=dict())
stopping_conditions: Dict[str, Any] = field(default_factory=dict())
data_entrypoint_mapping: Dict[str, Any] = field(default_factory=dict())
Expand All @@ -32,7 +35,7 @@ def from_raw(


@dataclass
class OptimizationTask:
class OptimizationTask(Task):
stop_conditions: Dict[str, Any]
"""stopping conditions for optimization task, e.g., max iteration, max time, etc."""

Expand Down
122 changes: 39 additions & 83 deletions malib/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import torch
import ray

from ray.util.queue import Queue
from gym import spaces
from torch.utils import tensorboard
from torch.utils.data import DataLoader

Expand All @@ -42,10 +42,11 @@
from malib.utils.tianshou_batch import Batch
from malib.utils.monitor import write_to_tensorboard
from malib.remote.interface import RemoteInterface
from malib.rl.common.trainer import Trainer
from malib.common.task import OptimizationTask
from malib.common.strategy_spec import StrategySpec
from malib.backend.dataset_server.data_loader import DynamicDataset
from malib.rl.common.trainer import Trainer
from malib.rl.common.policy import Policy


class Learner(RemoteInterface, ABC):
Expand All @@ -57,8 +58,8 @@ def __init__(
experiment_tag: str,
runtime_id: str,
log_dir: str,
observation_space: gym.Space,
action_space: gym.Space,
observation_space: spaces.Space,
action_space: spaces.Space,
algorithms: Dict[str, Tuple[Type, Type, Dict, Dict]],
agent_mapping_func: Callable[[AgentID], str],
governed_agents: Tuple[AgentID],
Expand Down Expand Up @@ -109,24 +110,32 @@ def __init__(
self._strategy_spec = strategy_spec
self._agent_mapping_func = agent_mapping_func
self._custom_config = custom_config
self._policy = strategy_spec.gen_policy(device=device)

self._summary_writer = tensorboard.SummaryWriter(log_dir=log_dir)
self._trainer_config = trainer_config

# load policy for trainer
self._trainer: Trainer = algorithms["default"][1](trainer_config, self._policy)
self._total_step = 0
self._total_epoch = 0
self._trainer: Trainer = algorithms["default"][1](trainer_config)
self._policies = {}

dataset = dataset or self.create_dataset()
self._data_loader = DataLoader(dataset, batch_size=trainer_config["batch_size"])
self._active_tups = deque()

self._verbose = verbose

@property
def verbose(self) -> bool:
return self._verbose

@property
def strategy_spec(self) -> StrategySpec:
return self._strategy_spec

@property
def policy(self) -> Policy:
return self._policy

@property
def data_loader(self) -> DataLoader:
return self._data_loader
Expand All @@ -151,9 +160,28 @@ def device(self) -> Union[str, torch.DeviceObjType]:

return self._device

@property
def trainer(self) -> Trainer:
return self._trainer

def get_data_entrypoint(self) -> str:
return self.data_loader.dataset.entrypoint

def get_strategy_spec(self) -> StrategySpec:
return self._strategy_spec

def get_state_dict(self) -> Dict[str, torch.Tensor]:
return self.policy.state_dict(device="cpu")

@abstractmethod
def create_dataset(self) -> DynamicDataset:
raise NotImplementedError
"""Create dataset
Returns:
DynamicDataset: Must be an subinstance of DynamicDataset
"""

@abstractmethod
def add_policies(self, n: int) -> StrategySpec:
"""Construct `n` new policies and return the latest strategy spec.
Expand All @@ -164,61 +192,6 @@ def add_policies(self, n: int) -> StrategySpec:
StrategySpec: The latest strategy spec instance.
"""

for _ in range(n):
spec_pid = f"policy-{len(self._strategy_spec)}"
self._strategy_spec.register_policy_id(policy_id=spec_pid)
policy = self._strategy_spec.gen_policy()
policy_id = f"{self._strategy_spec.id}/{spec_pid}"
self._policies[policy_id] = policy
# active tups store the policy info tuple for training, the
# the data request relies on it.
self._active_tups.append((self._strategy_spec.id, spec_pid))
self._trainer.reset(policy_instance=policy)

ray.get(self._parameter_server.create_table.remote(self._strategy_spec))
ray.get(
self._parameter_server.set_weights.remote(
spec_id=self._strategy_spec.id,
spec_policy_id=spec_pid,
state_dict=policy.state_dict(),
)
)

return self._strategy_spec

def push(self):
"""Push local weights to remote server"""

pending_tasks = []
for spec_pid in self._strategy_spec.policy_ids:
pid = f"{self._strategy_spec.id}/{spec_pid}"
task = self._parameter_server.set_weights.remote(
spec_id=self._strategy_spec.id,
spec_policy_id=spec_pid,
state_dict=self._policies[pid].state_dict(),
)
pending_tasks.append(task)
while len(pending_tasks) > 0:
dones, pending_tasks = ray.wait(pending_tasks)

def pull(self):
"""Pull remote weights to update local version."""

pending_tasks = []

for spec_pid in self._strategy_spec.policy_ids:
pid = f"{self._strategy_spec.id}/{spec_pid}"
task = self._parameter_server.get_weights.remote(
spec_id=self._strategy_spec.id, spec_policy_id=spec_pid
)
pending_tasks.append(task)

while len(pending_tasks) > 0:
dones, pending_tasks = ray.wait(pending_tasks)
for done in ray.get(dones):
pid = "{}/{}".format(done["spec_id"], done["spec_policy_id"])
self._policies[pid].load_state_dict(done["weights"])

@abstractmethod
def multiagent_post_process(
self,
Expand Down Expand Up @@ -246,21 +219,8 @@ def get_interface_state(self) -> Dict[str, Any]:
"total_step": self._total_step,
"total_epoch": self._total_epoch,
"policy_num": len(self._strategy_spec),
"active_tups": list(self._active_tups),
}

def sync_remote_parameters(self):
"""Push latest network parameters of active policies to remote parameter server."""

top_active_tup = self._active_tups[0]
ray.get(
self._parameter_server.set_weights.remote(
spec_id=top_active_tup[0],
spec_policy_id=top_active_tup[1],
state_dict=self._trainer.policy.state_dict(device="cpu"),
)
)

def train(self, task: OptimizationTask) -> Dict[str, Any]:
"""Executes a optimization task and returns the final interface state.
Expand All @@ -272,25 +232,21 @@ def train(self, task: OptimizationTask) -> Dict[str, Any]:
Dict[str, Any]: A dict that describes the final state.
"""

# XXX(ming): why we need to reset the state here? I think it is not necessary as
# an optimization task should be independent with other tasks.

self.set_running(True)

try:
while self.is_running():
for data in self.data_loader:
batch_info = self.multiagent_post_process(data)
step_info_list = self._trainer(batch_info)
step_info_list = self.trainer(batch_info)
for step_info in step_info_list:
self._total_step += 1
write_to_tensorboard(
self._summary_writer,
info=step_info,
global_step=self._total_step,
prefix=f"Training/{self._runtime_id}",
prefix=f"Learner/{self._runtime_id}",
)
self.sync_remote_parameters()
self._total_epoch += 1
except Exception as e:
Logger.warning(
Expand Down
Loading

0 comments on commit 8ddc763

Please sign in to comment.