-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stabilize covariance learning with FillScaleTriL
and update config behaviour
#163
Stabilize covariance learning with FillScaleTriL
and update config behaviour
#163
Conversation
Codecov Report
@@ Coverage Diff @@
## master #163 +/- ##
==========================================
- Coverage 96.89% 96.85% -0.04%
==========================================
Files 15 15
Lines 1385 1400 +15
==========================================
+ Hits 1342 1356 +14
- Misses 43 44 +1
Flags with carried forward coverage won't be shown. Click here to find out more.
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
Thanks, @patel-zeel! This PR looks great. Will start my review soon.
Is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @patel-zeel, this is an excellent PR. Just a few minor things (please see my comments), and to see if O1 in #127 might be resolved via removing the default jitter and define jitter inside objects/functions akin to gps.py
/variational_families.py
. It would be nice to get O1 done here if we can! :)
Thank you for the review, @daniel-dodd. I have resolved O1 in #127 by moving the jitter inside the function. I have also addressed the other comments. |
@daniel-dodd since we have addressed O1 in #127, we may need to ensure it does not get violated in the future. I think it is hard to detect such things via unit testing (due to the shared global scope among all tests). However, I have added a relevant test ( |
Hi @patel-zeel, thank you for addressing my comments.
In agreement with this. However, the latest test workflow failed on |
Yes, that looks like a way to go :) |
@patel-zeel Nice! Running the test workflow now, will merge soon as the tests pass. :) |
Pull request type
Please check the type of change your PR introduces:
What is the current v/s new behavior?
Issue Number: #127
As discussed in #127, the following is the summary of changes:
tfb.FillTriangular
withtfb.FillScaleTriL
inconfig.py
examples/uncollapsed_vi.pct.py
to illustrate the use of custom transformation for lower triangular Cholesky parameters.get_defaults()
withget_global_config()
. Addreset_global_config()
andget_default_config()
methods..ipynb
and.py
" section in README.md"Custom transformations" section in
examples/uncollapsed_vi.pct.py
I am showing how to use a custom
triangular_transform
ingpjax
(the method also generalizes for any transform).I have added a point that the
Square
bijector may lead to a faster convergence but can be unstable compared toSoftplus
bijector.Config in GPJAX
How is it done currently?
get_defaults
method returns the global config. If the config is unavailable, it creates one.What is the new behavior?
get_global_config
returns the global config. If the config is unavailable, it creates one. If JAX precision changes, it makes appropriate changes to the current global config e.g. updateFillScaleTriL
get_default_config
returns the default config and does not update any global config.reset_global_config
resets global config to default config.Conversion between
.ipynb
and.py
The following quick commands are introduced in README.md to convert between
.ipynb
and.py
.What is not done?
I have not ensured yet that
get_global_config
does not get invoked atimport gpjax as gpx
(O1 in #127). Execution ofGPJax/gpjax/__init__.py
Line 16 in c50d34d
get_global_config
. Any suggestions on how to solve this?