Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Match performance with stable-baselines (discrete case) #110

Merged
merged 15 commits into from
Aug 3, 2020

Conversation

Miffyli
Copy link
Collaborator

@Miffyli Miffyli commented Jul 16, 2020

This PR will be done when stable-baselines3 agent performance matches stable-baselines in discrete envs. Will be tested on discrete control tasks and Atari environments.

Closes #49
Closes #105

PS: Sorry about the confusing branch-name.

Changes

TODO

  • Match performance of A2C and PPO.

  • A2C Cartpole matches (mostly, see this. Averaged over 10 random seeds for both. Requires the TF-like RMSprop, and even still in the very end SB3 seems more unstable.)

  • A2C Atari matches (mostly, see sb2 and sb3. Original sb3 result here. Three random seeds, each line separate run (ignore legend). Using TF-like RMSprop. Performance and stability mostly matches, except sb2 has sudden spike in performance in Q*Bert. Something to do with stability in distributions?)

  • PPO Cartpole (using rl-zoo parameters, see learning curves, averaged over 20 random seeds)

  • PPO Atari (mostly, see sb2 and sb3 results (shaded curves averaged over two seeds). Q*Bert still seems to have an edge on SB2 for unknown reasons)

  • Check and match performance of DQN. Seems ok. See following learning curves, each curve is an average over three random seeds:
    atari_spaceinvaders.pdf
    atari_qbert.pdf
    atari_breakout.pdf
    atari_pong.pdf

  • Check if "dones" fix can (and should) be moved to computing GAE side.

  • Write docs on how to match A2C and PPO settings to stable-baselines ("moving from stable-baselines"). There are some important quirks to note here. Move this to migration guide PR Migration Guide #123 .

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist:

  • I've read the CONTRIBUTION guide (required)
  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)

@@ -74,7 +74,7 @@ def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch 🙈 Please don't tell me that solve your performance issue.

I know where it comes from ... I shouldn't have copy-pasted from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/master/a2c_ppo_acktr/model.py#L169

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thinking about that, we need to double check VecFrameStack, even though it is the same as SB2.

Copy link
Collaborator Author

@Miffyli Miffyli Jul 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly (luckily? =) ) it did not fix the issues yet. SB3 is still consistently worse in a few of the Atari games I have tested. I am in the process of step-by-step comparisons, will see how that goes.

Edit: Ah yes, stacking on the wrong channels?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or having that kind of issue: ikostrikov/pytorch-a2c-ppo-acktr-gail@84a7582

btw, is it better now with OMP_NUM_THREADS=1 w.r.t. fps? (maybe you should write in the comment the current stand of SB2 vs SB3)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thing that may change is the optimizer implementation and default parameters, for the initialization, I think (at least I tried) to reproduce what was done in SB2.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my question was more what is the fps we want to reach? (what did you have with SB2?)

Hmm I do not have conclusive numbers just yet because I have been running many experiments on same system and can not guarantee fair comparisons, but I think PyTorch variants are about 10% slower with Atari games and 25% slower on toy environments. The latter required the OMP_NUM_THREADS tuning. This sounds reasonable, given the non-compiled nature of PyTorch and the fact the code has not been optimized much yet.

Yes, the issue was that nminibatches lead to different mini-batchsize depending on the number of environments

Ah alright. I will write big notes about this on the "moving from stable-baselines" docs :)

Copy link
Contributor

@m-rph m-rph Jul 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One major change in parameters is the use of batch_size=64 rather than nminibatches=4 in PPO. Using such small batch-size made things very slow FPS-wise, but in some cases sped up the learning. I will focus more on these running-speed things in an another PR.

I would like to add that we may be able to gain a non minuscule speedup by avoiding single data stores but instead storing a whole batch at once.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started documenting the migration here ;)
#123

I would like to add that we may be able to gain a non minuscule speedup by avoiding single data stores but instead storing a while batch at once.

?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mistyped, I meant that if we store a whole batch at once, we should get a sizeable speedup over storing one transition at a time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still not sure what you mean...

@araffin
Copy link
Member

araffin commented Jul 20, 2020

Listing what can be different from PyTorch vs Tensorflow:

EDIT: the tf clip norm seems to be here https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/clip_ops.py#L291 (not that easy to read vs pytorch), and the doc: https://www.tensorflow.org/api_docs/python/tf/clip_by_global_norm

@araffin
Copy link
Member

araffin commented Jul 22, 2020

I am wondering: are you using clip_range_vf for PPO?
The behavior of this parameter changed between SB2 and SB3 (i'm currently documenting all thoses changes in a new branch).

@Miffyli
Copy link
Collaborator Author

Miffyli commented Jul 22, 2020

I am wondering: are you using clip_range_vf for PPO?
The behavior of this parameter changed between SB2 and SB3 (i'm currently documenting all thoses changes in a new branch).

I am using the parameters from rl-zoo for Atari PPO runs, where vf clipping is disabled and I set cliprange_vf=-1 for sb2 and clip_range_vf=None for sb3, which I understood is the matching behaviour (disables it).

@Miffyli
Copy link
Collaborator Author

Miffyli commented Jul 23, 2020

Some progression with A2C: With CartPole you get very similar learning curves (below, averaged over 10 random seeds) with rl-zoo parameters, after you update the PyTorch RMSProp to match TF's implementation. Turns out PyTorch RMSProp does things a little bit different, and these are crucial for stable learning like shown. These changes require a new optimizer or changes to PyTorch code, so should we include a modified RMSProp in stable-baselines3 like done here in another repo? We could include this as an additional optimizer and instruct to use it if one wants to replicate sb2 results, but we could also consider making it default RMSProp optimizer because of its (apparent) stability.

cartpole_comparison

@Miffyli
Copy link
Collaborator Author

Miffyli commented Jul 25, 2020

A2C seems to check out mostly (see the original post with plots) with the fixed RMSprop that is now included under sb2_compat. If this approach is ok, I can write notes in docs about this RMSprop with A2C and do same in #123 .

@araffin
Copy link
Member

araffin commented Jul 25, 2020

A2C seems to check out mostly (see the original post with plots)

including Atari games?

If this approach is ok, I can write notes in docs about this RMSprop with A2C and do same in #123 .

Sounds reasonable, I don't see any better solution... The only thing is which default should we use?
(I will re-run a quick continuous control benchmark with the updated RMSProp, depending on that we will see)

@Miffyli
Copy link
Collaborator Author

Miffyli commented Jul 25, 2020

including Atari games?

Yup! See the original post with plots. To me they seem "close enough" (with this limited amount of runs), except for Q*Bert which at end gets a sudden boost in performance in sb2. I will be checking PPO next and see if there is something common to A2C and PPO the is derp.

Sounds reasonable, I don't see any better solution... The only thing is which default should we use?

TF variant seems more stable and pytorch-image-models repo guys also say they have had better success with it. I'd personally go with that one by default.

(I will re-run a quick continuous control benchmark with the updated RMSProp, depending on that we will see)

Remember to set the parameters manually! I forgot this first time around ^^

        policy_kwargs["optimizer_class"] = RMSpropTFLike
        policy_kwargs["optimizer_kwargs"] = dict(alpha=0.99, eps=1e-5, weight_decay=0)

@araffin
Copy link
Member

araffin commented Jul 25, 2020

After a quick run on Bullet envs, th.optim.RMSProp yield better performances for continuous control...
Mean final reward over 3 seeds on HalfCheetahBulletEnv-v0:

tf-rmsprop: 1192
torch-rmsprop: 1912

@Miffyli
Copy link
Collaborator Author

Miffyli commented Jul 25, 2020

How is the stability, though? I noticed tf.optim.RMSprop learns faster with its bigger gradients but does not seem to converge so easily, while the TF-variant learns slower but is more stable (see the plots I have above).

Edit: In the light of these results we could keep the original enabled by default, though, and instruct people to use the TF-variant if they are experiencing unstable learning.

@araffin
Copy link
Member

araffin commented Jul 25, 2020

How is the stability, though?

A bit unstable at the beginning.

Edit: In the light of these results we could keep the original enabled by default, though, and instruct people to use the TF-variant if they are experiencing unstable learning.

Yes, and add the tf-version as default in the zoo for Atari?

see the plots I have above

I only see the plots where the two are similar.

See the original post with plots.

In the original post, I only see ppo plots...

@Miffyli
Copy link
Collaborator Author

Miffyli commented Jul 25, 2020

Yes, and add the tf-version as default in the zoo for Atari?

Works for me 👍

I only see the plots where the two are similar.
In the original post, I only see ppo plots...

Hmm there should be four A2C plots in total under "TODO" heading: A2C cartpole comparisons (with rmsprop fixes), sb2 and sb3 Atari results for A2C and sb3 Atari results without rmsprop fix.

@araffin
Copy link
Member

araffin commented Jul 25, 2020

Hmm there should be four A2C plots in total under "TODO" heading: A2C cartpole comparisons (with rmsprop fixes), sb2 and sb3 Atari results for A2C and sb3 Atari results without rmsprop fix.

🙈 I was looking at the issue, not the PR...

@Miffyli
Copy link
Collaborator Author

Miffyli commented Jul 27, 2020

Ran some more Atari PPO runs and now the performance seems to match (see the original post for plots). SB3 seems to be consistently lower than SB2 but nothing seems horribly broken. Q*Bert has an edge on SB2 for some reason with both PPO and A2C. I will be re-running experiments with more seeds, but that will take time. @araffin could you comment on the learning curves and tell what you think about the results?

@araffin
Copy link
Member

araffin commented Jul 27, 2020

Ran some more Atari PPO runs and now the performance seems to match (see the original post for plots). SB3 seems to be consistently lower than SB2 but nothing seems horribly broken

Do you know if the ADAM implementation is the same for A2C/PPO?

@Miffyli
Copy link
Collaborator Author

Miffyli commented Jul 27, 2020

Do you know if the ADAM implementation is the same for A2C/PPO?

Quick googling and skimming over the codes they seem to match, and also the A2C experiments matched with Adam (equally unstable :D), so I believe that part checks out.

@araffin
Copy link
Member

araffin commented Jul 27, 2020

And how many random seeds did you try?
For me, it looks good ;) I cannot spot dramatic performance drop, and ppo matches ppo2 performance for the continuous case (I will do a check again though before 1.0 release).

@Miffyli
Copy link
Collaborator Author

Miffyli commented Jul 27, 2020

And how many random seeds did you try?

Each of the curves is slightly different setup but, in general, tend to have the same result (see Figure 5 here, where we have five random seeds per curve). I.e. you can treat each curve as separate run with different random seed. But I will run some more for better conclusion.

@araffin
Copy link
Member

araffin commented Jul 30, 2020

Note: I will retry to run DQN with the updated network and maybe with the updated RMSprop
The pytorch version of chainer also uses a custom implementation: https://github.com/pfnet/pfrl/blob/master/pfrl/optimizers/rmsprop_eps_inside_sqrt.py

@Miffyli
Copy link
Collaborator Author

Miffyli commented Jul 30, 2020

Actually DQN uses Adam for optimizing, and it has been using it since stable-baselinse2, while (I think) the original implementation used rmsprop. It might be worth of trying out what happens if you change the optimizer to stabler rmsprop, as Adam made things unstable with PPO.

On sidenote: I ran Pong on sb3 DQN and was not able to get any improvement while sb2 learns it quickly (inside ~2M steps). I thought sb3 DQN was able to learn Pong, tho? Using parameters from rl-zoo, minus prioritized memory etc.

@araffin
Copy link
Member

araffin commented Jul 30, 2020

On sidenote: I ran Pong on sb3 DQN and was not able to get any improvement while sb2 learns it quickly (inside ~2M steps). I thought sb3 DQN was able to learn Pong, tho? Using parameters from rl-zoo, minus prioritized memory etc.

It was but not as good as expected... SB2 DQN has nothing to do with vanilla DQN...

@Miffyli
Copy link
Collaborator Author

Miffyli commented Aug 2, 2020

It was but not as good as expected... SB2 DQN has nothing to do with vanilla DQN...

To clarify to others: araffin referred to the fact how, by default, SB2 DQN has bunch of modifications enabled (Double-Q, Dueling). Those were disabled for those runs.

I ran more experiments with Atari with the recent hotfix #132 . The learning curves are included in the main post and match mostly. While not perfect I can not tell if issue is in lack of random seeds used (three is rather low), and in any case I do not have the compute to run enough training runs to debug deeper if something differs.

@araffin
Copy link
Member

araffin commented Aug 3, 2020

I ran more experiments with Atari with the recent hotfix #132 . The learning curves are included in the main post and match mostly. While not perfect I can not tell if issue is in lack of random seeds used (three is rather low), and in any case I do not have the compute to run enough training runs to debug deeper if something differs.

Looks good, no? SB3 DQN has even slightly better performance on one and I'm pretty sure SB3 DQN is faster than SB2, no?
Btw, which hyperparameters did you use? (Please update the defaults if you used different ones)

@araffin
Copy link
Member

araffin commented Aug 3, 2020

Otherwise, it looks like it is ready to merge, no?
The last_done fix is only 3 lines of code...

@Miffyli
Copy link
Collaborator Author

Miffyli commented Aug 3, 2020

Looks good, no? SB3 DQN has even slightly better performance on one and I'm pretty sure SB3 DQN is faster than SB2, no?
Btw, which hyperparameters did you use? (Please update the defaults if you used different ones)

Preferably I would want to performance match in both good and bad (i.e. not better or worse) just to keep consistent results, but that'd still require a lot of work ^^. I used the hyperparameters from sb2 rl-zoo, plus disabling all the DQN improvements for SB2. I am not quite sure what you mean by "update defaults".

@araffin
Copy link
Member

araffin commented Aug 3, 2020

I used the hyperparameters from sb2 rl-zoo, plus disabling all the DQN improvements for SB2. I am not quite sure what you mean by "update defaults".

I meant updating the default hyperparameters. The current ones are from the DQN nature paper and therefore do no correspond to your benchmark. The main differences are the buffer size and the final value of the exploration rate.
I would update the default with the one you used ;) (but we know they work)

@Miffyli
Copy link
Collaborator Author

Miffyli commented Aug 3, 2020

I meant updating the default hyperparameters. The current ones are from the DQN nature paper and therefore do no correspond to your benchmark. The main differences are the buffer size and the final value of the exploration rate.

Hmm I would those values from the original paper, as this is what users would expect when seeing "DQN". I do not think these parameters I used are the best (do not learn fastest / stablest), but I needed the replay-buffer size at the very least to be able to fit multiple experiments at same time on same machine.

@Miffyli Miffyli marked this pull request as ready for review August 3, 2020 19:59
@Miffyli Miffyli requested a review from araffin August 3, 2020 19:59
Copy link
Member

@araffin araffin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, very impressive and valuable detective work =)

@doylech
Copy link

doylech commented Jan 9, 2021

Thank you for your hard work on this to investigate and align the performance!

This PR is currently referenced in the Atari Results section of the documentation here: https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html

Regarding the learning curves, would you please be able to clarify,

  1. What Atari environment versions these are? e.g. NoFrameskip-v4

  2. If the learning curves apply to the agent reward (after preprocessing normalization) or the agent score? (This was an issue in Performance Check (Discrete actions) #49)

  3. What code would be needed to replicate these results the code referenced? The documentation indicates it would of the form:
    python train.py --algo ppo --env $ENV_ID --eval-episodes 10 --eval-freq 10000

  4. In the PPO learning curves, what settings "Full," "Minimal," and "Multi-discrete" refer to?

Thank you for your help

@Miffyli
Copy link
Collaborator Author

Miffyli commented Jan 11, 2021

@doylech

Thanks for the kind words!

I ran these experiments using a different code base from zoo (one I was most familiar at the time), so replicating results exactly might be bit tricky.

  1. Yes, NoFrameskip-v4, but with all the default Atari wrappers.
  2. If you refer to Use Monitor episode reward/length for evaluate_policy #220, yes, these results were obtained before that fix (i.e. learning curves use modified rewards)
  3. Looking at the code I used, that command should use same hyperparameters and wrappers I used. Note that these results use monitor files for creating learning curves which includes exploration, so DQN results will look different from zoo's. PPO/A2C should be fine.
  4. Those are different action-spaces explored in this work (the code I used to run these experiments). "Minimal" is the default you have with Atari envs, "Full" is where you always have access to all actions despite not used by the game and "Multi-discrete" is "Full" but where button press and joystick movement are separate into two different discrete actions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants