-
Notifications
You must be signed in to change notification settings - Fork 881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/improved training from ckpt #1501
Conversation
…ode, allows user to change the optimizer, scheduler or trainer and export the ckpt of this fine-tuned model into another folder. fine-tuning cannot be chained using this method (original model ckpt must be reloaded)
…er control over the logger, made the function static
…int is likely to be overwritten if the model is trained with default parameters)
… chain-fine tuning
Supercool! Thanks @madtoinou for making the effort! It looks way better then the "raw" version. |
Codecov ReportBase: 94.06% // Head: 94.03% // Decreases project coverage by
📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more Additional details and impacted files@@ Coverage Diff @@
## master #1501 +/- ##
==========================================
- Coverage 94.06% 94.03% -0.04%
==========================================
Files 125 125
Lines 11095 11123 +28
==========================================
+ Hits 10437 10459 +22
- Misses 658 664 +6
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
…/unit8co/darts into feat/improved-training-from-ckpt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That looks quite good but I'm a bit hesitant about introducing a new method. Could you maybe simply improve load_from_checkpoint()
and fix the epoch
issue in fit()
?
def setup_finetuning( | ||
old_model_name: str, | ||
new_model_name: str = None, | ||
additional_epochs: int = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this parameter, given that fit()
already accepts epochs
?
I would find it cleaner to rely exclusively on fit()
's parameter. If there's a problem with it, could we maybe fix it there (i.e. handle the trainer correctly in fit()
to handle epoch
)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is an issue with the fit()
parameter (#1495), I think that @rijkvandermeulen is already working on a fix. I will remove the epochs
argument from this method and wait for the patch to be merged.
Took your comments into account;
|
@madtoinou when I try to use your branch and load a model, I get a The context is: I trained the model on a machine with different home folder name now I want to load it in a new machine with new folder name (which I give in as I tried to trace through, and one thing is suspicious: Are you sure that this line here should be Am I messing something up? Please advise! |
For me, this helped:
Don't know if this is too crude, though! |
And also - and not wanting to be horrible here - what about this line? Shouldn't it be more like
I mean the ( is not making sense for me in this case. But again, I might be missing the point here... |
Ok, to be more specific:
|
Ok, I am starting to become annoying. Sorry! Some suggested refinement: As it is now implemented here if one does not give any new But in my use case, I want to finetune my model, so I definitely want to get rid of the scheduler it was loaded with. I give in an explicit Suggestion: Adding the ability to enter What do you think @madtoinou ? Any better solutions? |
Additional observation: If I used a RAdam during training and I want to switch to SGD after reload, I get a KeyError about momentum. Cause: currently the optimizer's param dict is not overwritten, but appended to, so whatever I put in load_from_checkpoint, I can not get rid of the momentum param. This is the same type of issue as the one in my previous comment. Maybe plainly overwriting things as default would be a better idea? |
After discussing offline with some others contributors, I decided to refactor the feature and instead of reloading all the attributes of the model's present in the checkpoint (which are relevant only when resuming training, not really for retraining/inference), the user will have to instantiate a new model and then load only the weights from the checkpoint (the method taking care of running some sanity check and initialization of the model without having to call It's almost done, I am currently writing the test and making sure that I am not overlooking a corner case. I'll try to update the PR in the next few days. |
Cool! 🎉 I tested the branch a lot, and I think yours is the good way. Too many combinations of crazy stuff can happen. Looking forward to the new version! I will test it as soon as I can! |
One remark: This may inform your design decisions. There would be some benefit in some "setup for fit" like functionality over and beyond just fit itself. |
…tes of an existing model, rather load the weights into a new model (but not the other attributes such as the optimizer, trainer, ...
…will retrieve and copy the original .ckpt file to avoid unexpected behaviors
Since it is also related to the original purpose of this PR, I also included a fix for #1561. Now, If a model is saved directly after loading it from a checkpoint, the original ".ckpt" checkpoint is duplicated so that this model can easily be loaded from the path given to I also added a bit of logic in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again for all the updates and for fixing the other issues 👍
Had some last minor suggestions, and a fix for loading the models with the correct dtype to load identical weights and produce identical forecasts between original and loaded model.
ckpt_hyper_params = ckpt["hyper_parameters"] | ||
|
||
# verify that the arguments passed to the constructor match those of the checkpoint | ||
for param_key, param_value in self.model_params.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, but on the other hand this requires all TorchForecastingModels and their corresponding PLForecastingModules to share the same model parameter names, which is not the case as you mention (and might be difficult to enforce in some cases).
So the torch error can still be raised, or maybe I'm missing something :)
Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
…rning will be raised in the load() call if no weights can be loaded
…utes and seconds, updated doc accordingly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really great work @madtoinou , thanks for that!
Just one last change (that I missed earlier, sorry) and then it's ready to merged!
|
||
Parameters | ||
---------- | ||
path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I just saw this now. I think we should use the path that the user gave when manually saving the model, i.e. model.save("my_model.pt")
, rather than the .ckpt path.
Then we just replace ".pt" with ".pt.ckpt" and get the checkpoint from there. Check here that the ckpt exists similar to how do it now in TorchForecastingModel.load()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, it makes the interface much more consistent and intuitive.
load_weights()
now expects the .pt
path and the .ckpt
suffix is added afterward, inside the function.
Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this great PR, everything looks good now 🚀
Feel free to merge!
* feat: new function fit_from_checkpoint that load one chkpt from the mode, allows user to change the optimizer, scheduler or trainer and export the ckpt of this fine-tuned model into another folder. fine-tuning cannot be chained using this method (original model ckpt must be reloaded) * fix: improved the model saving to allow chaining of fine-tuning, better control over the logger, made the function static * feat: allow to save the checkpoint in the same folder (loaded checkpoint is likely to be overwritten if the model is trained with default parameters) * fix: ordered arguments in a more intuitive way * fix: saving model after updating all the parameters to facilitate the chain-fine tuning * feat: support for load_from_checkpoint kwargs, support for force_reset argument * feat: adding test for setup_finetuning * fix: fused the setup_finetuning and load_from_checkpoint methods, added dcostring, updated tests * fix: changed the API/approach, instead of trying to overwrite attributes of an existing model, rather load the weights into a new model (but not the other attributes such as the optimizer, trainer, ... * fix: convertion of hyper-parameters to list when checking compatibility between checkpoint and instantiated model * fix: skip the None attribute during the hp check * fix: removed unecessary attribute initialization * feat: pl_forecasting_module also save the train_sample in the checkpoints * fix: saving only shape instead of the sample itself * fix: restore the self.train_sample in TorchForecastingModel * fix: update fit_called attribute to enable inference without retraining * fix: the mock train_sample must be converted to tuple * fix: tweaked model parameters to improve convergence * fix: increased number of epochs to improve convergence/test stability * fix: addressing review comments; added load_weights method and corresponding tests, updated documentation * fix: changed default checkpoint path name for compatibility with Windows OS * feat: raise error if the checkpoint being loaded does not contain the train_sample_shape entry, to make the break more transparent to users * fix: saving model manually directly after laoding it from checkpoint will retrieve and copy the original .ckpt file to avoid unexpected behaviors * fix: use random_state to fix randomness in tests * fix: restore newlines * fix: casting dtype of PLModule before loading the weights * doc: model_name docstring and code were not consistent * doc: improve phrasing * Apply suggestions from code review Co-authored-by: Dennis Bader <dennis.bader@gmx.ch> * fix: removed warning in saving about trainer/ckpt not being found, warning will be raised in the load() call if no weights can be loaded * fix: uniformised filename convention using '_' to separate hours, minutes and seconds, updated doc accordingly * fix: removed typo * Update darts/models/forecasting/torch_forecasting_model.py Co-authored-by: Dennis Bader <dennis.bader@gmx.ch> * fix: more consistent use of the path argument during save and load --------- Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
Fixes #1109, #1090, #1495 and #1471.
Summary
Implement
setup_finetuning
; wrapping the loading of a checkpoint and the setup of the various elements related to the training of a model (trainer, optimizer, lr_scheduler) and returning a model instance which can then directly be trained using thefit
method. For scenario where the training of a model is resumed from a checkpoint due to an error/crash,load_from_checkpoint
remain the most efficient approach.Other Information
The number of additional training epochs can be provided either using directly the
additional_epochs
argument or using thetrainer_params
dict (if both are provided, a sanity check is performed).The fine-tuned model checkpoints can either be saved in the same folder as the "original one" or in a different folder (recommended to avoid overwriting the loaded checkpoint).
The fine-tuning can be chained (the _model.ckpt.tar file is created when calling
setup_finetuning
, after updating the trainer parameters) to give granularity to the user and avoid unexpected behaviors.A big thank to @solalatus for providing a gist with all the attributes to update.
I tried to find methods that would allow to directly update the
model.model_params
attribute but it seems to be performed by PytorchLightning.Example of use: