Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Double IQN #69

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 236 additions & 0 deletions examples/atari/train_double_iqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
import argparse
import json
import os

import numpy as np
import torch
from torch import nn

import pfrl
from pfrl import experiments
from pfrl import explorers
from pfrl import utils
from pfrl import replay_buffers
from pfrl.wrappers import atari_wrappers


def parse_agent(agent):
return {"IQN": pfrl.agents.IQN, "DoubleIQN": pfrl.agents.DoubleIQN}[agent]


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4")
parser.add_argument(
"--outdir",
type=str,
default="results",
help=(
"Directory path to save output files."
" If it does not exist, it will be created."
),
)
parser.add_argument("--seed", type=int, default=0, help="Random seed [0, 2 ** 31)")
parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--demo", action="store_true", default=False)
parser.add_argument("--load", type=str, default=None)
parser.add_argument("--final-exploration-frames", type=int, default=10 ** 6)
parser.add_argument("--final-epsilon", type=float, default=0.01)
parser.add_argument("--eval-epsilon", type=float, default=0.001)
parser.add_argument("--steps", type=int, default=5 * 10 ** 7)
parser.add_argument(
"--max-frames",
type=int,
default=30 * 60 * 60, # 30 minutes with 60 fps
help="Maximum number of frames for each episode.",
)
parser.add_argument("--replay-start-size", type=int, default=5 * 10 ** 4)
parser.add_argument("--target-update-interval", type=int, default=10 ** 4)
parser.add_argument(
"--agent", type=str, default="IQN", choices=["IQN", "DoubleIQN"]
)
parser.add_argument(
"--prioritized",
action="store_true",
default=False,
help="Flag to use a prioritized replay buffer",
)
parser.add_argument("--num-step-return", type=int, default=1)
parser.add_argument("--eval-interval", type=int, default=250000)
parser.add_argument("--eval-n-steps", type=int, default=125000)
parser.add_argument("--update-interval", type=int, default=4)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument(
"--log-level",
type=int,
default=20,
help="Logging level. 10:DEBUG, 20:INFO etc.",
)
parser.add_argument(
"--render",
action="store_true",
default=False,
help="Render env states in a GUI window.",
)
parser.add_argument(
"--monitor",
action="store_true",
default=False,
help=(
"Monitor env. Videos and additional information are saved as output files."
),
)
parser.add_argument(
"--batch-accumulator", type=str, default="mean", choices=["mean", "sum"]
)
parser.add_argument("--quantile-thresholds-N", type=int, default=64)
parser.add_argument("--quantile-thresholds-N-prime", type=int, default=64)
parser.add_argument("--quantile-thresholds-K", type=int, default=32)
parser.add_argument("--n-best-episodes", type=int, default=200)
args = parser.parse_args()

import logging

logging.basicConfig(level=args.log_level)

# Set a random seed used in PFRL.
utils.set_random_seed(args.seed)

# Set different random seeds for train and test envs.
train_seed = args.seed
test_seed = 2 ** 31 - 1 - args.seed

args.outdir = experiments.prepare_output_dir(args, args.outdir)
print("Output files are saved in {}".format(args.outdir))

def make_env(test):
# Use different random seeds for train and test envs
env_seed = test_seed if test else train_seed
env = atari_wrappers.wrap_deepmind(
atari_wrappers.make_atari(args.env, max_frames=args.max_frames),
episode_life=not test,
clip_rewards=not test,
)
env.seed(int(env_seed))
if test:
# Randomize actions like epsilon-greedy in evaluation as well
env = pfrl.wrappers.RandomizeAction(env, args.eval_epsilon)
if args.monitor:
env = pfrl.wrappers.Monitor(
env, args.outdir, mode="evaluation" if test else "training"
)
if args.render:
env = pfrl.wrappers.Render(env)
return env

env = make_env(test=False)
eval_env = make_env(test=True)
n_actions = env.action_space.n

q_func = pfrl.agents.iqn.ImplicitQuantileQFunction(
psi=nn.Sequential(
nn.Conv2d(4, 32, 8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, 3, stride=1),
nn.ReLU(),
nn.Flatten(),
),
phi=nn.Sequential(pfrl.agents.iqn.CosineBasisLinear(64, 3136), nn.ReLU(),),
f=nn.Sequential(nn.Linear(3136, 512), nn.ReLU(), nn.Linear(512, n_actions),),
)

# Use the same hyper parameters as https://arxiv.org/abs/1710.10044
opt = torch.optim.Adam(q_func.parameters(), lr=5e-5, eps=1e-2 / args.batch_size)

if args.prioritized:
betasteps = args.steps / args.update_interval
rbuf = replay_buffers.PrioritizedReplayBuffer(
10 ** 6,
alpha=0.5,
beta0=0.4,
betasteps=betasteps,
num_steps=args.num_step_return,
)
else:
rbuf = replay_buffers.ReplayBuffer(10 ** 6, num_steps=args.num_step_return,)

explorer = explorers.LinearDecayEpsilonGreedy(
1.0,
args.final_epsilon,
args.final_exploration_frames,
lambda: np.random.randint(n_actions),
)

def phi(x):
# Feature extractor
return np.asarray(x, dtype=np.float32) / 255

Agent = parse_agent(args.agent)
agent = Agent(
q_func,
opt,
rbuf,
gpu=args.gpu,
gamma=0.99,
explorer=explorer,
replay_start_size=args.replay_start_size,
target_update_interval=args.target_update_interval,
update_interval=args.update_interval,
batch_accumulator=args.batch_accumulator,
phi=phi,
quantile_thresholds_N=args.quantile_thresholds_N,
quantile_thresholds_N_prime=args.quantile_thresholds_N_prime,
quantile_thresholds_K=args.quantile_thresholds_K,
)

if args.load:
agent.load(args.load)

if args.demo:
eval_stats = experiments.eval_performance(
env=eval_env, agent=agent, n_steps=args.eval_n_steps, n_episodes=None,
)
print(
"n_steps: {} mean: {} median: {} stdev {}".format(
args.eval_n_steps,
eval_stats["mean"],
eval_stats["median"],
eval_stats["stdev"],
)
)
else:
experiments.train_agent_with_evaluation(
agent=agent,
env=env,
steps=args.steps,
eval_n_steps=args.eval_n_steps,
eval_n_episodes=None,
eval_interval=args.eval_interval,
outdir=args.outdir,
save_best_so_far_agent=True,
eval_env=eval_env,
)

dir_of_best_network = os.path.join(args.outdir, "best")
agent.load(dir_of_best_network)

# run 200 evaluation episodes, each capped at 30 mins of play
stats = experiments.evaluator.eval_performance(
env=eval_env,
agent=agent,
n_steps=None,
n_episodes=args.n_best_episodes,
max_episode_len=args.max_frames / 4,
logger=None,
)
with open(os.path.join(args.outdir, "bestscores.json"), "w") as f:
json.dump(stats, f)
print("The results of the best scoring network:")
for stat in stats:
print(str(stat) + ":" + str(stats[stat]))


if __name__ == "__main__":
main()
12 changes: 12 additions & 0 deletions examples_tests/atari/test_double_iqn.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash

set -Ceu

outdir=$(mktemp -d)

gpu="$1"

# atari/double_iqn
python examples/atari/train_double_iqn.py --env PongNoFrameskip-v4 --steps 100 --replay-start-size 50 --outdir $outdir/atari/double_iqn --eval-n-steps 200 --eval-interval 50 --n-best-episodes 1 --gpu $gpu
model=$(find $outdir/atari/double_iqn -name "*_finish")
python examples/atari/train_double_iqn.py --env PongNoFrameskip-v4 --demo --load $model --outdir $outdir/temp --eval-n-steps 200 --gpu $gpu
1 change: 1 addition & 0 deletions pfrl/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pfrl.agents.categorical_dqn import CategoricalDQN # NOQA
from pfrl.agents.ddpg import DDPG # NOQA
from pfrl.agents.double_dqn import DoubleDQN # NOQA
from pfrl.agents.double_iqn import DoubleIQN # NOQA
from pfrl.agents.double_pal import DoublePAL # NOQA
from pfrl.agents.dpp import DPP # NOQA
from pfrl.agents.dqn import DQN # NOQA
Expand Down
70 changes: 70 additions & 0 deletions pfrl/agents/double_iqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch

from pfrl.agents import iqn
from pfrl.utils import evaluating
from pfrl.utils.recurrent import pack_and_forward


class DoubleIQN(iqn.IQN):

"""IQN with DoubleDQN-like target computation.

For computing targets, rather than have the target network
output the Q-value of its highest-valued action, the
target network outputs the Q-value of the primary network’s
highest valued action.
"""

def _compute_target_values(self, exp_batch):
"""Compute a batch of target return distributions.

Returns:
torch.Tensor: (batch_size, N_prime).
"""
batch_next_state = exp_batch["next_state"]
batch_size = len(exp_batch["reward"])

taus_tilde = torch.rand(
batch_size,
self.quantile_thresholds_K,
device=self.device,
dtype=torch.float,
)
with evaluating(self.model):
if self.recurrent:
next_tau2av, _ = pack_and_forward(
self.model, batch_next_state, exp_batch["next_recurrent_state"],
)
else:
next_tau2av = self.model(batch_next_state)
greedy_actions = next_tau2av(taus_tilde).greedy_actions

taus_prime = torch.rand(
batch_size,
self.quantile_thresholds_N_prime,
device=self.device,
dtype=torch.float,
)
if self.recurrent:
target_next_tau2av, _ = pack_and_forward(
self.target_model, batch_next_state, exp_batch["next_recurrent_state"],
)
else:
target_next_tau2av = self.target_model(batch_next_state)
target_next_maxz = target_next_tau2av(taus_prime).evaluate_actions_as_quantiles(
greedy_actions
)

batch_rewards = exp_batch["reward"]
batch_terminal = exp_batch["is_state_terminal"]
batch_discount = exp_batch["discount"]
assert batch_rewards.shape == (batch_size,)
assert batch_terminal.shape == (batch_size,)
assert batch_discount.shape == (batch_size,)
batch_rewards = batch_rewards.unsqueeze(-1)
batch_terminal = batch_terminal.unsqueeze(-1)
batch_discount = batch_discount.unsqueeze(-1)

return (
batch_rewards + batch_discount * (1.0 - batch_terminal) * target_next_maxz
)
Loading