Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/improved training from ckpt #1501
Feat/improved training from ckpt #1501
Changes from 29 commits
e7e92fa
22828d1
c4f4370
75acd53
c6eddc1
4b38347
30603ca
1abcb96
bd4f035
0e71805
5ec58bc
a7be96f
07ac34a
206aa40
247b570
83211be
5a39edd
4d2b77c
44a3fa4
ee00b89
9cc0ac8
8c93454
77447b2
17f9c3d
8e2462f
ce35e8a
167498a
4a18301
192a423
0c6a461
e309390
d13f4a7
96812d8
4304cf1
867ad35
b42d6e1
6b0de3e
845f96e
497420f
72486f8
39ba739
edab120
c002f3e
aa735de
3328835
9d13eaf
b60c9f2
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we hard set
strict=True
we could let PyTorch handle any discrepancies later on when callingThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is what I did initially but the error message are not informative at all, making it quite difficult for the user to realize that the problem comes from the definition of the model into which the weights are loaded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, but on the other hand this requires all TorchForecastingModels and their corresponding PLForecastingModules to share the same model parameter names, which is not the case as you mention (and might be difficult to enforce in some cases).
So the torch error can still be raised, or maybe I'm missing something :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, the torch error can still be raised if the discrepancy is in one of these parameter that do not have the same name in these two objects.
We could eventually have a dict in each model that that tries to map the name of the parameters in order to be able to run this sanity check thoroughly? Or try to catch torch error if
load_state_dict
fails and raise a meaningful message to the user, indicating that the weights mismatch can be caused by invalid parameters (or by a change of convention in torch...)?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I just saw this now. I think we should use the path that the user gave when manually saving the model, i.e.
model.save("my_model.pt")
, rather than the .ckpt path.Then we just replace ".pt" with ".pt.ckpt" and get the checkpoint from there. Check here that the ckpt exists similar to how do it now in
TorchForecastingModel.load()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, it makes the interface much more consistent and intuitive.
load_weights()
now expects the.pt
path and the.ckpt
suffix is added afterward, inside the function.