diff --git a/docs/userguide/forecasting_overview.md b/docs/userguide/forecasting_overview.md index 8a364eb175..57394bbf83 100644 --- a/docs/userguide/forecasting_overview.md +++ b/docs/userguide/forecasting_overview.md @@ -40,6 +40,44 @@ Furthermore, we define the following types of time series consumed by the models * **Target series:** the series that we are interested in forecasting. * **Covariate series:** some other series that we are not interested in forecasting, but that can provide valuable inputs to the forecasting model. +## Saving and Loading Models + +If you wish to save a particular model and use it elsewhere or at a later point in time, darts can achieve that. It leverages pickle and in the case of Torch models relies on saving PyTorch Lightning trainer checkpoints. +All forecasting models support saving the model on the filesystem, by calling the `save()` function, which saves that particular `ForecastingModel` object instance. When the model is to be used again, the method `load()` can be used. Please note that the methods `save_model()` and `load_model()` are deprecated. + +**Example:** +```python +from darts.models import RegressionModel + +model = RegressionModel(lags=4) + +model.save("my_model.pkl") +model_loaded = RegressionModel.load("my_model.pkl") +``` + +The parameter `path` specifies a path or file handle under which to save the model at its current state. If no `path` is specified, the model is automatically +saved under ``"{ModelClass}_{YYYY-mm-dd_HH:MM:SS}.pkl"``. E.g., ``"RegressionModel_2020-01-01_12:00:00.pkl"``. +Optionally there is also pickle specific keyword arguments `protocol`, `fix_imports` and `buffer_callback`. +More info: [pickle.dump()](https://docs.python.org/3/library/pickle.html?highlight=dump#pickle.dump) + +With torch models, the model parameters and the training state are saved. We use the trainer to save the model checkpoint. + +**Example:** +```python +from darts.models import NBEATSModel + +model = NBEATSModel(input_chunk_length=24, + output_chunk_length=12) + +model.save("my_model.pt") +model_loaded = NBEATSModel.load("my_model.pt") +``` + +Private methods for torch models used under the hood: + +* **save_checkpoint:** In addition, we need to use PTL save_checkpoint() to properly save the trainer and model. It is used to be able to save a snapshot of the model mid-training, and then be able to retrieve the model later. +* **load_from_checkpoint:** returns a model checkpointed during training (by default the one with lowest validation loss). + ## Support for multivariate series