diff --git a/malib/backend/dataset_server/feature.py b/malib/backend/dataset_server/feature.py index a759258..27269ce 100644 --- a/malib/backend/dataset_server/feature.py +++ b/malib/backend/dataset_server/feature.py @@ -35,7 +35,11 @@ def __init__( self.rw_lock = rwlock.RWLockFair() self._device = device self._spaces = spaces - self._block_size = block_size or list(np_memory.values())[0].shape[0] + self._block_size = ( + block_size + if block_size is not None + else list(np_memory.values())[0].shape[0] + ) self._available_size = 0 self._flag = 0 self._shared_memory = { @@ -59,9 +63,22 @@ def get(self, index: int): def write(self, data: Dict[str, Any], start: int, end: int): for k, v in data.items(): - self._shared_memory[k][start:end] = torch.as_tensor(v).to( + # FIXME(ming): should check the size of v + tensor = torch.as_tensor(v).to( self._device, dtype=self._shared_memory[k].dtype ) + split = 0 + if end > self.block_size: + # we now should split the data + split = self.block_size - start + self._shared_memory[k][start:] = tensor[:split] + _start = 0 + _end = tensor.shape[0] - split + else: + _start = start + _end = end + + self._shared_memory[k][_start:_end] = tensor[split:] def generate_timestep(self) -> Dict[str, np.ndarray]: return {k: space.sample() for k, space in self.spaces.items()} diff --git a/malib/learner/learner.py b/malib/learner/learner.py index 23f8cc0..438f5df 100644 --- a/malib/learner/learner.py +++ b/malib/learner/learner.py @@ -50,6 +50,9 @@ from malib.rl.config import Algorithm +MAX_MESSAGE_LENGTH = 7309898 + + class Learner(RemoteInterface, ABC): """Base class of agent interface, for training""" @@ -121,7 +124,7 @@ def __init__( if dataset is None: dataset = DynamicDataset( grpc_thread_num_workers=2, - max_message_length=1024, + max_message_length=MAX_MESSAGE_LENGTH, feature_handler=feature_handler_gen(device), ) else: diff --git a/malib/learner/manager.py b/malib/learner/manager.py index 6738c3e..8ed48f0 100644 --- a/malib/learner/manager.py +++ b/malib/learner/manager.py @@ -105,10 +105,6 @@ def __init__( learner_cls = learner_cls.as_remote(**resource_config) learners: Dict[str, ray.ObjectRef] = {} - # assert ( - # "training" in stopping_conditions - # ), f"Stopping conditions should contains `training` stoppong conditions: {stopping_conditions}" - ready_check = [] for rid, agents in group_info["agent_groups"].items(): @@ -135,6 +131,7 @@ def __init__( while len(ready_check): _, ready_check = ray.wait(ready_check, num_returns=1, timeout=1) + Logger.info("All Learners are ready for accepting new tasks.") data_entrypoints = ray.get( [x.get_data_entrypoint.remote() for x in learners.values()] ) diff --git a/malib/models/model_client.py b/malib/models/model_client.py index fefb524..02e39ae 100644 --- a/malib/models/model_client.py +++ b/malib/models/model_client.py @@ -34,7 +34,6 @@ def __init__( """ namespace, name = entry_point.split(":") - self.client = ray.get_actor(name=name, namespace=namespace) self.thread_pool = futures.ThreadPoolExecutor(max_workers=10) diff --git a/malib/rollout/inference/client.py b/malib/rollout/inference/client.py index 2c41e2a..93a6ab5 100644 --- a/malib/rollout/inference/client.py +++ b/malib/rollout/inference/client.py @@ -126,7 +126,7 @@ def compute_action( with torch.inference_mode(): obs = self.fixed_policy.preprocessor.transform(raw_obs) - obs = torch.from_numpy(obs).float() + obs = torch.tensor(obs).float() # FIXME(ming): act mask and hidden state is set to None, # not feasible for cases which require them policy_return = policy.compute_action( diff --git a/malib/rollout/inference/env_runner.py b/malib/rollout/inference/env_runner.py index 1048e78..230f427 100644 --- a/malib/rollout/inference/env_runner.py +++ b/malib/rollout/inference/env_runner.py @@ -135,7 +135,7 @@ def __init__( agent_groups: Dict[str, Set] = None, inference_entry_points: Dict[str, str] = None, ) -> None: - super().__init__() + super(RemoteInterface, self).__init__() self._use_subproc_env = use_subproc_env self._max_env_num = max_env_num @@ -265,8 +265,12 @@ def run( # FIXME(ming): send data to remote dataset data = agent_manager.merge_episodes() data_entrypoints = data_entrypoints or {} - for entrypoint in data_entrypoints.values(): - send_data(data, entrypoint=entrypoint) + for k, entrypoint in data_entrypoints.items(): + # FIXME(ming): a bug, data: list of agent episode + agent_episode = data[0] + # requires agent group for identification + random_data = list(agent_episode.values())[0] + send_data(random_data, entrypoint=entrypoint) stats = {"total_timesteps": total_timestep, **timer.todict()} return stats diff --git a/malib/rollout/inference/manager.py b/malib/rollout/inference/manager.py index 0bbd3a8..21e07e5 100644 --- a/malib/rollout/inference/manager.py +++ b/malib/rollout/inference/manager.py @@ -4,7 +4,7 @@ from malib.common.manager import Manager from malib.rl.config import Algorithm -from malib.scenarios import Scenario +from malib.utils.logging import Logger from malib.rollout.inference.client import InferenceClient @@ -12,10 +12,10 @@ class InferenceManager(Manager): def __init__( self, group_info: Dict[str, Set], - ray_actor_namespace: str, model_entry_point: Dict[str, str], algorithm: Algorithm, verbose: bool = False, + ray_actor_namespace: str = "inference", ): super().__init__(verbose, namespace=ray_actor_namespace) @@ -26,10 +26,11 @@ def __init__( self._infer_clients = {} self._inference_entry_points = {} - # FIXME(Ming): for debug only - model_entry_point = model_entry_point or { - rid: None for rid in agent_groups.keys() - } + model_entry_point = ( + model_entry_point + if model_entry_point is not None + else {rid: None for rid in agent_groups.keys()} + ) infer_client_ready_check = [] for rid, _ in agent_groups.items(): @@ -54,6 +55,8 @@ def __init__( infer_client_ready_check, num_returns=1, timeout=1 ) + Logger.info("All inference clients are ready for serving") + def get_inference_client(self, runtime_id: str) -> InferenceClient: return self.inference_clients[runtime_id] diff --git a/malib/rollout/rolloutworker.py b/malib/rollout/rolloutworker.py index a7b93ad..d06bec1 100644 --- a/malib/rollout/rolloutworker.py +++ b/malib/rollout/rolloutworker.py @@ -170,6 +170,11 @@ def create_env_runner( agent_groups=self.agent_groups, inference_entry_points=rollout_config.inference_entry_points, ) + ready_check = [env_runner.ready.remote()] + + # wait for it be ready + while len(ready_check): + _, ready_check = ray.wait(ready_check, num_returns=1, timeout=1) return env_runner diff --git a/tests/rollout/test_pb_rollout_worker.py b/tests/rollout/test_pb_rollout_worker.py index 2936a26..9bbd45f 100644 --- a/tests/rollout/test_pb_rollout_worker.py +++ b/tests/rollout/test_pb_rollout_worker.py @@ -103,13 +103,13 @@ def f(device): Episode.DONE: spaces.Discrete(1), Episode.CUR_OBS: env_desc["observation_spaces"][agent_id], Episode.ACTION: env_desc["action_spaces"][agent_id], - Episode.REWARD: spaces.Box(-np.inf, np.inf, shape=(1,), dtype=np.float32), + Episode.REWARD: spaces.Box(-np.inf, np.inf, shape=(), dtype=np.float32), Episode.NEXT_OBS: env_desc["observation_spaces"][agent_id], } np_memory = { - k: np.zeros((100,) + v.shape, dtype=v.dtype) for k, v in _spaces.items() + k: np.zeros((1000,) + v.shape, dtype=v.dtype) for k, v in _spaces.items() } - return FakeFeatureHandler(_spaces, np_memory, device) + return FakeFeatureHandler(_spaces, np_memory, device=device) return f @@ -165,7 +165,7 @@ class TestRolloutWorker: # stats = worker.rollout(task) def test_rollout_with_data_entrypoint(self, n_player: int): - with ray.init(local_mode=True): + with ray.init(): env_desc, algorithm, rollout_config, group_info = gen_common_requirements( n_player ) @@ -175,8 +175,6 @@ def test_rollout_with_data_entrypoint(self, n_player: int): agents = env_desc["possible_agents"] log_dir = "./logs" - inference_namespace = "test_pb_rolloutworker" - learner_manager = LearnerManager( stopping_conditions={"max_iteration": 10}, algorithm=algorithm, @@ -193,7 +191,6 @@ def test_rollout_with_data_entrypoint(self, n_player: int): infer_manager = InferenceManager( group_info=group_info, - ray_actor_namespace=inference_namespace, algorithm=algorithm, model_entry_point=learner_manager.learner_entrypoints, ) @@ -219,6 +216,8 @@ def test_rollout_with_data_entrypoint(self, n_player: int): log_dir=log_dir, ) + print("PBRollout worker is ready to work!!!") + task = RolloutTask( strategy_specs=strategy_spaces, stopping_conditions={"max_iteration": 10}, @@ -226,3 +225,4 @@ def test_rollout_with_data_entrypoint(self, n_player: int): ) stats = worker.rollout(task) + ray.shutdown()