Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Issue 18418: SAC w/ dict space not working. #19101

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions rllib/agents/sac/sac_tf_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import gym
from gym.spaces import Box, Discrete
import numpy as np
import tree # pip install dm_tree
from typing import Dict, List, Optional

from ray.rllib.models.catalog import ModelCatalog
Expand Down Expand Up @@ -267,13 +268,18 @@ def get_policy_output(self, model_out: TensorType) -> TensorType:
Returns:
TensorType: Distribution inputs for sampling actions.
"""
# Model outs may come as original Tuple observations, concat them
# Model outs may come as original Tuple/Dict observations, concat them
# here if this is the case.
if isinstance(self.action_model.obs_space, Box):
if isinstance(model_out, (list, tuple)):
model_out = tf.concat(model_out, axis=-1)
elif isinstance(model_out, dict):
model_out = tf.concat(list(model_out.values()), axis=-1)
model_out = tf.concat(
[
tf.expand_dims(val, 1) if len(val.shape) == 1 else val
for val in tree.flatten(model_out.values())
],
axis=-1)
out, _ = self.action_model({"obs": model_out}, [], None)
return out

Expand Down
8 changes: 2 additions & 6 deletions rllib/agents/sac/sac_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from gym.spaces import Box, Discrete
from functools import partial
import logging
import numpy as np
from typing import Dict, List, Optional, Tuple, Type, Union

import ray
Expand Down Expand Up @@ -53,9 +52,6 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space,
target model will be created in this function and assigned to
`policy.target_model`.
"""
# With separate state-preprocessor (before obs+action concat).
num_outputs = int(np.product(obs_space.shape))

# Force-ignore any additionally provided hidden layer sizes.
# Everything should be configured using SAC's "Q_model" and "policy_model"
# settings.
Expand All @@ -70,7 +66,7 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space,
model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=num_outputs,
num_outputs=None,
model_config=config["model"],
framework=config["framework"],
default_model=default_model_cls,
Expand All @@ -90,7 +86,7 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space,
policy.target_model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=num_outputs,
num_outputs=None,
model_config=config["model"],
framework=config["framework"],
default_model=default_model_cls,
Expand Down
8 changes: 7 additions & 1 deletion rllib/agents/sac/sac_torch_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import gym
from gym.spaces import Box, Discrete
import numpy as np
import tree # pip install dm_tree
from typing import Dict, List, Optional

from ray.rllib.models.catalog import ModelCatalog
Expand Down Expand Up @@ -281,7 +282,12 @@ def get_policy_output(self, model_out: TensorType) -> TensorType:
if isinstance(model_out, (list, tuple)):
model_out = torch.cat(model_out, dim=-1)
elif isinstance(model_out, dict):
model_out = torch.cat(list(model_out.values()), dim=-1)
model_out = torch.cat(
[
torch.unsqueeze(val, 1) if len(val.shape) == 1 else val
for val in tree.flatten(model_out.values())
],
dim=-1)
out, _ = self.action_model({"obs": model_out}, [], None)
return out

Expand Down
27 changes: 17 additions & 10 deletions rllib/agents/sac/tests/test_sac.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from gym import Env
from gym.spaces import Box, Discrete, Tuple
from gym.spaces import Box, Dict, Discrete, Tuple
import numpy as np
import re
import unittest
Expand All @@ -23,6 +23,7 @@
from ray.rllib.utils.test_utils import check, check_compute_single_action, \
check_train_results, framework_iterator
from ray.rllib.utils.torch_ops import convert_to_torch_tensor
from ray import tune

tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
Expand Down Expand Up @@ -90,22 +91,28 @@ def test_sac_compilation(self):
image_space = Box(-1.0, 1.0, shape=(84, 84, 3))
simple_space = Box(-1.0, 1.0, shape=(3, ))

tune.register_env(
"random_dict_env", lambda _: RandomEnv({
"observation_space": Dict({
"a": simple_space,
"b": Discrete(2),
"c": image_space, }),
"action_space": Box(-1.0, 1.0, shape=(1, )), }))
tune.register_env(
"random_tuple_env", lambda _: RandomEnv({
"observation_space": Tuple([
simple_space, Discrete(2), image_space]),
"action_space": Box(-1.0, 1.0, shape=(1, )), }))

for fw in framework_iterator(config):
# Test for different env types (discrete w/ and w/o image, + cont).
for env in [
RandomEnv,
"random_dict_env",
"random_tuple_env",
"MsPacmanNoFrameskip-v4",
"CartPole-v0",
]:
print("Env={}".format(env))
if env == RandomEnv:
config["env_config"] = {
"observation_space": Tuple((simple_space, Discrete(2),
image_space)),
"action_space": Box(-1.0, 1.0, shape=(1, )),
}
else:
config["env_config"] = {}
# Test making the Q-model a custom one for CartPole, otherwise,
# use the default model.
config["Q_model"]["custom_model"] = "batch_norm{}".format(
Expand Down