-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distrax reversion #125
Distrax reversion #125
Conversation
Codecov Report
@@ Coverage Diff @@
## v0.5_update #125 +/- ##
============================================
Coverage 99.23% 99.24%
============================================
Files 14 14
Lines 1172 1185 +13
============================================
+ Hits 1163 1176 +13
Misses 9 9
Flags with carried forward coverage won't be shown. Click here to find out more.
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
Wait, did you implement numpyro and then go back to distrax? Could you explain what happened? |
Yes - we replaced Distrax/TFP with NumPyro briefly. The main issue was NumPyro's bijectors were giving some odd behaviour. Despite a like-for-like bijector replacement, we were getting nans during optimisation with NumPyro's bijectors. NumPyro also lacked the ability to exploit structured covariance matrices for faster solves/kl-divergences. The former issue is perplexing and fundamentally we moved back to Distrax. The latter could have been solved through several PRs but we felt that given our plans to support more structured covariances e.g., kronecker matrices and basis function approximations, implementing our own linear operator and corresponding Gaussian distribution was the most flexible option. I'd still love to integrate with NumPyro though - their inference suite and broader ML modules is amazing. I know you've experience with NumPyro, so if you have any thoughts on this then I'd be all ears! |
Okay interesting. Good to know. I don't have too much experience with the inner working of numpyro for constructing likelihoods etc. When I get nans in my chains, sometimes running I'm more than happy with your approach I saw dynamax had taken a similar approach to you for state-space model, where they use tfp jax distrivutions and use blackjax for sampling. Is there much difference between tfp distributions and distrax? The main benefit to numpyro is the ability to have a GP as part of a larger probabilistic model. If it's possible to integrate GPJax (or dynamax) into larger probabilistic models built with tfp.jax, that would be great – even if tfp is so much harder to use as an end user than numpyro. |
Undo the NumPyro work to revert back to Distrax. This is to ensure more stable parameter transforms.
Pull request type