Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaspinder committed Nov 3, 2022
1 parent e5db07b commit 8b02848
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 23 deletions.
3 changes: 1 addition & 2 deletions examples/classification.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
#
# In this notebook we demonstrate how to perform inference for Gaussian process models with non-Gaussian likelihoods via maximum a posteriori (MAP) and Markov chain Monte Carlo (MCMC). We focus on a classification task here and use [BlackJax](https://github.com/blackjax-devs/blackjax/) for sampling.

# %%
import blackjax
import distrax as dx

# %%
import jax
import jax.numpy as jnp
import jax.random as jr
Expand Down
32 changes: 11 additions & 21 deletions examples/kernels.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@
#
# In this guide, we introduce the kernels available in GPJax and demonstrate how to create custom ones.

import distrax as dx

# %%
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
from jax import jit
from jaxtyping import Array, Float
from optax import adam
import distrax as dx

import gpjax as gpx

key = jr.PRNGKey(123)
Expand Down Expand Up @@ -207,29 +209,15 @@ def _initialise_params(self, key: jr.PRNGKey) -> dict:
#
# To define a bijector here we'll make use of the `Lambda` operator given in Distrax. This lets us convert any regular Jax function into a bijection. Given that we require $\tau$ to be strictly greater than $4.$, we'll apply a [softplus transformation](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.softplus.html) where the lower bound is shifted by $4$.

# %%
from gpjax.config import add_parameter, Softplus
from jax.nn import softplus

# class Softplus4p(dx.Bijector):
# def __init__(self):
# super().__init__(event_ndims_in=0)
# self._shift = 4.

# def forward_and_log_det(self, x):
# y = softplus(x) + self._shift
# logdet = -softplus(-x)
# return y, logdet

# def inverse_and_log_det(self, y):
# # Optional. Can be omitted if inverse methods are not needed.
# y = y - self._shift
# x = jnp.log(-jnp.expm1(-y)) + y
# logdet = -jnp.log(-jnp.expm1(-y))
# return x, logdet
# %%
from gpjax.config import Softplus, add_parameter

bij_fn = lambda x: softplus(x + jnp.array(4.0))
bij = dx.Lambda(forward = bij_fn, inverse = lambda y: -jnp.log(-jnp.expm1(-y - 4.))+y-4.)
bij = dx.Lambda(
forward=bij_fn, inverse=lambda y: -jnp.log(-jnp.expm1(-y - 4.0)) + y - 4.0
)

add_parameter("tau", bij)

Expand Down Expand Up @@ -276,7 +264,9 @@ def _initialise_params(self, key: jr.PRNGKey) -> dict:
# We'll now query the GP's predictive posterior at linearly spaced novel inputs and illustrate the results.

# %%
posterior_rv = likelihood(circlular_posterior(D, learned_params)(angles), learned_params)
posterior_rv = likelihood(
circlular_posterior(D, learned_params)(angles), learned_params
)
mu = posterior_rv.mean()
one_sigma = posterior_rv.stddev()

Expand Down

0 comments on commit 8b02848

Please sign in to comment.