diff --git a/CHANGELOG.md b/CHANGELOG.md index 5bd725acb5..1e826826b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,25 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - documentation search error (21.10 only) ([#1346](https://github.com/catalyst-team/catalyst/pull/1346)) +## [21.11] - 2021-11-30 + +### Added + +- Returned `resume` support - resolved [#1193](https://github.com/catalyst-team/catalyst/issues/1193) ([#1349](https://github.com/catalyst-team/catalyst/pull/1349)) + +### Changed + +- + +### Removed + +- + +### Fixed + +- + + ## [21.10] - 2021-10-30 ### Added diff --git a/catalyst/__init__.py b/catalyst/__init__.py index bd691f7698..df5c351f2a 100644 --- a/catalyst/__init__.py +++ b/catalyst/__init__.py @@ -4,6 +4,7 @@ warnings.filterwarnings("ignore", message="numpy.dtype size changed", append=True) warnings.filterwarnings("ignore", module="tqdm", append=True) warnings.filterwarnings("once", append=True) +warnings.filterwarnings("ignore", message="This overload of add_ is deprecated", append=True) from catalyst.__version__ import __version__ from catalyst.settings import SETTINGS diff --git a/catalyst/callbacks/checkpoint.py b/catalyst/callbacks/checkpoint.py index e364a32660..c2ac35af65 100644 --- a/catalyst/callbacks/checkpoint.py +++ b/catalyst/callbacks/checkpoint.py @@ -1,4 +1,4 @@ -from typing import Dict, Union +from typing import Dict, Iterable, Union from collections import OrderedDict import os from pathlib import Path @@ -9,6 +9,8 @@ from catalyst.extras.metric_handler import MetricHandler from catalyst.utils.config import save_config +_default_states = {"best", "best_full", "last", "last_full"} + def _save_checkpoint( checkpoint: Dict, @@ -123,7 +125,6 @@ def _get_required_files(logdir: str, load_map: Dict[str, str]) -> Dict[str, str] if load_map is None: return OrderedDict() - default_states = {"best", "best_full", "last", "last_full"} required_full_checkpoint = ["criterion", "optimizer", "scheduler"] steps = ["global_epoch_step", "global_batch_step", "global_sample_step"] experiment_parts = ["model"] + required_full_checkpoint + steps @@ -141,7 +142,7 @@ def _get_required_files(logdir: str, load_map: Dict[str, str]) -> Dict[str, str] fname = load_map[part] required_full = fname.endswith("_full") # specified default state - if fname in default_states: + if fname in _default_states: if part in required_full_checkpoint and not required_full: fname = fname + "_full" fname = f"{logdir}/{fname}.pth" @@ -197,24 +198,46 @@ def _load_states_from_file_map( def _load_runner( - logdir: str, runner: "IRunner", mapping: Union[str, Dict[str, str]], load_full: bool = False + logdir: str, + runner: "IRunner", + mapping: Union[str, Dict[str, str]], + not_required_states: Iterable[str] = None, ) -> None: """ - Selects a loading method based on type of mapping. + Checks if the files used in mapping exist and selects a loading method + based on type of mapping. Args: logdir: logdir with checkpoints runner: current runner mapping: mapping to use for loading - load_full: load a full model, used only when mapping type is string + not_required_states: states to skip when loading checkpoints + + Raises: + FileNotFoundError: if files given in mapping are missing. """ + if not_required_states is None: + not_required_states = [] + possible_states = _default_states.difference(not_required_states) + file_exists = False if isinstance(mapping, str): - if mapping in {"best", "best_full", "last", "last_full"}: + load_full = "full" in mapping + if mapping in possible_states: checkpoint = f"{logdir}/{mapping}.pth" else: checkpoint = mapping + file_exists = os.path.isfile(checkpoint) + if not file_exists: + raise FileNotFoundError(f"Missing file '{checkpoint}'!") # noqa: F821 _load_checkpoint(filename=checkpoint, runner=runner, load_full=load_full) elif isinstance(mapping, dict): + mapping = {k: v for k, v in mapping.items() if v not in not_required_states} + required_files = _get_required_files(logdir, mapping).keys() + file_exists = True + for use_file in required_files: + if not os.path.isfile(use_file): + file_exists = False + raise FileNotFoundError(f"Missing file '{use_file}'!") _load_states_from_file_map(logdir=logdir, runner=runner, load_map=mapping) @@ -228,7 +251,7 @@ class CheckpointCallback(ICheckpointCallback): """Checkpoint callback to save/restore your model/criterion/optimizer/scheduler. Args: - logdir: directory to store chekpoints + logdir: directory to store checkpoints loader_key: loader key for best model selection (based on metric score over the dataset) metric_key: metric key for best model selection (based on metric score over the dataset) minimize: boolean flag to minimize the required metric @@ -264,12 +287,8 @@ class CheckpointCallback(ICheckpointCallback): ``"optimizer"`` and ``"scheduler"`` will be ignored. If ``None`` or an empty dict (or dict without mentioned - above keys) then no action is required at stage start and: - - - Config API - will be used best state of model - - Notebook API - no action will be performed (will be used the last state) - - **NOTE:** Loading will be performed on all stages except first. + above keys) then no action is required at stage start and + no action will be performed (will be used the last state). **NOTE:** Criterion, optimizer and scheduler are optional keys and should be loaded from full checkpoint. @@ -296,6 +315,19 @@ class CheckpointCallback(ICheckpointCallback): and will be used the last runner. **NOTE:** Loading will be performed always at stage end. + resume (str or Dict[str, str]): load specified + state/model for experiment resuming. + + If passed **string** then will be performed initialization from + specified state (``best``/``best_full``/``last``/``last_full``) + or checkpoint file. + + If passed **dict** then will be performed initialization only + for specified parts - model, criterion, optimizer, scheduler. + Logic for dict is the same as for ``load_on_stage_start``. + + If ``None`` or an empty dict (or dict without mentioned + above keys) then no action is required at stage start and: metrics_filename: filename to save metrics in checkpoint folder. Must ends on ``.json`` or ``.yml`` @@ -382,7 +414,7 @@ def __init__( # loading info load_on_stage_start: Union[str, Dict[str, str]] = None, load_on_stage_end: Union[str, Dict[str, str]] = None, - # resume: str = None, + resume: Union[str, Dict[str, str]] = None, # resume_dir: str = None, # checkpointer info metrics_filename: str = "_metrics.json", @@ -392,13 +424,7 @@ def __init__( ): """Init.""" super().__init__(order=CallbackOrder.external, node=CallbackNode.all) - possible_states = { - None, - "best", - "last", - "best_full", - "last_full", - } + possible_states = _default_states.union([None]) assert save_n_best >= 0 if save_n_best == 0: assert load_on_stage_end in (None, "last", "last_full") @@ -406,8 +432,6 @@ def __init__( assert load_on_stage_start in possible_states if isinstance(load_on_stage_end, str): assert load_on_stage_end in possible_states - # if resume_dir is not None: - # assert resume is not None if loader_key is not None or metric_key is not None: assert loader_key is not None and metric_key is not None, ( @@ -448,8 +472,7 @@ def __init__( # loading info self.load_on_stage_start = load_on_stage_start self.load_on_stage_end = load_on_stage_end - # self.resume = resume - # self.resume_dir = resume_dir + self.resume = resume def _pack_checkpoint(self, runner: "IRunner"): checkpoint = runner.engine.pack_checkpoint( @@ -558,16 +581,12 @@ def on_stage_start(self, runner: "IRunner") -> None: .. note:: - If CheckpointCallback initialized with - ``resume`` (as path to checkpoint file) - or ``resume`` (as filename) - and ``resume_dir`` (as directory with file) + If CheckpointCallback initialized with ``resume`` or ``load_on_stage_start``: + - as path to checkpoint file or filename (``for resume only``) + - as specified state (``best``/``best_full``/``last``/``last_full``) + - as dict with specified parts (model, criterion, optimizer, etc.) then will be performed loading checkpoint. - Raises: - FileNotFoundError: if specified load_on_stage_start - but checkpoint file is missing. - Args: runner: current runner """ @@ -588,70 +607,19 @@ def on_stage_start(self, runner: "IRunner") -> None: # Use a barrier() to make sure that all processes have finished reading the checkpoint # dist.barrier() - is_first_stage = list(runner.stages).index(runner.stage_key) == 0 - if self.load_on_stage_start is not None and not is_first_stage: - need_full = False - file_exists = False - if isinstance(self.load_on_stage_start, str): - need_full = self.load_on_stage_start.endswith("full") - use_file = os.path.join(self.logdir, f"{self.load_on_stage_start}.pth") - file_exists = os.path.isfile(use_file) - if not file_exists: - raise FileNotFoundError(f"Missing file '{use_file}'!") # noqa: F821 - elif isinstance(self.load_on_stage_start, dict): - required_files = _get_required_files(self.logdir, self.load_on_stage_start).keys() - file_exists = True - for use_file in required_files: - if not os.path.isfile(use_file): - file_exists = False - raise FileNotFoundError(f"Missing file '{use_file}'!") - - if self.load_on_stage_start is not None and file_exists: - _load_runner( - logdir=self.logdir, - runner=runner, - mapping=self.load_on_stage_start, - load_full=need_full, - ) - - # if getattr(runner, "resume", None) is not None: - # self.resume = runner.resume - # runner.resume = None - # elif getattr(runner, "autoresume", None) is not None: - # self.resume_dir = runner.logdir / "checkpoints" - # self.resume = f"{runner.autoresume}_full.pth" - # runner.autoresume = None - # - # for key in self._keys_from_runner: - # value = getattr(runner, key, None) - # if value is not None: - # setattr(self, key, value) - # - # if self.resume_dir is not None: - # self.resume = str(self.resume_dir) + "/" + str(self.resume) - # - # if self.resume is not None: - # _load_runner(logdir=self.logdir, runner=runner, mapping=self.resume, load_full=True) - # self.resume = None - # else: - # checkpoint_exists = False - # need_load_full = False - # if isinstance(self.load_on_stage_start, str): - # checkpoint_exists = - # os.path.isfile(f"{self.logdir}/{self.load_on_stage_start}.pth") - # need_load_full = self.load_on_stage_start.endswith("full") - # elif isinstance(self.load_on_stage_start, dict): - # required_files = - # _get_required_files(self.logdir, self.load_on_stage_start).keys() - # checkpoint_exists = all(os.path.isfile(file) for file in required_files) - # - # if self.load_on_stage_start is not None and checkpoint_exists: - # _load_runner( - # logdir=self.logdir, - # runner=runner, - # mapping=self.load_on_stage_start, - # load_full=need_load_full, - # ) + if getattr(runner, "_resume", None) is not None: + self.resume = runner._resume + runner._resume = None + + if self.resume is not None: + _load_runner(logdir=self.logdir, runner=runner, mapping=self.resume) + self.resume = None + elif self.load_on_stage_start is not None: + _load_runner( + logdir=self.logdir, + runner=runner, + mapping=self.load_on_stage_start, + ) def on_epoch_end(self, runner: "IRunner") -> None: """ @@ -741,29 +709,15 @@ def on_stage_end(self, runner: "IRunner") -> None: # let's load runner state (model, criterion, optimizer, scheduler) if required not_required_load_states = {"last", "last_full"} - if ( - isinstance(self.load_on_stage_end, str) - and self.load_on_stage_end not in not_required_load_states - and self.save_n_best > 0 - ): - need_load_full = ( - self.load_on_stage_end.endswith("full") - if isinstance(self.load_on_stage_end, str) - else False - ) + if self.save_n_best == 0: + return + if self.load_on_stage_end is not None: _load_runner( logdir=self.logdir, runner=runner, mapping=self.load_on_stage_end, - load_full=need_load_full, + not_required_states=not_required_load_states, ) - elif isinstance(self.load_on_stage_end, dict) and self.save_n_best > 0: - to_load = { - k: v - for k, v in self.load_on_stage_end.items() - if v not in not_required_load_states - } - _load_runner(logdir=self.logdir, runner=runner, mapping=to_load) if runner.engine.is_ddp and runner.engine.is_master_process: # worker sync diff --git a/catalyst/dl/scripts/run.py b/catalyst/dl/scripts/run.py index 4fc73cd51c..a69610eeaa 100755 --- a/catalyst/dl/scripts/run.py +++ b/catalyst/dl/scripts/run.py @@ -31,19 +31,21 @@ def build_args(parser: ArgumentParser): parser.add_argument("--expdir", type=str, default=None) parser.add_argument("--logdir", type=str, default=None) parser.add_argument("--baselogdir", type=str, default=None) - # parser.add_argument( - # "--resume", default=None, type=str, metavar="PATH", help="path to latest checkpoint", - # ) - # parser.add_argument( - # "--autoresume", - # type=str, - # help=( - # "try automatically resume from logdir//{best,last}_full.pth " "if --resume is empty" - # ), - # required=False, - # choices=["best", "last"], - # default=None, - # ) + parser.add_argument( + "--resume", + default=None, + type=str, + metavar="PATH", + help="path to latest checkpoint", + ) + parser.add_argument( + "--autoresume", + type=str, + help=("try automatically resume from logdir/{best,last}_full.pth " "if --resume is empty"), + required=False, + choices=["best", "last"], + default=None, + ) parser.add_argument("--seed", type=int, default=42) boolean_flag( parser, diff --git a/catalyst/runners/config.py b/catalyst/runners/config.py index 65b93e5447..746860280c 100644 --- a/catalyst/runners/config.py +++ b/catalyst/runners/config.py @@ -108,6 +108,7 @@ def __init__(self, config: Dict): self._timeit: bool = get_by_keys(self._config, "args", "timeit", default=False) self._check: bool = get_by_keys(self._config, "args", "check", default=False) self._overfit: bool = get_by_keys(self._config, "args", "overfit", default=False) + self._resume: str = get_by_keys(self._config, "args", "resume") self._name: str = self._get_run_name() self._logdir: str = self._get_run_logdir() @@ -295,7 +296,9 @@ def get_model(self, stage: str) -> RunnerModel: """Returns the model for a given stage.""" assert "model" in self._config, "config must contain 'model' key" model_params: Dict = self._config["model"] - model: RunnerModel = self._get_model_from_params(**model_params) + model: RunnerModel = ( + self._get_model_from_params(**model_params) if self.model is None else self.model + ) return model def get_criterion(self, stage: str) -> RunnerCriterion: @@ -417,7 +420,7 @@ def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]": if self._logdir is not None and not is_callback_exists(ICheckpointCallback): callbacks["_checkpoint"] = CheckpointCallback( - logdir=os.path.join(self._logdir, "checkpoints") + logdir=os.path.join(self._logdir, "checkpoints"), resume=self._resume ) return callbacks diff --git a/catalyst/runners/hydra.py b/catalyst/runners/hydra.py index 975a96e1b9..82ba29cf34 100644 --- a/catalyst/runners/hydra.py +++ b/catalyst/runners/hydra.py @@ -76,6 +76,7 @@ def __init__(self, cfg: "DictConfig"): self._name: str = self._get_run_name() self._logdir: str = self._get_run_logdir() + self._resume: str = self._get_resume() # @TODO: hack for catalyst-dl tune, could be done better self._trial = None @@ -107,6 +108,15 @@ def _get_run_logdir(self) -> str: output = f"{baselogdir}/{logdir}" return output + def _get_resume(self) -> str: + autoresume = self._config.args.autoresume + logdir = self._config.args.logdir + resume = self._config.args.resume + if autoresume is not None and logdir is not None and resume is None: + checkpoint_filename = f"{logdir}/checkpoints/{autoresume}_full.pth" + return checkpoint_filename + return resume + @property def logdir(self) -> str: """Experiment's logdir for artefacts and logging.""" @@ -286,7 +296,9 @@ def get_model(self, stage: str) -> RunnerModel: """Returns the model for a given stage.""" assert "model" in self._config, "config must contain 'model' key" model_params: "DictConfig" = self._config.model - model: RunnerModel = self._get_model_from_params(model_params) + model: RunnerModel = ( + self._get_model_from_params(model_params) if self.model is None else self.model + ) return model @staticmethod @@ -440,7 +452,7 @@ def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]": if self._logdir is not None and not is_callback_exists(ICheckpointCallback): callbacks["_checkpoint"] = CheckpointCallback( - logdir=os.path.join(self._logdir, "checkpoints") + logdir=os.path.join(self._logdir, "checkpoints"), resume=self._resume ) return callbacks diff --git a/catalyst/runners/runner.py b/catalyst/runners/runner.py index a79fb711e0..6e33c273e4 100644 --- a/catalyst/runners/runner.py +++ b/catalyst/runners/runner.py @@ -187,6 +187,7 @@ def __init__(self, *args, **kwargs): self._valid_metric = None self._minimize_valid_metric = None # extras + self._resume: str = None self._verbose = False self._timeit = False self._check = False @@ -310,6 +311,7 @@ def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]": loader_key=self._valid_loader, metric_key=self._valid_metric, minimize=self._minimize_valid_metric, + resume=self._resume, ) return callbacks @@ -337,6 +339,7 @@ def train( num_epochs: int = 1, # extra info (callbacks info) logdir: str = None, + resume: str = None, valid_loader: str = None, valid_metric: str = None, minimize_valid_metric: bool = True, @@ -369,6 +372,7 @@ def train( hparams: hyperparameters for the run num_epochs: number of training epochs logdir: path to output directory + resume: path to checkpoint for model valid_loader: loader name used to calculate the metrics and save the checkpoints. For example, you can pass `train` and then @@ -503,6 +507,7 @@ def on_loader_end(self, runner): self._hparams = hparams self._num_epochs = num_epochs self._logdir = logdir + self._resume = resume self._valid_loader = valid_loader self._valid_metric = valid_metric self._minimize_valid_metric = minimize_valid_metric @@ -540,6 +545,8 @@ def predict_loader( model: Model = None, engine: Union["IEngine", str] = None, seed: int = 42, + # extra info + resume: str = None, # engine extra params, fp16: bool = False, amp: bool = False, @@ -555,6 +562,7 @@ def predict_loader( model: model to use for prediction engine: engine to use for prediction seed: random seed to use before prediction + resume: path to checkpoint for model fp16: boolean flag to use half-precision training (AMP > APEX) amp: boolean flag to use amp half-precision apex: boolean flag to use apex half-precision @@ -653,6 +661,9 @@ def on_loader_end(self, runner): # model inference for logits in runner.predict_loader(loader=loaders["valid"]): assert logits.detach().cpu().numpy().shape[-1] == 10 + # model inference from checkpoint + for logits in runner.predict_loader(loader=loaders["valid"], resume="./logs/best.pth"): + assert logits.detach().cpu().numpy().shape[-1] == 10 """ self.engine = engine or get_available_engine(fp16=fp16, ddp=ddp, amp=amp, apex=apex) @@ -660,9 +671,9 @@ def on_loader_end(self, runner): self.model = model assert self.model is not None - # if resume is not None: - # checkpoint = load_checkpoint(resume) - # unpack_checkpoint(checkpoint, model=self.model) + if resume is not None: + checkpoint = self.engine.load_checkpoint(resume) + self.engine.unpack_checkpoint(checkpoint, model=self.model) self.model = self.engine.sync_device(self.model) maybe_recursive_call(self.model, "train", mode=False) diff --git a/tests/catalyst/callbacks/test_checkpoint.py b/tests/catalyst/callbacks/test_checkpoint.py index 0ca7c73280..2c90ec4a66 100644 --- a/tests/catalyst/callbacks/test_checkpoint.py +++ b/tests/catalyst/callbacks/test_checkpoint.py @@ -1,12 +1,9 @@ # flake8: noqa +# TODO: add test for `save_n_best=0`` from collections import OrderedDict -from io import StringIO import os import re -import shutil -import sys -from tempfile import TemporaryDirectory import pytest @@ -69,7 +66,7 @@ def on_stage_start(self, runner): runner.engine.save_checkpoint(checkpoint, checkpoint_file) def on_batch_start(self, runner): - if not (runner.stage_key == self.stage and runner.stage_batch_step == 0): + if not (runner.stage_key == self.stage and runner.stage_batch_step == 1): return # check if model loaded right checkpoint model = runner.model @@ -99,21 +96,31 @@ def get_engine(self): return self._engine def get_callbacks(self, stage: str): - return { + callbacks = { "criterion": dl.CriterionCallback( metric_key="loss", input_key="logits", target_key="targets" ), "optimizer": dl.OptimizerCallback(metric_key="loss"), - "checkpoint": dl.CheckpointCallback( + "test_model_load": CheckModelStateLoadAfterStages("second", self._logdir, "best.pth"), + } + if stage == "first": + callbacks["checkpoint"] = dl.CheckpointCallback( + self._logdir, + loader_key="valid", + metric_key="loss", + minimize=True, + save_n_best=3, + ) + elif stage == "second": + callbacks["checkpoint"] = dl.CheckpointCallback( self._logdir, loader_key="valid", metric_key="loss", minimize=True, save_n_best=3, load_on_stage_start="best", - ), - "test_model_load": CheckModelStateLoadAfterStages("second", self._logdir, "best.pth"), - } + ) + return callbacks @property def stages(self) -> "Iterable[str]": @@ -143,7 +150,7 @@ def get_trial(self): return None def get_loggers(self): - return {"console": dl.ConsoleLogger(), "csv": dl.CSVLogger(logdir=self._logdir)} + return {} def handle_batch(self, batch): x, y = batch @@ -152,52 +159,42 @@ def handle_batch(self, batch): self.batch = {"features": x, "targets": y, "logits": logits} -def test_device_load_on_stage_start(): - to_check_devices = ["cpu"] - for device in to_check_devices: - with TemporaryDirectory() as logdir: - runner = CustomRunner(logdir, DeviceEngine(device)) - runner.run() - - -@pytest.mark.skipif(not IS_CUDA_AVAILABLE, reason="CUDA is not available") -def test_device_load_on_stage_start(): - to_check_devices = [f"cuda:{i}" for i in range(NUM_CUDA_DEVICES)] - for device in to_check_devices: - with TemporaryDirectory() as logdir: - runner = CustomRunner(logdir, DeviceEngine(device)) - runner.run() +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda", marks=pytest.mark.skipif(not IS_CUDA_AVAILABLE, reason="CUDA is not available") + ), + ], +) +def test_device_load_on_stage_end(device, tmpdir): + logdir = tmpdir + runner = CustomRunner(logdir, DeviceEngine(device)) + runner.run() @pytest.mark.skipif( not (IS_CUDA_AVAILABLE and NUM_CUDA_DEVICES >= 2), reason="Number of CUDA devices is less than 2", ) -def test_dp_load_on_stage_start(): - with TemporaryDirectory() as logdir: - runner = CustomRunner(logdir, DataParallelEngine()) - runner.run() +def test_dp_load_on_stage_end(tmpdir): + logdir = tmpdir + runner = CustomRunner(logdir, DataParallelEngine()) + runner.run() @pytest.mark.skipif( not (IS_CUDA_AVAILABLE and NUM_CUDA_DEVICES >= 2), reason="Number of CUDA devices is less than 2", ) -def test_ddp_load_on_stage_start(): - with TemporaryDirectory() as logdir: - runner = CustomRunner(logdir, DistributedDataParallelEngine()) - runner.run() +def test_ddp_load_on_stage_start(tmpdir): + logdir = tmpdir + runner = CustomRunner(logdir, DistributedDataParallelEngine()) + runner.run() -def test_load_best_on_stage_end(): - old_stdout = sys.stdout - sys.stdout = str_stdout = StringIO() - - # experiment_setup - logdir = "./logs/checkpoint_callback" - checkpoint = logdir # + "/checkpoints" - logfile = checkpoint + "/_metrics.json" - +def train_runner(logdir, n_epochs, callbacks): # data num_samples, num_features = int(1e4), int(1e1) X = torch.rand(num_samples, num_features) @@ -211,8 +208,8 @@ def test_load_best_on_stage_end(): criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters()) runner = dl.SupervisedRunner() + runner.loggers = {} - n_epochs = 5 # first stage runner.train( model=model, @@ -225,275 +222,135 @@ def test_load_best_on_stage_end(): valid_loader="valid", valid_metric="loss", minimize_valid_metric=True, - callbacks=[ - dl.CheckpointCallback( - logdir=logdir, - loader_key="valid", - metric_key="loss", - minimize=True, - save_n_best=2, - load_on_stage_end="best", - ), - dl.CheckRunCallback(num_epoch_steps=n_epochs), - ], + callbacks=callbacks, ) + return runner - sys.stdout = old_stdout - exp_output = str_stdout.getvalue() - assert len(re.findall(r"=> Loading", exp_output)) == 1 - assert len(re.findall(r"=> Loading .*best\.pth", exp_output)) == 1 +def test_files_existence(tmpdir): + logfile = tmpdir + "/_metrics.json" + n_epochs = 5 + callbacks = [ + dl.CheckpointCallback( + logdir=tmpdir, + loader_key="valid", + metric_key="loss", + minimize=True, + save_n_best=2, + ), + dl.CheckRunCallback(num_epoch_steps=n_epochs), + ] + train_runner(tmpdir, n_epochs, callbacks) assert os.path.isfile(logfile) - assert os.path.isfile(checkpoint + "/train.4.pth") - assert os.path.isfile(checkpoint + "/train.4_full.pth") - assert os.path.isfile(checkpoint + "/train.5.pth") - assert os.path.isfile(checkpoint + "/train.5_full.pth") - assert os.path.isfile(checkpoint + "/best.pth") - assert os.path.isfile(checkpoint + "/best_full.pth") - assert os.path.isfile(checkpoint + "/last.pth") - assert os.path.isfile(checkpoint + "/last_full.pth") - - shutil.rmtree(logdir, ignore_errors=True) - - -# @pytest.mark.skip(reason="disabled") -# def test_multiple_stages_and_different_checkpoints_to_load(): -# old_stdout = sys.stdout -# sys.stdout = str_stdout = StringIO() -# -# # experiment_setup -# logdir = "./logs/checkpoint_callback" -# checkpoint = logdir # + "/checkpoints" -# logfile = checkpoint + "/_metrics.json" -# num_epochs = 5 -# -# # data -# num_samples, num_features = int(1e4), int(1e1) -# X = torch.rand(num_samples, num_features) -# y = torch.randint(0, 5, size=[num_samples]) -# dataset = TensorDataset(X, y) -# loader = DataLoader(dataset, batch_size=32, num_workers=1) -# loaders = {"train": loader, "valid": loader} -# -# # model, criterion, optimizer, scheduler -# model = torch.nn.Linear(num_features, 5) -# criterion = torch.nn.CrossEntropyLoss() -# optimizer = torch.optim.Adam(model.parameters()) -# runner = dl.SupervisedRunner() -# -# # first stage -# runner.train( -# model=model, -# criterion=criterion, -# optimizer=optimizer, -# loaders=loaders, -# logdir=logdir, -# num_epochs=num_epochs, -# verbose=False, -# valid_loader="valid", -# valid_metric="loss", -# minimize_valid_metric=True, -# callbacks=[ -# dl.CheckpointCallback( -# logdir=logdir, -# loader_key="valid", -# metric_key="loss", -# minimize=True, -# save_n_best=2, -# load_on_stage_end={"model": "best", "criterion": "best", "optimizer": "last"}, -# ), -# dl.CheckRunCallback(num_epoch_steps=num_epochs), -# ], -# ) -# # second stage -# runner.train( -# model=model, -# criterion=criterion, -# optimizer=optimizer, -# loaders=loaders, -# logdir=logdir, -# num_epochs=num_epochs, -# verbose=False, -# valid_loader="valid", -# valid_metric="loss", -# minimize_valid_metric=True, -# callbacks=[ -# dl.CheckpointCallback( -# logdir=logdir, -# loader_key="valid", -# metric_key="loss", -# minimize=True, -# save_n_best=3, -# load_on_stage_start={"model": "last", "criterion": "last", "optimizer": "best"}, -# ), -# dl.CheckRunCallback(num_epoch_steps=num_epochs), -# ], -# ) -# -# sys.stdout = old_stdout -# exp_output = str_stdout.getvalue() -# -# assert len(re.findall(r"=> Loading", exp_output)) == 3 -# assert len(re.findall(r"=> Loading .*best_full\.pth", exp_output)) == 2 -# assert len(re.findall(r"=> Loading .*last_full\.pth", exp_output)) == 1 -# -# assert os.path.isfile(logfile) -# assert os.path.isfile(checkpoint + "/train.3.pth") -# assert os.path.isfile(checkpoint + "/train.3_full.pth") -# assert os.path.isfile(checkpoint + "/train.4.pth") -# assert os.path.isfile(checkpoint + "/train.4_full.pth") -# assert os.path.isfile(checkpoint + "/train.5.pth") -# assert os.path.isfile(checkpoint + "/train.5_full.pth") -# assert os.path.isfile(checkpoint + "/best.pth") -# assert os.path.isfile(checkpoint + "/best_full.pth") -# assert os.path.isfile(checkpoint + "/last.pth") -# assert os.path.isfile(checkpoint + "/last_full.pth") -# -# shutil.rmtree(logdir, ignore_errors=True) -# -# -# @pytest.mark.skip(reason="disabled") -# def test_resume_with_missing_file(): -# old_stdout = sys.stdout -# sys.stdout = str_stdout = StringIO() -# -# # experiment_setup -# logdir = "./logs/checkpoint_callback" -# checkpoint = logdir + "/checkpoints" -# logfile = checkpoint + "/_metrics.json" -# num_epochs = 5 -# -# # data -# num_samples, num_features = int(1e4), int(1e1) -# X = torch.rand(num_samples, num_features) -# y = torch.randint(0, 5, size=[num_samples]) -# dataset = TensorDataset(X, y) -# loader = DataLoader(dataset, batch_size=32, num_workers=1) -# loaders = {"train": loader, "valid": loader} -# -# # model, criterion, optimizer, scheduler -# model = torch.nn.Linear(num_features, 5) -# criterion = torch.nn.CrossEntropyLoss() -# optimizer = torch.optim.Adam(model.parameters()) -# runner = dl.SupervisedRunner() -# -# with pytest.raises(FileNotFoundError): -# runner.train( -# model=model, -# criterion=criterion, -# optimizer=optimizer, -# loaders=loaders, -# logdir=logdir, -# num_epochs=num_epochs, -# verbose=False, -# valid_loader="valid", -# valid_metric="loss", -# minimize_valid_metric=True, -# callbacks=[ -# dl.CheckpointCallback( -# logdir=logdir, -# loader_key="valid", -# metric_key="loss", -# minimize=True, -# save_n_best=2, -# load_on_stage_end={"model": "best", "criterion": "best", "optimizer": "last"}, -# resume="not_existing_file.pth", -# ), -# dl.CheckRunCallback(num_epoch_steps=num_epochs), -# ], -# ) -# -# sys.stdout = old_stdout -# exp_output = str_stdout.getvalue() -# -# shutil.rmtree(logdir, ignore_errors=True) -# -# -# @pytest.mark.skip(reason="disabled") -# def test_load_on_stage_start_with_empty_dict(): -# old_stdout = sys.stdout -# sys.stdout = str_stdout = StringIO() -# -# # experiment_setup -# logdir = "./logs/checkpoint_callback" -# checkpoint = logdir # + "/checkpoints" -# logfile = checkpoint + "/_metrics.json" -# num_epochs = 5 -# -# # data -# num_samples, num_features = int(1e4), int(1e1) -# X = torch.rand(num_samples, num_features) -# y = torch.randint(0, 5, size=[num_samples]) -# dataset = TensorDataset(X, y) -# loader = DataLoader(dataset, batch_size=32, num_workers=1) -# loaders = {"train": loader, "valid": loader} -# -# # model, criterion, optimizer, scheduler -# model = torch.nn.Linear(num_features, 5) -# criterion = torch.nn.CrossEntropyLoss() -# optimizer = torch.optim.Adam(model.parameters()) -# runner = dl.SupervisedRunner() -# -# # first stage -# runner.train( -# model=model, -# criterion=criterion, -# optimizer=optimizer, -# loaders=loaders, -# logdir=logdir, -# num_epochs=num_epochs, -# verbose=False, -# valid_loader="valid", -# valid_metric="loss", -# minimize_valid_metric=True, -# callbacks=[ -# dl.CheckpointCallback( -# logdir=logdir, loader_key="valid", metric_key="loss", minimize=True, save_n_best=2 -# ), -# dl.CheckRunCallback(num_epoch_steps=num_epochs), -# ], -# ) -# # second stage -# runner.train( -# model=model, -# criterion=criterion, -# optimizer=optimizer, -# loaders=loaders, -# logdir=logdir, -# num_epochs=num_epochs, -# verbose=False, -# valid_loader="valid", -# valid_metric="loss", -# minimize_valid_metric=True, -# callbacks=[ -# dl.CheckpointCallback( -# logdir=logdir, -# loader_key="valid", -# metric_key="loss", -# minimize=True, -# save_n_best=3, -# load_on_stage_start={}, -# ), -# dl.CheckRunCallback(num_epoch_steps=num_epochs), -# ], -# ) -# -# sys.stdout = old_stdout -# exp_output = str_stdout.getvalue() -# -# assert len(re.findall(r"=> Loading", exp_output)) == 0 -# -# assert os.path.isfile(logfile) -# assert os.path.isfile(checkpoint + "/train.3.pth") -# assert os.path.isfile(checkpoint + "/train.3_full.pth") -# assert os.path.isfile(checkpoint + "/train.4.pth") -# assert os.path.isfile(checkpoint + "/train.4_full.pth") -# assert os.path.isfile(checkpoint + "/train.5.pth") -# assert os.path.isfile(checkpoint + "/train.5_full.pth") -# assert os.path.isfile(checkpoint + "/best.pth") -# assert os.path.isfile(checkpoint + "/best_full.pth") -# assert os.path.isfile(checkpoint + "/last.pth") -# assert os.path.isfile(checkpoint + "/last_full.pth") -# -# shutil.rmtree(logdir, ignore_errors=True) + assert os.path.isfile(tmpdir + "/train.4.pth") + assert os.path.isfile(tmpdir + "/train.4_full.pth") + assert os.path.isfile(tmpdir + "/train.5.pth") + assert os.path.isfile(tmpdir + "/train.5_full.pth") + assert os.path.isfile(tmpdir + "/best.pth") + assert os.path.isfile(tmpdir + "/best_full.pth") + assert os.path.isfile(tmpdir + "/last.pth") + assert os.path.isfile(tmpdir + "/last_full.pth") + + +@pytest.mark.parametrize(("to_load", "exp_loaded"), [("best", "model"), ("best_full", "full")]) +def test_load_str_on_stage_end(to_load, exp_loaded, capsys, tmpdir): + # experiment_setup + n_epochs = 5 + callbacks = [ + dl.CheckpointCallback( + logdir=tmpdir, + loader_key="valid", + metric_key="loss", + minimize=True, + save_n_best=2, + load_on_stage_end=to_load, + ), + dl.CheckRunCallback(num_epoch_steps=n_epochs), + ] + + train_runner(tmpdir, n_epochs, callbacks) + exp_output = capsys.readouterr().out + + assert len(re.findall(r"=> Loading", exp_output)) == 1 + assert len(re.findall(r"=> Loading .*{}\.pth".format(to_load), exp_output)) == 1 + assert len(re.findall(r"{} checkpoint".format(exp_loaded), exp_output)) == 1 + + +@pytest.mark.parametrize( + ("to_load", "exp_loaded"), + [ + ({"model": "best", "criterion": "best", "optimizer": "last"}, "model, criterion"), + ( + {"model": "best", "criterion": "best", "optimizer": "best"}, + "model, criterion, optimizer", + ), + ], +) +def test_load_dict_on_stage_end(to_load, exp_loaded, capsys, tmpdir): + # experiment_setup + n_epochs = 5 + callbacks = [ + dl.CheckpointCallback( + logdir=tmpdir, + loader_key="valid", + metric_key="loss", + minimize=True, + save_n_best=2, + load_on_stage_end=to_load, + ), + dl.CheckRunCallback(num_epoch_steps=n_epochs), + ] + + train_runner(tmpdir, n_epochs, callbacks) + exp_output = capsys.readouterr().out + + assert len(re.findall(r"=> Loading", exp_output)) == 1 + assert len(re.findall(r"loaded: {}".format(exp_loaded), exp_output)) == 1 + + +@pytest.mark.parametrize("to_load", [{}, None]) +def test_load_empty(to_load, capsys, tmpdir): + # experiment_setup + n_epochs = 5 + callbacks = [ + dl.CheckpointCallback( + logdir=tmpdir, + loader_key="valid", + metric_key="loss", + minimize=True, + save_n_best=2, + load_on_stage_start=to_load, + load_on_stage_end=to_load, + resume=to_load, + ), + dl.CheckRunCallback(num_epoch_steps=n_epochs), + ] + + train_runner(tmpdir, n_epochs, callbacks) + exp_output = capsys.readouterr().out + + assert len(re.findall(r"=> Loading", exp_output)) == 0 + + +@pytest.mark.parametrize( + "to_load", ["best", {"model": "not_existing_file.pth", "criterion": "not_existing_file.pth"}] +) +def test_resume_with_missing_file(to_load, tmpdir): + n_epochs = 5 + callbacks = [ + dl.CheckpointCallback( + logdir=tmpdir, + loader_key="valid", + metric_key="loss", + minimize=True, + save_n_best=2, + load_on_stage_start=to_load, + load_on_stage_end=to_load, + resume="best", + ), + dl.CheckRunCallback(num_epoch_steps=n_epochs), + ] + + with pytest.raises(FileNotFoundError): + train_runner(tmpdir, n_epochs, callbacks)