Skip to content

Commit

Permalink
Fix device of eye in symeig
Browse files Browse the repository at this point in the history
  • Loading branch information
aleximmer committed Dec 10, 2021
1 parent 0915472 commit 2827538
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions laplace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def diagonal_add_scalar(X, value):
values = X.new_ones(X.shape[0]).mul(value)
return X.index_put(tuple(indices.t()), values, accumulate=True)


def symeig(M):
"""Symetric eigendecomposition avoiding failure cases by
adding and removing jitter to the diagonal.
Expand All @@ -152,7 +152,7 @@ def symeig(M):
except RuntimeError: # did not converge
logging.info('SYMEIG: adding jitter, did not converge.')
# use W L W^T + I = W (L + I) W^T
M = M + torch.eye(M.shape[0])
M = M + torch.eye(M.shape[0], device=M.device)
try:
L, W = torch.linalg.eigh(M, UPLO='U')
L -= 1.
Expand Down Expand Up @@ -195,13 +195,13 @@ def expand_prior_precision(prior_prec, model):
Parameters
----------
prior_prec : torch.Tensor 1-dimensional
prior precision
prior precision
model : torch.nn.Module
torch model with parameters that are regularized by prior_prec
Returns
-------
expanded_prior_prec : torch.Tensor
expanded_prior_prec : torch.Tensor
expanded prior precision has the same shape as model parameters
"""
theta = parameters_to_vector(model.parameters())
Expand Down

0 comments on commit 2827538

Please sign in to comment.