Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add a Jax+RL example based on rejax.PPO Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Remove some of the unused code Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Move things around a bit Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Update version requirements for jax/torch Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Use xtills for cleaner Jit with annotations Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Save gif every epoch Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix rendering of classic-control gymnax envs Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Add a "pure jax" training loop option Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fused training step in Lightning module Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Works without hash warnings now! Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Reorganize the code a bit Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Use vmap to train multiple agents in parallel Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Add a jax analogue to lightning.Trainer Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Add the equivalent of lightning.Callback for jax Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Log hyper-parameters Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Progress bar almost works Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Managed to get the progress bar to work! Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Move the trainer + callback to a different file Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Make stuff generic (not tied to PPOLearner) Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Update gymnax to improve rendering performance Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Add configs, tweak experiment/main Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * wip: fixing issues in experiment.py Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix config now that network is optional Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix issue with progress bar callback! Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix duplicated code in main.py Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Move tests / Lightning wrapper to test file Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Rename things, add docstring to JaxTrainer Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix links in docstrings of JaxTrainer / JaxModule Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Tweak the docs of JaxModule/JaxTrainer Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Use regression fixtures in test Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix the ref in the JaxTrainer docstring Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix small errors that break CI Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix bug in test_rejax Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * "fix" config schema generation errors Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix test_rejax function Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Test the `train` method to replicate rejax.PPO Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Move Jax typing utils to a new module Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix default param causing preallocation of GPU mem Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Add comments in conftest.py Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix test for rejax, add more todos in conftest.py Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix bug in lightning wrapper for rejax.PPO Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix issue in test_config from conftest change Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * (temp) make the tests run in unit test runs Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Tweaks to the jax typing utils Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Move the JaxTrainer to a new "trainers" dir Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Simplify docs in `jax_trainer.py` Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Move things around, add pytest.mark.slow marks Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix bug with config target type inference Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Move things around in jax_rl_example_test.py Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Add some docstrings Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Re-organize tests, update regression files Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix the missing indexing in test for equivalence Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Don't use file_regression with gifs Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix issue with jax_rl_example_test.test_lightning Signed-off-by: Fabrice Normandin <normandf@mila.quebec> --------- Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
- Loading branch information