diff --git a/examples/natgrads.pct.py b/examples/natgrads.pct.py index c14074a35..45c7dc7d0 100644 --- a/examples/natgrads.pct.py +++ b/examples/natgrads.pct.py @@ -25,7 +25,6 @@ import jax.random as jr import matplotlib.pyplot as plt import optax as ox -from jax import jit, lax from jax.config import config import gpjax as gpx @@ -97,7 +96,7 @@ n_iters=5000, batch_size=256, key=jr.PRNGKey(42), - moment_optim=ox.sgd(0.1), + moment_optim=ox.sgd(0.01), hyper_optim=ox.adam(1e-3), ) diff --git a/gpjax/covariance_operator.py b/gpjax/covariance_operator.py deleted file mode 100644 index fca0815ba..000000000 --- a/gpjax/covariance_operator.py +++ /dev/null @@ -1,426 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import abc -from typing import Callable, Optional, Tuple, Union - -import jax.numpy as jnp -import jax.scipy as jsp -from chex import dataclass -from jax import lax -from jaxtyping import Array, Float - - -@dataclass -class CovarianceOperator: - """Multivariate Gaussian covariance operator base class. - - Inspired by TensorFlows' LinearOperator class. - """ - - name: Optional[str] = None - - def __sub__(self, other: "CovarianceOperator") -> "CovarianceOperator": - """Subtract two covariance operators. - - Args: - other (CovarianceOperator): Other covariance operator. - - Returns: - CovarianceOperator: Difference of the two covariance operators. - """ - - return self + (other * -1) - - def __rsub__(self, other: "CovarianceOperator") -> "CovarianceOperator": - """Reimplimentation of subtracting two covariance operators. - - Args: - other (CovarianceOperator): Other covariance operator. - - Returns: - CovarianceOperator: Difference of the two covariance operators. - """ - return (self * -1) + other - - def __add__( - self, other: Union["CovarianceOperator", Float[Array, "N N"]] - ) -> "CovarianceOperator": - """Add diagonal to another covariance operator. - - Args: - other (Union["CovarianceOperator", Float[Array, "N N"]]): Other - covariance operator. Dimension of both operators must match. - If the other covariance operator is not a - DiagonalCovarianceOperator, dense matrix addition is used. - - Returns: - CovarianceOperator: Covariance operator plus the diagonal covariance operator. - """ - - # Check shapes: - if not (other.shape == self.shape): - raise ValueError( - f"Shape mismatch: {self.shape} and {other.shape} are not equal." - ) - - # If other is a JAX array, we convert it to a DenseCovarianceOperator - if isinstance(other, jnp.ndarray): - other = DenseCovarianceOperator(matrix=other) - - # Matix addition: - if isinstance(other, DiagonalCovarianceOperator): - return self._add_diagonal(other) - - if isinstance(self, DiagonalCovarianceOperator): - return other._add_diagonal(self) - - elif isinstance(other, CovarianceOperator): - - return DenseCovarianceOperator(matrix=self.to_dense() + other.to_dense()) - - else: - raise NotImplementedError - - def __radd__( - self, other: Union["CovarianceOperator", Float[Array, "N N"]] - ) -> "CovarianceOperator": - return self.__add__(other) - - def __mul__(self, other: float) -> "CovarianceOperator": - """Multiply covariance operator by scalar. - - Args: - other (CovarianceOperator): Scalar. - - Returns: - CovarianceOperator: Covariance operator multiplied by scalar. - """ - - raise NotImplementedError - - def __rmul__(self, other: float) -> "CovarianceOperator": - return self.__mul__(other) - - @abc.abstractmethod - def _add_diagonal( - self, other: "DiagonalCovarianceOperator" - ) -> "CovarianceOperator": - """ - Add diagonal matrix to a linear operator, useful for computing, Kxx + Iσ². - """ - return NotImplementedError - - @abc.abstractmethod - def __matmul__(self, x: Float[Array, "N M"]) -> Float[Array, "N M"]: - """Matrix multiplication. - - Args: - x (Float[Array, "N M"]): Matrix to multiply with. - - Returns: - Float[Array, "N M"]: Result of matrix multiplication. - """ - raise NotImplementedError - - @property - @abc.abstractmethod - def shape(self) -> Tuple[int, int]: - """Covaraince matrix shape. - - Returns: - Tuple[int, int]: shape of the covariance operator. - """ - raise NotImplementedError - - @abc.abstractmethod - def to_dense(self) -> Float[Array, "N N"]: - """Construct dense Covaraince matrix from the covariance operator. - - Returns: - Float[Array, "N N"]: Dense covariance matrix. - """ - raise NotImplementedError - - @abc.abstractmethod - def diagonal(self) -> Float[Array, "N"]: - """Construct covaraince matrix diagonal from the covariance operator. - - Returns: - Float[Array, "N"]: Covariance matrix diagonal. - """ - raise NotImplementedError - - @abc.abstractmethod - def triangular_lower(self) -> Float[Array, "N N"]: - """Compute lower triangular. - - Returns: - Float[Array, "N N"]: Lower triangular of the covariance matrix. - """ - raise NotImplementedError - - def log_det(self) -> Float[Array, "1"]: - """Log determinant of the covariance matrix. - - Returns: - Float[Array, "1"]: Log determinant of the covariance matrix. - """ - - return 2.0 * jnp.sum(jnp.log(jnp.diag(self.triangular_lower()))) - - def solve(self, rhs: Float[Array, "N M"]) -> Float[Array, "N M"]: - """Solve linear system. - - Args: - rhs (Float[Array, "N M"]): Right hand side of the linear system. - - Returns: - Float[Array, "N M"]: Solution of the linear system. - """ - return jsp.linalg.cho_solve((self.triangular_lower(), True), rhs) - - def trace(self) -> Float[Array, "1"]: - """Trace of the covariance matrix. - - Returns: - Float[Array, "1"]: Trace of the covariance matrix. - """ - return jnp.sum(self.diagonal()) - - -@dataclass -class _DenseMatrix: - matrix: Float[Array, "N N"] - - -@dataclass -class DenseCovarianceOperator(CovarianceOperator, _DenseMatrix): - """Dense covariance operator.""" - - name: Optional[str] = "Dense covariance operator" - - def __mul__(self, other: float) -> "CovarianceOperator": - """Multiply covariance operator by scalar. - - Args: - other (CovarianceOperator): Scalar. - - Returns: - CovarianceOperator: Covariance operator multiplied by a scalar. - """ - - return DenseCovarianceOperator(matrix=self.matrix * other) - - def _add_diagonal( - self, other: "DiagonalCovarianceOperator" - ) -> "CovarianceOperator": - """Add diagonal to the covariance operator, useful for - computing, :math:`\\mathbf{K}_{xx} + \\mathbf{I}\\sigma^2`. - - Args: - other (DiagonalCovarianceOperator): Diagonal covariance - operator to add to the covariance operator. - - Returns: - CovarianceOperator: Sum of the two covariance operators. - """ - - n = self.shape[0] - diag_indices = jnp.diag_indices(n) - new_matrix = self.matrix.at[diag_indices].add(other.diagonal()) - - return DenseCovarianceOperator(matrix=new_matrix) - - @property - def shape(self) -> Tuple[int, int]: - """Covaraince matrix shape. - - Returns: - Tuple[int, int]: shape of the covariance operator. - """ - return self.matrix.shape - - def to_dense(self) -> Float[Array, "N N"]: - """Construct dense Covaraince matrix from the covariance operator. - - Returns: - Float[Array, "N N"]: Dense covariance matrix. - """ - return self.matrix - - def diagonal(self) -> Float[Array, "N"]: - """ - Diagonal of the covariance operator. - - Returns: - Float[Array, "N"]: The diagonal of the covariance operator. - """ - - return jnp.diag(self.matrix) - - def __matmul__(self, x: Float[Array, "N M"]) -> Float[Array, "N M"]: - """Matrix multiplication. - - Args: - x (Float[Array, "N M"]): Matrix to multiply with. - - Returns: - Float[Array, "N M"]: Result of matrix multiplication. - """ - - return jnp.matmul(self.matrix, x) - - def triangular_lower(self) -> Float[Array, "N N"]: - """Compute lower triangular. - - Returns: - Float[Array, "N N"]: Lower triangular of the covariance matrix. - """ - return jnp.linalg.cholesky(self.matrix) - - -@dataclass -class _DiagonalMatrix: - diag: Float[Array, "N"] - - -@dataclass -class DiagonalCovarianceOperator(CovarianceOperator, _DiagonalMatrix): - """Diagonal covariance operator.""" - - name: Optional[str] = "Diagonal covariance operator" - - def __mul__(self, other: float) -> "CovarianceOperator": - """Multiply covariance operator by scalar. - - Args: - other (CovarianceOperator): Scalar. - - Returns: - CovarianceOperator: Covariance operator multiplied by a scalar. - """ - - return DiagonalCovarianceOperator(diag=self.diag * other) - - def _add_diagonal( - self, other: "DiagonalCovarianceOperator" - ) -> "CovarianceOperator": - """Add diagonal to the covariance operator, useful for computing, - :math:`\\mathbf{K}_{xx} + \\mathbf{I}\\sigma^2` - - Args: - other (DiagonalCovarianceOperator): Diagonal covariance - operator to add to the covariance operator. - - Returns: - CovarianceOperator: Covariance operator with the diagonal added. - """ - - return DiagonalCovarianceOperator(diag=self.diag + other.diagonal()) - - @property - def shape(self) -> Tuple[int, int]: - """Covaraince matrix shape. - - Returns: - Tuple[int, int]: shape of the covariance operator. - """ - N = self.diag.shape[0] - return (N, N) - - def to_dense(self) -> Float[Array, "N N"]: - """Construct dense Covaraince matrix from the covariance operator. - - Returns: - Float[Array, "N N"]: Dense covariance matrix. - """ - return jnp.diag(self.diag) - - def diagonal(self) -> Float[Array, "N"]: - """ - Diagonal of the covariance operator. - - Returns: - Float[Array, "N"]: The diagonal of the covariance operator. - """ - return self.diag - - def __matmul__(self, x: Float[Array, "N M"]) -> Float[Array, "N M"]: - """Matrix multiplication. - - Args: - x (Float[Array, "N M"]): Matrix to multiply with. - - Returns: - Float[Array, "N M"]: Result of matrix multiplication. - """ - diag_mat = jnp.expand_dims(self.diag, -1) - return diag_mat * x - - def triangular_lower(self) -> Float[Array, "N N"]: - """ - Lower triangular. - - Returns: - Float[Array, "N N"]: Lower triangular matrix. - """ - return jnp.diag(jnp.sqrt(self.diag)) - - def log_det(self) -> Float[Array, "1"]: - """Log determinant. - - Returns: - Float[Array, "1"]: Log determinant of the covariance matrix. - """ - return 2.0 * jnp.sum(jnp.log(self.diag)) - - def solve(self, rhs: Float[Array, "N M"]) -> Float[Array, "N M"]: - """Solve linear system. - - Args: - rhs (Float[Array, "N M"]): Right hand side of the linear system. - - Returns: - Float[Array, "N M"]: Solution of the linear system. - """ - inv_diag_mat = jnp.expand_dims(1.0 / self.diag, -1) - return rhs * inv_diag_mat - - -def I(n: int) -> DiagonalCovarianceOperator: - """Identity matrix. - - Args: - n (int): Size of the identity matrix. - - Returns: - DiagonalCovarianceOperator: Identity matrix of shape nxn. - """ - - I = DiagonalCovarianceOperator( - diag=jnp.ones(n), - name="Identity matrix", - ) - - return I - - -__all__ = [ - "CovarianceOperator", - "DenseCoarianceOperator", - "DiagonalCovarianceOperator", - "I", -] diff --git a/gpjax/gaussian_distribution.py b/gpjax/gaussian_distribution.py new file mode 100644 index 000000000..db5a487cd --- /dev/null +++ b/gpjax/gaussian_distribution.py @@ -0,0 +1,252 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import jax.numpy as jnp +from jaxlinop import LinearOperator, IdentityLinearOperator + +from jaxtyping import Array, Float +from jax import vmap + +from typing import Tuple, Optional, Any + +import distrax as dx +import jax.random as jr +from jax.random import KeyArray + + +def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None: + """Checks that the inputs are correct.""" + + if loc is None and scale is None: + raise ValueError("At least one of `loc` or `scale` must be specified.") + + if loc is not None and loc.ndim < 1: + raise ValueError("The parameter `loc` must have at least one dimension.") + + if scale is not None and scale.ndim < 2: + raise ValueError( + f"The `scale` must have at least two dimensions, but " + f"`scale.shape = {scale.shape}`." + ) + + if scale is not None and not isinstance(scale, LinearOperator): + raise ValueError( + f"scale must be a LinearOperator or a JAX array, but got {type(scale)}" + ) + + if scale is not None and (scale.shape[-1] != scale.shape[-2]): + raise ValueError( + f"The `scale` must be a square matrix, but " + f"`scale.shape = {scale.shape}`." + ) + + if loc is not None: + num_dims = loc.shape[-1] + if scale is not None and (scale.shape[-1] != num_dims): + raise ValueError( + f"Shapes are not compatible: `loc.shape = {loc.shape}` and " + f"`scale.shape = {scale.shape}`." + ) + + +class GaussianDistribution(dx.Distribution): + """Multivariate Gaussian distribution with a linear operator scale matrix. + + Args: + loc (Optional[Float[Array, "N"]]): The mean of the distribution. Defaults to None. + scale (Optional[LinearOperator]): The scale matrix of the distribution. Defaults to None. + + Returns: + GaussianDistribution: A multivariate Gaussian distribution with a linear operator scale matrix. + """ + + # TODO: Consider `distrax.transformed.Transformed` object. Can we create a LinearOperator to `distrax.bijector` representation + # and modify `distrax.MultivariateNormalFromBijector`? + + # TODO: Consider natural and expectation parameterisations in future work. + + def __init__( + self, + loc: Optional[Float[Array, "N"]] = None, + scale: Optional[LinearOperator] = None, + ) -> None: + """Initialises the distribution.""" + + _check_loc_scale(loc, scale) + + # Find dimensionality of the distribution. + if loc is not None: + num_dims = loc.shape[-1] + + elif scale is not None: + num_dims = scale.shape[-1] + + # Set the location to zero vector if unspecified. + if loc is None: + loc = jnp.zeros((num_dims,)) + + # If not specified, set the scale to the identity matrix. + if scale is None: + scale = IdentityLinearOperator(num_dims) + + self.loc = loc + self.scale = scale + + def mean(self) -> Float[Array, "N"]: + """Calculates the mean.""" + return self.loc + + def median(self) -> Float[Array, "N"]: + """Calculates the median.""" + return self.loc + + def mode(self) -> Float[Array, "N"]: + """Calculates the mode.""" + return self.loc + + def covariance(self) -> Float[Array, "N N"]: + """Calculates the covariance matrix.""" + return self.scale.to_dense() + + def variance(self) -> Float[Array, "N"]: + """Calculates the variance.""" + return self.scale.diagonal() + + def stddev(self) -> Float[Array, "N"]: + """Calculates the standard deviation.""" + return jnp.sqrt(self.scale.diagonal()) + + @property + def event_shape(self) -> Tuple: + """Returns the event shape.""" + return self.loc.shape[-1:] + + def entropy(self) -> Float[Array, "1"]: + """Calculates the entropy of the distribution.""" + return 0.5 * ( + self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi)) + self.scale.log_det() + ) + + def log_prob(self, y: Float[Array, "N"]) -> Float[Array, "1"]: + """Calculates the log pdf of the multivariate Gaussian. + + Args: + y (Float[Array, "N"]): The value to calculate the log probability of. + + Returns: + Float[Array, "1"]: The log probability of the value. + """ + mu = self.loc + sigma = self.scale + n = mu.shape[-1] + + # diff, y - µ + diff = y - mu + + # compute the pdf, -1/2[ n log(2π) + log|Σ| + (y - µ)ᵀΣ⁻¹(y - µ) ] + return -0.5 * ( + n * jnp.log(2.0 * jnp.pi) + sigma.log_det() + diff.T @ sigma.solve(diff) + ) + + def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]: + """Samples from the distribution. + + Args: + key (KeyArray): The key to use for sampling. + + Returns: + Float[Array, "n N"]: The samples. + """ + # Obtain covariance root. + sqrt = self.scale.to_root() + + # Gather n samples from standard normal distribution Z = [z₁, ..., zₙ]ᵀ. + Z = jr.normal(key, shape=(n, *self.event_shape)) + + # xᵢ ~ N(loc, cov) <=> xᵢ = loc + sqrt zᵢ, where zᵢ ~ N(0, I). + affine_transformation = lambda x: self.loc + sqrt @ x + + return vmap(affine_transformation)(Z) + + def kl_divergence(self, other: "GaussianDistribution") -> Float[Array, "1"]: + return _kl_divergence(self, other) + + +def _check_and_return_dimension( + q: GaussianDistribution, p: GaussianDistribution +) -> int: + """Checks that the dimensions of the distributions are compatible.""" + if q.event_shape != p.event_shape: + raise ValueError( + f"Distribution event shapes are not compatible: `q.event_shape = {q.event_shape}` and " + f"`p.event_shape = {p.event_shape}`. Please check your mean and covariance shapes." + ) + + return q.event_shape[-1] + + +def _frobeinius_norm_squared(matrix: Float[Array, "N N"]) -> Float[Array, "1"]: + """Calculates the squared Frobenius norm of a matrix.""" + return jnp.sum(jnp.square(matrix)) + + +def _kl_divergence( + q: GaussianDistribution, p: GaussianDistribution +) -> Float[Array, "1"]: + """Computes the KL divergence, KL[q||p], between two multivariate Gaussian distributions + q(x) = N(x; μq, Σq) and p(x) = N(x; μp, Σp). + + Args: + q (GaussianDistribution): A multivariate Gaussian distribution. + p (GaussianDistribution): A multivariate Gaussia distribution. + + Returns: + Float[Array, "1"]: The KL divergence between q and p. + """ + + n_dim = _check_and_return_dimension(q, p) + + # Extract q mean and covariance. + mu_q = q.loc + sigma_q = q.scale + + # Extract p mean and covariance. + mu_p = p.loc + sigma_p = p.scale + + # Find covariance roots. + sqrt_p = sigma_p.to_root() + sqrt_q = sigma_q.to_root() + + # diff, μp - μq + diff = mu_p - mu_q + + # trace term, tr[Σp⁻¹ Σq] = tr[(LpLpᵀ)⁻¹(LqLqᵀ)] = tr[(Lp⁻¹Lq)(Lp⁻¹Lq)ᵀ] = (fr[LqLp⁻¹])² + trace = _frobeinius_norm_squared( + sqrt_p.solve(sqrt_q.to_dense()) + ) # TODO: Not most efficient, given the `to_dense()` call (e.g., consider diagonal p and q). Need to abstract solving linear operator against another linear operator. + + # Mahalanobis term, (μp - μq)ᵀ Σp⁻¹ (μp - μq) = tr [(μp - μq)ᵀ [LpLpᵀ]⁻¹ (μp - μq)] = (fr[Lp⁻¹(μp - μq)])² + mahalanobis = _frobeinius_norm_squared( + sqrt_p.solve(diff) + ) # TODO: Need to improve this. Perhaps add a Mahalanobis method to LinearOperators. + + # KL[q(x)||p(x)] = [ [(μp - μq)ᵀ Σp⁻¹ (μp - μq)] - n - log|Σq| + log|Σp| + tr[Σp⁻¹ Σq] ] / 2 + return (mahalanobis - n_dim - sigma_q.log_det() + sigma_p.log_det() + trace) / 2.0 + + +__all__ = [ + "GaussianDistribution", +] diff --git a/gpjax/gps.py b/gpjax/gps.py index 73a68a334..66dd7e1be 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -19,17 +19,18 @@ import distrax as dx import jax.numpy as jnp import jax.random as jr -import jax.scipy as jsp from chex import dataclass from jaxtyping import Array, Float +from jaxlinop import identity + from .config import get_defaults -from .covariance_operator import I from .kernels import AbstractKernel from .likelihoods import AbstractLikelihood, Conjugate, Gaussian, NonConjugate from .mean_functions import AbstractMeanFunction, Zero from .types import Dataset, PRNGKeyType from .utils import concat_dictionaries +from .gaussian_distribution import GaussianDistribution @dataclass @@ -180,7 +181,7 @@ def __rmul__(self, other: AbstractLikelihood): def predict( self, params: Dict - ) -> Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: + ) -> Callable[[Float[Array, "N D"]], GaussianDistribution]: """Compute the predictive prior distribution for a given set of parameters. The output of this function is a function that computes a distrx distribution for a given set of inputs. @@ -204,7 +205,7 @@ def predict( function should be defined for. Returns: - Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: A mean + Callable[[Float[Array, "N D"]], GaussianDistribution]: A mean function that accepts an input array for where the mean function should be evaluated at. The mean function's value at these points is then returned. @@ -218,7 +219,7 @@ def predict( # Unpack kernel computation gram = kernel.gram - def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormalTri: + def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: # Unpack test inputs t = test_inputs @@ -226,10 +227,9 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormalTri: μt = mean_function(params["mean_function"], t) Ktt = gram(kernel, params["kernel"], t) - Ktt += I(n_test) * jitter - Lt = Ktt.triangular_lower() + Ktt += identity(n_test) * jitter - return dx.MultivariateNormalTri(jnp.atleast_1d(μt.squeeze()), Lt) + return GaussianDistribution(jnp.atleast_1d(μt.squeeze()), Ktt) return predict_fn @@ -273,7 +273,7 @@ class AbstractPosterior(AbstractPrior): name: Optional[str] = "GP posterior" @abstractmethod - def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: + def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution: """Compute the predictive posterior distribution of the latent function for a given set of parameters. For any class inheriting the ``AbstractPosterior`` class, this method must be implemented. @@ -283,7 +283,7 @@ def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: Keyword arguments to the predict method. Returns: - dx.Distribution: A multivariate normal random variable + GaussianDistribution: A multivariate normal random variable representation of the Gaussian process. """ raise NotImplementedError @@ -350,7 +350,7 @@ def predict( self, params: Dict, train_data: Dataset, - ) -> Callable[[Float[Array, "N D"]], dx.MultivariateNormalFullCovariance]: + ) -> Callable[[Float[Array, "N D"]], GaussianDistribution]: """Conditional on a training data set, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the @@ -392,9 +392,9 @@ def predict( input and output data used for training dataset. Returns: - Callable[[Float[Array, "N D"]], dx.MultivariateNormalFullCovariance]: A + Callable[[Float[Array, "N D"]], GaussianDistribution]: A function that accepts an input array and returns the predictive - distribution as a ``dx.MultivariateNormalTri``. + distribution as a ``GaussianDistribution``. """ jitter = get_defaults()["jitter"] @@ -415,19 +415,19 @@ def predict( # Precompute Gram matrix, Kxx, at training inputs, x Kxx = gram(kernel, params["kernel"], x) - Kxx += I(n) * jitter + Kxx += identity(n) * jitter # Σ = Kxx + Iσ² - Sigma = Kxx + I(n) * obs_noise + Sigma = Kxx + identity(n) * obs_noise - def predict(test_inputs: Float[Array, "N D"]) -> dx.Distribution: + def predict(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: """Compute the predictive distribution at a set of test inputs. Args: test_inputs (Float[Array, "N D"]): A Jax array of test inputs. Returns: - dx.Distribution: A ``dx.MultivariateNormalFullCovariance`` + GaussianDistribution: A ``GaussianDistribution`` object that represents the predictive distribution. """ @@ -439,22 +439,17 @@ def predict(test_inputs: Float[Array, "N D"]) -> dx.Distribution: Ktt = gram(kernel, params["kernel"], t) Kxt = cross_covariance(kernel, params["kernel"], x, t) - # TODO: Investigate lower triangular solves for general covariance operators - # this is more efficient than the full solve for dense matrices in the current implimentation. - # Σ⁻¹ Kxt Sigma_inv_Kxt = Sigma.solve(Kxt) # μt + Ktx (Kxx + Iσ²)⁻¹ (y - μx) mean = μt + jnp.matmul(Sigma_inv_Kxt.T, y - μx) - # Ktt - Ktx (Kxx + Iσ²)⁻¹ Kxt + # Ktt - Ktx (Kxx + Iσ²)⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently. covariance = Ktt - jnp.matmul(Kxt.T, Sigma_inv_Kxt) - covariance += I(n_test) * jitter + covariance += identity(n_test) * jitter - return dx.MultivariateNormalFullCovariance( - jnp.atleast_1d(mean.squeeze()), covariance.to_dense() - ) + return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance) return predict @@ -560,14 +555,12 @@ def mll( # Σ = (Kxx + Iσ²) = LLᵀ Kxx = gram(kernel, params["kernel"], x) - Kxx += I(n) * jitter - Sigma = Kxx + I(n) * obs_noise - L = Sigma.triangular_lower() + Kxx += identity(n) * jitter + Sigma = Kxx + identity(n) * obs_noise # p(y | x, θ), where θ are the model hyperparameters: - - marginal_likelihood = dx.MultivariateNormalTri( - jnp.atleast_1d(μx.squeeze()), L + marginal_likelihood = GaussianDistribution( + jnp.atleast_1d(μx.squeeze()), Sigma ) return constant * ( @@ -656,8 +649,8 @@ def predict( # Precompute lower triangular of Gram matrix, Lx, at training inputs, x Kxx = gram(kernel, params["kernel"], x) - Kxx += I(n) * jitter - Lx = Kxx.triangular_lower() + Kxx += identity(n) * jitter + Lx = Kxx.to_root() def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: """Predictive distribution of the latent function for a given set of test inputs. @@ -674,11 +667,11 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: # Compute terms of the posterior predictive distribution Ktx = cross_covariance(kernel, params["kernel"], t, x) - Ktt = gram(kernel, params["kernel"], t) + I(n_test) * jitter + Ktt = gram(kernel, params["kernel"], t) + identity(n_test) * jitter μt = mean_function(params["mean_function"], t) # Lx⁻¹ Kxt - Lx_inv_Kxt = jsp.linalg.solve_triangular(Lx, Ktx.T, lower=True) + Lx_inv_Kxt = Lx.solve(Ktx.T) # Whitened function values, wx, correponding to the inputs, x wx = params["latent"] @@ -686,13 +679,11 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: # μt + Ktx Lx⁻¹ wx mean = μt + jnp.matmul(Lx_inv_Kxt.T, wx) - # Ktt - Ktx Kxx⁻¹ Kxt + # Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently. covariance = Ktt - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt) - covariance += I(n_test) * jitter + covariance += identity(n_test) * jitter - return dx.MultivariateNormalFullCovariance( - jnp.atleast_1d(mean.squeeze()), covariance.to_dense() - ) + return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance) return predict_fn @@ -761,8 +752,8 @@ def mll(params: Dict): # Compute lower triangular of the kernel Gram matrix Kxx = gram(kernel, params["kernel"], x) - Kxx += I(n) * jitter - Lx = Kxx.triangular_lower() + Kxx += identity(n) * jitter + Lx = Kxx.to_root() # Compute the prior mean function μx = mean_function(params["mean_function"], x) @@ -771,7 +762,7 @@ def mll(params: Dict): wx = params["latent"] # f(x) = μx + Lx wx - fx = μx + jnp.matmul(Lx, wx) + fx = μx + Lx @ wx # p(y | f(x), θ), where θ are the model hyperparameters likelihood = link_function(params, fx) diff --git a/gpjax/kernels.py b/gpjax/kernels.py index df4d7f439..827008c9b 100644 --- a/gpjax/kernels.py +++ b/gpjax/kernels.py @@ -16,17 +16,19 @@ import abc from typing import Callable, Dict, List, Optional, Sequence +from jaxlinop import( + LinearOperator, + DenseLinearOperator, + DiagonalLinearOperator, + ConstantDiagonalLinearOperator, +) + import jax.numpy as jnp from chex import dataclass from jax import vmap from jaxtyping import Array, Float from .config import get_defaults -from .covariance_operator import ( - CovarianceOperator, - DenseCovarianceOperator, - DiagonalCovarianceOperator, -) from .types import PRNGKeyType JITTER = get_defaults()["jitter"] @@ -131,7 +133,7 @@ def gram( kernel: AbstractKernel, params: Dict, inputs: Float[Array, "N D"], - ) -> CovarianceOperator: + ) -> LinearOperator: """Compute Gram covariance operator of the kernel function. @@ -141,7 +143,7 @@ def gram( inputs (Float[Array, "N N"]): The inputs to the kernel function. Returns: - CovarianceOperator: Gram covariance operator of the kernel function. + LinearOperator: Gram covariance operator of the kernel function. """ raise NotImplementedError @@ -176,7 +178,7 @@ def diagonal( kernel: AbstractKernel, params: Dict, inputs: Float[Array, "N D"], - ) -> CovarianceOperator: + ) -> DiagonalLinearOperator: """For a given kernel, compute the elementwise diagonal of the NxN gram matrix on an input matrix of shape NxD. @@ -187,12 +189,12 @@ def diagonal( inputs (Float[Array, "N D"]): The input matrix. Returns: - CovarianceOperator: The computed diagonal variance entries. + LinearOperator: The computed diagonal variance entries. """ diag = vmap(lambda x: kernel(params, x, x))(inputs) - return DiagonalCovarianceOperator(diag=diag) + return DiagonalLinearOperator(diag=diag) class DenseKernelComputation(AbstractKernelComputation): @@ -205,7 +207,7 @@ def gram( kernel: AbstractKernel, params: Dict, inputs: Float[Array, "N D"], - ) -> CovarianceOperator: + ) -> DenseLinearOperator: """For a given kernel, compute the NxN gram matrix on an input matrix of shape NxD. @@ -221,7 +223,7 @@ def gram( matrix = vmap(lambda x: vmap(lambda y: kernel(params, x, y))(inputs))(inputs) - return DenseCovarianceOperator(matrix=matrix) + return DenseLinearOperator(matrix=matrix) class DiagonalKernelComputation(AbstractKernelComputation): @@ -230,7 +232,7 @@ def gram( kernel: AbstractKernel, params: Dict, inputs: Float[Array, "N D"], - ) -> CovarianceOperator: + ) -> DiagonalLinearOperator: """For a kernel with diagonal structure, compute the NxN gram matrix on an input matrix of shape NxD. @@ -246,7 +248,56 @@ def gram( diag = vmap(lambda x: kernel(params, x, x))(inputs) - return DiagonalCovarianceOperator(diag=diag) + return DiagonalLinearOperator(diag=diag) + + +class ConstantDiagonalKernelComputation(AbstractKernelComputation): + @staticmethod + def gram( + kernel: AbstractKernel, + params: Dict, + inputs: Float[Array, "N D"], + ) -> ConstantDiagonalLinearOperator: + """For a kernel with diagonal structure, compute the NxN gram matrix on + an input matrix of shape NxD. + + Args: + kernel (AbstractKernel): The kernel for which the Gram matrix + should be computed for. + params (Dict): The kernel's parameter set. + inputs (Float[Array, "N D"]): The input matrix. + + Returns: + CovarianceOperator: The computed square Gram matrix. + """ + + value = kernel(params, inputs[0], inputs[0]) + + return ConstantDiagonalLinearOperator(value=value, size=inputs.shape[0]) + + + @staticmethod + def diagonal( + kernel: AbstractKernel, + params: Dict, + inputs: Float[Array, "N D"], + ) -> DiagonalLinearOperator: + """For a given kernel, compute the elementwise diagonal of the + NxN gram matrix on an input matrix of shape NxD. + + Args: + kernel (AbstractKernel): The kernel for which the variance + vector should be computed for. + params (Dict): The kernel's parameter set. + inputs (Float[Array, "N D"]): The input matrix. + + Returns: + LinearOperator: The computed diagonal variance entries. + """ + + diag = vmap(lambda x: kernel(params, x, x))(inputs) + + return DiagonalLinearOperator(diag=diag) @dataclass @@ -536,7 +587,7 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: @dataclass(repr=False) -class White(AbstractKernel, DiagonalKernelComputation): +class White(AbstractKernel, ConstantDiagonalKernelComputation): def __post_init__(self) -> None: self.ndims = 1 if not self.active_dims else len(self.active_dims) diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index 0e15ab0ac..6b39c1d2c 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -15,6 +15,7 @@ import abc from typing import Any, Callable, Dict, Optional +from jaxlinop.utils import to_dense import distrax as dx import jax.numpy as jnp @@ -22,8 +23,7 @@ from chex import dataclass from jaxtyping import Array, Float -from .config import get_defaults -from .types import PRNGKeyType +from jax.random import KeyArray @dataclass @@ -59,11 +59,11 @@ def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: raise NotImplementedError @abc.abstractmethod - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """Return the parameters of the likelihood function. Args: - key (PRNGKeyType): A PRNG key. + key (KeyArray): A PRNG key. Returns: Dict: The parameters of the likelihood function. @@ -98,11 +98,11 @@ class Gaussian(AbstractLikelihood, Conjugate): name: Optional[str] = "Gaussian" - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """Return the variance parameter of the likelihood function. Args: - key (PRNGKeyType): A PRNG key. + key (KeyArray): A PRNG key. Returns: Dict: The parameters of the likelihood function. @@ -149,7 +149,7 @@ def predict(self, params: Dict, dist: dx.MultivariateNormalTri) -> dx.Distributi dx.Distribution: The predictive distribution. """ n_data = dist.event_shape[0] - cov = dist.covariance() + cov = to_dense(dist.covariance()) noisy_cov = cov.at[jnp.diag_indices(n_data)].add( params["likelihood"]["obs_noise"] ) @@ -161,11 +161,11 @@ def predict(self, params: Dict, dist: dx.MultivariateNormalTri) -> dx.Distributi class Bernoulli(AbstractLikelihood, NonConjugate): name: Optional[str] = "Bernoulli" - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """Initialise the parameter set of a Bernoulli likelihood. Args: - key (PRNGKeyType): A PRNG key. + key (KeyArray): A PRNG key. Returns: Dict: The parameters of the likelihood function (empty for the Bernoulli likelihood). diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index e3bc56b1f..31150ffdf 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -22,12 +22,15 @@ from chex import dataclass from jaxtyping import Array, Float +from jaxlinop import identity +import jaxlinop as jlo + from .config import get_defaults -from .covariance_operator import I from .gps import Prior from .likelihoods import AbstractLikelihood, Gaussian from .types import Dataset, PRNGKeyType from .utils import concat_dictionaries +from .gaussian_distribution import GaussianDistribution @dataclass @@ -37,7 +40,7 @@ class AbstractVariationalFamily: used within variational inference. """ - def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: + def __call__(self, *args: Any, **kwargs: Any) -> GaussianDistribution: """For a given set of parameters, compute the latent function's prediction under the variational approximation. @@ -47,7 +50,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: method. Returns: - Any: The output of the variational family's `predict` method. + GaussianDistribution: The output of the variational family's `predict` method. """ return self.predict(*args, **kwargs) @@ -66,7 +69,7 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: raise NotImplementedError @abc.abstractmethod - def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: + def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution: """Predict the GP's output given the input. Args: @@ -76,7 +79,7 @@ def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: ``predict`` method. Returns: - Any: The output of the variational family's ``predict`` method. + GaussianDistribution: The output of the variational family's ``predict`` method. """ raise NotImplementedError @@ -167,17 +170,19 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: μz = mean_function(params["mean_function"], z) Kzz = gram(kernel, params["kernel"], z) - Kzz += I(m) * jitter - Lz = Kzz.triangular_lower() + Kzz += identity(m) * jitter + + sqrt = jlo.LowerTriangularLinearOperator.from_dense(sqrt) + S = jlo.DenseLinearOperator.from_root(sqrt) - qu = dx.MultivariateNormalTri(jnp.atleast_1d(mu.squeeze()), sqrt) - pu = dx.MultivariateNormalTri(jnp.atleast_1d(μz.squeeze()), Lz) + qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S) + pu = GaussianDistribution(loc=jnp.atleast_1d(μz.squeeze()), scale=Kzz) - return kld_dense_dense(qu, pu) + return qu.kl_divergence(pu) def predict( self, params: Dict - ) -> Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: + ) -> Callable[[Float[Array, "N D"]], GaussianDistribution]: """ Compute the predictive distribution of the GP at the test inputs t. @@ -212,13 +217,11 @@ def predict( cross_covariance = kernel.cross_covariance Kzz = gram(kernel, params["kernel"], z) - Kzz += I(m) * jitter - Lz = Kzz.triangular_lower() + Kzz += identity(m) * jitter + Lz = Kzz.to_root() μz = mean_function(params["mean_function"], z) - def predict_fn( - test_inputs: Float[Array, "N D"] - ) -> dx.MultivariateNormalFullCovariance: + def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: # Unpack test inputs t, n_test = test_inputs, test_inputs.shape[0] @@ -228,10 +231,10 @@ def predict_fn( μt = mean_function(params["mean_function"], t) # Lz⁻¹ Kzt - Lz_inv_Kzt = jsp.linalg.solve_triangular(Lz, Kzt, lower=True) + Lz_inv_Kzt = Lz.solve(Kzt) # Kzz⁻¹ Kzt - Kzz_inv_Kzt = jsp.linalg.solve_triangular(Lz.T, Lz_inv_Kzt, lower=False) + Kzz_inv_Kzt = Lz.T.solve(Lz_inv_Kzt) # Ktz Kzz⁻¹ sqrt Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T, sqrt) @@ -245,10 +248,10 @@ def predict_fn( - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T) ) - covariance += I(n_test) * jitter + covariance += identity(n_test) * jitter - return dx.MultivariateNormalFullCovariance( - jnp.atleast_1d(mean.squeeze()), covariance.to_dense() + return GaussianDistribution( + loc=jnp.atleast_1d(mean.squeeze()), scale=covariance ) return predict_fn @@ -287,13 +290,17 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: mu = params["variational_family"]["moments"]["variational_mean"] sqrt = params["variational_family"]["moments"]["variational_root_covariance"] + sqrt = jlo.LowerTriangularLinearOperator.from_dense(sqrt) + S = jlo.DenseLinearOperator.from_root(sqrt) + # Compute whitened KL divergence - qu = dx.MultivariateNormalTri(jnp.atleast_1d(mu.squeeze()), sqrt) - return kld_dense_white(qu) + qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S) + pu = GaussianDistribution(loc=jnp.zeros_like(jnp.atleast_1d(mu.squeeze()))) + return qu.kl_divergence(pu) def predict( self, params: Dict - ) -> Callable[[Float[Array, "N D"]], dx.MultivariateNormalFullCovariance]: + ) -> Callable[[Float[Array, "N D"]], GaussianDistribution]: """Compute the predictive distribution of the GP at the test inputs t. This is the integral q(f(t)) = ∫ p(f(t)|u) q(u) du, which can be computed in closed form as @@ -323,12 +330,10 @@ def predict( cross_covariance = kernel.cross_covariance Kzz = gram(kernel, params["kernel"], z) - Kzz += I(m) * jitter - Lz = Kzz.triangular_lower() + Kzz += identity(m) * jitter + Lz = Kzz.to_root() - def predict_fn( - test_inputs: Float[Array, "N D"] - ) -> dx.MultivariateNormalFullCovariance: + def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: # Unpack test inputs t, n_test = test_inputs, test_inputs.shape[0] @@ -338,7 +343,7 @@ def predict_fn( μt = mean_function(params["mean_function"], t) # Lz⁻¹ Kzt - Lz_inv_Kzt = jsp.linalg.solve_triangular(Lz, Kzt, lower=True) + Lz_inv_Kzt = Lz.solve(Kzt) # Ktz Lz⁻ᵀ sqrt Ktz_Lz_invT_sqrt = jnp.matmul(Lz_inv_Kzt.T, sqrt) @@ -352,10 +357,10 @@ def predict_fn( - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + jnp.matmul(Ktz_Lz_invT_sqrt, Ktz_Lz_invT_sqrt.T) ) - covariance += I(n_test) * jitter + covariance += identity(n_test) * jitter - return dx.MultivariateNormalFullCovariance( - jnp.atleast_1d(mean.squeeze()), covariance.to_dense() + return GaussianDistribution( + loc=jnp.atleast_1d(mean.squeeze()), scale=covariance ) return predict_fn @@ -431,26 +436,26 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: # L = (L⁻¹)⁻¹I sqrt = jsp.linalg.solve_triangular(sqrt_inv, jnp.eye(m), lower=True) + sqrt = jlo.LowerTriangularLinearOperator.from_dense(sqrt) # S = LLᵀ: - S = jnp.matmul(sqrt, sqrt.T) + S = jlo.DenseLinearOperator.from_root(sqrt) # μ = Sθ₁ - mu = jnp.matmul(S, natural_vector) + mu = S @ natural_vector μz = mean_function(params["mean_function"], z) Kzz = gram(kernel, params["kernel"], z) - Kzz += I(m) * jitter - Lz = Kzz.triangular_lower() + Kzz += identity(m) * jitter - qu = dx.MultivariateNormalTri(jnp.atleast_1d(mu.squeeze()), sqrt) - pu = dx.MultivariateNormalTri(jnp.atleast_1d(μz.squeeze()), Lz) + qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S) + pu = GaussianDistribution(loc=jnp.atleast_1d(μz.squeeze()), scale=Kzz) - return kld_dense_dense(qu, pu) + return qu.kl_divergence(pu) def predict( self, params: Dict - ) -> Callable[[Float[Array, "N D"]], dx.MultivariateNormalFullCovariance]: + ) -> Callable[[Float[Array, "N D"]], GaussianDistribution]: """Compute the predictive distribution of the GP at the test inputs t. This is the integral q(f(t)) = ∫ p(f(t)|u) q(u) du, which can be computed in closed form as @@ -463,7 +468,7 @@ def predict( params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: - Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: A function that accepts a set of test points and will return the predictive distribution at those points. + Callable[[Float[Array, "N D"]], GaussianDistribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ jitter = get_defaults()["jitter"] @@ -500,8 +505,8 @@ def predict( mu = jnp.matmul(S, natural_vector) Kzz = gram(kernel, params["kernel"], z) - Kzz += I(m) * jitter - Lz = Kzz.triangular_lower() + Kzz += identity(m) * jitter + Lz = Kzz.to_root() μz = mean_function(params["mean_function"], z) def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormalTri: @@ -514,10 +519,10 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormalTri: μt = mean_function(params["mean_function"], t) # Lz⁻¹ Kzt - Lz_inv_Kzt = jsp.linalg.solve_triangular(Lz, Kzt, lower=True) + Lz_inv_Kzt = Lz.solve(Kzt) # Kzz⁻¹ Kzt - Kzz_inv_Kzt = jsp.linalg.solve_triangular(Lz.T, Lz_inv_Kzt, lower=False) + Kzz_inv_Kzt = Lz.T.solve(Lz_inv_Kzt) # Ktz Kzz⁻¹ L Ktz_Kzz_inv_L = jnp.matmul(Kzz_inv_Kzt.T, sqrt) @@ -531,10 +536,10 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormalTri: - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + jnp.matmul(Ktz_Kzz_inv_L, Ktz_Kzz_inv_L.T) ) - covariance += I(n_test) * jitter + covariance += identity(n_test) * jitter - return dx.MultivariateNormalFullCovariance( - jnp.atleast_1d(mean.squeeze()), covariance.to_dense() + return GaussianDistribution( + loc=jnp.atleast_1d(mean.squeeze()), scale=covariance ) return predict_fn @@ -610,24 +615,21 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: # S = η₂ - η₁ η₁ᵀ S = expectation_matrix - jnp.outer(mu, mu) - S += jnp.eye(m) * jitter - - # S = sqrt sqrtᵀ - sqrt = jnp.linalg.cholesky(S) + S = jlo.DenseLinearOperator(S) + S += identity(m) * jitter μz = mean_function(params["mean_function"], z) Kzz = gram(kernel, params["kernel"], z) - Kzz += I(m) * jitter - Lz = Kzz.triangular_lower() + Kzz += identity(m) * jitter - qu = dx.MultivariateNormalTri(jnp.atleast_1d(mu.squeeze()), sqrt) - pu = dx.MultivariateNormalTri(jnp.atleast_1d(μz.squeeze()), Lz) + qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S) + pu = GaussianDistribution(loc=jnp.atleast_1d(μz.squeeze()), scale=Kzz) - return kld_dense_dense(qu, pu) + return qu.kl_divergence(pu) def predict( self, params: Dict - ) -> Callable[[Float[Array, "N D"]], dx.MultivariateNormalFullCovariance]: + ) -> Callable[[Float[Array, "N D"]], GaussianDistribution]: """Compute the predictive distribution of the GP at the test inputs t. This is the integral q(f(t)) = ∫ p(f(t)|u) q(u) du, which can be computed in closed form as @@ -640,7 +642,7 @@ def predict( params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: - Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: A function that accepts a set of test points and will return the predictive distribution at those points. + Callable[[Float[Array, "N D"]], GaussianDistribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ jitter = get_defaults()["jitter"] @@ -667,19 +669,18 @@ def predict( # S = η₂ - η₁ η₁ᵀ S = expectation_matrix - jnp.matmul(mu, mu.T) - S += jnp.eye(m) * jitter + S = jlo.DenseLinearOperator(S) + S += identity(m) * jitter # S = sqrt sqrtᵀ - sqrt = jnp.linalg.cholesky(S) + sqrt = S.to_root().to_dense() Kzz = gram(kernel, params["kernel"], z) - Kzz += I(m) * jitter - Lz = Kzz.triangular_lower() + Kzz += identity(m) * jitter + Lz = Kzz.to_root() μz = mean_function(params["mean_function"], z) - def predict_fn( - test_inputs: Float[Array, "N D"] - ) -> dx.MultivariateNormalFullCovariance: + def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: # Unpack test inputs t, n_test = test_inputs, test_inputs.shape[0] @@ -689,10 +690,10 @@ def predict_fn( μt = mean_function(params["mean_function"], t) # Lz⁻¹ Kzt - Lz_inv_Kzt = jsp.linalg.solve_triangular(Lz, Kzt, lower=True) + Lz_inv_Kzt = Lz.solve(Kzt) # Kzz⁻¹ Kzt - Kzz_inv_Kzt = jsp.linalg.solve_triangular(Lz.T, Lz_inv_Kzt, lower=False) + Kzz_inv_Kzt = Lz.T.solve(Lz_inv_Kzt) # Ktz Kzz⁻¹ sqrt Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T, sqrt) @@ -706,10 +707,10 @@ def predict_fn( - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T) ) - covariance += I(n_test) * jitter + covariance += identity(n_test) * jitter - return dx.MultivariateNormalFullCovariance( - jnp.atleast_1d(mean.squeeze()), covariance.to_dense() + return GaussianDistribution( + loc=jnp.atleast_1d(mean.squeeze()), scale=covariance ) return predict_fn @@ -749,7 +750,7 @@ def predict( self, params: Dict, train_data: Dataset, - ) -> Callable[[Float[Array, "N D"]], dx.MultivariateNormalFullCovariance]: + ) -> Callable[[Float[Array, "N D"]], GaussianDistribution]: """Compute the predictive distribution of the GP at the test inputs. Args: @@ -761,10 +762,7 @@ def predict( """ jitter = get_defaults()["jitter"] - def predict_fn( - test_inputs: Float[Array, "N D"] - ) -> dx.MultivariateNormalFullCovariance: - # TODO - can we cache some of this? + def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: # Unpack test inputs t, n_test = test_inputs, test_inputs.shape[0] @@ -787,13 +785,13 @@ def predict_fn( Kzx = cross_covariance(kernel, params["kernel"], z, x) Kzz = gram(kernel, params["kernel"], z) - Kzz += I(m) * jitter + Kzz += identity(m) * jitter # Lz Lzᵀ = Kzz - Lz = Kzz.triangular_lower() + Lz = Kzz.to_root() # Lz⁻¹ Kzx - Lz_inv_Kzx = jsp.linalg.solve_triangular(Lz, Kzx, lower=True) + Lz_inv_Kzx = Lz.solve(Kzx) # A = Lz⁻¹ Kzt / σ A = Lz_inv_Kzx / jnp.sqrt(noise) @@ -813,16 +811,14 @@ def predict_fn( ) # Kzz⁻¹ Kzx (y - μx) - Kzz_inv_Kzx_diff = jsp.linalg.solve_triangular( - Lz.T, Lz_inv_Kzx_diff, lower=False - ) + Kzz_inv_Kzx_diff = Lz.T.solve(Lz_inv_Kzx_diff) Ktt = gram(kernel, params["kernel"], t) Kzt = cross_covariance(kernel, params["kernel"], z, t) μt = mean_function(params["mean_function"], t) # Lz⁻¹ Kzt - Lz_inv_Kzt = jsp.linalg.solve_triangular(Lz, Kzt, lower=True) + Lz_inv_Kzt = Lz.solve(Kzt) # L⁻¹ Lz⁻¹ Kzt L_inv_Lz_inv_Kzt = jsp.linalg.solve_triangular(L, Lz_inv_Kzt, lower=True) @@ -836,90 +832,15 @@ def predict_fn( - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + jnp.matmul(L_inv_Lz_inv_Kzt.T, L_inv_Lz_inv_Kzt) ) - covariance += I(n_test) * jitter + covariance += identity(n_test) * jitter - return dx.MultivariateNormalFullCovariance( - jnp.atleast_1d(mean.squeeze()), covariance.to_dense() + return GaussianDistribution( + loc=jnp.atleast_1d(mean.squeeze()), scale=covariance ) return predict_fn -# TODO: Abstract these out to a KL divergence that accepts a linear operator to facilate structured covarainces other than dense. -def kld_dense_dense( - q: dx.MultivariateNormalTri, p: dx.MultivariateNormalTri -) -> Float[Array, "1"]: - """Kullback-Leibler divergence KL[q(x)||p(x)] between two dense covariance Gaussian distributions - q(x) = N(x; μq, Σq) and p(x) = N(x; μp, Σp). - - Args: - q (dx.MultivariateNormalTri): A multivariate Gaussian distribution. - p (dx.MultivariateNormalTri): A multivariate Gaussian distribution. - - Returns: - Float[Array, "1"]: The KL divergence between the two distributions. - """ - - q_mu = q.loc - q_sqrt = q.scale_tri - n = q_mu.shape[-1] - - p_mu = p.loc - p_sqrt = p.scale_tri - - diag = jnp.diag(q_sqrt) - - # Trace term tr(Σp⁻¹ Σq) - trace = jnp.sum(jnp.square(jsp.linalg.solve_triangular(p_sqrt, q_sqrt, lower=True))) - - # Mahalanobis term: μqᵀ Σp⁻¹ μq - alpha = jsp.linalg.solve_triangular(p_sqrt, p_mu - q_mu, lower=True) - mahalanobis = jnp.sum(jnp.square(alpha)) - - # log|Σq| - logdet_qcov = jnp.sum(jnp.log(jnp.square(diag))) - two_kl = mahalanobis - n - logdet_qcov + trace - - # log|Σp| - log_det_pcov = jnp.sum(jnp.log(jnp.square(jnp.diag(p_sqrt)))) - two_kl += log_det_pcov - - return two_kl / 2.0 - - -def kld_dense_white(q: dx.MultivariateNormalTri) -> Float[Array, "1"]: - """Kullback-Leibler divergence KL[q(x)||p(x)] between a dense covariance Gaussian distribution - q(x) = N(x; μq, Σq), and white indenity Gaussian p(x) = N(x; 0, I). - - This is useful for variational inference with a whitened variational family. - - Args: - q (dx.MultivariateNormalTri): A multivariate Gaussian distribution. - - Returns: - Float[Array, "1"]: The KL divergence between the two distributions. - """ - - q_mu = q.loc - q_sqrt = q.scale_tri - n = q_mu.shape[-1] - - diag = jnp.diag(q_sqrt) - - # Trace term tr(Σp⁻¹ Σq), and alpha for Mahalanobis term: - alpha = q_mu - trace = jnp.sum(jnp.square(q_sqrt)) - - # Mahalanobis term: μqᵀ Σp⁻¹ μq - mahalanobis = jnp.sum(jnp.square(alpha)) - - # log|Σq| (no log|Σp| as this is just zero!) - logdet_qcov = jnp.sum(jnp.log(jnp.square(diag))) - two_kl = mahalanobis - n - logdet_qcov + trace - - return two_kl / 2.0 - - __all__ = [ "AbstractVariationalFamily", "AbstractVariationalGaussian", diff --git a/gpjax/variational_inference.py b/gpjax/variational_inference.py index b16a82b75..6817e9784 100644 --- a/gpjax/variational_inference.py +++ b/gpjax/variational_inference.py @@ -22,8 +22,9 @@ from jax import vmap from jaxtyping import Array, Float +from jaxlinop import identity + from .config import get_defaults -from .covariance_operator import I from .gps import AbstractPosterior from .likelihoods import Gaussian from .quadrature import gauss_hermite_quadrature @@ -124,12 +125,14 @@ def variational_expectation( x, y = batch.X, batch.y # Variational distribution q(f(·)) = N(f(·); μ(·), Σ(·, ·)) - q = self.variational_family + q = self.variational_family(params) # Compute variational mean, μ(x), and variance, √diag(Σ(x, x)), at training inputs, x - qx = vmap(q(params))(x[:, None]) - mean = qx.mean().val.reshape(-1, 1) - variance = qx.variance().val.reshape(-1, 1) + def q_moments(x): + qx = q(x) + return qx.mean(), qx.variance() + + mean, variance = vmap(q_moments)(x[:, None]) # log(p(y|f(x))) link_function = self.likelihood.link_function @@ -175,7 +178,7 @@ def elbo( # Unpack mean function and kernel mean_function = self.prior.mean_function - kernel = self.prior.kernel + kernel = self.prior.kernel # Unpack kernel computation gram, cross_covariance = kernel.gram, kernel.cross_covariance @@ -190,12 +193,12 @@ def elbo_fn(params: Dict) -> Float[Array, "1"]: noise = params["likelihood"]["obs_noise"] z = params["variational_family"]["inducing_inputs"] Kzz = gram(kernel, params["kernel"], z) - Kzz += I(m) * jitter + Kzz += identity(m) * jitter Kzx = cross_covariance(kernel, params["kernel"], z, x) Kxx_diag = vmap(kernel, in_axes=(None, 0, 0))(params["kernel"], x, x) μx = mean_function(params["mean_function"], x) - Lz = Kzz.triangular_lower() + Lz = Kzz.to_root() # Notation and derivation: # @@ -221,7 +224,7 @@ def elbo_fn(params: Dict) -> Float[Array, "1"]: # # with A and B defined as above. - A = jsp.linalg.solve_triangular(Lz, Kzx, lower=True) / jnp.sqrt(noise) + A = Lz.solve(Kzx) / jnp.sqrt(noise) # AAᵀ AAT = jnp.matmul(A, A.T) diff --git a/setup.py b/setup.py index c28fedebf..bf1e3233f 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ def find_version(*file_paths): "tqdm>=4.0.0", "ml-collections==0.1.0", "jaxtyping>=0.0.2", + "jaxlinop>=0.0.2", ] EXTRAS = { diff --git a/tests/test_covariance_operator.py b/tests/test_covariance_operator.py deleted file mode 100644 index 15b3aa59d..000000000 --- a/tests/test_covariance_operator.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - - -import jax.numpy as jnp -import jax.random as jr -import pytest -from jax.config import config - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) -from gpjax.covariance_operator import ( - CovarianceOperator, - DenseCovarianceOperator, - DiagonalCovarianceOperator, - I, -) - -_key = jr.PRNGKey(seed=42) - - -def test_covariance_operator() -> None: - with pytest.raises(TypeError): - CovarianceOperator() - - -@pytest.mark.parametrize("n", [1, 10, 100]) -def test_adding_jax_arrays(n: int) -> None: - import jax.random as jr - - # Create PSD jax arrays matricies A and B: - key_a, key_b = jr.split(_key) - - sqrt_A = jr.uniform(key_a, (n, n)) - sqrt_B = jr.uniform(key_b, (n, n)) - - A = sqrt_A @ sqrt_A.T - B = sqrt_B @ sqrt_B.T - - # Create corresponding covariance operators: - Dense_A = DenseCovarianceOperator(matrix=A) - Dense_B = DenseCovarianceOperator(matrix=B) - - # Test addition: - assert jnp.all((Dense_A + B).to_dense() == A + B) - assert jnp.all((B + Dense_A).to_dense() == B + A) - assert jnp.all((Dense_A + Dense_B).to_dense() == A + B) - - # Test subtraction: - assert jnp.all((Dense_A - Dense_B).to_dense() == A - B) - assert jnp.all((Dense_A - B).to_dense() == A - B) - assert jnp.all((B - Dense_A).to_dense() == B - A) - - -@pytest.mark.parametrize("n", [1, 10, 100]) -def test_dense_covariance_operator(n: int) -> None: - - sqrt = jr.normal(_key, (n, n)) - dense = sqrt.T @ sqrt # Dense random matrix is positive definite. - cov = DenseCovarianceOperator(matrix=dense) - - # Test shape: - assert cov.shape == (n, n) - - # Test solve: - b = jr.normal(_key, (n, 1)) - x = cov.solve(b) - assert jnp.allclose(b, dense @ x) - - # Test to_dense method: - assert jnp.allclose(dense, cov.to_dense()) - - # Test to_diag method: - assert jnp.allclose(jnp.diag(dense), cov.diagonal()) - - # Test log determinant: - assert jnp.allclose(jnp.linalg.slogdet(dense)[1], cov.log_det()) - - # Test trace: - assert jnp.allclose(jnp.trace(dense), cov.trace()) - - # Test lower triangular: - assert jnp.allclose(jnp.linalg.cholesky(dense), cov.triangular_lower()) - - # Test adding diagonal covariance operator to dense linear operator: - diag = DiagonalCovarianceOperator(diag=jnp.diag(dense)) - cov = cov + (diag * jnp.pi) - assert jnp.allclose(dense + jnp.pi * jnp.diag(jnp.diag(dense)), cov.to_dense()) - - -@pytest.mark.parametrize("constant", [1.0, 3.5]) -@pytest.mark.parametrize("n", [1, 10, 100]) -def test_diagonal_covariance_operator(n: int, constant: float) -> None: - diag = 1.0 + jnp.arange(n, dtype=jnp.float64) - diag_cov = DiagonalCovarianceOperator(diag=diag) - - # Test shape: - assert diag_cov.shape == (n, n) - - # Test trace: - assert jnp.allclose(jnp.sum(diag), diag_cov.trace()) - - # Test diagonal: - assert jnp.allclose(diag, diag_cov.diagonal()) - - # Test multiplying with scalar: - assert ((diag_cov * constant).diagonal() == constant * diag).all() - - # Test solve: - assert (jnp.diagonal(diag_cov.solve(rhs=jnp.eye(n))) == 1.0 / diag).all() - - # Test to_dense method: - dense = diag_cov.to_dense() - assert (dense - jnp.diag(diag) == 0.0).all() - assert dense.shape == (n, n) - - # Test log determinant: - assert diag_cov.log_det() == 2.0 * jnp.sum(jnp.log(diag)) - - # Test lower triangular: - L = diag_cov.triangular_lower() - assert L.shape == (n, n) - assert (L == jnp.diag(jnp.sqrt(diag))).all() - - # Test adding two diagonal covariance operators: - diag_other = 5.1 + 2 * jnp.arange(n, dtype=jnp.float64) - other = DiagonalCovarianceOperator(diag=diag_other) - assert ((diag_cov + other).diagonal() == diag + diag_other).all() - - -@pytest.mark.parametrize("n", [1, 10, 100]) -def test_identity_covariance_operator(n: int) -> None: - - # Create identity matrix of size nxn: - Identity = I(n) - - # Check iniation: - assert Identity.diag.shape == (n,) - assert (Identity.diag == 1.0).all() - assert isinstance(Identity.diag, jnp.ndarray) - assert isinstance(Identity, DiagonalCovarianceOperator) - - # Check iid covariance construction: - noise = jnp.array([jnp.pi]) - cov = Identity * noise - assert cov.diag.shape == (n,) - assert (cov.diag == jnp.pi).all() - assert isinstance(cov.diag, jnp.ndarray) - assert isinstance(cov, DiagonalCovarianceOperator) - - # Check addition to diagonal covariance: - diag = jnp.arange(n) - diag_gram_matrix = DiagonalCovarianceOperator(diag=diag) - cov = diag_gram_matrix + Identity - assert cov.diag.shape == (n,) - assert (cov.diag == (1.0 + jnp.arange(n))).all() - assert isinstance(cov.diag, jnp.ndarray) - assert isinstance(cov, DiagonalCovarianceOperator) - - # Check addition to dense covariance: - dense = jnp.arange(n**2, dtype=jnp.float64).reshape(n, n) - dense_matrix = DenseCovarianceOperator(matrix=dense) - cov = dense_matrix + (Identity * noise) - assert cov.matrix.shape == (n, n) - assert (jnp.diag(cov.matrix) == jnp.diag((noise + dense))).all() - assert isinstance(cov.matrix, jnp.ndarray) - assert isinstance(cov, DenseCovarianceOperator) diff --git a/tests/test_gaussian_distribution.py b/tests/test_gaussian_distribution.py new file mode 100644 index 000000000..38824f0a9 --- /dev/null +++ b/tests/test_gaussian_distribution.py @@ -0,0 +1,129 @@ +# %% [markdown] +# Copyright 2022 The Jax Linear Operator Contributors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import jax.numpy as jnp +import jax.random as jr +import pytest +from jax.config import config + +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) + +from jaxlinop.dense_linear_operator import DenseLinearOperator +from jaxlinop.diagonal_linear_operator import DiagonalLinearOperator + +from gpjax.gaussian_distribution import GaussianDistribution + +_key = jr.PRNGKey(seed=42) + +from distrax import MultivariateNormalDiag, MultivariateNormalFullCovariance + +def approx_equal(res: jnp.ndarray, actual: jnp.ndarray) -> bool: + """Check if two arrays are approximately equal.""" + return jnp.linalg.norm(res - actual) < 1e-6 + +@pytest.mark.parametrize("n", [1, 2, 5, 100]) +def test_array_arguments(n: int) -> None: + key_mean, key_sqrt = jr.split(_key, 2) + mean = jr.uniform(key_mean, shape=(n,)) + sqrt = jr.uniform(key_sqrt, shape=(n, n)) + covariance = sqrt @ sqrt.T + + dist = GaussianDistribution(loc=mean, scale=DenseLinearOperator(covariance)) + + assert approx_equal(dist.mean(), mean) + assert approx_equal(dist.variance(), covariance.diagonal()) + assert approx_equal(dist.stddev(), jnp.sqrt(covariance.diagonal())) + assert approx_equal(dist.covariance(), covariance) + + y = jr.uniform(_key, shape=(n,)) + + distrax_dist = MultivariateNormalFullCovariance(loc=mean, covariance_matrix=covariance) + + assert approx_equal(dist.log_prob(y), distrax_dist.log_prob(y)) + assert approx_equal(dist.kl_divergence(dist), 0.0) + + +@pytest.mark.parametrize("n", [1, 2, 5, 100]) +def test_diag_linear_operator(n: int) -> None: + key_mean, key_diag = jr.split(_key, 2) + mean = jr.uniform(key_mean, shape=(n,)) + diag = jr.uniform(key_diag, shape=(n,)) + + dist_diag = GaussianDistribution(loc=mean, scale=DiagonalLinearOperator(diag ** 2)) + distrax_dist = MultivariateNormalDiag(loc=mean, scale_diag=diag) + + assert approx_equal(dist_diag.mean(), distrax_dist.mean()) + assert approx_equal(dist_diag.variance(), distrax_dist.variance()) + assert approx_equal(dist_diag.stddev(), distrax_dist.stddev()) + assert approx_equal(dist_diag.covariance(), distrax_dist.covariance()) + + assert approx_equal(dist_diag.sample(seed=_key, sample_shape=(10,)), distrax_dist.sample(seed=_key, sample_shape=(10,))) + + y = jr.uniform(_key, shape=(n,)) + + assert approx_equal(dist_diag.log_prob(y), distrax_dist.log_prob(y)) + assert approx_equal(dist_diag.log_prob(y), distrax_dist.log_prob(y)) + + assert approx_equal(dist_diag.kl_divergence(dist_diag), 0.0) + + + +@pytest.mark.parametrize("n", [1, 2, 5, 100]) +def test_dense_linear_operator(n: int) -> None: + key_mean, key_sqrt = jr.split(_key, 2) + mean = jr.uniform(key_mean, shape=(n,)) + sqrt = jr.uniform(key_sqrt, shape=(n, n)) + covariance = sqrt @ sqrt.T + + sqrt = jnp.linalg.cholesky(covariance + jnp.eye(n) * 1e-10) + + dist_dense = GaussianDistribution(loc=mean, scale=DenseLinearOperator(covariance)) + distrax_dist = MultivariateNormalFullCovariance(loc=mean, covariance_matrix=covariance) + + assert approx_equal(dist_dense.mean(), distrax_dist.mean()) + assert approx_equal(dist_dense.variance(), distrax_dist.variance()) + assert approx_equal(dist_dense.stddev(), distrax_dist.stddev()) + assert approx_equal(dist_dense.covariance(), distrax_dist.covariance()) + + assert approx_equal(dist_dense.sample(seed=_key, sample_shape=(10,)), distrax_dist.sample(seed=_key, sample_shape=(10,))) + + y = jr.uniform(_key, shape=(n,)) + + assert approx_equal(dist_dense.log_prob(y), distrax_dist.log_prob(y)) + assert approx_equal(dist_dense.kl_divergence(dist_dense), 0.0) + + +@pytest.mark.parametrize("n", [1, 2, 5, 100]) +def test_kl_divergence(n: int) -> None: + key_a, key_b = jr.split(_key, 2) + mean_a = jr.uniform(key_a, shape=(n,)) + mean_b = jr.uniform(key_b, shape=(n,)) + sqrt_a = jr.uniform(key_a, shape=(n, n)) + sqrt_b = jr.uniform(key_b, shape=(n, n)) + covariance_a = sqrt_a @ sqrt_a.T + covariance_b = sqrt_b @ sqrt_b.T + + + dist_a = GaussianDistribution(loc=mean_a, scale=DenseLinearOperator(covariance_a)) + dist_b = GaussianDistribution(loc=mean_b, scale=DenseLinearOperator(covariance_b)) + + distrax_dist_a = MultivariateNormalFullCovariance(loc=mean_a, covariance_matrix=covariance_a) + distrax_dist_b = MultivariateNormalFullCovariance(loc=mean_b, covariance_matrix=covariance_b) + + assert approx_equal(dist_a.kl_divergence(dist_b), distrax_dist_a.kl_divergence(distrax_dist_b)) + diff --git a/tests/test_kernels.py b/tests/test_kernels.py index 1921feb1f..b7e931eaa 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -25,9 +25,9 @@ from jax.config import config from jaxtyping import Array, Float -from gpjax.covariance_operator import ( - CovarianceOperator, - I, +from jaxlinop import ( + LinearOperator, + identity, ) from gpjax.kernels import ( @@ -110,7 +110,7 @@ def test_gram(kernel: AbstractKernel, dim: int, n: int) -> None: # Test gram matrix: Kxx = gram(kernel, params, x) - assert isinstance(Kxx, CovarianceOperator) + assert isinstance(Kxx, LinearOperator) assert Kxx.shape == (n, n) @@ -173,7 +173,7 @@ def test_pos_def( # Test gram matrix eigenvalues are positive: Kxx = gram(kern, params, x) - Kxx += I(n) * _jitter + Kxx += identity(n) * _jitter eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) assert (eigen_values > 0.0).all() @@ -226,7 +226,7 @@ def test_polynomial( # Unpack kernel computation gram = kern.gram - + # Check name assert kern.name == f"Polynomial Degree: {degree}" @@ -246,7 +246,7 @@ def test_polynomial( assert Kxx.shape[0] == Kxx.shape[1] # Test positive definiteness - Kxx += I(n) * _jitter + Kxx += identity(n) * _jitter eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) assert (eigen_values > 0).all() @@ -257,7 +257,7 @@ def test_active_dim(kernel: AbstractKernel) -> None: perm_length = 2 dim_pairs = list(permutations(dim_list, r=perm_length)) n_dims = len(dim_list) - + # Generate random inputs x = jr.normal(_initialise_key, shape=(20, n_dims)) @@ -318,7 +318,7 @@ def test_combination_kernel( assert len(combination_kernel.kernel_set) == n_kerns assert isinstance(combination_kernel.kernel_set, list) assert isinstance(combination_kernel.kernel_set[0], AbstractKernel) - + # Compute gram matrix Kxx = gram(combination_kernel, params, x) @@ -327,7 +327,7 @@ def test_combination_kernel( assert Kxx.shape[1] == n # Check positive definiteness - Kxx += I(n) * _jitter + Kxx += identity(n) * _jitter eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) assert (eigen_values > 0).all() @@ -380,11 +380,11 @@ def test_sum_kern_value(k1: AbstractKernel, k2: AbstractKernel) -> None: "k2", [RBF(), Matern12(), Matern32(), Matern52(), Polynomial()] ) def test_prod_kern_value(k1: AbstractKernel, k2: AbstractKernel) -> None: - - # Create inputs + + # Create inputs n = 10 x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) - + # Create product kernel prod_kernel = ProductKernel(kernel_set=[k1, k2]) @@ -422,10 +422,10 @@ def test_graph_kernel(): n_edges = 40 G = nx.gnm_random_graph(n_verticies, n_edges, seed=123) x = jnp.arange(n_verticies).reshape(-1, 1) - + # Compute graph laplacian L = nx.laplacian_matrix(G).toarray() + jnp.eye(n_verticies) * 1e-12 - + # Create graph kernel kern = GraphKernel(laplacian=L) assert isinstance(kern, GraphKernel) @@ -445,16 +445,15 @@ def test_graph_kernel(): "smoothness", "variance", ] - + # Compute gram matrix Kxx = gram(kern, params, x) assert Kxx.shape == (n_verticies, n_verticies) # Check positive definiteness - Kxx += I(n_verticies) * _jitter + Kxx += identity(n_verticies) * _jitter eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) assert all(eigen_values > 0) - @pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52, Polynomial])