Quickstart | Install guide | Documentation | Slack Community
GPJax aims to provide a low-level interface to Gaussian process (GP) models in Jax, structured to give researchers maximum flexibility in extending the code to suit their own needs. The idea is that the code should be as close as possible to the maths we write on paper when working with GP models.
GPJax was founded by Thomas Pinder. Today, the maintenance of GPJax is undertaken by Thomas Pinder and Daniel Dodd.
We would be delighted to receive contributions from interested individuals and groups. To learn how you can get involved, please read our guide for contributing. If you have any questions, we encourage you to open an issue. For broader conversations, such as best GP fitting practices or questions about the mathematics of GPs, we invite you to open a discussion.
Feel free to join our Slack Channel, where we can discuss the development of GPJax and broader support for Gaussian process modelling.
Above examples are stored in examples directory in the double percent (py:percent
) format. Checkout jupytext using-cli for more info.
- To convert
example.py
toexample.ipynb
, run:
jupytext --to notebook example.py
- To convert
example.ipynb
toexample.py
, run:
jupytext --to py:percent example.ipynb
Let us import some dependencies and simulate a toy dataset
import gpjax as gpx
from jax import grad, jit
import jax.numpy as jnp
import jax.random as jr
import jaxkern as jk
import optax as ox
key = jr.PRNGKey(123)
f = lambda x: 10 * jnp.sin(x)
n = 50
x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,1)).sort()
y = f(x) + jr.normal(key, shape=(n,1))
D = gpx.Dataset(X=x, y=y)
The function of interest here,
We begin by defining a zero-mean Gaussian process prior with a radial basis function kernel and assume the likelihood to be Gaussian.
prior = gpx.Prior(kernel = jk.RBF())
likelihood = gpx.Gaussian(num_datapoints = n)
Similar to how we would write on paper, the posterior is constructed by the product of our prior with our likelihood.
posterior = prior * likelihood
Equipped with the posterior, we seek to learn the model's hyperparameters through gradient-optimisation of the marginal log-likelihood. We this below, adding Jax's just-in-time (JIT) compilation to accelerate training.
mll = jit(posterior.marginal_log_likelihood(D, negative=True))
For purposes of optimisation, we'll use optax's Adam.
opt = ox.adam(learning_rate=1e-3)
We define an initial parameter state through the initialise
callable.
parameter_state = gpx.initialise(posterior, key=key)
Finally, we run an optimisation loop using the Adam optimiser via the fit
callable.
inference_state = gpx.fit(mll, parameter_state, opt, num_iters=500)
Using our learned hyperparameters, we can obtain the posterior distribution of the latent function at novel test points.
learned_params, _ = inference_state.unpack()
xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)
latent_distribution = posterior(learned_params, D)(xtest)
predictive_distribution = likelihood(learned_params, latent_distribution)
predictive_mean = predictive_distribution.mean()
predictive_cov = predictive_distribution.covariance()
The latest stable version of GPJax can be installed via pip
:
pip install gpjax
Note
We recommend you check your installation version:
python -c 'import gpjax; print(gpjax.__version__)'
Warning
This version is possibly unstable and may contain bugs.
Clone a copy of the repository to your local machine and run the setup configuration in development mode.
git clone https://github.com/JaxGaussianProcesses/GPJax.git
cd GPJax
python setup.py develop
Note
We advise you create virtual environment before installing:
conda create -n gpjax_experimental python=3.10.0 conda activate gpjax_experimental
and recommend you check your installation passes the supplied unit tests:
python -m pytest tests/
If you use GPJax in your research, please cite our JOSS paper.
@article{Pinder2022,
doi = {10.21105/joss.04455},
url = {https://doi.org/10.21105/joss.04455},
year = {2022},
publisher = {The Open Journal},
volume = {7},
number = {75},
pages = {4455},
author = {Thomas Pinder and Daniel Dodd},
title = {GPJax: A Gaussian Process Framework in JAX},
journal = {Journal of Open Source Software}
}