Skip to content

Commit

Permalink
Fix dimensionality of log_std in GaussianMLPModule (#2266)
Browse files Browse the repository at this point in the history
  • Loading branch information
krzentner authored May 8, 2021
1 parent fa0cd08 commit 5bc155b
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 116 deletions.
141 changes: 70 additions & 71 deletions src/garage/torch/modules/gaussian_mlp_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self,
input_dim,
output_dim,
hidden_sizes=(32, 32),
*,
hidden_nonlinearity=torch.tanh,
hidden_w_init=nn.init.xavier_uniform_,
hidden_b_init=nn.init.zeros_,
Expand Down Expand Up @@ -121,12 +122,13 @@ def __init__(self,
if self._std_parameterization not in ('exp', 'softplus'):
raise NotImplementedError

init_std_param = torch.Tensor([init_std]).log()
self._init_std = torch.Tensor([init_std]).log()
log_std = torch.Tensor([init_std] * output_dim).log()
if self._learn_std:
self._init_std = torch.nn.Parameter(init_std_param)
self._log_std = torch.nn.Parameter(log_std)
else:
self._init_std = init_std_param
self.register_buffer('init_std', self._init_std)
self._log_std = log_std
self.register_buffer('log_std', self._log_std)

self._min_std_param = self._max_std_param = None
if min_std is not None:
Expand All @@ -146,8 +148,8 @@ def to(self, *args, **kwargs):
"""
super().to(*args, **kwargs)
buffers = dict(self.named_buffers())
if not isinstance(self._init_std, torch.nn.Parameter):
self._init_std = buffers['init_std']
if not isinstance(self._log_std, torch.nn.Parameter):
self._log_std = buffers['log_std']
self._min_std_param = buffers['min_std_param']
self._max_std_param = buffers['max_std_param']

Expand Down Expand Up @@ -242,6 +244,7 @@ def __init__(self,
input_dim,
output_dim,
hidden_sizes=(32, 32),
*,
hidden_nonlinearity=torch.tanh,
hidden_w_init=nn.init.xavier_uniform_,
hidden_b_init=nn.init.zeros_,
Expand All @@ -255,23 +258,22 @@ def __init__(self,
std_parameterization='exp',
layer_normalization=False,
normal_distribution_cls=Normal):
super(GaussianMLPModule,
self).__init__(input_dim=input_dim,
output_dim=output_dim,
hidden_sizes=hidden_sizes,
hidden_nonlinearity=hidden_nonlinearity,
hidden_w_init=hidden_w_init,
hidden_b_init=hidden_b_init,
output_nonlinearity=output_nonlinearity,
output_w_init=output_w_init,
output_b_init=output_b_init,
learn_std=learn_std,
init_std=init_std,
min_std=min_std,
max_std=max_std,
std_parameterization=std_parameterization,
layer_normalization=layer_normalization,
normal_distribution_cls=normal_distribution_cls)
super().__init__(input_dim=input_dim,
output_dim=output_dim,
hidden_sizes=hidden_sizes,
hidden_nonlinearity=hidden_nonlinearity,
hidden_w_init=hidden_w_init,
hidden_b_init=hidden_b_init,
output_nonlinearity=output_nonlinearity,
output_w_init=output_w_init,
output_b_init=output_b_init,
learn_std=learn_std,
init_std=init_std,
min_std=min_std,
max_std=max_std,
std_parameterization=std_parameterization,
layer_normalization=layer_normalization,
normal_distribution_cls=normal_distribution_cls)

self._mean_module = MLPModule(
input_dim=self._input_dim,
Expand All @@ -285,24 +287,21 @@ def __init__(self,
output_b_init=self._output_b_init,
layer_normalization=self._layer_normalization)

def _get_mean_and_log_std(self, *inputs):
# pylint: disable=arguments-differ
def _get_mean_and_log_std(self, x):
"""Get mean and std of Gaussian distribution given inputs.
Args:
*inputs: Input to the module.
x: Input to the module.
Returns:
torch.Tensor: The mean of Gaussian distribution.
torch.Tensor: The variance of Gaussian distribution.
"""
assert len(inputs) == 1
mean = self._mean_module(*inputs)

broadcast_shape = list(inputs[0].shape[:-1]) + [self._action_dim]
uncentered_log_std = torch.zeros(*broadcast_shape) + self._init_std
mean = self._mean_module(x)

return mean, uncentered_log_std
return mean, self._log_std


class GaussianMLPIndependentStdModule(GaussianMLPBaseModule):
Expand Down Expand Up @@ -369,6 +368,7 @@ def __init__(self,
input_dim,
output_dim,
hidden_sizes=(32, 32),
*,
hidden_nonlinearity=torch.tanh,
hidden_w_init=nn.init.xavier_uniform_,
hidden_b_init=nn.init.zeros_,
Expand All @@ -388,29 +388,28 @@ def __init__(self,
std_parameterization='exp',
layer_normalization=False,
normal_distribution_cls=Normal):
super(GaussianMLPIndependentStdModule,
self).__init__(input_dim=input_dim,
output_dim=output_dim,
hidden_sizes=hidden_sizes,
hidden_nonlinearity=hidden_nonlinearity,
hidden_w_init=hidden_w_init,
hidden_b_init=hidden_b_init,
output_nonlinearity=output_nonlinearity,
output_w_init=output_w_init,
output_b_init=output_b_init,
learn_std=learn_std,
init_std=init_std,
min_std=min_std,
max_std=max_std,
std_hidden_sizes=std_hidden_sizes,
std_hidden_nonlinearity=std_hidden_nonlinearity,
std_hidden_w_init=std_hidden_w_init,
std_hidden_b_init=std_hidden_b_init,
std_output_nonlinearity=std_output_nonlinearity,
std_output_w_init=std_output_w_init,
std_parameterization=std_parameterization,
layer_normalization=layer_normalization,
normal_distribution_cls=normal_distribution_cls)
super().__init__(input_dim=input_dim,
output_dim=output_dim,
hidden_sizes=hidden_sizes,
hidden_nonlinearity=hidden_nonlinearity,
hidden_w_init=hidden_w_init,
hidden_b_init=hidden_b_init,
output_nonlinearity=output_nonlinearity,
output_w_init=output_w_init,
output_b_init=output_b_init,
learn_std=learn_std,
init_std=init_std,
min_std=min_std,
max_std=max_std,
std_hidden_sizes=std_hidden_sizes,
std_hidden_nonlinearity=std_hidden_nonlinearity,
std_hidden_w_init=std_hidden_w_init,
std_hidden_b_init=std_hidden_b_init,
std_output_nonlinearity=std_output_nonlinearity,
std_output_w_init=std_output_w_init,
std_parameterization=std_parameterization,
layer_normalization=layer_normalization,
normal_distribution_cls=normal_distribution_cls)

self._mean_module = MLPModule(
input_dim=self._input_dim,
Expand Down Expand Up @@ -512,6 +511,7 @@ def __init__(self,
input_dim,
output_dim,
hidden_sizes=(32, 32),
*,
hidden_nonlinearity=torch.tanh,
hidden_w_init=nn.init.xavier_uniform_,
hidden_b_init=nn.init.zeros_,
Expand All @@ -525,23 +525,22 @@ def __init__(self,
std_parameterization='exp',
layer_normalization=False,
normal_distribution_cls=Normal):
super(GaussianMLPTwoHeadedModule,
self).__init__(input_dim=input_dim,
output_dim=output_dim,
hidden_sizes=hidden_sizes,
hidden_nonlinearity=hidden_nonlinearity,
hidden_w_init=hidden_w_init,
hidden_b_init=hidden_b_init,
output_nonlinearity=output_nonlinearity,
output_w_init=output_w_init,
output_b_init=output_b_init,
learn_std=learn_std,
init_std=init_std,
min_std=min_std,
max_std=max_std,
std_parameterization=std_parameterization,
layer_normalization=layer_normalization,
normal_distribution_cls=normal_distribution_cls)
super().__init__(input_dim=input_dim,
output_dim=output_dim,
hidden_sizes=hidden_sizes,
hidden_nonlinearity=hidden_nonlinearity,
hidden_w_init=hidden_w_init,
hidden_b_init=hidden_b_init,
output_nonlinearity=output_nonlinearity,
output_w_init=output_w_init,
output_b_init=output_b_init,
learn_std=learn_std,
init_std=init_std,
min_std=min_std,
max_std=max_std,
std_parameterization=std_parameterization,
layer_normalization=layer_normalization,
normal_distribution_cls=normal_distribution_cls)

self._shared_mean_log_std_network = MultiHeadedMLPModule(
n_heads=2,
Expand Down
Loading

0 comments on commit 5bc155b

Please sign in to comment.