Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/Robuster parameters check when loading weights #1952

Merged
merged 12 commits into from
Aug 15, 2023

Conversation

madtoinou
Copy link
Collaborator

Fixes #1946.

Summary

Instead of comparing the saved_model.model_params with the hyper-parameters contained in the .ckpt, compare them with the loaded_model.model_params. Parameters with inconsistent names (TFM/PL Module) are now properly checked (e.g. n_rnn_layers/num_layers for RNNModel).

The ignored parameters remain the same to allow the user to change various parameters during the retraining.

Other Information

Also added a new parameter skip_checks in case some user wants to load the weights from .ckpt file without depending on the .pt file. If set, the encoders cannot be loaded (potentially preventing direct inference) and weights loading related errors will be raised by torch.load() (missing keys, shape mismatch, ...).

…ing models) to make it more robust. this check can be skipped (not recommended).
@madtoinou madtoinou changed the title Fix/More robust parameters check when loading weights Fix/Robuster parameters check when loading weights Aug 10, 2023
@codecov-commenter
Copy link

codecov-commenter commented Aug 10, 2023

Codecov Report

Patch coverage: 94.44% and project coverage change: -0.01% ⚠️

Comparison is base (b3463ea) 93.87% compared to head (b748689) 93.87%.

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1952      +/-   ##
==========================================
- Coverage   93.87%   93.87%   -0.01%     
==========================================
  Files         132      132              
  Lines       12677    12688      +11     
==========================================
+ Hits        11901    11911      +10     
- Misses        776      777       +1     
Files Changed Coverage Δ
...arts/models/forecasting/torch_forecasting_model.py 90.85% <93.33%> (+0.08%) ⬆️
darts/utils/likelihood_models.py 95.60% <100.00%> (+0.04%) ⬆️

... and 5 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good thanks @madtoinou 🚀
Had a small comment for first going through all bad params and then giving all of them in the error message.

darts/models/forecasting/torch_forecasting_model.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great 🚀
Just had some last minor comments, after which we can merge

)
# param was different at loading model creation
elif self.model_params[param_key] != tfm_save.model_params[param_key]:
# NOTE: for TFTModel, default is None but converted to `QuantileRegression()`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this still relevant?

Suggested change
# NOTE: for TFTModel, default is None but converted to `QuantileRegression()`

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it means that if the user manually use QuantileRegression() as likelihood for the loading model and the initial model was created with likelihood=None, an error will be raised despite the resulting model being acceptable. It's probably a corner case that will not occur much but we might want to prevent it?

darts/utils/likelihood_models.py Show resolved Hide resolved
madtoinou and others added 3 commits August 14, 2023 11:07
@dennisbader dennisbader merged commit b69b8ca into master Aug 15, 2023
9 checks passed
@dennisbader dennisbader deleted the fix/load_tft_from_ckpt branch August 15, 2023 07:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] TFT model load failed from saved model.
3 participants