Skip to content

Commit

Permalink
Merge pull request #121 from AlexImmer/precision-diag
Browse files Browse the repository at this point in the history
Add `KronDecomposed.diag()` feature
  • Loading branch information
aleximmer authored Mar 22, 2023
2 parents e94a000 + db8f532 commit ada1c6f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 7 deletions.
49 changes: 42 additions & 7 deletions laplace/utils/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def diag(self) -> torch.Tensor:
if len(F) == 1:
diags.append(F[0].diagonal())
else:
diags.append(torch.ger(F[0].diagonal(), F[1].diagonal()).flatten())
diags.append(torch.outer(F[0].diagonal(), F[1].diagonal()).flatten())
return torch.cat(diags)

def to_matrix(self) -> torch.Tensor:
Expand Down Expand Up @@ -345,9 +345,9 @@ def logdet(self) -> torch.Tensor:
l1, l2 = ls
if self.damping:
l1d, l2d = l1 + torch.sqrt(delta), l2 + torch.sqrt(delta)
logdet += torch.log(torch.ger(l1d, l2d)).sum()
logdet += torch.log(torch.outer(l1d, l2d)).sum()
else:
logdet += torch.log(torch.ger(l1, l2) + delta).sum()
logdet += torch.log(torch.outer(l1, l2) + delta).sum()
else:
raise ValueError('Too many Kronecker factors. Something went wrong.')
return logdet
Expand Down Expand Up @@ -386,9 +386,9 @@ def _bmm(self, W: torch.Tensor, exponent: float = -1) -> torch.Tensor:
p = len(l1) * len(l2)
if self.damping:
l1d, l2d = l1 + torch.sqrt(delta), l2 + torch.sqrt(delta)
ldelta_exp = torch.pow(torch.ger(l1d, l2d), exponent).unsqueeze(0)
ldelta_exp = torch.pow(torch.outer(l1d, l2d), exponent).unsqueeze(0)
else:
ldelta_exp = torch.pow(torch.ger(l1, l2) + delta, exponent).unsqueeze(0)
ldelta_exp = torch.pow(torch.outer(l1, l2) + delta, exponent).unsqueeze(0)
p_in, p_out = len(l1), len(l2)
W_p = W[:, cur_p:cur_p+p].reshape(B * K, p_in, p_out)
W_p = (Q1.T @ W_p @ Q2) * ldelta_exp
Expand Down Expand Up @@ -432,11 +432,46 @@ def bmm(self, W: torch.Tensor, exponent: float = -1) -> torch.Tensor:
else:
raise ValueError('Invalid shape for W')

def diag(self, exponent: float = 1) -> torch.Tensor:
"""Extract diagonal of the entire decomposed Kronecker factorization.
Parameters
----------
exponent: float, default=1
exponent of the Kronecker factorization
Returns
-------
diag : torch.Tensor
"""
diags = list()
for Qs, ls, delta in zip(self.eigenvectors, self.eigenvalues, self.deltas):
if len(ls) == 1:
Ql = Qs[0] * torch.pow(ls[0] + delta, exponent).reshape(1, -1)
d = torch.einsum('mp,mp->m', Ql, Qs[0]) # only compute inner products for diag
diags.append(d)
else:
Q1, Q2 = Qs
l1, l2 = ls
if self.damping:
delta_sqrt = torch.sqrt(delta)
l = torch.pow(torch.outer(l1 + delta_sqrt, l2 + delta_sqrt), exponent)
else:
l = torch.pow(torch.outer(l1, l2) + delta, exponent)
d = torch.einsum('mp,nq,pq,mp,nq->mn', Q1, Q2, l, Q1, Q2).flatten()
diags.append(d)
return torch.cat(diags)

def to_matrix(self, exponent: float = 1) -> torch.Tensor:
"""Make the Kronecker factorization dense by computing the kronecker product.
Warning: this should only be used for testing purposes as it will allocate
large amounts of memory for big architectures.
Parameters
----------
exponent: float, default=1
exponent of the Kronecker factorization
Returns
-------
block_diag : torch.Tensor
Expand All @@ -452,9 +487,9 @@ def to_matrix(self, exponent: float = 1) -> torch.Tensor:
Q = kron(Q1, Q2)
if self.damping:
delta_sqrt = torch.sqrt(delta)
l = torch.pow(torch.ger(l1 + delta_sqrt, l2 + delta_sqrt), exponent)
l = torch.pow(torch.outer(l1 + delta_sqrt, l2 + delta_sqrt), exponent)
else:
l = torch.pow(torch.ger(l1, l2) + delta, exponent)
l = torch.pow(torch.outer(l1, l2) + delta, exponent)
L = torch.diag(l.flatten())
blocks.append(Q @ L @ Q.T)
return block_diag(blocks)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,15 @@ def test_matrix_consistent():
M_true.diagonal().add_(3.4)
kron_decomp += torch.tensor(3.4)
assert torch.allclose(M_true, kron_decomp.to_matrix(exponent=1))


def test_diag():
expected_sizes = [[20, 3], [20], [2, 20], [2]]
torch.manual_seed(7171)
kfacs = [[get_psd_matrix(i) for i in sizes] for sizes in expected_sizes]
kron = Kron(kfacs)
kron_decomp = kron.decompose()
assert torch.allclose(kron.diag(), kron_decomp.diag(exponent=1))
assert torch.allclose(kron.diag(), torch.diag(kron.to_matrix()))
assert torch.allclose(kron_decomp.diag(), torch.diag(kron_decomp.to_matrix()))
assert torch.allclose(kron_decomp.diag(exponent=-1), torch.diag(kron_decomp.to_matrix(exponent=-1)))

0 comments on commit ada1c6f

Please sign in to comment.