Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distrax reversion #125

Merged
merged 3 commits into from
Oct 23, 2022
Merged

Distrax reversion #125

merged 3 commits into from
Oct 23, 2022

Conversation

thomaspinder
Copy link
Collaborator

Undo the NumPyro work to revert back to Distrax. This is to ensure more stable parameter transforms.

Pull request type

  • Bugfix
  • Feature
  • Code style update (formatting, renaming)
  • [ x ] Refactoring (no functional changes, no api changes)
  • Build related changes
  • Documentation content changes
  • Other (please describe):

@codecov
Copy link

codecov bot commented Oct 23, 2022

Codecov Report

Merging #125 (269f670) into v0.5_update (e6f7f8a) will increase coverage by 0.00%.
The diff coverage is 100.00%.

@@             Coverage Diff              @@
##           v0.5_update     #125   +/-   ##
============================================
  Coverage        99.23%   99.24%           
============================================
  Files               14       14           
  Lines             1172     1185   +13     
============================================
+ Hits              1163     1176   +13     
  Misses               9        9           
Flag Coverage Δ
unittests 99.24% <100.00%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
gpjax/config.py 100.00% <100.00%> (ø)
gpjax/gps.py 100.00% <100.00%> (ø)
gpjax/likelihoods.py 100.00% <100.00%> (ø)
gpjax/parameters.py 95.65% <100.00%> (ø)
gpjax/variational_families.py 100.00% <100.00%> (ø)
gpjax/variational_inference.py 97.59% <100.00%> (ø)

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@thomaspinder thomaspinder merged commit f193bc8 into v0.5_update Oct 23, 2022
@thomaspinder thomaspinder deleted the distrax_reversion branch October 23, 2022 20:53
@theorashid
Copy link
Contributor

Wait, did you implement numpyro and then go back to distrax? Could you explain what happened?

@thomaspinder
Copy link
Collaborator Author

Yes - we replaced Distrax/TFP with NumPyro briefly. The main issue was NumPyro's bijectors were giving some odd behaviour. Despite a like-for-like bijector replacement, we were getting nans during optimisation with NumPyro's bijectors. NumPyro also lacked the ability to exploit structured covariance matrices for faster solves/kl-divergences.

The former issue is perplexing and fundamentally we moved back to Distrax. The latter could have been solved through several PRs but we felt that given our plans to support more structured covariances e.g., kronecker matrices and basis function approximations, implementing our own linear operator and corresponding Gaussian distribution was the most flexible option.

I'd still love to integrate with NumPyro though - their inference suite and broader ML modules is amazing. I know you've experience with NumPyro, so if you have any thoughts on this then I'd be all ears!

@theorashid
Copy link
Contributor

Okay interesting. Good to know. I don't have too much experience with the inner working of numpyro for constructing likelihoods etc. When I get nans in my chains, sometimes running numpyro.enable_x64() does the job, but I don't know your issues exactly.

I'm more than happy with your approach I saw dynamax had taken a similar approach to you for state-space model, where they use tfp jax distrivutions and use blackjax for sampling. Is there much difference between tfp distributions and distrax?

The main benefit to numpyro is the ability to have a GP as part of a larger probabilistic model. If it's possible to integrate GPJax (or dynamax) into larger probabilistic models built with tfp.jax, that would be great – even if tfp is so much harder to use as an end user than numpyro.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants