Skip to content

Commit

Permalink
Replace all ger() with outer()
Browse files Browse the repository at this point in the history
  • Loading branch information
aleximmer authored Mar 22, 2023
1 parent b1fd9c4 commit db8f532
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions laplace/utils/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def diag(self) -> torch.Tensor:
if len(F) == 1:
diags.append(F[0].diagonal())
else:
diags.append(torch.ger(F[0].diagonal(), F[1].diagonal()).flatten())
diags.append(torch.outer(F[0].diagonal(), F[1].diagonal()).flatten())
return torch.cat(diags)

def to_matrix(self) -> torch.Tensor:
Expand Down Expand Up @@ -345,9 +345,9 @@ def logdet(self) -> torch.Tensor:
l1, l2 = ls
if self.damping:
l1d, l2d = l1 + torch.sqrt(delta), l2 + torch.sqrt(delta)
logdet += torch.log(torch.ger(l1d, l2d)).sum()
logdet += torch.log(torch.outer(l1d, l2d)).sum()
else:
logdet += torch.log(torch.ger(l1, l2) + delta).sum()
logdet += torch.log(torch.outer(l1, l2) + delta).sum()
else:
raise ValueError('Too many Kronecker factors. Something went wrong.')
return logdet
Expand Down Expand Up @@ -386,9 +386,9 @@ def _bmm(self, W: torch.Tensor, exponent: float = -1) -> torch.Tensor:
p = len(l1) * len(l2)
if self.damping:
l1d, l2d = l1 + torch.sqrt(delta), l2 + torch.sqrt(delta)
ldelta_exp = torch.pow(torch.ger(l1d, l2d), exponent).unsqueeze(0)
ldelta_exp = torch.pow(torch.outer(l1d, l2d), exponent).unsqueeze(0)
else:
ldelta_exp = torch.pow(torch.ger(l1, l2) + delta, exponent).unsqueeze(0)
ldelta_exp = torch.pow(torch.outer(l1, l2) + delta, exponent).unsqueeze(0)
p_in, p_out = len(l1), len(l2)
W_p = W[:, cur_p:cur_p+p].reshape(B * K, p_in, p_out)
W_p = (Q1.T @ W_p @ Q2) * ldelta_exp
Expand Down Expand Up @@ -455,9 +455,9 @@ def diag(self, exponent: float = 1) -> torch.Tensor:
l1, l2 = ls
if self.damping:
delta_sqrt = torch.sqrt(delta)
l = torch.pow(torch.ger(l1 + delta_sqrt, l2 + delta_sqrt), exponent)
l = torch.pow(torch.outer(l1 + delta_sqrt, l2 + delta_sqrt), exponent)
else:
l = torch.pow(torch.ger(l1, l2) + delta, exponent)
l = torch.pow(torch.outer(l1, l2) + delta, exponent)
d = torch.einsum('mp,nq,pq,mp,nq->mn', Q1, Q2, l, Q1, Q2).flatten()
diags.append(d)
return torch.cat(diags)
Expand Down Expand Up @@ -487,9 +487,9 @@ def to_matrix(self, exponent: float = 1) -> torch.Tensor:
Q = kron(Q1, Q2)
if self.damping:
delta_sqrt = torch.sqrt(delta)
l = torch.pow(torch.ger(l1 + delta_sqrt, l2 + delta_sqrt), exponent)
l = torch.pow(torch.outer(l1 + delta_sqrt, l2 + delta_sqrt), exponent)
else:
l = torch.pow(torch.ger(l1, l2) + delta, exponent)
l = torch.pow(torch.outer(l1, l2) + delta, exponent)
L = torch.diag(l.flatten())
blocks.append(Q @ L @ Q.T)
return block_diag(blocks)
Expand Down

0 comments on commit db8f532

Please sign in to comment.