Skip to content

Commit

Permalink
Move the JaxTrainer to a new "trainers" dir
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Oct 7, 2024
1 parent 774f3c0 commit 36633bf
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 20 deletions.
2 changes: 1 addition & 1 deletion project/algorithms/jax_rl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from typing_extensions import TypeVar
from xtils.jitpp import Static

from project.algorithms.jax_trainer import JaxCallback, JaxModule, JaxTrainer
from project.trainers.jax_trainer import JaxCallback, JaxModule, JaxTrainer
from project.utils.typing_utils.jax_typing_utils import field, jit

logger = get_logger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion project/algorithms/jax_rl_example_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from typing_extensions import override

from project.algorithms.callbacks.samples_per_second import MeasureSamplesPerSecondCallback
from project.algorithms.jax_trainer import JaxTrainer, hparams_to_dict
from project.trainers.jax_trainer import JaxTrainer, hparams_to_dict

from .jax_rl_example import (
EvalMetrics,
Expand Down
2 changes: 1 addition & 1 deletion project/configs/trainer/jax.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defaults:
- callbacks: rich_progress_bar.yaml
_target_: project.algorithms.jax_trainer.JaxTrainer
_target_: project.trainers.jax_trainer.JaxTrainer
max_epochs: 75
training_steps_per_epoch: 1

Expand Down
2 changes: 1 addition & 1 deletion project/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def run(experiment: Experiment) -> tuple[str, float | None, dict]:

if datamodule is None:
# todo: missing `rng` argument.
from project.algorithms.jax_trainer import JaxTrainer
from project.trainers.jax_trainer import JaxTrainer

if isinstance(experiment.trainer, JaxTrainer):
import jax.random
Expand Down
8 changes: 8 additions & 0 deletions project/trainers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from lightning.pytorch.trainer.trainer import Trainer

from .jax_trainer import JaxTrainer

__all__ = [
"JaxTrainer",
"Trainer",
]
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import lightning
import lightning.pytorch.callbacks
import lightning.pytorch.loggers
import torch # noqa
from hydra.core.hydra_config import HydraConfig
from typing_extensions import TypeVar

from project.utils.typing_utils.jax_typing_utils import jit
Expand Down Expand Up @@ -91,8 +91,7 @@ def teardown(self, trainer: JaxTrainer, module: JaxModule[Ts], stage: str, ts: T


class JaxTrainer(flax.struct.PyTreeNode):
"""A simplified version of the `[lightning.Trainer][lightning.pytorch.trainer.Trainer]` with a
fully jitted training loop.
"""A simplified version of the `lightning.Trainer` with a fully jitted training loop.
## Assumptions:
Expand Down Expand Up @@ -143,17 +142,21 @@ class JaxTrainer(flax.struct.PyTreeNode):
but not both.
- If you want to use [jax.vmap][] on the `fit` method, just remove the callbacks on the
Trainer for now.
## TODOs / ideas
- Add a checkpoint callback with orbax-checkpoint?
"""

# num_epochs = np.ceil(algo.hp.total_timesteps / algo.hp.eval_freq).astype(int)
max_epochs: int = flax.struct.field(pytree_node=False)

# iteration_steps = algo.hp.num_envs * algo.hp.num_steps
# num_iterations = np.ceil(algo.hp.eval_freq / iteration_steps).astype(int)
training_steps_per_epoch: int
training_steps_per_epoch: int = flax.struct.field(pytree_node=False)

limit_val_batches: int = 0
limit_test_batches: int = 0

# TODO: Getting some errors with the schema generation for lightning.Callback and
# lightning.pytorch.loggers.logger.Logger here.
# lightning.pytorch.loggers.logger.Logger here if we keep the type annotation.
callbacks: Sequence = dataclasses.field(metadata={"pytree_node": False}, default_factory=tuple)

logger: Any | None = flax.struct.field(pytree_node=False, default=None)
Expand All @@ -162,22 +165,20 @@ class JaxTrainer(flax.struct.PyTreeNode):
# strategy: str = flax.struct.field(pytree_node=False, default="auto")
# devices: int | str = flax.struct.field(pytree_node=False, default="auto")

# min_epochs: int

# path to output directory, created dynamically by hydra
# path generation pattern is specified in `configs/hydra/default.yaml`
# use it to store all files generated during the run, like checkpoints and metrics
default_root_dir: str | Path | None = flax.struct.field(pytree_node=False, default="")

default_root_dir: str | Path | None = flax.struct.field(
pytree_node=False,
default_factory=lambda: HydraConfig.get().runtime.output_dir,
)

# State variables:
# TODO: Figure out how to efficiently present these even when jit is turned off (currently
# replacing self entirely).
# TODO: figure out how to cleanly store / update these.
current_epoch: int = flax.struct.field(pytree_node=True, default=0)
global_step: int = flax.struct.field(pytree_node=True, default=0)

# TODO: Add a checkpoint callback with orbax-checkpoint?
limit_val_batches: int = 0
limit_test_batches: int = 0
logged_metrics: dict = flax.struct.field(pytree_node=True, default_factory=dict)
callback_metrics: dict = flax.struct.field(pytree_node=True, default_factory=dict)
# todo: get the metrics from the callbacks?
Expand Down

0 comments on commit 36633bf

Please sign in to comment.