Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update dependencies #442

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ import jax.numpy as jnp
import jax.random as jr
import optax as ox

key = jr.PRNGKey(123)
key = jr.key(123)

f = lambda x: 10 * jnp.sin(x)

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Kernels:
params = [[10, 100, 500, 1000, 2000], [1, 2, 5]]

def setup(self, n_datapoints: int, n_dims: int):
key = jr.PRNGKey(123)
key = jr.key(123)
self.X = jr.uniform(
key=key, minval=-3.0, maxval=3.0, shape=(n_datapoints, n_dims)
)
Expand Down
32 changes: 0 additions & 32 deletions benchmarks/linops.py

This file was deleted.

12 changes: 6 additions & 6 deletions benchmarks/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ class Gaussian:
params = [[10, 100, 200, 500, 1000], [1, 2, 5]]

def setup(self, n_datapoints: int, n_dims: int):
key = jr.PRNGKey(123)
key = jr.key(123)
self.X = jr.normal(key=key, shape=(n_datapoints, n_dims))
self.y = jnp.sin(self.X[:, :1])
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
self.objective = gpx.ConjugateMLL()
self.objective = gpx.objectives.ConjugateMLL()
self.posterior = self.prior * self.likelihood

def time_eval(self, n_datapoints: int, n_dims: int):
Expand All @@ -42,15 +42,15 @@ class Bernoulli:
params = [[10, 100, 200, 500, 1000], [1, 2, 5]]

def setup(self, n_datapoints: int, n_dims: int):
key = jr.PRNGKey(123)
key = jr.key(123)
self.X = jr.normal(key=key, shape=(n_datapoints, n_dims))
self.y = jnp.where(jnp.sin(self.X[:, :1]) > 0, 1, 0)
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n)
self.objective = gpx.LogPosteriorDensity()
self.objective = gpx.objectives.LogPosteriorDensity()
self.posterior = self.prior * self.likelihood

def time_eval(self, n_datapoints: int, n_dims: int):
Expand All @@ -68,7 +68,7 @@ class Poisson:
params = [[10, 100, 200, 500, 1000], [1, 2, 5]]

def setup(self, n_datapoints: int, n_dims: int):
key = jr.PRNGKey(123)
key = jr.key(123)
self.X = jr.normal(key=key, shape=(n_datapoints, n_dims))
f = lambda x: 2.0 * jnp.sin(3 * x) + 0.5 * x # latent function
self.y = jr.poisson(key, jnp.exp(f(self.X)))
Expand All @@ -77,7 +77,7 @@ def setup(self, n_datapoints: int, n_dims: int):
meanf = gpx.mean_functions.Constant()
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Poisson(num_datapoints=self.data.n)
self.objective = gpx.LogPosteriorDensity()
self.objective = gpx.objectives.LogPosteriorDensity()
self.posterior = self.prior * self.likelihood

def time_eval(self, n_datapoints: int, n_dims: int):
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class Gaussian:
params = [[100, 200, 500, 1000, 2000, 3000], [1, 2, 5]]

def setup(self, n_test: int, n_dims: int):
key = jr.PRNGKey(123)
key = jr.key(123)
self.X = jr.normal(key=key, shape=(100, n_dims))
self.y = jnp.sin(self.X[:, :1])
self.data = gpx.Dataset(X=self.X, y=self.y)
Expand All @@ -39,7 +39,7 @@ class Bernoulli:
params = [[100, 200, 500, 1000, 2000, 3000], [1, 2, 5]]

def setup(self, n_test: int, n_dims: int):
key = jr.PRNGKey(123)
key = jr.key(123)
self.X = jr.normal(key=key, shape=(100, n_dims))
self.y = jnp.sin(self.X[:, :1])
self.y = jnp.array(jnp.where(self.y > 0, 1, 0), dtype=jnp.float64)
Expand All @@ -64,7 +64,7 @@ class Poisson:
params = [[100, 200, 500, 1000, 2000, 3000], [1, 2, 5]]

def setup(self, n_test: int, n_dims: int):
key = jr.PRNGKey(123)
key = jr.key(123)
self.X = jr.normal(key=key, shape=(100, n_dims))
f = lambda x: 2.0 * jnp.sin(3 * x) + 0.5 * x # latent function
self.y = jnp.array(jr.poisson(key, jnp.exp(f(self.X))), dtype=jnp.float64)
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Sparse:
params = [[2000, 5000, 10000, 20000], [10, 20, 50, 100, 200]]

def setup(self, n_datapoints: int, n_inducing: int):
key = jr.PRNGKey(123)
key = jr.key(123)
self.X = jr.normal(key=key, shape=(n_datapoints, 1))
self.y = jnp.sin(self.X[:, :1])
self.data = gpx.Dataset(X=self.X, y=self.y)
Expand All @@ -24,10 +24,10 @@ def setup(self, n_datapoints: int, n_inducing: int):
self.posterior = self.prior * self.likelihood

Z = jnp.linspace(self.X.min(), self.X.max(), n_inducing).reshape(-1, 1)
self.q = gpx.CollapsedVariationalGaussian(
self.q = gpx.variational_families.CollapsedVariationalGaussian(
posterior=self.posterior, inducing_inputs=Z
)
self.objective = gpx.CollapsedELBO(negative=True)
self.objective = gpx.objectives.CollapsedELBO(negative=True)

def time_eval(self, n_datapoints: int, n_dims: int):
self.objective(self.q, self.data)
Expand Down
12 changes: 7 additions & 5 deletions benchmarks/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Sparse:
params = [[10000, 20000, 50000], [10, 20, 50, 100, 200], [32, 64, 128, 256]]

def setup(self, n_datapoints: int, n_inducing: int, batch_size: int):
key = jr.PRNGKey(123)
key = jr.key(123)
self.X = jr.normal(key=key, shape=(n_datapoints, 1))
self.y = jnp.sin(self.X[:, :1])
self.data = gpx.Dataset(X=self.X, y=self.y)
Expand All @@ -25,15 +25,17 @@ def setup(self, n_datapoints: int, n_inducing: int, batch_size: int):
self.posterior = self.prior * self.likelihood

Z = jnp.linspace(self.X.min(), self.X.max(), n_inducing).reshape(-1, 1)
self.q = gpx.VariationalGaussian(posterior=self.posterior, inducing_inputs=Z)
self.objective = gpx.ELBO(negative=True)
self.q = gpx.variational_families.VariationalGaussian(
posterior=self.posterior, inducing_inputs=Z
)
self.objective = gpx.objectives.ELBO(negative=True)

def time_eval(self, n_datapoints: int, n_dims: int, batch_size: int):
key = jr.PRNGKey(123)
key = jr.key(123)
batch = get_batch(train_data=self.data, batch_size=batch_size, key=key)
self.objective(self.q, batch)

def time_grad(self, n_datapoints: int, n_dims: int, batch_size: int):
key = jr.PRNGKey(123)
key = jr.key(123)
batch = get_batch(train_data=self.data, batch_size=batch_size, key=key)
jax.grad(self.objective)(self.q, batch)
2 changes: 1 addition & 1 deletion docs/_static/jaxkern/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# import gpjax.kernels as jk

# key = jr.PRNGKey(123)
# key = jr.key(123)


# def set_font(font_path):
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/barycentres.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import gpjax as gpx


key = jr.PRNGKey(123)
key = jr.key(123)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/bayesian_optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from gpjax.typing import Array, FunctionalSample, ScalarFloat
from jaxopt import ScipyBoundedMinimize

key = jr.PRNGKey(42)
key = jr.key(42)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

tfd = tfp.distributions
identity_matrix = jnp.eye
key = jr.PRNGKey(123)
key = jr.key(123)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/collapsed_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx

key = jr.PRNGKey(123)
key = jr.key(123)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/constructing_new_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
import gpjax as gpx
from gpjax.base.param import param_field

key = jr.PRNGKey(123)
key = jr.key(123)
tfb = tfp.bijectors
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/decision_making.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
Float,
)

key = jr.PRNGKey(42)
key = jr.key(42)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/deep_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from gpjax.kernels.base import AbstractKernel
from gpjax.kernels.computations import AbstractKernelComputation

key = jr.PRNGKey(123)
key = jr.key(123)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down Expand Up @@ -103,7 +103,7 @@ class DeepKernelFunction(AbstractKernel):
base_kernel: AbstractKernel = None
network: nn.Module = static_field(None)
dummy_x: jax.Array = static_field(None)
key: jr.PRNGKeyArray = static_field(jr.PRNGKey(123))
key: jax.Array = static_field(jr.key(123))
nn_params: Any = field(init=False, repr=False)

def __post_init__(self):
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/graph_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx

key = jr.PRNGKey(123)
key = jr.key(123)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/intro_to_gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
# determines the correlation of the multivariate Gaussian.

# %%
key = jr.PRNGKey(123)
key = jr.key(123)

d1 = tfd.MultivariateNormalDiag(loc=jnp.zeros(2), scale_diag=jnp.ones(2))
d2 = tfd.MultivariateNormalTriL(
Expand Down
13 changes: 10 additions & 3 deletions docs/examples/intro_to_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from gpjax.typing import Array
from sklearn.preprocessing import StandardScaler

key = jr.PRNGKey(42)
key = jr.key(42)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down Expand Up @@ -249,17 +249,24 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
# with the optimised hyperparameters, and compare them to the predictions made using the
# posterior with the default hyperparameters:


# %%
def plot_ribbon(ax, x, dist, color):
mean = dist.mean()
std = dist.stddev()
ax.plot(x, mean, label="Predictive mean", color=color)
ax.fill_between(x.squeeze(), mean - 2 * std, mean + 2 * std, alpha=0.2, label="Two sigma", color=color)
ax.fill_between(
x.squeeze(),
mean - 2 * std,
mean + 2 * std,
alpha=0.2,
label="Two sigma",
color=color,
)
ax.plot(x, mean - 2 * std, linestyle="--", linewidth=1, color=color)
ax.plot(x, mean + 2 * std, linestyle="--", linewidth=1, color=color)



# %%
opt_latent_dist = opt_posterior.predict(test_x, train_data=D)
opt_predictive_dist = opt_posterior.likelihood(opt_latent_dist)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/likelihoods_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = plt.rcParams["axes.prop_cycle"].by_key()["color"]
key = jr.PRNGKey(123)
key = jr.key(123)


n = 50
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/oceanmodelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import gpjax as gpx

# Enable Float64 for more stable matrix inversions.
key = jr.PRNGKey(123)
key = jr.key(123)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
tfd = tfp.distributions
key = jr.PRNGKey(123)
key = jr.key(123)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down
6 changes: 3 additions & 3 deletions docs/examples/pytrees.md
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,9 @@ class RBF(Module):
init=False, bijector=tfb.Softplus(), trainable=True
)
variance: float = param_field(init=False, bijector=tfb.Softplus(), trainable=True)
key: jr.KeyArray = field(default_factory = lambda: jr.PRNGKey(42))
key: jax.Array = field(default_factory = lambda: jr.key(42))
# Note, for Python <3.11 you may use the following:
# key: jr.KeyArray = jr.PRNGKey(42)
# key: jax.Array = jr.key(42)

def __post_init__(self):
# Split key into two keys
Expand Down Expand Up @@ -444,7 +444,7 @@ class RBF(Module):
init=False, bijector=tfb.Softplus(), trainable=True
)
variance: float = param_field(init=False, bijector=tfb.Softplus(), trainable=True)
key: jr.KeyArray = static_field(default_factory=lambda: jr.PRNGKey(42))
key: jax.Array = static_field(default_factory=lambda: jr.key(42))

def __post_init__(self):
# Split key into two keys
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx

key = jr.PRNGKey(123)
key = jr.key(123)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
Expand Down
Loading
Loading