From 28275382cac05c6e886e668dd9044746968ad708 Mon Sep 17 00:00:00 2001 From: Alexander Immer Date: Fri, 10 Dec 2021 16:36:11 +0100 Subject: [PATCH] Fix device of eye in symeig --- laplace/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/laplace/utils.py b/laplace/utils.py index 607a03e7..5b059d31 100644 --- a/laplace/utils.py +++ b/laplace/utils.py @@ -131,7 +131,7 @@ def diagonal_add_scalar(X, value): values = X.new_ones(X.shape[0]).mul(value) return X.index_put(tuple(indices.t()), values, accumulate=True) - + def symeig(M): """Symetric eigendecomposition avoiding failure cases by adding and removing jitter to the diagonal. @@ -152,7 +152,7 @@ def symeig(M): except RuntimeError: # did not converge logging.info('SYMEIG: adding jitter, did not converge.') # use W L W^T + I = W (L + I) W^T - M = M + torch.eye(M.shape[0]) + M = M + torch.eye(M.shape[0], device=M.device) try: L, W = torch.linalg.eigh(M, UPLO='U') L -= 1. @@ -195,13 +195,13 @@ def expand_prior_precision(prior_prec, model): Parameters ---------- prior_prec : torch.Tensor 1-dimensional - prior precision + prior precision model : torch.nn.Module torch model with parameters that are regularized by prior_prec Returns ------- - expanded_prior_prec : torch.Tensor + expanded_prior_prec : torch.Tensor expanded prior precision has the same shape as model parameters """ theta = parameters_to_vector(model.parameters())