-
Notifications
You must be signed in to change notification settings - Fork 881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add rin to all torch models #1969
Conversation
@dennisbader, thanks for implementing this since I didn't have time. I don't know what dataset or training parameters you used to test the performance change; however, RIN isn't a guarantee to improve performance. In the TiDE paper they showed that it didn't always help (table 8, https://arxiv.org/pdf/2304.08424.pdf) so it could be the case for the Transformer and RNNs. |
Hi @alexcolpitts96 , I actually just applied to models to your notebook example from TiDE against NHiTS. |
Thank you for adding this! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks great, good job @dennisbader 🚀
@dennisbader I did some digging into what it might take to improve performance for Transformers and RNNs using RIN. I agree with @madtoinou that there seems to be some underlying recurrent architectural problem. I was able to get Transformers to give less bad performance; however, it was still pretty awful. From section 4.2 in this paper:
I have found this in my own experience as well where attention only improves performance when applied to some alternate representation (like Seq2Seq context vectors) and even then the improvement was nearly negligible. As for the recurrent problem? I will see if I can find some explanation, but I think pushing out RIN globally without understanding the RNN problem should be fine. I was also wondering what your thoughts were about moving to a weekly (or at least regular) release schedule? Patch releases (0.25.x) could help push out builds with features and bug fixes sooner. |
I was mostly concerned by TCNModel and RNNModel (transformers performed at least not worse than vanilla models). After another test it seems that TCNModel performs okay when predicting with n <= output_chunk_length. So we can keep support for it. We will ignore user supplied |
@alexcolpitts96, we'll try to release more frequently in the future. Next release is planned for next week. |
Codecov ReportPatch coverage is ❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the GitHub App Integration for your organization. Read more. 📢 Thoughts on this report? Let us know!. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good to me 🚀, can't wait to see the performance gain across all the DL models!
Fixes #1121
Summary
RINorm
to all TorchForecastingModels with model creation parameteruse_reversible_instance_norm
All models except RNN and Transformer have better performance with RIN. Have to investigate what's going on for these two models: