Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Jax Inference Methods #503

Open
BradyPlanden opened this issue Sep 16, 2024 · 0 comments
Open

Add Jax Inference Methods #503

BradyPlanden opened this issue Sep 16, 2024 · 0 comments
Labels
enhancement New feature or request

Comments

@BradyPlanden
Copy link
Member

BradyPlanden commented Sep 16, 2024

Feature description

#481 adds the Jaxified IDAKLU solver as an experimental implementation with auto-differentiation applied to the cost/likelihood functions. This issue aims to expand this functionality with Jax inference methods such as:

  • Numpyro for MCMC sampling
  • Optax for frequentist/deterministic inference methods
  • GPJax for Gaussian Processes
  • BlackJax for sampling

Motivation

Jax offers a compiled interface for parameter optimisation with lowering to both GPU/TPU. This can enable both performance improvements for PyBOP's methods, as well as removing the need for manual definition of gradients from cost/likelihoods.

Possible implementation

Design outlines and discussion needs to occur to ensure an integrated development into PyBOP's predefined design.

Additional context

No response

@BradyPlanden BradyPlanden added the enhancement New feature or request label Sep 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
Status: Todo
Development

No branches or pull requests

1 participant