Skip to content

Commit

Permalink
Simplify main.py / experiment.py and tests (#80)
Browse files Browse the repository at this point in the history
* [Ugly] simplify main.py / experiment.py

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix tiny bug in test_load_configs

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove duplicated test code

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add tests to run experiment configs

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove unused function

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Rename test

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Simplify test for remote launcher

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix the marks on the remote launcher tests

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add leftover changes

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Tweak failing test

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Simplify (and split up) the `evaluation` function

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add xfail on test for cluster sweep example test

- Getting a `TypeError: cannot pickle 'weakref.ReferenceType' object`.
  This is a bit weird. I don't know what's suddenly causing this.

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add comment in `main_test.py`

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix issue in `experiment_dictconfig` fixture

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove unused code in `config_test.py`

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add a test command to cover the `no_op` algorithm

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove unused variables in config_test.py

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix error with cluster_sweep_example + resources

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Fix bug with no_op algo and test command

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Make the xfail not strict (?)

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

---------

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice authored Nov 4, 2024
1 parent 2724bb8 commit 727fa67
Show file tree
Hide file tree
Showing 15 changed files with 370 additions and 249 deletions.
16 changes: 14 additions & 2 deletions docs/profiling_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import shutil

import hydra.errors
import lightning
import pytest
from omegaconf import DictConfig

Expand All @@ -13,7 +14,12 @@
datamodule_config,
experiment_dictconfig,
)
from project.experiment import setup_experiment
from project.experiment import (
instantiate_algorithm,
instantiate_datamodule,
instantiate_trainer,
setup_logging,
)
from project.utils.hydra_utils import resolve_dictconfig


Expand Down Expand Up @@ -111,5 +117,11 @@ def test_notebook_commands_dont_cause_errors(experiment_dictconfig: DictConfig):
# check for any errors related to OmegaConf interpolations and such
config = resolve_dictconfig(experiment_dictconfig)
# check for any errors when actually instantiating the components.
_experiment = setup_experiment(config)
# _experiment = _setup_experiment(config)
setup_logging(config)
lightning.seed_everything(config.seed, workers=True)
_trainer = instantiate_trainer(config)
datamodule = instantiate_datamodule(config.datamodule)
_algorithm = instantiate_algorithm(config.algorithm, datamodule=datamodule)

# Note: Here we don't actually do anything with the objects.
12 changes: 10 additions & 2 deletions project/algorithms/jax_rl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ class PPOHParams(flax.struct.PyTreeNode):
num_steps: int = field(default=64)
num_minibatches: int = field(default=16)

# ADDED:
num_seeds_per_eval: int = field(default=128)

eval_freq: int = field(default=4_096)

normalize_observations: bool = field(default=False)
Expand All @@ -115,6 +118,7 @@ class PPOHParams(flax.struct.PyTreeNode):
learning_rate: chex.Scalar = 0.0003
gamma: chex.Scalar = 0.99
max_grad_norm: chex.Scalar = jnp.inf
# todo: this `jnp.inf` is causing issues in the yaml schema because it becomes `Infinity`.

gae_lambda: chex.Scalar = 0.95
clip_eps: chex.Scalar = 0.2
Expand Down Expand Up @@ -389,9 +393,13 @@ def eval_callback(
if rng is None:
rng = ts.rng
actor = make_actor(ts=ts, hp=self.hp)
max_steps = self.env_params.max_steps_in_episode
ep_lengths, cum_rewards = evaluate(
actor, ts.rng, self.env, self.env_params, 128, max_steps
actor,
ts.rng,
self.env,
self.env_params,
num_seeds=self.hp.num_seeds_per_eval,
max_steps_in_episode=self.env_params.max_steps_in_episode,
)
return EvalMetrics(episode_length=ep_lengths, cumulative_reward=cum_rewards)

Expand Down
10 changes: 4 additions & 6 deletions project/algorithms/jax_rl_example_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

from project.algorithms.callbacks.samples_per_second import MeasureSamplesPerSecondCallback
from project.trainers.jax_trainer import JaxTrainer, hparams_to_dict
from project.utils.testutils import run_for_all_configs_of_type

from .jax_rl_example import (
EvalMetrics,
Expand All @@ -43,7 +42,6 @@
_actor,
render_episode,
)
from .testsuites.algorithm_tests import LearningAlgorithmTests

logger = getLogger(__name__)

Expand Down Expand Up @@ -671,10 +669,10 @@ def log(


# TODO: potentially just use the Lightning adapter for unit tests for now?
@pytest.mark.skip(reason="TODO: ests assume a LightningModule atm (.state_dict()), etc.")
@run_for_all_configs_of_type("algorithm", JaxRLExample)
class TestJaxRLExample(LearningAlgorithmTests[JaxRLExample]): # type: ignore
pass
# @pytest.mark.skip(reason="TODO: ests assume a LightningModule atm (.state_dict()), etc.")
# @run_for_all_configs_of_type("algorithm", JaxRLExample)
# class TestJaxRLExample(LearningAlgorithmTests[JaxRLExample]): # type: ignore
# pass


@pytest.fixture
Expand Down
1 change: 1 addition & 0 deletions project/algorithms/no_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self, datamodule: DataModule):
self.datamodule = datamodule
# Set this so PyTorch-Lightning doesn't try to train the model using our 'loss'
self.automatic_optimization = False
self.p = torch.nn.Parameter(torch.tensor(0.0)) # unused.

def training_step(self, batch: Any, batch_index: int):
return self.shared_step(batch, batch_index, "train")
Expand Down
67 changes: 1 addition & 66 deletions project/configs/config_test.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,11 @@
"""TODO: Add tests for the configurations?"""

import copy
from unittest.mock import Mock

import hydra_zen
import lightning
import omegaconf
import pytest
from hydra.core.config_store import ConfigStore
from omegaconf import DictConfig

import project
import project.main
from project.conftest import algorithm_config, command_line_overrides
from project.main import PROJECT_NAME
from project.utils.env_vars import REPO_ROOTDIR, SLURM_JOB_ID

CONFIG_DIR = REPO_ROOTDIR / PROJECT_NAME / "configs"

experiment_configs = list((CONFIG_DIR / "experiment").glob("*.yaml"))


@pytest.fixture
def mock_train(monkeypatch: pytest.MonkeyPatch):
mock_train_fn = Mock(spec=project.main.train)
monkeypatch.setattr(project.main, project.main.train.__name__, mock_train_fn)
return mock_train_fn


@pytest.fixture
def mock_evaluate(monkeypatch: pytest.MonkeyPatch):
mock_eval_fn = Mock(spec=project.main.evaluation, return_value=("fake", 0.0, {}))
monkeypatch.setattr(project.main, project.main.evaluation.__name__, mock_eval_fn)
return mock_eval_fn

# The problem is that not all experiment configs
# are to be used in the same way. For example,
# the cluster_sweep_config.yaml needs an
# additional `cluster` argument. Also, the
# example config uses wandb by default, which is
# probably bad, since it might be creating empty
# jobs in wandb during tests (since the logger is
# instantiated in main, even if the train fn is
# mocked.

@pytest.mark.skip(reason="TODO: test is too general")
@pytest.mark.parametrize(
command_line_overrides.__name__,
[
pytest.param(
f"experiment={experiment.name}",
marks=pytest.mark.xfail(
"cluster" in experiment.name and SLURM_JOB_ID is None,
reason="Needs to be run on a cluster.",
raises=omegaconf.errors.InterpolationResolutionError,
strict=True,
),
)
for experiment in list(experiment_configs)
],
indirect=True,
ids=[experiment.name for experiment in list(experiment_configs)],
)
def test_can_load_experiment_configs(
experiment_dictconfig: DictConfig, mock_train: Mock, mock_evaluate: Mock
):
# Mock out some part of the `main` function to not actually run anything.

results = project.main.main(copy.deepcopy(experiment_dictconfig))
assert results is not None
mock_train.assert_called_once()
mock_evaluate.assert_called_once()
from project.conftest import algorithm_config


class DummyModule(lightning.LightningModule):
Expand Down
16 changes: 11 additions & 5 deletions project/configs/experiment/cluster_sweep_example.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# @package _global_

# This is an "experiment" config, that groups together other configs into a ready-to-run example.

defaults:
- example.yaml # A configuration for a single run (that works!)
- override /trainer/logger: wandb
Expand All @@ -8,6 +11,8 @@ defaults:

log_level: DEBUG
name: "sweep-example"

# Set the seed to be the SLURM_PROCID, so that if we run more than one task per GPU, we get
# TODO: This should technically be something like the "run_id", which would be different than SLURM_PROCID when using >1 gpus per "run".
seed: ${oc.env:SLURM_PROCID,123}

Expand Down Expand Up @@ -44,7 +49,7 @@ hydra:
sweep:
dir: logs/${name}/multiruns/
# subdir: ${hydra.job.num}
subdir: ${hydra.job.id}/task${oc.env:SLURM_PROCID}
subdir: ${hydra.job.id}/task${oc.env:SLURM_PROCID,0}

launcher:
# todo: bump this up.
Expand All @@ -54,16 +59,17 @@ hydra:
# TODO: Pack more than one job on a single GPU, and support this with both a
# patched submitit launcher as well as our remote submitit launcher, as well as by patching the
# orion sweeper to not drop these other results.
ntasks_per_gpu: 1
# ntasks_per_gpu: 1
sweeper:
params:
algorithm:
optimizer:
lr: "loguniform(1e-6, 1.0, default_value=3e-4)"
# weight_decay: "loguniform(1e-6, 1e-2, default_value=0)"
trainer:
# Let the HPO algorithm allocate more epochs to more promising HP configurations.
max_epochs: "fidelity(1, 10, default_value=1)"
# todo: setup a fidelity parameter. Seems to not be working right now.
# trainer:
# # Let the HPO algorithm allocate more epochs to more promising HP configurations.
# max_epochs: "fidelity(1, 10, default_value=1)"

parametrization: null
experiment:
Expand Down
27 changes: 12 additions & 15 deletions project/configs/experiment/example.yaml
Original file line number Diff line number Diff line change
@@ -1,28 +1,25 @@
# @package _global_

# to execute this experiment run:
# python main.py experiment=example
# This is an "experiment" config, that groups together other configs into a ready-to-run example.

# To execute this experiment, use:
# python project/main.py experiment=example

defaults:
- override /algorithm: example
- override /algorithm/network: resnet18
- override /datamodule: cifar10
- override /trainer: default
- override /trainer/logger: wandb
- override /trainer/logger: tensorboard
- override /trainer/callbacks: default

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
name: example
# The parameters below will be merged with parameters from default configurations set above.
# This allows you to overwrite only specified parameters

seed: ${oc.env:SLURM_PROCID,12345}
# The name of the e
name: example

# hydra:
# run:
# # output directory, generated dynamically on each run
# # DOESN'T WORK! This won't get interpolated correctly!
# # TODO: Make it so running the same command twice in the same job id resumes from the last checkpoint.
# dir: logs/${name}/runs/${oc.env:SLURM_JOB_ID,${hydra.job.id}}
# sweep:
# dir: logs/${name}/multiruns/
seed: ${oc.env:SLURM_PROCID,42}

algorithm:
optimizer:
Expand Down
2 changes: 2 additions & 0 deletions project/configs/experiment/jax_rl_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ defaults:
- override /trainer: jax
- override /trainer/callbacks: rich_progress_bar
- override /datamodule: null
# - /trainer/logger: tensorboard
trainer:
_convert_: object
max_epochs: 75
training_steps_per_epoch: 1
callbacks:
Expand Down
2 changes: 1 addition & 1 deletion project/configs/resources/cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ hydra:
launcher:
nodes: 1
tasks_per_node: 1
cpus_per_task: 8
cpus_per_task: 4
mem_gb: 16
array_parallelism: 16 # max num of jobs to run in parallel
# Other things to pass to `sbatch`:
Expand Down
1 change: 1 addition & 0 deletions project/configs/trainer/jax.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
defaults:
- callbacks: rich_progress_bar.yaml
- logger: null
_target_: project.trainers.jax_trainer.JaxTrainer
max_epochs: 75
training_steps_per_epoch: 1
Expand Down
21 changes: 12 additions & 9 deletions project/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from typing import Literal

import jax
import lightning
import lightning.pytorch as pl
import pytest
import tensor_regression.stats
Expand All @@ -88,7 +89,6 @@
instantiate_algorithm,
instantiate_datamodule,
instantiate_trainer,
seed_rng,
setup_logging,
)
from project.main import PROJECT_NAME
Expand Down Expand Up @@ -186,7 +186,10 @@ def command_line_arguments(
# If we manually overwrite the command-line arguments with indirect parametrization,
# then ignore the rest of the stuff here and just use the provided command-line args.
# Split the string into a list of command-line arguments if needed.
return shlex.split(param) if isinstance(param, str) else param
if isinstance(param, str):
return tuple(shlex.split(param))
assert isinstance(param, list | tuple)
return tuple(param)

combination = set([datamodule_config, algorithm_network_config, algorithm_config])
for configs, marks in default_marks_for_config_combinations.items():
Expand Down Expand Up @@ -221,7 +224,7 @@ def command_line_arguments(

@pytest.fixture(scope="session")
def experiment_dictconfig(
command_line_arguments: list[str], tmp_path_factory: pytest.TempPathFactory
command_line_arguments: tuple[str, ...], tmp_path_factory: pytest.TempPathFactory
) -> DictConfig:
"""The `omegaconf.DictConfig` that is created by Hydra from the command-line arguments.
Expand All @@ -237,12 +240,12 @@ def experiment_dictconfig(

tmp_path = tmp_path_factory.mktemp("test")
if not any("trainer.default_root_dir" in override for override in command_line_arguments):
command_line_arguments = command_line_arguments + [
f"++trainer.default_root_dir={tmp_path}"
]
command_line_arguments = tuple(command_line_arguments) + (
f"++trainer.default_root_dir={tmp_path}",
)

with _setup_hydra_for_tests_and_compose(
all_overrides=command_line_arguments,
all_overrides=list(command_line_arguments),
tmp_path_factory=tmp_path_factory,
) as dict_config:
return dict_config
Expand Down Expand Up @@ -287,7 +290,7 @@ def trainer(
experiment_config: Config,
) -> pl.Trainer:
setup_logging(experiment_config)
seed_rng(experiment_config)
lightning.seed_everything(experiment_config.seed, workers=True)
return instantiate_trainer(experiment_config)


Expand Down Expand Up @@ -432,7 +435,7 @@ def _override_param_id(override: Param) -> str:


@pytest.fixture(scope="session", ids=_override_param_id)
def command_line_overrides(request: pytest.FixtureRequest):
def command_line_overrides(request: pytest.FixtureRequest) -> tuple[str, ...]:
"""Fixture that makes it possible to specify command-line overrides to use in a given test.
Tests that require running an experiment should use the `experiment_config` fixture below.
Expand Down
Loading

0 comments on commit 727fa67

Please sign in to comment.