diff --git a/gpjax/lower_cholesky.py b/gpjax/lower_cholesky.py index 69286ead0..274dbee26 100644 --- a/gpjax/lower_cholesky.py +++ b/gpjax/lower_cholesky.py @@ -30,6 +30,11 @@ def lower_cholesky(A: cola.ops.LinearOperator): # noqa: F811 cola.ops.LinearOperator: The lower Cholesky factor of A. """ + if cola.PSD not in A.annotations: + raise ValueError( + "Expected LinearOperator to be PSD, did you forget to use cola.PSD?" + ) + return cola.ops.Triangular(jnp.linalg.cholesky(A.to_dense()), lower=True) @@ -41,3 +46,15 @@ def _(A: cola.ops.Diagonal): # noqa: F811 @lower_cholesky.dispatch def _(A: cola.ops.Identity): # noqa: F811 return A + + +@lower_cholesky.dispatch +def _(A: cola.ops.Kronecker): # noqa: F811 + return cola.ops.Kronecker(*[lower_cholesky(Ai) for Ai in A.Ms]) + + +@lower_cholesky.dispatch +def _(A: cola.ops.BlockDiag): # noqa: F811 + return cola.ops.BlockDiag( + *[lower_cholesky(Ai) for Ai in A.Ms], multiplicities=A.multiplicities + ) diff --git a/tests/test_lower_cholesky.py b/tests/test_lower_cholesky.py index e00c1ea23..d8487c522 100644 --- a/tests/test_lower_cholesky.py +++ b/tests/test_lower_cholesky.py @@ -1,18 +1,30 @@ +import cola from cola.ops import ( + BlockDiag, Dense, Diagonal, I_like, Identity, + Kronecker, Triangular, ) import jax.numpy as jnp +import jax.scipy as jsp +import pytest from gpjax.lower_cholesky import lower_cholesky def test_dense() -> None: array = jnp.array([[3.0, 1.0], [1.0, 3.0]]) - A = Dense(array) + + # Test that we get an error if we don't use cola.PSD! + with pytest.raises(ValueError): + A = Dense(array) + lower_cholesky(A) + + # Now we annoate with cola.PSD and test for the correct output. + A = cola.PSD(Dense(array)) L = lower_cholesky(A) assert isinstance(L, Triangular) @@ -21,7 +33,7 @@ def test_dense() -> None: def test_diagonal() -> None: array = jnp.array([1.0, 2.0]) - A = Diagonal(array) + A = cola.PSD(Diagonal(array)) L = lower_cholesky(A) assert isinstance(L, Diagonal) @@ -33,3 +45,67 @@ def test_identity() -> None: L = lower_cholesky(A) assert isinstance(L, Identity) assert jnp.allclose(L.to_dense(), jnp.eye(2)) + + +def test_kronecker() -> None: + array_a = jnp.array([[3.0, 1.0], [1.0, 3.0]]) + array_b = jnp.array([[2.0, 0.0], [0.0, 2.0]]) + + # Create LinearOperators. + A = Dense(array_a) + B = Dense(array_b) + + # Annotate with cola.PSD. + A = cola.PSD(A) + B = cola.PSD(B) + + K = Kronecker(A, B) + + # Cholesky decomposition. + L = lower_cholesky(K) + + # Check types. + assert isinstance(L, Kronecker) + assert isinstance(L.Ms[0], Triangular) + assert isinstance(L.Ms[1], Triangular) + + # Check values. + assert jnp.allclose(L.Ms[0].to_dense(), jnp.linalg.cholesky(array_a)) + assert jnp.allclose(L.Ms[1].to_dense(), jnp.linalg.cholesky(array_b)) + assert jnp.allclose( + L.to_dense(), + jnp.kron(jnp.linalg.cholesky(array_a), jnp.linalg.cholesky(array_b)), + ) + + +def test_block_diag() -> None: + array_a = jnp.array([[3.0, 1.0], [1.0, 3.0]]) + array_b = jnp.array([[2.0, 0.0], [0.0, 2.0]]) + + # Create LinearOperators. + A = Dense(array_a) + B = Dense(array_b) + + # Annotate with cola.PSD. + A = cola.PSD(A) + B = cola.PSD(B) + + B = BlockDiag(A, B, multiplicities=[2, 3]) + + # Cholesky decomposition. + L = lower_cholesky(B) + + # Check types. + assert isinstance(L, BlockDiag) + assert isinstance(L.Ms[0], Triangular) + assert isinstance(L.Ms[1], Triangular) + + # Check values. + assert jnp.allclose(L.Ms[0].to_dense(), jnp.linalg.cholesky(array_a)) + assert jnp.allclose(L.Ms[1].to_dense(), jnp.linalg.cholesky(array_b)) + + # Check multiplicities. + assert L.multiplicities == [2, 3] + + # Check dense. + assert jnp.allclose(jnp.linalg.cholesky(B.to_dense()), L.to_dense())