Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for diagonal Kronecker factors in Kron matrix class #136

Merged
merged 4 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 33 additions & 28 deletions laplace/utils/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,9 @@ def init_from_model(cls, model, device):
kfacs = list()
for p in params:
if p.ndim == 1: # bias
P = p.size(0)
kfacs.append([torch.zeros(P, P, device=device)])
elif 4 >= p.ndim >= 2: # fully connected or conv
if p.ndim == 2: # fully connected
P_in, P_out = p.size()
elif p.ndim > 2:
P_in, P_out = p.shape[0], np.prod(p.shape[1:])

kfacs.append([
torch.zeros(P_in, P_in, device=device),
torch.zeros(P_out, P_out, device=device)
])
kfacs.append([0.])
elif 4 >= p.ndim >= 2: # fully connected or or embedding or conv
runame marked this conversation as resolved.
Show resolved Hide resolved
kfacs.append([0., 0.])
else:
raise ValueError('Invalid parameter shape in network.')
return cls(kfacs)
Expand All @@ -76,7 +67,7 @@ def __add__(self, other):
if not isinstance(other, Kron):
raise ValueError('Can only add Kron to Kron.')

kfacs = [[Hi.add(Hj) for Hi, Hj in zip(Fi, Fj)]
kfacs = [[Hi + Hj for Hi, Hj in zip(Fi, Fj)]
for Fi, Fj in zip(self.kfacs, other.kfacs)]
return Kron(kfacs)

Expand Down Expand Up @@ -118,7 +109,14 @@ def decompose(self, damping=False):
for F in self.kfacs:
Qs, ls = list(), list()
for Hi in F:
l, Q = symeig(Hi)
if Hi.ndim > 1:
# Dense Kronecker factor.
l, Q = symeig(Hi)
else:
# Diagonal Kronecker factor.
l = Hi
# This might be too memory intensive since len(Hi) can be large.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this comment that is from sorting previously?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it is due to the torch.eye, since the Kronecker factor is diagonal it might be too large to explicitly build a len(H_i) x len(H_i) matrix.

Q = torch.eye(len(Hi), dtype=Hi.dtype, device=Hi.device)
Qs.append(Q)
ls.append(l)
eigvecs.append(Qs)
Expand Down Expand Up @@ -149,14 +147,16 @@ def _bmm(self, W: torch.Tensor) -> torch.Tensor:
Q = Fs[0]
p = len(Q)
W_p = W[:, cur_p:cur_p+p].T
SW.append((Q @ W_p).T)
SW.append((Q @ W_p).T if Q.ndim > 1 else (Q.view(-1, 1) * W_p).T)
cur_p += p
elif len(Fs) == 2:
Q, H = Fs
p_in, p_out = len(Q), len(H)
p = p_in * p_out
W_p = W[:, cur_p:cur_p+p].reshape(B * K, p_in, p_out)
SW.append((Q @ W_p @ H.T).reshape(B * K, p_in * p_out))
QW_p= Q @ W_p if Q.ndim > 1 else Q.view(-1, 1) * W_p
QW_pHt = QW_p @ H.T if H.ndim > 1 else QW_p * H.view(1, -1)
SW.append(QW_pHt.reshape(B * K, p_in * p_out))
cur_p += p
else:
raise AttributeError('Shape mismatch')
Expand Down Expand Up @@ -204,11 +204,12 @@ def logdet(self) -> torch.Tensor:
logdet = 0
for F in self.kfacs:
if len(F) == 1:
logdet += F[0].logdet()
logdet += F[0].logdet() if F[0].ndim > 1 else F[0].prod().log()
runame marked this conversation as resolved.
Show resolved Hide resolved
else: # len(F) == 2
Hi, Hj = F
p_in, p_out = len(Hi), len(Hj)
logdet += p_out * Hi.logdet() + p_in * Hj.logdet()
logdet += p_out * Hi.logdet() if Hi.ndim > 1 else p_out * Hi.prod().log()
logdet += p_in * Hj.logdet() if Hj.ndim > 1 else p_in * Hj.prod().log()
runame marked this conversation as resolved.
Show resolved Hide resolved
return logdet

def diag(self) -> torch.Tensor:
Expand All @@ -220,10 +221,12 @@ def diag(self) -> torch.Tensor:
"""
diags = list()
for F in self.kfacs:
F0 = F[0].diag() if F[0].ndim > 1 else F[0]
if len(F) == 1:
diags.append(F[0].diagonal())
diags.append(F0)
else:
diags.append(torch.ger(F[0].diagonal(), F[1].diagonal()).flatten())
F1 = F[1].diag() if F[1].ndim > 1 else F[1]
diags.append(torch.outer(F0, F1).flatten())
return torch.cat(diags)

def to_matrix(self) -> torch.Tensor:
Expand All @@ -237,10 +240,12 @@ def to_matrix(self) -> torch.Tensor:
"""
blocks = list()
for F in self.kfacs:
F0 = F[0] if F[0].ndim > 1 else F[0].diag()
if len(F) == 1:
blocks.append(F[0])
blocks.append(F0)
else:
blocks.append(kron(F[0], F[1]))
F1 = F[1] if F[1].ndim > 1 else F[1].diag()
blocks.append(kron(F0, F1))
return block_diag(blocks)

# for commutative operations
Expand Down Expand Up @@ -350,9 +355,9 @@ def logdet(self) -> torch.Tensor:
l1, l2 = ls
if self.damping:
l1d, l2d = l1 + torch.sqrt(delta), l2 + torch.sqrt(delta)
logdet += torch.log(torch.ger(l1d, l2d)).sum()
logdet += torch.log(torch.outer(l1d, l2d)).sum()
else:
logdet += torch.log(torch.ger(l1, l2) + delta).sum()
logdet += torch.log(torch.outer(l1, l2) + delta).sum()
else:
raise ValueError('Too many Kronecker factors. Something went wrong.')
return logdet
Expand Down Expand Up @@ -391,9 +396,9 @@ def _bmm(self, W: torch.Tensor, exponent: float = -1) -> torch.Tensor:
p = len(l1) * len(l2)
if self.damping:
l1d, l2d = l1 + torch.sqrt(delta), l2 + torch.sqrt(delta)
ldelta_exp = torch.pow(torch.ger(l1d, l2d), exponent).unsqueeze(0)
ldelta_exp = torch.pow(torch.outer(l1d, l2d), exponent).unsqueeze(0)
else:
ldelta_exp = torch.pow(torch.ger(l1, l2) + delta, exponent).unsqueeze(0)
ldelta_exp = torch.pow(torch.outer(l1, l2) + delta, exponent).unsqueeze(0)
p_in, p_out = len(l1), len(l2)
W_p = W[:, cur_p:cur_p+p].reshape(B * K, p_in, p_out)
W_p = (Q1.T @ W_p @ Q2) * ldelta_exp
Expand Down Expand Up @@ -457,9 +462,9 @@ def to_matrix(self, exponent: float = 1) -> torch.Tensor:
Q = kron(Q1, Q2)
if self.damping:
delta_sqrt = torch.sqrt(delta)
l = torch.pow(torch.ger(l1 + delta_sqrt, l2 + delta_sqrt), exponent)
l = torch.pow(torch.outer(l1 + delta_sqrt, l2 + delta_sqrt), exponent)
else:
l = torch.pow(torch.ger(l1, l2) + delta, exponent)
l = torch.pow(torch.outer(l1, l2) + delta, exponent)
L = torch.diag(l.flatten())
blocks.append(Q @ L @ Q.T)
return block_diag(blocks)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ def test_laplace_init(laplace, model):
elif laplace == LowRankLaplace:
assert lap.H is None
else:
H = [[k.clone() for k in kfac] for kfac in lap.H.kfacs]
H = [[torch.tensor(k) for k in kfac] for kfac in lap.H.kfacs]
runame marked this conversation as resolved.
Show resolved Hide resolved
lap._init_H()
for kfac1, kfac2 in zip(H, lap.H.kfacs):
for k1, k2 in zip(kfac1, kfac2):
assert torch.allclose(k1, k2)
assert torch.allclose(k1, torch.tensor(k2))
runame marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.xfail(strict=True)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_lllaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ def test_laplace_init(laplace, model):
lap._init_H()
assert torch.allclose(H, lap.H)
else:
H = [[k.clone() for k in kfac] for kfac in lap.H.kfacs]
H = [[torch.tensor(k) for k in kfac] for kfac in lap.H.kfacs]
lap._init_H()
for kfac1, kfac2 in zip(H, lap.H.kfacs):
for k1, k2 in zip(kfac1, kfac2):
assert torch.allclose(k1, k2)
assert torch.allclose(k1, torch.tensor(k2))


@pytest.mark.parametrize('laplace', flavors)
Expand All @@ -77,11 +77,11 @@ def test_laplace_large_init(laplace, large_model):
lap._init_H()
assert torch.allclose(H, lap.H)
else:
H = [[k.clone() for k in kfac] for kfac in lap.H.kfacs]
H = [[torch.tensor(k) for k in kfac] for kfac in lap.H.kfacs]
lap._init_H()
for kfac1, kfac2 in zip(H, lap.H.kfacs):
for k1, k2 in zip(kfac1, kfac2):
assert torch.allclose(k1, k2)
assert torch.allclose(k1, torch.tensor(k2))


@pytest.mark.parametrize('laplace', flavors)
Expand Down
132 changes: 119 additions & 13 deletions tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from laplace.utils import Kron, block_diag
from laplace.utils import kron as kron_prod
from laplace.curvature import BackPackGGN
from tests.utils import get_psd_matrix, jacobians_naive
from laplace.curvature import BackPackGGN, AsdlGGN
from tests.utils import get_psd_matrix, get_diag_psd_matrix, jacobians_naive


torch.set_default_tensor_type(torch.DoubleTensor)
Expand All @@ -32,20 +32,20 @@ def small_model():

def test_init_from_model(model):
kron = Kron.init_from_model(model, 'cpu')
expected_sizes = [[20*20, 3*3], [20*20], [2*2, 20*20], [2*2]]
for facs, exp_facs in zip(kron.kfacs, expected_sizes):
expected_init_values = [[0., 0.], [0.], [0., 0.], [0.]]
for facs, exp_facs in zip(kron.kfacs, expected_init_values):
assert len(facs) == len(exp_facs)
for fi, exp_fi in zip(facs, exp_facs):
assert torch.all(fi == 0)
assert np.prod(fi.shape) == exp_fi
assert fi == exp_fi


def test_init_from_iterable(model):
kron = Kron.init_from_model(model.parameters(), 'cpu')
expected_sizes = [[20*20, 3*3], [20*20], [2*2, 20*20], [2*2]]
for facs, exp_facs in zip(kron.kfacs, expected_sizes):
expected_init_values = [[0., 0.], [0.], [0., 0.], [0.]]
for facs, exp_facs in zip(kron.kfacs, expected_init_values):
assert len(facs) == len(exp_facs)
for fi, exp_fi in zip(facs, exp_facs):
assert torch.all(fi == 0)
assert np.prod(fi.shape) == exp_fi
assert fi == exp_fi


def test_addition(model):
Expand All @@ -57,6 +57,7 @@ def test_addition(model):
for fi, exp_fi in zip(facs, exp_facs):
assert torch.allclose(fi, exp_fi)


def test_multiplication():
# kron * x should be the same as the expanded kronecker product * x
expected_sizes = [[20, 3], [20], [2, 20], [2]]
Expand All @@ -71,10 +72,12 @@ def test_multiplication():
facs = kron_prod(*facs)
assert torch.allclose(exp, facs)


def test_decompose():
expected_sizes = [[20, 3], [20], [2, 20], [2]]
P = 20 * 3 + 20 + 2 * 20 + 2
torch.manual_seed(7171)
# Dense Kronecker factors.
kfacs = [[get_psd_matrix(i) for i in sizes] for sizes in expected_sizes]
kron = Kron(kfacs)
kron_decomp = kron.decompose()
Expand All @@ -93,6 +96,25 @@ def test_decompose():
SW_kron = kron.bmm(W)
SW_kron_decomp = kron_decomp.bmm(W, exponent=1)
assert torch.allclose(SW_kron, SW_kron_decomp)
# Diagonal Kronecker factors.
diag_kfacs = [[get_diag_psd_matrix(i) for i in sizes] for sizes in expected_sizes]
kron = Kron(diag_kfacs)
kron_decomp = kron.decompose()
for facs, Qs, ls in zip(kron.kfacs, kron_decomp.eigenvectors, kron_decomp.eigenvalues):
if len(facs) == 1:
H, Q, l = facs[0], Qs[0], ls[0]
reconstructed = (Q @ torch.diag(l) @ Q.T).diag()
assert torch.allclose(H, reconstructed, rtol=1e-3)
if len(facs) == 2:
gtruth = kron_prod(facs[0].diag(), facs[1].diag())
rec_1 = Qs[0] @ torch.diag(ls[0]) @ Qs[0].T
rec_2 = Qs[1] @ torch.diag(ls[1]) @ Qs[1].T
reconstructed = kron_prod(rec_1, rec_2)
assert torch.allclose(gtruth, reconstructed, rtol=1e-2)
W = torch.randn(P)
SW_kron = kron.bmm(W)
SW_kron_decomp = kron_decomp.bmm(W, exponent=1)
assert torch.allclose(SW_kron, SW_kron_decomp)


def test_logdet_consistent():
Expand All @@ -102,17 +124,23 @@ def test_logdet_consistent():
kron = Kron(kfacs)
kron_decomp = kron.decompose()
assert torch.allclose(kron.logdet(), kron_decomp.logdet())
diag_kfacs = [[get_diag_psd_matrix(i) for i in sizes] for sizes in expected_sizes]
kron = Kron(diag_kfacs)
kron_decomp = kron.decompose()
assert torch.allclose(kron.logdet(), kron_decomp.logdet())


def test_bmm(small_model):
def test_bmm_dense(small_model):
model = small_model
# model = single_output_model
X = torch.randn(5, 3)
y = torch.randn(5, 2)

# Dense Kronecker factors.
backend = BackPackGGN(model, 'regression', stochastic=False)
loss, kron = backend.kron(X, y, N=5)
_, kron = backend.kron(X, y, N=5)
kron_decomp = kron.decompose()
Js, f = jacobians_naive(model, X)
Js, _ = jacobians_naive(model, X)
blocks = list()
for F in kron.kfacs:
if len(F) == 1:
Expand Down Expand Up @@ -167,9 +195,77 @@ def test_bmm(small_model):
assert torch.allclose(JS, JS_nodecomp)


def test_bmm_diag(small_model):
model = small_model
# model = single_output_model
X = torch.randn(5, 3)
y = torch.randn(5, 2)

# Diagonal Kronecker factors.
backend = AsdlGGN(model, 'regression', stochastic=False)
_, kron = backend.kron(X, y, N=5, diag_A=True, diag_B=True)
kron_decomp = kron.decompose()
Js, _ = jacobians_naive(model, X)
blocks = list()
for F in kron.kfacs:
F0 = F[0] if F[0].ndim > 1 else F[0].diag()
if len(F) == 1:
blocks.append(F0)
else:
F1 = F[1] if F[1].ndim > 1 else F[1].diag()
blocks.append(kron_prod(F0, F1))
S = block_diag(blocks)
assert torch.allclose(S, S.T)
assert torch.allclose(S.diagonal(), kron.diag())

# test J @ Kron_decomp @ Jt (square form)
JS = kron_decomp.bmm(Js, exponent=1)
JS_true = Js @ S
JSJ_true = torch.bmm(JS_true, Js.transpose(1,2))
JSJ = torch.bmm(JS, Js.transpose(1,2))
assert torch.allclose(JSJ, JSJ_true)
assert torch.allclose(JS, JS_true)

# test J @ Kron @ Jt (square form)
JS_nodecomp = kron.bmm(Js)
JSJ_nodecomp = torch.bmm(JS_nodecomp, Js.transpose(1,2))
assert torch.allclose(JSJ_nodecomp, JSJ)
assert torch.allclose(JS_nodecomp, JS)

# test J @ S_inv @ J (funcitonal variance)
JSJ = kron_decomp.inv_square_form(Js)
S_inv = S.inverse()
JSJ_true = torch.bmm(Js @ S_inv, Js.transpose(1,2))
assert torch.allclose(JSJ, JSJ_true)

# test J @ S^-1/2 (sampling)
JS = kron_decomp.bmm(Js, exponent=-1/2)
JSJ = torch.bmm(JS, Js.transpose(1,2))
l, Q = torch.linalg.eigh(S_inv, UPLO='U')
JS_true = Js @ Q @ torch.diag(torch.sqrt(l)) @ Q.T
JSJ_true = torch.bmm(JS_true, Js.transpose(1,2))
assert torch.allclose(JS, JS_true)
assert torch.allclose(JSJ, JSJ_true)

# test different Js shapes:
# 2 - dimensional
JS = kron_decomp.bmm(Js[:, 0, :].squeeze(), exponent=1)
JS_nodecomp = kron.bmm(Js[:, 0, :].squeeze())
JS_true = Js[:, 0, :].squeeze() @ S
assert torch.allclose(JS, JS_true)
assert torch.allclose(JS, JS_nodecomp)
# 1 - dimensional
JS = kron_decomp.bmm(Js[0, 0, :].squeeze(), exponent=1)
JS_nodecomp = kron.bmm(Js[0, 0, :].squeeze())
JS_true = Js[0, 0, :].squeeze() @ S
assert torch.allclose(JS, JS_true)
assert torch.allclose(JS, JS_nodecomp)


def test_matrix_consistent():
expected_sizes = [[20, 3], [20], [2, 20], [2]]
torch.manual_seed(7171)
# Dense Kronecker factors.
kfacs = [[get_psd_matrix(i) for i in sizes] for sizes in expected_sizes]
kron = Kron(kfacs)
kron_decomp = kron.decompose()
Expand All @@ -179,3 +275,13 @@ def test_matrix_consistent():
M_true.diagonal().add_(3.4)
kron_decomp += torch.tensor(3.4)
assert torch.allclose(M_true, kron_decomp.to_matrix(exponent=1))
# Diagonal Kronecker factors.
diag_kfacs = [[get_diag_psd_matrix(i) for i in sizes] for sizes in expected_sizes]
kron = Kron(diag_kfacs)
kron_decomp = kron.decompose()
assert torch.allclose(kron.to_matrix(), kron_decomp.to_matrix(exponent=1))
assert torch.allclose(kron.to_matrix().inverse(), kron_decomp.to_matrix(exponent=-1))
M_true = kron.to_matrix()
M_true.diagonal().add_(3.4)
kron_decomp += torch.tensor(3.4)
assert torch.allclose(M_true, kron_decomp.to_matrix(exponent=1))
Loading