From 36633bf2c67a75bb643f3e44fc860028c0066d02 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 7 Oct 2024 14:47:31 +0000 Subject: [PATCH] Move the JaxTrainer to a new "trainers" dir Signed-off-by: Fabrice Normandin --- project/algorithms/jax_rl_example.py | 2 +- project/algorithms/jax_rl_example_test.py | 2 +- project/configs/trainer/jax.yaml | 2 +- project/main.py | 2 +- project/trainers/__init__.py | 8 +++++ .../{algorithms => trainers}/jax_trainer.py | 33 ++++++++++--------- 6 files changed, 29 insertions(+), 20 deletions(-) create mode 100644 project/trainers/__init__.py rename project/{algorithms => trainers}/jax_trainer.py (95%) diff --git a/project/algorithms/jax_rl_example.py b/project/algorithms/jax_rl_example.py index e4c8be71..698a1621 100644 --- a/project/algorithms/jax_rl_example.py +++ b/project/algorithms/jax_rl_example.py @@ -30,7 +30,7 @@ from typing_extensions import TypeVar from xtils.jitpp import Static -from project.algorithms.jax_trainer import JaxCallback, JaxModule, JaxTrainer +from project.trainers.jax_trainer import JaxCallback, JaxModule, JaxTrainer from project.utils.typing_utils.jax_typing_utils import field, jit logger = get_logger(__name__) diff --git a/project/algorithms/jax_rl_example_test.py b/project/algorithms/jax_rl_example_test.py index 24be5462..519d5aa2 100644 --- a/project/algorithms/jax_rl_example_test.py +++ b/project/algorithms/jax_rl_example_test.py @@ -40,7 +40,7 @@ from typing_extensions import override from project.algorithms.callbacks.samples_per_second import MeasureSamplesPerSecondCallback -from project.algorithms.jax_trainer import JaxTrainer, hparams_to_dict +from project.trainers.jax_trainer import JaxTrainer, hparams_to_dict from .jax_rl_example import ( EvalMetrics, diff --git a/project/configs/trainer/jax.yaml b/project/configs/trainer/jax.yaml index 534652f3..ae68ce4a 100644 --- a/project/configs/trainer/jax.yaml +++ b/project/configs/trainer/jax.yaml @@ -1,6 +1,6 @@ defaults: - callbacks: rich_progress_bar.yaml -_target_: project.algorithms.jax_trainer.JaxTrainer +_target_: project.trainers.jax_trainer.JaxTrainer max_epochs: 75 training_steps_per_epoch: 1 diff --git a/project/main.py b/project/main.py index f0c29a3f..b0fa985f 100644 --- a/project/main.py +++ b/project/main.py @@ -83,7 +83,7 @@ def run(experiment: Experiment) -> tuple[str, float | None, dict]: if datamodule is None: # todo: missing `rng` argument. - from project.algorithms.jax_trainer import JaxTrainer + from project.trainers.jax_trainer import JaxTrainer if isinstance(experiment.trainer, JaxTrainer): import jax.random diff --git a/project/trainers/__init__.py b/project/trainers/__init__.py new file mode 100644 index 00000000..f27ba440 --- /dev/null +++ b/project/trainers/__init__.py @@ -0,0 +1,8 @@ +from lightning.pytorch.trainer.trainer import Trainer + +from .jax_trainer import JaxTrainer + +__all__ = [ + "JaxTrainer", + "Trainer", +] diff --git a/project/algorithms/jax_trainer.py b/project/trainers/jax_trainer.py similarity index 95% rename from project/algorithms/jax_trainer.py rename to project/trainers/jax_trainer.py index da085353..c41da079 100644 --- a/project/algorithms/jax_trainer.py +++ b/project/trainers/jax_trainer.py @@ -16,7 +16,7 @@ import lightning import lightning.pytorch.callbacks import lightning.pytorch.loggers -import torch # noqa +from hydra.core.hydra_config import HydraConfig from typing_extensions import TypeVar from project.utils.typing_utils.jax_typing_utils import jit @@ -91,8 +91,7 @@ def teardown(self, trainer: JaxTrainer, module: JaxModule[Ts], stage: str, ts: T class JaxTrainer(flax.struct.PyTreeNode): - """A simplified version of the `[lightning.Trainer][lightning.pytorch.trainer.Trainer]` with a - fully jitted training loop. + """A simplified version of the `lightning.Trainer` with a fully jitted training loop. ## Assumptions: @@ -143,17 +142,21 @@ class JaxTrainer(flax.struct.PyTreeNode): but not both. - If you want to use [jax.vmap][] on the `fit` method, just remove the callbacks on the Trainer for now. + + ## TODOs / ideas + + - Add a checkpoint callback with orbax-checkpoint? """ - # num_epochs = np.ceil(algo.hp.total_timesteps / algo.hp.eval_freq).astype(int) max_epochs: int = flax.struct.field(pytree_node=False) - # iteration_steps = algo.hp.num_envs * algo.hp.num_steps - # num_iterations = np.ceil(algo.hp.eval_freq / iteration_steps).astype(int) - training_steps_per_epoch: int + training_steps_per_epoch: int = flax.struct.field(pytree_node=False) + + limit_val_batches: int = 0 + limit_test_batches: int = 0 # TODO: Getting some errors with the schema generation for lightning.Callback and - # lightning.pytorch.loggers.logger.Logger here. + # lightning.pytorch.loggers.logger.Logger here if we keep the type annotation. callbacks: Sequence = dataclasses.field(metadata={"pytree_node": False}, default_factory=tuple) logger: Any | None = flax.struct.field(pytree_node=False, default=None) @@ -162,22 +165,20 @@ class JaxTrainer(flax.struct.PyTreeNode): # strategy: str = flax.struct.field(pytree_node=False, default="auto") # devices: int | str = flax.struct.field(pytree_node=False, default="auto") - # min_epochs: int - # path to output directory, created dynamically by hydra # path generation pattern is specified in `configs/hydra/default.yaml` # use it to store all files generated during the run, like checkpoints and metrics - default_root_dir: str | Path | None = flax.struct.field(pytree_node=False, default="") + + default_root_dir: str | Path | None = flax.struct.field( + pytree_node=False, + default_factory=lambda: HydraConfig.get().runtime.output_dir, + ) # State variables: - # TODO: Figure out how to efficiently present these even when jit is turned off (currently - # replacing self entirely). + # TODO: figure out how to cleanly store / update these. current_epoch: int = flax.struct.field(pytree_node=True, default=0) global_step: int = flax.struct.field(pytree_node=True, default=0) - # TODO: Add a checkpoint callback with orbax-checkpoint? - limit_val_batches: int = 0 - limit_test_batches: int = 0 logged_metrics: dict = flax.struct.field(pytree_node=True, default_factory=dict) callback_metrics: dict = flax.struct.field(pytree_node=True, default_factory=dict) # todo: get the metrics from the callbacks?