Skip to content

Commit

Permalink
Fix natgrads nb to numpyro.
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-dodd committed Oct 17, 2022
1 parent 388119f commit 189939a
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions examples/natgrads.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# %% [markdown]
# In this notebook, we show how to create natural gradients. Ordinary gradient descent algorithms are an undesirable for variational inference because we are minimising the KL divergence between distributions rather than a set of parameters directly. Natural gradients, on the other hand, accounts for the curvature induced by the KL divergence that has the capacity to considerably improve performance (see e.g., <strong data-cite="salimbeni2018">Salimbeni et al. (2018)</strong> for further details).

# %%
# %% vscode={"languageId": "python"}
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
Expand All @@ -41,7 +41,7 @@
#
# We store our data $\mathcal{D}$ as a GPJax `Dataset` and create test inputs for later.

# %%
# %% vscode={"languageId": "python"}
n = 5000
noise = 0.2

Expand All @@ -56,7 +56,7 @@
# %% [markdown]
# Intialise inducing points:

# %%
# %% vscode={"languageId": "python"}
z = jnp.linspace(-5.0, 5.0, 20).reshape(-1, 1)

fig, ax = plt.subplots(figsize=(12, 5))
Expand All @@ -71,7 +71,7 @@
# %% [markdown]
# We begin by defining our model, variational family and variational inference strategy:

# %%
# %% vscode={"languageId": "python"}
likelihood = gpx.Gaussian(num_datapoints=n)
kernel = gpx.RBF()
prior = gpx.Prior(kernel=kernel)
Expand All @@ -86,7 +86,7 @@
# %% [markdown]
# Next, we can conduct natural gradients as follows:

# %%
# %% vscode={"languageId": "python"}
inference_state = gpx.fit_natgrads(
natural_svgp,
parameter_state=parameter_state,
Expand All @@ -103,12 +103,12 @@
# %% [markdown]
# Here is the fitted model:

# %%
# %% vscode={"languageId": "python"}
latent_dist = natural_q(learned_params)(xtest)
predictive_dist = likelihood(latent_dist, learned_params)

meanf = predictive_dist.mean()
sigma = predictive_dist.stddev()
meanf = predictive_dist.mean
sigma = jnp.sqrt(predictive_dist.variance)

fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(x, y, "o", alpha=0.15, label="Training Data", color="tab:gray")
Expand All @@ -126,7 +126,7 @@
# %% [markdown]
# As mentioned in <strong data-cite="hensman2013gaussian">Hensman et al. (2013)</strong>, in the case of a Gaussian likelihood, taking a step of unit length for natural gradients on a full batch of data recovers the same solution as <strong data-cite="titsias2009">Titsias (2009)</strong>. We now illustrate this.

# %%
# %% vscode={"languageId": "python"}
n = 1000
noise = 0.2

Expand All @@ -139,7 +139,7 @@

xtest = jnp.linspace(-5.5, 5.5, 500).reshape(-1, 1)

# %%
# %% vscode={"languageId": "python"}
z = jnp.linspace(-5.0, 5.0, 20).reshape(-1, 1)

fig, ax = plt.subplots(figsize=(12, 5))
Expand All @@ -148,7 +148,7 @@
[ax.axvline(x=z_i, color="black", alpha=0.3, linewidth=1) for z_i in z]
plt.show()

# %%
# %% vscode={"languageId": "python"}
likelihood = gpx.Gaussian(num_datapoints=n)
kernel = gpx.RBF()
prior = gpx.Prior(kernel=kernel)
Expand All @@ -157,7 +157,7 @@
# %% [markdown]
# We begin with natgrads:

# %%
# %% vscode={"languageId": "python"}
from gpjax.natural_gradients import natural_gradients

q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z)
Expand Down Expand Up @@ -186,7 +186,7 @@
# %% [markdown]
# Let us now run it for SGPR:

# %%
# %% vscode={"languageId": "python"}
q = gpx.CollapsedVariationalGaussian(
prior=prior, likelihood=likelihood, inducing_inputs=z
)
Expand All @@ -206,6 +206,6 @@
# %% [markdown]
# ## System configuration

# %%
# %% vscode={"languageId": "python"}
# %reload_ext watermark
# %watermark -n -u -v -iv -w -a 'Daniel Dodd'

0 comments on commit 189939a

Please sign in to comment.