diff --git a/docs/profiling_test.py b/docs/profiling_test.py index 24b9c39c..0fb9f659 100644 --- a/docs/profiling_test.py +++ b/docs/profiling_test.py @@ -1,6 +1,7 @@ import shutil import hydra.errors +import lightning import pytest from omegaconf import DictConfig @@ -13,7 +14,12 @@ datamodule_config, experiment_dictconfig, ) -from project.experiment import setup_experiment +from project.experiment import ( + instantiate_algorithm, + instantiate_datamodule, + instantiate_trainer, + setup_logging, +) from project.utils.hydra_utils import resolve_dictconfig @@ -111,5 +117,11 @@ def test_notebook_commands_dont_cause_errors(experiment_dictconfig: DictConfig): # check for any errors related to OmegaConf interpolations and such config = resolve_dictconfig(experiment_dictconfig) # check for any errors when actually instantiating the components. - _experiment = setup_experiment(config) + # _experiment = _setup_experiment(config) + setup_logging(config) + lightning.seed_everything(config.seed, workers=True) + _trainer = instantiate_trainer(config) + datamodule = instantiate_datamodule(config.datamodule) + _algorithm = instantiate_algorithm(config.algorithm, datamodule=datamodule) + # Note: Here we don't actually do anything with the objects. diff --git a/project/algorithms/jax_rl_example.py b/project/algorithms/jax_rl_example.py index 62764a94..8cbfedc5 100644 --- a/project/algorithms/jax_rl_example.py +++ b/project/algorithms/jax_rl_example.py @@ -106,6 +106,9 @@ class PPOHParams(flax.struct.PyTreeNode): num_steps: int = field(default=64) num_minibatches: int = field(default=16) + # ADDED: + num_seeds_per_eval: int = field(default=128) + eval_freq: int = field(default=4_096) normalize_observations: bool = field(default=False) @@ -115,6 +118,7 @@ class PPOHParams(flax.struct.PyTreeNode): learning_rate: chex.Scalar = 0.0003 gamma: chex.Scalar = 0.99 max_grad_norm: chex.Scalar = jnp.inf + # todo: this `jnp.inf` is causing issues in the yaml schema because it becomes `Infinity`. gae_lambda: chex.Scalar = 0.95 clip_eps: chex.Scalar = 0.2 @@ -389,9 +393,13 @@ def eval_callback( if rng is None: rng = ts.rng actor = make_actor(ts=ts, hp=self.hp) - max_steps = self.env_params.max_steps_in_episode ep_lengths, cum_rewards = evaluate( - actor, ts.rng, self.env, self.env_params, 128, max_steps + actor, + ts.rng, + self.env, + self.env_params, + num_seeds=self.hp.num_seeds_per_eval, + max_steps_in_episode=self.env_params.max_steps_in_episode, ) return EvalMetrics(episode_length=ep_lengths, cumulative_reward=cum_rewards) diff --git a/project/algorithms/jax_rl_example_test.py b/project/algorithms/jax_rl_example_test.py index e859e658..11a37967 100644 --- a/project/algorithms/jax_rl_example_test.py +++ b/project/algorithms/jax_rl_example_test.py @@ -29,7 +29,6 @@ from project.algorithms.callbacks.samples_per_second import MeasureSamplesPerSecondCallback from project.trainers.jax_trainer import JaxTrainer, hparams_to_dict -from project.utils.testutils import run_for_all_configs_of_type from .jax_rl_example import ( EvalMetrics, @@ -43,7 +42,6 @@ _actor, render_episode, ) -from .testsuites.algorithm_tests import LearningAlgorithmTests logger = getLogger(__name__) @@ -671,10 +669,10 @@ def log( # TODO: potentially just use the Lightning adapter for unit tests for now? -@pytest.mark.skip(reason="TODO: ests assume a LightningModule atm (.state_dict()), etc.") -@run_for_all_configs_of_type("algorithm", JaxRLExample) -class TestJaxRLExample(LearningAlgorithmTests[JaxRLExample]): # type: ignore - pass +# @pytest.mark.skip(reason="TODO: ests assume a LightningModule atm (.state_dict()), etc.") +# @run_for_all_configs_of_type("algorithm", JaxRLExample) +# class TestJaxRLExample(LearningAlgorithmTests[JaxRLExample]): # type: ignore +# pass @pytest.fixture diff --git a/project/algorithms/no_op.py b/project/algorithms/no_op.py index b5c1ec59..f4c35909 100644 --- a/project/algorithms/no_op.py +++ b/project/algorithms/no_op.py @@ -15,6 +15,7 @@ def __init__(self, datamodule: DataModule): self.datamodule = datamodule # Set this so PyTorch-Lightning doesn't try to train the model using our 'loss' self.automatic_optimization = False + self.p = torch.nn.Parameter(torch.tensor(0.0)) # unused. def training_step(self, batch: Any, batch_index: int): return self.shared_step(batch, batch_index, "train") diff --git a/project/configs/config_test.py b/project/configs/config_test.py index 6fecbf71..20fe2f5c 100644 --- a/project/configs/config_test.py +++ b/project/configs/config_test.py @@ -1,76 +1,11 @@ """TODO: Add tests for the configurations?""" -import copy -from unittest.mock import Mock - import hydra_zen import lightning -import omegaconf import pytest from hydra.core.config_store import ConfigStore -from omegaconf import DictConfig - -import project -import project.main -from project.conftest import algorithm_config, command_line_overrides -from project.main import PROJECT_NAME -from project.utils.env_vars import REPO_ROOTDIR, SLURM_JOB_ID - -CONFIG_DIR = REPO_ROOTDIR / PROJECT_NAME / "configs" - -experiment_configs = list((CONFIG_DIR / "experiment").glob("*.yaml")) - - -@pytest.fixture -def mock_train(monkeypatch: pytest.MonkeyPatch): - mock_train_fn = Mock(spec=project.main.train) - monkeypatch.setattr(project.main, project.main.train.__name__, mock_train_fn) - return mock_train_fn - - -@pytest.fixture -def mock_evaluate(monkeypatch: pytest.MonkeyPatch): - mock_eval_fn = Mock(spec=project.main.evaluation, return_value=("fake", 0.0, {})) - monkeypatch.setattr(project.main, project.main.evaluation.__name__, mock_eval_fn) - return mock_eval_fn - -# The problem is that not all experiment configs -# are to be used in the same way. For example, -# the cluster_sweep_config.yaml needs an -# additional `cluster` argument. Also, the -# example config uses wandb by default, which is -# probably bad, since it might be creating empty -# jobs in wandb during tests (since the logger is -# instantiated in main, even if the train fn is -# mocked. - -@pytest.mark.skip(reason="TODO: test is too general") -@pytest.mark.parametrize( - command_line_overrides.__name__, - [ - pytest.param( - f"experiment={experiment.name}", - marks=pytest.mark.xfail( - "cluster" in experiment.name and SLURM_JOB_ID is None, - reason="Needs to be run on a cluster.", - raises=omegaconf.errors.InterpolationResolutionError, - strict=True, - ), - ) - for experiment in list(experiment_configs) - ], - indirect=True, - ids=[experiment.name for experiment in list(experiment_configs)], -) -def test_can_load_experiment_configs( - experiment_dictconfig: DictConfig, mock_train: Mock, mock_evaluate: Mock -): - # Mock out some part of the `main` function to not actually run anything. - results = project.main.main(copy.deepcopy(experiment_dictconfig)) - assert results is not None - mock_train.assert_called_once() - mock_evaluate.assert_called_once() +from project.conftest import algorithm_config class DummyModule(lightning.LightningModule): diff --git a/project/configs/experiment/cluster_sweep_example.yaml b/project/configs/experiment/cluster_sweep_example.yaml index 3bc0c37e..a6e8c157 100644 --- a/project/configs/experiment/cluster_sweep_example.yaml +++ b/project/configs/experiment/cluster_sweep_example.yaml @@ -1,4 +1,7 @@ # @package _global_ + +# This is an "experiment" config, that groups together other configs into a ready-to-run example. + defaults: - example.yaml # A configuration for a single run (that works!) - override /trainer/logger: wandb @@ -8,6 +11,8 @@ defaults: log_level: DEBUG name: "sweep-example" + +# Set the seed to be the SLURM_PROCID, so that if we run more than one task per GPU, we get # TODO: This should technically be something like the "run_id", which would be different than SLURM_PROCID when using >1 gpus per "run". seed: ${oc.env:SLURM_PROCID,123} @@ -44,7 +49,7 @@ hydra: sweep: dir: logs/${name}/multiruns/ # subdir: ${hydra.job.num} - subdir: ${hydra.job.id}/task${oc.env:SLURM_PROCID} + subdir: ${hydra.job.id}/task${oc.env:SLURM_PROCID,0} launcher: # todo: bump this up. @@ -54,16 +59,17 @@ hydra: # TODO: Pack more than one job on a single GPU, and support this with both a # patched submitit launcher as well as our remote submitit launcher, as well as by patching the # orion sweeper to not drop these other results. - ntasks_per_gpu: 1 + # ntasks_per_gpu: 1 sweeper: params: algorithm: optimizer: lr: "loguniform(1e-6, 1.0, default_value=3e-4)" # weight_decay: "loguniform(1e-6, 1e-2, default_value=0)" - trainer: - # Let the HPO algorithm allocate more epochs to more promising HP configurations. - max_epochs: "fidelity(1, 10, default_value=1)" + # todo: setup a fidelity parameter. Seems to not be working right now. + # trainer: + # # Let the HPO algorithm allocate more epochs to more promising HP configurations. + # max_epochs: "fidelity(1, 10, default_value=1)" parametrization: null experiment: diff --git a/project/configs/experiment/example.yaml b/project/configs/experiment/example.yaml index f20a346a..4d1a97c1 100644 --- a/project/configs/experiment/example.yaml +++ b/project/configs/experiment/example.yaml @@ -1,28 +1,25 @@ # @package _global_ -# to execute this experiment run: -# python main.py experiment=example +# This is an "experiment" config, that groups together other configs into a ready-to-run example. + +# To execute this experiment, use: +# python project/main.py experiment=example defaults: - override /algorithm: example + - override /algorithm/network: resnet18 - override /datamodule: cifar10 - override /trainer: default - - override /trainer/logger: wandb + - override /trainer/logger: tensorboard + - override /trainer/callbacks: default -# all parameters below will be merged with parameters from default configurations set above -# this allows you to overwrite only specified parameters -name: example +# The parameters below will be merged with parameters from default configurations set above. +# This allows you to overwrite only specified parameters -seed: ${oc.env:SLURM_PROCID,12345} +# The name of the e +name: example -# hydra: -# run: -# # output directory, generated dynamically on each run -# # DOESN'T WORK! This won't get interpolated correctly! -# # TODO: Make it so running the same command twice in the same job id resumes from the last checkpoint. -# dir: logs/${name}/runs/${oc.env:SLURM_JOB_ID,${hydra.job.id}} -# sweep: -# dir: logs/${name}/multiruns/ +seed: ${oc.env:SLURM_PROCID,42} algorithm: optimizer: diff --git a/project/configs/experiment/jax_rl_example.yaml b/project/configs/experiment/jax_rl_example.yaml index da571b8f..41cdc2fa 100644 --- a/project/configs/experiment/jax_rl_example.yaml +++ b/project/configs/experiment/jax_rl_example.yaml @@ -5,7 +5,9 @@ defaults: - override /trainer: jax - override /trainer/callbacks: rich_progress_bar - override /datamodule: null + # - /trainer/logger: tensorboard trainer: + _convert_: object max_epochs: 75 training_steps_per_epoch: 1 callbacks: diff --git a/project/configs/resources/cpu.yaml b/project/configs/resources/cpu.yaml index 328da608..8c64ee65 100644 --- a/project/configs/resources/cpu.yaml +++ b/project/configs/resources/cpu.yaml @@ -10,7 +10,7 @@ hydra: launcher: nodes: 1 tasks_per_node: 1 - cpus_per_task: 8 + cpus_per_task: 4 mem_gb: 16 array_parallelism: 16 # max num of jobs to run in parallel # Other things to pass to `sbatch`: diff --git a/project/configs/trainer/jax.yaml b/project/configs/trainer/jax.yaml index ae68ce4a..10238873 100644 --- a/project/configs/trainer/jax.yaml +++ b/project/configs/trainer/jax.yaml @@ -1,5 +1,6 @@ defaults: - callbacks: rich_progress_bar.yaml + - logger: null _target_: project.trainers.jax_trainer.JaxTrainer max_epochs: 75 training_steps_per_epoch: 1 diff --git a/project/conftest.py b/project/conftest.py index a58fc923..3769b066 100644 --- a/project/conftest.py +++ b/project/conftest.py @@ -68,6 +68,7 @@ from typing import Literal import jax +import lightning import lightning.pytorch as pl import pytest import tensor_regression.stats @@ -88,7 +89,6 @@ instantiate_algorithm, instantiate_datamodule, instantiate_trainer, - seed_rng, setup_logging, ) from project.main import PROJECT_NAME @@ -186,7 +186,10 @@ def command_line_arguments( # If we manually overwrite the command-line arguments with indirect parametrization, # then ignore the rest of the stuff here and just use the provided command-line args. # Split the string into a list of command-line arguments if needed. - return shlex.split(param) if isinstance(param, str) else param + if isinstance(param, str): + return tuple(shlex.split(param)) + assert isinstance(param, list | tuple) + return tuple(param) combination = set([datamodule_config, algorithm_network_config, algorithm_config]) for configs, marks in default_marks_for_config_combinations.items(): @@ -221,7 +224,7 @@ def command_line_arguments( @pytest.fixture(scope="session") def experiment_dictconfig( - command_line_arguments: list[str], tmp_path_factory: pytest.TempPathFactory + command_line_arguments: tuple[str, ...], tmp_path_factory: pytest.TempPathFactory ) -> DictConfig: """The `omegaconf.DictConfig` that is created by Hydra from the command-line arguments. @@ -237,12 +240,12 @@ def experiment_dictconfig( tmp_path = tmp_path_factory.mktemp("test") if not any("trainer.default_root_dir" in override for override in command_line_arguments): - command_line_arguments = command_line_arguments + [ - f"++trainer.default_root_dir={tmp_path}" - ] + command_line_arguments = tuple(command_line_arguments) + ( + f"++trainer.default_root_dir={tmp_path}", + ) with _setup_hydra_for_tests_and_compose( - all_overrides=command_line_arguments, + all_overrides=list(command_line_arguments), tmp_path_factory=tmp_path_factory, ) as dict_config: return dict_config @@ -287,7 +290,7 @@ def trainer( experiment_config: Config, ) -> pl.Trainer: setup_logging(experiment_config) - seed_rng(experiment_config) + lightning.seed_everything(experiment_config.seed, workers=True) return instantiate_trainer(experiment_config) @@ -432,7 +435,7 @@ def _override_param_id(override: Param) -> str: @pytest.fixture(scope="session", ids=_override_param_id) -def command_line_overrides(request: pytest.FixtureRequest): +def command_line_overrides(request: pytest.FixtureRequest) -> tuple[str, ...]: """Fixture that makes it possible to specify command-line overrides to use in a given test. Tests that require running an experiment should use the `experiment_config` fixture below. diff --git a/project/experiment.py b/project/experiment.py index bdbee664..aef75c67 100644 --- a/project/experiment.py +++ b/project/experiment.py @@ -15,8 +15,6 @@ import functools import logging import os -import random -from dataclasses import dataclass from logging import getLogger as get_logger from typing import Any @@ -27,7 +25,7 @@ import rich.logging import rich.traceback from hydra_zen.typing import Builds -from lightning import Callback, LightningDataModule, LightningModule, Trainer, seed_everything +from lightning import Callback, LightningDataModule, LightningModule, Trainer from project.configs.config import Config from project.trainers.jax_trainer import JaxModule @@ -46,50 +44,6 @@ instantiate = hydra_zen.instantiate -@dataclass -class Experiment: - """Dataclass containing everything used in an experiment. - - This gets created from the config that are parsed from Hydra. Can be used to run the experiment - by calling `run(experiment)`. Could also be serialized to a file or saved to disk, which might - come in handy with `submitit` later on. - """ - - algorithm: LightningModule - datamodule: DataModule | None - trainer: Trainer - - -def setup_experiment(experiment_config: Config) -> Experiment: - """Instantiate the experiment components from the Hydra configuration. - - All the interpolations in the configs have already been resolved by - [project.utils.hydra_utils.resolve_dictconfig][]. Now we only need to instantiate the components - from their configs. - - Do all the postprocessing necessary (e.g., create the network, datamodule, callbacks, - Trainer, Algorithm, etc) to go from the options that come from Hydra, into all required - components for the experiment, which is stored as a dataclass called `Experiment`. - - NOTE: This also has the effect of seeding the random number generators, so the weights that are - constructed are deterministic and reproducible. - """ - setup_logging(experiment_config) - seed_rng(experiment_config) - trainer = instantiate_trainer(experiment_config) - - datamodule = instantiate_datamodule(experiment_config.datamodule) - - algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule) - - return Experiment( - trainer=trainer, - algorithm=algorithm, - # network=network, - datamodule=datamodule, - ) - - def setup_logging(experiment_config: Config) -> None: LOGLEVEL = os.environ.get("LOGLEVEL", "info").upper() logging.basicConfig( @@ -116,16 +70,6 @@ def setup_logging(experiment_config: Config) -> None: root_logger.setLevel(logging.DEBUG) -def seed_rng(experiment_config: Config): - if experiment_config.seed is not None: - seed = experiment_config.seed - print(f"seed manually set to {experiment_config.seed}") - else: - seed = random.randint(0, int(1e5)) - print(f"Randomly selected seed: {seed}") - seed_everything(seed=seed, workers=True) - - def instantiate_trainer(experiment_config: Config) -> Trainer: # NOTE: Need to do a bit of sneaky type tricks to convince the outside world that these # fields have the right type. diff --git a/project/main.py b/project/main.py index 126a292d..282dda3e 100644 --- a/project/main.py +++ b/project/main.py @@ -10,6 +10,9 @@ from __future__ import annotations +import dataclasses +import functools +import operator import os import warnings from logging import getLogger as get_logger @@ -22,19 +25,19 @@ import omegaconf import rich from hydra_plugins.auto_schema import auto_schema_plugin -from lightning import Callback, LightningDataModule +from lightning import Callback from lightning.pytorch.loggers import Logger from omegaconf import DictConfig +from project.algorithms.jax_rl_example import EvalMetrics from project.configs import add_configs_to_hydra_store from project.configs.config import Config from project.experiment import ( instantiate_algorithm, instantiate_datamodule, - seed_rng, setup_logging, ) -from project.trainers.jax_trainer import JaxModule, JaxTrainer +from project.trainers.jax_trainer import JaxModule, JaxTrainer, Ts, _MetricsT from project.utils.env_vars import REPO_ROOTDIR from project.utils.hydra_utils import resolve_dictconfig from project.utils.utils import print_config @@ -64,29 +67,42 @@ def main(dict_config: DictConfig) -> dict: This does roughly the same thing as https://github.com/ashleve/lightning-hydra-template/blob/main/src/train.py + + 1. Instantiates the experiment components from the Hydra configuration: + - trainer + - algorithm + - datamodule (optional) + 2. Calls `train` to train the algorithm + 3. Calls `evaluation` to evaluate the model + 4. Returns the evaluation metrics. """ print_config(dict_config, resolve=False) + # Resolve all the interpolations in the configs. config: Config = resolve_dictconfig(dict_config) - experiment_config = config - setup_logging(experiment_config) - seed_rng(experiment_config) + # Now we instantiate the components. + + # seed the random number generators, so the weights that are + # constructed are deterministic and reproducible. + + setup_logging(config) + lightning.seed_everything(seed=config.seed, workers=True) - trainer_config = config.trainer.copy() # Avoid mutating the input config, if passed. + # Create the Trainer + trainer_config = config.trainer.copy() # Avoid mutating the config if possible. callbacks: list[Callback] | None = instantiate_values(trainer_config.pop("callbacks", None)) logger: list[Logger] | None = instantiate_values(trainer_config.pop("logger", None)) - trainer: lightning.Trainer | JaxTrainer = hydra.utils.instantiate( trainer_config, callbacks=callbacks, logger=logger ) - datamodule: lightning.LightningDataModule | None = instantiate_datamodule( - experiment_config.datamodule - ) + # Create the datamodule (if present) + datamodule: lightning.LightningDataModule | None = instantiate_datamodule(config.datamodule) + # Create the "algorithm" algorithm: lightning.LightningModule | JaxModule = instantiate_algorithm( - experiment_config.algorithm, datamodule=datamodule + config.algorithm, datamodule=datamodule ) import wandb @@ -96,17 +112,28 @@ def main(dict_config: DictConfig) -> dict: wandb.run.config.update( omegaconf.OmegaConf.to_container(dict_config, resolve=False, throw_on_missing=True) ) - - train(config=config, trainer=trainer, datamodule=datamodule, algorithm=algorithm) - - metric_name, error, _metrics = evaluation( - trainer=trainer, datamodule=datamodule, algorithm=algorithm + # Train the algorithm. + train_results = train( + config=config, trainer=trainer, datamodule=datamodule, algorithm=algorithm ) + # Evaluate the algorithm. + if isinstance(algorithm, JaxModule): + assert isinstance(trainer, JaxTrainer) + metric_name, error, _metrics = evaluate_jax_module( + algorithm, trainer=trainer, train_results=train_results + ) + else: + assert isinstance(trainer, lightning.Trainer) + metric_name, error, _metrics = evaluate_lightningmodule( + algorithm, datamodule=datamodule, trainer=trainer + ) + if wandb.run: wandb.finish() assert error is not None + # Results are returned like this so that the Orion sweeper can parse the results correctly. return dict(name=metric_name, type="objective", value=error) @@ -123,12 +150,11 @@ def train( # example in RL, where we need to set the actor to use in the environment, as well as # potentially adding Wrappers on top of the environment, or having a replay buffer, etc. datamodule = getattr(algorithm, "datamodule", datamodule) - trainer.fit( + return trainer.fit( algorithm, datamodule=datamodule, ckpt_path=config.ckpt_path, ) - return if datamodule is not None: raise NotImplementedError( @@ -146,7 +172,7 @@ def train( rng = jax.random.key(config.seed) # TODO: Use ckpt_path argument to load the training state and resume the training run. assert config.ckpt_path is None - trainer.fit(algorithm, rng=rng) + return trainer.fit(algorithm, rng=rng) def instantiate_values(config_dict: DictConfig | None) -> list[Any] | None: @@ -161,7 +187,7 @@ def instantiate_values(config_dict: DictConfig | None) -> list[Any] | None: This would then return a list with the instantiated WandbLogger and TensorBoardLogger objects. """ if not config_dict: - return [] + return None objects_dict = hydra.utils.instantiate(config_dict, _recursive_=True) if objects_dict is None: return None @@ -172,10 +198,10 @@ def instantiate_values(config_dict: DictConfig | None) -> list[Any] | None: MetricName = str -def evaluation( - trainer: JaxTrainer | lightning.Trainer, +def evaluate_lightningmodule( + algorithm: lightning.LightningModule, + trainer: lightning.Trainer, datamodule: lightning.LightningDataModule | None, - algorithm, ) -> tuple[MetricName, float | None, dict]: """Evaluates the algorithm and returns the metrics. @@ -183,73 +209,92 @@ def evaluation( training error when `trainer.overfit_batches != 0` (e.g. when debugging or testing). Otherwise, if `trainer.limit_val_batches == 0`, returns the test error. """ - # TODO Probably log the hydra config with something like this: + # exp.trainer.logger.log_hyperparams() # When overfitting on a single batch or only training, we return the train error. if (trainer.limit_val_batches == trainer.limit_test_batches == 0) or ( trainer.overfit_batches == 1 # type: ignore ): # We want to report the training error. - metrics = { - **trainer.logged_metrics, - **trainer.callback_metrics, - **trainer.progress_bar_metrics, - } - rich.print(metrics) - if "train/accuracy" in metrics: - train_acc: float = metrics["train/accuracy"] - train_error = 1 - train_acc - return "1-accuracy", train_error, metrics - elif "train/avg_episode_reward" in metrics: - average_episode_rewards: float = metrics["train/avg_episode_reward"] - train_error = -average_episode_rewards - return "-avg_episode_reward", train_error, metrics - elif "train/loss" in metrics: - return "loss", metrics["train/loss"], metrics - else: - raise RuntimeError( - f"Don't know which metric to use to calculate the 'error' of this run.\n" - f"Here are the available metric names:\n" - f"{list(metrics.keys())}" - ) - assert isinstance(datamodule, LightningDataModule) - - if trainer.limit_val_batches != 0: - results = trainer.validate(model=algorithm, datamodule=datamodule) + results_type = "train" + results = [ + { + **trainer.logged_metrics, + **trainer.callback_metrics, + **trainer.progress_bar_metrics, + } + ] + elif trainer.limit_val_batches != 0: results_type = "val" + results = trainer.validate(model=algorithm, datamodule=datamodule) else: warnings.warn(RuntimeWarning("About to use the test set for evaluation!")) - results = trainer.test(model=algorithm, datamodule=datamodule) results_type = "test" + results = trainer.test(model=algorithm, datamodule=datamodule) if results is None: rich.print("RUN FAILED!") return "fail", None, {} - returned_results_dict = dict(results[0]) - results_dict = dict(results[0]).copy() - - loss = results_dict.pop(f"{results_type}/loss") - - if f"{results_type}/accuracy" in results_dict: - accuracy: float = results_dict[f"{results_type}/accuracy"] - rich.print(f"{results_type} accuracy: {accuracy:.1%}") + metrics = dict(results[0]) + for key, value in metrics.items(): + rich.print(f"{results_type} {key}: ", value) - if top5_accuracy := results_dict.get(f"{results_type}/top5_accuracy") is not None: - rich.print(f"{results_type} top5 accuracy: {top5_accuracy:.1%}") + if (accuracy := metrics.get(f"{results_type}/accuracy")) is not None: # NOTE: This is the value that is used for HParam sweeps. - error = 1 - accuracy metric_name = "1-accuracy" - else: - logger.warning("Assuming that the objective to minimize is the loss metric.") + error = 1 - accuracy + + elif (loss := metrics.get(f"{results_type}/loss")) is not None: + logger.info("Assuming that the objective to minimize is the loss metric.") # If 'accuracy' isn't in the results, assume that the loss is the metric to use. metric_name = "loss" error = loss + else: + raise RuntimeError( + f"Don't know which metric to use to calculate the 'error' of this run.\n" + f"Here are the available metric names:\n" + f"{list(metrics.keys())}" + ) - for key, value in results_dict.items(): - rich.print(f"{results_type} {key}: ", value) + return metric_name, error, metrics - return metric_name, error, returned_results_dict + +def evaluate_jax_module( + algorithm: JaxModule[Ts, Any, _MetricsT], + trainer: JaxTrainer, + train_results: tuple[Ts, _MetricsT] | None = None, +): + # todo: there isn't yet a `validate` method on the jax trainer. + assert isinstance(trainer, JaxTrainer) + assert train_results is not None + metrics = train_results[1] + + return get_error_from_metrics(metrics) + + +@functools.singledispatch +def get_error_from_metrics(metrics: _MetricsT) -> tuple[MetricName, float, dict]: + """Returns the main metric name, its value, and the full metrics dictionary.""" + raise NotImplementedError( + f"Don't know how to calculate the error to minimize from metrics {metrics} of type " + f"{type(metrics)}! " + f"You probably need to register a handler for it." + ) + + +@get_error_from_metrics.register(EvalMetrics) +def get_error_from_jax_rl_example_metrics(metrics: EvalMetrics): + last_epoch_metrics = jax.tree.map(operator.itemgetter(-1), metrics) + assert isinstance(last_epoch_metrics, EvalMetrics) + # Average across eval seeds (we're doing evaluation in multiple environments in parallel with + # vmap). + last_epoch_average_cumulative_reward = last_epoch_metrics.cumulative_reward.mean().item() + return ( + "-avg_cumulative_reward", + -last_epoch_average_cumulative_reward, # need to return an "error" to minimize for HPO. + dataclasses.asdict(last_epoch_metrics), + ) if __name__ == "__main__": diff --git a/project/main_test.py b/project/main_test.py index c7854f8d..dab992c9 100644 --- a/project/main_test.py +++ b/project/main_test.py @@ -2,21 +2,30 @@ from __future__ import annotations import shutil +import sys +import uuid +from unittest.mock import Mock import hydra_zen import omegaconf.errors import pytest import torch +from _pytest.mark.structures import ParameterSet +from hydra.types import RunMode from omegaconf import DictConfig +import project.main from project.algorithms.example import ExampleAlgorithm from project.configs.config import Config -from project.configs.config_test import CONFIG_DIR from project.conftest import command_line_overrides from project.datamodules.image_classification.cifar10 import CIFAR10DataModule +from project.utils.env_vars import REPO_ROOTDIR, SLURM_JOB_ID from project.utils.hydra_utils import resolve_dictconfig +from project.utils.testutils import IN_GITHUB_CI -from .main import main +from .main import PROJECT_NAME, main + +CONFIG_DIR = REPO_ROOTDIR / PROJECT_NAME / "configs" def test_jax_can_use_the_GPU(): @@ -39,6 +48,152 @@ def test_torch_can_use_the_GPU(): assert torch.cuda.is_available() == bool(shutil.which("nvidia-smi")) +@pytest.fixture +def mock_train(monkeypatch: pytest.MonkeyPatch): + mock_train_fn = Mock(spec=project.main.train) + monkeypatch.setattr(project.main, project.main.train.__name__, mock_train_fn) + return mock_train_fn + + +@pytest.fixture +def mock_evaluate_lightningmodule(monkeypatch: pytest.MonkeyPatch): + mock_eval_lightningmodule = Mock( + spec=project.main.evaluate_lightningmodule, return_value=("fake", 0.0, {}) + ) + monkeypatch.setattr( + project.main, project.main.evaluate_lightningmodule.__name__, mock_eval_lightningmodule + ) + return mock_eval_lightningmodule + + +@pytest.fixture +def mock_evaluate_jax_module(monkeypatch: pytest.MonkeyPatch): + mock_eval_jax_module = Mock( + spec=project.main.evaluate_jax_module, return_value=("fake", 0.0, {}) + ) + monkeypatch.setattr( + project.main, project.main.evaluate_jax_module.__name__, mock_eval_jax_module + ) + return mock_eval_jax_module + + +experiment_configs = [p.stem for p in (CONFIG_DIR / "experiment").glob("*.yaml")] + +experiment_commands_to_test = [ + "experiment=example trainer.fast_dev_run=True", + "experiment=hf_example trainer.fast_dev_run=True", + # "experiment=jax_example trainer.fast_dev_run=True", + "experiment=jax_rl_example trainer.max_epochs=1", + pytest.param( + f"experiment=cluster_sweep_example " + f"trainer/logger=[] " # disable logging. + f"trainer.fast_dev_run=True " # make each job quicker to run + f"hydra.sweeper.worker.max_trials=1 " # limit the number of jobs that get launched. + f"resources=gpu " + f"cluster={'current' if SLURM_JOB_ID else 'mila'} ", + marks=[ + pytest.mark.slow, + pytest.mark.skipif( + IN_GITHUB_CI, + reason="Remote launcher tries to do a git push, doesn't work in github CI.", + ), + pytest.mark.xfail( + raises=TypeError, + reason="TODO: Getting a `TypeError: cannot pickle 'weakref.ReferenceType' object` error.", + strict=False, + ), + ], + ), + pytest.param( + "experiment=local_sweep_example " + "trainer/logger=[] " # disable logging. + "trainer.fast_dev_run=True " # make each job quicker to run + "hydra.sweeper.worker.max_trials=2 ", # Run a small number of trials. + marks=pytest.mark.slow, + ), + pytest.param( + "experiment=profiling " + "datamodule=cifar10 " # Run a small dataset instead of ImageNet (would take ~6min to process on a compute node..) + "trainer/logger=tensorboard " # Use Tensorboard logger because DeviceStatsMonitor requires a logger being used. + "trainer.fast_dev_run=True ", # make each job quicker to run + marks=pytest.mark.slow, + ), + pytest.param( + "experiment=profiling " + "algorithm=no_op " + "datamodule=cifar10 " # Run a small dataset instead of ImageNet (would take ~6min to process on a compute node..) + "trainer/logger=tensorboard " # Use Tensorboard logger because DeviceStatsMonitor requires a logger being used. + "trainer.fast_dev_run=True " # make each job quicker to run + ), +] + + +@pytest.mark.parametrize("experiment_config", experiment_configs) +def test_experiment_config_is_tested(experiment_config: str): + select_experiment_command = f"experiment={experiment_config}" + + for test_command in experiment_commands_to_test: + if isinstance(test_command, ParameterSet): + assert len(test_command.values) == 1 + assert isinstance(test_command.values[0], str), test_command.values + test_command = test_command.values[0] + if select_experiment_command in test_command: + return # success. + + raise RuntimeError(f"{experiment_config=} is not tested by any of the test commands!") + + +@pytest.mark.parametrize( + command_line_overrides.__name__, + experiment_commands_to_test, + indirect=True, +) +def test_can_load_experiment_configs( + experiment_dictconfig: DictConfig, + mock_train: Mock, + mock_evaluate_lightningmodule: Mock, + mock_evaluate_jax_module: Mock, +): + # Mock out some part of the `main` function to not actually run anything. + if experiment_dictconfig["hydra"]["mode"] == RunMode.MULTIRUN: + # NOTE: Can't pass a dictconfig to `main` function when doing a multirun (seems to just do + # a single run). If we try to call `main` without arguments and with the right arguments on\ + # the command-line, with the right functions mocked out, those might not get used at all + # since `main` seems to create the launcher which pickles stuff and uses subprocesses. + # Pretty gnarly stuff. + pytest.skip(reason="Config is a multi-run config (e.g. a sweep). ") + else: + results = project.main.main(experiment_dictconfig) + assert results is not None + + mock_train.assert_called_once() + # One of them should have been called once. + assert (mock_evaluate_lightningmodule.call_count == 1) ^ ( + mock_evaluate_jax_module.call_count == 1 + ) + + +@pytest.mark.slow +@pytest.mark.parametrize( + command_line_overrides.__name__, + experiment_commands_to_test, + indirect=True, +) +def test_can_run_experiment( + command_line_overrides: tuple[str, ...], + request: pytest.FixtureRequest, + monkeypatch: pytest.MonkeyPatch, +): + # Mock out some part of the `main` function to not actually run anything. + # Get a unique hash id: + # todo: Set a unique name to avoid collisions between tests and reusing previous results. + name = f"{request.function.__name__}_{uuid.uuid4().hex}" + command_line_args = ["project/main.py"] + list(command_line_overrides) + [f"name={name}"] + print(command_line_args) + monkeypatch.setattr(sys, "argv", command_line_args) + project.main.main() + + @pytest.mark.parametrize(command_line_overrides.__name__, ["algorithm=example"], indirect=True) def test_setting_just_algorithm_isnt_enough(experiment_dictconfig: DictConfig) -> None: """Test to check that the datamodule is required (even when just the example algorithm is set). diff --git a/project/utils/remote_launcher_plugin_test.py b/project/utils/remote_launcher_plugin_test.py index 3679f032..bc821ff4 100644 --- a/project/utils/remote_launcher_plugin_test.py +++ b/project/utils/remote_launcher_plugin_test.py @@ -2,6 +2,7 @@ # Use monkeypatch.setattr(project.utils.remote_launcher_plugin, ..., that_mock) # Assert That the mock launcher plugin was instantiated import os +import shlex import sys from pathlib import Path from unittest.mock import Mock @@ -17,11 +18,12 @@ import project.main import project.utils.remote_launcher_plugin -from project.configs.config_test import CONFIG_DIR -from project.conftest import command_line_overrides from project.main import PROJECT_NAME, main +from project.main_test import CONFIG_DIR from project.utils import remote_launcher_plugin +from project.utils.env_vars import SLURM_JOB_ID from project.utils.remote_launcher_plugin import RemoteSlurmLauncher +from project.utils.testutils import IN_GITHUB_CI, IN_SELF_HOSTED_GITHUB_CI def _yaml_files_in(directory: str | Path, recursive: bool = False): @@ -30,27 +32,37 @@ def _yaml_files_in(directory: str | Path, recursive: bool = False): return list(glob("*.yml")) + list(glob("*.yaml")) -cluster_configs = _yaml_files_in(CONFIG_DIR / "cluster") -resource_configs = _yaml_files_in(CONFIG_DIR / "resources") +cluster_configs = [p.stem for p in _yaml_files_in(CONFIG_DIR / "cluster")] +resource_configs = [p.stem for p in _yaml_files_in(CONFIG_DIR / "resources")] -@pytest.mark.skipif("SLURM_JOB_ID" in os.environ, reason="Can't be run on the cluster just yet.") @pytest.mark.parametrize( - command_line_overrides.__name__, + "command_line_args", [ pytest.param( - f"algorithm=example datamodule=cifar10 cluster={cluster.stem} resources={resources.stem}", - marks=pytest.mark.skipif( - cluster != "mila" and not is_already_logged_in(cluster.stem), - reason="Logging in would go through 2FA!", - ), + f"algorithm=example datamodule=cifar10 trainer.fast_dev_run=True cluster={cluster} resources={resources}", + marks=[ + pytest.mark.skipif( + SLURM_JOB_ID is None and cluster == "current", + reason="Can only be run on a slurm cluster.", + ), + pytest.mark.skipif( + IN_SELF_HOSTED_GITHUB_CI + and cluster != "mila" + and not is_already_logged_in(cluster), + reason="Can only use remote clusters from s-h runner if connection already exists (2FA).", + ), + pytest.mark.skipif( + IN_GITHUB_CI and not IN_SELF_HOSTED_GITHUB_CI, + reason="Can't connect to clusters from the GitHub cloud CI runner.", + ), + ], ) for cluster in cluster_configs for resources in resource_configs ], - indirect=True, ) -def test_can_load_configs(command_line_arguments: list[str]): +def test_can_load_configs(command_line_args: str): """Test that the cluster and resource configs can be loaded without errors.""" with initialize_config_module( @@ -58,9 +70,10 @@ def test_can_load_configs(command_line_arguments: list[str]): job_name="test", version_base="1.2", ): + overrides = shlex.split(command_line_args) _config = hydra.compose( config_name="config", - overrides=command_line_arguments, + overrides=overrides, return_hydra_config=True, ) @@ -74,8 +87,9 @@ def test_can_load_configs(command_line_arguments: list[str]): if launcher_config["_target_"] == remote_launcher_plugin.RemoteSlurmQueueConf._target_: with omegaconf.open_dict(launcher_config): launcher_config["executor"]["_synced"] = True # avoid syncing the code here. - launcher = hydra.utils.instantiate(launcher_config) - assert isinstance(launcher, remote_launcher_plugin.RemoteSlurmLauncher) + # TODO: This still tries to `git push`, which fails on the CI. + # launcher = hydra.utils.instantiate(launcher_config) + # assert isinstance(launcher, remote_launcher_plugin.RemoteSlurmLauncher) else: launcher = hydra.utils.instantiate(launcher_config) assert isinstance(launcher, SlurmLauncher)