Skip to content

Commit

Permalink
Add Kronecker and BlockDiag to lower_cholesky.
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-dodd committed Aug 30, 2023
1 parent 177a5d7 commit 2c9ebce
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 2 deletions.
17 changes: 17 additions & 0 deletions gpjax/lower_cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ def lower_cholesky(A: cola.ops.LinearOperator): # noqa: F811
cola.ops.LinearOperator: The lower Cholesky factor of A.
"""

if cola.PSD not in A.annotations:
raise ValueError(
"Expected LinearOperator to be PSD, did you forget to use cola.PSD?"
)

return cola.ops.Triangular(jnp.linalg.cholesky(A.to_dense()), lower=True)


Expand All @@ -41,3 +46,15 @@ def _(A: cola.ops.Diagonal): # noqa: F811
@lower_cholesky.dispatch
def _(A: cola.ops.Identity): # noqa: F811
return A


@lower_cholesky.dispatch
def _(A: cola.ops.Kronecker): # noqa: F811
return cola.ops.Kronecker(*[lower_cholesky(Ai) for Ai in A.Ms])


@lower_cholesky.dispatch
def _(A: cola.ops.BlockDiag): # noqa: F811
return cola.ops.BlockDiag(
*[lower_cholesky(Ai) for Ai in A.Ms], multiplicities=A.multiplicities
)
80 changes: 78 additions & 2 deletions tests/test_lower_cholesky.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,30 @@
import cola
from cola.ops import (
BlockDiag,
Dense,
Diagonal,
I_like,
Identity,
Kronecker,
Triangular,
)
import jax.numpy as jnp
import jax.scipy as jsp
import pytest

from gpjax.lower_cholesky import lower_cholesky


def test_dense() -> None:
array = jnp.array([[3.0, 1.0], [1.0, 3.0]])
A = Dense(array)

# Test that we get an error if we don't use cola.PSD!
with pytest.raises(ValueError):
A = Dense(array)
lower_cholesky(A)

# Now we annoate with cola.PSD and test for the correct output.
A = cola.PSD(Dense(array))

L = lower_cholesky(A)
assert isinstance(L, Triangular)
Expand All @@ -21,7 +33,7 @@ def test_dense() -> None:

def test_diagonal() -> None:
array = jnp.array([1.0, 2.0])
A = Diagonal(array)
A = cola.PSD(Diagonal(array))

L = lower_cholesky(A)
assert isinstance(L, Diagonal)
Expand All @@ -33,3 +45,67 @@ def test_identity() -> None:
L = lower_cholesky(A)
assert isinstance(L, Identity)
assert jnp.allclose(L.to_dense(), jnp.eye(2))


def test_kronecker() -> None:
array_a = jnp.array([[3.0, 1.0], [1.0, 3.0]])
array_b = jnp.array([[2.0, 0.0], [0.0, 2.0]])

# Create LinearOperators.
A = Dense(array_a)
B = Dense(array_b)

# Annotate with cola.PSD.
A = cola.PSD(A)
B = cola.PSD(B)

K = Kronecker(A, B)

# Cholesky decomposition.
L = lower_cholesky(K)

# Check types.
assert isinstance(L, Kronecker)
assert isinstance(L.Ms[0], Triangular)
assert isinstance(L.Ms[1], Triangular)

# Check values.
assert jnp.allclose(L.Ms[0].to_dense(), jnp.linalg.cholesky(array_a))
assert jnp.allclose(L.Ms[1].to_dense(), jnp.linalg.cholesky(array_b))
assert jnp.allclose(
L.to_dense(),
jnp.kron(jnp.linalg.cholesky(array_a), jnp.linalg.cholesky(array_b)),
)


def test_block_diag() -> None:
array_a = jnp.array([[3.0, 1.0], [1.0, 3.0]])
array_b = jnp.array([[2.0, 0.0], [0.0, 2.0]])

# Create LinearOperators.
A = Dense(array_a)
B = Dense(array_b)

# Annotate with cola.PSD.
A = cola.PSD(A)
B = cola.PSD(B)

B = BlockDiag(A, B, multiplicities=[2, 3])

# Cholesky decomposition.
L = lower_cholesky(B)

# Check types.
assert isinstance(L, BlockDiag)
assert isinstance(L.Ms[0], Triangular)
assert isinstance(L.Ms[1], Triangular)

# Check values.
assert jnp.allclose(L.Ms[0].to_dense(), jnp.linalg.cholesky(array_a))
assert jnp.allclose(L.Ms[1].to_dense(), jnp.linalg.cholesky(array_b))

# Check multiplicities.
assert L.multiplicities == [2, 3]

# Check dense.
assert jnp.allclose(jnp.linalg.cholesky(B.to_dense()), L.to_dense())

0 comments on commit 2c9ebce

Please sign in to comment.