From bb1943d1e7f78786cae220acc0c97c4c46494ab1 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 19 Dec 2024 10:02:59 -0800 Subject: [PATCH] Project main rework (#99) * Remove `experiment.py`, move to `main.py` Signed-off-by: Fabrice Normandin * Use singledispatch for `train` and `evaluate` Signed-off-by: Fabrice Normandin * WIP: Rework `main.py`, fix resulting errors Signed-off-by: Fabrice Normandin * Minor fix for doc / readability of main.py Signed-off-by: Fabrice Normandin * Use full path to datamodules in configs Signed-off-by: Fabrice Normandin * Save datamodule on self in text_classifier.py Signed-off-by: Fabrice Normandin * Fix project/main_test.py::test_help_string Signed-off-by: Fabrice Normandin * Fix test_help_string Signed-off-by: Fabrice Normandin --------- Signed-off-by: Fabrice Normandin --- .../project/main_test/test_help_string.txt | 83 +++++ docs/profiling_test.py | 9 +- project/algorithms/__init__.py | 13 - project/algorithms/jax_ppo.py | 33 ++ .../testsuites/lightning_module_tests.py | 9 +- project/algorithms/text_classifier.py | 1 + project/configs/config.py | 2 +- project/configs/config.yaml | 3 +- project/configs/datamodule/__init__.py | 4 - project/configs/datamodule/cifar10.yaml | 2 +- project/configs/datamodule/fashion_mnist.yaml | 2 +- project/configs/datamodule/glue_cola.yaml | 2 +- project/configs/datamodule/imagenet.yaml | 2 +- project/configs/datamodule/inaturalist.yaml | 2 +- project/configs/datamodule/mnist.yaml | 2 +- project/configs/datamodule/vision.yaml | 2 +- project/configs/experiment/example.yaml | 2 +- project/conftest.py | 12 +- project/experiment.py | 306 ++++++++++-------- project/main.py | 269 ++++----------- project/main_test.py | 55 ++-- project/trainers/jax_trainer.py | 33 ++ 22 files changed, 438 insertions(+), 410 deletions(-) create mode 100644 .regression_files/project/main_test/test_help_string.txt diff --git a/.regression_files/project/main_test/test_help_string.txt b/.regression_files/project/main_test/test_help_string.txt new file mode 100644 index 00000000..3ff32394 --- /dev/null +++ b/.regression_files/project/main_test/test_help_string.txt @@ -0,0 +1,83 @@ +main is powered by Hydra. + +== Configuration groups == +Compose your configuration from those groups (group=option) + +algorithm: image_classifier, jax_image_classifier, jax_ppo, llm_finetuning, no_op, text_classifier +algorithm/lr_scheduler: CosineAnnealingLR, StepLR +algorithm/network: fcnet, jax_cnn, jax_fcnet, resnet18, resnet50 +algorithm/optimizer: Adam, SGD, custom_adam +cluster: beluga, cedar, current, mila, narval +datamodule: cifar10, fashion_mnist, glue_cola, imagenet, inaturalist, mnist, vision +experiment: cluster_sweep_example, example, jax_rl_example, llm_finetuning_example, local_sweep_example, profiling, text_classification_example +resources: cpu, gpu +trainer: cpu, debug, default, jax_trainer, overfit_one_batch +trainer/callbacks: default, early_stopping, model_checkpoint, model_summary, no_checkpoints, none, rich_progress_bar +trainer/logger: tensorboard, wandb, wandb_cluster + + +== Config == +Override anything in the config (foo.bar=value) + +algorithm: ??? +datamodule: null +trainer: + callbacks: + model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: ${hydra:runtime.output_dir}/checkpoints + filename: epoch_{epoch:03d} + monitor: val/loss + verbose: false + save_last: true + save_top_k: 1 + mode: min + auto_insert_metric_name: false + save_weights_only: false + every_n_train_steps: null + train_time_interval: null + every_n_epochs: null + save_on_train_epoch_end: null + early_stopping: + _target_: lightning.pytorch.callbacks.EarlyStopping + monitor: val/loss + min_delta: 0.0 + patience: 5 + verbose: false + mode: min + strict: true + check_finite: true + stopping_threshold: null + divergence_threshold: null + check_on_train_epoch_end: null + model_summary: + _target_: lightning.pytorch.callbacks.RichModelSummary + max_depth: 2 + rich_progress_bar: + _target_: lightning.pytorch.callbacks.RichProgressBar + lr_monitor: + _target_: lightning.pytorch.callbacks.LearningRateMonitor + device_utilisation: + _target_: lightning.pytorch.callbacks.DeviceStatsMonitor + throughput: + _target_: project.algorithms.callbacks.samples_per_second.MeasureSamplesPerSecondCallback + _target_: lightning.Trainer + accelerator: auto + strategy: auto + devices: 1 + deterministic: false + fast_dev_run: false + min_epochs: 1 + max_epochs: 10 + default_root_dir: ${hydra:runtime.output_dir} + detect_anomaly: false +log_level: info +seed: 123 +name: default +debug: false +verbose: false +ckpt_path: null + + +Powered by Hydra (https://hydra.cc) +Use --hydra-help to view Hydra specific help diff --git a/docs/profiling_test.py b/docs/profiling_test.py index 14d02549..e8127179 100644 --- a/docs/profiling_test.py +++ b/docs/profiling_test.py @@ -14,10 +14,9 @@ datamodule_config, experiment_dictconfig, ) -from project.experiment import ( +from project.experiment import instantiate_datamodule, instantiate_trainer +from project.main import ( instantiate_algorithm, - instantiate_datamodule, - instantiate_trainer, setup_logging, ) from project.utils.hydra_utils import resolve_dictconfig @@ -121,8 +120,8 @@ def test_notebook_commands_dont_cause_errors(experiment_dictconfig: DictConfig): # _experiment = _setup_experiment(config) setup_logging(log_level=config.log_level) lightning.seed_everything(config.seed, workers=True) - _trainer = instantiate_trainer(config) + _trainer = instantiate_trainer(config.trainer) datamodule = instantiate_datamodule(config.datamodule) - _algorithm = instantiate_algorithm(config.algorithm, datamodule=datamodule) + _algorithm = instantiate_algorithm(config, datamodule=datamodule) # Note: Here we don't actually do anything with the objects. diff --git a/project/algorithms/__init__.py b/project/algorithms/__init__.py index cbd55ece..e69de29b 100644 --- a/project/algorithms/__init__.py +++ b/project/algorithms/__init__.py @@ -1,13 +0,0 @@ -from .image_classifier import ImageClassifier -from .jax_image_classifier import JaxImageClassifier -from .jax_ppo import JaxRLExample -from .no_op import NoOp -from .text_classifier import TextClassifier - -__all__ = [ - "ImageClassifier", - "JaxImageClassifier", - "NoOp", - "TextClassifier", - "JaxRLExample", -] diff --git a/project/algorithms/jax_ppo.py b/project/algorithms/jax_ppo.py index cd6527bf..b3c3d4f2 100644 --- a/project/algorithms/jax_ppo.py +++ b/project/algorithms/jax_ppo.py @@ -7,7 +7,9 @@ from __future__ import annotations import contextlib +import dataclasses import functools +import operator from collections.abc import Callable, Sequence from logging import getLogger as get_logger from pathlib import Path @@ -36,6 +38,8 @@ from typing_extensions import TypeVar from xtils.jitpp import Static +from project import experiment +from project.configs.config import Config from project.trainers.jax_trainer import JaxCallback, JaxModule, JaxTrainer from project.utils.typing_utils.jax_typing_utils import field, jit @@ -826,3 +830,32 @@ def on_train_epoch_start(self, trainer: JaxTrainer, module: JaxRLExample, ts: PP gif_path = Path(log_dir) / f"epoch_{ts.data_collection_state.global_step:05}.gif" module.visualize(ts=ts, gif_path=gif_path) jax.debug.print("Saved gif to {gif_path}", gif_path=gif_path) + + +@experiment.evaluate.register +def evaluate_ppo_example( + algorithm: JaxRLExample, + /, + *, + trainer: JaxTrainer, + train_results: tuple[PPOState, EvalMetrics], + config: Config, + datamodule: None = None, +): + """Override for the `evaluate` function used by `main.py`, in the case of this algorithm.""" + # todo: there isn't yet a `validate` method on the jax trainer. + assert isinstance(algorithm, JaxModule) + assert isinstance(trainer, JaxTrainer) + assert train_results is not None + metrics = train_results[1] + + last_epoch_metrics = jax.tree.map(operator.itemgetter(-1), metrics) + assert isinstance(last_epoch_metrics, EvalMetrics) + # Average across eval seeds (we're doing evaluation in multiple environments in parallel with + # vmap). + last_epoch_average_cumulative_reward = last_epoch_metrics.cumulative_reward.mean().item() + return ( + "-avg_cumulative_reward", + -last_epoch_average_cumulative_reward, # need to return an "error" to minimize for HPO. + dataclasses.asdict(last_epoch_metrics), + ) diff --git a/project/algorithms/testsuites/lightning_module_tests.py b/project/algorithms/testsuites/lightning_module_tests.py index eb509a73..91462e82 100644 --- a/project/algorithms/testsuites/lightning_module_tests.py +++ b/project/algorithms/testsuites/lightning_module_tests.py @@ -22,7 +22,8 @@ from project.configs.config import Config from project.conftest import DEFAULT_SEED -from project.experiment import instantiate_algorithm, instantiate_trainer, setup_logging +from project.experiment import instantiate_trainer +from project.main import instantiate_algorithm, setup_logging from project.trainers.jax_trainer import JaxTrainer from project.utils.hydra_utils import resolve_dictconfig from project.utils.typing_utils import PyTree, is_sequence_of @@ -47,6 +48,8 @@ class LightningModuleTests(Generic[AlgorithmType], ABC): - Dataset splits: check some basic stats about the train/val/test inputs, are they somewhat similar? - Define the input as a space, check that the dataset samples are in that space and not too many samples are statistically OOD? + - Test to monitor distributed traffic out of this process? + - Dummy two-process tests (on CPU) to check before scaling up experiments? """ # algorithm_config: ParametrizedFixture[str] @@ -67,7 +70,7 @@ def trainer( ) -> lightning.Trainer | JaxTrainer: setup_logging(log_level=experiment_config.log_level) lightning.seed_everything(experiment_config.seed, workers=True) - return instantiate_trainer(experiment_config) + return instantiate_trainer(experiment_config.trainer) @pytest.fixture(scope="class") def algorithm( @@ -79,7 +82,7 @@ def algorithm( ): """Fixture that creates the "algorithm" (a [LightningModule][lightning.pytorch.core.module.LightningModule]).""" - algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule) + algorithm = instantiate_algorithm(experiment_config, datamodule=datamodule) if isinstance(trainer, lightning.Trainer) and isinstance( algorithm, lightning.LightningModule ): diff --git a/project/algorithms/text_classifier.py b/project/algorithms/text_classifier.py index 2ef16b1a..1d9ea4e1 100644 --- a/project/algorithms/text_classifier.py +++ b/project/algorithms/text_classifier.py @@ -30,6 +30,7 @@ def __init__( init_seed: int = 42, ): super().__init__() + self.datamodule = datamodule self.network_config = network self.num_labels = datamodule.num_classes self.task_name = datamodule.task_name diff --git a/project/configs/config.py b/project/configs/config.py index 277b0f6f..2b4482e4 100644 --- a/project/configs/config.py +++ b/project/configs/config.py @@ -27,7 +27,7 @@ class Config: It is suggested for this class to accept a `datamodule` and `network` as arguments. The instantiated datamodule and network will be passed to the algorithm's constructor. - For more info, see the [instantiate_algorithm][project.experiment.instantiate_algorithm] function. + For more info, see the [instantiate_algorithm][project.main.instantiate_algorithm] function. """ datamodule: Any | None = None diff --git a/project/configs/config.yaml b/project/configs/config.yaml index bf77f664..e21ba25c 100644 --- a/project/configs/config.yaml +++ b/project/configs/config.yaml @@ -1,7 +1,7 @@ defaults: - base_config - _self_ - - algorithm: ??? + - algorithm: null - optional datamodule: null - trainer: default.yaml - hydra: default.yaml @@ -12,4 +12,5 @@ defaults: # experiment configs allow for version control of specific hyperparameters # e.g. best hyperparameters for given model and datamodule - experiment: null +# This is a good default name to use when you aren't doing a sweep. Otherwise it causes an error. # name: "${hydra:runtime.choices.algorithm}-${hydra:runtime.choices.network}-${hydra:runtime.choices.datamodule}" diff --git a/project/configs/datamodule/__init__.py b/project/configs/datamodule/__init__.py index d9b68bc5..f30cd17a 100644 --- a/project/configs/datamodule/__init__.py +++ b/project/configs/datamodule/__init__.py @@ -4,10 +4,6 @@ logger = get_logger(__name__) - -# TODO: Make it possible to extend a structured base via yaml files as well as adding new fields -# (for example, ImagetNet32DataModule has a new constructor argument which can't be set atm in the -# config). datamodule_store = store(group="datamodule") diff --git a/project/configs/datamodule/cifar10.yaml b/project/configs/datamodule/cifar10.yaml index 2410ef7c..97678a17 100644 --- a/project/configs/datamodule/cifar10.yaml +++ b/project/configs/datamodule/cifar10.yaml @@ -1,7 +1,7 @@ defaults: - vision - _self_ -_target_: project.datamodules.CIFAR10DataModule +_target_: project.datamodules.image_classification.cifar10.CIFAR10DataModule data_dir: ${constant:torchvision_dir,DATA_DIR} batch_size: 128 train_transforms: diff --git a/project/configs/datamodule/fashion_mnist.yaml b/project/configs/datamodule/fashion_mnist.yaml index 472a4d96..a0c99bb6 100644 --- a/project/configs/datamodule/fashion_mnist.yaml +++ b/project/configs/datamodule/fashion_mnist.yaml @@ -1,4 +1,4 @@ defaults: - mnist - _self_ -_target_: project.datamodules.FashionMNISTDataModule +_target_: project.datamodules.image_classification.fashion_mnist.FashionMNISTDataModule diff --git a/project/configs/datamodule/glue_cola.yaml b/project/configs/datamodule/glue_cola.yaml index 078a153d..f0903b27 100644 --- a/project/configs/datamodule/glue_cola.yaml +++ b/project/configs/datamodule/glue_cola.yaml @@ -1,4 +1,4 @@ -_target_: project.datamodules.text.TextClassificationDataModule +_target_: project.datamodules.text.text_classification.TextClassificationDataModule data_dir: ${oc.env:SCRATCH,.}/data hf_dataset_path: glue task_name: cola diff --git a/project/configs/datamodule/imagenet.yaml b/project/configs/datamodule/imagenet.yaml index 23804087..b62a3231 100644 --- a/project/configs/datamodule/imagenet.yaml +++ b/project/configs/datamodule/imagenet.yaml @@ -1,5 +1,5 @@ defaults: - vision - _self_ -_target_: project.datamodules.ImageNetDataModule +_target_: project.datamodules.image_classification.imagenet.ImageNetDataModule # todo: add good configuration options here. diff --git a/project/configs/datamodule/inaturalist.yaml b/project/configs/datamodule/inaturalist.yaml index d3621b0f..dea23915 100644 --- a/project/configs/datamodule/inaturalist.yaml +++ b/project/configs/datamodule/inaturalist.yaml @@ -1,6 +1,6 @@ defaults: - vision - _self_ -_target_: project.datamodules.INaturalistDataModule +_target_: project.datamodules.image_classification.inaturalist.INaturalistDataModule version: "2021_train" target_type: "full" diff --git a/project/configs/datamodule/mnist.yaml b/project/configs/datamodule/mnist.yaml index 625b1ad9..80a7214a 100644 --- a/project/configs/datamodule/mnist.yaml +++ b/project/configs/datamodule/mnist.yaml @@ -1,7 +1,7 @@ defaults: - vision - _self_ -_target_: project.datamodules.MNISTDataModule +_target_: project.datamodules.image_classification.mnist.MNISTDataModule data_dir: ${constant:torchvision_dir,DATA_DIR} normalize: True batch_size: 128 diff --git a/project/configs/datamodule/vision.yaml b/project/configs/datamodule/vision.yaml index 561a36b1..60e62369 100644 --- a/project/configs/datamodule/vision.yaml +++ b/project/configs/datamodule/vision.yaml @@ -1,5 +1,5 @@ # todo: This config should not show up as an option on the command-line. -_target_: project.datamodules.VisionDataModule +_target_: project.datamodules.vision.VisionDataModule data_dir: ${constant:DATA_DIR} num_workers: ${constant:NUM_WORKERS} val_split: 0.1 # NOTE: reduced from default of 0.2 diff --git a/project/configs/experiment/example.yaml b/project/configs/experiment/example.yaml index 90d2ca6f..138f4822 100644 --- a/project/configs/experiment/example.yaml +++ b/project/configs/experiment/example.yaml @@ -16,7 +16,7 @@ defaults: # The parameters below will be merged with parameters from default configurations set above. # This allows you to overwrite only specified parameters -# The name of the e +# The name of the experiment (for logging) name: example seed: ${oc.env:SLURM_PROCID,42} diff --git a/project/conftest.py b/project/conftest.py index 8a9f88a2..e420dfeb 100644 --- a/project/conftest.py +++ b/project/conftest.py @@ -93,13 +93,12 @@ from project.configs.config import Config from project.datamodules.vision import VisionDataModule, num_cpus_on_node -from project.experiment import ( +from project.experiment import instantiate_datamodule, instantiate_trainer +from project.main import ( + PROJECT_NAME, instantiate_algorithm, - instantiate_datamodule, - instantiate_trainer, setup_logging, ) -from project.main import PROJECT_NAME from project.trainers.jax_trainer import JaxTrainer from project.utils.env_vars import REPO_ROOTDIR from project.utils.hydra_utils import resolve_dictconfig @@ -332,7 +331,7 @@ def algorithm( ): """Fixture that creates the "algorithm" (a [LightningModule][lightning.pytorch.core.module.LightningModule]).""" - algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule) + algorithm = instantiate_algorithm(experiment_config, datamodule=datamodule) if isinstance(trainer, lightning.Trainer) and isinstance(algorithm, lightning.LightningModule): with trainer.init_module(), device: # A bit hacky, but we have to do this because the lightningmodule isn't associated @@ -347,8 +346,9 @@ def trainer( experiment_config: Config, ) -> pl.Trainer | JaxTrainer: setup_logging(log_level=experiment_config.log_level) + # put here to copy what's done in main.py lightning.seed_everything(experiment_config.seed, workers=True) - return instantiate_trainer(experiment_config) + return instantiate_trainer(experiment_config.trainer) @pytest.fixture(scope="session") diff --git a/project/experiment.py b/project/experiment.py index 8b9e4cc8..bd5097a5 100644 --- a/project/experiment.py +++ b/project/experiment.py @@ -1,177 +1,207 @@ -"""Module containing the functions which create experiment components from Hydra configs. - -This is essentially just calling [hydra.utils.instantiate]( -https://hydra.cc/docs/1.3/advanced/instantiate_objects/overview/#internaldocs-banner) -on the -datamodule, network, trainer, and algorithm configs in a certain order. - -This also adds the instance_attr custom resolver, which allows you to retrieve an attribute of -an instantiated object instead of a config. -""" - from __future__ import annotations -import copy import functools import logging import typing +import warnings from typing import Any import hydra -import hydra.utils -import hydra_zen -import rich.console -import rich.logging -import rich.traceback +import lightning +import rich +from hydra_zen.typing import Builds +from omegaconf import DictConfig + +from project.configs.config import Config if typing.TYPE_CHECKING: - from hydra_zen.typing import Builds - from lightning import Callback, LightningDataModule, LightningModule, Trainer + import lightning + + from project.trainers.jax_trainer import JaxTrainer - from project.configs.config import Config - from project.trainers.jax_trainer import JaxModule, JaxTrainer logger = logging.getLogger(__name__) -# BUG: Always using the pydantic parser when instantiating things would be nice, but it currently -# causes issues related to pickling: https://github.com/mit-ll-responsible-ai/hydra-zen/issues/717 -# def _use_pydantic[C: Callable](fn: C) -> C: -# return functools.partial(hydra_zen.instantiate, _target_wrapper_=pydantic_parser) # type: ignore -# instantiate = _use_pydantic(hydra_zen.instantiate) - -instantiate = hydra_zen.instantiate - - -def setup_logging(log_level: str, global_log_level: str = "WARNING") -> None: - from project.main import PROJECT_NAME - - logging.basicConfig( - level=global_log_level.upper(), - # format="%(asctime)s - %(levelname)s - %(message)s", - format="%(message)s", - datefmt="[%X]", - force=True, - handlers=[ - rich.logging.RichHandler( - markup=True, - rich_tracebacks=True, - tracebacks_width=100, - tracebacks_show_locals=False, - ) - ], +@functools.singledispatch +def train( + algorithm, + /, + **kwargs, +) -> tuple[Any, Any]: + raise NotImplementedError( + f"There is no registered handler for training algorithm {algorithm} of type " + f"{type(algorithm)}! (kwargs: {kwargs})." + f"Registered handlers: " + + "\n\t".join([f"- {k}: {v.__name__}" for k, v in train.registry.items()]) ) - project_logger = logging.getLogger(PROJECT_NAME) - project_logger.setLevel(log_level.upper()) +@functools.singledispatch +def evaluate(algorithm: Any, /, **kwargs) -> tuple[str, float | None, dict]: + """Evaluates the algorithm. -def instantiate_trainer(experiment_config: Config) -> Trainer | JaxTrainer: + Returns the name of the 'error' metric for this run, its value, and a dict of metrics. + """ + raise NotImplementedError( + f"There is no registered handler for evaluating algorithm {algorithm} of type " + f"{type(algorithm)}! (kwargs: {kwargs})" + ) + + +def instantiate_trainer(trainer_config: dict | DictConfig) -> lightning.Trainer | JaxTrainer: # NOTE: Need to do a bit of sneaky type tricks to convince the outside world that these # fields have the right type. + # Create the Trainer + trainer_config = trainer_config.copy() # Avoid mutating the config. + callbacks: list | None = instantiate_values(trainer_config.pop("callbacks", None)) + logger: list | None = instantiate_values(trainer_config.pop("logger", None)) + trainer = hydra.utils.instantiate(trainer_config, callbacks=callbacks, logger=logger) + return trainer - # instantiate all the callbacks - callback_configs = experiment_config.trainer.pop("callbacks", {}) - callback_configs = {k: v for k, v in callback_configs.items() if v is not None} - callbacks: dict[str, Callback] | None = hydra.utils.instantiate( - callback_configs, _convert_="object" - ) - # Create the loggers, if any. - loggers: dict[str, Any] | None = hydra.utils.instantiate( - experiment_config.trainer.pop("logger", {}) - ) - # Create the Trainer. - - # BUG: `hydra.utils.instantiate` doesn't work with override **kwargs when some of them are - # dataclasses (e.g. a callback). - # trainer = hydra.utils.instantiate( - # config, - # callbacks=list(callbacks.values()) if callbacks else None, - # logger=list(loggers.values()) if loggers else None, - # ) - assert isinstance(experiment_config.trainer, dict) - config = copy.deepcopy(experiment_config.trainer) - target = hydra.utils.get_object(config.pop("_target_")) - _callbacks = list(callbacks.values()) if callbacks else None - _loggers = list(loggers.values()) if loggers else None - - trainer = target(**config, callbacks=_callbacks, logger=_loggers) - return trainer +def instantiate_values(config_dict: DictConfig | None) -> list[Any] | None: + """Returns the list of objects at the values in this dict of configs. + This is used for the config of the `trainer/logger` and `trainer/callbacks` fields, where + we can combine multiple config groups by adding entries in a dict. -def instantiate_datamodule( - datamodule_config: Builds[type[LightningDataModule]] | LightningDataModule | None, -) -> LightningDataModule | None: - """Instantiate the datamodule from the configuration dict. + For example, using `trainer/logger=wandb` and `trainer/logger=tensorboard` would result in a + dict with `wandb` and `tensorboard` as keys, and the corresponding config groups as values. - Any interpolations in the config will have already been resolved by the time we get here. + This would then return a list with the instantiated WandbLogger and TensorBoardLogger objects. """ - if not datamodule_config: + if not config_dict: + return None + objects_dict = hydra.utils.instantiate(config_dict, _recursive_=True) + if objects_dict is None: return None - import lightning - - if isinstance(datamodule_config, lightning.LightningDataModule): - logger.info( - f"Datamodule was already instantiated (probably to interpolate a field value). " - f"{datamodule_config=}" - ) - datamodule = datamodule_config - else: - logger.debug(f"Instantiating datamodule from config: {datamodule_config}") - datamodule = instantiate(datamodule_config) - return datamodule + assert isinstance(objects_dict, dict | DictConfig) + return [v for v in objects_dict.values() if v is not None] -def instantiate_algorithm( - algorithm_config: Config, datamodule: LightningDataModule | None -) -> LightningModule | JaxModule: - """Function used to instantiate the algorithm. +MetricName = str - It is suggested that your algorithm (LightningModule) take in the `datamodule` and `network` - as arguments, to make it easier to swap out different networks and datamodules during - experiments. +import lightning # noqa - The instantiated datamodule and network will be passed to the algorithm's constructor. - """ - # TODO: The algorithm is now always instantiated on the CPU, whereas it used to be instantiated - # directly on the default device (GPU). - # Create the algorithm - algo_config = algorithm_config - import lightning - if isinstance(algo_config, lightning.LightningModule): - logger.info( - f"Algorithm was already instantiated (probably to interpolate a field value)." - f"{algo_config=}" - ) - return algo_config +@evaluate.register(lightning.LightningModule) +def evaluate_lightningmodule( + algorithm: lightning.LightningModule, + /, + *, + trainer: lightning.Trainer, + datamodule: lightning.LightningDataModule | None = None, + config: Config, + train_results: Any = None, +) -> tuple[MetricName, float | None, dict]: + """Evaluates the algorithm and returns the metrics. - if datamodule: - algo_or_algo_partial = hydra.utils.instantiate(algo_config, datamodule=datamodule) + By default, if validation is to be performed, returns the validation error. Returns the + training error when `trainer.overfit_batches != 0` (e.g. when debugging or testing). Otherwise, + if `trainer.limit_val_batches == 0`, returns the test error. + """ + datamodule = datamodule or getattr(algorithm, "datamodule", None) + + # exp.trainer.logger.log_hyperparams() + # When overfitting on a single batch or only training, we return the train error. + if (trainer.limit_val_batches == trainer.limit_test_batches == 0) or ( + trainer.overfit_batches == 1 # type: ignore + ): + # We want to report the training error. + results_type = "train" + results = [ + { + **trainer.logged_metrics, + **trainer.callback_metrics, + **trainer.progress_bar_metrics, + } + ] + elif trainer.limit_val_batches != 0: + results_type = "val" + results = trainer.validate(model=algorithm, datamodule=datamodule) else: - algo_or_algo_partial = hydra.utils.instantiate(algo_config) - - if isinstance(algo_or_algo_partial, functools.partial): - if datamodule: - algorithm = algo_or_algo_partial(datamodule=datamodule) - else: - algorithm = algo_or_algo_partial() + warnings.warn(RuntimeWarning("About to use the test set for evaluation!")) + results_type = "test" + results = trainer.test(model=algorithm, datamodule=datamodule) + + if results is None: + rich.print("RUN FAILED!") + return "fail", None, {} + + metrics = dict(results[0]) + for key, value in metrics.items(): + rich.print(f"{results_type} {key}: ", value) + + if (accuracy := metrics.get(f"{results_type}/accuracy")) is not None: + # NOTE: This is the value that is used for HParam sweeps. + metric_name = "1-accuracy" + error = 1 - accuracy + + elif (loss := metrics.get(f"{results_type}/loss")) is not None: + logger.info("Assuming that the objective to minimize is the loss metric.") + # If 'accuracy' isn't in the results, assume that the loss is the metric to use. + metric_name = "loss" + error = loss else: - # logger.warning( - # f"Your algorithm config {algo_config} doesn't have '_partial_: true' set, which is " - # f"not recommended (since we can't pass the datamodule to the constructor)." - # ) - algorithm = algo_or_algo_partial - from project.trainers.jax_trainer import JaxModule - - if not isinstance(algorithm, lightning.LightningModule | JaxModule): - logger.warning( - UserWarning( - f"Your algorithm ({algorithm}) is not a LightningModule. Beware that this isn't " - f"explicitly supported at the moment." - ) + raise RuntimeError( + f"Don't know which metric to use to calculate the 'error' of this run.\n" + f"Here are the available metric names:\n" + f"{list(metrics.keys())}" ) - return algorithm + return metric_name, error, metrics + + +def instantiate_datamodule( + datamodule_config: Builds[type[lightning.LightningDataModule]] + | lightning.LightningDataModule + | None, +) -> lightning.LightningDataModule | None: + """Instantiate the datamodule from the configuration dict. + + Any interpolations in the config will have already been resolved by the time we get here. + """ + if not datamodule_config: + return None + import lightning + + if isinstance(datamodule_config, lightning.LightningDataModule): + logger.info( + f"Datamodule was already instantiated (probably to interpolate a field value). " + f"{datamodule_config=}" + ) + return datamodule_config + + logger.debug(f"Instantiating datamodule from config: {datamodule_config}") + return hydra.utils.instantiate(datamodule_config) + + +@train.register +def train_lightningmodule( + algorithm: lightning.LightningModule, + /, + *, + trainer: lightning.Trainer | None, + datamodule: lightning.LightningDataModule | None = None, + config: Config, +): + # Create the Trainer from the config. + if trainer is None: + _trainer = instantiate_trainer(config.trainer) + assert isinstance(_trainer, lightning.Trainer) + trainer = _trainer + + # Train the model using the dataloaders of the datamodule: + # The Algorithm gets to "wrap" the datamodule if it wants to. This could be useful for + # example in RL, where we need to set the actor to use in the environment, as well as + # potentially adding Wrappers on top of the environment, or having a replay buffer, etc. + if datamodule is None: + if hasattr(algorithm, "datamodule"): + datamodule = getattr(algorithm, "datamodule") + elif config.datamodule is not None: + datamodule = instantiate_datamodule(config.datamodule) + trainer.fit(algorithm, datamodule=datamodule, ckpt_path=config.ckpt_path) + train_results = None # todo: get the train results from the trainer. + return algorithm, train_results diff --git a/project/main.py b/project/main.py index 6c715159..5ed42379 100644 --- a/project/main.py +++ b/project/main.py @@ -10,41 +10,34 @@ from __future__ import annotations -import dataclasses +import functools import logging -import operator import os -import warnings +import typing from pathlib import Path -from typing import Any import hydra -import jax import lightning -import lightning.pytorch -import lightning.pytorch.loggers import omegaconf import rich +import rich.logging import wandb from hydra_plugins.auto_schema import auto_schema_plugin from omegaconf import DictConfig -from project.algorithms.jax_ppo import EvalMetrics from project.configs import add_configs_to_hydra_store from project.configs.config import Config -from project.experiment import ( - instantiate_algorithm, - instantiate_datamodule, - setup_logging, -) -from project.trainers.jax_trainer import JaxModule, JaxTrainer, Ts, _MetricsT +from project.experiment import evaluate, instantiate_datamodule, instantiate_trainer, train from project.utils.hydra_utils import resolve_dictconfig from project.utils.utils import print_config +if typing.TYPE_CHECKING: + from project.trainers.jax_trainer import JaxModule + PROJECT_NAME = Path(__file__).parent.name REPO_ROOTDIR = Path(__file__).parent.parent +logger = logging.getLogger(__name__) -setup_logging(log_level="INFO", global_log_level="ERROR") auto_schema_plugin.config = auto_schema_plugin.AutoSchemaPluginConfig( schemas_dir=REPO_ROOTDIR / ".schemas", @@ -77,39 +70,26 @@ def main(dict_config: DictConfig) -> dict: 3. Calls `evaluation` to evaluate the model 4. Returns the evaluation metrics. """ + print_config(dict_config, resolve=False) + assert dict_config["algorithm"] is not None # Resolve all the interpolations in the configs. config: Config = resolve_dictconfig(dict_config) - setup_logging( log_level=config.log_level, global_log_level="DEBUG" if config.debug else "INFO" if config.verbose else "WARNING", ) - # seed the random number generators, so the weights that are + # Seed the random number generators, so the weights that are # constructed are deterministic and reproducible. lightning.seed_everything(seed=config.seed, workers=True) - # Create the Trainer - trainer_config = config.trainer.copy() # Avoid mutating the config if possible. - callbacks: list[lightning.Callback] | None = instantiate_values( - trainer_config.pop("callbacks", None) - ) - logger: list[lightning.pytorch.loggers.Logger] | None = instantiate_values( - trainer_config.pop("logger", None) - ) - trainer: lightning.Trainer | JaxTrainer = hydra.utils.instantiate( - trainer_config, callbacks=callbacks, logger=logger - ) + # Create the algo. + algorithm = instantiate_algorithm(config) - # Create the datamodule (if present) - datamodule: lightning.LightningDataModule | None = instantiate_datamodule(config.datamodule) - - # Create the "algorithm" - algorithm: lightning.LightningModule | JaxModule = instantiate_algorithm( - config.algorithm, datamodule=datamodule - ) + # Create the trainer + trainer = instantiate_trainer(config.trainer) if wandb.run: wandb.run.config.update({k: v for k, v in os.environ.items() if k.startswith("SLURM")}) @@ -118,22 +98,19 @@ def main(dict_config: DictConfig) -> dict: ) # Train the algorithm. - train_results = train( - config=config, trainer=trainer, datamodule=datamodule, algorithm=algorithm + algorithm, train_results = train( + algorithm, + trainer=trainer, + config=config, ) # Evaluate the algorithm. - if isinstance(trainer, lightning.Trainer): - assert isinstance(algorithm, lightning.LightningModule) - metric_name, error, _metrics = evaluate_lightningmodule( - algorithm, datamodule=datamodule, trainer=trainer - ) - else: - assert isinstance(trainer, JaxTrainer) - assert isinstance(algorithm, JaxModule) - metric_name, error, _metrics = evaluate_jax_module( - algorithm, trainer=trainer, train_results=train_results - ) + metric_name, error, _metrics = evaluate( + algorithm, + trainer=trainer, + train_results=train_results, + config=config, + ) if wandb.run: wandb.finish() @@ -143,172 +120,60 @@ def main(dict_config: DictConfig) -> dict: return dict(name=metric_name, type="objective", value=error) -def train( - config: Config, - trainer: lightning.Trainer | JaxTrainer, - datamodule: lightning.LightningDataModule | None, - algorithm: lightning.LightningModule | JaxModule, -): - if isinstance(trainer, lightning.Trainer): - assert isinstance(algorithm, lightning.LightningModule) - # Train the model using the dataloaders of the datamodule: - # The Algorithm gets to "wrap" the datamodule if it wants to. This could be useful for - # example in RL, where we need to set the actor to use in the environment, as well as - # potentially adding Wrappers on top of the environment, or having a replay buffer, etc. - datamodule = getattr(algorithm, "datamodule", datamodule) - return trainer.fit( - algorithm, - datamodule=datamodule, - ckpt_path=config.ckpt_path, - ) - - if datamodule is not None: - raise NotImplementedError( - "The JaxTrainer doesn't yet support using a datamodule. For now, you should " - f"return a batch of data from the {JaxModule.get_batch.__name__} method in your " - f"algorithm." - ) - - if not isinstance(algorithm, JaxModule): - raise TypeError( - f"The selected algorithm ({algorithm}) doesn't implement the required methods of " - f"a {JaxModule.__name__}, so it can't be used with the `{JaxTrainer.__name__}`. " - f"Try to subclass {JaxModule.__name__} and implement the missing methods." - ) - import jax - - rng = jax.random.key(config.seed) - # TODO: Use ckpt_path argument to load the training state and resume the training run. - assert config.ckpt_path is None - return trainer.fit(algorithm, rng=rng) +def setup_logging(log_level: str, global_log_level: str = "WARNING") -> None: + from project.main import PROJECT_NAME + + logging.basicConfig( + level=global_log_level.upper(), + # format="%(asctime)s - %(levelname)s - %(message)s", + format="%(message)s", + datefmt="[%X]", + force=True, + handlers=[ + rich.logging.RichHandler( + markup=True, + rich_tracebacks=True, + tracebacks_width=100, + tracebacks_show_locals=False, + ) + ], + ) + project_logger = logging.getLogger(PROJECT_NAME) + project_logger.setLevel(log_level.upper()) -def instantiate_values(config_dict: DictConfig | None) -> list[Any] | None: - """Returns the list of objects at the values in this dict of configs. - This is used for the config of the `trainer/logger` and `trainer/callbacks` fields, where - we can combine multiple config groups by adding entries in a dict. +def instantiate_algorithm( + config: Config, datamodule: lightning.LightningDataModule | None = None +) -> lightning.LightningModule | JaxModule: + """Function used to instantiate the algorithm. - For example, using `trainer/logger=wandb` and `trainer/logger=tensorboard` would result in a - dict with `wandb` and `tensorboard` as keys, and the corresponding config groups as values. + It is suggested that your algorithm (LightningModule) take in the `datamodule` and `network` + as arguments, to make it easier to swap out different networks and datamodules during + experiments. - This would then return a list with the instantiated WandbLogger and TensorBoardLogger objects. + The instantiated datamodule and network will be passed to the algorithm's constructor. """ - if not config_dict: - return None - objects_dict = hydra.utils.instantiate(config_dict, _recursive_=True) - if objects_dict is None: - return None - - assert isinstance(objects_dict, dict | DictConfig) - return [v for v in objects_dict.values() if v is not None] - - -MetricName = str + # Create the algorithm + algo_config = config.algorithm -def evaluate_lightningmodule( - algorithm: lightning.LightningModule, - trainer: lightning.Trainer, - datamodule: lightning.LightningDataModule | None, -) -> tuple[MetricName, float | None, dict]: - """Evaluates the algorithm and returns the metrics. + # Create the datamodule (if present) from the config + if datamodule is None and config.datamodule is not None: + datamodule = instantiate_datamodule(config.datamodule) - By default, if validation is to be performed, returns the validation error. Returns the - training error when `trainer.overfit_batches != 0` (e.g. when debugging or testing). Otherwise, - if `trainer.limit_val_batches == 0`, returns the test error. - """ - - # exp.trainer.logger.log_hyperparams() - # When overfitting on a single batch or only training, we return the train error. - if (trainer.limit_val_batches == trainer.limit_test_batches == 0) or ( - trainer.overfit_batches == 1 # type: ignore - ): - # We want to report the training error. - results_type = "train" - results = [ - { - **trainer.logged_metrics, - **trainer.callback_metrics, - **trainer.progress_bar_metrics, - } - ] - elif trainer.limit_val_batches != 0: - results_type = "val" - results = trainer.validate(model=algorithm, datamodule=datamodule) + if datamodule: + algo_or_algo_partial = hydra.utils.instantiate(algo_config, datamodule=datamodule) else: - warnings.warn(RuntimeWarning("About to use the test set for evaluation!")) - results_type = "test" - results = trainer.test(model=algorithm, datamodule=datamodule) - - if results is None: - rich.print("RUN FAILED!") - return "fail", None, {} - - metrics = dict(results[0]) - for key, value in metrics.items(): - rich.print(f"{results_type} {key}: ", value) - - logger = logging.getLogger(__name__) - - if (accuracy := metrics.get(f"{results_type}/accuracy")) is not None: - # NOTE: This is the value that is used for HParam sweeps. - metric_name = "1-accuracy" - error = 1 - accuracy - - elif (loss := metrics.get(f"{results_type}/loss")) is not None: - logger.info("Assuming that the objective to minimize is the loss metric.") - # If 'accuracy' isn't in the results, assume that the loss is the metric to use. - metric_name = "loss" - error = loss - else: - raise RuntimeError( - f"Don't know which metric to use to calculate the 'error' of this run.\n" - f"Here are the available metric names:\n" - f"{list(metrics.keys())}" - ) + algo_or_algo_partial = hydra.utils.instantiate(algo_config) - return metric_name, error, metrics + if isinstance(algo_or_algo_partial, functools.partial): + if datamodule: + return algo_or_algo_partial(datamodule=datamodule) + return algo_or_algo_partial() - -def evaluate_jax_module( - algorithm: JaxModule[Ts, Any, _MetricsT], - trainer: JaxTrainer, - train_results: tuple[Ts, _MetricsT] | None = None, -): - # todo: there isn't yet a `validate` method on the jax trainer. - assert isinstance(trainer, JaxTrainer) - assert train_results is not None - metrics = train_results[1] - - return get_error_from_metrics(metrics) - - -# BUG: ULTRA weird bug happens with cloudpickle if we use a singledispatch function here! -# @functools.singledispatch -def get_error_from_metrics(metrics: _MetricsT) -> tuple[str, float, dict]: - """Returns the main metric name, its value, and the full metrics dictionary.""" - if isinstance(metrics, EvalMetrics): - return get_error_from_jax_rl_example_metrics(metrics) - raise NotImplementedError( - f"Don't know how to calculate the error to minimize from metrics {metrics} of type " - f"{type(metrics)}! " - f"You probably need to register a handler for it." - ) - - -# @get_error_from_metrics.register(EvalMetrics) -def get_error_from_jax_rl_example_metrics(metrics: EvalMetrics): - last_epoch_metrics = jax.tree.map(operator.itemgetter(-1), metrics) - assert isinstance(last_epoch_metrics, EvalMetrics) - # Average across eval seeds (we're doing evaluation in multiple environments in parallel with - # vmap). - last_epoch_average_cumulative_reward = last_epoch_metrics.cumulative_reward.mean().item() - return ( - "-avg_cumulative_reward", - -last_epoch_average_cumulative_reward, # need to return an "error" to minimize for HPO. - dataclasses.asdict(last_epoch_metrics), - ) + algorithm = algo_or_algo_partial + return algorithm if __name__ == "__main__": diff --git a/project/main_test.py b/project/main_test.py index 3bc8ddbc..32e18e74 100644 --- a/project/main_test.py +++ b/project/main_test.py @@ -1,7 +1,9 @@ # ADAPTED FROM https://github.com/facebookresearch/hydra/blob/main/examples/advanced/hydra_app_example/tests/test_example.py from __future__ import annotations +import shlex import shutil +import subprocess import sys import uuid from unittest.mock import Mock @@ -12,7 +14,9 @@ from _pytest.mark.structures import ParameterSet from hydra.types import RunMode from omegaconf import DictConfig +from pytest_regressions.file_regression import FileRegressionFixture +import project.experiment import project.main from project.conftest import command_line_overrides, skip_on_macOS_in_CI from project.utils.env_vars import REPO_ROOTDIR, SLURM_JOB_ID @@ -46,31 +50,20 @@ def test_torch_can_use_the_GPU(): @pytest.fixture def mock_train(monkeypatch: pytest.MonkeyPatch): - mock_train_fn = Mock(spec=project.main.train) + mock_train_fn = Mock(spec=project.main.train, return_value=(None, None)) monkeypatch.setattr(project.main, project.main.train.__name__, mock_train_fn) return mock_train_fn @pytest.fixture -def mock_evaluate_lightningmodule(monkeypatch: pytest.MonkeyPatch): - mock_eval_lightningmodule = Mock( - spec=project.main.evaluate_lightningmodule, return_value=("fake", 0.0, {}) - ) +def mock_evaluate(monkeypatch: pytest.MonkeyPatch): + mock_eval = Mock(spec=project.experiment.evaluate, return_value=("fake", 0.0, {})) monkeypatch.setattr( - project.main, project.main.evaluate_lightningmodule.__name__, mock_eval_lightningmodule + project.main, + project.experiment.evaluate.__name__, + mock_eval, ) - return mock_eval_lightningmodule - - -@pytest.fixture -def mock_evaluate_jax_module(monkeypatch: pytest.MonkeyPatch): - mock_eval_jax_module = Mock( - spec=project.main.evaluate_jax_module, return_value=("fake", 0.0, {}) - ) - monkeypatch.setattr( - project.main, project.main.evaluate_jax_module.__name__, mock_eval_jax_module - ) - return mock_eval_jax_module + return mock_eval experiment_configs = [p.stem for p in (CONFIG_DIR / "experiment").glob("*.yaml")] @@ -93,11 +86,6 @@ def mock_evaluate_jax_module(monkeypatch: pytest.MonkeyPatch): IN_GITHUB_CI, reason="Remote launcher tries to do a git push, doesn't work in github CI.", ), - pytest.mark.xfail( - raises=TypeError, - reason="TODO: Getting a `TypeError: cannot pickle 'weakref.ReferenceType' object` error.", - strict=False, - ), ], ), pytest.param( @@ -152,8 +140,7 @@ def test_experiment_config_is_tested(experiment_config: str): def test_can_load_experiment_configs( experiment_dictconfig: DictConfig, mock_train: Mock, - mock_evaluate_lightningmodule: Mock, - mock_evaluate_jax_module: Mock, + mock_evaluate: Mock, ): # Mock out some part of the `main` function to not actually run anything. if experiment_dictconfig["hydra"]["mode"] == RunMode.MULTIRUN: @@ -168,10 +155,7 @@ def test_can_load_experiment_configs( assert results is not None mock_train.assert_called_once() - # One of them should have been called once. - assert (mock_evaluate_lightningmodule.call_count == 1) ^ ( - mock_evaluate_jax_module.call_count == 1 - ) + mock_evaluate.assert_called_once() @pytest.mark.slow @@ -212,6 +196,19 @@ def test_setting_just_algorithm_isnt_enough(experiment_dictconfig: DictConfig) - _ = resolve_dictconfig(experiment_dictconfig) +def test_help_string(file_regression: FileRegressionFixture) -> None: + help_string = subprocess.run( + # Pass a seed so it isn't selected randomly, which would make the regression file change. + shlex.split("python project/main.py seed=123 --help"), + text=True, + capture_output=True, + ).stdout + # Remove trailing whitespace so pre-commit doesn't change the regression file. + # Also remove first or last empty lines (which would also be removed by pre-commit). + help_string = "\n".join([line.rstrip() for line in help_string.splitlines()]).strip() + "\n" + file_regression.check(help_string) + + @pytest.mark.skipif( IN_GITHUB_CI and sys.platform == "darwin", reason="TODO: Getting a 'MPS backend out of memory' error on the Github CI. ", diff --git a/project/trainers/jax_trainer.py b/project/trainers/jax_trainer.py index fa41a3ab..8f9a8f36 100644 --- a/project/trainers/jax_trainer.py +++ b/project/trainers/jax_trainer.py @@ -19,6 +19,8 @@ from hydra.core.hydra_config import HydraConfig from typing_extensions import TypeVar +from project.configs.config import Config +from project.experiment import train from project.utils.typing_utils.jax_typing_utils import jit Ts = TypeVar("Ts", bound=flax.struct.PyTreeNode, default=flax.struct.PyTreeNode) @@ -62,6 +64,37 @@ def eval_callback(self, ts: Ts) -> _MetricsT: raise NotImplementedError +@train.register(JaxModule) +def train_jax_module( + algorithm: JaxModule, + /, + *, + trainer: JaxTrainer, + config: Config, + datamodule: None = None, +): + if datamodule is not None: + raise NotImplementedError( + "The JaxTrainer doesn't yet support using a datamodule. For now, you should " + f"return a batch of data from the {JaxModule.get_batch.__name__} method in your " + f"algorithm." + ) + + if not isinstance(algorithm, JaxModule): + raise TypeError( + f"The selected algorithm ({algorithm}) doesn't implement the required methods of " + f"a {JaxModule.__name__}, so it can't be used with the `{JaxTrainer.__name__}`. " + f"Try to subclass {JaxModule.__name__} and implement the missing methods." + ) + import jax + + rng = jax.random.key(config.seed) + # TODO: Use ckpt_path argument to load the training state and resume the training run. + assert config.ckpt_path is None + ts, train_metrics = trainer.fit(algorithm, rng=rng) + return algorithm, (ts, train_metrics) + + class JaxCallback(flax.struct.PyTreeNode): def setup(self, trainer: JaxTrainer, module: JaxModule[Ts], stage: str, ts: Ts): ... def on_fit_start(self, trainer: JaxTrainer, module: JaxModule[Ts], ts: Ts): ...