Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Returned resume support #1349

Merged
merged 17 commits into from
Nov 10, 2021
Merged
1 change: 1 addition & 0 deletions catalyst/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
warnings.filterwarnings("ignore", message="numpy.dtype size changed", append=True)
warnings.filterwarnings("ignore", module="tqdm", append=True)
warnings.filterwarnings("once", append=True)
warnings.filterwarnings("ignore", message="This overload of add_ is deprecated", append=True)

from catalyst.__version__ import __version__
from catalyst.settings import SETTINGS
186 changes: 70 additions & 116 deletions catalyst/callbacks/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Union
from typing import Dict, Iterable, Union
from collections import OrderedDict
import os
from pathlib import Path
Expand All @@ -9,6 +9,8 @@
from catalyst.tools.metric_handler import MetricHandler
from catalyst.utils.config import save_config

_default_states = {"best", "best_full", "last", "last_full"}


def _save_checkpoint(
checkpoint: Dict,
Expand Down Expand Up @@ -123,7 +125,6 @@ def _get_required_files(logdir: str, load_map: Dict[str, str]) -> Dict[str, str]
if load_map is None:
return OrderedDict()

default_states = {"best", "best_full", "last", "last_full"}
required_full_checkpoint = ["criterion", "optimizer", "scheduler"]
steps = ["global_epoch_step", "global_batch_step", "global_sample_step"]
experiment_parts = ["model"] + required_full_checkpoint + steps
Expand All @@ -141,7 +142,7 @@ def _get_required_files(logdir: str, load_map: Dict[str, str]) -> Dict[str, str]
fname = load_map[part]
required_full = fname.endswith("_full")
# specified default state
if fname in default_states:
if fname in _default_states:
if part in required_full_checkpoint and not required_full:
fname = fname + "_full"
fname = f"{logdir}/{fname}.pth"
Expand Down Expand Up @@ -197,24 +198,46 @@ def _load_states_from_file_map(


def _load_runner(
logdir: str, runner: "IRunner", mapping: Union[str, Dict[str, str]], load_full: bool = False
logdir: str,
runner: "IRunner",
mapping: Union[str, Dict[str, str]],
not_required_states: Iterable[str] = None,
) -> None:
"""
Selects a loading method based on type of mapping.
Checks if the files used in mapping exist and selects a loading method
based on type of mapping.

Args:
logdir: logdir with checkpoints
runner: current runner
mapping: mapping to use for loading
load_full: load a full model, used only when mapping type is string
not_required_states: TODO

Raises:
FileNotFoundError: if files given in mapping are missing.
"""
if not_required_states is None:
not_required_states = []
possible_states = _default_states.difference(not_required_states)
file_exists = False
if isinstance(mapping, str):
if mapping in {"best", "best_full", "last", "last_full"}:
load_full = "full" in mapping
if mapping in possible_states:
checkpoint = f"{logdir}/{mapping}.pth"
else:
checkpoint = mapping
file_exists = os.path.isfile(checkpoint)
if not file_exists:
raise FileNotFoundError(f"Missing file '{checkpoint}'!") # noqa: F821
_load_checkpoint(filename=checkpoint, runner=runner, load_full=load_full)
elif isinstance(mapping, dict):
mapping = {k: v for k, v in mapping.items() if v not in not_required_states}
required_files = _get_required_files(logdir, mapping).keys()
file_exists = True
for use_file in required_files:
if not os.path.isfile(use_file):
file_exists = False
raise FileNotFoundError(f"Missing file '{use_file}'!")
_load_states_from_file_map(logdir=logdir, runner=runner, load_map=mapping)


Expand All @@ -228,7 +251,7 @@ class CheckpointCallback(ICheckpointCallback):
"""Checkpoint callback to save/restore your model/criterion/optimizer/scheduler.

Args:
logdir: directory to store chekpoints
logdir: directory to store checkpoints
loader_key: loader key for best model selection (based on metric score over the dataset)
metric_key: metric key for best model selection (based on metric score over the dataset)
minimize: boolean flag to minimize the required metric
Expand Down Expand Up @@ -264,12 +287,8 @@ class CheckpointCallback(ICheckpointCallback):
``"optimizer"`` and ``"scheduler"`` will be ignored.

If ``None`` or an empty dict (or dict without mentioned
above keys) then no action is required at stage start and:

- Config API - will be used best state of model
- Notebook API - no action will be performed (will be used the last state)

**NOTE:** Loading will be performed on all stages except first.
above keys) then no action is required at stage start and
no action will be performed (will be used the last state).

**NOTE:** Criterion, optimizer and scheduler are optional keys
and should be loaded from full checkpoint.
Expand All @@ -296,6 +315,19 @@ class CheckpointCallback(ICheckpointCallback):
and will be used the last runner.

**NOTE:** Loading will be performed always at stage end.
resume (str or Dict[str, str]): load specified
state/model for experiment resuming.

If passed **string** then will be performed initialization from
specified state (``best``/``best_full``/``last``/``last_full``)
or checkpoint file.

If passed **dict** then will be performed initialization only
for specified parts - model, criterion, optimizer, scheduler.
Logic for dict is the same as for ``load_on_stage_start``.

If ``None`` or an empty dict (or dict without mentioned
above keys) then no action is required at stage start and:
metrics_filename: filename to save metrics
in checkpoint folder.
Must ends on ``.json`` or ``.yml``
Expand Down Expand Up @@ -382,7 +414,7 @@ def __init__(
# loading info
load_on_stage_start: Union[str, Dict[str, str]] = None,
load_on_stage_end: Union[str, Dict[str, str]] = None,
# resume: str = None,
resume: Union[str, Dict[str, str]] = None,
# resume_dir: str = None,
# checkpointer info
metrics_filename: str = "_metrics.json",
Expand All @@ -392,22 +424,14 @@ def __init__(
):
"""Init."""
super().__init__(order=CallbackOrder.external, node=CallbackNode.all)
possible_states = {
None,
"best",
"last",
"best_full",
"last_full",
}
possible_states = _default_states.union([None])
assert save_n_best >= 0
if save_n_best == 0:
assert load_on_stage_end in (None, "last", "last_full")
if isinstance(load_on_stage_start, str):
assert load_on_stage_start in possible_states
if isinstance(load_on_stage_end, str):
assert load_on_stage_end in possible_states
# if resume_dir is not None:
# assert resume is not None

if loader_key is not None or metric_key is not None:
assert loader_key is not None and metric_key is not None, (
Expand Down Expand Up @@ -448,8 +472,7 @@ def __init__(
# loading info
self.load_on_stage_start = load_on_stage_start
self.load_on_stage_end = load_on_stage_end
# self.resume = resume
# self.resume_dir = resume_dir
self.resume = resume

def _pack_checkpoint(self, runner: "IRunner"):
checkpoint = runner.engine.pack_checkpoint(
Expand Down Expand Up @@ -558,16 +581,12 @@ def on_stage_start(self, runner: "IRunner") -> None:

.. note::

If CheckpointCallback initialized with
``resume`` (as path to checkpoint file)
or ``resume`` (as filename)
and ``resume_dir`` (as directory with file)
If CheckpointCallback initialized with ``resume`` or ``load_on_stage_start``:
- as path to checkpoint file or filename (``for resume only``)
- as specified state (``best``/``best_full``/``last``/``last_full``)
- as dict with specified parts (model, criterion, optimizer, etc.)
then will be performed loading checkpoint.

Raises:
FileNotFoundError: if specified load_on_stage_start
but checkpoint file is missing.

Args:
runner: current runner
"""
Expand All @@ -588,70 +607,19 @@ def on_stage_start(self, runner: "IRunner") -> None:
# Use a barrier() to make sure that all processes have finished reading the checkpoint
# dist.barrier()

is_first_stage = list(runner.stages).index(runner.stage_key) == 0
if self.load_on_stage_start is not None and not is_first_stage:
need_full = False
file_exists = False
if isinstance(self.load_on_stage_start, str):
need_full = self.load_on_stage_start.endswith("full")
use_file = os.path.join(self.logdir, f"{self.load_on_stage_start}.pth")
file_exists = os.path.isfile(use_file)
if not file_exists:
raise FileNotFoundError(f"Missing file '{use_file}'!") # noqa: F821
elif isinstance(self.load_on_stage_start, dict):
required_files = _get_required_files(self.logdir, self.load_on_stage_start).keys()
file_exists = True
for use_file in required_files:
if not os.path.isfile(use_file):
file_exists = False
raise FileNotFoundError(f"Missing file '{use_file}'!")

if self.load_on_stage_start is not None and file_exists:
_load_runner(
logdir=self.logdir,
runner=runner,
mapping=self.load_on_stage_start,
load_full=need_full,
)

# if getattr(runner, "resume", None) is not None:
# self.resume = runner.resume
# runner.resume = None
# elif getattr(runner, "autoresume", None) is not None:
# self.resume_dir = runner.logdir / "checkpoints"
# self.resume = f"{runner.autoresume}_full.pth"
# runner.autoresume = None
#
# for key in self._keys_from_runner:
# value = getattr(runner, key, None)
# if value is not None:
# setattr(self, key, value)
#
# if self.resume_dir is not None:
# self.resume = str(self.resume_dir) + "/" + str(self.resume)
#
# if self.resume is not None:
# _load_runner(logdir=self.logdir, runner=runner, mapping=self.resume, load_full=True)
# self.resume = None
# else:
# checkpoint_exists = False
# need_load_full = False
# if isinstance(self.load_on_stage_start, str):
# checkpoint_exists =
# os.path.isfile(f"{self.logdir}/{self.load_on_stage_start}.pth")
# need_load_full = self.load_on_stage_start.endswith("full")
# elif isinstance(self.load_on_stage_start, dict):
# required_files =
# _get_required_files(self.logdir, self.load_on_stage_start).keys()
# checkpoint_exists = all(os.path.isfile(file) for file in required_files)
#
# if self.load_on_stage_start is not None and checkpoint_exists:
# _load_runner(
# logdir=self.logdir,
# runner=runner,
# mapping=self.load_on_stage_start,
# load_full=need_load_full,
# )
if getattr(runner, "_resume", None) is not None:
self.resume = runner._resume
runner._resume = None

if self.resume is not None:
_load_runner(logdir=self.logdir, runner=runner, mapping=self.resume)
self.resume = None
elif self.load_on_stage_start is not None:
_load_runner(
logdir=self.logdir,
runner=runner,
mapping=self.load_on_stage_start,
)

def on_epoch_end(self, runner: "IRunner") -> None:
"""
Expand Down Expand Up @@ -741,29 +709,15 @@ def on_stage_end(self, runner: "IRunner") -> None:

# let's load runner state (model, criterion, optimizer, scheduler) if required
not_required_load_states = {"last", "last_full"}
if (
isinstance(self.load_on_stage_end, str)
and self.load_on_stage_end not in not_required_load_states
and self.save_n_best > 0
):
need_load_full = (
self.load_on_stage_end.endswith("full")
if isinstance(self.load_on_stage_end, str)
else False
)
if self.save_n_best == 0:
return
if self.load_on_stage_end is not None:
_load_runner(
logdir=self.logdir,
runner=runner,
mapping=self.load_on_stage_end,
load_full=need_load_full,
not_required_states=not_required_load_states,
)
elif isinstance(self.load_on_stage_end, dict) and self.save_n_best > 0:
to_load = {
k: v
for k, v in self.load_on_stage_end.items()
if v not in not_required_load_states
}
_load_runner(logdir=self.logdir, runner=runner, mapping=to_load)

if runner.engine.is_ddp and runner.engine.is_master_process:
# worker sync
Expand Down
28 changes: 15 additions & 13 deletions catalyst/dl/scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,21 @@ def build_args(parser: ArgumentParser):
parser.add_argument("--expdir", type=str, default=None)
parser.add_argument("--logdir", type=str, default=None)
parser.add_argument("--baselogdir", type=str, default=None)
# parser.add_argument(
# "--resume", default=None, type=str, metavar="PATH", help="path to latest checkpoint",
# )
# parser.add_argument(
# "--autoresume",
# type=str,
# help=(
# "try automatically resume from logdir//{best,last}_full.pth " "if --resume is empty"
# ),
# required=False,
# choices=["best", "last"],
# default=None,
# )
parser.add_argument(
"--resume",
default=None,
type=str,
metavar="PATH",
help="path to latest checkpoint",
)
parser.add_argument(
"--autoresume",
type=str,
help=("try automatically resume from logdir/{best,last}_full.pth " "if --resume is empty"),
required=False,
choices=["best", "last"],
default=None,
)
parser.add_argument("--seed", type=int, default=42)
boolean_flag(
parser,
Expand Down
7 changes: 5 additions & 2 deletions catalyst/runners/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(self, config: Dict):
self._timeit: bool = get_by_keys(self._config, "args", "timeit", default=False)
self._check: bool = get_by_keys(self._config, "args", "check", default=False)
self._overfit: bool = get_by_keys(self._config, "args", "overfit", default=False)
self._resume: str = get_by_keys(self._config, "args", "resume")

self._name: str = self._get_run_name()
self._logdir: str = self._get_run_logdir()
Expand Down Expand Up @@ -292,7 +293,9 @@ def get_model(self, stage: str) -> RunnerModel:
"""Returns the model for a given stage."""
assert "model" in self._config, "config must contain 'model' key"
model_params: Dict = self._config["model"]
model: RunnerModel = self._get_model_from_params(**model_params)
model: RunnerModel = (
self._get_model_from_params(**model_params) if self.model is None else self.model
)
return model

def get_criterion(self, stage: str) -> RunnerCriterion:
Expand Down Expand Up @@ -414,7 +417,7 @@ def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]":

if self._logdir is not None and not is_callback_exists(ICheckpointCallback):
callbacks["_checkpoint"] = CheckpointCallback(
logdir=os.path.join(self._logdir, "checkpoints")
logdir=os.path.join(self._logdir, "checkpoints"), resume=self._resume
)

return callbacks
Expand Down
Loading