Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] SAC crashes on the env having the dict observation space #18418

Closed
RuofanKong opened this issue Sep 8, 2021 · 11 comments · Fixed by #19101
Closed

[RLlib] SAC crashes on the env having the dict observation space #18418

RuofanKong opened this issue Sep 8, 2021 · 11 comments · Fixed by #19101
Assignees
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks rllib RLlib related issues

Comments

@RuofanKong
Copy link

RuofanKong commented Sep 8, 2021

Issue Description

I was using SAC to train an agent on an environment having the dictionary type observation space but it was crashed.

System Info

  • Ray version: 1.3.0 (the issue also occurs in the latest version 1.6.0)
  • Python version: 3.8.5
  • Tensorflow version: 2.3.0
  • OS: MacOS Catalina 10.15.6

Repro Steps

Run the following code with above system info, and the issue could be reproduced.

import csv
from typing import Any, Dict, Tuple

import gym
import numpy as np
import ray
import ray.rllib.agents.sac.sac as sac
from gym.envs.classic_control.pendulum import PendulumEnv
from gym.spaces import Box, Dict as GymDict
from ray.tune.logger import pretty_print
from ray.tune.registry import register_env


class DictPendulumEnv(PendulumEnv):
    def __init__(self, g: float = 10.0):
        super().__init__(g=g)
        high = np.array([1.0, 1.0, self.max_speed], dtype=np.float32)
        self.observation_space = GymDict(
            {
                "cos_theta": Box(low=-high[0], high=high[0], shape=()),
                "sin_theta": Box(low=-high[1], high=high[1], shape=()),
                "theta_dot": Box(low=-high[2], high=high[2], shape=()),
            }
        )

    def reset(self) -> Dict[str, np.ndarray]:
        obs = super().reset()
        return {
            "cos_theta": obs[0],
            "sin_theta": obs[1],
            "theta_dot": obs[2],
        }

    def step(
        self, action: int
    ) -> Tuple[Dict[str, np.ndarray], float, bool, Dict[str, Any]]:
        obs, reward, done, info = super().step(action)
        return (
            {
                "cos_theta": obs[0],
                "sin_theta": obs[1],
                "theta_dot": obs[2],
            },
            reward,
            done,
            info,
        )


def env_creator(env_config: Dict[str, Any]) -> gym.Env:
    return DictPendulumEnv()


def train_dict_pendulum() -> None:
    ray.init()
    register_env("dict_pendulum", env_creator)
    config = sac.DEFAULT_CONFIG.copy()
    config["num_gpus"] = 0
    config["num_workers"] = 3
    trainer = sac.SACTrainer(config=config, env="dict_pendulum")

    writer = csv.writer(open("./dict_pendulum_sac.csv", "w"))
    for _ in range(30):
        # Perform one iteration of training the policy with PPO
        result = trainer.train()
        print(pretty_print(result))
        writer.writerow(
            [
                result["training_iteration"],
                result["timesteps_total"],
                result["episode_reward_mean"],
            ]
        )


train_dict_pendulum()

By running the code, the following crash will show up,

(pid=78844) 2021-09-07 18:48:51,904	ERROR worker.py:428 -- Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=78844, ip=192.168.4.29)
(pid=78844) TypeError: get_distribution_inputs_and_class() missing 1 required positional argument: 'obs_batch'
(pid=78844) 
(pid=78844) During handling of the above exception, another exception occurred:
(pid=78844) 
(pid=78844) ray::RolloutWorker.__init__() (pid=78844, ip=192.168.4.29)
(pid=78844)   File "/Users/rukon/ray130/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 580, in __init__
(pid=78844)     self._build_policy_map(
(pid=78844)   File "/Users/rukon/ray130/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1375, in _build_policy_map
(pid=78844)     self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
(pid=78844)   File "/Users/rukon/ray130/lib/python3.8/site-packages/ray/rllib/policy/policy_map.py", line 126, in create_policy
(pid=78844)     self[policy_id] = class_(
(pid=78844)   File "/Users/rukon/ray130/lib/python3.8/site-packages/ray/rllib/policy/tf_policy_template.py", line 237, in __init__
(pid=78844)     DynamicTFPolicy.__init__(
(pid=78844)   File "/Users/rukon/ray130/lib/python3.8/site-packages/ray/rllib/policy/dynamic_tf_policy.py", line 294, in __init__
(pid=78844)     action_distribution_fn(
(pid=78844)   File "/Users/rukon/ray130/lib/python3.8/site-packages/ray/rllib/agents/sac/sac_tf_policy.py", line 214, in get_distribution_inputs_and_class
(pid=78844)     distribution_inputs = model.get_policy_output(forward_out)
(pid=78844)   File "/Users/rukon/ray130/lib/python3.8/site-packages/ray/rllib/agents/sac/sac_tf_model.py", line 277, in get_policy_output
(pid=78844)     out, _ = self.action_model({"obs": model_out}, [], None)
(pid=78844)   File "/Users/rukon/ray130/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 219, in __call__
(pid=78844)     restored["obs"] = restore_original_dimensions(
(pid=78844)   File "/Users/rukon/ray130/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 374, in restore_original_dimensions
(pid=78844)     return _unpack_obs(obs, original_space, tensorlib=tensorlib)
(pid=78844)   File "/Users/rukon/ray130/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 408, in _unpack_obs
(pid=78844)     raise ValueError(
(pid=78844) ValueError: Expected flattened obs shape of [..., 3], got (?,)
(pid=78843) 2021-09-07 18:48:51,905	ERROR worker.py:428 -- Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=78843, ip=192.168.4.29)
(pid=78843) TypeError: get_distribution_inputs_and_class() missing 1 required positional argument: 'obs_batch'
(pid=78843) 
(pid=78843) During handling of the above exception, another exception occurred:
(pid=78843) 
(pid=78843) ray::RolloutWorker.__init__() (pid=78843, ip=192.168.4.29)
(pid=78843)   File "/Users/rukon/ray130/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 580, in __init__
(pid=78843)     self._build_policy_map(
(pid=78843)   File "/Users/rukon/ray130/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1375, in _build_policy_map
(pid=78843)     self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
(pid=78843)   File "/Users/rukon/ray130/lib/python3.8/site-packages/ray/rllib/policy/policy_map.py", line 126, in create_policy
(pid=78843)     self[policy_id] = class_(
(pid=78843)   File "/Users/rukon/ray130/lib/python3.8/site-packages/ray/rllib/policy/tf_policy_template.py", line 237, in __init__
(pid=78843)     DynamicTFPolicy.__init__(
(pid=78843)   File "/Users/rukon/ray130/lib/python3.8/site-packages/ray/rllib/policy/dynamic_tf_policy.py", line 294, in __init__
(pid=78843)     action_distribution_fn(
(pid=78843)   File "/Users/rukon/ray130/lib/python3.8/site-packages/ray/rllib/agents/sac/sac_tf_policy.py", line 214, in get_distribution_inputs_and_class
(pid=78843)     distribution_inputs = model.get_policy_output(forward_out)
(pid=78843)   File "/Users/rukon/ray130/lib/python3.8/site-packages/ray/rllib/agents/sac/sac_tf_model.py", line 277, in get_policy_output
(pid=78843)     out, _ = self.action_model({"obs": model_out}, [], None)
(pid=78843)   File "/Users/rukon/ray130/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 219, in __call__
(pid=78843)     restored["obs"] = restore_original_dimensions(
(pid=78843)   File "/Users/rukon/ray130/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 374, in restore_original_dimensions
(pid=78843)     return _unpack_obs(obs, original_space, tensorlib=tensorlib)
(pid=78843)   File "/Users/rukon/ray130/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 408, in _unpack_obs
(pid=78843)     raise ValueError(
(pid=78843) ValueError: Expected flattened obs shape of [..., 3], got (?,)

NOTE: The above "dictionary observation based" pendulum is a mock, and the issue occurs on all the environment having the dictionary observation space.

@RuofanKong RuofanKong added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Sep 8, 2021
@gjoliver gjoliver self-assigned this Sep 8, 2021
@gjoliver
Copy link
Member

gjoliver commented Sep 8, 2021

the problem seems to be this line:
https://github.com/ray-project/ray/blob/master/rllib/agents/sac/sac_tf_model.py#L276
instead of turning model_out into (?, 3), it converts it into (?,)

@RuofanKong, can I ask when was the last time you successfully used our SAC agent for a Dict obs space?
asking because this code hasn't been updated in the last 7 months, so I want to understand the change/problem better.

@RuofanKong
Copy link
Author

@gjoliver the earliest version that i tried was 0.8.6, and it crashed on different errors,

Traceback (most recent call last):
  File "train_dict_pendulum.py", line 76, in <module>
    train_dict_pendulum()
  File "train_dict_pendulum.py", line 60, in train_dict_pendulum
    trainer = sac.SACTrainer(config=config, env="dict_pendulum")
  File "/Users/rukon/ray086/lib/python3.8/site-packages/ray/rllib/agents/trainer_template.py", line 90, in __init__
    Trainer.__init__(self, config, env, logger_creator)
  File "/Users/rukon/ray086/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 450, in __init__
    super().__init__(config, logger_creator)
  File "/Users/rukon/ray086/lib/python3.8/site-packages/ray/tune/trainable.py", line 175, in __init__
    self._setup(copy.deepcopy(self.config))
  File "/Users/rukon/ray086/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 623, in _setup
    self._init(self.config, self.env_creator)
  File "/Users/rukon/ray086/lib/python3.8/site-packages/ray/rllib/agents/trainer_template.py", line 113, in _init
    self.workers = self._make_workers(env_creator, self._policy,
  File "/Users/rukon/ray086/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 691, in _make_workers
    return WorkerSet(
  File "/Users/rukon/ray086/lib/python3.8/site-packages/ray/rllib/evaluation/worker_set.py", line 58, in __init__
    self._local_worker = self._make_worker(
  File "/Users/rukon/ray086/lib/python3.8/site-packages/ray/rllib/evaluation/worker_set.py", line 245, in _make_worker
    worker = cls(
  File "/Users/rukon/ray086/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 379, in __init__
    self._build_policy_map(policy_dict, policy_config)
  File "/Users/rukon/ray086/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 937, in _build_policy_map
    policy_map[name] = cls(obs_space, act_space, merged_conf)
  File "/Users/rukon/ray086/lib/python3.8/site-packages/ray/rllib/policy/tf_policy_template.py", line 128, in __init__
    DynamicTFPolicy.__init__(
  File "/Users/rukon/ray086/lib/python3.8/site-packages/ray/rllib/policy/dynamic_tf_policy.py", line 215, in __init__
    action_dist = dist_class(dist_inputs, self.model)
  File "/Users/rukon/ray086/lib/python3.8/site-packages/ray/rllib/models/tf/tf_action_dist.py", line 277, in __init__
    assert tfp is not None
AssertionError

So far I've not successfully get SAC working with dict obs space.

@gjoliver
Copy link
Member

gjoliver commented Sep 8, 2021

ah, ok, just want to confirm that it's not a recent regression.

this error just means you need to pip install tensorflow-probability actually. don't know why it's not in the requirements.txt.

@RuofanKong
Copy link
Author

@gjoliver I actually have tensorflow-probability installed but get the error. And yes, I don't think all dependencies are included in requirements.txt, while i have to manually install them one-by-one (e.g. opencv-python, etc.), not sure worthy to file a FR on it. But anyway, dictionary based obs never worked for me for SAC.

@sven1977
Copy link
Contributor

sven1977 commented Sep 8, 2021

Hmm, there is a SAC "compilation" test case in agents/sac/tests/test_sac.py that uses the RandomEnv with a Dict obs space. Maybe we can also work with that and reproduce?
Yes, the tfp error is a pain. We need to improve this error message. This only means that you have a version mismatch between tf and tfp (even if both are installed properly) and when you import tfp, it errors, but RLlib suppresses that b/c we are wrapping the import with "rllib/utils/framework.py::try_import_tfp()".

@gjoliver
Copy link
Member

gjoliver commented Sep 9, 2021

Turns out, this has nothing to do with Dict space, Tuple space would cause the same problem.
Our nested space packing/unpacking logic doesn't work with 0-dimension observations very well.
For example, the following DictPendulumEnv runs just fine:

class DictPendulumEnv(PendulumEnv):
    def __init__(self, g: float = 10.0):
        super().__init__(g=g)
        high = np.array([1.0, 1.0, self.max_speed], dtype=np.float32)
        self.observation_space = GymDict(
            {
                "cos_theta": Box(low=-high[0], high=high[0], shape=(1,)),
                "sin_theta": Box(low=-high[1], high=high[1], shape=(1,)),
                "theta_dot": Box(low=-high[2], high=high[2], shape=(1,)),
            }
        )

    def reset(self) -> Dict[str, np.ndarray]:
        obs = super().reset()
        return {
            "cos_theta": [obs[0]],
            "sin_theta": [obs[1]],
            "theta_dot": [obs[2]],
        }

    def step(
        self, action: int
    ) -> Tuple[Dict[str, np.ndarray], float, bool, Dict[str, Any]]:
        obs, reward, done, info = super().step(action)
        return (
            {
                "cos_theta": [obs[0]],
                "sin_theta": [obs[1]],
                "theta_dot": [obs[2]],
            },
            reward,
            done,
            info,
        )

I can probably make a fix for this.

@gjoliver
Copy link
Member

gjoliver commented Sep 9, 2021

had a quick chat with Sven. Sven is actually cleaning up our codebase to get rid of the preprocessor stuff. this problem should go away with that bigger cleanup effort.
so instead of creating a one-off fix here, let's just wait for this PR to land:

#18468

In the mean time, Dict space should run with the workaround.
Thanks for the report Ruofan.

@gjoliver gjoliver assigned sven1977 and unassigned gjoliver Sep 9, 2021
@sven1977 sven1977 added P1 Issue that should be fixed within a few weeks rllib RLlib related issues and removed triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Oct 5, 2021
@sven1977
Copy link
Contributor

sven1977 commented Oct 5, 2021

Hey @RuofanKong , sorry for the long delay. I think the root cause here is a different one, namely:

Dict(a=Box((?, )), b=Box((?, ))) -> gets flattened into Box((?, )) however, it should get flattened into Box((?, 2))

@sven1977
Copy link
Contributor

sven1977 commented Oct 5, 2021

I'm prepping a PR that fixes this.

@sven1977
Copy link
Contributor

sven1977 commented Oct 5, 2021

Here is a PR that will fix this problem. We will also more and more roll out the support for using no-preprocessing across all algos. Due to their specific model constraints, SAC and DQN currently actually don't support this experimental flag.

#19101

@sven1977 sven1977 changed the title SAC crashes on the env having the dict observation space [RLlib] SAC crashes on the env having the dict observation space Oct 5, 2021
@jjyyxx
Copy link
Contributor

jjyyxx commented Mar 2, 2022

Not sure if its the correct place to ask, but #19101 only partly fixes the problem. When having an observation space like

Tuple([Box(...), Repeated(Box(...), max_len=4)])

and I have a custom model to handle the Repeated observation, the flatten logic will not work as expected.

My final hacky workaround is

class CustomTorchModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        super().__init__(obs_space, action_space, num_outputs, model_config,
                         name)
        self.obs_space = obs_space.original_space
        ...
  • If passing obs_space.original_space to super().__init__ or setting _disable_preprocessor_api to True, Policy._get_dummy_batch_from_view_requirements will raise exception about "Repeated.shape=None is not iteratable" (not limited to SACPolicy).

  • If passing obs_space to super().__init__ and not adding the self.obs_space = obs_space.original_space line, the [RLlib] Issue 18418: SAC w/ dict space not working. #19101 fix will raise exception then concatenating a Tensor and a RepeatedValue.

If it's not the correct place to ask, I could open another issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks rllib RLlib related issues
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants