Skip to content

Commit

Permalink
Project main rework (#99)
Browse files Browse the repository at this point in the history
* Remove `experiment.py`, move to `main.py`

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Use singledispatch for `train` and `evaluate`

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* WIP: Rework `main.py`, fix resulting errors

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Minor fix for doc / readability of main.py

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Use full path to datamodules in configs

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Save datamodule on self in text_classifier.py

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix project/main_test.py::test_help_string

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix test_help_string

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

---------

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice authored Dec 19, 2024
1 parent a34689f commit bb1943d
Show file tree
Hide file tree
Showing 22 changed files with 438 additions and 410 deletions.
83 changes: 83 additions & 0 deletions .regression_files/project/main_test/test_help_string.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
main is powered by Hydra.

== Configuration groups ==
Compose your configuration from those groups (group=option)

algorithm: image_classifier, jax_image_classifier, jax_ppo, llm_finetuning, no_op, text_classifier
algorithm/lr_scheduler: CosineAnnealingLR, StepLR
algorithm/network: fcnet, jax_cnn, jax_fcnet, resnet18, resnet50
algorithm/optimizer: Adam, SGD, custom_adam
cluster: beluga, cedar, current, mila, narval
datamodule: cifar10, fashion_mnist, glue_cola, imagenet, inaturalist, mnist, vision
experiment: cluster_sweep_example, example, jax_rl_example, llm_finetuning_example, local_sweep_example, profiling, text_classification_example
resources: cpu, gpu
trainer: cpu, debug, default, jax_trainer, overfit_one_batch
trainer/callbacks: default, early_stopping, model_checkpoint, model_summary, no_checkpoints, none, rich_progress_bar
trainer/logger: tensorboard, wandb, wandb_cluster


== Config ==
Override anything in the config (foo.bar=value)

algorithm: ???
datamodule: null
trainer:
callbacks:
model_checkpoint:
_target_: lightning.pytorch.callbacks.ModelCheckpoint
dirpath: ${hydra:runtime.output_dir}/checkpoints
filename: epoch_{epoch:03d}
monitor: val/loss
verbose: false
save_last: true
save_top_k: 1
mode: min
auto_insert_metric_name: false
save_weights_only: false
every_n_train_steps: null
train_time_interval: null
every_n_epochs: null
save_on_train_epoch_end: null
early_stopping:
_target_: lightning.pytorch.callbacks.EarlyStopping
monitor: val/loss
min_delta: 0.0
patience: 5
verbose: false
mode: min
strict: true
check_finite: true
stopping_threshold: null
divergence_threshold: null
check_on_train_epoch_end: null
model_summary:
_target_: lightning.pytorch.callbacks.RichModelSummary
max_depth: 2
rich_progress_bar:
_target_: lightning.pytorch.callbacks.RichProgressBar
lr_monitor:
_target_: lightning.pytorch.callbacks.LearningRateMonitor
device_utilisation:
_target_: lightning.pytorch.callbacks.DeviceStatsMonitor
throughput:
_target_: project.algorithms.callbacks.samples_per_second.MeasureSamplesPerSecondCallback
_target_: lightning.Trainer
accelerator: auto
strategy: auto
devices: 1
deterministic: false
fast_dev_run: false
min_epochs: 1
max_epochs: 10
default_root_dir: ${hydra:runtime.output_dir}
detect_anomaly: false
log_level: info
seed: 123
name: default
debug: false
verbose: false
ckpt_path: null


Powered by Hydra (https://hydra.cc)
Use --hydra-help to view Hydra specific help
9 changes: 4 additions & 5 deletions docs/profiling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
datamodule_config,
experiment_dictconfig,
)
from project.experiment import (
from project.experiment import instantiate_datamodule, instantiate_trainer
from project.main import (
instantiate_algorithm,
instantiate_datamodule,
instantiate_trainer,
setup_logging,
)
from project.utils.hydra_utils import resolve_dictconfig
Expand Down Expand Up @@ -121,8 +120,8 @@ def test_notebook_commands_dont_cause_errors(experiment_dictconfig: DictConfig):
# _experiment = _setup_experiment(config)
setup_logging(log_level=config.log_level)
lightning.seed_everything(config.seed, workers=True)
_trainer = instantiate_trainer(config)
_trainer = instantiate_trainer(config.trainer)
datamodule = instantiate_datamodule(config.datamodule)
_algorithm = instantiate_algorithm(config.algorithm, datamodule=datamodule)
_algorithm = instantiate_algorithm(config, datamodule=datamodule)

# Note: Here we don't actually do anything with the objects.
13 changes: 0 additions & 13 deletions project/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +0,0 @@
from .image_classifier import ImageClassifier
from .jax_image_classifier import JaxImageClassifier
from .jax_ppo import JaxRLExample
from .no_op import NoOp
from .text_classifier import TextClassifier

__all__ = [
"ImageClassifier",
"JaxImageClassifier",
"NoOp",
"TextClassifier",
"JaxRLExample",
]
33 changes: 33 additions & 0 deletions project/algorithms/jax_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from __future__ import annotations

import contextlib
import dataclasses
import functools
import operator
from collections.abc import Callable, Sequence
from logging import getLogger as get_logger
from pathlib import Path
Expand Down Expand Up @@ -36,6 +38,8 @@
from typing_extensions import TypeVar
from xtils.jitpp import Static

from project import experiment
from project.configs.config import Config
from project.trainers.jax_trainer import JaxCallback, JaxModule, JaxTrainer
from project.utils.typing_utils.jax_typing_utils import field, jit

Expand Down Expand Up @@ -826,3 +830,32 @@ def on_train_epoch_start(self, trainer: JaxTrainer, module: JaxRLExample, ts: PP
gif_path = Path(log_dir) / f"epoch_{ts.data_collection_state.global_step:05}.gif"
module.visualize(ts=ts, gif_path=gif_path)
jax.debug.print("Saved gif to {gif_path}", gif_path=gif_path)


@experiment.evaluate.register
def evaluate_ppo_example(
algorithm: JaxRLExample,
/,
*,
trainer: JaxTrainer,
train_results: tuple[PPOState, EvalMetrics],
config: Config,
datamodule: None = None,
):
"""Override for the `evaluate` function used by `main.py`, in the case of this algorithm."""
# todo: there isn't yet a `validate` method on the jax trainer.
assert isinstance(algorithm, JaxModule)
assert isinstance(trainer, JaxTrainer)
assert train_results is not None
metrics = train_results[1]

last_epoch_metrics = jax.tree.map(operator.itemgetter(-1), metrics)
assert isinstance(last_epoch_metrics, EvalMetrics)
# Average across eval seeds (we're doing evaluation in multiple environments in parallel with
# vmap).
last_epoch_average_cumulative_reward = last_epoch_metrics.cumulative_reward.mean().item()
return (
"-avg_cumulative_reward",
-last_epoch_average_cumulative_reward, # need to return an "error" to minimize for HPO.
dataclasses.asdict(last_epoch_metrics),
)
9 changes: 6 additions & 3 deletions project/algorithms/testsuites/lightning_module_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

from project.configs.config import Config
from project.conftest import DEFAULT_SEED
from project.experiment import instantiate_algorithm, instantiate_trainer, setup_logging
from project.experiment import instantiate_trainer
from project.main import instantiate_algorithm, setup_logging
from project.trainers.jax_trainer import JaxTrainer
from project.utils.hydra_utils import resolve_dictconfig
from project.utils.typing_utils import PyTree, is_sequence_of
Expand All @@ -47,6 +48,8 @@ class LightningModuleTests(Generic[AlgorithmType], ABC):
- Dataset splits: check some basic stats about the train/val/test inputs, are they somewhat similar?
- Define the input as a space, check that the dataset samples are in that space and not too
many samples are statistically OOD?
- Test to monitor distributed traffic out of this process?
- Dummy two-process tests (on CPU) to check before scaling up experiments?
"""

# algorithm_config: ParametrizedFixture[str]
Expand All @@ -67,7 +70,7 @@ def trainer(
) -> lightning.Trainer | JaxTrainer:
setup_logging(log_level=experiment_config.log_level)
lightning.seed_everything(experiment_config.seed, workers=True)
return instantiate_trainer(experiment_config)
return instantiate_trainer(experiment_config.trainer)

@pytest.fixture(scope="class")
def algorithm(
Expand All @@ -79,7 +82,7 @@ def algorithm(
):
"""Fixture that creates the "algorithm" (a
[LightningModule][lightning.pytorch.core.module.LightningModule])."""
algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule)
algorithm = instantiate_algorithm(experiment_config, datamodule=datamodule)
if isinstance(trainer, lightning.Trainer) and isinstance(
algorithm, lightning.LightningModule
):
Expand Down
1 change: 1 addition & 0 deletions project/algorithms/text_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
init_seed: int = 42,
):
super().__init__()
self.datamodule = datamodule
self.network_config = network
self.num_labels = datamodule.num_classes
self.task_name = datamodule.task_name
Expand Down
2 changes: 1 addition & 1 deletion project/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Config:
It is suggested for this class to accept a `datamodule` and `network` as arguments. The
instantiated datamodule and network will be passed to the algorithm's constructor.
For more info, see the [instantiate_algorithm][project.experiment.instantiate_algorithm] function.
For more info, see the [instantiate_algorithm][project.main.instantiate_algorithm] function.
"""

datamodule: Any | None = None
Expand Down
3 changes: 2 additions & 1 deletion project/configs/config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- base_config
- _self_
- algorithm: ???
- algorithm: null
- optional datamodule: null
- trainer: default.yaml
- hydra: default.yaml
Expand All @@ -12,4 +12,5 @@ defaults:
# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
- experiment: null
# This is a good default name to use when you aren't doing a sweep. Otherwise it causes an error.
# name: "${hydra:runtime.choices.algorithm}-${hydra:runtime.choices.network}-${hydra:runtime.choices.datamodule}"
4 changes: 0 additions & 4 deletions project/configs/datamodule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@

logger = get_logger(__name__)


# TODO: Make it possible to extend a structured base via yaml files as well as adding new fields
# (for example, ImagetNet32DataModule has a new constructor argument which can't be set atm in the
# config).
datamodule_store = store(group="datamodule")


Expand Down
2 changes: 1 addition & 1 deletion project/configs/datamodule/cifar10.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- vision
- _self_
_target_: project.datamodules.CIFAR10DataModule
_target_: project.datamodules.image_classification.cifar10.CIFAR10DataModule
data_dir: ${constant:torchvision_dir,DATA_DIR}
batch_size: 128
train_transforms:
Expand Down
2 changes: 1 addition & 1 deletion project/configs/datamodule/fashion_mnist.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
defaults:
- mnist
- _self_
_target_: project.datamodules.FashionMNISTDataModule
_target_: project.datamodules.image_classification.fashion_mnist.FashionMNISTDataModule
2 changes: 1 addition & 1 deletion project/configs/datamodule/glue_cola.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: project.datamodules.text.TextClassificationDataModule
_target_: project.datamodules.text.text_classification.TextClassificationDataModule
data_dir: ${oc.env:SCRATCH,.}/data
hf_dataset_path: glue
task_name: cola
Expand Down
2 changes: 1 addition & 1 deletion project/configs/datamodule/imagenet.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
defaults:
- vision
- _self_
_target_: project.datamodules.ImageNetDataModule
_target_: project.datamodules.image_classification.imagenet.ImageNetDataModule
# todo: add good configuration options here.
2 changes: 1 addition & 1 deletion project/configs/datamodule/inaturalist.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defaults:
- vision
- _self_
_target_: project.datamodules.INaturalistDataModule
_target_: project.datamodules.image_classification.inaturalist.INaturalistDataModule
version: "2021_train"
target_type: "full"
2 changes: 1 addition & 1 deletion project/configs/datamodule/mnist.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- vision
- _self_
_target_: project.datamodules.MNISTDataModule
_target_: project.datamodules.image_classification.mnist.MNISTDataModule
data_dir: ${constant:torchvision_dir,DATA_DIR}
normalize: True
batch_size: 128
Expand Down
2 changes: 1 addition & 1 deletion project/configs/datamodule/vision.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# todo: This config should not show up as an option on the command-line.
_target_: project.datamodules.VisionDataModule
_target_: project.datamodules.vision.VisionDataModule
data_dir: ${constant:DATA_DIR}
num_workers: ${constant:NUM_WORKERS}
val_split: 0.1 # NOTE: reduced from default of 0.2
Expand Down
2 changes: 1 addition & 1 deletion project/configs/experiment/example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ defaults:
# The parameters below will be merged with parameters from default configurations set above.
# This allows you to overwrite only specified parameters

# The name of the e
# The name of the experiment (for logging)
name: example

seed: ${oc.env:SLURM_PROCID,42}
Expand Down
12 changes: 6 additions & 6 deletions project/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,12 @@

from project.configs.config import Config
from project.datamodules.vision import VisionDataModule, num_cpus_on_node
from project.experiment import (
from project.experiment import instantiate_datamodule, instantiate_trainer
from project.main import (
PROJECT_NAME,
instantiate_algorithm,
instantiate_datamodule,
instantiate_trainer,
setup_logging,
)
from project.main import PROJECT_NAME
from project.trainers.jax_trainer import JaxTrainer
from project.utils.env_vars import REPO_ROOTDIR
from project.utils.hydra_utils import resolve_dictconfig
Expand Down Expand Up @@ -332,7 +331,7 @@ def algorithm(
):
"""Fixture that creates the "algorithm" (a
[LightningModule][lightning.pytorch.core.module.LightningModule])."""
algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule)
algorithm = instantiate_algorithm(experiment_config, datamodule=datamodule)
if isinstance(trainer, lightning.Trainer) and isinstance(algorithm, lightning.LightningModule):
with trainer.init_module(), device:
# A bit hacky, but we have to do this because the lightningmodule isn't associated
Expand All @@ -347,8 +346,9 @@ def trainer(
experiment_config: Config,
) -> pl.Trainer | JaxTrainer:
setup_logging(log_level=experiment_config.log_level)
# put here to copy what's done in main.py
lightning.seed_everything(experiment_config.seed, workers=True)
return instantiate_trainer(experiment_config)
return instantiate_trainer(experiment_config.trainer)


@pytest.fixture(scope="session")
Expand Down
Loading

0 comments on commit bb1943d

Please sign in to comment.