-
Notifications
You must be signed in to change notification settings - Fork 881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/document model saving loading #1210
Feat/document model saving loading #1210
Conversation
Codecov ReportBase: 94.02% // Head: 94.01% // Decreases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## master #1210 +/- ##
==========================================
- Coverage 94.02% 94.01% -0.01%
==========================================
Files 73 73
Lines 8215 8203 -12
==========================================
- Hits 7724 7712 -12
Misses 491 491
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
…jkocbek/darts into feat/document-model-saving-loading
@@ -40,6 +40,54 @@ Furthermore, we define the following types of time series consumed by the models | |||
* **Target series:** the series that we are interested in forecasting. | |||
* **Covariate series:** some other series that we are not interested in forecasting, but that can provide valuable inputs to the forecasting model. | |||
|
|||
## Reproducibility | |||
|
|||
If the user wishes to save a particular model and use it elsewhere or at a later point in time, darts leverages pickle to achieve that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpicking] You shouldn't talk about the user in third person. Either don't address the reader, or say "you".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noted and fixed.
@@ -40,6 +40,54 @@ Furthermore, we define the following types of time series consumed by the models | |||
* **Target series:** the series that we are interested in forecasting. | |||
* **Covariate series:** some other series that we are not interested in forecasting, but that can provide valuable inputs to the forecasting model. | |||
|
|||
## Reproducibility | |||
|
|||
If the user wishes to save a particular model and use it elsewhere or at a later point in time, darts leverages pickle to achieve that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Darts does not only leverage pickle. Torch models rely on saving PyTorch Lightning trainer checkpoints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lightning trainer checkpoints are now mentioned in the description.
model.save("my_model.pkl") | ||
model_loaded = RegressionModel.load("my_model.pkl") | ||
|
||
from darts.models import NaiveSeasonal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this second example? I'd say maybe the one above is enough. Also the one below seems to be saving to a .py
file?
|
||
Special class methods for torch models: | ||
|
||
* **save_checkpoint** In addition, we need to use PTL save_checkpoint() to properly save the trainer and model. It is used to be able to save a snapshot of the model mid-training, and then be able to retrieve the model later. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we offer this as a method?
refine load from checkpoint explanation Co-authored-by: Julien Herzen <j.herzen@gmail.com>
refine what gets saved description Co-authored-by: Julien Herzen <j.herzen@gmail.com>
refine general description Co-authored-by: Julien Herzen <j.herzen@gmail.com>
…jkocbek/darts into feat/document-model-saving-loading
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just made a couple of suggestions.
model = NBEATSModel(input_chunk_length=24, | ||
output_chunk_length=12) | ||
|
||
model.fit([series1, series2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we don't even need this line. We can save a model that's not been fitted yet too.
Update title Co-authored-by: Julien Herzen <j.herzen@gmail.com>
Update saved model file name Co-authored-by: Julien Herzen <j.herzen@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Fixes #.
Summary
Add Reproducibility section in the Forecasting Overview part of the User Guide, describing model saving and loading.
Other Information
Still WIP, as more detail and care is needed for torch models.