Skip to content

Commit

Permalink
pb test passed
Browse files Browse the repository at this point in the history
  • Loading branch information
Ming Zhou committed Nov 27, 2023
1 parent 40c5074 commit 3e69230
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 25 deletions.
21 changes: 19 additions & 2 deletions malib/backend/dataset_server/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ def __init__(
self.rw_lock = rwlock.RWLockFair()
self._device = device
self._spaces = spaces
self._block_size = block_size or list(np_memory.values())[0].shape[0]
self._block_size = (
block_size
if block_size is not None
else list(np_memory.values())[0].shape[0]
)
self._available_size = 0
self._flag = 0
self._shared_memory = {
Expand All @@ -59,9 +63,22 @@ def get(self, index: int):

def write(self, data: Dict[str, Any], start: int, end: int):
for k, v in data.items():
self._shared_memory[k][start:end] = torch.as_tensor(v).to(
# FIXME(ming): should check the size of v
tensor = torch.as_tensor(v).to(
self._device, dtype=self._shared_memory[k].dtype
)
split = 0
if end > self.block_size:
# we now should split the data
split = self.block_size - start
self._shared_memory[k][start:] = tensor[:split]
_start = 0
_end = tensor.shape[0] - split
else:
_start = start
_end = end

self._shared_memory[k][_start:_end] = tensor[split:]

def generate_timestep(self) -> Dict[str, np.ndarray]:
return {k: space.sample() for k, space in self.spaces.items()}
Expand Down
5 changes: 4 additions & 1 deletion malib/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
from malib.rl.config import Algorithm


MAX_MESSAGE_LENGTH = 7309898


class Learner(RemoteInterface, ABC):
"""Base class of agent interface, for training"""

Expand Down Expand Up @@ -121,7 +124,7 @@ def __init__(
if dataset is None:
dataset = DynamicDataset(
grpc_thread_num_workers=2,
max_message_length=1024,
max_message_length=MAX_MESSAGE_LENGTH,
feature_handler=feature_handler_gen(device),
)
else:
Expand Down
5 changes: 1 addition & 4 deletions malib/learner/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,6 @@ def __init__(
learner_cls = learner_cls.as_remote(**resource_config)
learners: Dict[str, ray.ObjectRef] = {}

# assert (
# "training" in stopping_conditions
# ), f"Stopping conditions should contains `training` stoppong conditions: {stopping_conditions}"

ready_check = []

for rid, agents in group_info["agent_groups"].items():
Expand All @@ -135,6 +131,7 @@ def __init__(
while len(ready_check):
_, ready_check = ray.wait(ready_check, num_returns=1, timeout=1)

Logger.info("All Learners are ready for accepting new tasks.")
data_entrypoints = ray.get(
[x.get_data_entrypoint.remote() for x in learners.values()]
)
Expand Down
1 change: 0 additions & 1 deletion malib/models/model_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def __init__(
"""

namespace, name = entry_point.split(":")

self.client = ray.get_actor(name=name, namespace=namespace)
self.thread_pool = futures.ThreadPoolExecutor(max_workers=10)

Expand Down
2 changes: 1 addition & 1 deletion malib/rollout/inference/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def compute_action(

with torch.inference_mode():
obs = self.fixed_policy.preprocessor.transform(raw_obs)
obs = torch.from_numpy(obs).float()
obs = torch.tensor(obs).float()
# FIXME(ming): act mask and hidden state is set to None,
# not feasible for cases which require them
policy_return = policy.compute_action(
Expand Down
10 changes: 7 additions & 3 deletions malib/rollout/inference/env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(
agent_groups: Dict[str, Set] = None,
inference_entry_points: Dict[str, str] = None,
) -> None:
super().__init__()
super(RemoteInterface, self).__init__()

self._use_subproc_env = use_subproc_env
self._max_env_num = max_env_num
Expand Down Expand Up @@ -265,8 +265,12 @@ def run(
# FIXME(ming): send data to remote dataset
data = agent_manager.merge_episodes()
data_entrypoints = data_entrypoints or {}
for entrypoint in data_entrypoints.values():
send_data(data, entrypoint=entrypoint)
for k, entrypoint in data_entrypoints.items():
# FIXME(ming): a bug, data: list of agent episode
agent_episode = data[0]
# requires agent group for identification
random_data = list(agent_episode.values())[0]
send_data(random_data, entrypoint=entrypoint)

stats = {"total_timesteps": total_timestep, **timer.todict()}
return stats
15 changes: 9 additions & 6 deletions malib/rollout/inference/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@

from malib.common.manager import Manager
from malib.rl.config import Algorithm
from malib.scenarios import Scenario
from malib.utils.logging import Logger
from malib.rollout.inference.client import InferenceClient


class InferenceManager(Manager):
def __init__(
self,
group_info: Dict[str, Set],
ray_actor_namespace: str,
model_entry_point: Dict[str, str],
algorithm: Algorithm,
verbose: bool = False,
ray_actor_namespace: str = "inference",
):
super().__init__(verbose, namespace=ray_actor_namespace)

Expand All @@ -26,10 +26,11 @@ def __init__(

self._infer_clients = {}
self._inference_entry_points = {}
# FIXME(Ming): for debug only
model_entry_point = model_entry_point or {
rid: None for rid in agent_groups.keys()
}
model_entry_point = (
model_entry_point
if model_entry_point is not None
else {rid: None for rid in agent_groups.keys()}
)

infer_client_ready_check = []
for rid, _ in agent_groups.items():
Expand All @@ -54,6 +55,8 @@ def __init__(
infer_client_ready_check, num_returns=1, timeout=1
)

Logger.info("All inference clients are ready for serving")

def get_inference_client(self, runtime_id: str) -> InferenceClient:
return self.inference_clients[runtime_id]

Expand Down
5 changes: 5 additions & 0 deletions malib/rollout/rolloutworker.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ def create_env_runner(
agent_groups=self.agent_groups,
inference_entry_points=rollout_config.inference_entry_points,
)
ready_check = [env_runner.ready.remote()]

# wait for it be ready
while len(ready_check):
_, ready_check = ray.wait(ready_check, num_returns=1, timeout=1)

return env_runner

Expand Down
14 changes: 7 additions & 7 deletions tests/rollout/test_pb_rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ def f(device):
Episode.DONE: spaces.Discrete(1),
Episode.CUR_OBS: env_desc["observation_spaces"][agent_id],
Episode.ACTION: env_desc["action_spaces"][agent_id],
Episode.REWARD: spaces.Box(-np.inf, np.inf, shape=(1,), dtype=np.float32),
Episode.REWARD: spaces.Box(-np.inf, np.inf, shape=(), dtype=np.float32),
Episode.NEXT_OBS: env_desc["observation_spaces"][agent_id],
}
np_memory = {
k: np.zeros((100,) + v.shape, dtype=v.dtype) for k, v in _spaces.items()
k: np.zeros((1000,) + v.shape, dtype=v.dtype) for k, v in _spaces.items()
}
return FakeFeatureHandler(_spaces, np_memory, device)
return FakeFeatureHandler(_spaces, np_memory, device=device)

return f

Expand Down Expand Up @@ -165,7 +165,7 @@ class TestRolloutWorker:
# stats = worker.rollout(task)

def test_rollout_with_data_entrypoint(self, n_player: int):
with ray.init(local_mode=True):
with ray.init():
env_desc, algorithm, rollout_config, group_info = gen_common_requirements(
n_player
)
Expand All @@ -175,8 +175,6 @@ def test_rollout_with_data_entrypoint(self, n_player: int):
agents = env_desc["possible_agents"]
log_dir = "./logs"

inference_namespace = "test_pb_rolloutworker"

learner_manager = LearnerManager(
stopping_conditions={"max_iteration": 10},
algorithm=algorithm,
Expand All @@ -193,7 +191,6 @@ def test_rollout_with_data_entrypoint(self, n_player: int):

infer_manager = InferenceManager(
group_info=group_info,
ray_actor_namespace=inference_namespace,
algorithm=algorithm,
model_entry_point=learner_manager.learner_entrypoints,
)
Expand All @@ -219,10 +216,13 @@ def test_rollout_with_data_entrypoint(self, n_player: int):
log_dir=log_dir,
)

print("PBRollout worker is ready to work!!!")

task = RolloutTask(
strategy_specs=strategy_spaces,
stopping_conditions={"max_iteration": 10},
data_entrypoints=learner_manager.data_entrypoints,
)

stats = worker.rollout(task)
ray.shutdown()

0 comments on commit 3e69230

Please sign in to comment.