Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/improved training from ckpt #1501

Merged
merged 47 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
e7e92fa
feat: new function fit_from_checkpoint that load one chkpt from the m…
madtoinou Jan 19, 2023
22828d1
fix: improved the model saving to allow chaining of fine-tuning, bett…
madtoinou Jan 20, 2023
c4f4370
feat: allow to save the checkpoint in the same folder (loaded checkpo…
madtoinou Jan 20, 2023
75acd53
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Jan 20, 2023
c6eddc1
fix: ordered arguments in a more intuitive way
madtoinou Jan 20, 2023
4b38347
fix: saving model after updating all the parameters to facilitate the…
madtoinou Jan 20, 2023
30603ca
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Jan 20, 2023
1abcb96
feat: support for load_from_checkpoint kwargs, support for force_rese…
madtoinou Jan 20, 2023
bd4f035
feat: adding test for setup_finetuning
madtoinou Jan 20, 2023
0e71805
Merge branch 'feat/improved-training-from-ckpt' of https://github.com…
madtoinou Jan 20, 2023
5ec58bc
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Jan 20, 2023
a7be96f
fix: fused the setup_finetuning and load_from_checkpoint methods, add…
madtoinou Jan 23, 2023
07ac34a
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Jan 23, 2023
206aa40
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Jan 23, 2023
247b570
fix: changed the API/approach, instead of trying to overwrite attribu…
madtoinou Jan 30, 2023
83211be
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Jan 30, 2023
5a39edd
fix: convertion of hyper-parameters to list when checking compatibili…
madtoinou Jan 30, 2023
4d2b77c
Merge branch 'feat/improved-training-from-ckpt' of https://github.com…
madtoinou Jan 30, 2023
44a3fa4
fix: skip the None attribute during the hp check
madtoinou Jan 30, 2023
ee00b89
fix: removed unecessary attribute initialization
madtoinou Jan 30, 2023
9cc0ac8
feat: pl_forecasting_module also save the train_sample in the checkpo…
madtoinou Feb 5, 2023
8c93454
fix: saving only shape instead of the sample itself
madtoinou Feb 5, 2023
77447b2
fix: restore the self.train_sample in TorchForecastingModel
madtoinou Feb 6, 2023
17f9c3d
fix: update fit_called attribute to enable inference without retraining
madtoinou Feb 6, 2023
8e2462f
fix: the mock train_sample must be converted to tuple
madtoinou Feb 6, 2023
ce35e8a
fix: tweaked model parameters to improve convergence
madtoinou Feb 6, 2023
167498a
fix: increased number of epochs to improve convergence/test stability
madtoinou Feb 6, 2023
4a18301
fix: addressing review comments; added load_weights method and corres…
madtoinou Feb 13, 2023
192a423
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Feb 13, 2023
0c6a461
fix: changed default checkpoint path name for compatibility with Wind…
madtoinou Feb 14, 2023
e309390
feat: raise error if the checkpoint being loaded does not contain the…
madtoinou Feb 14, 2023
d13f4a7
fix: saving model manually directly after laoding it from checkpoint …
madtoinou Feb 16, 2023
96812d8
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Feb 16, 2023
4304cf1
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Feb 17, 2023
867ad35
fix: use random_state to fix randomness in tests
madtoinou Feb 19, 2023
b42d6e1
fix: restore newlines
madtoinou Feb 19, 2023
6b0de3e
fix: casting dtype of PLModule before loading the weights
madtoinou Feb 19, 2023
845f96e
doc: model_name docstring and code were not consistent
madtoinou Feb 19, 2023
497420f
doc: improve phrasing
madtoinou Feb 19, 2023
72486f8
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Feb 19, 2023
39ba739
Apply suggestions from code review
madtoinou Feb 19, 2023
edab120
fix: removed warning in saving about trainer/ckpt not being found, wa…
madtoinou Feb 19, 2023
c002f3e
fix: uniformised filename convention using '_' to separate hours, min…
madtoinou Feb 19, 2023
aa735de
fix: removed typo
madtoinou Feb 19, 2023
3328835
Update darts/models/forecasting/torch_forecasting_model.py
madtoinou Feb 19, 2023
9d13eaf
fix: more consistent use of the path argument during save and load
madtoinou Feb 19, 2023
b60c9f2
Merge branch 'master' into feat/improved-training-from-ckpt
madtoinou Feb 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions darts/models/forecasting/block_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,10 @@ def __init__(
Number of epochs over which to train the model. Default: ``100``.
model_name
Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified,
defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part
defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part
of the name is formatted with the local date and time, while PID is the processed ID (preventing models
spawned at the same time by different processes to share the same model_name). E.g.,
``"2021-06-14_09:53:32_torch_model_run_44607"``.
``"2021-06-14_09_53_32_torch_model_run_44607"``.
work_dir
Path of the working directory, where to save checkpoints and Tensorboard summaries.
Default: current working directory.
Expand Down
4 changes: 2 additions & 2 deletions darts/models/forecasting/dlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,10 @@ def __init__(
Number of epochs over which to train the model. Default: ``100``.
model_name
Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified,
defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part
defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part
of the name is formatted with the local date and time, while PID is the processed ID (preventing models
spawned at the same time by different processes to share the same model_name). E.g.,
``"2021-06-14_09:53:32_torch_model_run_44607"``.
``"2021-06-14_09_53_32_torch_model_run_44607"``.
work_dir
Path of the working directory, where to save checkpoints and Tensorboard summaries.
Default: current working directory.
Expand Down
6 changes: 3 additions & 3 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,7 @@ def model_params(self) -> dict:

@classmethod
def _default_save_path(cls) -> str:
return f"{cls.__name__}_{datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
return f"{cls.__name__}_{datetime.datetime.now().strftime('%Y-%m-%d_%H_%M_%S')}"

def save(self, path: Optional[Union[str, BinaryIO]] = None, **pkl_kwargs) -> None:
"""
Expand All @@ -1555,8 +1555,8 @@ def save(self, path: Optional[Union[str, BinaryIO]] = None, **pkl_kwargs) -> Non
----------
path
Path or file handle under which to save the model at its current state. If no path is specified, the model
is automatically saved under ``"{ModelClass}_{YYYY-mm-dd_HH:MM:SS}.pkl"``.
E.g., ``"RegressionModel_2020-01-01_12:00:00.pkl"``.
is automatically saved under ``"{ModelClass}_{YYYY-mm-dd_HH_MM_SS}.pkl"``.
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
E.g., ``"RegressionModel_2020-01-01_12_00_00.pkl"``.
pkl_kwargs
Keyword arguments passed to `pickle.dump()`
"""
Expand Down
4 changes: 2 additions & 2 deletions darts/models/forecasting/nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,10 +622,10 @@ def __init__(
Number of epochs over which to train the model. Default: ``100``.
model_name
Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified,
defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part
defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part
of the name is formatted with the local date and time, while PID is the processed ID (preventing models
spawned at the same time by different processes to share the same model_name). E.g.,
``"2021-06-14_09:53:32_torch_model_run_44607"``.
``"2021-06-14_09_53_32_torch_model_run_44607"``.
work_dir
Path of the working directory, where to save checkpoints and Tensorboard summaries.
Default: current working directory.
Expand Down
4 changes: 2 additions & 2 deletions darts/models/forecasting/nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,10 +558,10 @@ def __init__(
Number of epochs over which to train the model. Default: ``100``.
model_name
Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified,
defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part
defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part
of the name is formatted with the local date and time, while PID is the processed ID (preventing models
spawned at the same time by different processes to share the same model_name). E.g.,
``"2021-06-14_09:53:32_torch_model_run_44607"``.
``"2021-06-14_09_53_32_torch_model_run_44607"``.
work_dir
Path of the working directory, where to save checkpoints and Tensorboard summaries.
Default: current working directory.
Expand Down
4 changes: 2 additions & 2 deletions darts/models/forecasting/nlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,10 @@ def __init__(
Number of epochs over which to train the model. Default: ``100``.
model_name
Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified,
defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part
defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part
of the name is formatted with the local date and time, while PID is the processed ID (preventing models
spawned at the same time by different processes to share the same model_name). E.g.,
``"2021-06-14_09:53:32_torch_model_run_44607"``.
``"2021-06-14_09_53_32_torch_model_run_44607"``.
work_dir
Path of the working directory, where to save checkpoints and Tensorboard summaries.
Default: current working directory.
Expand Down
13 changes: 13 additions & 0 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
self,
input_chunk_length: int,
output_chunk_length: int,
train_sample_shape: Optional[Tuple] = None,
loss_fn: nn.modules.loss._Loss = nn.MSELoss(),
torch_metrics: Optional[
Union[torchmetrics.Metric, torchmetrics.MetricCollection]
Expand Down Expand Up @@ -59,6 +60,9 @@ def __init__(
Number of input past time steps per chunk.
output_chunk_length
Number of output time steps per chunk.
train_sample_shape
Shape of the model's input, used to instantiate model without calling ``fit_from_dataset`` and
perform sanity check on new training/inference datasets used for re-training or prediction.
loss_fn
PyTorch loss function used for training.
This parameter will be ignored for probabilistic models if the ``likelihood`` parameter is specified.
Expand Down Expand Up @@ -101,6 +105,9 @@ def __init__(
# by default models are deterministic (i.e. not probabilistic)
self.likelihood = likelihood

# saved in checkpoint to be able to instantiate a model without calling fit_from_dataset
self.train_sample_shape = train_sample_shape

# persist optimiser and LR scheduler parameters
self.optimizer_cls = optimizer_cls
self.optimizer_kwargs = dict() if optimizer_kwargs is None else optimizer_kwargs
Expand Down Expand Up @@ -370,11 +377,17 @@ def _produce_predict_output(self, x: Tuple):
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# we must save the dtype for correct parameter precision at loading time
checkpoint["model_dtype"] = self.dtype
# we must save the shape of the input to be able to instanciate the model without calling fit_from_dataset
checkpoint["train_sample_shape"] = self.train_sample_shape

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# by default our models are initialized as float32. For other dtypes, we need to cast to the correct precision
# before parameters are loaded by PyTorch-Lightning
dtype = checkpoint["model_dtype"]
self.to_dtype(dtype)

def to_dtype(self, dtype):
"""Cast module precision (float32 by default) to another precision."""
if dtype == torch.float16:
self.half()
if dtype == torch.float32:
Expand Down
4 changes: 2 additions & 2 deletions darts/models/forecasting/rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,10 @@ def __init__(
Number of epochs over which to train the model. Default: ``100``.
model_name
Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified,
defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part
defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part
of the name is formatted with the local date and time, while PID is the processed ID (preventing models
spawned at the same time by different processes to share the same model_name). E.g.,
``"2021-06-14_09:53:32_torch_model_run_44607"``.
``"2021-06-14_09_53_32_torch_model_run_44607"``.
work_dir
Path of the working directory, where to save checkpoints and Tensorboard summaries.
Default: current working directory.
Expand Down
4 changes: 2 additions & 2 deletions darts/models/forecasting/tcn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,10 @@ def __init__(
Number of epochs over which to train the model. Default: ``100``.
model_name
Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified,
defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part
defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part
of the name is formatted with the local date and time, while PID is the processed ID (preventing models
spawned at the same time by different processes to share the same model_name). E.g.,
``"2021-06-14_09:53:32_torch_model_run_44607"``.
``"2021-06-14_09_53_32_torch_model_run_44607"``.
work_dir
Path of the working directory, where to save checkpoints and Tensorboard summaries.
Default: current working directory.
Expand Down
4 changes: 2 additions & 2 deletions darts/models/forecasting/tft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,10 +763,10 @@ def __init__(
Number of epochs over which to train the model. Default: ``100``.
model_name
Name of the model. Used for creating checkpoints and saving tensorboard data. If not specified,
defaults to the following string ``"YYYY-mm-dd_HH:MM:SS_torch_model_run_PID"``, where the initial part
defaults to the following string ``"YYYY-mm-dd_HH_MM_SS_torch_model_run_PID"``, where the initial part
of the name is formatted with the local date and time, while PID is the processed ID (preventing models
spawned at the same time by different processes to share the same model_name). E.g.,
``"2021-06-14_09:53:32_torch_model_run_44607"``.
``"2021-06-14_09_53_32_torch_model_run_44607"``.
work_dir
Path of the working directory, where to save checkpoints and Tensorboard summaries.
Default: current working directory.
Expand Down
Loading