-
Notifications
You must be signed in to change notification settings - Fork 881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix/Robuster parameters check when loading weights #1952
Conversation
…ing models) to make it more robust. this check can be skipped (not recommended).
Codecov ReportPatch coverage:
❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more. Additional details and impacted files@@ Coverage Diff @@
## master #1952 +/- ##
==========================================
- Coverage 93.87% 93.87% -0.01%
==========================================
Files 132 132
Lines 12677 12688 +11
==========================================
+ Hits 11901 11911 +10
- Misses 776 777 +1
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good thanks @madtoinou 🚀
Had a small comment for first going through all bad params and then giving all of them in the error message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great 🚀
Just had some last minor comments, after which we can merge
) | ||
# param was different at loading model creation | ||
elif self.model_params[param_key] != tfm_save.model_params[param_key]: | ||
# NOTE: for TFTModel, default is None but converted to `QuantileRegression()` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this still relevant?
# NOTE: for TFTModel, default is None but converted to `QuantileRegression()` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it means that if the user manually use QuantileRegression()
as likelihood
for the loading model and the initial model was created with likelihood=None
, an error will be raised despite the resulting model being acceptable. It's probably a corner case that will not occur much but we might want to prevent it?
Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
Fixes #1946.
Summary
Instead of comparing the
saved_model.model_params
with the hyper-parameters contained in the.ckpt
, compare them with theloaded_model.model_params
. Parameters with inconsistent names (TFM/PL Module) are now properly checked (e.g.n_rnn_layers
/num_layers
forRNNModel
).The ignored parameters remain the same to allow the user to change various parameters during the retraining.
Other Information
Also added a new parameter
skip_checks
in case some user wants to load the weights from.ckpt
file without depending on the.pt
file. If set, the encoders cannot be loaded (potentially preventing direct inference) and weights loading related errors will be raised bytorch.load()
(missing keys, shape mismatch, ...).