From 8b0284862a48af3f0bf210984d74f3bafb9b8cce Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Thu, 3 Nov 2022 21:34:14 +0000 Subject: [PATCH] Fix bugs --- examples/classification.pct.py | 3 +-- examples/kernels.pct.py | 32 +++++++++++--------------------- 2 files changed, 12 insertions(+), 23 deletions(-) diff --git a/examples/classification.pct.py b/examples/classification.pct.py index 3ca18d7cb..4c0f59bc7 100644 --- a/examples/classification.pct.py +++ b/examples/classification.pct.py @@ -19,10 +19,9 @@ # # In this notebook we demonstrate how to perform inference for Gaussian process models with non-Gaussian likelihoods via maximum a posteriori (MAP) and Markov chain Monte Carlo (MCMC). We focus on a classification task here and use [BlackJax](https://github.com/blackjax-devs/blackjax/) for sampling. +# %% import blackjax import distrax as dx - -# %% import jax import jax.numpy as jnp import jax.random as jr diff --git a/examples/kernels.pct.py b/examples/kernels.pct.py index 47bd63ce4..902fcc588 100644 --- a/examples/kernels.pct.py +++ b/examples/kernels.pct.py @@ -19,6 +19,8 @@ # # In this guide, we introduce the kernels available in GPJax and demonstrate how to create custom ones. +import distrax as dx + # %% import jax.numpy as jnp import jax.random as jr @@ -26,7 +28,7 @@ from jax import jit from jaxtyping import Array, Float from optax import adam -import distrax as dx + import gpjax as gpx key = jr.PRNGKey(123) @@ -207,29 +209,15 @@ def _initialise_params(self, key: jr.PRNGKey) -> dict: # # To define a bijector here we'll make use of the `Lambda` operator given in Distrax. This lets us convert any regular Jax function into a bijection. Given that we require $\tau$ to be strictly greater than $4.$, we'll apply a [softplus transformation](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.softplus.html) where the lower bound is shifted by $4$. -# %% -from gpjax.config import add_parameter, Softplus from jax.nn import softplus -# class Softplus4p(dx.Bijector): -# def __init__(self): -# super().__init__(event_ndims_in=0) -# self._shift = 4. - -# def forward_and_log_det(self, x): -# y = softplus(x) + self._shift -# logdet = -softplus(-x) -# return y, logdet - -# def inverse_and_log_det(self, y): -# # Optional. Can be omitted if inverse methods are not needed. -# y = y - self._shift -# x = jnp.log(-jnp.expm1(-y)) + y -# logdet = -jnp.log(-jnp.expm1(-y)) -# return x, logdet +# %% +from gpjax.config import Softplus, add_parameter bij_fn = lambda x: softplus(x + jnp.array(4.0)) -bij = dx.Lambda(forward = bij_fn, inverse = lambda y: -jnp.log(-jnp.expm1(-y - 4.))+y-4.) +bij = dx.Lambda( + forward=bij_fn, inverse=lambda y: -jnp.log(-jnp.expm1(-y - 4.0)) + y - 4.0 +) add_parameter("tau", bij) @@ -276,7 +264,9 @@ def _initialise_params(self, key: jr.PRNGKey) -> dict: # We'll now query the GP's predictive posterior at linearly spaced novel inputs and illustrate the results. # %% -posterior_rv = likelihood(circlular_posterior(D, learned_params)(angles), learned_params) +posterior_rv = likelihood( + circlular_posterior(D, learned_params)(angles), learned_params +) mu = posterior_rv.mean() one_sigma = posterior_rv.stddev()