From 453a4f809f9af4cffd11a5989cecacf959209a48 Mon Sep 17 00:00:00 2001 From: Kaito Suzuki Date: Mon, 7 Sep 2020 19:35:57 +0900 Subject: [PATCH 1/2] add Double IQN, training example, and test based on Double IQN in ChainerRL --- examples/atari/train_double_iqn.py | 249 ++++++++++++++++++++++++ examples_tests/atari/test_double_iqn.sh | 12 ++ pfrl/agents/__init__.py | 1 + pfrl/agents/double_iqn.py | 74 +++++++ tests/agents_tests/test_double_iqn.py | 91 +++++++++ 5 files changed, 427 insertions(+) create mode 100644 examples/atari/train_double_iqn.py create mode 100644 examples_tests/atari/test_double_iqn.sh create mode 100644 pfrl/agents/double_iqn.py create mode 100644 tests/agents_tests/test_double_iqn.py diff --git a/examples/atari/train_double_iqn.py b/examples/atari/train_double_iqn.py new file mode 100644 index 000000000..ff36d7656 --- /dev/null +++ b/examples/atari/train_double_iqn.py @@ -0,0 +1,249 @@ +import argparse +import json +import os + +import numpy as np +import torch +from torch import nn + +import pfrl +from pfrl import experiments +from pfrl import explorers +from pfrl import utils +from pfrl import replay_buffers +from pfrl.wrappers import atari_wrappers + + +def parse_agent(agent): + return {"IQN": pfrl.agents.IQN, "DoubleIQN": pfrl.agents.DoubleIQN}[agent] + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4") + parser.add_argument( + "--outdir", + type=str, + default="results", + help=( + "Directory path to save output files." + " If it does not exist, it will be created." + ), + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed [0, 2 ** 31)") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--demo", action="store_true", default=False) + parser.add_argument("--load", type=str, default=None) + parser.add_argument("--final-exploration-frames", type=int, default=10 ** 6) + parser.add_argument("--final-epsilon", type=float, default=0.01) + parser.add_argument("--eval-epsilon", type=float, default=0.001) + parser.add_argument("--steps", type=int, default=5 * 10 ** 7) + parser.add_argument( + "--max-frames", + type=int, + default=30 * 60 * 60, # 30 minutes with 60 fps + help="Maximum number of frames for each episode.", + ) + parser.add_argument("--replay-start-size", type=int, default=5 * 10 ** 4) + parser.add_argument("--target-update-interval", type=int, default=10 ** 4) + parser.add_argument( + "--agent", type=str, default="IQN", choices=["IQN", "DoubleIQN"] + ) + parser.add_argument( + "--prioritized", + action="store_true", + default=False, + help="Flag to use a prioritized replay buffer", + ) + parser.add_argument("--num-step-return", type=int, default=1) + parser.add_argument("--eval-interval", type=int, default=250000) + parser.add_argument("--eval-n-steps", type=int, default=125000) + parser.add_argument("--update-interval", type=int, default=4) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument( + "--log-level", + type=int, + default=20, + help="Logging level. 10:DEBUG, 20:INFO etc.", + ) + parser.add_argument( + "--render", + action="store_true", + default=False, + help="Render env states in a GUI window.", + ) + parser.add_argument( + "--monitor", + action="store_true", + default=False, + help=( + "Monitor env. Videos and additional information are saved as output files." + ), + ) + parser.add_argument( + "--batch-accumulator", type=str, default="mean", choices=["mean", "sum"] + ) + parser.add_argument("--quantile-thresholds-N", type=int, default=64) + parser.add_argument("--quantile-thresholds-N-prime", type=int, default=64) + parser.add_argument("--quantile-thresholds-K", type=int, default=32) + parser.add_argument("--n-best-episodes", type=int, default=200) + args = parser.parse_args() + + import logging + + logging.basicConfig(level=args.log_level) + + # Set a random seed used in PFRL. + utils.set_random_seed(args.seed) + + # Set different random seeds for train and test envs. + train_seed = args.seed + test_seed = 2 ** 31 - 1 - args.seed + + args.outdir = experiments.prepare_output_dir(args, args.outdir) + print("Output files are saved in {}".format(args.outdir)) + + def make_env(test): + # Use different random seeds for train and test envs + env_seed = test_seed if test else train_seed + env = atari_wrappers.wrap_deepmind( + atari_wrappers.make_atari(args.env, max_frames=args.max_frames), + episode_life=not test, + clip_rewards=not test, + ) + env.seed(int(env_seed)) + if test: + # Randomize actions like epsilon-greedy in evaluation as well + env = pfrl.wrappers.RandomizeAction(env, args.eval_epsilon) + if args.monitor: + env = pfrl.wrappers.Monitor( + env, args.outdir, mode="evaluation" if test else "training" + ) + if args.render: + env = pfrl.wrappers.Render(env) + return env + + env = make_env(test=False) + eval_env = make_env(test=True) + n_actions = env.action_space.n + + q_func = pfrl.agents.iqn.ImplicitQuantileQFunction( + psi=nn.Sequential( + nn.Conv2d(4, 32, 8, stride=4), + nn.ReLU(), + nn.Conv2d(32, 64, 4, stride=2), + nn.ReLU(), + nn.Conv2d(64, 64, 3, stride=1), + nn.ReLU(), + nn.Flatten(), + ), + phi=nn.Sequential( + pfrl.agents.iqn.CosineBasisLinear(64, 3136), + nn.ReLU(), + ), + f=nn.Sequential( + nn.Linear(3136, 512), + nn.ReLU(), + nn.Linear(512, n_actions), + ), + ) + + # Use the same hyper parameters as https://arxiv.org/abs/1710.10044 + opt = torch.optim.Adam(q_func.parameters(), lr=5e-5, eps=1e-2 / args.batch_size) + + if args.prioritized: + betasteps = args.steps / args.update_interval + rbuf = replay_buffers.PrioritizedReplayBuffer( + 10 ** 6, + alpha=0.5, + beta0=0.4, + betasteps=betasteps, + num_steps=args.num_step_return, + ) + else: + rbuf = replay_buffers.ReplayBuffer( + 10 ** 6, + num_steps=args.num_step_return, + ) + + explorer = explorers.LinearDecayEpsilonGreedy( + 1.0, + args.final_epsilon, + args.final_exploration_frames, + lambda: np.random.randint(n_actions), + ) + + def phi(x): + # Feature extractor + return np.asarray(x, dtype=np.float32) / 255 + + Agent = parse_agent(args.agent) + agent = Agent( + q_func, + opt, + rbuf, + gpu=args.gpu, + gamma=0.99, + explorer=explorer, + replay_start_size=args.replay_start_size, + target_update_interval=args.target_update_interval, + update_interval=args.update_interval, + batch_accumulator=args.batch_accumulator, + phi=phi, + quantile_thresholds_N=args.quantile_thresholds_N, + quantile_thresholds_N_prime=args.quantile_thresholds_N_prime, + quantile_thresholds_K=args.quantile_thresholds_K, + ) + + if args.load: + agent.load(args.load) + + if args.demo: + eval_stats = experiments.eval_performance( + env=eval_env, + agent=agent, + n_steps=args.eval_n_steps, + n_episodes=None, + ) + print( + "n_steps: {} mean: {} median: {} stdev {}".format( + args.eval_n_steps, + eval_stats["mean"], + eval_stats["median"], + eval_stats["stdev"], + ) + ) + else: + experiments.train_agent_with_evaluation( + agent=agent, + env=env, + steps=args.steps, + eval_n_steps=args.eval_n_steps, + eval_n_episodes=None, + eval_interval=args.eval_interval, + outdir=args.outdir, + save_best_so_far_agent=True, + eval_env=eval_env, + ) + + dir_of_best_network = os.path.join(args.outdir, "best") + agent.load(dir_of_best_network) + + # run 200 evaluation episodes, each capped at 30 mins of play + stats = experiments.evaluator.eval_performance( + env=eval_env, + agent=agent, + n_steps=None, + n_episodes=args.n_best_episodes, + max_episode_len=args.max_frames / 4, + logger=None, + ) + with open(os.path.join(args.outdir, "bestscores.json"), "w") as f: + json.dump(stats, f) + print("The results of the best scoring network:") + for stat in stats: + print(str(stat) + ":" + str(stats[stat])) + + +if __name__ == "__main__": + main() diff --git a/examples_tests/atari/test_double_iqn.sh b/examples_tests/atari/test_double_iqn.sh new file mode 100644 index 000000000..326d72540 --- /dev/null +++ b/examples_tests/atari/test_double_iqn.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +set -Ceu + +outdir=$(mktemp -d) + +gpu="$1" + +# atari/double_iqn +python examples/atari/train_double_iqn.py --env PongNoFrameskip-v4 --steps 100 --replay-start-size 50 --outdir $outdir/atari/double_iqn --eval-n-steps 200 --eval-interval 50 --n-best-episodes 1 --gpu $gpu +model=$(find $outdir/atari/double_iqn -name "*_finish") +python examples/atari/train_double_iqn.py --env PongNoFrameskip-v4 --demo --load $model --outdir $outdir/temp --eval-n-steps 200 --gpu $gpu diff --git a/pfrl/agents/__init__.py b/pfrl/agents/__init__.py index b8cd64644..7c40dfcb0 100644 --- a/pfrl/agents/__init__.py +++ b/pfrl/agents/__init__.py @@ -6,6 +6,7 @@ from pfrl.agents.categorical_dqn import CategoricalDQN # NOQA from pfrl.agents.ddpg import DDPG # NOQA from pfrl.agents.double_dqn import DoubleDQN # NOQA +from pfrl.agents.double_iqn import DoubleIQN # NOQA from pfrl.agents.double_pal import DoublePAL # NOQA from pfrl.agents.dpp import DPP # NOQA from pfrl.agents.dqn import DQN # NOQA diff --git a/pfrl/agents/double_iqn.py b/pfrl/agents/double_iqn.py new file mode 100644 index 000000000..073d4d6b1 --- /dev/null +++ b/pfrl/agents/double_iqn.py @@ -0,0 +1,74 @@ +import torch + +from pfrl.agents import iqn +from pfrl.utils import evaluating +from pfrl.utils.recurrent import pack_and_forward + + +class DoubleIQN(iqn.IQN): + + """IQN with DoubleDQN-like target computation. + + For computing targets, rather than have the target network + output the Q-value of its highest-valued action, the + target network outputs the Q-value of the primary network’s + highest valued action. + """ + + def _compute_target_values(self, exp_batch): + """Compute a batch of target return distributions. + + Returns: + torch.Tensor: (batch_size, N_prime). + """ + batch_next_state = exp_batch["next_state"] + batch_size = len(exp_batch["reward"]) + + taus_tilde = torch.rand( + batch_size, + self.quantile_thresholds_K, + device=self.device, + dtype=torch.float, + ) + with evaluating(self.model): + if self.recurrent: + next_tau2av, _ = pack_and_forward( + self.model, + batch_next_state, + exp_batch["next_recurrent_state"], + ) + else: + next_tau2av = self.model(batch_next_state) + greedy_actions = next_tau2av(taus_tilde).greedy_actions + + taus_prime = torch.rand( + batch_size, + self.quantile_thresholds_N_prime, + device=self.device, + dtype=torch.float, + ) + if self.recurrent: + target_next_tau2av, _ = pack_and_forward( + self.target_model, + batch_next_state, + exp_batch["next_recurrent_state"], + ) + else: + target_next_tau2av = self.target_model(batch_next_state) + target_next_maxz = target_next_tau2av(taus_prime).evaluate_actions_as_quantiles( + greedy_actions + ) + + batch_rewards = exp_batch["reward"] + batch_terminal = exp_batch["is_state_terminal"] + batch_discount = exp_batch["discount"] + assert batch_rewards.shape == (batch_size,) + assert batch_terminal.shape == (batch_size,) + assert batch_discount.shape == (batch_size,) + batch_rewards = batch_rewards.unsqueeze(-1) + batch_terminal = batch_terminal.unsqueeze(-1) + batch_discount = batch_discount.unsqueeze(-1) + + return ( + batch_rewards + batch_discount * (1.0 - batch_terminal) * target_next_maxz + ) diff --git a/tests/agents_tests/test_double_iqn.py b/tests/agents_tests/test_double_iqn.py new file mode 100644 index 000000000..66d8f7e1c --- /dev/null +++ b/tests/agents_tests/test_double_iqn.py @@ -0,0 +1,91 @@ +from torch import nn +import pytest + +import basetest_dqn_like as base + +from basetest_training import _TestBatchTrainingMixin +import pfrl +from pfrl.agents import double_iqn, iqn + + +@pytest.mark.parametrize("quantile_thresholds_N", [1, 5]) +@pytest.mark.parametrize("quantile_thresholds_N_prime", [1, 7]) +class TestDoubleIQNOnDiscreteABC( + _TestBatchTrainingMixin, + base._TestDQNOnDiscreteABC, +): + @pytest.fixture(autouse=True) + def set_iqn_params(self, quantile_thresholds_N, quantile_thresholds_N_prime): + self.quantile_thresholds_N = quantile_thresholds_N + self.quantile_thresholds_N_prime = quantile_thresholds_N_prime + + def make_q_func(self, env): + obs_size = env.observation_space.low.size + hidden_size = 64 + return iqn.ImplicitQuantileQFunction( + psi=nn.Sequential( + nn.Linear(obs_size, hidden_size), + nn.ReLU(), + ), + phi=nn.Sequential( + pfrl.agents.iqn.CosineBasisLinear(32, hidden_size), + nn.ReLU(), + ), + f=nn.Linear(hidden_size, env.action_space.n), + ) + + def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu): + return double_iqn.DoubleIQN( + q_func, + opt, + rbuf, + gpu=gpu, + gamma=0.9, + explorer=explorer, + replay_start_size=100, + target_update_interval=100, + quantile_thresholds_N=self.quantile_thresholds_N, + quantile_thresholds_N_prime=self.quantile_thresholds_N_prime, + act_deterministically=True, + ) + + +class TestDoubleIQNOnDiscretePOABC( + _TestBatchTrainingMixin, + base._TestDQNOnDiscretePOABC, +): + def make_q_func(self, env): + obs_size = env.observation_space.low.size + hidden_size = 64 + return iqn.RecurrentImplicitQuantileQFunction( + psi=pfrl.nn.RecurrentSequential( + nn.Linear(obs_size, hidden_size), + nn.ReLU(), + nn.RNN( + num_layers=1, + input_size=hidden_size, + hidden_size=hidden_size, + ), + ), + phi=nn.Sequential( + pfrl.agents.iqn.CosineBasisLinear(32, hidden_size), + nn.ReLU(), + ), + f=nn.Linear(hidden_size, env.action_space.n), + ) + + def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu): + return double_iqn.DoubleIQN( + q_func, + opt, + rbuf, + gpu=gpu, + gamma=0.9, + explorer=explorer, + replay_start_size=100, + target_update_interval=100, + quantile_thresholds_N=32, + quantile_thresholds_N_prime=32, + recurrent=True, + act_deterministically=True, + ) From a98d6f58dfbf679d044190e0958db528ef96bb53 Mon Sep 17 00:00:00 2001 From: Kaito Suzuki Date: Thu, 24 Sep 2020 19:43:48 +0900 Subject: [PATCH 2/2] Reformat some code by black --- examples/atari/train_double_iqn.py | 21 ++++----------------- pfrl/agents/double_iqn.py | 8 ++------ tests/agents_tests/test_double_iqn.py | 23 ++++++----------------- 3 files changed, 12 insertions(+), 40 deletions(-) diff --git a/examples/atari/train_double_iqn.py b/examples/atari/train_double_iqn.py index ff36d7656..733c8174c 100644 --- a/examples/atari/train_double_iqn.py +++ b/examples/atari/train_double_iqn.py @@ -137,15 +137,8 @@ def make_env(test): nn.ReLU(), nn.Flatten(), ), - phi=nn.Sequential( - pfrl.agents.iqn.CosineBasisLinear(64, 3136), - nn.ReLU(), - ), - f=nn.Sequential( - nn.Linear(3136, 512), - nn.ReLU(), - nn.Linear(512, n_actions), - ), + phi=nn.Sequential(pfrl.agents.iqn.CosineBasisLinear(64, 3136), nn.ReLU(),), + f=nn.Sequential(nn.Linear(3136, 512), nn.ReLU(), nn.Linear(512, n_actions),), ) # Use the same hyper parameters as https://arxiv.org/abs/1710.10044 @@ -161,10 +154,7 @@ def make_env(test): num_steps=args.num_step_return, ) else: - rbuf = replay_buffers.ReplayBuffer( - 10 ** 6, - num_steps=args.num_step_return, - ) + rbuf = replay_buffers.ReplayBuffer(10 ** 6, num_steps=args.num_step_return,) explorer = explorers.LinearDecayEpsilonGreedy( 1.0, @@ -200,10 +190,7 @@ def phi(x): if args.demo: eval_stats = experiments.eval_performance( - env=eval_env, - agent=agent, - n_steps=args.eval_n_steps, - n_episodes=None, + env=eval_env, agent=agent, n_steps=args.eval_n_steps, n_episodes=None, ) print( "n_steps: {} mean: {} median: {} stdev {}".format( diff --git a/pfrl/agents/double_iqn.py b/pfrl/agents/double_iqn.py index 073d4d6b1..ffb781e18 100644 --- a/pfrl/agents/double_iqn.py +++ b/pfrl/agents/double_iqn.py @@ -33,9 +33,7 @@ def _compute_target_values(self, exp_batch): with evaluating(self.model): if self.recurrent: next_tau2av, _ = pack_and_forward( - self.model, - batch_next_state, - exp_batch["next_recurrent_state"], + self.model, batch_next_state, exp_batch["next_recurrent_state"], ) else: next_tau2av = self.model(batch_next_state) @@ -49,9 +47,7 @@ def _compute_target_values(self, exp_batch): ) if self.recurrent: target_next_tau2av, _ = pack_and_forward( - self.target_model, - batch_next_state, - exp_batch["next_recurrent_state"], + self.target_model, batch_next_state, exp_batch["next_recurrent_state"], ) else: target_next_tau2av = self.target_model(batch_next_state) diff --git a/tests/agents_tests/test_double_iqn.py b/tests/agents_tests/test_double_iqn.py index 66d8f7e1c..a3edbfd41 100644 --- a/tests/agents_tests/test_double_iqn.py +++ b/tests/agents_tests/test_double_iqn.py @@ -11,8 +11,7 @@ @pytest.mark.parametrize("quantile_thresholds_N", [1, 5]) @pytest.mark.parametrize("quantile_thresholds_N_prime", [1, 7]) class TestDoubleIQNOnDiscreteABC( - _TestBatchTrainingMixin, - base._TestDQNOnDiscreteABC, + _TestBatchTrainingMixin, base._TestDQNOnDiscreteABC, ): @pytest.fixture(autouse=True) def set_iqn_params(self, quantile_thresholds_N, quantile_thresholds_N_prime): @@ -23,13 +22,9 @@ def make_q_func(self, env): obs_size = env.observation_space.low.size hidden_size = 64 return iqn.ImplicitQuantileQFunction( - psi=nn.Sequential( - nn.Linear(obs_size, hidden_size), - nn.ReLU(), - ), + psi=nn.Sequential(nn.Linear(obs_size, hidden_size), nn.ReLU(),), phi=nn.Sequential( - pfrl.agents.iqn.CosineBasisLinear(32, hidden_size), - nn.ReLU(), + pfrl.agents.iqn.CosineBasisLinear(32, hidden_size), nn.ReLU(), ), f=nn.Linear(hidden_size, env.action_space.n), ) @@ -51,8 +46,7 @@ def make_dqn_agent(self, env, q_func, opt, explorer, rbuf, gpu): class TestDoubleIQNOnDiscretePOABC( - _TestBatchTrainingMixin, - base._TestDQNOnDiscretePOABC, + _TestBatchTrainingMixin, base._TestDQNOnDiscretePOABC, ): def make_q_func(self, env): obs_size = env.observation_space.low.size @@ -61,15 +55,10 @@ def make_q_func(self, env): psi=pfrl.nn.RecurrentSequential( nn.Linear(obs_size, hidden_size), nn.ReLU(), - nn.RNN( - num_layers=1, - input_size=hidden_size, - hidden_size=hidden_size, - ), + nn.RNN(num_layers=1, input_size=hidden_size, hidden_size=hidden_size,), ), phi=nn.Sequential( - pfrl.agents.iqn.CosineBasisLinear(32, hidden_size), - nn.ReLU(), + pfrl.agents.iqn.CosineBasisLinear(32, hidden_size), nn.ReLU(), ), f=nn.Linear(hidden_size, env.action_space.n), )