From 3a29a98286807ecfb0a0235854ec8ce5ac44e6d5 Mon Sep 17 00:00:00 2001 From: muupan Date: Fri, 28 Sep 2018 14:35:52 +0900 Subject: [PATCH] Fix the error caused by inexact delta_z --- chainerrl/agents/categorical_dqn.py | 2 ++ tests/agents_tests/test_categorical_dqn.py | 28 ++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/chainerrl/agents/categorical_dqn.py b/chainerrl/agents/categorical_dqn.py index 2f338a35e..01e0b9d47 100644 --- a/chainerrl/agents/categorical_dqn.py +++ b/chainerrl/agents/categorical_dqn.py @@ -41,6 +41,8 @@ def _apply_categorical_projection(y, y_probs, z): # bj: (batch_size, n_atoms) bj = (y - v_min) / delta_z assert bj.shape == (batch_size, n_atoms) + # Avoid the error caused by inexact delta_z + bj = xp.clip(bj, 0, n_atoms - 1) # l, u: (batch_size, n_atoms) l, u = xp.floor(bj), xp.ceil(bj) diff --git a/tests/agents_tests/test_categorical_dqn.py b/tests/agents_tests/test_categorical_dqn.py index d6a9a31e6..0370350ac 100644 --- a/tests/agents_tests/test_categorical_dqn.py +++ b/tests/agents_tests/test_categorical_dqn.py @@ -131,6 +131,27 @@ def _test(self, xp): proj = categorical_dqn._apply_categorical_projection(y, y_probs, z) xp.testing.assert_allclose(proj, proj_gt, atol=1e-5) + def _test_inexact_delta_z(self, xp): + v_min, v_max = (-1, 1) + n_atoms = 4 + # delta_z=2/3=0.66666... is not exact + z = xp.linspace(v_min, v_max, num=n_atoms, dtype=np.float32) + y = xp.asarray([ + [-1, -1, 1, 1], + [-1, 0, 1, 1], + ], dtype=np.float32) + y_probs = xp.asarray([ + [0.5, 0.1, 0.1, 0.3], + [0.5, 0.2, 0.0, 0.3], + ], dtype=np.float32) + proj_gt = xp.asarray([ + [0.6, 0.0, 0.0, 0.4], + [0.5, 0.1, 0.1, 0.3], + ], dtype=np.float32) + + proj = categorical_dqn._apply_categorical_projection(y, y_probs, z) + xp.testing.assert_allclose(proj, proj_gt, atol=1e-5) + def test_cpu(self): self._test(np) @@ -138,6 +159,13 @@ def test_cpu(self): def test_gpu(self): self._test(chainer.cuda.cupy) + def test_inexact_delta_z_cpu(self): + self._test_inexact_delta_z(np) + + @testing.attr.gpu + def test_inexact_delta_z_gpu(self): + self._test_inexact_delta_z(chainer.cuda.cupy) + def make_distrib_ff_q_func(env): n_atoms = 51