Skip to content

Commit

Permalink
Merge pull request #9 from ziatdinovmax/v0.0.3
Browse files Browse the repository at this point in the history
V0.0.3
  • Loading branch information
ziatdinovmax authored Mar 18, 2022
2 parents a07ebab + 3e3826f commit 4e51d0d
Show file tree
Hide file tree
Showing 16 changed files with 1,557 additions and 995 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,17 @@ import gpax

# Get random number generator keys for training and prediction
rng_key, rng_key_predict = gpax.utils.get_keys()

# Obtain/update DKL posterior; input data dimensions are (n, h*w*c)
dkl = gpax.viDKL(input_dim=X.shape[-1], z_dim=2, kernel='RBF')
dkl.fit(rng_key, X_train, y_train, num_steps=100, step_size=0.05)

# Compute UCB acquisition function
obj = gpax.acquisition.UCB(rng_key_predict, dkl, X_unmeasured, maximize=True)
# Select next point to measure (assuming grid data)
next_point_idx = obj.argmax()
# Perform measurement, update trainning data, etc.

# Perform measurement in next_point_idx, update trainning data, etc.
```
The full example is available [here](https://colab.research.google.com/github/ziatdinovmax/gpax/blob/main/examples/gpax_viDKL_plasmons.ipynb). Note that in viDKL, we use a simple MLP as a default feature extractor. However, you can easily write a custom DNN using [haiku](https://github.com/deepmind/dm-haiku) and pass it to the viDKL initializer
```python3
Expand Down
497 changes: 266 additions & 231 deletions examples/GP_sGP.ipynb

Large diffs are not rendered by default.

896 changes: 366 additions & 530 deletions examples/gpax_viDKL_plasmons.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion gpax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from . import utils, kernels, acquisition
from .gp import ExactGP
from .vgp import vExactGP
from .dkl import DKL
from .vidkl import viDKL

__all__ = ["utils", "kernels", "acquisition", "ExactGP", "DKL"]
__all__ = ["utils", "kernels", "acquisition", "ExactGP", "vExactGP", "DKL", "viDKL"]
2 changes: 1 addition & 1 deletion gpax/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = '0.0.2'
version = '0.0.3'
99 changes: 53 additions & 46 deletions gpax/acquisition.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Type, Tuple
from typing import Type, Tuple, Optional

import jax.numpy as jnp
import jax.random as jra
import numpy as onp
import numpyro
import numpyro.distributions as dist

from .gp import ExactGP
Expand All @@ -17,12 +18,12 @@ def EI(rng_key: jnp.ndarray, model: Type[ExactGP],
"""
if model.mcmc is not None:
y_mean, y_sampled = model.predict(rng_key, X, n=n)
if n > 1:
y_sampled = y_sampled.reshape(n * y_sampled.shape[0], -1)
y_sampled = y_sampled.reshape(n * y_sampled.shape[0], -1)
mean, sigma = y_sampled.mean(0), y_sampled.std(0)
u = (mean - y_mean.max() - xi) / sigma
else:
mean, sigma = vi_mean_and_var(model, X, compute_std=True)
mean, var = model.predict(rng_key, X)
sigma = jnp.sqrt(var)
u = (mean - mean.max() - xi) / sigma
u = -u if not maximize else u
normal = dist.Normal(jnp.zeros_like(u), jnp.ones_like(u))
Expand All @@ -39,11 +40,10 @@ def UCB(rng_key: jnp.ndarray, model: Type[ExactGP],
"""
if model.mcmc is not None:
_, y_sampled = model.predict(rng_key, X, n=n)
if n > 1:
y_sampled = y_sampled.reshape(n * y_sampled.shape[0], -1)
y_sampled = y_sampled.reshape(n * y_sampled.shape[0], -1)
mean, var = y_sampled.mean(0), y_sampled.var(0)
else:
mean, var = vi_mean_and_var(model, X)
mean, var = model.predict(rng_key, X)
delta = jnp.sqrt(beta * var)
if maximize:
return mean + delta
Expand All @@ -56,11 +56,10 @@ def UE(rng_key: jnp.ndarray,
"""Uncertainty-based exploration (aka kriging)"""
if model.mcmc is not None:
_, y_sampled = model.predict(rng_key, X, n=n)
if n > 1:
y_sampled = y_sampled.mean(1)
y_sampled = y_sampled.mean(1)
var = y_sampled.var(0)
else:
_, var = vi_mean_and_var(model, X)
_, var = model.predict(rng_key, X)
return var


Expand All @@ -76,65 +75,73 @@ def Thompson(rng_key: jnp.ndarray,
if n > 1:
tsample = tsample.mean(1).squeeze()
else:
_, tsample = model.predict(rng_key, X, n=1)
_, tsample = model.sample_from_posterior(rng_key, X, n=1)
return tsample


def bUCB(rng_key: jnp.ndarray, model: Type[ExactGP],
X: jnp.ndarray, batch_size: int = 4,
beta: float = .25,
maximize: bool = False,
n: int = 100,
n_restarts: int = 20) -> jnp.ndarray:
X: jnp.ndarray, indices: Optional[jnp.ndarray] = None,
batch_size: int = 4, alpha: float = 1.0, beta: float = .25,
maximize: bool = True, n: int = 500,
n_restarts: int = 20, **kwargs) -> jnp.ndarray:
"""
Batch mode for the upper confidence bound
The acquisition function defined as alpha * mu + sqrt(beta) * sigma
that can output a "batch" of next points to evaluate. It takes advantage of
the fact that in MCMC-based GP or DKL we obtain a separate multivariate
normal posterior for each set of sampled kernel hyperparameters.
Args:
rng_key: random number generator key
model: ExactGP or DKL type of model
X: input array
indices: indices of data points in X array. For example, if
each data point is an image patch, the indices should
correspond to their (x, y) coordinates in the original image.
batch_size: desired number of sampled points (default: 4)
alpha: coefficient before mean prediction term (default: 1.0)
beta: coefficient before variance term (default: 0.25)
maximize: sign of variance term (+/- if True/False)
n: number of draws from each multivariate normal posterior
n_restarts: number of restarts to find a batch of maximally
separated points to evaluate next
Returns:
Computed acquisition function with batch x features
or task x batch x features dimensions
"""
if model.mcmc is None:
raise NotImplementedError(
"Currently supports only ExactGP with MCMC inference")
"Currently supports only ExactGP and DKL with MCMC inference")
dist_all, obj_all = [], []
for i in range(n_restarts):
y_sampled = obtain_samples(rng_key, model, X, batch_size, n)
X_ = jnp.array(indices) if indices is not None else X
for _ in range(n_restarts):
y_sampled = obtain_samples(rng_key, model, X, batch_size, n, **kwargs)
mean, var = y_sampled.mean(1), y_sampled.var(1)
delta = jnp.sqrt(beta * var)
if maximize:
obj = mean + delta
points = X[obj.argmax(1)]
obj = alpha * mean + delta
points = X_[obj.argmax(-1)]
else:
obj = mean - delta
points = X[obj.argmin(1)]
d = get_distance(points)
obj = alpha * mean - delta
points = X_[obj.argmin(-1)]
d = jnp.linalg.norm(points, axis=-1).mean(0)
dist_all.append(d)
obj_all.append(obj)
idx = jnp.array(dist_all).argmax()
idx = jnp.array(dist_all).argmax(0)
if idx.ndim > 0:
obj_all = jnp.array(obj_all)
return jnp.array([obj_all[j,:,i] for i, j in enumerate(idx)])
return obj_all[idx]


def obtain_samples(rng_key: jnp.ndarray, model: Type[ExactGP],
X: jnp.ndarray, batch_size: int = 4,
n: int = 500) -> jnp.ndarray:
n: int = 500, **kwargs) -> jnp.ndarray:
posterior_samples = model.get_samples()
idx = onp.arange(0, len(posterior_samples["k_length"]))
onp.random.shuffle(idx)
idx = idx[:batch_size]
samples = {k: v[idx] for (k, v) in posterior_samples.items()}
_, y_sampled = model.predict(rng_key, X, samples, n)
_, y_sampled = model.predict_in_batches(
rng_key, X, kwargs.get("xbatch_size", 500), samples, n)
return y_sampled


def get_distance(points: jnp.ndarray) -> float:
d = []
for p1 in points:
for p2 in points:
d.append(jnp.linalg.norm(p1-p2))
return jnp.array(d).mean().item()


def vi_mean_and_var(model: Type[viDKL], X: jnp.ndarray,
compute_std: bool = False
) -> Tuple[jnp.ndarray]:
mean, cov = model.get_mvn_posterior(X)
var = cov.diagonal()
if compute_std:
return mean, jnp.sqrt(var)
return mean, var
124 changes: 72 additions & 52 deletions gpax/dkl.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from functools import partial
from typing import Callable, Dict, Optional, Tuple
from typing import Callable, Dict, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import jit

from .gp import ExactGP
from .vgp import vExactGP
from .kernels import get_kernel


class DKL(ExactGP):
class DKL(vExactGP):
"""
Fully Bayesian implementation of deep kernel learning
Expand All @@ -20,45 +20,44 @@ class DKL(ExactGP):
z_dim: latent space dimensionality
kernel: type of kernel ('RBF', 'Matern', 'Periodic')
kernel_prior: optional priors over kernel hyperparameters (uses LogNormal(0,1) by default)
bnn_fn: Custom MLP
bnn_fn_prior: Bayesian priors over the weights and biases in bnn_fn
nn: Custom MLP
nn_prior: Bayesian priors over the weights and biases in 'nn'
latent_prior: Optional prior over the latent space (BNN embedding)
"""

def __init__(self, input_dim: int, z_dim: int = 2, kernel: str = 'RBF',
kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
bnn_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None,
bnn_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
nn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None,
nn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
latent_prior: Optional[Callable[[jnp.ndarray], Dict[str, jnp.ndarray]]] = None
) -> None:
super(DKL, self).__init__(input_dim, kernel, kernel_prior)
self.bnn = bnn_fn if bnn_fn else bnn
self.bnn_prior = bnn_fn_prior if bnn_fn_prior else bnn_prior(input_dim, z_dim)
self.nn = nn if nn else mlp
self.nn_prior = nn_prior if nn_prior else mlp_prior(input_dim, z_dim)
self.kernel_dim = z_dim
self.latent_prior = latent_prior

def model(self, X: jnp.ndarray, y: jnp.ndarray) -> None:
"""DKL probabilistic model"""
task_dim = X.shape[0]
# BNN part
bnn_params = self.bnn_prior()
z = self.bnn(X, bnn_params)
bnn_params = self.nn_prior(task_dim)
z = jax.jit(jax.vmap(self.nn))(X, bnn_params)
if self.latent_prior: # Sample latent variable
z = self.latent_prior(z)
# Sample GP kernel parameters
if self.kernel_prior:
kernel_params = self.kernel_prior()
else:
kernel_params = self._sample_kernel_params()
kernel_params = self._sample_kernel_params(task_dim)
# Sample noise
noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0))
with numpyro.plate('obs_noise', task_dim):
noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0))
# GP's mean function
f_loc = jnp.zeros(z.shape[0])
# compute kernel
k = get_kernel(self.kernel)(
z, z,
kernel_params,
noise
)
f_loc = jnp.zeros(z.shape[:2])
# compute kernel(s)
k_args = (z, z, kernel_params, noise)
k = jax.vmap(get_kernel(self.kernel))(*k_args)
# sample y according to the standard Gaussian process formula
numpyro.sample(
"y",
Expand All @@ -67,67 +66,88 @@ def model(self, X: jnp.ndarray, y: jnp.ndarray) -> None:
)

@partial(jit, static_argnames='self')
def get_mvn_posterior(self,
X_new: jnp.ndarray, params: Dict[str, jnp.ndarray]
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Returns parameters (mean and cov) of multivariate normal posterior
for a single sample of DKL hyperparameters
"""
def _get_mvn_posterior(self,
X_train: jnp.ndarray, y_train: jnp.ndarray,
X_new: jnp.ndarray, params: Dict[str, jnp.ndarray]
) -> Tuple[jnp.ndarray, jnp.ndarray]:
noise = params["noise"]
# embed data intot the latent space
z_train = self.bnn(self.X_train, params)
z_test = self.bnn(X_new, params)
# compute kernel matrices for train and test data
k_pp = get_kernel(self.kernel)(z_test, z_test, params, noise)
k_pX = get_kernel(self.kernel)(z_test, z_train, params, jitter=0.0)
# embed data into the latent space
z_train = self.nn(X_train, params)
z_new = self.nn(X_new, params)
# compute kernel matrices for train and new ('test') data
k_pp = get_kernel(self.kernel)(z_new, z_new, params, noise)
k_pX = get_kernel(self.kernel)(z_new, z_train, params, jitter=0.0)
k_XX = get_kernel(self.kernel)(z_train, z_train, params, noise)
# compute the predictive covariance and mean
K_xx_inv = jnp.linalg.inv(k_XX)
cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, self.y_train))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_train))
return mean, cov

@partial(jit, static_argnames='self')
def embed(self, X_new: jnp.ndarray) -> jnp.ndarray:
"""
Embeds data into the latent space using the inferred weights
of the DKL's Bayesian neural network
"""
samples = self.get_samples(chain_dim=False)
predictive = jax.vmap(lambda params: self.bnn(X_new, params))
predictive = jax.vmap(lambda params: self.nn(X_new, params))
z = predictive(samples)
return z

def _set_data(self,
X: jnp.ndarray,
y: Optional[jnp.ndarray] = None
) -> Union[Tuple[jnp.ndarray], jnp.ndarray]:
X = X[None] if X.ndim == 2 else X # add task pseudo-dimension
if y is not None:
y = y[None] if y.ndim == 1 else y # add task pseudo-dimension
return X, y
return X

def _print_summary(self):
list_of_keys = ["k_scale", "k_length", "noise", "period"]
samples = self.get_samples(1)
numpyro.diagnostics.print_summary(
{k: v for (k, v) in samples.items() if k in list_of_keys})


def sample_weights(name: str, in_channels: int, out_channels: int) -> jnp.ndarray:
def sample_weights(name: str, in_channels: int, out_channels: int, task_dim: int) -> jnp.ndarray:
"""Sampling weights matrix"""
return numpyro.sample(name=name, fn=dist.Normal(
loc=jnp.zeros((in_channels, out_channels)),
scale=jnp.ones((in_channels, out_channels))))
with numpyro.plate("batch_dim", task_dim, dim=-3):
w = numpyro.sample(name=name, fn=dist.Normal(
loc=jnp.zeros((in_channels, out_channels)),
scale=jnp.ones((in_channels, out_channels))))
return w


def sample_biases(name: str, channels: int) -> jnp.ndarray:
def sample_biases(name: str, channels: int, task_dim: int) -> jnp.ndarray:
"""Sampling bias vector"""
return numpyro.sample(name=name, fn=dist.Normal(
loc=jnp.zeros((channels)), scale=jnp.ones((channels))))
with numpyro.plate("batch_dim", task_dim, dim=-3):
b = numpyro.sample(name=name, fn=dist.Normal(
loc=jnp.zeros((channels)), scale=jnp.ones((channels))))
return b


def bnn(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Simple Bayesian MLP"""
def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""Simple MLP for a single MCMC sample of weights and biases"""
h1 = jnp.tanh(jnp.matmul(X, params["w1"]) + params["b1"])
h2 = jnp.tanh(jnp.matmul(h1, params["w2"]) + params["b2"])
z = jnp.matmul(h2, params["w3"]) + params["b3"]
return z


def bnn_prior(input_dim: int, zdim: int = 2) -> Dict[str, jnp.array]:
def mlp_prior(input_dim: int, zdim: int = 2) -> Dict[str, jnp.array]:
"""Priors over weights and biases in the default Bayesian MLP"""
hdim = [64, 32]

def _bnn_prior():
w1 = sample_weights("w1", input_dim, hdim[0])
b1 = sample_biases("b1", hdim[0])
w2 = sample_weights("w2", hdim[0], hdim[1])
b2 = sample_biases("b2", hdim[1])
w3 = sample_weights("w3", hdim[1], zdim)
b3 = sample_biases("b3", zdim)
def _bnn_prior(task_dim: int):
w1 = sample_weights("w1", input_dim, hdim[0], task_dim)
b1 = sample_biases("b1", hdim[0], task_dim)
w2 = sample_weights("w2", hdim[0], hdim[1], task_dim)
b2 = sample_biases("b2", hdim[1], task_dim)
w3 = sample_weights("w3", hdim[1], zdim, task_dim)
b3 = sample_biases("b3", zdim, task_dim)
return {"w1": w1, "b1": b1, "w2": w2, "b2": b2, "w3": w3, "b3": b3}

return _bnn_prior
Loading

0 comments on commit 4e51d0d

Please sign in to comment.