Skip to content

Commit

Permalink
Change network interface from RNNCell to function.
Browse files Browse the repository at this point in the history
  • Loading branch information
danijar committed Oct 4, 2017
1 parent d8a2881 commit 9d6d8b5
Show file tree
Hide file tree
Showing 10 changed files with 123 additions and 182 deletions.
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ modifying the code:
| File | Content |
| ---- | ------- |
| `scripts/configs.py` | Experiment configurations specifying the tasks and algorithms. |
| `scripts/networks.py` | Neural network models defined as [TensorFlow RNNCells][tf-rnn-cell]. |
| `scripts/networks.py` | Neural network models. |
| `scripts/train.py` | The executable file containing the training setup. |
| `ppo/algorithm.py` | The TensorFlow graph for the PPO algorithm. |

Expand All @@ -80,8 +80,6 @@ python3 -m unittest discover -p "*_test.py"

For further questions, please open an issue on Github.

[tf-rnn-cell]: https://www.tensorflow.org/api_docs/python/tf/contrib/rnn/RNNCell

Implementation
--------------

Expand Down
69 changes: 25 additions & 44 deletions agents/ppo/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from __future__ import print_function

import collections
import functools

import tensorflow as tf

Expand All @@ -31,10 +32,6 @@
from agents.ppo import utility


_NetworkOutput = collections.namedtuple(
'NetworkOutput', 'policy, mean, logstd, value, state')


class PPOAlgorithm(object):
"""A vectorized implementation of the PPO algorithm by John Schulman."""

Expand Down Expand Up @@ -70,15 +67,25 @@ def __init__(self, batch_env, step, is_training, should_log, config):
use_gpu = self._config.use_gpu and utility.available_gpus()
with tf.device('/gpu:0' if use_gpu else '/cpu:0'):
# Create network variables for later calls to reuse.
self._network(
action_size = self._batch_env.action.shape[1].value
self._network = tf.make_template(
'network', functools.partial(config.network, config, action_size))
output = self._network(
tf.zeros_like(self._batch_env.observ)[:, None],
tf.ones(len(self._batch_env)), reuse=None)
cell = self._config.network(self._batch_env.action.shape[1].value)
tf.ones(len(self._batch_env)))
with tf.variable_scope('ppo_temporary'):
self._episodes = memory.EpisodeMemory(
template, len(batch_env), config.max_length, 'episodes')
self._last_state = utility.create_nested_vars(
cell.zero_state(len(batch_env), tf.float32))
if output.state is None:
self._last_state = None
else:
# Ensure the batch dimension is set.
tf.contrib.framework.nest.map_structure(
lambda x: x.set_shape([len(batch_env)] + x.shape.as_list()[1:]),
output.state)
self._last_state = tf.contrib.framework.nest.map_structure(
lambda x: tf.Variable(lambda: tf.zeros_like(x), False),
output.state)
self._last_action = tf.Variable(
tf.zeros_like(self._batch_env.action), False, name='last_action')
self._last_mean = tf.Variable(
Expand All @@ -102,7 +109,10 @@ def begin_episode(self, agent_indices):
Summary tensor.
"""
with tf.name_scope('begin_episode/'):
reset_state = utility.reinit_nested_vars(self._last_state, agent_indices)
if self._last_state is None:
reset_state = tf.no_op()
else:
reset_state = utility.reinit_nested_vars(self._last_state, agent_indices)
reset_buffer = self._episodes.clear(agent_indices)
with tf.control_dependencies([reset_state, reset_buffer]):
return tf.constant('')
Expand Down Expand Up @@ -130,8 +140,12 @@ def perform(self, observ):
tf.summary.histogram('action', action[:, 0]),
tf.summary.histogram('logprob', logprob)]), str)
# Remember current policy to append to memory in the experience callback.
if self._last_state is None:
assign_state = tf.no_op()
else:
assign_state = utility.assign_nested_vars(self._last_state, network.state)
with tf.control_dependencies([
utility.assign_nested_vars(self._last_state, network.state),
assign_state,
self._last_action.assign(action[:, 0]),
self._last_mean.assign(network.mean[:, 0]),
self._last_logstd.assign(network.logstd[:, 0])]):
Expand Down Expand Up @@ -523,36 +537,3 @@ def _mask(self, tensor, length):
mask = tf.cast(range_[None, :] < length[:, None], tf.float32)
masked = tensor * mask
return tf.check_numerics(masked, 'masked')

def _network(self, observ, length=None, state=None, reuse=True):
"""Compute the network output for a batched sequence of observations.
Optionally, the initial state can be specified. The weights should be
reused for all calls, except for the first one. Output is a named tuple
containing the policy as a TensorFlow distribution, the policy mean and log
standard deviation, the approximated state value, and the new recurrent
state.
Args:
observ: Sequences of observations.
length: Batch of sequence lengths.
state: Batch of initial recurrent states.
reuse: Python boolean whether to reuse previous variables.
Returns:
NetworkOutput tuple.
"""
with tf.variable_scope('network', reuse=reuse):
observ = tf.convert_to_tensor(observ)
use_gpu = self._config.use_gpu and utility.available_gpus()
with tf.device('/gpu:0' if use_gpu else '/cpu:0'):
observ = tf.check_numerics(observ, 'observ')
cell = self._config.network(self._batch_env.action.shape[1].value)
(mean, logstd, value), state = tf.nn.dynamic_rnn(
cell, observ, length, state, tf.float32, swap_memory=True)
mean = tf.check_numerics(mean, 'mean')
logstd = tf.check_numerics(logstd, 'logstd')
value = tf.check_numerics(value, 'value')
policy = tf.contrib.distributions.MultivariateNormalDiag(
mean, tf.exp(logstd))
return _NetworkOutput(policy, mean, logstd, value, state)
14 changes: 0 additions & 14 deletions agents/ppo/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,6 @@
from tensorflow.python.client import device_lib


def create_nested_vars(tensors):
"""Create variables matching a nested tuple of tensors.
Args:
tensors: Nested tuple of list of tensors.
Returns:
Nested tuple or list of variables.
"""
if isinstance(tensors, (tuple, list)):
return type(tensors)(create_nested_vars(tensor) for tensor in tensors)
return tf.Variable(tensors, False)


def reinit_nested_vars(variables, indices=None):
"""Reset all variables in a nested tuple to zeros.
Expand Down
6 changes: 2 additions & 4 deletions agents/scripts/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,9 @@ def default():
eval_episodes = 25
use_gpu = False
# Network
network = networks.ForwardGaussianPolicy
network = networks.feed_forward_gaussian
weight_summaries = dict(
all=r'.*',
policy=r'.*/policy/.*',
value=r'.*/value/.*')
all=r'.*', policy=r'.*/policy/.*', value=r'.*/value/.*')
policy_layers = 200, 100
value_layers = 200, 100
init_mean_factor = 0.05
Expand Down
175 changes: 90 additions & 85 deletions agents/scripts/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,109 +12,114 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Networks for the PPO algorithm defined as recurrent cells."""
"""Network definitions for the PPO algorithm."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import functools
import operator

import tensorflow as tf


_MEAN_WEIGHTS_INITIALIZER = tf.contrib.layers.variance_scaling_initializer(
factor=0.1)
_LOGSTD_INITIALIZER = tf.random_normal_initializer(-1, 1e-10)
NetworkOutput = collections.namedtuple(
'NetworkOutput', 'policy, mean, logstd, value, state')


class ForwardGaussianPolicy(tf.contrib.rnn.RNNCell):
def feed_forward_gaussian(
config, action_size, observations, length, state=None):
"""Independent feed forward networks for policy and value.
The policy network outputs the mean action and the log standard deviation
is learned as independent parameter vector.
"""
def __init__(
self, policy_layers, value_layers, action_size,
mean_weights_initializer=_MEAN_WEIGHTS_INITIALIZER,
logstd_initializer=_LOGSTD_INITIALIZER):
self._policy_layers = policy_layers
self._value_layers = value_layers
self._action_size = action_size
self._mean_weights_initializer = mean_weights_initializer
self._logstd_initializer = logstd_initializer

@property
def state_size(self):
unused_state_size = 1
return unused_state_size

@property
def output_size(self):
return (self._action_size, self._action_size, tf.TensorShape([]))

def __call__(self, observation, state):
with tf.variable_scope('policy'):
x = tf.contrib.layers.flatten(observation)
for size in self._policy_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
mean = tf.contrib.layers.fully_connected(
x, self._action_size, tf.tanh,
weights_initializer=self._mean_weights_initializer)
logstd = tf.get_variable(
'logstd', mean.shape[1:], tf.float32, self._logstd_initializer)
logstd = tf.tile(
logstd[None, ...], [tf.shape(mean)[0]] + [1] * logstd.shape.ndims)
with tf.variable_scope('value'):
x = tf.contrib.layers.flatten(observation)
for size in self._value_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
value = tf.contrib.layers.fully_connected(x, 1, None)[:, 0]
return (mean, logstd, value), state


class RecurrentGaussianPolicy(tf.contrib.rnn.RNNCell):
Args:
config: Configuration object.
action_size: Length of the action vector.
observations: Sequences of observations.
length: Batch of sequence lengths.
state: Batch of initial recurrent states.
Returns:
NetworkOutput tuple.
"""
mean_weights_initializer = tf.contrib.layers.variance_scaling_initializer(
factor=config.init_mean_factor)
logstd_initializer = tf.random_normal_initializer(config.init_logstd, 1e-10)
flat_observations = tf.reshape(observations, [
tf.shape(observations)[0], tf.shape(observations)[1],
functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)])
with tf.variable_scope('policy'):
x = flat_observations
for size in config.policy_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
mean = tf.contrib.layers.fully_connected(
x, action_size, tf.tanh,
weights_initializer=mean_weights_initializer)
logstd = tf.tile(tf.get_variable(
'logstd', mean.shape[2:], tf.float32, logstd_initializer)[None, None],
[tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2))
with tf.variable_scope('value'):
x = flat_observations
for size in config.value_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
value = tf.contrib.layers.fully_connected(x, 1, None)[:, 0]
mean = tf.check_numerics(mean, 'mean')
logstd = tf.check_numerics(logstd, 'logstd')
value = tf.check_numerics(value, 'value')
policy = tf.contrib.distributions.MultivariateNormalDiag(
mean, tf.exp(logstd))
return NetworkOutput(policy, mean, logstd, value, state)


def recurrent_gaussian(
config, action_size, observations, length, state=None):
"""Independent recurrent policy and feed forward value networks.
The policy network outputs the mean action and the log standard deviation
is learned as independent parameter vector. The last policy layer is recurrent
and uses a GRU cell.
"""
def __init__(
self, policy_layers, value_layers, action_size,
mean_weights_initializer=_MEAN_WEIGHTS_INITIALIZER,
logstd_initializer=_LOGSTD_INITIALIZER):
self._policy_layers = policy_layers
self._value_layers = value_layers
self._action_size = action_size
self._mean_weights_initializer = mean_weights_initializer
self._logstd_initializer = logstd_initializer
self._cell = tf.contrib.rnn.GRUBlockCell(100)

@property
def state_size(self):
return self._cell.state_size

@property
def output_size(self):
return (self._action_size, self._action_size, tf.TensorShape([]))

def __call__(self, observation, state):
with tf.variable_scope('policy'):
x = tf.contrib.layers.flatten(observation)
for size in self._policy_layers[:-1]:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
x, state = self._cell(x, state)
mean = tf.contrib.layers.fully_connected(
x, self._action_size, tf.tanh,
weights_initializer=self._mean_weights_initializer)
logstd = tf.get_variable(
'logstd', mean.shape[1:], tf.float32, self._logstd_initializer)
logstd = tf.tile(
logstd[None, ...], [tf.shape(mean)[0]] + [1] * logstd.shape.ndims)
with tf.variable_scope('value'):
x = tf.contrib.layers.flatten(observation)
for size in self._value_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
value = tf.contrib.layers.fully_connected(x, 1, None)[:, 0]
return (mean, logstd, value), state
Args:
config: Configuration object.
action_size: Length of the action vector.
observations: Sequences of observations.
length: Batch of sequence lengths.
state: Batch of initial recurrent states.
Returns:
NetworkOutput tuple.
"""
mean_weights_initializer = tf.contrib.layers.variance_scaling_initializer(
factor=config.init_mean_factor)
logstd_initializer = tf.random_normal_initializer(config.init_logstd, 1e-10)
cell = tf.contrib.rnn.GRUBlockCell(config.policy_layers[-1])
flat_observations = tf.reshape(observations, [
tf.shape(observations)[0], tf.shape(observations)[1],
functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)])
with tf.variable_scope('policy'):
x = flat_observations
for size in config.policy_layers[:-1]:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
x, state = tf.nn.dynamic_rnn(cell, x, length, state, tf.float32)
mean = tf.contrib.layers.fully_connected(
x, action_size, tf.tanh,
weights_initializer=mean_weights_initializer)
logstd = tf.tile(tf.get_variable(
'logstd', mean.shape[2:], tf.float32, logstd_initializer)[None, None],
[tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2))
with tf.variable_scope('value'):
x = flat_observations
for size in config.value_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
value = tf.contrib.layers.fully_connected(x, 1, None)[:, 0]
mean = tf.check_numerics(mean, 'mean')
logstd = tf.check_numerics(logstd, 'logstd')
value = tf.check_numerics(value, 'value')
policy = tf.contrib.distributions.MultivariateNormalDiag(
mean, tf.exp(logstd))
# assert state.shape.as_list()[0] is not None
return NetworkOutput(policy, mean, logstd, value, state)
2 changes: 0 additions & 2 deletions agents/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ def train(config, env_processes):
"""
tf.reset_default_graph()
with config.unlocked:
config.network = functools.partial(
utility.define_network, config.network, config)
config.policy_optimizer = getattr(tf.train, config.policy_optimizer)
config.value_optimizer = getattr(tf.train, config.value_optimizer)
if config.update_every % config.num_agents:
Expand Down
Loading

1 comment on commit 9d6d8b5

@AdamStelmaszczyk
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is indeed a nice speedup, on my training setup, before 100k steps took on average 310 seconds, now it takes 240 seconds.

Please sign in to comment.