Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(gpjax/kernels/base.py): add diagonal #429

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions gpjax/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def cross_covariance(self, x: Num[Array, "N D"], y: Num[Array, "M D"]):
def gram(self, x: Num[Array, "N D"]):
return self.compute_engine.gram(self, x)

def diagonal(self, x: Num[Array, "N D"]):
return self.compute_engine.diagonal(self, x)

def slice_input(self, x: Float[Array, "... D"]) -> Float[Array, "... Q"]:
r"""Slice out the relevant columns of the input matrix.

Expand Down
15 changes: 15 additions & 0 deletions gpjax/kernels/computations/basis_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from cola import PSD
from cola.ops import (
Dense,
Diagonal,
LinearOperator,
)

Expand Down Expand Up @@ -58,6 +59,20 @@ def gram(self, kernel: Kernel, inputs: Float[Array, "N D"]) -> LinearOperator:
z1 = self.compute_features(kernel, inputs)
return PSD(Dense(self.scaling(kernel) * jnp.matmul(z1, z1.T)))

def diagonal(self, kernel: Kernel, inputs: Float[Array, "N D"]) -> Diagonal:
r"""For a given kernel, compute the elementwise diagonal of the
NxN gram matrix on an input matrix of shape NxD.

Args:
kernel (AbstractKernel): the kernel function.
inputs (Float[Array, "N D"]): The input matrix.

Returns
-------
Diagonal: The computed diagonal variance entries.
"""
return super().diagonal(kernel.base_kernel, inputs)

def compute_features(
self, kernel: Kernel, x: Float[Array, "N D"]
) -> Float[Array, "N L"]:
Expand Down
31 changes: 30 additions & 1 deletion tests/test_kernels/test_approximations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Tuple

from cola.ops import Dense
from cola.ops import (
Dense,
Diagonal,
)
import jax
from jax import config
import jax.numpy as jnp
Expand Down Expand Up @@ -63,6 +66,32 @@ def test_gram(kernel: AbstractKernel, num_basis_fns: int, n_dims: int, n_data: i
assert jnp.all(evals > 0)


@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52])
@pytest.mark.parametrize("num_basis_fns", [2, 10, 20])
@pytest.mark.parametrize("n_dims", [1, 2, 5])
@pytest.mark.parametrize("n_data", [50, 100])
def test_diagonal(kernel: AbstractKernel, num_basis_fns: int, n_dims: int, n_data: int):
key = jr.key(123)
x = jr.uniform(key, shape=(n_data, 1), minval=-3.0, maxval=3.0).reshape(-1, 1)
if n_dims > 1:
x = jnp.hstack([x] * n_dims)
base_kernel = kernel(active_dims=list(range(n_dims)))
approximate = RFF(base_kernel=base_kernel, num_basis_fns=num_basis_fns)

linop = approximate.diagonal(x)

# Check the return type
assert isinstance(linop, Diagonal)

Kxx = linop.diag + _jitter

# Check that the shape is correct
assert Kxx.shape == (n_data,)

# Check that the diagonal is positive
assert jnp.all(Kxx > 0)


@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52])
@pytest.mark.parametrize("num_basis_fns", [2, 10, 20])
@pytest.mark.parametrize("n_dims", [1, 2, 5])
Expand Down
29 changes: 27 additions & 2 deletions tests/test_kernels/test_nonstationary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from itertools import product
from typing import List

from cola.ops import LinearOperator
from cola.ops import (
Diagonal,
LinearOperator,
)
import jax
from jax import config
import jax.numpy as jnp
Expand Down Expand Up @@ -125,9 +128,28 @@ def test_gram(self, dim: int, n: int) -> None:

# Test gram matrix
Kxx = kernel.gram(x)
Kxx_cross = kernel.cross_covariance(x, x)
assert isinstance(Kxx, LinearOperator)
assert Kxx.shape == (n, n)
assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense() + jnp.eye(n) * 1e-6) > 0.0)
assert jnp.allclose(Kxx_cross, Kxx.to_dense())

@pytest.mark.parametrize("n", [1, 2, 5], ids=lambda x: f"n={x}")
@pytest.mark.parametrize("dim", [1, 3], ids=lambda x: f"dim={x}")
def test_diagonal(self, dim: int, n: int) -> None:
# Initialise kernel
kernel: AbstractKernel = self.kernel()

# Inputs
x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim)

# Test diagonal
Kxx = kernel.diagonal(x)
Kxx_gram = jnp.diagonal(kernel.gram(x).to_dense())
assert isinstance(Kxx, Diagonal)
assert Kxx.shape == (n, n)
assert jnp.all(Kxx.diag + 1e-6 > 0.0)
assert jnp.allclose(Kxx_gram, Kxx.diag)

@pytest.mark.parametrize("n_a", [1, 2, 5], ids=lambda x: f"n_a={x}")
@pytest.mark.parametrize("n_b", [1, 2, 5], ids=lambda x: f"n_b={x}")
Expand All @@ -139,11 +161,14 @@ def test_cross_covariance(self, n_a: int, n_b: int, dim: int) -> None:
# Inputs
a = jnp.linspace(-1.0, 1.0, n_a * dim).reshape(n_a, dim)
b = jnp.linspace(3.0, 4.0, n_b * dim).reshape(n_b, dim)
c = jnp.vstack((a, b))

# Test cross-covariance
Kab = kernel.cross_covariance(a, b)
Kab_gram = kernel.gram(c).to_dense()[:n_a, n_a:]
assert isinstance(Kab, jnp.ndarray)
assert Kab.shape == (n_a, n_b)
assert jnp.allclose(Kab, Kab_gram)


def prod(inp):
Expand Down Expand Up @@ -216,4 +241,4 @@ def test_values_by_monte_carlo_in_special_case(self, order: int) -> None:
integrands = H_a * H_b * (weights_a**order) * (weights_b**order)
Kab_approx = 2.0 * jnp.mean(integrands)

assert jnp.max(Kab_approx - Kab_exact) < 1e-4
assert jnp.max(jnp.abs(Kab_approx - Kab_exact)) < 1e-4
27 changes: 26 additions & 1 deletion tests/test_kernels/test_stationary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from dataclasses import is_dataclass
from itertools import product

from cola.ops import LinearOperator
from cola.ops import (
Diagonal,
LinearOperator,
)
import jax
from jax import config
import jax.numpy as jnp
Expand Down Expand Up @@ -129,9 +132,28 @@ def test_gram(self, dim: int, n: int) -> None:

# Test gram matrix
Kxx = kernel.gram(x)
Kxx_cross = kernel.cross_covariance(x, x)
assert isinstance(Kxx, LinearOperator)
assert Kxx.shape == (n, n)
assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense() + jnp.eye(n) * 1e-6) > 0.0)
assert jnp.allclose(Kxx_cross, Kxx.to_dense())

@pytest.mark.parametrize("n", [1, 2, 5], ids=lambda x: f"n={x}")
@pytest.mark.parametrize("dim", [1, 3], ids=lambda x: f"dim={x}")
def test_diagonal(self, dim: int, n: int) -> None:
# Initialise kernel
kernel: AbstractKernel = self.kernel()

# Inputs
x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim)

# Test diagonal
Kxx = kernel.diagonal(x)
Kxx_gram = jnp.diagonal(kernel.gram(x).to_dense())
assert isinstance(Kxx, Diagonal)
assert Kxx.shape == (n, n)
assert jnp.all(Kxx.diag + 1e-6 > 0.0)
assert jnp.allclose(Kxx_gram, Kxx.diag)

@pytest.mark.parametrize("n_a", [1, 2, 5], ids=lambda x: f"n_a={x}")
@pytest.mark.parametrize("n_b", [1, 2, 5], ids=lambda x: f"n_b={x}")
Expand All @@ -143,11 +165,14 @@ def test_cross_covariance(self, n_a: int, n_b: int, dim: int) -> None:
# Inputs
a = jnp.linspace(-1.0, 1.0, n_a * dim).reshape(n_a, dim)
b = jnp.linspace(3.0, 4.0, n_b * dim).reshape(n_b, dim)
c = jnp.vstack((a, b))

# Test cross-covariance
Kab = kernel.cross_covariance(a, b)
Kab_gram = kernel.gram(c).to_dense()[:n_a, n_a:]
assert isinstance(Kab, jnp.ndarray)
assert Kab.shape == (n_a, n_b)
assert jnp.allclose(Kab, Kab_gram)

def test_spectral_density(self):
# Initialise kernel
Expand Down
Loading