Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A3C Example for reproducing paper results. #433

Merged
merged 12 commits into from
Jun 5, 2019
131 changes: 131 additions & 0 deletions examples/atari/a3c/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# A3C
This example trains an Asynchronous Advantage Actor Critic (A3C) agent, from the following paper: [Asynchronous Methods for Deep Reinforcement Learning](https://arxiv.org/abs/1602.01783).

## Requirements

- atari_py>=0.1.1
- opencv-python

## Running the Example

```
python train_a3c.py [options]
```

### Useful Options
- `--gpu`. Specifies the GPU. If you do not have a GPU on your machine, run the example with the option `--gpu -1`. E.g. `python train_a3c.py --gpu -1`.
- `--env`. Specifies the environment.
- `--render`. Add this option to render the states in a GUI window.
- `--seed`. This option specifies the random seed used.
- `--outdir` This option specifies the output directory to which the results are written.

To view the full list of options, either view the code or run the example with the `--help` option.

## Results
These results reflect ChainerRL `v0.6.0`. The ChainerRL score currently consists of a single run. The reported results are compared against the scores from the [Noisy Networks Paper](https://arxiv.org/abs/1706.10295), since the original paper does not report scores for the no-op evaluation protocol.

**NOTE: These benchmark scores below come from running train_a3c.py and evaluating every 1 million timesteps, as opposed to every 250K timesteps. New benchmark results will come soon.**

| Results Summary ||
| ------------- |:-------------:|
| Number of seeds | 1 |
| Number of common domains | 52 |
| Number of domains where paper scores higher | 25 |
| Number of domains where ChainerRL scores higher | 24 |
| Number of ties between paper and ChainerRL | 3 |


| Game | ChainerRL Score | Original Reported Scores |
| ------------- |:-------------:|:-------------:|
prabhatnagarajan marked this conversation as resolved.
Show resolved Hide resolved
| AirRaid | 4625.9| N/A|
| Alien | 1397.2| **2027**|
| Amidar | **1110.8**| 904|
| Assault | **5821.6**| 2879|
| Asterix | 6820.7| **6822**|
| Asteroids | 2428.8| **2544**|
| Atlantis | **732425.0**| 422700|
| BankHeist | **1308.9**| 1296|
| BattleZone | 5421.1| **16411**|
| BeamRider | 8493.4| **9214**|
| Berzerk | **1594.2**| 1022|
| Bowling | 31.7| **37**|
| Boxing | **98.1**| 91|
| Breakout | **533.6**| 496|
| Carnival | 5132.9| N/A|
| Centipede | 4849.9| **5350**|
| ChopperCommand | 4881.0| **5285**|
| CrazyClimber | 124400.0| **134783**|
| Defender | N/A| 52917.0|
| DemonAttack | **108832.5**| 37085|
| DoubleDunk | 1.5| **3**|
| Enduro | **0.0**| **0**|
| FishingDerby | **36.3**| -7|
| Freeway | **0.0**| **0**|
| Frostbite | **313.6**| 288|
| Gopher | **8746.5**| 7992|
| Gravitar | 228.0| **379**|
| Hero | **36892.5**| 30791|
| IceHockey | -4.6| **-2**|
| JamesBond | N/A| 509.0|
| Jamesbond | 370.1| N/A|
| JourneyEscape | -871.2| N/A|
| Kangaroo | 115.8| **1166**|
| Krull | **10601.4**| 9422|
| KungFuMaster | **40970.4**| 37422|
| MontezumaRevenge | 1.9| **14**|
| MsPacman | **2498.0**| 2436|
| NameThisGame | 6597.0| **7168**|
| Phoenix | **42654.5**| 9476|
| Pitfall | -10.8| N/A|
| Pitfall! | N/A| 0.0|
| Pong | **20.9**| 7|
| Pooyan | 4067.9| N/A|
| PrivateEye | 376.1| **3781**|
| Qbert | 15610.6| **18586**|
| Riverraid | 13223.3| N/A|
| RoadRunner | 39897.8| **45315**|
| Robotank | 2.9| **6**|
| Seaquest | **1786.5**| 1744|
| Skiing | -16090.5| **-12972**|
| Solaris | 3157.8| **12380**|
| SpaceInvaders | **1630.6**| 1034|
| StarGunner | **57943.2**| 49156|
| Surround | N/A| -8.0|
| Tennis | **-0.3**| -6|
| TimePilot | 3850.6| **10294**|
| Tutankham | **331.4**| 213|
| UpNDown | 17952.0| **89067**|
| Venture | **0.0**| **0**|
| VideoPinball | **407331.2**| 229402|
| WizardOfWor | 2800.0| **8953**|
| YarsRevenge | **25175.5**| 21596|
| Zaxxon | 80.7| **16544**|


## Evaluation Protocol

Our evaluation protocol is designed to mirror the evaluation protocol from the [Noisy Networks Paper](https://arxiv.org/abs/1706.10295) as closely as possible, since the original A3C paper does not report reproducible results (they use human starts trajectories which are not publicly available). The reported results are from the [Noisy Networks Paper](https://arxiv.org/abs/1706.10295), Table 3.

Our evaluation protocol is designed to mirror the evaluation protocol of the original paper as closely as possible, in order to offer a fair comparison of the quality of our example implementation. Specifically, the details of our evaluation (also can be found in the code) are the following:

- **Evaluation Frequency**: The agent is evaluated after every 1 million frames (250K timesteps). This results in a total of 200 "intermediate" evaluations.
- **Evaluation Phase**: The agent is evaluated for 500K frames (125K timesteps) in each intermediate evaluation.
- **Output**: The output of an intermediate evaluation phase is a score representing the mean score of all completed evaluation episodes within the 125K timesteps. If there is any unfinished episode by the time the 125K timestep evaluation phase is finished, that episode is discarded.
- **Intermediate Evaluation Episode**:
- Each intermediate evaluation episode is capped in length at 27K timesteps or 108K frames.
- Each evaluation episode begins with a random number of no-ops (up to 30), where this number is chosen uniformly at random.
- **Reporting**: For each run of our A3C example, we report the highest scores amongst each of the intermediate evaluation phases. This differs from the original A3C paper which states that: "We additionally used the final network weights for evaluation". This is because the [Noisy Networks Paper](https://arxiv.org/abs/1706.10295) states that "Per-game maximum scores are computed by taking the maximum raw scores of the agent and then averaging over three seeds".


## Training times

We trained with 17 CPUs and no GPU. However, we used 16 processes (as per the A3C paper).


| Statistic | | |
| ------------- |:-------------:|:-------------:|
| Mean time (in days) across all domains | 1.08299383309 |
| Fastest Domain | DemonAttack | 0.736027011088 |
| Slowest Domain | UpNDown | 1.25626688715 |


177 changes: 177 additions & 0 deletions examples/atari/a3c/train_a3c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals
from __future__ import absolute_import
from builtins import * # NOQA
from future import standard_library
standard_library.install_aliases() # NOQA
import argparse
import os

# Prevent numpy from using multiple threads
os.environ['OMP_NUM_THREADS'] = '1' # NOQA

import chainer
import gym
import gym.wrappers
import numpy as np

import chainerrl
from chainerrl.agents import a3c
from chainerrl import experiments
from chainerrl import links
from chainerrl import misc
from chainerrl.optimizers.nonbias_weight_decay import NonbiasWeightDecay
from chainerrl.optimizers import rmsprop_async
from chainerrl import policy
from chainerrl import v_function

from chainerrl.wrappers import atari_wrappers


class A3CFF(chainer.ChainList, a3c.A3CModel):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming this follows the specification of the A3C paper:

"The agents used the network architecture
from (Mnih et al., 2013). The network used a convolutional layer with 16 filters of size 8 × 8 with stride
4, followed by a convolutional layer with with 32 filters of size 4 × 4 with stride 2, followed by a fully
connected layer with 256 hidden units. All three hidden layers were followed by a rectifier nonlinearity. The
value-based methods had a single linear output unit for each action representing the action-value. The model
used by actor-critic agents had two set of outputs – a softmax output with one entry per action representing the
probability of selecting the action, and a single linear output representing the value function." (Source: A3C Paper)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Source: Noisy Nets paper - "In each case, we used the neural network architecture from the
corresponding original papers for both the baseline and NoisyNet variant"


def __init__(self, n_actions):
self.head = links.NIPSDQNHead()
self.pi = policy.FCSoftmaxPolicy(
self.head.n_output_channels, n_actions)
self.v = v_function.FCVFunction(self.head.n_output_channels)
super().__init__(self.head, self.pi, self.v)

def pi_and_v(self, state):
out = self.head(state)
return self.pi(out), self.v(out)


def main():

parser = argparse.ArgumentParser()
parser.add_argument('--processes', type=int, default=16)
parser.add_argument('--env', type=str, default='BreakoutNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=0,
help='Random seed [0, 2 ** 31)')
parser.add_argument('--outdir', type=str, default='results',
help='Directory path to save output files.'
' If it does not exist, it will be created.')
parser.add_argument('--t-max', type=int, default=5)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Source: A3C paper, Appendix 8: "All methods performed updates after every 5 actions (tmax = 5 and
IUpdate = 5) and shared RMSProp was used for optimization"

parser.add_argument('--beta', type=float, default=1e-2)
parser.add_argument('--profile', action='store_true')
parser.add_argument('--steps', type=int, default=8 * 10 ** 7)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Source: Noisy nets paper - "The DQN and A3C agents were training for
200M and 320M frames, respectively".

parser.add_argument('--max-frames', type=int,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Source: Noisy Networks paper - "Episodes are truncated at 108K frames (or
30 minutes of simulated play) (van Hasselt et al., 2016)." However, it's unclear in the context whether this refers to training or testing. Given the nature of other Deep RL papers, I'm assuming the truncation applies to both training and evaluation.

default=30 * 60 * 60, # 30 minutes with 60 fps
help='Maximum number of frames for each episode.')
parser.add_argument('--lr', type=float, default=7e-4)
parser.add_argument('--eval-interval', type=int, default=250000)
parser.add_argument('--eval-n-steps', type=int, default=125000)
parser.add_argument('--weight-decay', type=float, default=0.0)
parser.add_argument('--demo', action='store_true', default=False)
parser.add_argument('--load', type=str, default='')
parser.add_argument('--logging-level', type=int, default=20,
help='Logging level. 10:DEBUG, 20:INFO etc.')
parser.add_argument('--render', action='store_true', default=False,
help='Render env states in a GUI window.')
parser.add_argument('--monitor', action='store_true', default=False,
help='Monitor env. Videos and additional information'
' are saved as output files.')
args = parser.parse_args()

import logging
logging.basicConfig(level=args.logging_level)

# Set a random seed used in ChainerRL.
# If you use more than one processes, the results will be no longer
# deterministic even with the same random seed.
misc.set_random_seed(args.seed)

# Set different random seeds for different subprocesses.
# If seed=0 and processes=4, subprocess seeds are [0, 1, 2, 3].
# If seed=1 and processes=4, subprocess seeds are [4, 5, 6, 7].
process_seeds = np.arange(args.processes) + args.seed * args.processes
assert process_seeds.max() < 2 ** 31

args.outdir = experiments.prepare_output_dir(args, args.outdir)
print('Output files are saved in {}'.format(args.outdir))

n_actions = gym.make(args.env).action_space.n

model = A3CFF(n_actions)

# Draw the computational graph and save it in the output directory.
fake_obs = chainer.Variable(
np.zeros((4, 84, 84), dtype=np.float32)[None],
name='observation')
with chainerrl.recurrent.state_reset(model):
# The state of the model is reset again after drawing the graph
chainerrl.misc.draw_computational_graph(
[model(fake_obs)],
os.path.join(args.outdir, 'model'))

opt = rmsprop_async.RMSpropAsync(lr=7e-4, eps=1e-1, alpha=0.99)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Source: A3C Paper - "and an RMSProp decay factor of α = 0.99"

opt.setup(model)
opt.add_hook(chainer.optimizer.GradientClipping(40))
if args.weight_decay > 0:
opt.add_hook(NonbiasWeightDecay(args.weight_decay))

def phi(x):
# Feature extractor
return np.asarray(x, dtype=np.float32) / 255

agent = a3c.A3C(model, opt, t_max=args.t_max, gamma=0.99,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Source: A3C Paper - "All experiments
used a discount of γ = 0.99"

beta=args.beta, phi=phi)

if args.load:
agent.load(args.load)

def make_env(process_idx, test):
# Use different random seeds for train and test envs
process_seed = process_seeds[process_idx]
env_seed = 2 ** 31 - 1 - process_seed if test else process_seed
env = atari_wrappers.wrap_deepmind(
atari_wrappers.make_atari(args.env, max_frames=args.max_frames),
episode_life=not test,
clip_rewards=not test)
env.seed(int(env_seed))
if args.monitor:
env = gym.wrappers.Monitor(
env, args.outdir,
mode='evaluation' if test else 'training')
if args.render:
env = chainerrl.wrappers.Render(env)
return env

if args.demo:
env = make_env(0, True)
eval_stats = experiments.eval_performance(
env=env,
agent=agent,
n_steps=None,
n_episodes=args.eval_n_runs)
print('n_runs: {} mean: {} median: {} stdev: {}'.format(
args.eval_n_runs, eval_stats['mean'], eval_stats['median'],
eval_stats['stdev']))
else:

# Linearly decay the learning rate to zero
def lr_setter(env, agent, value):
agent.optimizer.lr = value

lr_decay_hook = experiments.LinearInterpolationHook(
args.steps, args.lr, 0, lr_setter)

experiments.train_agent_async(
agent=agent,
outdir=args.outdir,
processes=args.processes,
make_env=make_env,
profile=args.profile,
steps=args.steps,
eval_n_steps=args.eval_n_steps,
eval_n_episodes=None,
eval_interval=args.eval_interval,
global_step_hooks=[lr_decay_hook],
save_best_so_far_agent=False,
)


if __name__ == '__main__':
main()