Skip to content

Commit

Permalink
Merge pull request #287 from seann999/patch-1
Browse files Browse the repository at this point in the history
Noisy network fixes
  • Loading branch information
muupan authored Jul 23, 2018
2 parents f6c3779 + 9f8f347 commit 7523bb7
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 11 deletions.
9 changes: 7 additions & 2 deletions chainerrl/initializers/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,10 @@ def __init__(self, scale=1.0, dtype=None):
def __call__(self, array):
if self.dtype is not None:
assert array.dtype == self.dtype
fan_in, _ = initializer.get_fans(array.shape)
Constant(self.scale / numpy.sqrt(fan_in))(array)

if len(array.shape) == 1:
Constant(self.scale / numpy.sqrt(array.shape[0]))(array)
else:
fan_in, _ = initializer.get_fans(array.shape)

Constant(self.scale / numpy.sqrt(fan_in))(array)
5 changes: 5 additions & 0 deletions chainerrl/links/noisy_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from chainer.links import Linear

from chainerrl.links.noisy_linear import FactorizedNoisyLinear
from chainerrl.links.sequence import Sequence


def to_factorized_noisy(link, *args, **kwargs):
Expand Down Expand Up @@ -47,3 +48,7 @@ def _map_links(func, link):
# mimic ChainList.add_link
children[i] = new_child
children[i].name = str(i)

if isinstance(link, Sequence):
# assumes i-th layer corresponds with i-th child
link.layers[i] = new_child
17 changes: 10 additions & 7 deletions chainerrl/links/noisy_linear.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import chainer
import chainer.functions as F
from chainer.initializers import Constant
from chainer.initializers import LeCunUniform
import chainer.links as L
import numpy

Expand All @@ -23,15 +23,18 @@ def __init__(self, mu_link, sigma_scale=0.4):

W_data = mu_link.W.data
in_size = None if W_data is None else W_data.shape[1]
device_id = mu_link._device_id

with self.init_scope():
self.mu = mu_link
self.sigma = L.Linear(
in_size=in_size, out_size=self.out_size, nobias=self.nobias,
initialW=VarianceScalingConstant(sigma_scale),
initial_bias=Constant(sigma_scale))
self.mu = L.Linear(in_size, self.out_size, self.nobias,
initialW=LeCunUniform(1 / numpy.sqrt(3)))

self.sigma = L.Linear(in_size, self.out_size, self.nobias,
initialW=VarianceScalingConstant(
sigma_scale),
initial_bias=VarianceScalingConstant(
sigma_scale))

device_id = self.mu._device_id
if device_id is not None:
self.to_gpu(device_id)

Expand Down
2 changes: 1 addition & 1 deletion chainerrl/links/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class Sequence(chainer.ChainList, RecurrentChainMixin):
"""Sequential callable Link that consists of other Links."""

def __init__(self, *layers):
self.layers = layers
self.layers = list(layers)
links = [layer for layer in layers if isinstance(layer, chainer.Link)]
# Cache the signatures because it might be slow
self.argnames = [set(signature(layer).parameters)
Expand Down
6 changes: 6 additions & 0 deletions examples/ale/train_dqn_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def main():
type=int, default=10 ** 6)
parser.add_argument('--final-epsilon', type=float, default=0.1)
parser.add_argument('--eval-epsilon', type=float, default=0.05)
parser.add_argument('--noisy-net-sigma', type=float, default=None)
parser.add_argument('--arch', type=str, default='nature',
choices=['nature', 'nips', 'dueling'])
parser.add_argument('--steps', type=int, default=10 ** 7)
Expand Down Expand Up @@ -142,6 +143,11 @@ def make_env(test):
activation = parse_activation(args.activation)
q_func = parse_arch(args.arch, n_actions, activation)

if args.noisy_net_sigma is not None:
links.to_factorized_noisy(q_func)
# Turn off explorer
explorer = explorers.Greedy()

# Draw the computational graph and save it in the output directory.
chainerrl.misc.draw_computational_graph(
[q_func(np.zeros((4, 84, 84), dtype=np.float32)[None])],
Expand Down
2 changes: 1 addition & 1 deletion tests/links_tests/test_noisy_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _test_randomness(self, xp):
if self.nobias:
target *= 2 / numpy.pi
else:
target *= 2 / numpy.pi + numpy.sqrt(2 / numpy.pi)
target *= 2 / numpy.pi + numpy.sqrt(2 / numpy.pi) / y1.shape[1]

self.assertGreater(d, target / 3.)
self.assertLess(d, target * 3.)
Expand Down

0 comments on commit 7523bb7

Please sign in to comment.