From 682cce6fa32f389891df6a606111cb1300fa64d0 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 11 Oct 2024 14:41:25 -0400 Subject: [PATCH] Add an RL example in Jax (#55) * Add a Jax+RL example based on rejax.PPO Signed-off-by: Fabrice Normandin * Remove some of the unused code Signed-off-by: Fabrice Normandin * Move things around a bit Signed-off-by: Fabrice Normandin * Update version requirements for jax/torch Signed-off-by: Fabrice Normandin * Use xtills for cleaner Jit with annotations Signed-off-by: Fabrice Normandin * Save gif every epoch Signed-off-by: Fabrice Normandin * Fix rendering of classic-control gymnax envs Signed-off-by: Fabrice Normandin * Add a "pure jax" training loop option Signed-off-by: Fabrice Normandin * Fused training step in Lightning module Signed-off-by: Fabrice Normandin * Works without hash warnings now! Signed-off-by: Fabrice Normandin * Reorganize the code a bit Signed-off-by: Fabrice Normandin * Use vmap to train multiple agents in parallel Signed-off-by: Fabrice Normandin * Add a jax analogue to lightning.Trainer Signed-off-by: Fabrice Normandin * Add the equivalent of lightning.Callback for jax Signed-off-by: Fabrice Normandin * Log hyper-parameters Signed-off-by: Fabrice Normandin * Progress bar almost works Signed-off-by: Fabrice Normandin * Managed to get the progress bar to work! Signed-off-by: Fabrice Normandin * Move the trainer + callback to a different file Signed-off-by: Fabrice Normandin * Make stuff generic (not tied to PPOLearner) Signed-off-by: Fabrice Normandin * Update gymnax to improve rendering performance Signed-off-by: Fabrice Normandin * Add configs, tweak experiment/main Signed-off-by: Fabrice Normandin * wip: fixing issues in experiment.py Signed-off-by: Fabrice Normandin * Fix config now that network is optional Signed-off-by: Fabrice Normandin * Fix issue with progress bar callback! Signed-off-by: Fabrice Normandin * Fix duplicated code in main.py Signed-off-by: Fabrice Normandin * Move tests / Lightning wrapper to test file Signed-off-by: Fabrice Normandin * Rename things, add docstring to JaxTrainer Signed-off-by: Fabrice Normandin * Fix links in docstrings of JaxTrainer / JaxModule Signed-off-by: Fabrice Normandin * Tweak the docs of JaxModule/JaxTrainer Signed-off-by: Fabrice Normandin * Use regression fixtures in test Signed-off-by: Fabrice Normandin * Fix the ref in the JaxTrainer docstring Signed-off-by: Fabrice Normandin * Fix small errors that break CI Signed-off-by: Fabrice Normandin * Fix bug in test_rejax Signed-off-by: Fabrice Normandin * "fix" config schema generation errors Signed-off-by: Fabrice Normandin * Fix test_rejax function Signed-off-by: Fabrice Normandin * Test the `train` method to replicate rejax.PPO Signed-off-by: Fabrice Normandin * Move Jax typing utils to a new module Signed-off-by: Fabrice Normandin * Fix default param causing preallocation of GPU mem Signed-off-by: Fabrice Normandin * Add comments in conftest.py Signed-off-by: Fabrice Normandin * Fix test for rejax, add more todos in conftest.py Signed-off-by: Fabrice Normandin * Fix bug in lightning wrapper for rejax.PPO Signed-off-by: Fabrice Normandin * Fix issue in test_config from conftest change Signed-off-by: Fabrice Normandin * (temp) make the tests run in unit test runs Signed-off-by: Fabrice Normandin * Tweaks to the jax typing utils Signed-off-by: Fabrice Normandin * Move the JaxTrainer to a new "trainers" dir Signed-off-by: Fabrice Normandin * Simplify docs in `jax_trainer.py` Signed-off-by: Fabrice Normandin * Move things around, add pytest.mark.slow marks Signed-off-by: Fabrice Normandin * Fix bug with config target type inference Signed-off-by: Fabrice Normandin * Move things around in jax_rl_example_test.py Signed-off-by: Fabrice Normandin * Add some docstrings Signed-off-by: Fabrice Normandin * Re-organize tests, update regression files Signed-off-by: Fabrice Normandin * Fix the missing indexing in test for equivalence Signed-off-by: Fabrice Normandin * Don't use file_regression with gifs Signed-off-by: Fabrice Normandin * Fix issue with jax_rl_example_test.test_lightning Signed-off-by: Fabrice Normandin --------- Signed-off-by: Fabrice Normandin --- .regression_files/.gitignore | 3 + .../test_lightning/123_Pendulum_v1_15.yaml | 12 + .../test_ours/123_Pendulum_v1.yaml | 16 + .../123_Pendulum_v1.yaml | 16 + .../test_rejax/123_Pendulum_v1.yaml | 16 + docs/examples/jax_rl_example.md | 11 + docs/generate_reference_docs.py | 2 + project/algorithms/__init__.py | 2 + .../callbacks/samples_per_second.py | 55 +- project/algorithms/jax_example.py | 36 +- project/algorithms/jax_rl_example.py | 815 ++++++++++++++++++ project/algorithms/jax_rl_example_test.py | 724 ++++++++++++++++ project/configs/algorithm/jax_rl_example.yaml | 35 + project/configs/config.py | 9 +- .../configs/experiment/jax_rl_example.yaml | 16 + .../experiment/local_sweep_example.yaml | 2 +- project/configs/trainer/jax.yaml | 16 + project/conftest.py | 46 +- project/main.py | 11 +- project/trainers/__init__.py | 8 + project/trainers/jax_trainer.py | 463 ++++++++++ project/utils/auto_schema.py | 5 +- project/utils/hydra_config_utils.py | 7 +- .../utils/typing_utils/jax_typing_utils.py | 141 +++ pyproject.toml | 27 +- requirements-dev.lock | 256 +++++- requirements.lock | 248 +++++- 27 files changed, 2843 insertions(+), 155 deletions(-) create mode 100644 .regression_files/.gitignore create mode 100644 .regression_files/project/algorithms/jax_rl_example_test/test_lightning/123_Pendulum_v1_15.yaml create mode 100644 .regression_files/project/algorithms/jax_rl_example_test/test_ours/123_Pendulum_v1.yaml create mode 100644 .regression_files/project/algorithms/jax_rl_example_test/test_ours_with_trainer/123_Pendulum_v1.yaml create mode 100644 .regression_files/project/algorithms/jax_rl_example_test/test_rejax/123_Pendulum_v1.yaml create mode 100644 docs/examples/jax_rl_example.md create mode 100644 project/algorithms/jax_rl_example.py create mode 100644 project/algorithms/jax_rl_example_test.py create mode 100644 project/configs/algorithm/jax_rl_example.yaml create mode 100644 project/configs/experiment/jax_rl_example.yaml create mode 100644 project/configs/trainer/jax.yaml create mode 100644 project/trainers/__init__.py create mode 100644 project/trainers/jax_trainer.py create mode 100644 project/utils/typing_utils/jax_typing_utils.py diff --git a/.regression_files/.gitignore b/.regression_files/.gitignore new file mode 100644 index 00000000..6a995205 --- /dev/null +++ b/.regression_files/.gitignore @@ -0,0 +1,3 @@ +*.gif +# Ignore tensor regression files. +*.npz diff --git a/.regression_files/project/algorithms/jax_rl_example_test/test_lightning/123_Pendulum_v1_15.yaml b/.regression_files/project/algorithms/jax_rl_example_test/test_lightning/123_Pendulum_v1_15.yaml new file mode 100644 index 00000000..e70ed343 --- /dev/null +++ b/.regression_files/project/algorithms/jax_rl_example_test/test_lightning/123_Pendulum_v1_15.yaml @@ -0,0 +1,12 @@ +val/episode_lengths: + max: '2.e+02' + mean: '2.e+02' + min: '2.e+02' + shape: [] + sum: '2.e+02' +val/rewards: + max: '-1.222e+03' + mean: '-1.222e+03' + min: '-1.222e+03' + shape: [] + sum: '-1.222e+03' diff --git a/.regression_files/project/algorithms/jax_rl_example_test/test_ours/123_Pendulum_v1.yaml b/.regression_files/project/algorithms/jax_rl_example_test/test_ours/123_Pendulum_v1.yaml new file mode 100644 index 00000000..d83973a5 --- /dev/null +++ b/.regression_files/project/algorithms/jax_rl_example_test/test_ours/123_Pendulum_v1.yaml @@ -0,0 +1,16 @@ +cumulative_reward: + max: '-6.495e+02' + mean: '-1.229e+03' + min: '-1.878e+03' + shape: + - 76 + - 128 + sum: '-1.196e+07' +episode_length: + max: 200 + mean: '2.e+02' + min: 200 + shape: + - 76 + - 128 + sum: 1945600 diff --git a/.regression_files/project/algorithms/jax_rl_example_test/test_ours_with_trainer/123_Pendulum_v1.yaml b/.regression_files/project/algorithms/jax_rl_example_test/test_ours_with_trainer/123_Pendulum_v1.yaml new file mode 100644 index 00000000..d83973a5 --- /dev/null +++ b/.regression_files/project/algorithms/jax_rl_example_test/test_ours_with_trainer/123_Pendulum_v1.yaml @@ -0,0 +1,16 @@ +cumulative_reward: + max: '-6.495e+02' + mean: '-1.229e+03' + min: '-1.878e+03' + shape: + - 76 + - 128 + sum: '-1.196e+07' +episode_length: + max: 200 + mean: '2.e+02' + min: 200 + shape: + - 76 + - 128 + sum: 1945600 diff --git a/.regression_files/project/algorithms/jax_rl_example_test/test_rejax/123_Pendulum_v1.yaml b/.regression_files/project/algorithms/jax_rl_example_test/test_rejax/123_Pendulum_v1.yaml new file mode 100644 index 00000000..8b29ccb9 --- /dev/null +++ b/.regression_files/project/algorithms/jax_rl_example_test/test_rejax/123_Pendulum_v1.yaml @@ -0,0 +1,16 @@ +cumulative_reward: + max: '-4.319e-01' + mean: '-5.755e+02' + min: '-1.872e+03' + shape: + - 76 + - 128 + sum: '-5.599e+06' +episode_length: + max: 200 + mean: '2.e+02' + min: 200 + shape: + - 76 + - 128 + sum: 1945600 diff --git a/docs/examples/jax_rl_example.md b/docs/examples/jax_rl_example.md new file mode 100644 index 00000000..7d130243 --- /dev/null +++ b/docs/examples/jax_rl_example.md @@ -0,0 +1,11 @@ +--- +additional_python_references: + - project.algorithms.jax_rl_example + - project.trainers.jax_trainer +--- + +# Reinforcement Learning (Jax) + +## JaxTrainer + +The `JaxTrainer` is diff --git a/docs/generate_reference_docs.py b/docs/generate_reference_docs.py index 73d4aa6a..6f6765bf 100644 --- a/docs/generate_reference_docs.py +++ b/docs/generate_reference_docs.py @@ -2,9 +2,11 @@ # based on https://github.com/mkdocstrings/mkdocstrings/blob/5802b1ef5ad9bf6077974f777bd55f32ce2bc219/docs/gen_doc_stubs.py#L25 +import os from logging import getLogger as get_logger from pathlib import Path +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" logger = get_logger(__name__) diff --git a/project/algorithms/__init__.py b/project/algorithms/__init__.py index 0185cb15..c92dca8a 100644 --- a/project/algorithms/__init__.py +++ b/project/algorithms/__init__.py @@ -1,6 +1,7 @@ from .example import ExampleAlgorithm from .hf_example import HFExample from .jax_example import JaxExample +from .jax_rl_example import JaxRLExample from .no_op import NoOp __all__ = [ @@ -8,4 +9,5 @@ "JaxExample", "NoOp", "HFExample", + "JaxRLExample", ] diff --git a/project/algorithms/callbacks/samples_per_second.py b/project/algorithms/callbacks/samples_per_second.py index 853c6b06..523956b3 100644 --- a/project/algorithms/callbacks/samples_per_second.py +++ b/project/algorithms/callbacks/samples_per_second.py @@ -1,5 +1,5 @@ import time -from typing import Literal +from typing import Any, Literal from lightning import LightningModule, Trainer from torch import Tensor @@ -11,11 +11,11 @@ class MeasureSamplesPerSecondCallback(Callback[BatchType, StepOutputType]): - def __init__(self): + def __init__(self, num_optimizers: int | None = None): super().__init__() self.last_step_times: dict[Literal["train", "val", "test"], float] = {} self.last_update_time: dict[int, float | None] = {} - self.num_optimizers: int | None = None + self.num_optimizers: int | None = num_optimizers @override def on_shared_epoch_start( @@ -56,19 +56,44 @@ def on_shared_batch_end( now = time.perf_counter() if phase in self.last_step_times: elapsed = now - self.last_step_times[phase] - if is_sequence_of(batch, Tensor): - batch_size = batch[0].shape[0] - pl_module.log( - f"{phase}/samples_per_second", - batch_size / elapsed, - prog_bar=True, - on_step=True, - on_epoch=True, - sync_dist=True, - ) + batch_size = self.get_num_samples(batch) + self.log( + f"{phase}/samples_per_second", + batch_size / elapsed, + module=pl_module, + trainer=trainer, + prog_bar=True, + on_step=True, + on_epoch=True, + sync_dist=True, + batch_size=batch_size, + ) # todo: support other kinds of batches self.last_step_times[phase] = now + def log( + self, + name: str, + value: Any, + module: LightningModule | Any, + trainer: Trainer | Any, + **kwargs, + ): + # Used to possibly customize how the values are logged (e.g. for non-LightningModules). + # By default, uses the LightningModule.log method. + return module.log( + name, + value, + **kwargs, + ) + + def get_num_samples(self, batch: BatchType) -> int: + if is_sequence_of(batch, Tensor): + return batch[0].shape[0] + raise NotImplementedError( + f"Don't know how many 'samples' there are in batch of type {type(batch)}" + ) + @override def on_before_optimizer_step( self, @@ -89,9 +114,11 @@ def on_before_optimizer_step( key = "ups" else: key = f"optimizer_{opt_idx}/ups" - pl_module.log( + self.log( key, updates_per_second, + module=pl_module, + trainer=trainer, prog_bar=False, on_step=True, ) diff --git a/project/algorithms/jax_example.py b/project/algorithms/jax_example.py index aab9e4fe..6817e4d2 100644 --- a/project/algorithms/jax_example.py +++ b/project/algorithms/jax_example.py @@ -1,9 +1,9 @@ import dataclasses import logging import os -from collections.abc import Callable -from typing import Concatenate, Literal, ParamSpec, TypeVar +from typing import Literal +import chex import flax.linen import jax import rich @@ -21,8 +21,6 @@ from project.datamodules.image_classification.mnist import MNISTDataModule from project.utils.typing_utils.protocols import ClassificationDataModule -os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" - def flatten(x: jax.Array) -> jax.Array: return x.reshape((x.shape[0], -1)) @@ -58,8 +56,8 @@ class JaxFcNet(flax.linen.Module): num_features: int = 256 @flax.linen.compact - def __call__(self, x: jax.Array): - x = flatten(x) + def __call__(self, x: jax.Array, forward_rng: chex.PRNGKey | None = None): + # x = flatten(x) x = flax.linen.Dense(features=self.num_features)(x) x = flax.linen.relu(x) x = flax.linen.Dense(features=self.num_classes)(x) @@ -89,6 +87,8 @@ def __init__( hp: HParams = HParams(), ): super().__init__() + os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" + self.datamodule = datamodule self.hp = hp or self.HParams() @@ -193,30 +193,6 @@ def to_channels_last(x: jax.Array) -> jax.Array: return x.transpose(0, 2, 3, 1) -P = ParamSpec("P") -Out = TypeVar("Out") - - -def jit( - fn: Callable[P, Out], -) -> Callable[P, Out]: - """Small type hint fix for jax's `jit` (preserves the signature of the callable).""" - return jax.jit(fn) # type: ignore - - -In = TypeVar("In") -Aux = TypeVar("Aux") - - -def value_and_grad( - fn: Callable[Concatenate[In, P], tuple[Out, Aux]], - argnums: Literal[0] = 0, - has_aux: Literal[True] = True, -) -> Callable[Concatenate[In, P], tuple[tuple[Out, Aux], In]]: - """Small type hint fix for jax's `value_and_grad` (preserves the signature of the callable).""" - return jax.value_and_grad(fn, argnums=argnums, has_aux=has_aux) # type: ignore - - def main(): logging.basicConfig( level=logging.INFO, format="%(message)s", handlers=[rich.logging.RichHandler()] diff --git a/project/algorithms/jax_rl_example.py b/project/algorithms/jax_rl_example.py new file mode 100644 index 00000000..900ffb15 --- /dev/null +++ b/project/algorithms/jax_rl_example.py @@ -0,0 +1,815 @@ +from __future__ import annotations + +import contextlib +import functools +from collections.abc import Callable, Sequence +from logging import getLogger as get_logger +from pathlib import Path +from typing import Any, Generic, TypedDict + +import chex +import flax.core +import flax.linen +import flax.struct +import gymnax +import gymnax.environments.spaces +import gymnax.experimental.rollout +import jax +import jax.experimental +import jax.numpy as jnp +import numpy as np +import optax +from flax.training.train_state import TrainState +from flax.typing import FrozenVariableDict +from gymnax.environments.environment import Environment +from gymnax.visualize.visualizer import Visualizer +from matplotlib import pyplot as plt +from rejax.algos.mixins import RMSState +from rejax.evaluate import evaluate +from rejax.networks import DiscretePolicy, GaussianPolicy, VNetwork +from typing_extensions import TypeVar +from xtils.jitpp import Static + +from project.trainers.jax_trainer import JaxCallback, JaxModule, JaxTrainer +from project.utils.typing_utils.jax_typing_utils import field, jit + +logger = get_logger(__name__) +# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" + +TEnvParams = TypeVar("TEnvParams", bound=gymnax.EnvParams, default=gymnax.EnvParams) +"""Type variable for the env params (`gymnax.EnvParams`).""" + +TEnvState = TypeVar("TEnvState", bound=gymnax.EnvState, default=gymnax.EnvState) + + +class Trajectory(flax.struct.PyTreeNode): + """A sequence of interactions between an agent and an environment.""" + + obs: jax.Array + action: jax.Array + log_prob: jax.Array + reward: jax.Array + value: jax.Array + done: jax.Array + + +class TrajectoryWithLastObs(flax.struct.PyTreeNode): + """Trajectory with the last observation and whether the last step is the end of an episode.""" + + trajectories: Trajectory + last_done: jax.Array + last_obs: jax.Array + + +class AdvantageMinibatch(flax.struct.PyTreeNode): + """Annotated trajectories with advantages and targets for the critic.""" + + trajectories: Trajectory + advantages: chex.Array + targets: chex.Array + + +class TrajectoryCollectionState(Generic[TEnvState], flax.struct.PyTreeNode): + """Struct containing the state related to the collection of data from the environment.""" + + last_obs: jax.Array + env_state: TEnvState + rms_state: RMSState + last_done: jax.Array + global_step: int + rng: chex.PRNGKey + + +class PPOState(Generic[TEnvState], flax.struct.PyTreeNode): + """Contains all the state of the `JaxRLExample` algorithm.""" + + actor_ts: TrainState + critic_ts: TrainState + rng: chex.PRNGKey + data_collection_state: TrajectoryCollectionState[TEnvState] + + +class PPOHParams(flax.struct.PyTreeNode): + """Hyper-parameters for this PPO example. + + These are taken from `rejax.PPO` algorithm class. + """ + + num_epochs: int = field(default=8) + num_envs: int = field(default=64) # overwrite default + num_steps: int = field(default=64) + num_minibatches: int = field(default=16) + + eval_freq: int = field(default=4_096) + + normalize_observations: bool = field(default=False) + total_timesteps: int = field(default=131_072) + debug: bool = field(default=False) + + learning_rate: chex.Scalar = 0.0003 + gamma: chex.Scalar = 0.99 + max_grad_norm: chex.Scalar = jnp.inf + + gae_lambda: chex.Scalar = 0.95 + clip_eps: chex.Scalar = 0.2 + vf_coef: chex.Scalar = 0.5 + ent_coef: chex.Scalar = 0.01 + + # IDEA: Split up the RNGs for different parts? + # rng: chex.PRNGKey = flax.struct.field(pytree_node=True, default=jax.random.key(0)) + # networks_rng: chex.PRNGKey = flax.struct.field(pytree_node=True, default=jax.random.key(1)) + # env_rng: chex.PRNGKey = flax.struct.field(pytree_node=True, default=jax.random.key(2)) + + +class _AgentKwargs(TypedDict): + activation: str + hidden_layer_sizes: Sequence[int] + + +class _NetworkConfig(TypedDict): + agent_kwargs: _AgentKwargs + + +class TrainStepMetrics(flax.struct.PyTreeNode): + actor_losses: jax.Array + critic_losses: jax.Array + + +class EvalMetrics(flax.struct.PyTreeNode): + episode_length: jax.Array + cumulative_reward: jax.Array + + +class JaxRLExample( + flax.struct.PyTreeNode, + JaxModule[PPOState[TEnvState], TrajectoryWithLastObs, EvalMetrics], + Generic[TEnvState, TEnvParams], +): + """Example of an RL algorithm written in Jax: PPO, based on `rejax.PPO`. + + ## Differences w.r.t. rejax.PPO: + + - The state / hparams are split into different, fully-typed structs: + - The algorithm state is in a typed `PPOState` struct (vs an untyped, + dynamically-generated struct in rejax). + - The hyper-parameters are in a typed `PPOHParams` struct. + - The state variables related to the collection of data from the environment is a + `TrajectoryCollectionState` instead of everything being bunched up together. + - This makes it easier to call the `collect_episodes` function with just what it needs. + - The seeds for the networks and the environment data collection are separated. + + The logic is exactly the same: The losses / updates are computed in the exact same way. + """ + + env: Environment[TEnvState, TEnvParams] = flax.struct.field(pytree_node=False) + env_params: TEnvParams + actor: flax.linen.Module = flax.struct.field(pytree_node=False) + critic: flax.linen.Module = flax.struct.field(pytree_node=False) + hp: PPOHParams + + @classmethod + def create( + cls, + env_id: str | None = None, + env: Environment[TEnvState, TEnvParams] | None = None, + env_params: TEnvParams | None = None, + hp: PPOHParams | None = None, + ) -> JaxRLExample[TEnvState, TEnvParams]: + from brax.envs import _envs as brax_envs + from rejax.compat.brax2gymnax import create_brax + + # env_params: gymnax.EnvParams + if env_id is None: + assert env is not None + env_params = env_params or env.default_params # type: ignore + elif env_id in brax_envs: + env, env_params = create_brax( # type: ignore + env_id, + episode_length=1000, + action_repeat=1, + auto_reset=True, + batch_size=None, + backend="generalized", + ) + elif isinstance(env_id, str): + env, env_params = gymnax.make(env_id=env_id) # type: ignore + else: + raise NotImplementedError(env_id) + + assert env is not None + assert env_params is not None + return cls( + env=env, + env_params=env_params, + actor=cls.create_actor(env, env_params), + critic=cls.create_critic(), + hp=hp or PPOHParams(), + ) + + @classmethod + def create_networks( + cls, + env: Environment[gymnax.EnvState, TEnvParams], + env_params: TEnvParams, + config: _NetworkConfig, + ): + # Equivalent to: + # return rejax.PPO.create_agent(config, env, env_params) + return { + "actor": cls.create_actor(env, env_params, **config["agent_kwargs"]), + "critic": cls.create_actor(env, env_params, **config["agent_kwargs"]), + } + + _TEnvParams = TypeVar("_TEnvParams", bound=gymnax.EnvParams, covariant=True) + _TEnvState = TypeVar("_TEnvState", bound=gymnax.EnvState, covariant=True) + + @classmethod + def create_actor( + cls, + env: Environment[_TEnvState, _TEnvParams], + env_params: _TEnvParams, + activation: str | Callable[[jax.Array], jax.Array] = "swish", + hidden_layer_sizes: Sequence[int] = (64, 64), + **actor_kwargs, + ) -> DiscretePolicy | GaussianPolicy: + activation_fn: Callable[[jax.Array], jax.Array] = ( + getattr(flax.linen, activation) if not callable(activation) else activation + ) + hidden_layer_sizes = tuple(hidden_layer_sizes) + action_space = env.action_space(env_params) + + if isinstance(action_space, gymnax.environments.spaces.Discrete): + return DiscretePolicy( + action_space.n, + activation=activation_fn, + hidden_layer_sizes=hidden_layer_sizes, + **actor_kwargs, + ) + assert isinstance(action_space, gymnax.environments.spaces.Box) + return GaussianPolicy( + np.prod(action_space.shape), + (action_space.low, action_space.high), # type: ignore + activation=activation_fn, + hidden_layer_sizes=hidden_layer_sizes, + **actor_kwargs, + ) + + @classmethod + def create_critic( + cls, + activation: str | Callable[[jax.Array], jax.Array] = "swish", + hidden_layer_sizes: Sequence[int] = (64, 64), + **critic_kwargs, + ) -> VNetwork: + activation_fn: Callable[[jax.Array], jax.Array] = ( + getattr(flax.linen, activation) if isinstance(activation, str) else activation + ) + hidden_layer_sizes = tuple(hidden_layer_sizes) + return VNetwork( + hidden_layer_sizes=hidden_layer_sizes, activation=activation_fn, **critic_kwargs + ) + + def init_train_state(self, rng: chex.PRNGKey) -> PPOState[TEnvState]: + rng, networks_rng, env_rng = jax.random.split(rng, 3) + + rng_actor, rng_critic = jax.random.split(networks_rng, 2) + + obs_ph = jnp.empty([1, *self.env.observation_space(self.env_params).shape]) + + actor_params = self.actor.init(rng_actor, obs_ph, rng_actor) + critic_params = self.critic.init(rng_critic, obs_ph) + + tx = optax.adam(learning_rate=self.hp.learning_rate) + # TODO: Why isn't the `apply_fn` not set in rejax? + actor_ts = TrainState.create(apply_fn=self.actor.apply, params=actor_params, tx=tx) + critic_ts = TrainState.create(apply_fn=self.critic.apply, params=critic_params, tx=tx) + + env_rng, reset_rng = jax.random.split(env_rng) + obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))( + jax.random.split(reset_rng, self.hp.num_envs), self.env_params + ) + + collection_state = TrajectoryCollectionState( + last_obs=obs, + rms_state=RMSState.create(shape=obs_ph.shape), + global_step=0, + env_state=env_state, + last_done=jnp.zeros(self.hp.num_envs, dtype=bool), + rng=env_rng, + ) + + return PPOState( + actor_ts=actor_ts, + critic_ts=critic_ts, + rng=rng, + data_collection_state=collection_state, + ) + + # @jit + def training_step(self, batch_idx: int, ts: PPOState[TEnvState], batch: TrajectoryWithLastObs): + """Training step in pure jax.""" + trajectories = batch + + ts, (actor_losses, critic_losses) = jax.lax.scan( + functools.partial(self.ppo_update_epoch, trajectories=trajectories), + init=ts, + xs=jnp.arange(self.hp.num_epochs), # type: ignore + length=self.hp.num_epochs, + ) + # todo: perhaps we could have a callback that updates a progress bar? + # jax.debug.print("actor_losses {}: {}", iteration, actor_losses.mean()) + # jax.debug.print("critic_losses {}: {}", iteration, critic_losses.mean()) + + return ts, TrainStepMetrics(actor_losses=actor_losses, critic_losses=critic_losses) + + # @jit + def ppo_update_epoch( + self, ts: PPOState[TEnvState], epoch_index: int, trajectories: TrajectoryWithLastObs + ): + minibatch_rng = jax.random.fold_in(ts.rng, epoch_index) + + last_val = self.critic.apply(ts.critic_ts.params, ts.data_collection_state.last_obs) + assert isinstance(last_val, jax.Array) + last_val = jnp.where(ts.data_collection_state.last_done, 0, last_val) + advantages, targets = calculate_gae( + trajectories, last_val, gamma=self.hp.gamma, gae_lambda=self.hp.gae_lambda + ) + batch = AdvantageMinibatch(trajectories.trajectories, advantages, targets) + minibatches = shuffle_and_split( + batch, minibatch_rng, num_minibatches=self.hp.num_minibatches + ) + + # shuffle the data and split it into minibatches + + num_steps = self.hp.num_steps + num_envs = self.hp.num_envs + num_minibatches = self.hp.num_minibatches + assert (num_envs * num_steps) % num_minibatches == 0 + minibatches = shuffle_and_split( + batch, + minibatch_rng, + num_minibatches=num_minibatches, + ) + return jax.lax.scan(self.ppo_update, ts, minibatches, length=self.hp.num_minibatches) + + # @jit + def ppo_update(self, ts: PPOState[TEnvState], batch: AdvantageMinibatch): + actor_loss, actor_grads = jax.value_and_grad(actor_loss_fn)( + ts.actor_ts.params, + actor=self.actor, + batch=batch, + clip_eps=self.hp.clip_eps, + ent_coef=self.hp.ent_coef, + ) + assert isinstance(actor_loss, jax.Array) + critic_loss, critic_grads = jax.value_and_grad(critic_loss_fn)( + ts.critic_ts.params, + critic=self.critic, + batch=batch, + clip_eps=self.hp.clip_eps, + vf_coef=self.hp.vf_coef, + ) + assert isinstance(critic_loss, jax.Array) + + # TODO: to log the loss here? + actor_ts = ts.actor_ts.apply_gradients(grads=actor_grads) + critic_ts = ts.critic_ts.apply_gradients(grads=critic_grads) + + return ts.replace(actor_ts=actor_ts, critic_ts=critic_ts), (actor_loss, critic_loss) + + def eval_callback( + self, ts: PPOState[TEnvState], rng: chex.PRNGKey | None = None + ) -> EvalMetrics: + if rng is None: + rng = ts.rng + actor = make_actor(ts=ts, hp=self.hp) + max_steps = self.env_params.max_steps_in_episode + ep_lengths, cum_rewards = evaluate( + actor, ts.rng, self.env, self.env_params, 128, max_steps + ) + return EvalMetrics(episode_length=ep_lengths, cumulative_reward=cum_rewards) + + def get_batch( + self, ts: PPOState[TEnvState], batch_idx: int + ) -> tuple[PPOState[TEnvState], TrajectoryWithLastObs]: + data_collection_state, trajectories = self.collect_trajectories( + ts.data_collection_state, + actor_params=ts.actor_ts.params, + critic_params=ts.critic_ts.params, + ) + ts = ts.replace(data_collection_state=data_collection_state) + return ts, trajectories + + # @jit + def collect_trajectories( + self, + collection_state: TrajectoryCollectionState[TEnvState], + actor_params: FrozenVariableDict, + critic_params: FrozenVariableDict, + ): + env_step_fn = functools.partial( + self.env_step, + # env=self.env, + # env_params=self.env_params, + # actor=self.actor, + # critic=self.critic, + # num_envs=self.hp.num_envs, + actor_params=actor_params, + critic_params=critic_params, + # discrete=self.discrete, + # normalize_observations=self.hp.normalize_observations, + ) + collection_state, trajectories = jax.lax.scan( + env_step_fn, + collection_state, + xs=jnp.arange(self.hp.num_steps), + length=self.hp.num_steps, + ) + trajectories_with_last = TrajectoryWithLastObs( + trajectories=trajectories, + last_done=collection_state.last_done, + last_obs=collection_state.last_obs, + ) + return collection_state, trajectories_with_last + + # @jit + def env_step( + self, + collection_state: TrajectoryCollectionState[TEnvState], + step_index: jax.Array, + actor_params: FrozenVariableDict, + critic_params: FrozenVariableDict, + ): + # Get keys for sampling action and stepping environment + # doing it this way to try to get *exactly* the same rngs as in rejax.PPO. + rng, new_rngs = jax.random.split(collection_state.rng, 2) + rng_steps, rng_action = jax.random.split(new_rngs, 2) + rng_steps = jax.random.split(rng_steps, self.hp.num_envs) + + # Sample action + unclipped_action, log_prob = self.actor.apply( + actor_params, collection_state.last_obs, rng_action, method="action_log_prob" + ) + assert isinstance(log_prob, jax.Array) + value = self.critic.apply(critic_params, collection_state.last_obs) + assert isinstance(value, jax.Array) + + # Clip action + if self.discrete: + action = unclipped_action + else: + low = self.env.action_space(self.env_params).low + high = self.env.action_space(self.env_params).high + action = jnp.clip(unclipped_action, low, high) + + # Step environment + next_obs, env_state, reward, done, _ = jax.vmap(self.env.step, in_axes=(0, 0, 0, None))( + rng_steps, + collection_state.env_state, + action, + self.env_params, + ) + + if self.hp.normalize_observations: + # rms_state, next_obs = learner.update_and_normalize(collection_state.rms_state, next_obs) + rms_state = _update_rms(collection_state.rms_state, obs=next_obs, batched=True) + next_obs = _normalize_obs(rms_state, obs=next_obs) + + collection_state = collection_state.replace(rms_state=rms_state) + + # Return updated runner state and transition + transition = Trajectory( + collection_state.last_obs, unclipped_action, log_prob, reward, value, done + ) + collection_state = collection_state.replace( + env_state=env_state, + last_obs=next_obs, + last_done=done, + global_step=collection_state.global_step + self.hp.num_envs, + rng=rng, + ) + return collection_state, transition + + @property + def discrete(self) -> bool: + return isinstance( + self.env.action_space(self.env_params), gymnax.environments.spaces.Discrete + ) + + def visualize(self, ts: PPOState, gif_path: str | Path, eval_rng: chex.PRNGKey | None = None): + actor = make_actor(ts=ts, hp=self.hp) + render_episode( + actor=actor, + env=self.env, + env_params=self.env_params, + gif_path=Path(gif_path), + rng=eval_rng if eval_rng is not None else ts.rng, + ) + + ## These here aren't currently used. They are here to mirror rejax.PPO where the training loop + # is in the algorithm. + + @functools.partial(jit, static_argnames=["skip_initial_evaluation"]) + def train( + self, + rng: jax.Array, + train_state: PPOState[TEnvState] | None = None, + skip_initial_evaluation: bool = False, + ) -> tuple[PPOState[TEnvState], EvalMetrics]: + """Full training loop in pure jax (a lot faster than when using pytorch-lightning). + + This doesn't get used when using the `JaxTrainer`, since this is the equivalent of the + `JaxTrainer.fit` method. + + Unfolded version of `rejax.PPO.train`. + + Training loop in pure jax (a lot faster than when using pytorch-lightning). + """ + if train_state is None and rng is None: + raise ValueError("Either train_state or rng must be provided") + + ts = train_state if train_state is not None else self.init_train_state(rng) + + initial_evaluation: EvalMetrics | None = None + if not skip_initial_evaluation: + initial_evaluation = self.eval_callback(ts) + + num_evals = np.ceil(self.hp.total_timesteps / self.hp.eval_freq).astype(int) + ts, evaluation = jax.lax.scan( + self.training_epoch, + init=ts, + xs=None, + length=num_evals, + ) + + if not skip_initial_evaluation: + assert initial_evaluation is not None + evaluation = jax.tree.map( + lambda i, ev: jnp.concatenate((jnp.expand_dims(i, 0), ev)), + initial_evaluation, + evaluation, + ) + assert isinstance(evaluation, EvalMetrics) + + return ts, evaluation + + # @jit + def training_epoch( + self, ts: PPOState[TEnvState], epoch: int + ) -> tuple[PPOState[TEnvState], EvalMetrics]: + # Run a few training iterations + iteration_steps = self.hp.num_envs * self.hp.num_steps + num_iterations = np.ceil(self.hp.eval_freq / iteration_steps).astype(int) + ts = jax.lax.fori_loop( + 0, + num_iterations, + # drop metrics for now + lambda i, train_state_i: self.fused_training_step(i, train_state_i)[0], + ts, + ) + # Run evaluation + return ts, self.eval_callback(ts) + + # @jit + def fused_training_step(self, iteration: int, ts: PPOState[TEnvState]): + """Fused training step in jax (joined data collection + training). + + *MUCH* faster than using pytorch-lightning, but you lose the callbacks and such. + """ + + data_collection_state, trajectories = self.collect_trajectories( + # env=self.env, + # env_params=self.env_params, + # actor=self.actor, + # critic=self.critic, + collection_state=ts.data_collection_state, + actor_params=ts.actor_ts.params, + critic_params=ts.critic_ts.params, + # num_envs=self.hp.num_envs, + # num_steps=self.hp.num_steps, + # discrete=discrete, + # normalize_observations=self.hp.normalize_observations, + ) + ts = ts.replace(data_collection_state=data_collection_state) + return self.training_step(iteration, ts, trajectories) + + +def has_discrete_actions( + env: Environment[gymnax.EnvState, TEnvParams], env_params: TEnvParams +) -> bool: + return isinstance(env.action_space(env_params), gymnax.environments.spaces.Discrete) + + +def _update_rms(rms_state: RMSState, obs: jax.Array, batched: bool = True): + batch = obs if batched else jnp.expand_dims(obs, 0) + + batch_count = batch.shape[0] + batch_mean, batch_var = batch.mean(axis=0), batch.var(axis=0) + + delta = batch_mean - rms_state.mean + tot_count = rms_state.count + batch_count + + new_mean = rms_state.mean + delta * batch_count / tot_count + m_a = rms_state.var * rms_state.count + m_b = batch_var * batch_count + M2 = m_a + m_b + delta**2 * rms_state.count * batch_count / tot_count + new_var = M2 / tot_count + new_count = tot_count + + return RMSState(mean=new_mean, var=new_var, count=new_count) + + +def _normalize_obs(rms_state: RMSState, obs: jax.Array): + return (obs - rms_state.mean) / jnp.sqrt(rms_state.var + 1e-8) + + +@functools.partial(jit, static_argnames=["num_minibatches"]) +def shuffle_and_split( + data: AdvantageMinibatch, rng: chex.PRNGKey, num_minibatches: int +) -> AdvantageMinibatch: + assert data.trajectories.obs.shape + iteration_size = data.trajectories.obs.shape[0] * data.trajectories.obs.shape[1] + permutation = jax.random.permutation(rng, iteration_size) + _shuffle_and_split_fn = functools.partial( + _shuffle_and_split, + permutation=permutation, + num_minibatches=num_minibatches, + ) + return jax.tree.map(_shuffle_and_split_fn, data) + + +@functools.partial(jit, static_argnames=["num_minibatches"]) +def _shuffle_and_split(x: jax.Array, permutation: jax.Array, num_minibatches: Static[int]): + x = x.reshape((x.shape[0] * x.shape[1], *x.shape[2:])) + x = jnp.take(x, permutation, axis=0) + return x.reshape(num_minibatches, -1, *x.shape[1:]) + + +# @jit +def calculate_gae( + trajectories: TrajectoryWithLastObs, + last_val: jax.Array, + gamma: float, + gae_lambda: float, +) -> tuple[jax.Array, jax.Array]: + get_advantages_fn = functools.partial(get_advantages, gamma=gamma, gae_lambda=gae_lambda) + _, advantages = jax.lax.scan( + get_advantages_fn, + init=(jnp.zeros_like(last_val), last_val), + xs=trajectories, + reverse=True, + ) + return advantages, advantages + trajectories.trajectories.value + + +# @jit +def get_advantages( + advantage_and_next_value: tuple[jax.Array, jax.Array], + transition: TrajectoryWithLastObs, + gamma: float, + gae_lambda: float, +) -> tuple[tuple[jax.Array, jax.Array], jax.Array]: + advantage, next_value = advantage_and_next_value + transition_data = transition.trajectories + assert isinstance(transition_data.reward, jax.Array) + delta = ( + transition_data.reward.squeeze() # For gymnax envs that return shape (1, ) + + gamma * next_value * (1 - transition_data.done) + - transition_data.value + ) + advantage = delta + gamma * gae_lambda * (1 - transition_data.done) * advantage + assert isinstance(transition_data.value, jax.Array) + return (advantage, transition_data.value), advantage + + +@functools.partial(jit, static_argnames=["actor"]) +def actor_loss_fn( + params: FrozenVariableDict, + actor: Static[flax.linen.Module], + batch: AdvantageMinibatch, + clip_eps: float, + ent_coef: float, +) -> jax.Array: + log_prob, entropy = actor.apply( + params, + batch.trajectories.obs, + batch.trajectories.action, + method="log_prob_entropy", + ) + assert isinstance(entropy, jax.Array) + entropy = entropy.mean() + + # Calculate actor loss + ratio = jnp.exp(log_prob - batch.trajectories.log_prob) + advantages = (batch.advantages - batch.advantages.mean()) / (batch.advantages.std() + 1e-8) + clipped_ratio = jnp.clip(ratio, 1 - clip_eps, 1 + clip_eps) + pi_loss1 = ratio * advantages + pi_loss2 = clipped_ratio * advantages + pi_loss = -jnp.minimum(pi_loss1, pi_loss2).mean() + return pi_loss - ent_coef * entropy + + +@functools.partial(jit, static_argnames=["critic"]) +def critic_loss_fn( + params: FrozenVariableDict, + critic: Static[flax.linen.Module], + batch: AdvantageMinibatch, + clip_eps: float, + vf_coef: float, +): + value = critic.apply(params, batch.trajectories.obs) + assert isinstance(value, jax.Array) + value_pred_clipped = batch.trajectories.value + (value - batch.trajectories.value).clip( + -clip_eps, clip_eps + ) + assert isinstance(value_pred_clipped, jax.Array) + value_losses = jnp.square(value - batch.targets) + value_losses_clipped = jnp.square(value_pred_clipped - batch.targets) + value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() + return vf_coef * value_loss + + +def _actor( + obs: jax.Array, + rng: chex.PRNGKey, + actor_ts: TrainState, + rms_state: RMSState | None, + normalize_observations: bool, +): + if normalize_observations: + assert rms_state is not None + obs = _normalize_obs(rms_state, obs) + + obs = jnp.expand_dims(obs, 0) + action = actor_ts.apply_fn(actor_ts.params, obs, rng, method="act") + return jnp.squeeze(action) + + +def make_actor( + ts: PPOState[Any], hp: PPOHParams +) -> Callable[[jax.Array, chex.PRNGKey], jax.Array]: + return functools.partial( + _actor, + actor_ts=ts.actor_ts, + rms_state=ts.data_collection_state.rms_state, + normalize_observations=hp.normalize_observations, + ) + + +def render_episode( + actor: Callable[[jax.Array, chex.PRNGKey], jax.Array], + env: Environment[Any, TEnvParams], + env_params: TEnvParams, + gif_path: Path, + rng: chex.PRNGKey, + num_steps: int = 200, +): + state_seq, reward_seq = [], [] + rng, rng_reset = jax.random.split(rng) + obs, env_state = env.reset(rng_reset, env_params) + for step in range(num_steps): + state_seq.append(env_state) + rng, rng_act, rng_step = jax.random.split(rng, 3) + action = actor(obs, rng_act) + next_obs, next_env_state, reward, done, info = env.step( + key=rng_step, state=env_state, action=action, params=env_params + ) + reward_seq.append(reward) + # if done or step >= 500: + # break + obs = next_obs + env_state = next_env_state + + cum_rewards = jnp.cumsum(jnp.array(reward_seq)) + vis = Visualizer(env, env_params, state_seq, cum_rewards) + # gif_path = Path(log_dir) / f"epoch_{current_epoch}.gif" + logger.info(f"Saving gif to {gif_path}") + # print(f"Saving gif to {gif_path}") + # Disable the "ffmpeg moviewriter not available, using Pillow" print to stderr that happens in + # there. + gif_path.parent.mkdir(exist_ok=True, parents=True) + with contextlib.redirect_stderr(None): + vis.animate(str(gif_path)) + plt.close(vis.fig) + + +class RenderEpisodesCallback(JaxCallback): + on_every_epoch: int = False + + def on_fit_start(self, trainer: JaxTrainer, module: JaxRLExample, ts: PPOState): + if not self.on_every_epoch: + return + log_dir = trainer.logger.save_dir if trainer.logger else trainer.default_root_dir + assert log_dir is not None + gif_path = Path(log_dir) / f"step_{ts.data_collection_state.global_step:05}.gif" + module.visualize(ts=ts, gif_path=gif_path) + jax.debug.print("Saved gif to {gif_path}", gif_path=gif_path) + + def on_train_epoch_start(self, trainer: JaxTrainer, module: JaxRLExample, ts: PPOState): + if not self.on_every_epoch: + return + log_dir = trainer.logger.save_dir if trainer.logger else trainer.default_root_dir + assert log_dir is not None + gif_path = Path(log_dir) / f"epoch_{ts.data_collection_state.global_step:05}.gif" + module.visualize(ts=ts, gif_path=gif_path) + jax.debug.print("Saved gif to {gif_path}", gif_path=gif_path) diff --git a/project/algorithms/jax_rl_example_test.py b/project/algorithms/jax_rl_example_test.py new file mode 100644 index 00000000..d94d7eb3 --- /dev/null +++ b/project/algorithms/jax_rl_example_test.py @@ -0,0 +1,724 @@ +from __future__ import annotations + +import dataclasses +import functools +import operator +import time +from collections.abc import Callable, Iterable +from logging import getLogger +from pathlib import Path +from typing import Any + +import chex +import gymnax +import jax +import jax.numpy as jnp +import lightning +import numpy as np +import pytest +import rejax +import scipy.stats +import torch +import torch_jax_interop +from gymnax.environments.environment import Environment +from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBar +from lightning.pytorch.loggers import CSVLogger +from tensor_regression import TensorRegressionFixture +from torch.utils.data import DataLoader +from typing_extensions import override + +from project.algorithms.callbacks.samples_per_second import MeasureSamplesPerSecondCallback +from project.trainers.jax_trainer import JaxTrainer, hparams_to_dict +from project.utils.testutils import run_for_all_configs_of_type + +from .jax_rl_example import ( + EvalMetrics, + JaxRLExample, + PPOHParams, + PPOState, + TEnvParams, + TEnvState, + Trajectory, + TrajectoryWithLastObs, + _actor, + render_episode, +) +from .testsuites.algorithm_tests import LearningAlgorithmTests + +logger = getLogger(__name__) + + +@pytest.fixture(scope="session", params=[123]) +def seed(request: pytest.FixtureRequest) -> int: + seed = getattr(request, "param", 123) + return seed + + +@pytest.fixture(scope="session") +def rng(seed: int) -> chex.PRNGKey: + return jax.random.key(seed) + + +@pytest.fixture(scope="session") +def n_agents(request: pytest.FixtureRequest) -> int | None: + return getattr(request, "param", None) + + +@pytest.fixture(scope="session") +def results_ours( + algo: JaxRLExample, + rng: chex.PRNGKey, + n_agents: int | None, +): + train_fn = algo.train + + if n_agents is not None: + train_fn = jax.vmap(train_fn) + rng = jax.random.split(rng, n_agents) + + train_fn = jax.jit(train_fn).lower(rng).compile() + _start = time.perf_counter() + train_states_ours, evals_ours = train_fn(rng) + jax.block_until_ready((train_states_ours, evals_ours)) + print(f"Our tweaked rejax.PPO: {time.perf_counter() - _start:.1f} seconds.") + return train_states_ours, evals_ours + + +@pytest.fixture +def results_ours_with_trainer( + algo: JaxRLExample, + rng: chex.PRNGKey, + n_agents: int, + jax_trainer: JaxTrainer, +): + train_fn = jax_trainer.fit + + if n_agents is not None: + jax_trainer = jax_trainer.replace(callbacks=()) + train_fn = jax_trainer.fit + train_fn = jax.vmap(train_fn, in_axes=(None, 0)) + rng = jax.random.split(rng, n_agents) + + train_fn_with_trainer = jax.jit(train_fn).lower(algo, rng).compile() + _start = time.perf_counter() + _train_states_ours_with_trainer, evals_ours_with_trainer = train_fn_with_trainer(algo, rng) + jax.block_until_ready((_train_states_ours_with_trainer, evals_ours_with_trainer)) + print(f"Our tweaked rejax.PPO with JaxTrainer: {time.perf_counter() - _start:.1f} seconds.") + return _train_states_ours_with_trainer, evals_ours_with_trainer + + +@pytest.fixture +def results_rejax( + algo: JaxRLExample, + rng: chex.PRNGKey, + n_agents: int, +): + # _start = time.perf_counter() + _rejax_ppo, train_states_rejax, evals_rejax = _train_rejax( + env=algo.env, env_params=algo.env_params, hp=algo.hp, rng=rng, n_agents=n_agents + ) + # jax.block_until_ready((train_states_rejax, evals_rejax)) + # print(f"rejax.PPO: {time.perf_counter() - _start:.1f} seconds.") + return _rejax_ppo, train_states_rejax, evals_rejax + + +def test_ours( + algo: JaxRLExample, + results_ours: tuple[PPOState, EvalMetrics], + tensor_regression: TensorRegressionFixture, + seed: int, + rng: chex.PRNGKey, + n_agents: int | None, + original_datadir: Path, +): + evaluations = results_ours[1] + tensor_regression.check(jax.tree.map(lambda v: v.__array__(), dataclasses.asdict(evaluations))) + + eval_rng = rng + if n_agents is None: + gif_path = original_datadir / f"ours_{seed=}.gif" + algo.visualize(results_ours[0], gif_path=gif_path, eval_rng=eval_rng) + else: + gif_path = original_datadir / f"ours_{n_agents=}_{seed=}_first.gif" + fn = functools.partial(jax.tree.map, operator.itemgetter(0)) + algo.visualize(fn(results_ours[0]), gif_path=gif_path, eval_rng=eval_rng) + + +def test_ours_with_trainer( + algo: JaxRLExample, + results_ours_with_trainer: tuple[PPOState, EvalMetrics], + tensor_regression: TensorRegressionFixture, + tmp_path: Path, + seed: int, + rng: chex.PRNGKey, + n_agents: int | None, + original_datadir: Path, +): + ts, evaluations = results_ours_with_trainer + tensor_regression.check(jax.tree.map(lambda v: v.__array__(), dataclasses.asdict(evaluations))) + + eval_rng = rng + if n_agents is None: + gif_path = original_datadir / f"ours_with_trainer_{seed=}.gif" + algo.visualize(ts, gif_path=gif_path, eval_rng=eval_rng) + else: + gif_path = original_datadir / f"ours_with_trainer_{n_agents=}_{seed=}_first.gif" + fn = functools.partial(jax.tree.map, operator.itemgetter(0)) + algo.visualize(fn(ts), gif_path=gif_path, eval_rng=eval_rng) + + +def test_results_are_same_with_or_without_jax_trainer( + results_ours: tuple[PPOState, EvalMetrics], + results_ours_with_trainer: tuple[PPOState, EvalMetrics], +): + np.testing.assert_allclose( + results_ours[1].cumulative_reward, results_ours_with_trainer[1].cumulative_reward + ) + # jax.tree.map( + # np.testing.assert_allclose, + # jax.tree.leaves(results_ours), + # jax.tree.leaves(results_ours_with_trainer), + # ) + + +def test_rejax( + rng: chex.PRNGKey, + results_rejax: tuple[rejax.PPO, Any, EvalMetrics], + tensor_regression: TensorRegressionFixture, + original_datadir: Path, + n_agents: int | None, +): + """Train `rejax.PPO` with the same parameters.""" + + _algo, ts, evaluations = results_rejax + tensor_regression.check(jax.tree.map(lambda v: v.__array__(), dataclasses.asdict(evaluations))) + eval_rng = rng + + if n_agents is None: + gif_path = original_datadir / f"rejax_{seed=}.gif" + _visualize_rejax(rejax_algo=_algo, rejax_ts=ts, eval_rng=rng, gif_path=gif_path) + else: + fn = functools.partial(jax.tree.map, operator.itemgetter(0)) + _visualize_rejax( + rejax_algo=results_rejax[0], + rejax_ts=fn(results_rejax[1]), + eval_rng=eval_rng, + gif_path=original_datadir / f"rejax_{n_agents=}_{seed=}_first.gif.gif", + ) + + +def best_index(evals: EvalMetrics) -> int: + return jnp.argmax(evals.cumulative_reward[:, -1]).item() + + +def median_index(evals: EvalMetrics) -> int: + vals = evals.cumulative_reward[:, -1] + return jnp.argsort(vals)[len(vals) // 2].item() + + +def get_slicing_fn(eval: EvalMetrics, get_index_fn: Callable[[EvalMetrics], int]) -> Any: + index = get_index_fn(eval) + return functools.partial(jax.tree.map, operator.itemgetter(index)) + + +@pytest.mark.parametrize("n_agents", [pytest.param(100, marks=pytest.mark.slow)], indirect=True) +def test_algos_are_equivalent( + algo: JaxRLExample, + n_agents: int | None, + results_ours: tuple[PPOState, EvalMetrics], + results_rejax: tuple[rejax.PPO, Any, EvalMetrics], +): + if n_agents is None: + _ours_vs_rejax = scipy.stats.mannwhitneyu( + results_ours[1].cumulative_reward[-1], results_rejax[2].cumulative_reward[-1] + ) + else: + _ours_vs_rejax = scipy.stats.mannwhitneyu( + results_ours[1].cumulative_reward[:, -1], results_rejax[2].cumulative_reward[:, -1] + ) + # TODO: interpret these results. + + +def _visualize_rejax(rejax_algo: rejax.PPO, rejax_ts: Any, eval_rng: chex.PRNGKey, gif_path: Path): + # rejax_algo = results_rejax[0] + # ts = results_rejax[1] + actor_ts = rejax_ts.actor_ts.replace(apply_fn=rejax_algo.actor.apply) + actor = functools.partial( + _actor, + actor_ts=actor_ts, + rms_state=rejax_ts.rms_state, + normalize_observations=rejax_algo.normalize_observations, + ) + render_episode( + actor=actor, + env=rejax_algo.env, + env_params=rejax_algo.env_params, + gif_path=gif_path, + rng=eval_rng, + ) + + +def _train_rejax( + env: Environment[gymnax.EnvState, TEnvParams], + env_params: TEnvParams, + hp: PPOHParams, + rng: chex.PRNGKey, + n_agents: int | None = None, +): + print("Rejax") + # todo: Make sure that rejax uses the same number of epochs as us. + algo = rejax.PPO.create( + env=env, + env_params=env_params, + num_envs=hp.num_envs, # =100, + num_steps=hp.num_steps, # =100, + num_epochs=hp.num_epochs, # =10, + num_minibatches=hp.num_minibatches, # =10, + learning_rate=hp.learning_rate, # =0.001, + max_grad_norm=hp.max_grad_norm, # =10, + total_timesteps=hp.total_timesteps, # =150_000, + eval_freq=hp.eval_freq, # =2000, + gamma=hp.gamma, # =0.995, + gae_lambda=hp.gae_lambda, # =0.95, + clip_eps=hp.clip_eps, # =0.2, + ent_coef=hp.ent_coef, # =0.0, + vf_coef=hp.vf_coef, # =0.5, + normalize_observations=hp.normalize_observations, # =True, + ) + print("Compiling...") + start = time.perf_counter() + + train_fn = algo.train + if n_agents: + # Vmap training function over n_agents initial seeds + train_fn = jax.vmap(train_fn) + rng = jax.random.split(rng, n_agents) + + train_fn = jax.jit(train_fn).lower(rng).compile() + print(f"Compiled in {time.perf_counter() - start} seconds.") + print("Training...") + start = time.perf_counter() + ts, eval = train_fn(rng) + jax.block_until_ready((ts, eval)) + print(f"Finished training in {time.perf_counter() - start} seconds.") + return algo, ts, EvalMetrics(eval[0], eval[1]) + + # print(ts) + + +def train_lightning( + algo: JaxRLExample, + rng: chex.PRNGKey, + trainer: lightning.Trainer, +): + # Fit with pytorch-lightning. + print("Lightning") + + module = PPOLightningModule( + learner=algo, + ts=algo.init_train_state(rng), + ) + + start = time.perf_counter() + trainer.fit(module) + print(f"Trained in {time.perf_counter() - start:.1f} seconds.") + + evaluation = trainer.validate(module) + + return module.ts, evaluation + + +@pytest.fixture(scope="session", params=["Pendulum-v1"]) +def env_id(request: pytest.FixtureRequest) -> str: + # env_id = "halfcheetah" + # env_id = "humanoid" + return request.param + + +@pytest.fixture(scope="session") +def env_and_params(env_id: str) -> tuple[Environment[gymnax.EnvState, TEnvParams], TEnvParams]: + from brax.envs import _envs as brax_envs + from rejax.compat.brax2gymnax import create_brax + + env: Environment[gymnax.EnvState, gymnax.EnvParams] + env_params: gymnax.EnvParams + if env_id in brax_envs: + env, env_params = create_brax( # type: ignore + env_id, + episode_length=1000, + action_repeat=1, + auto_reset=True, + batch_size=None, + backend="generalized", + ) + elif isinstance(env_id, str): + env, env_params = gymnax.make(env_id=env_id) # type: ignore + else: + env = env_id() # type: ignore + env_params = env.default_params + return env, env_params # type: ignore + + +@pytest.fixture(scope="session") +def algo( + env_and_params: tuple[Environment[TEnvState, TEnvParams], TEnvParams], +) -> JaxRLExample[TEnvState, TEnvParams]: + env, env_params = env_and_params + algo = JaxRLExample[TEnvState, TEnvParams]( + env=env, + env_params=env_params, + actor=JaxRLExample.create_actor(env, env_params), + critic=JaxRLExample.create_critic(), + hp=PPOHParams( + num_envs=100, + num_steps=100, + num_epochs=10, + num_minibatches=10, + learning_rate=0.001, + max_grad_norm=10, + total_timesteps=150_000, + eval_freq=2000, + gamma=0.995, + gae_lambda=0.95, + clip_eps=0.2, + ent_coef=0.0, + vf_coef=0.5, + normalize_observations=True, + debug=False, + ), + ) + return algo + + +@pytest.fixture(autouse=True, scope="session") +def debug_jit_warnings(): + # Temporarily make this particular warning into an error to help future-proof our jax code. + import jax._src.deprecations + + val_before = jax._src.deprecations._registered_deprecations["tracer-hash"].accelerated + jax._src.deprecations._registered_deprecations["tracer-hash"].accelerated = True + yield + jax._src.deprecations._registered_deprecations["tracer-hash"].accelerated = val_before + + # train_pure_jax(algo, backend="cpu") + # train_rejax(env=algo.env, env_params=algo.env_params, hp=algo.hp, backend="cpu") + # train_lightning(algo, accelerator="cpu") + + +@pytest.fixture +def max_epochs(algo: JaxRLExample, request: pytest.FixtureRequest) -> int: + # This is the usual value: (75) + # return 3 # shorter for tests? + default_max_epochs = int(np.ceil(algo.hp.total_timesteps / algo.hp.eval_freq).astype(int)) + return getattr(request, "param", default_max_epochs) + + +@pytest.fixture +def jax_trainer(algo: JaxRLExample, max_epochs: int, tmp_path: Path): + iteration_steps = algo.hp.num_envs * algo.hp.num_steps + num_iterations = np.ceil(algo.hp.eval_freq / iteration_steps).astype(int) + training_steps_per_epoch: int = num_iterations + + return JaxTrainer( + max_epochs=max_epochs, + training_steps_per_epoch=training_steps_per_epoch, + # todo: make sure that this also works with the wandb logger! + logger=CSVLogger(save_dir=tmp_path, name=None, flush_logs_every_n_steps=1), + default_root_dir=tmp_path, + callbacks=( + # RlThroughputCallback(), # Can't use this callback with `vmap`! + # RenderEpisodesCallback(on_every_epoch=False), + RichProgressBar(), + ), + ) + + +## Pytorch-Lightning wrapper around this learner: + +# Don't allow tests to run for more than 5 seconds. +# pytestmark = pytest.mark.timeout(5) + + +class PPOLightningModule(lightning.LightningModule): + """Uses the same code as [project.algorithms.jax_rl_example.JaxRLExample][], but the training + loop is run with pytorch-lightning. + + This is currently only meant to be used to compare the difference fully-jitted training loop + and lightning. + """ + + def __init__( + self, + learner: JaxRLExample, + ts: PPOState, + ): + # https://github.com/keraJLi/rejax/blob/a1428ad3d661e31985c5c19460cec70bc95aef6e/configs/gymnax/pendulum.yaml#L1 + + super().__init__() + self.learner = learner + self.ts = ts + + self.save_hyperparameters(hparams_to_dict(learner)) + self.actor_params = torch.nn.ParameterList( + jax.tree.leaves( + jax.tree.map( + torch_jax_interop.to_torch.jax_to_torch_tensor, + self.ts.actor_ts.params, + ) + ) + ) + self.critic_params = torch.nn.ParameterList( + jax.tree.leaves( + jax.tree.map( + torch_jax_interop.to_torch.jax_to_torch_tensor, + self.ts.critic_ts.params, + ) + ) + ) + + self.automatic_optimization = False + + iteration_steps = self.learner.hp.num_envs * self.learner.hp.num_steps + # number of "iterations" (collecting batches of episodes in the environment) per epoch. + self.num_train_iterations = np.ceil(self.learner.hp.eval_freq / iteration_steps).astype( + int + ) + + @override + def training_step(self, batch: torch.Tensor, batch_idx: int): + start = time.perf_counter() + with jax.disable_jit(self.learner.hp.debug): + algo_struct = self.learner + self.ts, train_metrics = algo_struct.fused_training_step(batch_idx, self.ts) + + duration = time.perf_counter() - start + logger.debug(f"Training step took {duration:.1f} seconds.") + actor_losses = train_metrics.actor_losses + critic_losses = train_metrics.critic_losses + self.log("train/actor_loss", actor_losses.mean().item(), logger=True, prog_bar=True) + self.log("train/critic_loss", critic_losses.mean().item(), logger=True, prog_bar=True) + + updates_per_second = ( + self.learner.hp.num_epochs * self.learner.hp.num_minibatches + ) / duration + self.log("train/updates_per_second", updates_per_second, logger=True, prog_bar=True) + minibatch_size = ( + self.learner.hp.num_envs * self.learner.hp.num_steps + ) // self.learner.hp.num_minibatches + samples_per_update = minibatch_size + self.log( + "train/samples_per_second", + updates_per_second * samples_per_update, + logger=True, + prog_bar=True, + on_step=True, + ) + + # for jax_param, torch_param in zip( + # jax.tree.leaves(self.train_state.actor_ts.params), self.actor_params + # ): + # torch_param.set_(torch_jax_interop.to_torch.jax_to_torch_tensor(jax_param)) + + # for jax_param, torch_param in zip( + # jax.tree.leaves(self.train_state.critic_ts.params), self.critic_params + # ): + # torch_param.set_(torch_jax_interop.to_torch.jax_to_torch_tensor(jax_param)) + + return + + @override + def train_dataloader(self) -> Iterable[Trajectory]: + # BUG: what's probably happening is that the dataloader keeps getting batches with the + # initial train state! + from torch.utils.data import TensorDataset + + dataset = TensorDataset(torch.arange(self.num_train_iterations, device=self.device)) + return DataLoader(dataset, batch_size=None, num_workers=0, shuffle=False, collate_fn=None) + + def val_dataloader(self) -> Any: + # todo: unsure what this should be yielding.. + from torch.utils.data import TensorDataset + + dataset = TensorDataset(torch.arange(1, device=self.device)) + return DataLoader(dataset, batch_size=None, num_workers=0, shuffle=False, collate_fn=None) + + def validation_step(self, batch: int, batch_index: int): + # self.learner.eval_callback() + # return # skip the rest for now while we compare the performance? + eval_metrics = self.learner.eval_callback(ts=self.ts) + episode_lengths = eval_metrics.episode_length + cumulative_rewards = eval_metrics.cumulative_reward + self.log("val/episode_lengths", episode_lengths.mean().item(), batch_size=1) + self.log("val/rewards", cumulative_rewards.mean().item(), batch_size=1) + + @override + def configure_optimizers(self) -> Any: + # todo: Note, this one isn't used atm! + from torch.optim.adam import Adam + + return Adam(self.parameters(), lr=1e-3) + + @override + def transfer_batch_to_device( + self, batch: TrajectoryWithLastObs | int, device: torch.device, dataloader_idx: int + ) -> TrajectoryWithLastObs | int: + if isinstance(batch, int): + # FIXME: valid dataloader currently just yields ints, not trajectories. + return batch + if isinstance(batch, list) and len(batch) == 1: + # FIXME: train dataloader currently just yields ints, not trajectories. + return batch + + _batch_jax_devices = batch.trajectories.obs.devices() + assert len(_batch_jax_devices) == 1 + batch_jax_device = _batch_jax_devices.pop() + torch_self_device = device + if ( + torch_self_device.type == "cuda" + and "cuda" in str(batch_jax_device) + and (torch_self_device.index == -1 or torch_self_device.index == batch_jax_device.id) + ): + # All good, both are on the same GPU. + return batch + + jax_self_device = torch_jax_interop.to_jax.torch_to_jax_device(torch_self_device) + return jax.tree.map(functools.partial(jax.device_put, device=jax_self_device), batch) + + +class RlThroughputCallback(MeasureSamplesPerSecondCallback): + """A callback to measure the throughput of RL algorithms.""" + + def __init__(self, num_optimizers: int | None = 1): + super().__init__(num_optimizers=num_optimizers) + self.total_transitions = 0 + self.total_episodes = 0 + self._start = time.perf_counter() + self._updates = 0 + + @override + def on_fit_start( + self, + trainer: lightning.Trainer, + pl_module: lightning.LightningModule, + ) -> None: + super().on_fit_start(trainer, pl_module) + self.total_transitions = 0 + self.total_episodes = 0 + self._start = time.perf_counter() + + @override + def on_train_batch_end( + self, + trainer: lightning.Trainer, + pl_module: lightning.LightningModule, + outputs: dict[str, torch.Tensor], + batch: TrajectoryWithLastObs, + batch_index: int, + ) -> None: + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_index) + if not isinstance(batch, TrajectoryWithLastObs): + return + episodes = batch.trajectories + assert episodes.obs.shape + num_episodes = episodes.obs.shape[0] + num_transitions = np.prod(episodes.obs.shape[:2]) + self.total_episodes += num_episodes + self.total_transitions += num_transitions + steps_per_second = self.total_transitions / (time.perf_counter() - self._start) + updates_per_second = (self._updates) / (time.perf_counter() - self._start) + episodes_per_second = self.total_episodes / (time.perf_counter() - self._start) + logger.info( + f"Total transitions: {self.total_transitions}, total episodes: {self.total_episodes}" + ) + # print(f"Steps per second: {steps_per_second}") + logger.info(f"Steps per second: {steps_per_second}") + logger.info(f"Episodes per second: {episodes_per_second}") + logger.info(f"Updates per second: {updates_per_second}") + + @override + def get_num_samples(self, batch: TrajectoryWithLastObs) -> int: + if isinstance(batch, int): # fixme + return 1 + return int(np.prod(batch.trajectories.obs.shape[:2]).item()) + + @override + def on_fit_end(self, trainer: lightning.Trainer, pl_module: lightning.LightningModule) -> None: + super().on_fit_end(trainer, pl_module) + + def log( + self, + name: str, + value: Any, + module: JaxRLExample, + trainer: lightning.Trainer | JaxTrainer, + **kwargs, + ): + # Used to possibly customize how the values are logged (e.g. for non-LightningModules). + # By default, uses the LightningModule.log method. + # TODO: Somehow log the metrics without an actual trainer? + # Should we create a Trainer / LightningModule "facade" that the callbacks can interact with? + if trainer.logger: + trainer.logger.log_metrics({name: value}, step=trainer.global_step, **kwargs) + + # if trainer.progress_bar_callback: + # trainer.progress_bar_callback.log_metrics({name: value}, step=trainer.global_step, **kwargs) + # return trainer.logger.log_metrics().log( + # name, + # value, + # **kwargs, + # ) + + +# TODO: potentially just use the Lightning adapter for unit tests for now? +@pytest.mark.skip(reason="TODO: ests assume a LightningModule atm (.state_dict()), etc.") +@run_for_all_configs_of_type("algorithm", JaxRLExample) +class TestJaxRLExample(LearningAlgorithmTests[JaxRLExample]): # type: ignore + pass + + +@pytest.fixture +def lightning_trainer(max_epochs: int, tmp_path: Path): + return lightning.Trainer( + max_epochs=max_epochs, + # logger=CSVLogger(save_dir="logs/jax_rl_debug"), + logger=False, + accelerator="auto", + devices=1 if torch.cuda.is_available() else "auto", + default_root_dir=tmp_path, + # barebones=True, + # reload_dataloaders_every_n_epochs=1, # todo: use this if we end up making a generator in train_dataloader + enable_checkpointing=False, + enable_model_summary=True, + num_sanity_val_steps=0, + check_val_every_n_epoch=max_epochs, + # limit_val_batches=0, + fast_dev_run=False, + detect_anomaly=False, + profiler=None, + ) + + +# reducing the max_epochs from 75 down to 10 because it's just wayyy too slow. +@pytest.mark.slow +# @pytest.mark.timeout(80) +@pytest.mark.parametrize("max_epochs", [15], indirect=True) +def test_lightning( + algo: JaxRLExample, + rng: chex.PRNGKey, + lightning_trainer: lightning.Trainer, + tensor_regression: TensorRegressionFixture, + original_datadir: Path, +): + # todo: save a gif and some metrics? + train_state, evaluations = train_lightning( + algo, + rng=rng, + trainer=lightning_trainer, + ) + gif_path = original_datadir / "lightning.gif" + algo.visualize(train_state, gif_path=gif_path) + # file_regression.check(gif_path.read_bytes(), binary=True, extension=".gif") + assert len(evaluations) == 1 + # floats in regression files are saved with full precision, and the last few digits are + # different for some reason. + tensor_regression.check(jax.tree.map(np.asarray, evaluations[0])) diff --git a/project/configs/algorithm/jax_rl_example.yaml b/project/configs/algorithm/jax_rl_example.yaml new file mode 100644 index 00000000..3e210bcc --- /dev/null +++ b/project/configs/algorithm/jax_rl_example.yaml @@ -0,0 +1,35 @@ +# Config for the Jax RL Example (PPO). +# To run this, use the following command: +# ``` +# python project/main.py algorithm=jax_rl_example trainer=jax +# ``` + +_target_: project.algorithms.jax_rl_example.JaxRLExample.create +env: + _target_: gymnax.environments.classic_control.pendulum.Pendulum +env_params: + _target_: gymnax.environments.classic_control.pendulum.EnvParams + dt: 0.05000000074505806 + g: 10.0 + l: 1.0 + m: 1.0 + max_speed: 8.0 + max_steps_in_episode: 200 + max_torque: 2.0 +hp: + _target_: project.algorithms.jax_rl_example.PPOHParams + clip_eps: 0.20000000298023224 + debug: false + ent_coef: 0.0 + eval_freq: 2000 + gae_lambda: 0.949999988079071 + gamma: 0.9950000047683716 + learning_rate: 0.0010000000474974513 + max_grad_norm: 10 + normalize_observations: true + num_envs: 100 + num_epochs: 10 + num_minibatches: 10 + num_steps: 100 + total_timesteps: 150000 + vf_coef: 0.5 diff --git a/project/configs/config.py b/project/configs/config.py index f79086aa..e918a07c 100644 --- a/project/configs/config.py +++ b/project/configs/config.py @@ -1,7 +1,7 @@ import random from dataclasses import dataclass, field from logging import getLogger as get_logger -from typing import Any, Literal +from typing import Any, Literal, Optional from omegaconf import OmegaConf @@ -37,6 +37,13 @@ class Config: See the [MNISTDataModule][project.datamodules.image_classification.mnist.MNISTDataModule] for an example. """ + datamodule: Optional[Any] = None # noqa + """Configuration for the datamodule (dataset + transforms + dataloader creation). + + This should normally create a [LightningDataModule][lightning.pytorch.core.datamodule.LightningDataModule]. + See the [MNISTDataModule][project.datamodules.image_classification.mnist.MNISTDataModule] for an example. + """ + trainer: dict = field(default_factory=dict) """Keyword arguments for the Trainer constructor.""" diff --git a/project/configs/experiment/jax_rl_example.yaml b/project/configs/experiment/jax_rl_example.yaml new file mode 100644 index 00000000..da571b8f --- /dev/null +++ b/project/configs/experiment/jax_rl_example.yaml @@ -0,0 +1,16 @@ +# @package _global_ + +defaults: + - override /algorithm: jax_rl_example + - override /trainer: jax + - override /trainer/callbacks: rich_progress_bar + - override /datamodule: null +trainer: + max_epochs: 75 + training_steps_per_epoch: 1 + callbacks: + render_episodes: + _target_: project.algorithms.jax_rl_example.RenderEpisodesCallback + on_every_epoch: false + # progress_bar: + # _target_: lightning.pytorch.callbacks.progress.rich_progress.RichProgressBar diff --git a/project/configs/experiment/local_sweep_example.yaml b/project/configs/experiment/local_sweep_example.yaml index 6fcd4659..4a32746e 100644 --- a/project/configs/experiment/local_sweep_example.yaml +++ b/project/configs/experiment/local_sweep_example.yaml @@ -13,7 +13,7 @@ algorithm: lr: 0.002 trainer: - accelerator: gpu + accelerator: auto devices: 1 max_epochs: 1 logger: diff --git a/project/configs/trainer/jax.yaml b/project/configs/trainer/jax.yaml new file mode 100644 index 00000000..ae68ce4a --- /dev/null +++ b/project/configs/trainer/jax.yaml @@ -0,0 +1,16 @@ +defaults: + - callbacks: rich_progress_bar.yaml +_target_: project.trainers.jax_trainer.JaxTrainer +max_epochs: 75 +training_steps_per_epoch: 1 + +# path to output directory, created dynamically by hydra +# path generation pattern is specified in `configs/hydra/default.yaml` +# use it to store all files generated during the run, like checkpoints and metrics +default_root_dir: ${hydra:runtime.output_dir} +# callbacks: +# render_episodes: +# _target_: project.algorithms.jax_rl_example.RenderEpisodesCallback +# on_every_epoch: false +# progress_bar: +# _target_: lightning.pytorch.callbacks.progress.rich_progress.RichProgressBar diff --git a/project/conftest.py b/project/conftest.py index 76d3414c..d365fc2b 100644 --- a/project/conftest.py +++ b/project/conftest.py @@ -66,6 +66,7 @@ import sys import typing from collections import defaultdict +from collections.abc import Generator from contextlib import contextmanager from logging import getLogger as get_logger from pathlib import Path @@ -172,8 +173,8 @@ def algorithm_network_config(request: pytest.FixtureRequest) -> str | None: @pytest.fixture(scope="session") def command_line_arguments( - devices: str, - accelerator: str, + # devices: str, + # accelerator: str, algorithm_config: str | None, datamodule_config: str | None, algorithm_network_config: str | None, @@ -201,14 +202,16 @@ def command_line_arguments( if combination >= configs: # warnings.warn(f"Applying markers because {combination} contains {configs}") # There is a combination of potentially unsupported configs here, e.g. MNIST and ResNets. - pytest.skip(reason=f"Combination {combination} contains {configs}.") + # BUG: This is supposed to work, but doesn't for some reason! # for mark in marks: # request.applymarker(mark) + # Skipping the test entirely for now. + pytest.skip(reason=f"Combination {combination} contains {configs}.") default_overrides = [ # NOTE: if we were to run the test in a slurm job, this wouldn't make sense. - f"trainer.devices={devices}", - f"trainer.accelerator={accelerator}", + # f"trainer.devices={devices}", + # f"trainer.accelerator={accelerator}", # TODO: Setting this here, which actually impacts the tests! "seed=42", ] @@ -335,6 +338,7 @@ def seed(request: pytest.FixtureRequest, make_torch_deterministic: None): yield random_seed +# TODO: Remove this. @pytest.fixture(scope="session") def accelerator(request: pytest.FixtureRequest): """Returns the accelerator to use during unit tests. @@ -342,6 +346,8 @@ def accelerator(request: pytest.FixtureRequest): By default, if cuda is available, returns "cuda". If the tests are run with -vvv, then also runs CPU. """ + # TODO: Shouldn't we get this from the experiment config instead? + default_accelerator = "gpu" if torch.cuda.is_available() else "cpu" accelerator: str = getattr(request, "param", default_accelerator) @@ -381,8 +387,18 @@ def num_devices_to_use(accelerator: str, request: pytest.FixtureRequest) -> int: @pytest.fixture(scope="session") -def devices(accelerator: str, request: pytest.FixtureRequest) -> list[int] | int | Literal["auto"]: - """Fixture that creates the 'devices' argument for the Trainer config.""" +def devices( + accelerator: str, request: pytest.FixtureRequest +) -> Generator[list[int] | int | Literal["auto"], None, None]: + """Fixture that creates the 'devices' argument for the Trainer config. + + Splits up the GPUs between pytest-xdist workers when using distributed testing. + This isn't currently used in the CI. + + TODO: Design dilemna here: Should we be parametrizing the `devices` command-line override and + force experiments to run with this value during tests? Or should we be changing things based on + this value in the config? + """ # When using pytest-xdist to distribute tests, each worker will use different devices. devices = getattr(request, "param", None) @@ -397,22 +413,30 @@ def devices(accelerator: str, request: pytest.FixtureRequest) -> list[int] | int n_cpus = num_cpus_on_node() # Split the CPUS as evenly as possible (last worker might get less). if num_pytest_workers == 1: - return "auto" + yield "auto" + return n_cpus_for_this_worker = ( n_cpus // num_pytest_workers if worker_index != num_pytest_workers - 1 else n_cpus - n_cpus // num_pytest_workers * (num_pytest_workers - 1) ) assert 1 <= n_cpus_for_this_worker <= n_cpus - return n_cpus_for_this_worker + yield n_cpus_for_this_worker + return if accelerator == "gpu" or (accelerator == "auto" and torch.cuda.is_available()): # Alternate GPUS between workers. n_gpus = torch.cuda.device_count() first_gpu_to_use = worker_index % n_gpus logger.info(f"Using GPU #{first_gpu_to_use}") - return [first_gpu_to_use] - return 1 # Use only one GPU by default if not distributed. + devices_before = os.environ.get("CUDA_VISIBLE_DEVICES") + os.environ["CUDA_VISIBLE_DEVICES"] = str(first_gpu_to_use) + yield [first_gpu_to_use] + if devices_before is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = devices_before + return + + yield 1 # Use only one GPU by default if not distributed. def _override_param_id(override: Param) -> str: diff --git a/project/main.py b/project/main.py index 9be65365..b0fa985f 100644 --- a/project/main.py +++ b/project/main.py @@ -82,7 +82,16 @@ def run(experiment: Experiment) -> tuple[str, float | None, dict]: datamodule = getattr(experiment.algorithm, "datamodule", experiment.datamodule) if datamodule is None: - experiment.trainer.fit(experiment.algorithm) + # todo: missing `rng` argument. + from project.trainers.jax_trainer import JaxTrainer + + if isinstance(experiment.trainer, JaxTrainer): + import jax.random + + experiment.trainer.fit(experiment.algorithm, rng=jax.random.key(0)) + else: + experiment.trainer.fit(experiment.algorithm) + else: assert isinstance(datamodule, LightningDataModule) experiment.trainer.fit( diff --git a/project/trainers/__init__.py b/project/trainers/__init__.py new file mode 100644 index 00000000..f27ba440 --- /dev/null +++ b/project/trainers/__init__.py @@ -0,0 +1,8 @@ +from lightning.pytorch.trainer.trainer import Trainer + +from .jax_trainer import JaxTrainer + +__all__ = [ + "JaxTrainer", + "Trainer", +] diff --git a/project/trainers/jax_trainer.py b/project/trainers/jax_trainer.py new file mode 100644 index 00000000..fa41a3ab --- /dev/null +++ b/project/trainers/jax_trainer.py @@ -0,0 +1,463 @@ +from __future__ import annotations + +import dataclasses +import functools +from collections.abc import Sequence +from pathlib import Path +from typing import Any, Protocol, runtime_checkable + +import chex +import flax.core +import flax.linen +import flax.struct +import jax +import jax.experimental +import jax.numpy as jnp +import lightning +import lightning.pytorch.callbacks +import lightning.pytorch.loggers +from hydra.core.hydra_config import HydraConfig +from typing_extensions import TypeVar + +from project.utils.typing_utils.jax_typing_utils import jit + +Ts = TypeVar("Ts", bound=flax.struct.PyTreeNode, default=flax.struct.PyTreeNode) +"""Type Variable for the training state.""" + +_B = TypeVar("_B", bound=flax.struct.PyTreeNode, default=flax.struct.PyTreeNode) +"""Type Variable for the batches produced (and consumed) by the algorithm.""" + +_MetricsT = TypeVar( + "_MetricsT", bound=flax.struct.PyTreeNode, default=flax.struct.PyTreeNode, covariant=True +) +"""Type Variable for the metrics produced by the algorithm.""" + +__all__ = ["JaxModule", "JaxCallback", "JaxTrainer"] + + +@runtime_checkable +class JaxModule(Protocol[Ts, _B, _MetricsT]): + """A protocol for algorithms that can be trained by the `JaxTrainer`. + + The `JaxRLExample` is an example that follows this structure and can be trained with a + `JaxTrainer`. + """ + + def init_train_state(self, rng: chex.PRNGKey) -> Ts: + """Create the initial training state.""" + raise NotImplementedError + + def get_batch(self, ts: Ts, batch_idx: int) -> tuple[Ts, _B]: + """Produces a batch of data.""" + raise NotImplementedError + + def training_step( + self, batch_idx: int, ts: Ts, batch: _B + ) -> tuple[Ts, flax.struct.PyTreeNode]: + """Update the training state using a "batch" of data.""" + raise NotImplementedError + + def eval_callback(self, ts: Ts) -> _MetricsT: + """Perform evaluation and return metrics.""" + raise NotImplementedError + + +class JaxCallback(flax.struct.PyTreeNode): + def setup(self, trainer: JaxTrainer, module: JaxModule[Ts], stage: str, ts: Ts): ... + def on_fit_start(self, trainer: JaxTrainer, module: JaxModule[Ts], ts: Ts): ... + def on_fit_end(self, trainer: JaxTrainer, module: JaxModule[Ts], ts: Ts): ... + def on_train_start(self, trainer: JaxTrainer, module: JaxModule[Ts], ts: Ts): ... + def on_train_end(self, trainer: JaxTrainer, module: JaxModule[Ts], ts: Ts): ... + def on_train_batch_start( + self, + trainer: JaxTrainer, + pl_module: JaxModule[Ts, _B], + batch: _B, + batch_index: int, + ts: Ts, + ) -> None: ... + def on_train_batch_end( + self, + trainer: JaxTrainer, + module: JaxModule[Ts, _B], + outputs: Any, + batch: _B, + batch_index: int, + ts: Ts, + ) -> None: ... + def on_train_epoch_start(self, trainer: JaxTrainer, module: JaxModule[Ts], ts: Ts): ... + def on_train_epoch_end(self, trainer: JaxTrainer, module: JaxModule[Ts], ts: Ts): ... + def on_validation_epoch_start(self, trainer: JaxTrainer, module: JaxModule[Ts], ts: Ts): ... + def on_validation_epoch_end(self, trainer: JaxTrainer, module: JaxModule[Ts], ts: Ts): ... + def teardown(self, trainer: JaxTrainer, module: JaxModule[Ts], stage: str, ts: Ts): ... + + +class JaxTrainer(flax.struct.PyTreeNode): + """A simplified version of the `lightning.Trainer` with a fully jitted training loop. + + ## Assumptions: + + - The algo object must match the `JaxModule` protocol (in other words, it should implement its + methods). + + ## Training loop + + This is the training loop, which is fully jitted: + + ```python + ts = algo.init_train_state(rng) + + setup("fit") + on_fit_start() + on_train_start() + + eval_metrics = [] + for epoch in range(self.max_epochs): + on_train_epoch_start() + + for step in range(self.training_steps_per_epoch): + + batch = algo.get_batch(ts, step) + + on_train_batch_start() + + ts, metrics = algo.training_step(step, ts, batch) + + on_train_batch_end() + + on_train_epoch_end() + + # Evaluation "loop" + on_validation_epoch_start() + epoch_eval_metrics = self.eval_epoch(ts, epoch, algo) + on_validation_epoch_start() + + eval_metrics.append(epoch_eval_metrics) + + return ts, eval_metrics + ``` + + ## Caveats + + - Some lightning callbacks can be used with this trainer and work well, but not all of them. + - You can either use Regular pytorch-lightning callbacks, or use `jax.vmap` on the `fit` method, + but not both. + - If you want to use [jax.vmap][] on the `fit` method, just remove the callbacks on the + Trainer for now. + + ## TODOs / ideas + + - Add a checkpoint callback with orbax-checkpoint? + """ + + max_epochs: int = flax.struct.field(pytree_node=False) + + training_steps_per_epoch: int = flax.struct.field(pytree_node=False) + + limit_val_batches: int = 0 + limit_test_batches: int = 0 + + # TODO: Getting some errors with the schema generation for lightning.Callback and + # lightning.pytorch.loggers.logger.Logger here if we keep the type annotation. + callbacks: Sequence = dataclasses.field(metadata={"pytree_node": False}, default_factory=tuple) + + logger: Any | None = flax.struct.field(pytree_node=False, default=None) + + # accelerator: str = flax.struct.field(pytree_node=False, default="auto") + # strategy: str = flax.struct.field(pytree_node=False, default="auto") + # devices: int | str = flax.struct.field(pytree_node=False, default="auto") + + # path to output directory, created dynamically by hydra + # path generation pattern is specified in `configs/hydra/default.yaml` + # use it to store all files generated during the run, like checkpoints and metrics + + default_root_dir: str | Path | None = flax.struct.field( + pytree_node=False, + default_factory=lambda: HydraConfig.get().runtime.output_dir, + ) + + # State variables: + # TODO: figure out how to cleanly store / update these. + current_epoch: int = flax.struct.field(pytree_node=True, default=0) + global_step: int = flax.struct.field(pytree_node=True, default=0) + + logged_metrics: dict = flax.struct.field(pytree_node=True, default_factory=dict) + callback_metrics: dict = flax.struct.field(pytree_node=True, default_factory=dict) + # todo: get the metrics from the callbacks? + # lightning.pytorch.loggers.CSVLogger.log_metrics + # TODO: Take a look at this method: + # lightning.pytorch.callbacks.progress.rich_progress.RichProgressBar.get_metrics + # return lightning.Trainer._logger_connector.progress_bar_metrics + progress_bar_metrics: dict = flax.struct.field(pytree_node=True, default_factory=dict) + + verbose: bool = flax.struct.field(pytree_node=False, default=False) + + @functools.partial(jit, static_argnames=["skip_initial_evaluation"]) + def fit( + self, + algo: JaxModule[Ts, _B, _MetricsT], + rng: chex.PRNGKey, + train_state: Ts | None = None, + skip_initial_evaluation: bool = False, + ) -> tuple[Ts, _MetricsT]: + """Full training loop in pure jax (a lot faster than when using pytorch-lightning). + + Unfolded version of `rejax.PPO.train`. + + Training loop in pure jax (a lot faster than when using pytorch-lightning). + """ + + if train_state is None and rng is None: + raise ValueError("Either train_state or rng must be provided") + + train_state = train_state if train_state is not None else algo.init_train_state(rng) + + if self.progress_bar_callback is not None: + if self.verbose: + jax.debug.print("Enabling the progress bar callback.") + jax.experimental.io_callback(self.progress_bar_callback.enable, ()) + + self._callback_hook("setup", self, algo, ts=train_state, partial_kwargs=dict(stage="fit")) + self._callback_hook("on_fit_start", self, algo, ts=train_state) + self._callback_hook("on_train_start", self, algo, ts=train_state) + + if self.logger: + jax.experimental.io_callback( + lambda algo: self.logger and self.logger.log_hyperparams(hparams_to_dict(algo)), + (), + algo, + ordered=True, + ) + + initial_evaluation: _MetricsT | None = None + if not skip_initial_evaluation: + initial_evaluation = algo.eval_callback(train_state) + + # Run the epoch loop `self.max_epoch` times. + train_state, evaluations = jax.lax.scan( + functools.partial(self.epoch_loop, algo=algo), + init=train_state, + xs=jnp.arange(self.max_epochs), # type: ignore + length=self.max_epochs, + ) + + if not skip_initial_evaluation: + assert initial_evaluation is not None + evaluations: _MetricsT = jax.tree.map( + lambda i, ev: jnp.concatenate((jnp.expand_dims(i, 0), ev)), + initial_evaluation, + evaluations, + ) + + if self.logger is not None: + jax.block_until_ready((train_state, evaluations)) + # jax.debug.print("Saving...") + jax.experimental.io_callback( + functools.partial(self.logger.finalize, status="success"), () + ) + + self._callback_hook("on_fit_end", self, algo, ts=train_state) + self._callback_hook("on_train_end", self, algo, ts=train_state) + self._callback_hook( + "teardown", self, algo, ts=train_state, partial_kwargs={"stage": "fit"} + ) + + return train_state, evaluations + + # @jit + def epoch_loop(self, ts: Ts, epoch: int, algo: JaxModule[Ts, _B, _MetricsT]): + # todo: Some lightning callbacks try to get the "trainer.current_epoch". + # FIXME: Hacky: Present a trainer with a different value of `self.current_epoch` to + # the callbacks. + # chex.assert_scalar_in(epoch, 0, self.max_epochs) + # TODO: Can't just set current_epoch to `epoch` as `epoch` is a Traced value. + # todo: need to have the callback take in the actual int value. + # jax.debug.print("Starting epoch {epoch}", epoch=epoch) + + self = self.replace(current_epoch=epoch) # doesn't quite work? + ts = self.training_epoch(ts=ts, epoch=epoch, algo=algo) + eval_metrics = self.eval_epoch(ts=ts, epoch=epoch, algo=algo) + return ts, eval_metrics + + # @jit + def training_epoch(self, ts: Ts, epoch: int, algo: JaxModule[Ts, _B, _MetricsT]): + # Run a few training iterations + self._callback_hook("on_train_epoch_start", self, algo, ts=ts) + + ts = jax.lax.fori_loop( + 0, + self.training_steps_per_epoch, + # drop training metrics for now. + functools.partial(self.training_step, algo=algo), + ts, + ) + + self._callback_hook("on_train_epoch_end", self, algo, ts=ts) + return ts + + # @jit + def eval_epoch(self, ts: Ts, epoch: int, algo: JaxModule[Ts, _B, _MetricsT]): + self._callback_hook("on_validation_epoch_start", self, algo, ts=ts) + + # todo: split up into eval batch and eval step? + eval_metrics = algo.eval_callback(ts=ts) + + self._callback_hook("on_validation_epoch_end", self, algo, ts=ts) + + return eval_metrics + + # @jit + def training_step(self, batch_idx: int, ts: Ts, algo: JaxModule[Ts, _B, _MetricsT]): + """Training step in pure jax (joined data collection + training). + + *MUCH* faster than using pytorch-lightning, but you lose the callbacks and such. + """ + # todo: rename to `get_training_batch`? + ts, batch = algo.get_batch(ts, batch_idx=batch_idx) + + self._callback_hook("on_train_batch_start", self, algo, batch, batch_idx, ts=ts) + + ts, metrics = algo.training_step(batch_idx=batch_idx, ts=ts, batch=batch) + + if self.logger is not None: + # todo: Clean this up. logs metrics. + jax.experimental.io_callback( + lambda metrics, batch_index: self.logger + and self.logger.log_metrics( + jax.tree.map(lambda v: v.mean(), metrics), batch_index + ), + (), + dataclasses.asdict(metrics) if dataclasses.is_dataclass(metrics) else metrics, + batch_idx, + ) + + self._callback_hook("on_train_batch_end", self, algo, metrics, batch, batch_idx, ts=ts) + + return ts + + ### Hooks to mimic those of lightning.Trainer + + def _callback_hook( + self, + hook_name: str, + /, + *hook_args, + ts: Ts, + partial_kwargs: dict | None = None, + sharding: jax.sharding.SingleDeviceSharding | None = None, + ordered: bool = True, + **hook_kwargs, + ): + """Call a hook on all callbacks.""" + # with jax.disable_jit(): + for i, callback in enumerate(self.callbacks): + # assert hasattr(callback, hook_name) + + method = getattr(callback, hook_name) + if partial_kwargs: + method = functools.partial(method, **partial_kwargs) + if self.verbose: + jax.debug.print( + "Epoch {current_epoch}/{max_epochs}: " + + f"Calling hook {hook_name} on callback {callback}" + + "{i}", + i=i, + current_epoch=self.current_epoch, + ordered=True, + max_epochs=self.max_epochs, + ) + jax.experimental.io_callback( + method, + (), + *hook_args, + **({"ts": ts} if isinstance(callback, JaxCallback) else {}), + **hook_kwargs, + sharding=sharding, + ordered=ordered if not isinstance(callback, JaxCallback) else False, + ) + + # Compat for RichProgressBar + @property + def is_global_zero(self) -> bool: + return True + + @property + def num_training_batches(self) -> int: + return self.training_steps_per_epoch + + @property + def loggers(self) -> list[lightning.pytorch.loggers.Logger]: + if isinstance(self.logger, list | tuple): + return list(self.logger) + if self.logger is not None: + return [self.logger] + return [] + + # @property + # def progress_bar_metrics(self) -> dict[str, float]: + + # return {} + + @property + def progress_bar_callback(self) -> lightning.pytorch.callbacks.ProgressBar | None: + for c in self.callbacks: + if isinstance(c, lightning.pytorch.callbacks.ProgressBar): + return c + return None + + @property + def state(self): + from lightning.pytorch.trainer.states import ( + RunningStage, + TrainerFn, + TrainerState, + TrainerStatus, + ) + + return TrainerState( + fn=TrainerFn.FITTING, + status=TrainerStatus.RUNNING, + stage=RunningStage.TRAINING, + ) + # self._trainer.state.fn != "fit" + # or self._trainer.sanity_checking + # or self._trainer.progress_bar_callback.train_progress_bar_id != task.id + # ): + + @property + def sanity_checking(self) -> bool: + from lightning.pytorch.trainer.states import RunningStage + + return self.state.stage == RunningStage.SANITY_CHECKING + + @property + def training(self) -> bool: + from lightning.pytorch.trainer.states import RunningStage + + return self.state.stage == RunningStage.TRAINING + + @property + def log_dir(self) -> Path | None: + # copied from lightning.Trainer + if len(self.loggers) > 0: + if not isinstance( + self.loggers[0], + lightning.pytorch.loggers.TensorBoardLogger | lightning.pytorch.loggers.CSVLogger, + ): + dirpath = self.loggers[0].save_dir + else: + dirpath = self.loggers[0].log_dir + else: + dirpath = self.default_root_dir + if dirpath: + return Path(dirpath) + return None + + +def hparams_to_dict(algo: flax.struct.PyTreeNode) -> dict: + """Convert the learner struct to a serializable dict.""" + val = dataclasses.asdict( + jax.tree.map(lambda arr: arr.tolist() if isinstance(arr, jnp.ndarray) else arr, algo) + ) + val = jax.tree.map(lambda v: getattr(v, "__name__", str(v)) if callable(v) else v, val) + return val diff --git a/project/utils/auto_schema.py b/project/utils/auto_schema.py index b2566c2f..f9a0dffc 100644 --- a/project/utils/auto_schema.py +++ b/project/utils/auto_schema.py @@ -391,6 +391,7 @@ def add_schemas_to_all_hydra_configs( hydra.errors.MissingConfigException, hydra.errors.ConfigCompositionException, omegaconf.errors.InterpolationResolutionError, + Exception, # todo: remove this to harden the code. ) as exc: logger.warning( f"Unable to create a schema for config {pretty_config_file_name}: {exc}" @@ -914,7 +915,7 @@ def _get_schema_from_target(config: dict | DictConfig) -> ObjectSchema | Schema: hydra_defaults=config.get("defaults", None), hydra_recursive=False, hydra_convert="all", - zen_dataclass={"cls_name": target.__qualname__}, + zen_dataclass={"cls_name": target.__qualname__.replace(".", "_")}, # zen_wrappers=pydantic_parser, # unsure if this is how it works? ) @@ -948,7 +949,7 @@ def _get_schema_from_target(config: dict | DictConfig) -> ObjectSchema | Schema: if init_docstring := inspect.getdoc(target_or_base_class.__init__): docs_to_search.append(dp.parse(init_docstring)) else: - assert inspect.isfunction(target) + assert inspect.isfunction(target) or inspect.ismethod(target), target docstring = inspect.getdoc(target) if docstring: docs_to_search = [dp.parse(docstring)] diff --git a/project/utils/hydra_config_utils.py b/project/utils/hydra_config_utils.py index 5feccef5..68b908ed 100644 --- a/project/utils/hydra_config_utils.py +++ b/project/utils/hydra_config_utils.py @@ -147,6 +147,7 @@ def get_all_configs_in_group_of_type( config_name: get_target_of_config(config_group, config_name) for config_name in config_names } + names_to_types: dict[str, type] = {} for name, target in names_to_targets.items(): if inspect.isclass(target): @@ -154,11 +155,13 @@ def get_all_configs_in_group_of_type( continue if ( - inspect.isfunction(target) + (inspect.isfunction(target) or inspect.ismethod(target)) and (annotations := typing.get_type_hints(target)) and (return_type := annotations.get("return")) - and inspect.isclass(return_type) + and (inspect.isclass(return_type) or inspect.isclass(typing.get_origin(return_type))) ): + # Resolve generic aliases if present. + return_type = typing.get_origin(return_type) or return_type logger.info( f"Assuming that the function {target} creates objects of type {return_type} based " f"on its return type annotation." diff --git a/project/utils/typing_utils/jax_typing_utils.py b/project/utils/typing_utils/jax_typing_utils.py new file mode 100644 index 00000000..57376765 --- /dev/null +++ b/project/utils/typing_utils/jax_typing_utils.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import dataclasses +from collections.abc import Callable, Iterable, Mapping, Sequence +from typing import Any, Concatenate, Literal, ParamSpec, overload + +import jax +import jax.experimental +from jax._src.sharding_impls import UNSPECIFIED, Device, UnspecifiedValue +from typing_extensions import TypeVar + +P = ParamSpec("P") +Out = TypeVar("Out", covariant=True) + + +# @functools.wraps(jax.jit) +def jit( + fn: Callable[P, Out], + in_shardings: UnspecifiedValue = UNSPECIFIED, + out_shardings: UnspecifiedValue = UNSPECIFIED, + static_argnums: int | Sequence[int] | None = None, + static_argnames: str | Iterable[str] | None = None, + donate_argnums: int | Sequence[int] | None = None, + donate_argnames: str | Iterable[str] | None = None, + keep_unused: bool = False, + device: Device | None = None, + backend: str | None = None, + inline: bool = False, + abstracted_axes: Any | None = None, +) -> Callable[P, Out]: + # Small type hint fix for jax's `jit` (preserves the signature of the callable). + # TODO: Remove once [our PR to Jax](https://github.com/jax-ml/jax/pull/23720) is merged + + return jax.jit( + fn, + in_shardings=in_shardings, + out_shardings=out_shardings, + static_argnums=static_argnums, + static_argnames=static_argnames, + donate_argnums=donate_argnums, + donate_argnames=donate_argnames, + keep_unused=keep_unused, + device=device, + backend=backend, + inline=inline, + abstracted_axes=abstracted_axes, + ) + + +In = TypeVar("In") +Aux = TypeVar("Aux") + + +# @functools.wraps(jax.value_and_grad) +def value_and_grad( + fn: Callable[Concatenate[In, P], tuple[Out, Aux]], + argnums: Literal[0] = 0, + has_aux: Literal[True] = True, +) -> Callable[Concatenate[In, P], tuple[tuple[Out, Aux], In]]: + # Small type hint fix for jax's `value_and_grad` (preserves the signature of the callable). + return jax.value_and_grad(fn, argnums=argnums, has_aux=has_aux) # type: ignore + + +_T = TypeVar("_T") + + +# @functools.wraps(flax.struct.field) +@overload # `default` and `default_factory` are optional and mutually exclusive. +def field( + *, + default: _T, + init: bool = True, + repr: bool = True, + hash: bool | None = None, + compare: bool = True, + metadata: Mapping[Any, Any] | None = None, + kw_only: bool = ..., + pytree_node: bool = True, +) -> _T: ... +@overload +def field( + *, + default_factory: Callable[[], _T], + init: bool = True, + repr: bool = True, + hash: bool | None = None, + compare: bool = True, + metadata: Mapping[Any, Any] | None = None, + kw_only: bool = ..., + pytree_node: bool = True, +) -> _T: ... +@overload +def field( + *, + init: bool = True, + repr: bool = True, + hash: bool | None = None, + compare: bool = True, + metadata: Mapping[Any, Any] | None = None, + kw_only: bool = ..., + pytree_node: bool = True, +) -> Any: ... + + +def field( + *, + default=dataclasses.MISSING, + default_factory=dataclasses.MISSING, + init=True, + repr=True, + hash=None, + compare=True, + metadata: Mapping[Any, Any] | None = None, + kw_only=dataclasses.MISSING, + pytree_node: bool | None = None, +): + """Small Typing fix for `flax.struct.field`. + + - Add type annotations so it doesn't drop the signature of the `dataclasses.field` function. + - Make the `pytree_node` has a default value of `False` for ints and bools, and `True` for + everything else. + """ + if pytree_node is None and isinstance(default, int): # note: also includes `bool`. + pytree_node = False + if pytree_node is None: + pytree_node = True + if metadata is None: + metadata = {} + else: + metadata = dict(metadata) + metadata.setdefault("pytree_node", pytree_node) + return dataclasses.field( + default=default, + default_factory=default_factory, + init=init, + repr=repr, + hash=hash, + compare=compare, + metadata=metadata, + kw_only=kw_only, + ) # type: ignore diff --git a/pyproject.toml b/pyproject.toml index 395ef4d6..272aa387 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,8 +7,9 @@ authors = [ { name = "César Miguel Valdez Córdova", email = "cesar.valdez@mila.quebec" }, ] dependencies = [ - "torch>=2.4.0", - "jax>=0.4.31", + "torch==2.4.1", + "jax==0.4.33", + "jaxlib==0.4.33", "hydra-core>=1.3.2", "wandb>=0.17.6", "lightning>=2.4.0", @@ -31,6 +32,10 @@ dependencies = [ # Only pinning this so that we can install the hydra-orion-sweeper. "kaleido==0.2.1", "hydra-orion-sweeper>=1.6.4", + "gymnax @ git+https://www.github.com/lebrice/gymnax@fix-classic-control-rendering", + "rejax>=0.1.0", + "xtils[jitpp] @ git+https://github.com/jessefarebro/xtils", + "gymnasium[classic-control]>=0.29.1", ] readme = "README.md" requires-python = ">= 3.10" @@ -52,7 +57,22 @@ docs = [ "mkdocs-section-index>=0.3.9", "mkdocs-macros-plugin>=1.0.5", ] -gpu = ["jax[cuda12]>=0.4.31"] +gpu = [ + "jax[cuda12]>=0.4.31", + "nvidia-cublas-cu12==12.1.3.1", + "nvidia-cuda-cupti-cu12==12.1.105", + "nvidia-cuda-nvcc-cu12==12.6.68", + "nvidia-cuda-nvrtc-cu12==12.1.105", + "nvidia-cuda-runtime-cu12==12.1.105", + "nvidia-cudnn-cu12==9.1.0.70", + "nvidia-cufft-cu12==11.0.2.54", + "nvidia-curand-cu12==10.3.2.106", + "nvidia-cusolver-cu12==11.4.5.107", + "nvidia-cusparse-cu12==12.1.0.106", + "nvidia-nccl-cu12==2.20.5", + "nvidia-nvjitlink-cu12==12.6.68", + "nvidia-nvtx-cu12==12.1.105", +] [build-system] requires = ["hatchling"] @@ -62,6 +82,7 @@ build-backend = "hatchling.build" [tool.pytest.ini_options] testpaths = ["project", "docs"] norecursedirs = [".venv"] +# Required to use torch deterministic mode. env = ["CUBLAS_WORKSPACE_CONFIG=:4096:8"] addopts = [ # todo: look into using https://github.com/scientific-python/pytest-doctestplus diff --git a/requirements-dev.lock b/requirements-dev.lock index 3b478845..a9b1c450 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -11,9 +11,16 @@ -e file:. absl-py==2.1.0 + # via brax # via chex + # via distrax + # via dm-env + # via ml-collections + # via mujoco + # via mujoco-mjx # via optax # via orbax-checkpoint + # via tensorflow-probability aiohappyeyeballs==2.4.0 # via aiohttp aiohttp==3.10.5 @@ -38,8 +45,12 @@ beautifulsoup4==4.12.3 # via gdown black==24.8.0 # via research-project-template +blinker==1.8.2 + # via flask bracex==2.5 # via wcmatch +brax==0.10.5 + # via rejax certifi==2024.8.30 # via requests # via sentry-sdk @@ -48,20 +59,29 @@ cfgv==3.4.0 charset-normalizer==3.3.2 # via requests chex==0.1.86 + # via distrax + # via evosax + # via gymnax # via optax click==8.1.7 # via black + # via flask # via mkdocs # via mkdocstrings # via wandb cloudpickle==3.0.0 + # via gym + # via gymnasium # via orion # via submitit + # via tensorflow-probability colorama==0.4.6 # via griffe # via mkdocs-material colorlog==6.8.2 # via hydra-colorlog +contextlib2==21.6.0 + # via ml-collections contourpy==1.3.0 # via matplotlib coverage==7.6.1 @@ -69,26 +89,42 @@ coverage==7.6.1 # via pytest-testmon cycler==0.12.1 # via matplotlib -datasets==2.21.0 +datasets==3.0.0 # via evaluate # via research-project-template +decorator==5.1.1 + # via tensorflow-probability dill==0.3.8 # via datasets # via evaluate # via multiprocess distlib==0.3.8 # via virtualenv +distrax==0.1.5 + # via rejax +dm-env==1.6 + # via brax +dm-tree==0.1.8 + # via dm-env + # via tensorflow-probability dnspython==2.6.1 # via pymongo docker-pycreds==0.4.0 # via wandb docstring-parser==0.16 # via simple-parsing +dotmap==1.3.30 + # via evosax etils==1.9.4 + # via brax + # via mujoco + # via mujoco-mjx # via optax # via orbax-checkpoint -evaluate==0.4.2 +evaluate==0.4.3 # via research-project-template +evosax==0.1.6 + # via rejax exceptiongroup==1.2.2 # via pytest execnet==2.1.1 @@ -98,7 +134,9 @@ falcon==3.1.3 # via orion falcon-cors==1.1.7 # via orion -filelock==3.15.4 +farama-notifications==0.0.4 + # via gymnasium +filelock==3.16.1 # via datasets # via gdown # via huggingface-hub @@ -107,7 +145,16 @@ filelock==3.15.4 # via transformers # via triton # via virtualenv +flask==3.0.3 + # via brax + # via flask-cors +flask-cors==5.0.0 + # via brax flax==0.8.5 + # via brax + # via evosax + # via gymnax + # via rejax # via torch-jax-interop fonttools==4.53.1 # via matplotlib @@ -122,6 +169,8 @@ fsspec==2024.6.1 # via lightning # via pytorch-lightning # via torch +gast==0.6.0 + # via tensorflow-probability gdown==5.2.0 # via research-project-template ghp-import==2.1.0 @@ -131,11 +180,26 @@ gitdb==4.0.11 gitpython==3.1.43 # via orion # via wandb -griffe==1.2.0 +glfw==2.7.0 + # via mujoco +griffe==1.3.1 # via mkdocstrings-python +grpcio==1.66.1 + # via brax gunicorn==23.0.0 # via orion -huggingface-hub==0.24.6 +gym==0.26.2 + # via brax + # via gymnax +gym-notices==0.0.8 + # via gym +gymnasium==0.29.1 + # via gymnax + # via research-project-template +gymnax @ git+https://www.github.com/lebrice/gymnax@1d4a4b45cfa291de896cd7005fe624420dc6106c + # via rejax + # via research-project-template +huggingface-hub==0.25.0 # via datasets # via evaluate # via tokenizers @@ -156,34 +220,54 @@ hydra-submitit-launcher==1.2.0 # via research-project-template hydra-zen==0.13.0 # via research-project-template -identify==2.6.0 +identify==2.6.1 # via pre-commit -idna==3.8 +idna==3.10 # via requests # via yarl -importlib-resources==6.4.4 +importlib-resources==6.4.5 # via etils iniconfig==2.0.0 # via pytest -jax==0.4.31 +itsdangerous==2.2.0 + # via flask +jax==0.4.33 + # via brax # via chex + # via distrax + # via evosax # via flax + # via gymnax + # via jaxopt + # via mujoco-mjx # via optax # via orbax-checkpoint # via pytorch2jax # via research-project-template # via torch-jax-interop -jax-cuda12-pjrt==0.4.31 + # via xtils +jax-cuda12-pjrt==0.4.33 # via jax-cuda12-plugin -jax-cuda12-plugin==0.4.31 +jax-cuda12-plugin==0.4.33 # via jax -jaxlib==0.4.31 +jaxlib==0.4.33 + # via brax # via chex + # via distrax + # via evosax + # via gymnax # via jax + # via jaxopt + # via mujoco-mjx # via optax # via orbax-checkpoint # via pytorch2jax + # via research-project-template +jaxopt==0.8.3 + # via brax jinja2==3.1.4 + # via brax + # via flask # via mkdocs # via mkdocs-macros-plugin # via mkdocs-material @@ -219,8 +303,12 @@ markupsafe==2.1.5 # via mkdocs # via mkdocs-autorefs # via mkdocstrings + # via werkzeug matplotlib==3.9.2 + # via evosax + # via gymnax # via research-project-template + # via seaborn mdurl==0.1.2 # via markdown-it-py mergedeep==1.3.4 @@ -248,9 +336,9 @@ mkdocs-get-deps==0.2.0 # via mkdocs mkdocs-literate-nav==0.6.1 # via research-project-template -mkdocs-macros-plugin==1.0.5 +mkdocs-macros-plugin==1.2.0 # via research-project-template -mkdocs-material==9.5.34 +mkdocs-material==9.5.35 # via research-project-template mkdocs-material-extensions==1.3.1 # via mkdocs-material @@ -258,22 +346,29 @@ mkdocs-section-index==0.3.9 # via research-project-template mkdocs-video==1.5.0 # via research-project-template -mkdocstrings==0.26.0 +mkdocstrings==0.26.1 # via mkdocstrings-python # via research-project-template mkdocstrings-python==1.11.1 # via mkdocstrings mktestdocs==0.2.2 -ml-dtypes==0.4.0 +ml-collections==0.1.1 + # via brax +ml-dtypes==0.5.0 # via jax # via jaxlib # via tensorstore mpmath==1.3.0 # via sympy -msgpack==1.0.8 +msgpack==1.1.0 # via flax # via orbax-checkpoint -multidict==6.0.5 +mujoco==3.2.3 + # via brax + # via mujoco-mjx +mujoco-mjx==3.2.3 + # via brax +multidict==6.1.0 # via aiohttp # via yarl multiprocess==0.70.16 @@ -290,67 +385,93 @@ networkx==3.3 nodeenv==1.9.1 # via pre-commit numpy==1.26.4 + # via brax # via chex # via contourpy # via datasets + # via distrax + # via dm-env # via evaluate + # via evosax # via flax + # via gym + # via gymnasium # via jax # via jaxlib + # via jaxopt # via matplotlib # via ml-dtypes + # via mujoco # via opt-einsum # via optax # via orbax-checkpoint # via orion # via pandas # via pyarrow + # via rejax # via scikit-learn # via scikit-optimize # via scipy + # via seaborn # via tensor-regression + # via tensorboardx + # via tensorflow-probability # via tensorstore # via torchmetrics # via torchvision # via transformers + # via trimesh nvidia-cublas-cu12==12.1.3.1 # via jax-cuda12-plugin # via nvidia-cudnn-cu12 # via nvidia-cusolver-cu12 + # via research-project-template # via torch nvidia-cuda-cupti-cu12==12.1.105 # via jax-cuda12-plugin + # via research-project-template # via torch nvidia-cuda-nvcc-cu12==12.6.68 # via jax-cuda12-plugin + # via research-project-template nvidia-cuda-nvrtc-cu12==12.1.105 + # via research-project-template # via torch nvidia-cuda-runtime-cu12==12.1.105 # via jax-cuda12-plugin + # via research-project-template # via torch nvidia-cudnn-cu12==9.1.0.70 # via jax-cuda12-plugin + # via research-project-template # via torch nvidia-cufft-cu12==11.0.2.54 # via jax-cuda12-plugin + # via research-project-template # via torch nvidia-curand-cu12==10.3.2.106 + # via research-project-template # via torch nvidia-cusolver-cu12==11.4.5.107 # via jax-cuda12-plugin + # via research-project-template # via torch nvidia-cusparse-cu12==12.1.0.106 # via jax-cuda12-plugin # via nvidia-cusolver-cu12 + # via research-project-template # via torch nvidia-nccl-cu12==2.20.5 # via jax-cuda12-plugin + # via research-project-template # via torch nvidia-nvjitlink-cu12==12.6.68 # via jax-cuda12-plugin # via nvidia-cusolver-cu12 # via nvidia-cusparse-cu12 + # via research-project-template nvidia-nvtx-cu12==12.1.105 + # via research-project-template # via torch omegaconf==2.3.0 # via hydra-core @@ -359,8 +480,11 @@ omegaconf==2.3.0 opt-einsum==3.3.0 # via jax optax==0.2.3 + # via brax # via flax -orbax-checkpoint==0.6.1 + # via rejax +orbax-checkpoint==0.6.4 + # via brax # via flax orion==0.2.7 # via hydra-orion-sweeper @@ -375,10 +499,12 @@ packaging==24.1 # via lightning-utilities # via matplotlib # via mkdocs + # via mkdocs-macros-plugin # via plotly # via pytest # via pytorch-lightning # via scikit-optimize + # via tensorboardx # via torchmetrics # via transformers paginate==0.5.7 @@ -387,25 +513,28 @@ pandas==2.2.2 # via datasets # via evaluate # via orion + # via seaborn pathspec==0.12.1 # via black # via mkdocs pillow==10.4.0 + # via brax # via matplotlib # via torchvision -platformdirs==4.2.2 +platformdirs==4.3.6 # via black # via mkdocs-get-deps # via mkdocstrings # via virtualenv # via wandb -plotly==5.24.0 +plotly==5.24.1 # via orion pluggy==1.5.0 # via pytest pre-commit==3.8.0 -protobuf==5.28.0 +protobuf==5.28.2 # via orbax-checkpoint + # via tensorboardx # via wandb psutil==6.0.0 # via orion @@ -416,23 +545,27 @@ pyaml==24.7.0 # via scikit-optimize pyarrow==17.0.0 # via datasets -pydantic==2.8.2 +pydantic==2.9.2 # via research-project-template -pydantic-core==2.20.1 +pydantic-core==2.23.4 # via pydantic +pygame==2.6.0 + # via gymnasium pygments==2.18.0 # via mkdocs-material # via rich pymdown-extensions==10.9 # via mkdocs-material # via mkdocstrings -pymongo==4.8.0 +pymongo==4.9.1 # via orion +pyopengl==3.1.7 + # via mujoco pyparsing==3.1.4 # via matplotlib pysocks==1.7.1 # via requests -pytest==8.3.2 +pytest==8.3.3 # via orion # via pytest-benchmark # via pytest-cov @@ -447,7 +580,7 @@ pytest-benchmark==4.0.0 pytest-cov==5.0.0 pytest-datadir==1.5.0 # via pytest-regressions -pytest-env==1.1.3 +pytest-env==1.1.5 pytest-regressions==2.5.0 # via tensor-regression pytest-skip-slow==0.0.5 @@ -459,20 +592,25 @@ python-dateutil==2.9.0.post0 # via matplotlib # via mkdocs-macros-plugin # via pandas +pytinyrenderer==0.0.14 + # via brax pytorch-lightning==2.4.0 # via lightning pytorch2jax==0.1.0 # via torch-jax-interop -pytz==2024.1 +pytz==2024.2 # via pandas pyyaml==6.0.2 # via datasets + # via evosax # via flax + # via gymnax # via huggingface-hub # via lightning # via mkdocs # via mkdocs-get-deps # via mkdocs-macros-plugin + # via ml-collections # via omegaconf # via orbax-checkpoint # via orion @@ -482,13 +620,16 @@ pyyaml==6.0.2 # via pytest-regressions # via pytorch-lightning # via pyyaml-env-tag + # via rejax # via transformers # via wandb pyyaml-env-tag==0.1 # via mkdocs -regex==2024.7.24 +regex==2024.9.11 # via mkdocs-material # via transformers +rejax==0.1.0 + # via research-project-template requests==2.32.3 # via datasets # via evaluate @@ -498,50 +639,61 @@ requests==2.32.3 # via orion # via transformers # via wandb -rich==13.8.0 +rich==13.8.1 # via flax # via research-project-template -ruff==0.6.3 -safetensors==0.4.4 +ruff==0.6.5 +safetensors==0.4.5 # via transformers -scikit-learn==1.5.1 +scikit-learn==1.5.2 # via orion # via research-project-template # via scikit-optimize scikit-optimize==0.10.2 # via orion scipy==1.14.1 + # via brax # via jax # via jaxlib + # via jaxopt + # via mujoco-mjx # via orion # via scikit-learn # via scikit-optimize -sentry-sdk==2.13.0 +seaborn==0.13.2 + # via gymnax +sentry-sdk==2.14.0 # via wandb setproctitle==1.3.3 # via wandb -setuptools==74.1.1 +setuptools==75.1.0 # via lightning-utilities # via wandb -simple-parsing==0.1.5 +simple-parsing==0.1.6 # via research-project-template six==1.16.0 # via docker-pycreds + # via ml-collections # via python-dateutil + # via tensorflow-probability smmap==5.0.1 # via gitdb soupsieve==2.6 # via beautifulsoup4 -submitit==1.5.1 +submitit==1.5.2 # via hydra-submitit-launcher -sympy==1.13.2 +sympy==1.13.3 # via torch tabulate==0.9.0 # via orion tenacity==9.0.0 # via plotly tensor-regression==0.0.8 -tensorstore==0.1.64 +tensorboardx==2.6.2.2 + # via brax +tensorflow-probability==0.24.0 + # via distrax +tensorstore==0.1.65 # via flax # via orbax-checkpoint termcolor==2.4.0 @@ -557,7 +709,7 @@ tomli==2.0.1 # via pytest-env toolz==0.12.1 # via chex -torch==2.4.0 +torch==2.4.1 # via lightning # via pytorch-lightning # via pytorch2jax @@ -568,10 +720,10 @@ torch==2.4.0 # via torchvision torch-jax-interop==0.0.7 # via research-project-template -torchmetrics==1.4.1 +torchmetrics==1.4.2 # via lightning # via pytorch-lightning -torchvision==0.19.0 +torchvision==0.19.1 # via research-project-template tqdm==4.66.5 # via datasets @@ -584,18 +736,24 @@ tqdm==4.66.5 # via transformers transformers==4.44.2 # via research-project-template +trimesh==4.4.9 + # via brax + # via mujoco-mjx triton==3.0.0 # via torch typing-extensions==4.12.2 # via black + # via brax # via chex # via etils # via flax + # via gymnasium # via huggingface-hub # via hydra-orion-sweeper # via hydra-zen # via lightning # via lightning-utilities + # via multidict # via orbax-checkpoint # via pydantic # via pydantic-core @@ -605,21 +763,25 @@ typing-extensions==4.12.2 # via torch tzdata==2024.1 # via pandas -urllib3==2.2.2 +urllib3==2.2.3 # via requests # via sentry-sdk -virtualenv==20.26.3 +virtualenv==20.26.5 # via pre-commit -wandb==0.17.8 +wandb==0.18.1 # via research-project-template watchdog==5.0.2 # via mkdocs wcmatch==9.0 # via mkdocs-awesome-pages-plugin +werkzeug==3.0.4 + # via flask +xtils @ git+https://github.com/jessefarebro/xtils@b20807d980b059c598a77ff2ef05075594f49f63 + # via research-project-template xxhash==3.5.0 # via datasets # via evaluate -yarl==1.9.8 +yarl==1.11.1 # via aiohttp -zipp==3.20.1 +zipp==3.20.2 # via etils diff --git a/requirements.lock b/requirements.lock index ca6eac08..389cd098 100644 --- a/requirements.lock +++ b/requirements.lock @@ -11,9 +11,16 @@ -e file:. absl-py==2.1.0 + # via brax # via chex + # via distrax + # via dm-env + # via ml-collections + # via mujoco + # via mujoco-mjx # via optax # via orbax-checkpoint + # via tensorflow-probability aiohappyeyeballs==2.4.0 # via aiohttp aiohttp==3.10.5 @@ -38,50 +45,79 @@ beautifulsoup4==4.12.3 # via gdown black==24.8.0 # via research-project-template +blinker==1.8.2 + # via flask bracex==2.5 # via wcmatch +brax==0.10.5 + # via rejax certifi==2024.8.30 # via requests # via sentry-sdk charset-normalizer==3.3.2 # via requests chex==0.1.86 + # via distrax + # via evosax + # via gymnax # via optax click==8.1.7 # via black + # via flask # via mkdocs # via mkdocstrings # via wandb cloudpickle==3.0.0 + # via gym + # via gymnasium # via orion # via submitit + # via tensorflow-probability colorama==0.4.6 # via griffe # via mkdocs-material colorlog==6.8.2 # via hydra-colorlog +contextlib2==21.6.0 + # via ml-collections contourpy==1.3.0 # via matplotlib cycler==0.12.1 # via matplotlib -datasets==2.21.0 +datasets==3.0.0 # via evaluate # via research-project-template +decorator==5.1.1 + # via tensorflow-probability dill==0.3.8 # via datasets # via evaluate # via multiprocess +distrax==0.1.5 + # via rejax +dm-env==1.6 + # via brax +dm-tree==0.1.8 + # via dm-env + # via tensorflow-probability dnspython==2.6.1 # via pymongo docker-pycreds==0.4.0 # via wandb docstring-parser==0.16 # via simple-parsing +dotmap==1.3.30 + # via evosax etils==1.9.4 + # via brax + # via mujoco + # via mujoco-mjx # via optax # via orbax-checkpoint -evaluate==0.4.2 +evaluate==0.4.3 # via research-project-template +evosax==0.1.6 + # via rejax exceptiongroup==1.2.2 # via pytest falcon==3.1.3 @@ -89,7 +125,9 @@ falcon==3.1.3 # via orion falcon-cors==1.1.7 # via orion -filelock==3.15.4 +farama-notifications==0.0.4 + # via gymnasium +filelock==3.16.1 # via datasets # via gdown # via huggingface-hub @@ -97,7 +135,16 @@ filelock==3.15.4 # via torch # via transformers # via triton +flask==3.0.3 + # via brax + # via flask-cors +flask-cors==5.0.0 + # via brax flax==0.8.5 + # via brax + # via evosax + # via gymnax + # via rejax # via torch-jax-interop fonttools==4.53.1 # via matplotlib @@ -112,6 +159,8 @@ fsspec==2024.6.1 # via lightning # via pytorch-lightning # via torch +gast==0.6.0 + # via tensorflow-probability gdown==5.2.0 # via research-project-template ghp-import==2.1.0 @@ -121,11 +170,26 @@ gitdb==4.0.11 gitpython==3.1.43 # via orion # via wandb -griffe==1.2.0 +glfw==2.7.0 + # via mujoco +griffe==1.3.1 # via mkdocstrings-python +grpcio==1.66.1 + # via brax gunicorn==23.0.0 # via orion -huggingface-hub==0.24.6 +gym==0.26.2 + # via brax + # via gymnax +gym-notices==0.0.8 + # via gym +gymnasium==0.29.1 + # via gymnax + # via research-project-template +gymnax @ git+https://www.github.com/lebrice/gymnax@1d4a4b45cfa291de896cd7005fe624420dc6106c + # via rejax + # via research-project-template +huggingface-hub==0.25.0 # via datasets # via evaluate # via tokenizers @@ -146,32 +210,52 @@ hydra-submitit-launcher==1.2.0 # via research-project-template hydra-zen==0.13.0 # via research-project-template -idna==3.8 +idna==3.10 # via requests # via yarl -importlib-resources==6.4.4 +importlib-resources==6.4.5 # via etils iniconfig==2.0.0 # via pytest -jax==0.4.31 +itsdangerous==2.2.0 + # via flask +jax==0.4.33 + # via brax # via chex + # via distrax + # via evosax # via flax + # via gymnax + # via jaxopt + # via mujoco-mjx # via optax # via orbax-checkpoint # via pytorch2jax # via research-project-template # via torch-jax-interop -jax-cuda12-pjrt==0.4.31 + # via xtils +jax-cuda12-pjrt==0.4.33 # via jax-cuda12-plugin -jax-cuda12-plugin==0.4.31 +jax-cuda12-plugin==0.4.33 # via jax -jaxlib==0.4.31 +jaxlib==0.4.33 + # via brax # via chex + # via distrax + # via evosax + # via gymnax # via jax + # via jaxopt + # via mujoco-mjx # via optax # via orbax-checkpoint # via pytorch2jax + # via research-project-template +jaxopt==0.8.3 + # via brax jinja2==3.1.4 + # via brax + # via flask # via mkdocs # via mkdocs-macros-plugin # via mkdocs-material @@ -207,8 +291,12 @@ markupsafe==2.1.5 # via mkdocs # via mkdocs-autorefs # via mkdocstrings + # via werkzeug matplotlib==3.9.2 + # via evosax + # via gymnax # via research-project-template + # via seaborn mdurl==0.1.2 # via markdown-it-py mergedeep==1.3.4 @@ -236,9 +324,9 @@ mkdocs-get-deps==0.2.0 # via mkdocs mkdocs-literate-nav==0.6.1 # via research-project-template -mkdocs-macros-plugin==1.0.5 +mkdocs-macros-plugin==1.2.0 # via research-project-template -mkdocs-material==9.5.34 +mkdocs-material==9.5.35 # via research-project-template mkdocs-material-extensions==1.3.1 # via mkdocs-material @@ -246,21 +334,28 @@ mkdocs-section-index==0.3.9 # via research-project-template mkdocs-video==1.5.0 # via research-project-template -mkdocstrings==0.26.0 +mkdocstrings==0.26.1 # via mkdocstrings-python # via research-project-template mkdocstrings-python==1.11.1 # via mkdocstrings -ml-dtypes==0.4.0 +ml-collections==0.1.1 + # via brax +ml-dtypes==0.5.0 # via jax # via jaxlib # via tensorstore mpmath==1.3.0 # via sympy -msgpack==1.0.8 +msgpack==1.1.0 # via flax # via orbax-checkpoint -multidict==6.0.5 +mujoco==3.2.3 + # via brax + # via mujoco-mjx +mujoco-mjx==3.2.3 + # via brax +multidict==6.1.0 # via aiohttp # via yarl multiprocess==0.70.16 @@ -275,66 +370,92 @@ nest-asyncio==1.6.0 networkx==3.3 # via torch numpy==2.1.1 + # via brax # via chex # via contourpy # via datasets + # via distrax + # via dm-env # via evaluate + # via evosax # via flax + # via gym + # via gymnasium # via jax # via jaxlib + # via jaxopt # via matplotlib # via ml-dtypes + # via mujoco # via opt-einsum # via optax # via orbax-checkpoint # via orion # via pandas # via pyarrow + # via rejax # via scikit-learn # via scikit-optimize # via scipy + # via seaborn + # via tensorboardx + # via tensorflow-probability # via tensorstore # via torchmetrics # via torchvision # via transformers + # via trimesh nvidia-cublas-cu12==12.1.3.1 # via jax-cuda12-plugin # via nvidia-cudnn-cu12 # via nvidia-cusolver-cu12 + # via research-project-template # via torch nvidia-cuda-cupti-cu12==12.1.105 # via jax-cuda12-plugin + # via research-project-template # via torch nvidia-cuda-nvcc-cu12==12.6.68 # via jax-cuda12-plugin + # via research-project-template nvidia-cuda-nvrtc-cu12==12.1.105 + # via research-project-template # via torch nvidia-cuda-runtime-cu12==12.1.105 # via jax-cuda12-plugin + # via research-project-template # via torch nvidia-cudnn-cu12==9.1.0.70 # via jax-cuda12-plugin + # via research-project-template # via torch nvidia-cufft-cu12==11.0.2.54 # via jax-cuda12-plugin + # via research-project-template # via torch nvidia-curand-cu12==10.3.2.106 + # via research-project-template # via torch nvidia-cusolver-cu12==11.4.5.107 # via jax-cuda12-plugin + # via research-project-template # via torch nvidia-cusparse-cu12==12.1.0.106 # via jax-cuda12-plugin # via nvidia-cusolver-cu12 + # via research-project-template # via torch nvidia-nccl-cu12==2.20.5 # via jax-cuda12-plugin + # via research-project-template # via torch nvidia-nvjitlink-cu12==12.6.68 # via jax-cuda12-plugin # via nvidia-cusolver-cu12 # via nvidia-cusparse-cu12 + # via research-project-template nvidia-nvtx-cu12==12.1.105 + # via research-project-template # via torch omegaconf==2.3.0 # via hydra-core @@ -343,8 +464,11 @@ omegaconf==2.3.0 opt-einsum==3.3.0 # via jax optax==0.2.3 + # via brax # via flax -orbax-checkpoint==0.6.1 + # via rejax +orbax-checkpoint==0.6.4 + # via brax # via flax orion==0.2.7 # via hydra-orion-sweeper @@ -359,10 +483,12 @@ packaging==24.1 # via lightning-utilities # via matplotlib # via mkdocs + # via mkdocs-macros-plugin # via plotly # via pytest # via pytorch-lightning # via scikit-optimize + # via tensorboardx # via torchmetrics # via transformers paginate==0.5.7 @@ -371,23 +497,26 @@ pandas==2.2.2 # via datasets # via evaluate # via orion + # via seaborn pathspec==0.12.1 # via black # via mkdocs pillow==10.4.0 + # via brax # via matplotlib # via torchvision -platformdirs==4.2.2 +platformdirs==4.3.6 # via black # via mkdocs-get-deps # via mkdocstrings # via wandb -plotly==5.24.0 +plotly==5.24.1 # via orion pluggy==1.5.0 # via pytest -protobuf==5.28.0 +protobuf==5.28.2 # via orbax-checkpoint + # via tensorboardx # via wandb psutil==6.0.0 # via orion @@ -396,43 +525,52 @@ pyaml==24.7.0 # via scikit-optimize pyarrow==17.0.0 # via datasets -pydantic==2.8.2 +pydantic==2.9.2 # via research-project-template -pydantic-core==2.20.1 +pydantic-core==2.23.4 # via pydantic +pygame==2.6.0 + # via gymnasium pygments==2.18.0 # via mkdocs-material # via rich pymdown-extensions==10.9 # via mkdocs-material # via mkdocstrings -pymongo==4.8.0 +pymongo==4.9.1 # via orion +pyopengl==3.1.7 + # via mujoco pyparsing==3.1.4 # via matplotlib pysocks==1.7.1 # via requests -pytest==8.3.2 +pytest==8.3.3 # via orion python-dateutil==2.9.0.post0 # via ghp-import # via matplotlib # via mkdocs-macros-plugin # via pandas +pytinyrenderer==0.0.14 + # via brax pytorch-lightning==2.4.0 # via lightning pytorch2jax==0.1.0 # via torch-jax-interop -pytz==2024.1 +pytz==2024.2 # via pandas pyyaml==6.0.2 # via datasets + # via evosax # via flax + # via gymnax # via huggingface-hub # via lightning # via mkdocs # via mkdocs-get-deps # via mkdocs-macros-plugin + # via ml-collections # via omegaconf # via orbax-checkpoint # via orion @@ -440,13 +578,16 @@ pyyaml==6.0.2 # via pymdown-extensions # via pytorch-lightning # via pyyaml-env-tag + # via rejax # via transformers # via wandb pyyaml-env-tag==0.1 # via mkdocs -regex==2024.7.24 +regex==2024.9.11 # via mkdocs-material # via transformers +rejax==0.1.0 + # via research-project-template requests==2.32.3 # via datasets # via evaluate @@ -456,48 +597,59 @@ requests==2.32.3 # via orion # via transformers # via wandb -rich==13.8.0 +rich==13.8.1 # via flax # via research-project-template -safetensors==0.4.4 +safetensors==0.4.5 # via transformers -scikit-learn==1.5.1 +scikit-learn==1.5.2 # via orion # via research-project-template # via scikit-optimize scikit-optimize==0.10.2 # via orion scipy==1.14.1 + # via brax # via jax # via jaxlib + # via jaxopt + # via mujoco-mjx # via orion # via scikit-learn # via scikit-optimize -sentry-sdk==2.13.0 +seaborn==0.13.2 + # via gymnax +sentry-sdk==2.14.0 # via wandb setproctitle==1.3.3 # via wandb -setuptools==74.1.1 +setuptools==75.1.0 # via lightning-utilities # via wandb -simple-parsing==0.1.5 +simple-parsing==0.1.6 # via research-project-template six==1.16.0 # via docker-pycreds + # via ml-collections # via python-dateutil + # via tensorflow-probability smmap==5.0.1 # via gitdb soupsieve==2.6 # via beautifulsoup4 -submitit==1.5.1 +submitit==1.5.2 # via hydra-submitit-launcher -sympy==1.13.2 +sympy==1.13.3 # via torch tabulate==0.9.0 # via orion tenacity==9.0.0 # via plotly -tensorstore==0.1.64 +tensorboardx==2.6.2.2 + # via brax +tensorflow-probability==0.24.0 + # via distrax +tensorstore==0.1.65 # via flax # via orbax-checkpoint termcolor==2.4.0 @@ -511,7 +663,7 @@ tomli==2.0.1 # via pytest toolz==0.12.1 # via chex -torch==2.4.0 +torch==2.4.1 # via lightning # via pytorch-lightning # via pytorch2jax @@ -521,10 +673,10 @@ torch==2.4.0 # via torchvision torch-jax-interop==0.0.7 # via research-project-template -torchmetrics==1.4.1 +torchmetrics==1.4.2 # via lightning # via pytorch-lightning -torchvision==0.19.0 +torchvision==0.19.1 # via research-project-template tqdm==4.66.5 # via datasets @@ -537,18 +689,24 @@ tqdm==4.66.5 # via transformers transformers==4.44.2 # via research-project-template +trimesh==4.4.9 + # via brax + # via mujoco-mjx triton==3.0.0 # via torch typing-extensions==4.12.2 # via black + # via brax # via chex # via etils # via flax + # via gymnasium # via huggingface-hub # via hydra-orion-sweeper # via hydra-zen # via lightning # via lightning-utilities + # via multidict # via orbax-checkpoint # via pydantic # via pydantic-core @@ -558,19 +716,23 @@ typing-extensions==4.12.2 # via torch tzdata==2024.1 # via pandas -urllib3==2.2.2 +urllib3==2.2.3 # via requests # via sentry-sdk -wandb==0.17.8 +wandb==0.18.1 # via research-project-template watchdog==5.0.2 # via mkdocs wcmatch==9.0 # via mkdocs-awesome-pages-plugin +werkzeug==3.0.4 + # via flask +xtils @ git+https://github.com/jessefarebro/xtils@b20807d980b059c598a77ff2ef05075594f49f63 + # via research-project-template xxhash==3.5.0 # via datasets # via evaluate -yarl==1.9.8 +yarl==1.11.1 # via aiohttp -zipp==3.20.1 +zipp==3.20.2 # via etils