Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/document model saving loading #1210

Merged
merged 21 commits into from
Sep 28, 2022
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions docs/userguide/forecasting_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,44 @@ Furthermore, we define the following types of time series consumed by the models
* **Target series:** the series that we are interested in forecasting.
* **Covariate series:** some other series that we are not interested in forecasting, but that can provide valuable inputs to the forecasting model.

## Saving and Loading Models

If you wish to save a particular model and use it elsewhere or at a later point in time, darts can achieve that. It leverages pickle and in the case of Torch models relies on saving PyTorch Lightning trainer checkpoints.
All forecasting models support saving the model on the filesystem, by calling the `save()` function, which saves that particular `ForecastingModel` object instance. When the model is to be used again, the method `load()` can be used. Please note that the methods `save_model()` and `load_model()` are deprecated.

**Example:**
```python
from darts.models import RegressionModel

model = RegressionModel(lags=4)

model.save("my_model.pkl")
model_loaded = RegressionModel.load("my_model.pkl")
```

The parameter `path` specifies a path or file handle under which to save the model at its current state. If no `path` is specified, the model is automatically
saved under ``"{ModelClass}_{YYYY-mm-dd_HH:MM:SS}.pkl"``. E.g., ``"RegressionModel_2020-01-01_12:00:00.pkl"``.
Optionally there is also pickle specific keyword arguments `protocol`, `fix_imports` and `buffer_callback`.
More info: [pickle.dump()](https://docs.python.org/3/library/pickle.html?highlight=dump#pickle.dump)

With torch models, the model parameters and the training state are saved. We use the trainer to save the model checkpoint.

**Example:**
```python
from darts.models import NBEATSModel

model = NBEATSModel(input_chunk_length=24,
output_chunk_length=12)

model.save("my_model.pt")
model_loaded = NBEATSModel.load("my_model.pt")
```

Private methods for torch models used under the hood:

* **save_checkpoint:** In addition, we need to use PTL save_checkpoint() to properly save the trainer and model. It is used to be able to save a snapshot of the model mid-training, and then be able to retrieve the model later.
* **load_from_checkpoint:** returns a model checkpointed during training (by default the one with lowest validation loss).


## Support for multivariate series

Expand Down