From 28b1bd18e5f964b94098e269b34969548e8e215e Mon Sep 17 00:00:00 2001 From: Prabhat Nagarajan Date: Sat, 6 Jul 2019 01:36:10 -0500 Subject: [PATCH 1/8] adds double IQN, an example, tests, and addresses flakes --- chainerrl/agents/__init__.py | 1 + chainerrl/agents/double_iqn.py | 55 +++++ examples/atari/train_iqn_ale.py | 214 ++++++++++++++++++++ tests/agents_tests/test_double_iqn.py | 48 +++++ tests/agents_tests/test_ppo.py | 4 +- tests/wrappers_tests/test_atari_wrappers.py | 3 +- 6 files changed, 322 insertions(+), 3 deletions(-) create mode 100644 chainerrl/agents/double_iqn.py create mode 100644 examples/atari/train_iqn_ale.py create mode 100644 tests/agents_tests/test_double_iqn.py diff --git a/chainerrl/agents/__init__.py b/chainerrl/agents/__init__.py index ab05f67fd..2deb667ef 100644 --- a/chainerrl/agents/__init__.py +++ b/chainerrl/agents/__init__.py @@ -6,6 +6,7 @@ from chainerrl.agents.categorical_dqn import CategoricalDQN # NOQA from chainerrl.agents.ddpg import DDPG # NOQA from chainerrl.agents.double_dqn import DoubleDQN # NOQA +from chainerrl.agents.double_iqn import DoubleIQN # NOQA from chainerrl.agents.double_pal import DoublePAL # NOQA from chainerrl.agents.dpp import DPP # NOQA from chainerrl.agents.dqn import DQN # NOQA diff --git a/chainerrl/agents/double_iqn.py b/chainerrl/agents/double_iqn.py new file mode 100644 index 000000000..e2864b5f5 --- /dev/null +++ b/chainerrl/agents/double_iqn.py @@ -0,0 +1,55 @@ +from __future__ import unicode_literals +from __future__ import print_function +from __future__ import division +from __future__ import absolute_import +from future import standard_library +standard_library.install_aliases() # NOQA + +import chainer +import chainer.functions as F + +from chainerrl.agents import iqn +from chainerrl.agents.categorical_dqn import _apply_categorical_projection +from chainerrl.recurrent import state_kept + + +class DoubleIQN(iqn.IQN): + """Double IQN. + + """ + + def _compute_target_values(self, exp_batch): + """Compute a batch of target return distributions. + + Returns: + chainer.Variable: (batch_size, N_prime). + """ + batch_next_state = exp_batch['next_state'] + batch_rewards = exp_batch['reward'] + batch_terminal = exp_batch['is_state_terminal'] + batch_size = len(exp_batch['reward']) + taus_tilde = self.xp.random.uniform( + 0, 1, size=(batch_size, self.quantile_thresholds_K)).astype('f') + + next_tau2av = self.model(batch_next_state) + greedy_actions = next_tau2av(taus_tilde).greedy_actions + taus_prime = self.xp.random.uniform( + 0, 1, + size=(batch_size, self.quantile_thresholds_N_prime)).astype('f') + target_next_tau2av = self.target_model(batch_next_state) + target_next_maxz = target_next_tau2av( + taus_prime).evaluate_actions_as_quantiles(greedy_actions) + + batch_discount = exp_batch['discount'] + assert batch_rewards.shape == (batch_size,) + assert batch_terminal.shape == (batch_size,) + assert batch_discount.shape == (batch_size,) + batch_rewards = F.broadcast_to( + batch_rewards[..., None], target_next_maxz.shape) + batch_terminal = F.broadcast_to( + batch_terminal[..., None], target_next_maxz.shape) + batch_discount = F.broadcast_to( + batch_discount[..., None], target_next_maxz.shape) + + return (batch_rewards + + batch_discount * (1.0 - batch_terminal) * target_next_maxz) diff --git a/examples/atari/train_iqn_ale.py b/examples/atari/train_iqn_ale.py new file mode 100644 index 000000000..a71134b1d --- /dev/null +++ b/examples/atari/train_iqn_ale.py @@ -0,0 +1,214 @@ +from __future__ import print_function +from __future__ import division +from __future__ import unicode_literals +from __future__ import absolute_import +from builtins import * # NOQA +from future import standard_library +standard_library.install_aliases() # NOQA +import argparse +import functools +import json +import os + +import chainer +import chainer.functions as F +import chainer.links as L +import gym +import numpy as np + +import chainerrl +from chainerrl import experiments +from chainerrl import explorers +from chainerrl import misc +from chainerrl import replay_buffer +from chainerrl.wrappers import atari_wrappers + + +def parse_agent(agent): + return {'IQN': chainerrl.agents.IQN, + 'DoubleIQN': chainerrl.agents.DoubleIQN}[agent] + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--env', type=str, default='BreakoutNoFrameskip-v4') + parser.add_argument('--outdir', type=str, default='results', + help='Directory path to save output files.' + ' If it does not exist, it will be created.') + parser.add_argument('--seed', type=int, default=0, + help='Random seed [0, 2 ** 31)') + parser.add_argument('--gpu', type=int, default=0) + parser.add_argument('--demo', action='store_true', default=False) + parser.add_argument('--load', type=str, default=None) + parser.add_argument('--final-exploration-frames', + type=int, default=10 ** 6) + parser.add_argument('--final-epsilon', type=float, default=0.01) + parser.add_argument('--eval-epsilon', type=float, default=0.001) + parser.add_argument('--steps', type=int, default=5 * 10 ** 7) + parser.add_argument('--max-frames', type=int, + default=30 * 60 * 60, # 30 minutes with 60 fps + help='Maximum number of frames for each episode.') + parser.add_argument('--replay-start-size', type=int, default=5 * 10 ** 4) + parser.add_argument('--target-update-interval', + type=int, default=10 ** 4) + parser.add_argument('--agent', type=str, default='IQN', + choices=['IQN', 'DoubleIQN']) + parser.add_argument('--eval-interval', type=int, default=250000) + parser.add_argument('--eval-n-steps', type=int, default=125000) + parser.add_argument('--update-interval', type=int, default=4) + parser.add_argument('--batch-size', type=int, default=32) + parser.add_argument('--logging-level', type=int, default=20, + help='Logging level. 10:DEBUG, 20:INFO etc.') + parser.add_argument('--render', action='store_true', default=False, + help='Render env states in a GUI window.') + parser.add_argument('--monitor', action='store_true', default=False, + help='Monitor env. Videos and additional information' + ' are saved as output files.') + parser.add_argument('--batch-accumulator', type=str, default='mean', + choices=['mean', 'sum']) + parser.add_argument('--quantile-thresholds-N', type=int, default=64) + parser.add_argument('--quantile-thresholds-N-prime', type=int, default=64) + parser.add_argument('--quantile-thresholds-K', type=int, default=32) + parser.add_argument('--n-best-episodes', type=int, default=200) + args = parser.parse_args() + + import logging + logging.basicConfig(level=args.logging_level) + + # Set a random seed used in ChainerRL. + misc.set_random_seed(args.seed, gpus=(args.gpu,)) + + # Set different random seeds for train and test envs. + train_seed = args.seed + test_seed = 2 ** 31 - 1 - args.seed + + args.outdir = experiments.prepare_output_dir(args, args.outdir) + print('Output files are saved in {}'.format(args.outdir)) + + def make_env(test): + # Use different random seeds for train and test envs + env_seed = test_seed if test else train_seed + env = atari_wrappers.wrap_deepmind( + atari_wrappers.make_atari(args.env, max_frames=args.max_frames), + episode_life=not test, + clip_rewards=not test) + env.seed(int(env_seed)) + if test: + # Randomize actions like epsilon-greedy in evaluation as well + env = chainerrl.wrappers.RandomizeAction(env, args.eval_epsilon) + if args.monitor: + env = gym.wrappers.Monitor( + env, args.outdir, + mode='evaluation' if test else 'training') + if args.render: + env = chainerrl.wrappers.Render(env) + return env + + env = make_env(test=False) + eval_env = make_env(test=True) + n_actions = env.action_space.n + + q_func = chainerrl.agents.iqn.ImplicitQuantileQFunction( + psi=chainerrl.links.Sequence( + L.Convolution2D(None, 32, 8, stride=4), + F.relu, + L.Convolution2D(None, 64, 4, stride=2), + F.relu, + L.Convolution2D(None, 64, 3, stride=1), + F.relu, + functools.partial(F.reshape, shape=(-1, 3136)), + ), + phi=chainerrl.links.Sequence( + chainerrl.agents.iqn.CosineBasisLinear(64, 3136), + F.relu, + ), + f=chainerrl.links.Sequence( + L.Linear(None, 512), + F.relu, + L.Linear(None, n_actions), + ), + ) + + # Draw the computational graph and save it in the output directory. + fake_obss = np.zeros((4, 84, 84), dtype=np.float32)[None] + fake_taus = np.zeros(32, dtype=np.float32)[None] + chainerrl.misc.draw_computational_graph( + [q_func(fake_obss)(fake_taus)], + os.path.join(args.outdir, 'model')) + + # Use the same hyper parameters as https://arxiv.org/abs/1710.10044 + opt = chainer.optimizers.Adam(5e-5, eps=1e-2 / args.batch_size) + opt.setup(q_func) + + rbuf = replay_buffer.ReplayBuffer(10 ** 6) + + explorer = explorers.LinearDecayEpsilonGreedy( + 1.0, args.final_epsilon, + args.final_exploration_frames, + lambda: np.random.randint(n_actions)) + + def phi(x): + # Feature extractor + return np.asarray(x, dtype=np.float32) / 255 + + Agent = parse_agent(args.agent) + agent = Agent( + q_func, opt, rbuf, gpu=args.gpu, gamma=0.99, + explorer=explorer, replay_start_size=args.replay_start_size, + target_update_interval=args.target_update_interval, + update_interval=args.update_interval, + batch_accumulator=args.batch_accumulator, + phi=phi, + quantile_thresholds_N=args.quantile_thresholds_N, + quantile_thresholds_N_prime=args.quantile_thresholds_N_prime, + quantile_thresholds_K=args.quantile_thresholds_K, + ) + + if args.load: + agent.load(args.load) + + if args.demo: + eval_stats = experiments.eval_performance( + env=eval_env, + agent=agent, + n_steps=args.eval_n_steps, + n_episodes=None, + ) + print('n_steps: {} mean: {} median: {} stdev {}'.format( + args.eval_n_steps, eval_stats['mean'], eval_stats['median'], + eval_stats['stdev'])) + else: + experiments.train_agent_with_evaluation( + agent=agent, + env=env, + steps=args.steps, + eval_n_steps=args.eval_n_steps, + eval_n_episodes=None, + eval_interval=args.eval_interval, + outdir=args.outdir, + save_best_so_far_agent=True, + eval_env=eval_env, + ) + + dir_of_best_network = os.path.join(args.outdir, "best") + agent.load(dir_of_best_network) + + # run 200 evaluation episodes, each capped at 30 mins of play + stats = experiments.evaluator.eval_performance( + env=eval_env, + agent=agent, + n_steps=None, + n_episodes=args.n_best_episodes, + max_episode_len=args.max_frames / 4, + logger=None) + with open(os.path.join(args.outdir, 'bestscores.json'), 'w') as f: + # temporary hack to handle python 2/3 support issues. + # json dumps does not support non-string literal dict keys + json_stats = json.dumps(stats) + print(str(json_stats), file=f) + print("The results of the best scoring network:") + for stat in stats: + print(str(stat) + ":" + str(stats[stat])) + + +if __name__ == '__main__': + main() diff --git a/tests/agents_tests/test_double_iqn.py b/tests/agents_tests/test_double_iqn.py new file mode 100644 index 000000000..8afc713d8 --- /dev/null +++ b/tests/agents_tests/test_double_iqn.py @@ -0,0 +1,48 @@ +from __future__ import unicode_literals +from __future__ import print_function +from __future__ import division +from __future__ import absolute_import +from future import standard_library +from builtins import * # NOQA +standard_library.install_aliases() # NOQA + +import chainer.functions as F +import chainer.links as L +from chainer import testing + +import basetest_dqn_like as base +from basetest_training import _TestBatchTrainingMixin +import chainerrl +from chainerrl.agents import double_iqn +from chainerrl.agents import iqn + + +@testing.parameterize(*testing.product({ + 'quantile_thresholds_N': [1, 5], + 'quantile_thresholds_N_prime': [1, 7], +})) +class TestIQNOnDiscreteABC( + _TestBatchTrainingMixin, base._TestDQNOnDiscreteABC): + + def make_q_func(self, env): + obs_size = env.observation_space.low.size + hidden_size = 64 + return iqn.ImplicitQuantileQFunction( + psi=chainerrl.links.Sequence( + L.Linear(obs_size, hidden_size), + F.relu, + ), + phi=chainerrl.links.Sequence( + iqn.CosineBasisLinear(32, hidden_size), + F.relu, + ), + f=L.Linear(hidden_size, env.action_space.n), + ) + + def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu): + return double_iqn.DoubleIQN( + q_func, opt, rbuf, gpu=gpu, gamma=0.9, explorer=explorer, + replay_start_size=100, target_update_interval=100, + quantile_thresholds_N=self.quantile_thresholds_N, + quantile_thresholds_N_prime=self.quantile_thresholds_N_prime, + ) diff --git a/tests/agents_tests/test_ppo.py b/tests/agents_tests/test_ppo.py index 180acbc1a..68e6dc9c9 100644 --- a/tests/agents_tests/test_ppo.py +++ b/tests/agents_tests/test_ppo.py @@ -21,8 +21,8 @@ import chainerrl from chainerrl.agents.a3c import A3CSeparateModel -from chainerrl.agents.ppo import PPO from chainerrl.agents import ppo +from chainerrl.agents.ppo import PPO from chainerrl.envs.abc import ABC from chainerrl.experiments.evaluator import batch_run_evaluation_episodes from chainerrl.experiments.evaluator import run_evaluation_episodes @@ -31,8 +31,8 @@ from chainerrl.misc.batch_states import batch_states from chainerrl import policies -from chainerrl.links import StatelessRecurrentSequential from chainerrl.links import StatelessRecurrentBranched +from chainerrl.links import StatelessRecurrentSequential def make_random_episodes(n_episodes=10, obs_size=2, n_actions=3): diff --git a/tests/wrappers_tests/test_atari_wrappers.py b/tests/wrappers_tests/test_atari_wrappers.py index ee3c857a8..de72228db 100644 --- a/tests/wrappers_tests/test_atari_wrappers.py +++ b/tests/wrappers_tests/test_atari_wrappers.py @@ -17,7 +17,8 @@ import gym.spaces import numpy as np -from chainerrl.wrappers.atari_wrappers import FrameStack, LazyFrames +from chainerrl.wrappers.atari_wrappers import FrameStack +from chainerrl.wrappers.atari_wrappers import LazyFrames from chainerrl.wrappers.atari_wrappers import ScaledFloatFrame From c3129af33f4044c42f2fbda134b05f44600365b9 Mon Sep 17 00:00:00 2001 From: Prabhat Nagarajan Date: Sat, 6 Jul 2019 01:41:13 -0500 Subject: [PATCH 2/8] addresses more flakes --- chainerrl/agents/double_iqn.py | 3 --- chainerrl/wrappers/vector_frame_stack.py | 2 +- examples/atari/train_iqn_ale.py | 1 + 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/chainerrl/agents/double_iqn.py b/chainerrl/agents/double_iqn.py index e2864b5f5..be796b314 100644 --- a/chainerrl/agents/double_iqn.py +++ b/chainerrl/agents/double_iqn.py @@ -5,12 +5,9 @@ from future import standard_library standard_library.install_aliases() # NOQA -import chainer import chainer.functions as F from chainerrl.agents import iqn -from chainerrl.agents.categorical_dqn import _apply_categorical_projection -from chainerrl.recurrent import state_kept class DoubleIQN(iqn.IQN): diff --git a/chainerrl/wrappers/vector_frame_stack.py b/chainerrl/wrappers/vector_frame_stack.py index bd8c82b48..00763b6c5 100644 --- a/chainerrl/wrappers/vector_frame_stack.py +++ b/chainerrl/wrappers/vector_frame_stack.py @@ -10,8 +10,8 @@ from gym import spaces import numpy as np -from chainerrl.wrappers.atari_wrappers import LazyFrames from chainerrl.env import VectorEnv +from chainerrl.wrappers.atari_wrappers import LazyFrames class VectorEnvWrapper(VectorEnv): diff --git a/examples/atari/train_iqn_ale.py b/examples/atari/train_iqn_ale.py index a71134b1d..f9d225b3a 100644 --- a/examples/atari/train_iqn_ale.py +++ b/examples/atari/train_iqn_ale.py @@ -28,6 +28,7 @@ def parse_agent(agent): return {'IQN': chainerrl.agents.IQN, 'DoubleIQN': chainerrl.agents.DoubleIQN}[agent] + def main(): parser = argparse.ArgumentParser() parser.add_argument('--env', type=str, default='BreakoutNoFrameskip-v4') From 56627f6418f7b7bd2022f2154eef17d96796cbe6 Mon Sep 17 00:00:00 2001 From: Prabhat Nagarajan Date: Thu, 1 Aug 2019 04:49:34 -0500 Subject: [PATCH 3/8] adds an IQN test script --- examples_tests/atari/test_iqn.sh | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100755 examples_tests/atari/test_iqn.sh diff --git a/examples_tests/atari/test_iqn.sh b/examples_tests/atari/test_iqn.sh new file mode 100755 index 000000000..d189cd55c --- /dev/null +++ b/examples_tests/atari/test_iqn.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +set -Ceu + +outdir=$(mktemp -d) + +gpu="$1" + +# atari/iqn +python examples/atari/train_iqn_ale.py --env PongNoFrameskip-v4 --steps 100 --replay-start-size 50 --outdir $outdir/atari/iqn --eval-n-steps 200 --eval-interval 50 --n-best-episodes 1 --gpu $gpu +model=$(find $outdir/atari/iqn -name "*_finish") +python examples/atari/train_iqn_ale.py --env PongNoFrameskip-v4 --demo --load $model --outdir $outdir/temp --eval-n-steps 200 --gpu $gpu From c11b190fabd1422c24f3177142e6fa8c919a37d5 Mon Sep 17 00:00:00 2001 From: Prabhat Nagarajan Date: Fri, 9 Aug 2019 07:19:16 -0500 Subject: [PATCH 4/8] renames example scripts and adds additional details to double IQN description --- chainerrl/agents/double_iqn.py | 4 +++- .../atari/{train_iqn_ale.py => train_double_iqn.py} | 0 examples_tests/atari/test_double_iqn.sh | 12 ++++++++++++ examples_tests/atari/test_iqn.sh | 12 ------------ 4 files changed, 15 insertions(+), 13 deletions(-) rename examples/atari/{train_iqn_ale.py => train_double_iqn.py} (100%) create mode 100755 examples_tests/atari/test_double_iqn.sh delete mode 100755 examples_tests/atari/test_iqn.sh diff --git a/chainerrl/agents/double_iqn.py b/chainerrl/agents/double_iqn.py index be796b314..2c1c9a8b4 100644 --- a/chainerrl/agents/double_iqn.py +++ b/chainerrl/agents/double_iqn.py @@ -11,8 +11,10 @@ class DoubleIQN(iqn.IQN): - """Double IQN. + """Double IQN. Using the primary Q-network's greedy/max action + to compute the target value rather than use the target network's + max action. """ def _compute_target_values(self, exp_batch): diff --git a/examples/atari/train_iqn_ale.py b/examples/atari/train_double_iqn.py similarity index 100% rename from examples/atari/train_iqn_ale.py rename to examples/atari/train_double_iqn.py diff --git a/examples_tests/atari/test_double_iqn.sh b/examples_tests/atari/test_double_iqn.sh new file mode 100755 index 000000000..326d72540 --- /dev/null +++ b/examples_tests/atari/test_double_iqn.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +set -Ceu + +outdir=$(mktemp -d) + +gpu="$1" + +# atari/double_iqn +python examples/atari/train_double_iqn.py --env PongNoFrameskip-v4 --steps 100 --replay-start-size 50 --outdir $outdir/atari/double_iqn --eval-n-steps 200 --eval-interval 50 --n-best-episodes 1 --gpu $gpu +model=$(find $outdir/atari/double_iqn -name "*_finish") +python examples/atari/train_double_iqn.py --env PongNoFrameskip-v4 --demo --load $model --outdir $outdir/temp --eval-n-steps 200 --gpu $gpu diff --git a/examples_tests/atari/test_iqn.sh b/examples_tests/atari/test_iqn.sh deleted file mode 100755 index d189cd55c..000000000 --- a/examples_tests/atari/test_iqn.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/bash - -set -Ceu - -outdir=$(mktemp -d) - -gpu="$1" - -# atari/iqn -python examples/atari/train_iqn_ale.py --env PongNoFrameskip-v4 --steps 100 --replay-start-size 50 --outdir $outdir/atari/iqn --eval-n-steps 200 --eval-interval 50 --n-best-episodes 1 --gpu $gpu -model=$(find $outdir/atari/iqn -name "*_finish") -python examples/atari/train_iqn_ale.py --env PongNoFrameskip-v4 --demo --load $model --outdir $outdir/temp --eval-n-steps 200 --gpu $gpu From 3cd68074d18ba93cc438d9303c8c9c40aa0fc60c Mon Sep 17 00:00:00 2001 From: Prabhat Nagarajan Date: Tue, 13 Aug 2019 08:01:17 -0500 Subject: [PATCH 5/8] adds support for recurrent double IQN --- chainerrl/agents/double_iqn.py | 20 ++++++++++++++--- tests/agents_tests/test_double_iqn.py | 31 ++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/chainerrl/agents/double_iqn.py b/chainerrl/agents/double_iqn.py index 2c1c9a8b4..96b5f31fe 100644 --- a/chainerrl/agents/double_iqn.py +++ b/chainerrl/agents/double_iqn.py @@ -29,13 +29,27 @@ def _compute_target_values(self, exp_batch): batch_size = len(exp_batch['reward']) taus_tilde = self.xp.random.uniform( 0, 1, size=(batch_size, self.quantile_thresholds_K)).astype('f') - - next_tau2av = self.model(batch_next_state) + with chainer.using_config('train', False): + if self.recurrent: + next_tau2av, _ = self.model.n_step_forward( + batch_next_state, + exp_batch['next_recurrent_state'], + output_mode='concat', + ) + else: + next_tau2av = self.model(batch_next_state) greedy_actions = next_tau2av(taus_tilde).greedy_actions taus_prime = self.xp.random.uniform( 0, 1, size=(batch_size, self.quantile_thresholds_N_prime)).astype('f') - target_next_tau2av = self.target_model(batch_next_state) + if self.recurrent: + target_next_tau2av, _ = self.target_model.n_step_forward( + batch_next_state, + exp_batch['next_recurrent_state'], + output_mode='concat', + ) + else: + target_next_tau2av = self.target_model(batch_next_state) target_next_maxz = target_next_tau2av( taus_prime).evaluate_actions_as_quantiles(greedy_actions) diff --git a/tests/agents_tests/test_double_iqn.py b/tests/agents_tests/test_double_iqn.py index 8afc713d8..9dcd5b5fd 100644 --- a/tests/agents_tests/test_double_iqn.py +++ b/tests/agents_tests/test_double_iqn.py @@ -21,7 +21,7 @@ 'quantile_thresholds_N': [1, 5], 'quantile_thresholds_N_prime': [1, 7], })) -class TestIQNOnDiscreteABC( +class TestDoubleIQNOnDiscreteABC( _TestBatchTrainingMixin, base._TestDQNOnDiscreteABC): def make_q_func(self, env): @@ -46,3 +46,32 @@ def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu): quantile_thresholds_N=self.quantile_thresholds_N, quantile_thresholds_N_prime=self.quantile_thresholds_N_prime, ) + +class TestDoubleIQNOnDiscretePOABC( + _TestBatchTrainingMixin, base._TestDQNOnDiscretePOABC): + + def make_q_func(self, env): + obs_size = env.observation_space.low.size + hidden_size = 64 + return iqn.StatelessRecurrentImplicitQuantileQFunction( + psi=chainerrl.links.StatelessRecurrentSequential( + L.Linear(obs_size, hidden_size), + F.relu, + L.NStepRNNTanh(1, hidden_size, hidden_size, 0), + ), + phi=chainerrl.links.Sequence( + chainerrl.agents.iqn.CosineBasisLinear(32, hidden_size), + F.relu, + ), + f=L.Linear(hidden_size, env.action_space.n, + initialW=chainer.initializers.LeCunNormal(1e-1)), + ) + + def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu): + return double_iqn.DoubleIQN( + q_func, opt, rbuf, gpu=gpu, gamma=0.9, explorer=explorer, + replay_start_size=100, target_update_interval=100, + quantile_thresholds_N=32, + quantile_thresholds_N_prime=32, + recurrent=True, + ) From 5d72eef4371952040e6d0c08e45e706c849df6c4 Mon Sep 17 00:00:00 2001 From: Prabhat Nagarajan Date: Wed, 14 Aug 2019 01:16:48 -0500 Subject: [PATCH 6/8] addresses flakes --- chainerrl/agents/double_iqn.py | 1 + tests/agents_tests/test_double_iqn.py | 2 ++ tests/wrappers_tests/test_monitor.py | 8 ++++---- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/chainerrl/agents/double_iqn.py b/chainerrl/agents/double_iqn.py index 96b5f31fe..37c67b567 100644 --- a/chainerrl/agents/double_iqn.py +++ b/chainerrl/agents/double_iqn.py @@ -5,6 +5,7 @@ from future import standard_library standard_library.install_aliases() # NOQA +import chainer import chainer.functions as F from chainerrl.agents import iqn diff --git a/tests/agents_tests/test_double_iqn.py b/tests/agents_tests/test_double_iqn.py index 9dcd5b5fd..fa515c831 100644 --- a/tests/agents_tests/test_double_iqn.py +++ b/tests/agents_tests/test_double_iqn.py @@ -6,6 +6,7 @@ from builtins import * # NOQA standard_library.install_aliases() # NOQA +import chainer import chainer.functions as F import chainer.links as L from chainer import testing @@ -47,6 +48,7 @@ def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu): quantile_thresholds_N_prime=self.quantile_thresholds_N_prime, ) + class TestDoubleIQNOnDiscretePOABC( _TestBatchTrainingMixin, base._TestDQNOnDiscretePOABC): diff --git a/tests/wrappers_tests/test_monitor.py b/tests/wrappers_tests/test_monitor.py index 1e5162dfd..61747eb08 100644 --- a/tests/wrappers_tests/test_monitor.py +++ b/tests/wrappers_tests/test_monitor.py @@ -6,10 +6,10 @@ from future import standard_library standard_library.install_aliases() # NOQA -import unittest -import tempfile -import shutil import os +import shutil +import tempfile +import unittest from chainer import testing import gym @@ -50,7 +50,7 @@ def test(self): if done or info.get('needs_reset', False) or t == steps: if episode_idx + 1 == self.n_episodes or t == steps: break - _ = env.reset() + env.reset() episode_idx += 1 episode_len = 0 # `env.close()` is called when `env` is gabage-collected From 4c3b231c631ae1a6615dc8ac6b7c4186fc8b7ae3 Mon Sep 17 00:00:00 2001 From: Prabhat Nagarajan Date: Tue, 20 Aug 2019 03:36:43 -0500 Subject: [PATCH 7/8] modifies docstring --- chainerrl/agents/double_iqn.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/chainerrl/agents/double_iqn.py b/chainerrl/agents/double_iqn.py index 37c67b567..cb025f993 100644 --- a/chainerrl/agents/double_iqn.py +++ b/chainerrl/agents/double_iqn.py @@ -12,11 +12,7 @@ class DoubleIQN(iqn.IQN): - """Double IQN. Using the primary Q-network's greedy/max action - - to compute the target value rather than use the target network's - max action. - """ + """Double IQN - Use primary network for target computation.""" def _compute_target_values(self, exp_batch): """Compute a batch of target return distributions. From 8eb275d0dd1bae82d8fb4c5327ab4bb4f38545dc Mon Sep 17 00:00:00 2001 From: Prabhat Date: Tue, 20 Aug 2019 20:15:43 +0900 Subject: [PATCH 8/8] rephrases docstring and adds a longer description --- chainerrl/agents/double_iqn.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/chainerrl/agents/double_iqn.py b/chainerrl/agents/double_iqn.py index cb025f993..6d1cb569f 100644 --- a/chainerrl/agents/double_iqn.py +++ b/chainerrl/agents/double_iqn.py @@ -12,7 +12,13 @@ class DoubleIQN(iqn.IQN): - """Double IQN - Use primary network for target computation.""" + """IQN with DoubleDQN-like target computation. + + For computing targets, rather than have the target network + output the Q-value of its highest-valued action, the + target network outputs the Q-value of the primary network's + highest valued action. + """ def _compute_target_values(self, exp_batch): """Compute a batch of target return distributions.