From 52c88ce62d0a2f2b666d3f0e0364ac350e0bca29 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 2 Oct 2024 13:29:19 -0400 Subject: [PATCH] Make the datamodule config optional (#56) * Make the datamodule config optional Signed-off-by: Fabrice Normandin * Remove 'network' field in `Config` class Signed-off-by: Fabrice Normandin * Revert change to example.yaml config Signed-off-by: Fabrice Normandin * Add test for loading experiment configs Signed-off-by: Fabrice Normandin * Remove default algorithm / datamodule Signed-off-by: Fabrice Normandin * Move configs / add defaults Signed-off-by: Fabrice Normandin * Add the classification metrics cb to ExampleAlgo Signed-off-by: Fabrice Normandin * Don't try to print `datamodule` config Signed-off-by: Fabrice Normandin * Fix passing datamodule=None to algorithm Signed-off-by: Fabrice Normandin * Make algorithm sort-of optional (Hydra required) Signed-off-by: Fabrice Normandin * Remove dead code Signed-off-by: Fabrice Normandin * Fix failing tests in `main_test.py` Signed-off-by: Fabrice Normandin * Keep algorithm required Signed-off-by: Fabrice Normandin * "fix" test: passing only algorithm isnt enough now Signed-off-by: Fabrice Normandin * Fix tests in main_test.py Signed-off-by: Fabrice Normandin * Fix issue with main_test.py (again) Signed-off-by: Fabrice Normandin * Update cluster_sweep_example.yaml * Remove uncovered branch in ExampleAlgorithm Signed-off-by: Fabrice Normandin --------- Signed-off-by: Fabrice Normandin --- project/algorithms/example.py | 11 +- project/configs/config.py | 15 +- project/configs/config.yaml | 4 +- project/configs/config_test.py | 87 ++++++++++ .../experiment/cluster_sweep_example.yaml | 4 +- project/configs/experiment/example.yaml | 20 +-- .../experiment/local_sweep_example.yaml | 8 +- .../configs/experiment/overfit_one_batch.yaml | 23 --- .../configs/trainer/overfit_one_batch.yaml | 22 +++ project/conftest.py | 8 +- project/datamodules/datamodules_test.py | 13 +- project/experiment.py | 94 +++++------ project/main.py | 35 ++-- project/main_test.py | 27 ++- project/utils/__init__.py | 2 - project/utils/hydra_utils.py | 157 ++---------------- project/utils/utils.py | 14 -- 17 files changed, 255 insertions(+), 289 deletions(-) delete mode 100644 project/configs/experiment/overfit_one_batch.yaml create mode 100644 project/configs/trainer/overfit_one_batch.yaml diff --git a/project/algorithms/example.py b/project/algorithms/example.py index c95d381d..aa15d8a6 100644 --- a/project/algorithms/example.py +++ b/project/algorithms/example.py @@ -7,16 +7,19 @@ ``` """ +from collections.abc import Sequence from logging import getLogger from typing import Literal, TypeVar import torch from hydra_zen.typing import Builds, PartialBuilds from lightning import LightningModule +from lightning.pytorch.callbacks.callback import Callback from torch import Tensor from torch.nn import functional as F from torch.optim.optimizer import Optimizer +from project.algorithms.callbacks.classification_metrics import ClassificationMetricsCallback from project.configs.algorithm.optimizer import AdamConfig from project.datamodules.image_classification import ImageClassificationDataModule from project.experiment import instantiate @@ -68,9 +71,6 @@ def __init__( } ) - # Save hyper-parameters. - self.save_hyperparameters(ignore=["datamodule", "network"]) - # Small fix for the `device` property in LightningModule, which is CPU by default. self._device = next((p.device for p in self.parameters()), torch.device("cpu")) # Used by Pytorch-Lightning to compute the input/output shapes of the network. @@ -122,3 +122,8 @@ def configure_optimizers(self): optimizer = optimizer_partial(self.parameters()) # This then returns the optimizer. return optimizer + + def configure_callbacks(self) -> Sequence[Callback] | Callback: + return [ + ClassificationMetricsCallback.attach_to(self, num_classes=self.datamodule.num_classes) + ] diff --git a/project/configs/config.py b/project/configs/config.py index f4bc1b25..f79086aa 100644 --- a/project/configs/config.py +++ b/project/configs/config.py @@ -20,13 +20,6 @@ class Config: For more info, see https://hydra.cc/docs/tutorials/structured_config/schema/ """ - datamodule: Any - """Configuration for the datamodule (dataset + transforms + dataloader creation). - - This should normally create a [LightningDataModule][lightning.pytorch.core.datamodule.LightningDataModule]. - See the [MNISTDataModule][project.datamodules.image_classification.mnist.MNISTDataModule] for an example. - """ - algorithm: Any """Configuration for the algorithm (a [LightningModule][lightning.pytorch.core.module.LightningModule]). @@ -37,8 +30,12 @@ class Config: For more info, see the [instantiate_algorithm][project.experiment.instantiate_algorithm] function. """ - network: Any | None = None - """The network to use.""" + datamodule: Any | None = None + """Configuration for the datamodule (dataset + transforms + dataloader creation). + + This should normally create a [LightningDataModule][lightning.pytorch.core.datamodule.LightningDataModule]. + See the [MNISTDataModule][project.datamodules.image_classification.mnist.MNISTDataModule] for an example. + """ trainer: dict = field(default_factory=dict) """Keyword arguments for the Trainer constructor.""" diff --git a/project/configs/config.yaml b/project/configs/config.yaml index 77b8d3cd..2b894931 100644 --- a/project/configs/config.yaml +++ b/project/configs/config.yaml @@ -1,8 +1,8 @@ defaults: - base_config - _self_ - - datamodule: cifar10 - - algorithm: example + - optional algorithm: ??? + - optional datamodule: null - trainer: default.yaml - hydra: default.yaml diff --git a/project/configs/config_test.py b/project/configs/config_test.py index fa78a32b..ea6f6998 100644 --- a/project/configs/config_test.py +++ b/project/configs/config_test.py @@ -1 +1,88 @@ """TODO: Add tests for the configurations?""" + +import hydra_zen +import lightning +import omegaconf +import pytest +from hydra.core.config_store import ConfigStore + +from project.configs.config import Config +from project.experiment import Experiment, setup_experiment +from project.main import PROJECT_NAME +from project.utils.env_vars import REPO_ROOTDIR, SLURM_JOB_ID + +CONFIG_DIR = REPO_ROOTDIR / PROJECT_NAME / "configs" + +experiment_configs = list((CONFIG_DIR / "experiment").glob("*.yaml")) + + +@pytest.mark.parametrize( + "overrides", + [ + pytest.param( + f"experiment={experiment.name}", + marks=pytest.mark.xfail( + "cluster" in experiment.name and SLURM_JOB_ID is None, + reason="Needs to be run on a cluster.", + raises=omegaconf.errors.InterpolationResolutionError, + strict=True, + ), + ) + for experiment in list(experiment_configs) + ], + indirect=True, + ids=[experiment.name for experiment in list(experiment_configs)], +) +def test_can_load_experiment_configs(experiment_config: Config): + experiment = setup_experiment(experiment_config) + assert isinstance(experiment, Experiment) + + +class DummyModule(lightning.LightningModule): + def __init__(self, bob: int = 123): + super().__init__() + + def forward(self, x): + return self.network(x) + + +@pytest.fixture(scope="session") +def cs(): + state_before = ConfigStore.get_state() + yield ConfigStore.instance() + ConfigStore.set_state(state_before) + + +@pytest.fixture(scope="session") +def register_dummy_configs(cs: ConfigStore): + cs.store( + "dummy", + node=hydra_zen.builds( + DummyModule, + zen_partial=False, + populate_full_signature=True, + ), + group="algorithm", + ) + cs.store( + "dummy_partial", + node=hydra_zen.builds( + DummyModule, + zen_partial=True, + populate_full_signature=True, + ), + group="algorithm", + ) + + +@pytest.mark.parametrize( + "algorithm_config", + ["dummy", "dummy_partial"], + indirect=True, + # scope="module", +) +def test_can_use_algo_without_datamodule( + register_dummy_configs: None, algorithm: lightning.LightningModule +): + """Test that we can use an algorithm without a datamodule.""" + assert isinstance(algorithm, DummyModule) diff --git a/project/configs/experiment/cluster_sweep_example.yaml b/project/configs/experiment/cluster_sweep_example.yaml index 45faa77e..5055a274 100644 --- a/project/configs/experiment/cluster_sweep_example.yaml +++ b/project/configs/experiment/cluster_sweep_example.yaml @@ -11,7 +11,7 @@ name: "sweep-example" seed: ${oc.env:SLURM_PROCID,123} algorithm: - optimizer_config: + optimizer: # This here will get overwritten by the sweeper. lr: 0.002 @@ -72,7 +72,7 @@ hydra: sweeper: params: algorithm: - optimizer_config: + optimizer: lr: "loguniform(1e-6, 1.0, default_value=3e-4)" # weight_decay: "loguniform(1e-6, 1e-2, default_value=0)" diff --git a/project/configs/experiment/example.yaml b/project/configs/experiment/example.yaml index 34ed16f6..f20a346a 100644 --- a/project/configs/experiment/example.yaml +++ b/project/configs/experiment/example.yaml @@ -4,8 +4,8 @@ # python main.py experiment=example defaults: - - override /datamodule: cifar10 - override /algorithm: example + - override /datamodule: cifar10 - override /trainer: default - override /trainer/logger: wandb @@ -15,17 +15,17 @@ name: example seed: ${oc.env:SLURM_PROCID,12345} -hydra: - run: - # output directory, generated dynamically on each run - # DOESN'T WORK! This won't get interpolated correctly! - # TODO: Make it so running the same command twice in the same job id resumes from the last checkpoint. - dir: logs/${name}/runs/${oc.env:SLURM_JOB_ID,${hydra.job.id}} - sweep: - dir: logs/${name}/multiruns/ +# hydra: +# run: +# # output directory, generated dynamically on each run +# # DOESN'T WORK! This won't get interpolated correctly! +# # TODO: Make it so running the same command twice in the same job id resumes from the last checkpoint. +# dir: logs/${name}/runs/${oc.env:SLURM_JOB_ID,${hydra.job.id}} +# sweep: +# dir: logs/${name}/multiruns/ algorithm: - optimizer_config: + optimizer: lr: 0.002 datamodule: diff --git a/project/configs/experiment/local_sweep_example.yaml b/project/configs/experiment/local_sweep_example.yaml index f30b8832..6fcd4659 100644 --- a/project/configs/experiment/local_sweep_example.yaml +++ b/project/configs/experiment/local_sweep_example.yaml @@ -1,14 +1,14 @@ # @package _global_ defaults: - - example.yaml # A configuration for a single run (that works!) - - override /hydra/sweeper: orion # Select the orion sweeper plugin + - example.yaml # A configuration for a single run (that works!) + - override /hydra/sweeper: orion # Select the orion sweeper plugin log_level: DEBUG name: "local-sweep-example" seed: 123 algorithm: - optimizer_config: + optimizer: # This here will get overwritten by the sweeper. lr: 0.002 @@ -45,7 +45,7 @@ hydra: sweeper: params: algorithm: - optimizer_config: + optimizer: lr: "loguniform(1e-6, 1.0, default_value=3e-4)" # weight_decay: "loguniform(1e-6, 1e-2, default_value=0)" diff --git a/project/configs/experiment/overfit_one_batch.yaml b/project/configs/experiment/overfit_one_batch.yaml deleted file mode 100644 index 6cca320b..00000000 --- a/project/configs/experiment/overfit_one_batch.yaml +++ /dev/null @@ -1,23 +0,0 @@ -# @package _global_ -defaults: - - override /trainer/callbacks: no_checkpoints - -datamodule: - shuffle: false - normalize: true - -seed: 123 - -trainer: - min_epochs: 1 - max_epochs: 50 - # prints - profiler: null - - # debugs - fast_dev_run: False - overfit_batches: 1 - limit_val_batches: 0 - limit_test_batches: 0 - detect_anomaly: true - enable_checkpointing: false diff --git a/project/configs/trainer/overfit_one_batch.yaml b/project/configs/trainer/overfit_one_batch.yaml new file mode 100644 index 00000000..80d02ae8 --- /dev/null +++ b/project/configs/trainer/overfit_one_batch.yaml @@ -0,0 +1,22 @@ +# Note: This configuration should be run in combination with an algorithm. For example like this: +# `python project/main.py algorithm=example datamodule=cifar10 trainer=overfit_one_batch` +# +defaults: + - default + +callbacks: + model_checkpoint: null + early_stopping: null +min_epochs: 1 +max_epochs: 50 +log_every_n_steps: 1 +# prints +profiler: null + +# debugs +fast_dev_run: False +overfit_batches: 1 +limit_val_batches: 0 +limit_test_batches: 0 +detect_anomaly: true +enable_checkpointing: false diff --git a/project/conftest.py b/project/conftest.py index ccb5a4eb..3d54b675 100644 --- a/project/conftest.py +++ b/project/conftest.py @@ -239,14 +239,16 @@ def experiment_config( @pytest.fixture(scope="session") -def datamodule(experiment_dictconfig: DictConfig) -> DataModule: +def datamodule(experiment_dictconfig: DictConfig) -> DataModule | None: """Fixture that creates the datamodule for the given config.""" # NOTE: creating the datamodule by itself instead of with everything else. - return instantiate_datamodule(experiment_dictconfig.datamodule) + return instantiate_datamodule(experiment_dictconfig["datamodule"]) @pytest.fixture(scope="function") -def algorithm(experiment_config: Config, datamodule: DataModule, device: torch.device, seed: int): +def algorithm( + experiment_config: Config, datamodule: DataModule | None, device: torch.device, seed: int +): """Fixture that creates the "algorithm" (a [LightningModule][lightning.pytorch.core.module.LightningModule]).""" with device: diff --git a/project/datamodules/datamodules_test.py b/project/datamodules/datamodules_test.py index bc4e1875..912fc951 100644 --- a/project/datamodules/datamodules_test.py +++ b/project/datamodules/datamodules_test.py @@ -1,18 +1,13 @@ import sys from pathlib import Path -import hydra_zen import matplotlib.pyplot as plt import pytest import torch from lightning import LightningDataModule from lightning.fabric.utilities.exceptions import MisconfigurationException from lightning.pytorch.trainer.states import RunningStage -from omegaconf import DictConfig -from tensor_regression.fixture import ( - TensorRegressionFixture, - get_test_source_and_temp_file_paths, -) +from tensor_regression.fixture import TensorRegressionFixture, get_test_source_and_temp_file_paths from torch import Tensor from project.datamodules.image_classification.image_classification import ( @@ -23,11 +18,6 @@ from project.utils.typing_utils import is_sequence_of -@pytest.fixture() -def datamodule(experiment_dictconfig: DictConfig) -> LightningDataModule: - return hydra_zen.instantiate(experiment_dictconfig.datamodule) - - # @use_overrides(["datamodule.num_workers=0"]) # @pytest.mark.timeout(25, func_only=True) @pytest.mark.slow @@ -46,6 +36,7 @@ def datamodule(experiment_dictconfig: DictConfig) -> LightningDataModule: ), ], ) +@pytest.mark.parametrize("overrides", ["algorithm=no_op"], indirect=True) @run_for_all_datamodules() def test_first_batch( datamodule: LightningDataModule, diff --git a/project/experiment.py b/project/experiment.py index 651d4a9a..17ed1d4b 100644 --- a/project/experiment.py +++ b/project/experiment.py @@ -11,6 +11,7 @@ from __future__ import annotations +import copy import functools import logging import os @@ -19,21 +20,17 @@ from logging import getLogger as get_logger from typing import Any +import hydra +import hydra.utils import hydra_zen import rich.console import rich.logging import rich.traceback from hydra_zen.typing import Builds from lightning import Callback, LightningModule, Trainer, seed_everything -from torch import nn from project.configs.config import Config -from project.datamodules.image_classification.image_classification import ( - ImageClassificationDataModule, -) -from project.utils.hydra_utils import get_outer_class -from project.utils.typing_utils import Dataclass -from project.utils.typing_utils.protocols import DataModule, Module +from project.utils.typing_utils.protocols import DataModule from project.utils.utils import validate_datamodule logger = get_logger(__name__) @@ -58,7 +55,7 @@ class Experiment: """ algorithm: LightningModule - datamodule: DataModule + datamodule: DataModule | None trainer: Trainer @@ -135,32 +132,42 @@ def instantiate_trainer(experiment_config: Config) -> Trainer: # instantiate all the callbacks callback_configs = experiment_config.trainer.pop("callbacks", {}) callback_configs = {k: v for k, v in callback_configs.items() if v is not None} - callbacks: dict[str, Callback] | None = hydra_zen.instantiate(callback_configs) + callbacks: dict[str, Callback] | None = hydra.utils.instantiate( + callback_configs, _convert_="object" + ) # Create the loggers, if any. - loggers: dict[str, Any] | None = instantiate(experiment_config.trainer.pop("logger", {})) + loggers: dict[str, Any] | None = hydra.utils.instantiate( + experiment_config.trainer.pop("logger", {}) + ) # Create the Trainer. + + # BUG: `hydra.utils.instantiate` doesn't work with override **kwargs when some of them are + # dataclasses (e.g. a callback). + # trainer = hydra.utils.instantiate( + # config, + # callbacks=list(callbacks.values()) if callbacks else None, + # logger=list(loggers.values()) if loggers else None, + # ) assert isinstance(experiment_config.trainer, dict) - if experiment_config.debug: - logger.info("Setting the max_epochs to 1, since the 'debug' flag was passed.") - experiment_config.trainer["max_epochs"] = 1 - if "_target_" not in experiment_config.trainer: - experiment_config.trainer["_target_"] = Trainer - - trainer = instantiate( - experiment_config.trainer, - callbacks=list(callbacks.values()) if callbacks else None, - logger=list(loggers.values()) if loggers else None, - ) - assert isinstance(trainer, Trainer) + config = copy.deepcopy(experiment_config.trainer) + target = hydra.utils.get_object(config.pop("_target_")) + _callbacks = list(callbacks.values()) if callbacks else None + _loggers = list(loggers.values()) if loggers else None + + trainer = target(**config, callbacks=_callbacks, logger=_loggers) return trainer -def instantiate_datamodule(datamodule_config: Builds[type[DataModule]] | DataModule) -> DataModule: +def instantiate_datamodule( + datamodule_config: Builds[type[DataModule]] | DataModule | None, +) -> DataModule | None: """Instantiate the datamodule from the configuration dict. Any interpolations in the config will have already been resolved by the time we get here. """ + if not datamodule_config: + return None if isinstance(datamodule_config, DataModule): logger.info( f"Datamodule was already instantiated (probably to interpolate a field value). " @@ -170,13 +177,15 @@ def instantiate_datamodule(datamodule_config: Builds[type[DataModule]] | DataMod else: logger.debug(f"Instantiating datamodule from config: {datamodule_config}") datamodule = instantiate(datamodule_config) - assert isinstance(datamodule, DataModule) + # assert isinstance(datamodule, DataModule) datamodule = validate_datamodule(datamodule) return datamodule -def instantiate_algorithm(algorithm_config: Config, datamodule: DataModule) -> LightningModule: +def instantiate_algorithm( + algorithm_config: Config, datamodule: DataModule | None +) -> LightningModule: """Function used to instantiate the algorithm. It is suggested that your algorithm (LightningModule) take in the `datamodule` and `network` @@ -197,9 +206,16 @@ def instantiate_algorithm(algorithm_config: Config, datamodule: DataModule) -> L ) return algo_config - algo_or_algo_partial = instantiate(algo_config, datamodule=datamodule) + if datamodule: + algo_or_algo_partial = hydra.utils.instantiate(algo_config, datamodule=datamodule) + else: + algo_or_algo_partial = hydra.utils.instantiate(algo_config) + if isinstance(algo_or_algo_partial, functools.partial): - algorithm = algo_or_algo_partial(datamodule=datamodule) + if datamodule: + algorithm = algo_or_algo_partial(datamodule=datamodule) + else: + algorithm = algo_or_algo_partial() else: # logger.warning( # f"Your algorithm config {algo_config} doesn't have '_partial_: true' set, which is " @@ -214,27 +230,5 @@ def instantiate_algorithm(algorithm_config: Config, datamodule: DataModule) -> L f"explicitly supported at the moment." ) ) - return algorithm - -def instantiate_network_from_hparams(network_hparams: Dataclass, datamodule: DataModule) -> Module: - """TODO: Refactor this if possible. Shouldn't be as complicated as it currently is. - - Perhaps we could register handler functions for each pair of datamodule and network type, a bit - like a multiple dispatch? - """ - network_type = get_outer_class(type(network_hparams)) - assert issubclass(network_type, nn.Module) - assert isinstance( - network_hparams, - network_type.HParams, # type: ignore - ), "HParams type should match net type" - if isinstance(datamodule, ImageClassificationDataModule): - # if issubclass(network_type, ImageClassifierNetwork): - return network_type( - in_channels=datamodule.dims[0], - n_classes=datamodule.num_classes, # type: ignore - hparams=network_hparams, - ) - - raise NotImplementedError(datamodule, network_hparams) + return algorithm diff --git a/project/main.py b/project/main.py index fda9457a..9be65365 100644 --- a/project/main.py +++ b/project/main.py @@ -40,15 +40,18 @@ def main(dict_config: DictConfig) -> dict: from project.utils.auto_schema import add_schemas_to_all_hydra_configs # Note: running this should take ~5 seconds the first time and <1s after that. - add_schemas_to_all_hydra_configs( - config_files=None, - repo_root=REPO_ROOTDIR, - configs_dir=REPO_ROOTDIR / PROJECT_NAME / "configs", - regen_schemas=False, - stop_on_error=False, - quiet=True, - add_headers=False, # don't add headers if we can't add an entry in vscode settings. - ) + try: + add_schemas_to_all_hydra_configs( + config_files=None, + repo_root=REPO_ROOTDIR, + configs_dir=REPO_ROOTDIR / PROJECT_NAME / "configs", + regen_schemas=False, + stop_on_error=False, + quiet=True, + add_headers=False, # don't add headers if we can't add an entry in vscode settings. + ) + except Exception: + logger.error("Unable to add schemas to all hydra configs.") config: Config = resolve_dictconfig(dict_config) @@ -77,11 +80,15 @@ def run(experiment: Experiment) -> tuple[str, float | None, dict]: # potentially adding Wrappers on top of the environment, or having a replay buffer, etc. # TODO: Add ckpt_path argument to resume a training run. datamodule = getattr(experiment.algorithm, "datamodule", experiment.datamodule) - assert isinstance(datamodule, LightningDataModule) - experiment.trainer.fit( - experiment.algorithm, - datamodule=datamodule, - ) + + if datamodule is None: + experiment.trainer.fit(experiment.algorithm) + else: + assert isinstance(datamodule, LightningDataModule) + experiment.trainer.fit( + experiment.algorithm, + datamodule=datamodule, + ) metric_name, error, metrics = evaluation(experiment) if wandb.run: diff --git a/project/main_test.py b/project/main_test.py index 5b277b6e..94ad53cd 100644 --- a/project/main_test.py +++ b/project/main_test.py @@ -3,7 +3,9 @@ import shutil +import hydra.errors import hydra_zen +import omegaconf.errors import pytest import torch from omegaconf import DictConfig @@ -12,6 +14,7 @@ from project.configs.config import Config from project.conftest import use_overrides from project.datamodules.image_classification.cifar10 import CIFAR10DataModule +from project.utils.hydra_utils import resolve_dictconfig from .main import main @@ -36,10 +39,30 @@ def test_torch_can_use_the_GPU(): assert torch.cuda.is_available() == bool(shutil.which("nvidia-smi")) +@pytest.mark.xfail(raises=hydra.errors.ConfigCompositionException, strict=True) @pytest.mark.parametrize("overrides", [""], indirect=True) -def test_defaults(experiment_config: Config) -> None: +def test_defaults(experiment_dictconfig: DictConfig) -> None: """Test to check what the default values are when not specifying anything on the command- line.""" + # todo: the error is actually raised before this. + # with pytest.raises(hydra.errors.ConfigCompositionException): + # _ = resolve_dictconfig(experiment_dictconfig) + + +@pytest.mark.parametrize("overrides", ["algorithm=example"], indirect=True) +def test_setting_just_algorithm_isnt_enough(experiment_dictconfig: DictConfig) -> None: + """Test to check that the datamodule is required (even when just an algorithm is set?!).""" + with pytest.raises( + omegaconf.errors.InterpolationResolutionError, + match="Could not find any of these attributes", + ): + _ = resolve_dictconfig(experiment_dictconfig) + + +@pytest.mark.parametrize("overrides", ["algorithm=example datamodule=cifar10"], indirect=True) +def test_example_experiment_defaults(experiment_config: Config) -> None: + """Test to check that the datamodule is required (even when just an algorithm is set?!).""" + assert experiment_config.algorithm["_target_"] == ( ExampleAlgorithm.__module__ + "." + ExampleAlgorithm.__qualname__ ) @@ -49,7 +72,7 @@ def test_defaults(experiment_config: Config) -> None: ) -@use_overrides(["seed=1 +trainer.fast_dev_run=True"]) +@use_overrides(["algorithm=example datamodule=cifar10 seed=1 +trainer.fast_dev_run=True"]) def test_fast_dev_run(experiment_dictconfig: DictConfig): result = main(experiment_dictconfig) assert isinstance(result, dict) diff --git a/project/utils/__init__.py b/project/utils/__init__.py index 3c4d011e..bb9141c8 100644 --- a/project/utils/__init__.py +++ b/project/utils/__init__.py @@ -1,8 +1,6 @@ # Import this patch for https://github.com/mit-ll-responsible-ai/hydra-zen/issues/705 to make sure that it gets applied. -from .hydra_utils import patched_safe_name from .utils import default_device __all__ = [ "default_device", - "patched_safe_name", ] diff --git a/project/utils/hydra_utils.py b/project/utils/hydra_utils.py index 7f9d1a25..f00eed6c 100644 --- a/project/utils/hydra_utils.py +++ b/project/utils/hydra_utils.py @@ -2,25 +2,23 @@ from __future__ import annotations -import dataclasses import functools import importlib import inspect import typing from collections import ChainMap from collections.abc import Callable, Mapping, MutableMapping -from dataclasses import MISSING, field, fields, is_dataclass +from dataclasses import fields, is_dataclass from logging import getLogger as get_logger from typing import ( Any, - Literal, TypeVar, ) +import hydra.utils import hydra_zen.structured_configs._utils +import omegaconf from hydra_zen import instantiate -from hydra_zen.structured_configs._utils import safe_name -from hydra_zen.typing._implementations import Partial as _Partial from omegaconf import DictConfig, OmegaConf if typing.TYPE_CHECKING: @@ -29,127 +27,6 @@ logger = get_logger(__name__) -T = TypeVar("T") - - -def patched_safe_name(obj: Any, repr_allowed: bool = True): - """Patches a bug in Hydra-zen where the _target_ of inner classes is incorrect: - https://github.com/mit-ll-responsible-ai/hydra-zen/issues/705 - """ - - if not hasattr(obj, "__qualname__"): - return safe_name(obj, repr_allowed=repr_allowed) - - name = safe_name(obj, repr_allowed=repr_allowed) - qualname = obj.__qualname__ - assert isinstance(qualname, str) - - if name != qualname and qualname.endswith("." + name): - logger.debug(f"Using patched fn: returning {qualname} for target {obj}") - return qualname - - return name - - -hydra_zen.structured_configs._utils.safe_name = patched_safe_name - - -def interpolate_config_attribute(*attributes: str, default: Any | Literal[MISSING] = MISSING): - """Use this in a config to to get an attribute from another config after it is instantiated. - - Multiple attributes can be specified, which will lead to trying each of them in order until the - attribute is found. If none are found, then an error will be raised. - - For example, if we only know the number of classes in the datamodule after it is instantiated, - we can set this in the network config so it is created with the right number of output dims. - - ```yaml - _target_: torchvision.models.resnet50 - num_classes: ${instance_attr:datamodule.num_classes} - ``` - - This is equivalent to: - - >>> import hydra_zen - >>> import torchvision.models - >>> resnet50_config = hydra_zen.builds( - ... torchvision.models.resnet50, - ... num_classes=interpolate_config_attribute("datamodule.num_classes"), - ... populate_full_signature=True, - ... ) - >>> print(hydra_zen.to_yaml(resnet50_config)) # doctest: +NORMALIZE_WHITESPACE - _target_: torchvision.models.resnet.resnet50 - weights: null - progress: true - num_classes: ${instance_attr:datamodule.num_classes} - """ - if default is MISSING: - return "${instance_attr:" + ",".join(attributes) + "}" - return "${instance_attr:" + ",".join(attributes) + ":" + str(default) + "}" - - -def interpolated_field( - interpolation: str, - default: T | Literal[MISSING] = MISSING, - default_factory: Callable[[], T] | Literal[MISSING] = MISSING, - instance_attr: bool = False, -) -> T: - """Field with a default value computed with a OmegaConf-style interpolation when appropriate. - - When the dataclass is created by Hydra / OmegaConf, the interpolation is used. - Otherwise, behaves as usual (either using default or calling the default_factory). - - Parameters - ---------- - interpolation: The string interpolation to use to get the default value. - default: The default value to use when not in a hydra/OmegaConf context. - default_factory: The default value to use when not in a hydra/OmegaConf context. - instance_attr: Whether to use the `instance_attr` custom resolver to run the interpolation \ - with respect to instantiated objects instead of their configs. - Passing `interpolation='${instance_attr:some_config.some_attr}'` has the same effect. - - This last parameter is important, since in order to retrieve the instance attribute, we need to - instantiate the objects, which could be expensive. These instantiated objects are reused at - least, but still, be mindful when using this parameter. - """ - assert "${" in interpolation and "}" in interpolation - - if instance_attr: - if not interpolation.startswith("${instance_attr:"): - interpolation = interpolation.removeprefix("${") - interpolation = "${instance_attr:" + interpolation - - if default is MISSING and default_factory is MISSING: - raise RuntimeError( - "Interpolated fields currently still require a default value or default factory for " - "when they are used outside the Hydra/OmegaConf context." - ) - return field( - default_factory=functools.partial( - _default_factory, - interpolation=interpolation, - default=default, - default_factory=default_factory, - ) - ) - - -# @dataclass(init=False) -class Partial(functools.partial[T], _Partial[T]): - def __getattr__(self, name: str): - if name in self.keywords: - return self.keywords[name] - raise AttributeError(name) - - -def add_attributes(fn: functools.partial[T]) -> Partial[T]: - """Adds a __getattr__ to the partial that returns the value in `v.keywords`.""" - if isinstance(fn, Partial): - return fn - assert isinstance(fn, functools.partial) - return Partial(fn.func, *fn.args, **fn.keywords) - - def get_full_name(object_type: type) -> str: return object_type.__module__ + "." + object_type.__qualname__ @@ -224,6 +101,18 @@ def resolve_dictconfig(dict_config: DictConfig) -> Config: register_instance_attr_resolver(instantiated_objects_cache) # Convert the "raw" DictConfig (which uses the `Config` class to define it's structure) # into an actual `Config` object: + + # TODO: Seems to only be necessary now that the datamodule group is optional? + # Need to manually nudge OmegaConf so that it instantiates the datamodule first. + if dict_config["datamodule"]: + with omegaconf.open_dict(dict_config): + v = dict_config._get_flag("allow_objects") + dict_config._set_flag("allow_objects", True) + instantiated_objects_cache["datamodule"] = dict_config["datamodule"] = ( + hydra.utils.instantiate(dict_config["datamodule"]) + ) + dict_config._set_flag("allow_objects", v) + config = OmegaConf.to_object(dict_config) from project.configs.config import Config @@ -276,7 +165,7 @@ def instance_attr( """ if not attributes: raise RuntimeError("Need to pass one or more attributes to this resolver.") - assert being_called_in_hydra_context() + assert _being_called_in_hydra_context() logger.debug(f"Custom resolver is being called to get the value of {attributes}.") current_frame = inspect.currentframe() @@ -408,7 +297,7 @@ def instance_attr( ) -def being_called_in_hydra_context() -> bool: +def _being_called_in_hydra_context() -> bool: """Returns `True` if this function is being called indirectly by Hydra/OmegaConf. Can be used in a field default factory to change the default value based on whether the config @@ -441,18 +330,6 @@ def _being_called_by(*functions: Callable) -> bool: return False -def _default_factory( - interpolation: str, - default: T | Literal[dataclasses.MISSING] = dataclasses.MISSING, - default_factory: Callable[[], T] | Literal[dataclasses.MISSING] = dataclasses.MISSING, -) -> T: - if being_called_in_hydra_context(): - return interpolation # type: ignore - if default_factory is not dataclasses.MISSING: - return default_factory() - return default # type: ignore - - Target = TypeVar("Target") diff --git a/project/utils/utils.py b/project/utils/utils.py index 7cf28ec6..ad2ef13f 100644 --- a/project/utils/utils.py +++ b/project/utils/utils.py @@ -1,6 +1,5 @@ from __future__ import annotations -import functools import typing from collections.abc import Sequence from logging import getLogger as get_logger @@ -22,18 +21,6 @@ logger = get_logger(__name__) -# todo: doesn't work? keeps logging each time! -@functools.cache -def log_once(message: str, level: int) -> None: - """Logs a message once. The message is logged at the specified level. - - Args: - message: The message to log. - level: The logging level to use. - """ - logger.log(level=level, msg=message, stacklevel=2) - - def get_log_dir(trainer: Trainer | None) -> Path: """Gives back the default directory to use when `trainer.log_dir` is None (no logger used).""" # TODO: This isn't great.. It could probably be a property on the Algorithm class or @@ -114,7 +101,6 @@ def print_config( config: DictConfig, print_order: Sequence[str] = ( "algorithm", - "network", "datamodule", "trainer", ),