Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] No Preprocessors; preparatory PR #1 #18367

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2245,9 +2245,9 @@ py_test(
name = "examples/custom_observation_filters",
main = "examples/custom_observation_filters.py",
tags = ["team:ml", "examples", "examples_C"],
size = "small",
size = "medium",
srcs = ["examples/custom_observation_filters.py"],
args = ["--stop-iters=2"]
args = ["--stop-iters=3"]
)

py_test(
Expand Down
2 changes: 2 additions & 0 deletions rllib/agents/dqn/dqn_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,8 @@ def postprocess_nstep_and_prio(policy: Policy,
batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS],
batch[SampleBatch.DONES])

# Create dummy prio-weights (1.0) in case we don't have any in
# the batch.
if PRIO_WEIGHTS not in batch:
batch[PRIO_WEIGHTS] = np.ones_like(batch[SampleBatch.REWARDS])

Expand Down
16 changes: 12 additions & 4 deletions rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,17 @@
# Tuple[value1, value2]: Clip at value1 and value2.
"clip_rewards": None,
# If True, RLlib will learn entirely inside a normalized action space
# (0.0 centered with small stddev; only affecting Box components) and
# only unsquash actions (and clip just in case) to the bounds of
# env's action space before sending actions back to the env.
# (0.0 centered with small stddev; only affecting Box components).
# We will unsquash actions (and clip, just in case) to the bounds of
# the env's action space before sending actions back to the env.
"normalize_actions": True,
# If True, RLlib will clip actions according to the env's bounds
# before sending them back to the env.
# TODO: (sven) This option should be obsoleted and always be False.
"clip_actions": False,
# Whether to use "rllib" or "deepmind" preprocessors by default
# Set to None for using no preprocessor. In this case, the model will have
# to handle possibly complex observations from the environment.
"preprocessor_pref": "deepmind",

# === Debug Settings ===
Expand Down Expand Up @@ -1003,7 +1005,7 @@ def compute_single_action(

# Check the preprocessor and preprocess, if necessary.
pp = local_worker.preprocessors[policy_id]
if type(pp).__name__ != "NoPreprocessor":
if pp and type(pp).__name__ != "NoPreprocessor":
observation = pp.transform(observation)
filtered_observation = local_worker.filters[policy_id](
observation, update=False)
Expand Down Expand Up @@ -1467,6 +1469,12 @@ def _validate_config(config: PartialTrainerConfigDict,
config["input_evaluation"]))

# Check model config.
# If no preprocessing, propagate into model's config as well
# (so model will know, whether inputs are preprocessed or not).
if config["preprocessor_pref"] is None:
model_config["_no_preprocessor"] = True

# Prev_a/r settings.
prev_a_r = model_config.get("lstm_use_prev_action_reward",
DEPRECATED_VALUE)
if prev_a_r != DEPRECATED_VALUE:
Expand Down
2 changes: 1 addition & 1 deletion rllib/env/multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def make_multi_agent(env_name_or_creator):
Returns:
Type[MultiAgentEnv]: New MultiAgentEnv class to be used as env.
The constructor takes a config dict with `num_agents` key
(default=1). The reset of the config dict will be passed on to the
(default=1). The rest of the config dict will be passed on to the
underlying single-agent env's constructor.

Examples:
Expand Down
13 changes: 8 additions & 5 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(
count_steps_by: str = "env_steps",
batch_mode: str = "truncate_episodes",
episode_horizon: int = None,
preprocessor_pref: str = "deepmind",
preprocessor_pref: Optional[str] = "deepmind",
sample_async: bool = False,
compress_observations: bool = False,
num_envs: int = 1,
Expand Down Expand Up @@ -256,8 +256,9 @@ class to use.
until the episode completes, and hence batches may contain
significant amounts of off-policy data.
episode_horizon (int): Whether to stop episodes at this horizon.
preprocessor_pref (str): Whether to prefer RLlib preprocessors
("rllib") or deepmind ("deepmind") when applicable.
preprocessor_pref (Optional[str]): Whether to use no preprocessor
(None), RLlib preprocessors ("rllib") or deepmind ("deepmind"),
when applicable.
sample_async (bool): Whether to compute samples asynchronously in
the background, which improves throughput but can cause samples
to be slightly off-policy.
Expand Down Expand Up @@ -414,7 +415,8 @@ def gen_rollouts():
self.count_steps_by: str = count_steps_by
self.batch_mode: str = batch_mode
self.compress_observations: bool = compress_observations
self.preprocessing_enabled: bool = True
self.preprocessing_enabled: bool = False \
if preprocessor_pref is None else True
self.observation_filter = observation_filter
self.last_batch: SampleBatchType = None
self.global_vars: dict = None
Expand Down Expand Up @@ -1358,7 +1360,8 @@ def _build_policy_map(
preprocessor = ModelCatalog.get_preprocessor_for_space(
obs_space, merged_conf.get("model"))
self.preprocessors[name] = preprocessor
obs_space = preprocessor.observation_space
if preprocessor is not None:
obs_space = preprocessor.observation_space
else:
self.preprocessors[name] = NoPreprocessor(obs_space)

Expand Down
20 changes: 14 additions & 6 deletions rllib/evaluation/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,10 +812,13 @@ def _process_observations(

policy_id: PolicyID = episode.policy_for(agent_id)

prep_obs: EnvObsType = _get_or_raise(worker.preprocessors,
policy_id).transform(raw_obs)
if log_once("prep_obs"):
logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))
preprocessor = _get_or_raise(worker.preprocessors, policy_id)
prep_obs: EnvObsType = raw_obs
if preprocessor is not None:
prep_obs = preprocessor.transform(raw_obs)
if log_once("prep_obs"):
logger.info("Preprocessed obs: {}".format(
summarize(prep_obs)))
filtered_obs: EnvObsType = _get_or_raise(worker.filters,
policy_id)(prep_obs)
if log_once("filtered_obs"):
Expand Down Expand Up @@ -955,10 +958,15 @@ def _process_observations(
# types: AgentID, EnvObsType
for agent_id, raw_obs in resetted_obs.items():
policy_id: PolicyID = new_episode.policy_for(agent_id)
prep_obs: EnvObsType = _get_or_raise(
worker.preprocessors, policy_id).transform(raw_obs)
preproccessor = _get_or_raise(worker.preprocessors,
policy_id)

prep_obs: EnvObsType = raw_obs
if preproccessor is not None:
prep_obs = preproccessor.transform(raw_obs)
filtered_obs: EnvObsType = _get_or_raise(
worker.filters, policy_id)(prep_obs)
new_episode._set_last_raw_obs(agent_id, raw_obs)
new_episode._set_last_observation(agent_id, filtered_obs)

# Add initial obs to buffer.
Expand Down
5 changes: 1 addition & 4 deletions rllib/examples/custom_observation_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,6 @@ def __repr__(self):
}

results = tune.run(
"PG",
args.run,
config=config,
stop={"training_iteration": args.stop_iters})
args.run, config=config, stop={"training_iteration": args.stop_iters})

ray.shutdown()
36 changes: 26 additions & 10 deletions rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
TorchDeterministic, TorchDiagGaussian, \
TorchMultiActionDistribution, TorchMultiCategorical
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, PublicAPI
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, \
deprecation_warning
from ray.rllib.utils.error import UnsupportedSpaceException
Expand All @@ -44,6 +44,11 @@
# 2) fully connected and CNN default networks as well as
# auto-wrapped LSTM- and attention nets.
"_use_default_native_models": False,
# Experimental flag.
# If True, user specified no preprocessor to be created
# (via config.preprocessor_pref=None). If True, observations will arrive
# in model as they are returned by the env.
"_no_preprocessing": False,

# === Built-in options ===
# FullyConnectedNetwork (tf and torch): rllib.models.tf|torch.fcnet.py
Expand Down Expand Up @@ -693,12 +698,13 @@ def get_preprocessor_for_space(observation_space: gym.Space,
cls = get_preprocessor(observation_space)
prep = cls(observation_space, options)

logger.debug("Created preprocessor {}: {} -> {}".format(
prep, observation_space, prep.shape))
if prep is not None:
logger.debug("Created preprocessor {}: {} -> {}".format(
prep, observation_space, prep.shape))
return prep

@staticmethod
@PublicAPI
@Deprecated(error=False)
def register_custom_preprocessor(preprocessor_name: str,
preprocessor_class: type) -> None:
"""Register a custom preprocessor class by name.
Expand Down Expand Up @@ -796,14 +802,15 @@ def _get_v2_model_class(input_space: gym.Space,
"framework={} not supported in `ModelCatalog._get_v2_model_"
"class`!".format(framework))

# Tuple space, where at least one sub-space is image.
# -> Complex input model.
# Complex space, where at least one sub-space is image.
# -> Complex input model (which auto-flattens everything, but correctly
# processes image components with default CNN stacks).
space_to_check = input_space if not hasattr(
input_space, "original_space") else input_space.original_space
if isinstance(input_space,
Tuple) or (isinstance(space_to_check, Tuple) and any(
isinstance(s, Box) and len(s.shape) >= 2
for s in space_to_check.spaces)):
if isinstance(input_space, (Dict, Tuple)) or (isinstance(
space_to_check, (Dict, Tuple)) and any(
isinstance(s, Box) and len(s.shape) >= 2
for s in tree.flatten(space_to_check.spaces))):
return ComplexNet

# Single, flattenable/one-hot-able space -> Simple FCNet.
Expand Down Expand Up @@ -860,6 +867,15 @@ def _validate_config(config: ModelConfigDict, framework: str) -> None:
Raises:
ValueError: If something is wrong with the given config.
"""
# Soft-deprecate custom preprocessors.
if config.get("custom_preprocessor") is not None:
deprecation_warning(
old="model.custom_preprocessor",
new="gym.ObservationWrapper around your env or handle complex "
"inputs inside your Model",
error=False,
)

if config.get("use_attention") and config.get("use_lstm"):
raise ValueError("Only one of `use_lstm` or `use_attention` may "
"be set to True!")
Expand Down
32 changes: 23 additions & 9 deletions rllib/models/modelv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,18 +216,32 @@ def __call__(
input_dict["is_training"] = input_dict.is_training
else:
restored = input_dict.copy()
restored["obs"] = restore_original_dimensions(
input_dict["obs"], self.obs_space, self.framework)
try:
if len(input_dict["obs"].shape) > 2:
restored["obs_flat"] = flatten(input_dict["obs"],
self.framework)
else:
restored["obs_flat"] = input_dict["obs"]
except AttributeError:

# No Preprocessor used: `config.preprocessor_pref`=None.
# TODO: This is unnecessary for when no preprocessor is used.
# Obs are not flat then anymore. However, we'll keep this
# here for backward-compatibility until Preprocessors have
# been fully deprecated.
if self.model_config.get("_no_preprocessing"):
restored["obs_flat"] = input_dict["obs"]
# Input to this Model went through a Preprocessor.
# Generate extra keys: "obs_flat" (vs "obs", which will hold the
# original obs).
else:
restored["obs"] = restore_original_dimensions(
input_dict["obs"], self.obs_space, self.framework)
try:
if len(input_dict["obs"].shape) > 2:
restored["obs_flat"] = flatten(input_dict["obs"],
self.framework)
else:
restored["obs_flat"] = input_dict["obs"]
except AttributeError:
restored["obs_flat"] = input_dict["obs"]

with self.context():
res = self.forward(restored, state or [], seq_lens)

if ((not isinstance(res, list) and not isinstance(res, tuple))
or len(res) != 2):
raise ValueError(
Expand Down
18 changes: 14 additions & 4 deletions rllib/models/preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,14 @@ def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
for i in range(len(self._obs_space.spaces)):
space = self._obs_space.spaces[i]
logger.debug("Creating sub-preprocessor for {}".format(space))
preprocessor = get_preprocessor(space)(space, self._options)
preprocessor_class = get_preprocessor(space)
if preprocessor_class is not None:
preprocessor = preprocessor_class(space, self._options)
size += preprocessor.size
else:
preprocessor = None
size += int(np.product(space.shape))
self.preprocessors.append(preprocessor)
size += preprocessor.size
return (size, )

@override(Preprocessor)
Expand Down Expand Up @@ -247,9 +252,14 @@ def _init_shape(self, obs_space: gym.Space, options: dict) -> List[int]:
self.preprocessors = []
for space in self._obs_space.spaces.values():
logger.debug("Creating sub-preprocessor for {}".format(space))
preprocessor = get_preprocessor(space)(space, self._options)
preprocessor_class = get_preprocessor(space)
if preprocessor_class is not None:
preprocessor = preprocessor_class(space, self._options)
size += preprocessor.size
else:
preprocessor = None
size += int(np.product(space.shape))
self.preprocessors.append(preprocessor)
size += preprocessor.size
return (size, )

@override(Preprocessor)
Expand Down
29 changes: 19 additions & 10 deletions rllib/models/tf/complex_input_net.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from gym.spaces import Box, Discrete, Tuple
from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
import numpy as np
import tree # pip install dm_tree

from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2, restore_original_dimensions
Expand All @@ -9,6 +10,7 @@
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.spaces.space_utils import flatten_space
from ray.rllib.utils.tf_ops import one_hot

tf1, tf, tfv = try_import_tf()
Expand All @@ -31,21 +33,22 @@ class ComplexInputNetwork(TFModelV2):

def __init__(self, obs_space, action_space, num_outputs, model_config,
name):
# TODO: (sven) Support Dicts as well.
self.original_space = obs_space.original_space if \
hasattr(obs_space, "original_space") else obs_space
assert isinstance(self.original_space, (Tuple)), \
"`obs_space.original_space` must be Tuple!"
assert isinstance(self.original_space, (Dict, Tuple)), \
"`obs_space.original_space` must be [Dict|Tuple]!"

super().__init__(self.original_space, action_space, num_outputs,
model_config, name)

self.flattened_input_space = flatten_space(self.original_space)

# Build the CNN(s) given obs_space's image components.
self.cnns = {}
self.one_hot = {}
self.flatten = {}
concat_size = 0
for i, component in enumerate(self.original_space):
for i, component in enumerate(self.flattened_input_space):
# Image space.
if len(component.shape) == 3:
config = {
Expand All @@ -64,11 +67,13 @@ def __init__(self, obs_space, action_space, num_outputs, model_config,
name="cnn_{}".format(i))
concat_size += cnn.num_outputs
self.cnns[i] = cnn
# Discrete inputs -> One-hot encode.
# Discrete|MultiDiscrete inputs -> One-hot encode.
elif isinstance(component, Discrete):
self.one_hot[i] = True
concat_size += component.n
# TODO: (sven) Multidiscrete (see e.g. our auto-LSTM wrappers).
elif isinstance(component, MultiDiscrete):
self.one_hot[i] = True
concat_size += sum(component.nvec)
# Everything else (1D Box).
else:
self.flatten[i] = int(np.product(component.shape))
Expand Down Expand Up @@ -123,18 +128,22 @@ def forward(self, input_dict, state, seq_lens):
self.obs_space, "tf")
# Push image observations through our CNNs.
outs = []
for i, component in enumerate(orig_obs):
for i, component in enumerate(tree.flatten(orig_obs)):
if i in self.cnns:
cnn_out, _ = self.cnns[i]({SampleBatch.OBS: component})
outs.append(cnn_out)
elif i in self.one_hot:
if component.dtype in [tf.int32, tf.int64, tf.uint8]:
outs.append(
one_hot(component, self.original_space.spaces[i]))
one_hot(component, self.flattened_input_space[i]))
else:
outs.append(component)
else:
outs.append(tf.reshape(component, [-1, self.flatten[i]]))
outs.append(
tf.cast(
tf.reshape(component, [-1, self.flatten[i]]),
dtype=tf.float32,
))
# Concat all outputs and the non-image inputs.
out = tf.concat(outs, axis=1)
# Push through (optional) FC-stack (this may be an empty stack).
Expand Down
Loading