Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/improved training from ckpt #1501

Merged
merged 47 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
e7e92fa
feat: new function fit_from_checkpoint that load one chkpt from the m…
madtoinou Jan 19, 2023
22828d1
fix: improved the model saving to allow chaining of fine-tuning, bett…
madtoinou Jan 20, 2023
c4f4370
feat: allow to save the checkpoint in the same folder (loaded checkpo…
madtoinou Jan 20, 2023
75acd53
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Jan 20, 2023
c6eddc1
fix: ordered arguments in a more intuitive way
madtoinou Jan 20, 2023
4b38347
fix: saving model after updating all the parameters to facilitate the…
madtoinou Jan 20, 2023
30603ca
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Jan 20, 2023
1abcb96
feat: support for load_from_checkpoint kwargs, support for force_rese…
madtoinou Jan 20, 2023
bd4f035
feat: adding test for setup_finetuning
madtoinou Jan 20, 2023
0e71805
Merge branch 'feat/improved-training-from-ckpt' of https://github.com…
madtoinou Jan 20, 2023
5ec58bc
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Jan 20, 2023
a7be96f
fix: fused the setup_finetuning and load_from_checkpoint methods, add…
madtoinou Jan 23, 2023
07ac34a
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Jan 23, 2023
206aa40
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Jan 23, 2023
247b570
fix: changed the API/approach, instead of trying to overwrite attribu…
madtoinou Jan 30, 2023
83211be
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Jan 30, 2023
5a39edd
fix: convertion of hyper-parameters to list when checking compatibili…
madtoinou Jan 30, 2023
4d2b77c
Merge branch 'feat/improved-training-from-ckpt' of https://github.com…
madtoinou Jan 30, 2023
44a3fa4
fix: skip the None attribute during the hp check
madtoinou Jan 30, 2023
ee00b89
fix: removed unecessary attribute initialization
madtoinou Jan 30, 2023
9cc0ac8
feat: pl_forecasting_module also save the train_sample in the checkpo…
madtoinou Feb 5, 2023
8c93454
fix: saving only shape instead of the sample itself
madtoinou Feb 5, 2023
77447b2
fix: restore the self.train_sample in TorchForecastingModel
madtoinou Feb 6, 2023
17f9c3d
fix: update fit_called attribute to enable inference without retraining
madtoinou Feb 6, 2023
8e2462f
fix: the mock train_sample must be converted to tuple
madtoinou Feb 6, 2023
ce35e8a
fix: tweaked model parameters to improve convergence
madtoinou Feb 6, 2023
167498a
fix: increased number of epochs to improve convergence/test stability
madtoinou Feb 6, 2023
4a18301
fix: addressing review comments; added load_weights method and corres…
madtoinou Feb 13, 2023
192a423
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Feb 13, 2023
0c6a461
fix: changed default checkpoint path name for compatibility with Wind…
madtoinou Feb 14, 2023
e309390
feat: raise error if the checkpoint being loaded does not contain the…
madtoinou Feb 14, 2023
d13f4a7
fix: saving model manually directly after laoding it from checkpoint …
madtoinou Feb 16, 2023
96812d8
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Feb 16, 2023
4304cf1
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Feb 17, 2023
867ad35
fix: use random_state to fix randomness in tests
madtoinou Feb 19, 2023
b42d6e1
fix: restore newlines
madtoinou Feb 19, 2023
6b0de3e
fix: casting dtype of PLModule before loading the weights
madtoinou Feb 19, 2023
845f96e
doc: model_name docstring and code were not consistent
madtoinou Feb 19, 2023
497420f
doc: improve phrasing
madtoinou Feb 19, 2023
72486f8
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Feb 19, 2023
39ba739
Apply suggestions from code review
madtoinou Feb 19, 2023
edab120
fix: removed warning in saving about trainer/ckpt not being found, wa…
madtoinou Feb 19, 2023
c002f3e
fix: uniformised filename convention using '_' to separate hours, min…
madtoinou Feb 19, 2023
aa735de
fix: removed typo
madtoinou Feb 19, 2023
3328835
Update darts/models/forecasting/torch_forecasting_model.py
madtoinou Feb 19, 2023
9d13eaf
fix: more consistent use of the path argument during save and load
madtoinou Feb 19, 2023
b60c9f2
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Feb 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
self,
input_chunk_length: int,
output_chunk_length: int,
train_sample_shape: Optional[Tuple] = None,
loss_fn: nn.modules.loss._Loss = nn.MSELoss(),
torch_metrics: Optional[
Union[torchmetrics.Metric, torchmetrics.MetricCollection]
Expand Down Expand Up @@ -59,6 +60,9 @@ def __init__(
Number of input past time steps per chunk.
output_chunk_length
Number of output time steps per chunk.
train_sample_shape
Shape of the model's input, used to instantiate model without calling ``fit_from_dataset`` and
perform sanity check on new training/inference datasets used for re-training or prediction.
loss_fn
PyTorch loss function used for training.
This parameter will be ignored for probabilistic models if the ``likelihood`` parameter is specified.
Expand Down Expand Up @@ -101,6 +105,9 @@ def __init__(
# by default models are deterministic (i.e. not probabilistic)
self.likelihood = likelihood

# saved in checkpoint to be able to instantiate a model without calling fit_from_dataset
self.train_sample_shape = train_sample_shape

# persist optimiser and LR scheduler parameters
self.optimizer_cls = optimizer_cls
self.optimizer_kwargs = dict() if optimizer_kwargs is None else optimizer_kwargs
Expand Down Expand Up @@ -370,6 +377,8 @@ def _produce_predict_output(self, x: Tuple):
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# we must save the dtype for correct parameter precision at loading time
checkpoint["model_dtype"] = self.dtype
# we must save the shape of the input to be able to instanciate the model without calling fit_from_dataset
checkpoint["train_sample_shape"] = self.train_sample_shape

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# by default our models are initialized as float32. For other dtypes, we need to cast to the correct precision
Expand Down
165 changes: 147 additions & 18 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,10 @@ def _init_model(self, trainer: Optional[pl.Trainer] = None) -> None:
"calling `super.__init__(...)`. Do this with `self._extract_pl_module_params(**self.model_params).`",
)

self.pl_module_params["train_sample_shape"] = [
variate.shape if variate is not None else None
for variate in self.train_sample
]
# the tensors have shape (chunk_length, nr_dimensions)
self.model = self._create_model(self.train_sample)
self._module_name = self.model.__class__.__name__
Expand Down Expand Up @@ -860,7 +864,7 @@ def fit_from_dataset(
same_dims,
"The dimensionality of the series in the training set do not match the dimensionality"
" of the series the model has previously been trained on. "
"Model input/output dimensions = {}, provided input/ouptput dimensions = {}".format(
"Model input/output dimensions = {}, provided input/output dimensions = {}".format(
tuple(
s.shape[1] if s is not None else None for s in self.train_sample
),
Expand Down Expand Up @@ -1261,8 +1265,10 @@ def save(self, path: Optional[str] = None) -> None:
Parameters
----------
path
Path under which to save the model at its current state. If no path is specified, the model is automatically
saved under ``"{ModelClass}_{YYYY-mm-dd_HH:MM:SS}.pt"``. E.g., ``"RNNModel_2020-01-01_12:00:00.pt"``.
Path under which to save the model at its current state. Please avoid path starting with "last-" or
"best-" to avoid collision with Pytorch-Ligthning checkpoints. If no path is specified, the model
is automatically saved under ``"{ModelClass}_{YYYY-mm-dd_HH:MM:SS}.pt"``.
E.g., ``"RNNModel_2020-01-01_12:00:00.pt"``.
"""
if path is None:
# default path
Expand Down Expand Up @@ -1338,37 +1344,25 @@ def load_from_checkpoint(
"""
Load the model from automatically saved checkpoints under '{work_dir}/darts_logs/{model_name}/checkpoints/'.
This method is used for models that were created with ``save_checkpoints=True``.

madtoinou marked this conversation as resolved.
Show resolved Hide resolved
If you manually saved your model, consider using :meth:`load() <TorchForecastingModel.load()>`.

Example for loading a :class:`RNNModel` from checkpoint (``model_name`` is the ``model_name`` used at model
creation):

.. highlight:: python
.. code-block:: python

from darts.models import RNNModel

model_loaded = RNNModel.load_from_checkpoint(model_name, best=True)
..

If ``file_name`` is given, returns the model saved under
'{work_dir}/darts_logs/{model_name}/checkpoints/{file_name}'.

If ``file_name`` is not given, will try to restore the best checkpoint (if ``best`` is ``True``) or the most
recent checkpoint (if ``best`` is ``False`` from '{work_dir}/darts_logs/{model_name}/checkpoints/'.

Example for loading an :class:`RNNModel` checkpoint to CPU that was saved on GPU:

.. highlight:: python
.. code-block:: python

from darts.models import RNNModel

model_loaded = RNNModel.load_from_checkpoint(model_name, best=True, map_location="cpu")
model_loaded.to_cpu()
..

Parameters
----------
model_name
Expand All @@ -1385,8 +1379,6 @@ def load_from_checkpoint(
such as ``map_location`` to load the model onto a different device than the one from which it was saved.
For more information, read the `official documentation <https://pytorch-lightning.readthedocs.io/en/stable/
common/lightning_module.html#load-from-checkpoint>`_.


Returns
-------
TorchForecastingModel
Expand All @@ -1409,7 +1401,7 @@ def load_from_checkpoint(
model = TorchForecastingModel.load(base_model_path, **kwargs)

# load PyTorch LightningModule from checkpoint
# if file_name is None, find most recent file in savepath that is a checkpoint
# if file_name is None, find the path of the best or most recent checkpoint in savepath
if file_name is None:
file_name = _get_checkpoint_fname(work_dir, model_name, best=best)

Expand All @@ -1429,6 +1421,143 @@ def _load_from_checkpoint(self, file_path, **kwargs):
pl_module_cls = getattr(sys.modules[self._module_path], self._module_name)
return pl_module_cls.load_from_checkpoint(file_path, **kwargs)

def load_weights_from_checkpoint(
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
self,
model_name: str = None,
work_dir: str = None,
file_name: str = None,
best: bool = True,
strict: bool = True,
dennisbader marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
):
"""
Load only the weights from automatically saved checkpoints under '{work_dir}/darts_logs/{model_name}/
checkpoints/'. This method is used for models that were created with ``save_checkpoints=True`` and
that need to be re-trained or fine-tuned with different optimizer or learning rate scheduler. However,
it can also be used to load weights for inference.

To resume an interrupted training, please consider using :meth:`load_from_checkpoint()
<TorchForecastingModel.load_from_checkpoint()>` which also reload the trainer, optimizer and
learning rate scheduler states.

For manually saved model, consider using :meth:`load() <TorchForecastingModel.load()>` or
:meth:`load_weights() <TorchForecastingModel.load_weights()>` instead.

Parameters
----------
model_name
The name of the model (used to retrieve the checkpoints folder's name). Default: ``self.model_name``.
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
work_dir
Working directory (containing the checkpoints folder). Defaults to current working directory.
file_name
The name of the checkpoint file. If not specified, use the most recent one.
best
If set, will retrieve the best model (according to validation loss) instead of the most recent one. Only
is ignored when ``file_name`` is given. Default: ``True``.
strict
If set, strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict().
Default: ``True``.
For more information, read the `official documentation <https://pytorch.org/docs/stable/generated/torch.
nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict>`_.
**kwargs
Additional kwargs for PyTorch's :func:`load` method, such as ``map_location`` to load the model onto a
different device than the one from which it was saved.
For more information, read the `official documentation <https://pytorch.org/docs/stable/generated/
torch.load.html>`_.
"""
raise_if(
"weights_only" in kwargs.keys() and kwargs["weights_only"],
"Passing `weights_only=True` to `torch.load` will disrupt this"
" method sanity checks.",
logger,
)

# use the name of the model being loaded with the saved weights
if model_name is None:
model_name = self.model_name

if work_dir is None:
work_dir = os.path.join(os.getcwd(), DEFAULT_DARTS_FOLDER)

# load PyTorch LightningModule from checkpoint
# if file_name is None, find the path of the best or most recent checkpoint in savepath
if file_name is None:
file_name = _get_checkpoint_fname(work_dir, model_name, best=best)

# checkpoints generated by PL, prefix is defined in TorchForecastingModel __init__()
if file_name[:5] == "last-" or file_name[:5] == "best-":
checkpoint_dir = _get_checkpoint_folder(work_dir, model_name)
# manual save
else:
checkpoint_dir = os.getcwd()

ckpt_path = os.path.join(checkpoint_dir, file_name)
ckpt = torch.load(ckpt_path, **kwargs)
ckpt_hyper_params = ckpt["hyper_parameters"]

# verify that the arguments passed to the constructor match those of the checkpoint
for param_key, param_value in self.model_params.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we hard set strict=True we could let PyTorch handle any discrepancies later on when calling

self.model.load_state_dict(ckpt["state_dict"], strict=True)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what I did initially but the error message are not informative at all, making it quite difficult for the user to realize that the problem comes from the definition of the model into which the weights are loaded.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but on the other hand this requires all TorchForecastingModels and their corresponding PLForecastingModules to share the same model parameter names, which is not the case as you mention (and might be difficult to enforce in some cases).

So the torch error can still be raised, or maybe I'm missing something :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, the torch error can still be raised if the discrepancy is in one of these parameter that do not have the same name in these two objects.

We could eventually have a dict in each model that that tries to map the name of the parameters in order to be able to run this sanity check thoroughly? Or try to catch torch error if load_state_dict fails and raise a meaningful message to the user, indicating that the weights mismatch can be caused by invalid parameters (or by a change of convention in torch...)?

# TODO: there are discrepancies between the param names, for ex num_layer/n_rnn_layers
if param_key in ckpt_hyper_params.keys() and param_value is not None:
# some parameters must be converted
if isinstance(ckpt_hyper_params[param_key], list) and not isinstance(
param_value, list
):
param_value = [param_value] * len(ckpt_hyper_params[param_key])

raise_if(
param_value != ckpt_hyper_params[param_key],
f"The values of the hyper parameter {param_key} should be identical between "
f"the instantiated model ({param_value}) and the loaded checkpoint "
f"({ckpt_hyper_params[param_key]}). Please adjust the model accordingly.",
logger,
)

# pl_forecasting module saves the train_sample shape, must recreate one
mock_train_sample = [
np.zeros(sample_shape) if sample_shape else None
for sample_shape in ckpt["train_sample_shape"]
]
self.train_sample = tuple(mock_train_sample)

# instanciate the model without having to call `fit_from_dataset`
self._init_model()
madtoinou marked this conversation as resolved.
Show resolved Hide resolved

# load only the weights from the state dict
self.model.load_state_dict(ckpt["state_dict"], strict=strict)
# update the fit_called attribute to allow for direct inference
self._fit_called = True

def load_weights(self, path: str, **kwargs):
"""
Loads a model weights from a Pytorch Lightning checkpoint file ('.ckpt') path.

This method is used for the checkpoints created with ``model.save()``
madtoinou marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
path
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I just saw this now. I think we should use the path that the user gave when manually saving the model, i.e. model.save("my_model.pt"), rather than the .ckpt path.

Then we just replace ".pt" with ".pt.ckpt" and get the checkpoint from there. Check here that the ckpt exists similar to how do it now in TorchForecastingModel.load()

Copy link
Collaborator Author

@madtoinou madtoinou Feb 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, it makes the interface much more consistent and intuitive.

load_weights() now expects the .pt path and the .ckpt suffix is added afterward, inside the function.

Path from which to load the model's weights. If no path was specified when saving the model, the
automatically generated path ending with ".ckpt" has to be provided.
**kwargs
Additional kwargs for PyTorch's :func:`load` method, such as ``map_location`` to load the model onto a
different device than the one from which it was saved.
For more information, read the `official documentation <https://pytorch.org/docs/stable/generated/
torch.load.html>`_.

"""
raise_if_not(
path.endswith(".ckpt"),
"The file path passed to this method should end with '.ckpt' "
"(Pytorch LightningModule checkpoints extension).",
logger,
)

self.load_weights_from_checkpoint(
file_name=path,
**kwargs,
)

def to_cpu(self):
"""Updates the PyTorch Lightning Trainer parameters to move the model to CPU the next time :fun:`fit()` or
:func:`predict()` is called.
Expand Down
Loading