Skip to content

Commit

Permalink
Add an RL example in Jax (#55)
Browse files Browse the repository at this point in the history
* Add a Jax+RL example based on rejax.PPO

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove some of the unused code

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Move things around a bit

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Update version requirements for jax/torch

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Use xtills for cleaner Jit with annotations

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Save gif every epoch

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix rendering of classic-control gymnax envs

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add a "pure jax" training loop option

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fused training step in Lightning module

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Works without hash warnings now!

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Reorganize the code a bit

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Use vmap to train multiple agents in parallel

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add a jax analogue to lightning.Trainer

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add the equivalent of lightning.Callback for jax

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Log hyper-parameters

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Progress bar almost works

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Managed to get the progress bar to work!

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Move the trainer + callback to a different file

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Make stuff generic (not tied to PPOLearner)

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Update gymnax to improve rendering performance

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add configs, tweak experiment/main

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* wip: fixing issues in experiment.py

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix config now that network is optional

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix issue with progress bar callback!

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix duplicated code in main.py

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Move tests / Lightning wrapper to test file

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Rename things, add docstring to JaxTrainer

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix links in docstrings of JaxTrainer / JaxModule

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Tweak the docs of JaxModule/JaxTrainer

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Use regression fixtures in test

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix the ref in the JaxTrainer docstring

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix small errors that break CI

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix bug in test_rejax

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* "fix" config schema generation errors

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix test_rejax function

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Test the `train` method to replicate rejax.PPO

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Move Jax typing utils to a new module

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix default param causing preallocation of GPU mem

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add comments in conftest.py

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix test for rejax, add more todos in conftest.py

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix bug in lightning wrapper for rejax.PPO

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix issue in test_config from conftest change

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* (temp) make the tests run in unit test runs

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Tweaks to the jax typing utils

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Move the JaxTrainer to a new "trainers" dir

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Simplify docs in `jax_trainer.py`

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Move things around, add pytest.mark.slow marks

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix bug with config target type inference

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Move things around in jax_rl_example_test.py

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add some docstrings

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Re-organize tests, update regression files

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix the missing indexing in test for equivalence

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Don't use file_regression with gifs

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix issue with jax_rl_example_test.test_lightning

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

---------

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice authored Oct 11, 2024
1 parent a5acd0b commit 682cce6
Show file tree
Hide file tree
Showing 27 changed files with 2,843 additions and 155 deletions.
3 changes: 3 additions & 0 deletions .regression_files/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.gif
# Ignore tensor regression files.
*.npz
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
val/episode_lengths:
max: '2.e+02'
mean: '2.e+02'
min: '2.e+02'
shape: []
sum: '2.e+02'
val/rewards:
max: '-1.222e+03'
mean: '-1.222e+03'
min: '-1.222e+03'
shape: []
sum: '-1.222e+03'
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
cumulative_reward:
max: '-6.495e+02'
mean: '-1.229e+03'
min: '-1.878e+03'
shape:
- 76
- 128
sum: '-1.196e+07'
episode_length:
max: 200
mean: '2.e+02'
min: 200
shape:
- 76
- 128
sum: 1945600
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
cumulative_reward:
max: '-6.495e+02'
mean: '-1.229e+03'
min: '-1.878e+03'
shape:
- 76
- 128
sum: '-1.196e+07'
episode_length:
max: 200
mean: '2.e+02'
min: 200
shape:
- 76
- 128
sum: 1945600
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
cumulative_reward:
max: '-4.319e-01'
mean: '-5.755e+02'
min: '-1.872e+03'
shape:
- 76
- 128
sum: '-5.599e+06'
episode_length:
max: 200
mean: '2.e+02'
min: 200
shape:
- 76
- 128
sum: 1945600
11 changes: 11 additions & 0 deletions docs/examples/jax_rl_example.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
additional_python_references:
- project.algorithms.jax_rl_example
- project.trainers.jax_trainer
---

# Reinforcement Learning (Jax)

## JaxTrainer

The `JaxTrainer` is
2 changes: 2 additions & 0 deletions docs/generate_reference_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
# based on https://github.com/mkdocstrings/mkdocstrings/blob/5802b1ef5ad9bf6077974f777bd55f32ce2bc219/docs/gen_doc_stubs.py#L25


import os
from logging import getLogger as get_logger
from pathlib import Path

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
logger = get_logger(__name__)


Expand Down
2 changes: 2 additions & 0 deletions project/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .example import ExampleAlgorithm
from .hf_example import HFExample
from .jax_example import JaxExample
from .jax_rl_example import JaxRLExample
from .no_op import NoOp

__all__ = [
"ExampleAlgorithm",
"JaxExample",
"NoOp",
"HFExample",
"JaxRLExample",
]
55 changes: 41 additions & 14 deletions project/algorithms/callbacks/samples_per_second.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Literal
from typing import Any, Literal

from lightning import LightningModule, Trainer
from torch import Tensor
Expand All @@ -11,11 +11,11 @@


class MeasureSamplesPerSecondCallback(Callback[BatchType, StepOutputType]):
def __init__(self):
def __init__(self, num_optimizers: int | None = None):
super().__init__()
self.last_step_times: dict[Literal["train", "val", "test"], float] = {}
self.last_update_time: dict[int, float | None] = {}
self.num_optimizers: int | None = None
self.num_optimizers: int | None = num_optimizers

@override
def on_shared_epoch_start(
Expand Down Expand Up @@ -56,19 +56,44 @@ def on_shared_batch_end(
now = time.perf_counter()
if phase in self.last_step_times:
elapsed = now - self.last_step_times[phase]
if is_sequence_of(batch, Tensor):
batch_size = batch[0].shape[0]
pl_module.log(
f"{phase}/samples_per_second",
batch_size / elapsed,
prog_bar=True,
on_step=True,
on_epoch=True,
sync_dist=True,
)
batch_size = self.get_num_samples(batch)
self.log(
f"{phase}/samples_per_second",
batch_size / elapsed,
module=pl_module,
trainer=trainer,
prog_bar=True,
on_step=True,
on_epoch=True,
sync_dist=True,
batch_size=batch_size,
)
# todo: support other kinds of batches
self.last_step_times[phase] = now

def log(
self,
name: str,
value: Any,
module: LightningModule | Any,
trainer: Trainer | Any,
**kwargs,
):
# Used to possibly customize how the values are logged (e.g. for non-LightningModules).
# By default, uses the LightningModule.log method.
return module.log(
name,
value,
**kwargs,
)

def get_num_samples(self, batch: BatchType) -> int:
if is_sequence_of(batch, Tensor):
return batch[0].shape[0]
raise NotImplementedError(
f"Don't know how many 'samples' there are in batch of type {type(batch)}"
)

@override
def on_before_optimizer_step(
self,
Expand All @@ -89,9 +114,11 @@ def on_before_optimizer_step(
key = "ups"
else:
key = f"optimizer_{opt_idx}/ups"
pl_module.log(
self.log(
key,
updates_per_second,
module=pl_module,
trainer=trainer,
prog_bar=False,
on_step=True,
)
36 changes: 6 additions & 30 deletions project/algorithms/jax_example.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import dataclasses
import logging
import os
from collections.abc import Callable
from typing import Concatenate, Literal, ParamSpec, TypeVar
from typing import Literal

import chex
import flax.linen
import jax
import rich
Expand All @@ -21,8 +21,6 @@
from project.datamodules.image_classification.mnist import MNISTDataModule
from project.utils.typing_utils.protocols import ClassificationDataModule

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"


def flatten(x: jax.Array) -> jax.Array:
return x.reshape((x.shape[0], -1))
Expand Down Expand Up @@ -58,8 +56,8 @@ class JaxFcNet(flax.linen.Module):
num_features: int = 256

@flax.linen.compact
def __call__(self, x: jax.Array):
x = flatten(x)
def __call__(self, x: jax.Array, forward_rng: chex.PRNGKey | None = None):
# x = flatten(x)
x = flax.linen.Dense(features=self.num_features)(x)
x = flax.linen.relu(x)
x = flax.linen.Dense(features=self.num_classes)(x)
Expand Down Expand Up @@ -89,6 +87,8 @@ def __init__(
hp: HParams = HParams(),
):
super().__init__()
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

self.datamodule = datamodule
self.hp = hp or self.HParams()

Expand Down Expand Up @@ -193,30 +193,6 @@ def to_channels_last(x: jax.Array) -> jax.Array:
return x.transpose(0, 2, 3, 1)


P = ParamSpec("P")
Out = TypeVar("Out")


def jit(
fn: Callable[P, Out],
) -> Callable[P, Out]:
"""Small type hint fix for jax's `jit` (preserves the signature of the callable)."""
return jax.jit(fn) # type: ignore


In = TypeVar("In")
Aux = TypeVar("Aux")


def value_and_grad(
fn: Callable[Concatenate[In, P], tuple[Out, Aux]],
argnums: Literal[0] = 0,
has_aux: Literal[True] = True,
) -> Callable[Concatenate[In, P], tuple[tuple[Out, Aux], In]]:
"""Small type hint fix for jax's `value_and_grad` (preserves the signature of the callable)."""
return jax.value_and_grad(fn, argnums=argnums, has_aux=has_aux) # type: ignore


def main():
logging.basicConfig(
level=logging.INFO, format="%(message)s", handlers=[rich.logging.RichHandler()]
Expand Down
Loading

0 comments on commit 682cce6

Please sign in to comment.