From 98a9800db0a28c023ba73444be3cd8d490eeca62 Mon Sep 17 00:00:00 2001 From: runame Date: Fri, 11 Aug 2023 16:02:27 +0200 Subject: [PATCH 1/4] Make Kron init agnostic to the shape of the Kronecker factors --- laplace/utils/matrix.py | 17 ++++------------- tests/test_baselaplace.py | 4 ++-- tests/test_lllaplace.py | 8 ++++---- tests/test_matrix.py | 16 ++++++++-------- 4 files changed, 18 insertions(+), 27 deletions(-) diff --git a/laplace/utils/matrix.py b/laplace/utils/matrix.py index aed873a1..0c3e99d0 100644 --- a/laplace/utils/matrix.py +++ b/laplace/utils/matrix.py @@ -46,18 +46,9 @@ def init_from_model(cls, model, device): kfacs = list() for p in params: if p.ndim == 1: # bias - P = p.size(0) - kfacs.append([torch.zeros(P, P, device=device)]) - elif 4 >= p.ndim >= 2: # fully connected or conv - if p.ndim == 2: # fully connected - P_in, P_out = p.size() - elif p.ndim > 2: - P_in, P_out = p.shape[0], np.prod(p.shape[1:]) - - kfacs.append([ - torch.zeros(P_in, P_in, device=device), - torch.zeros(P_out, P_out, device=device) - ]) + kfacs.append([0.]) + elif 4 >= p.ndim >= 2: # fully connected or or embedding or conv + kfacs.append([0., 0.]) else: raise ValueError('Invalid parameter shape in network.') return cls(kfacs) @@ -76,7 +67,7 @@ def __add__(self, other): if not isinstance(other, Kron): raise ValueError('Can only add Kron to Kron.') - kfacs = [[Hi.add(Hj) for Hi, Hj in zip(Fi, Fj)] + kfacs = [[Hi + Hj for Hi, Hj in zip(Fi, Fj)] for Fi, Fj in zip(self.kfacs, other.kfacs)] return Kron(kfacs) diff --git a/tests/test_baselaplace.py b/tests/test_baselaplace.py index dedd14c4..b8b845de 100644 --- a/tests/test_baselaplace.py +++ b/tests/test_baselaplace.py @@ -63,11 +63,11 @@ def test_laplace_init(laplace, model): elif laplace == LowRankLaplace: assert lap.H is None else: - H = [[k.clone() for k in kfac] for kfac in lap.H.kfacs] + H = [[torch.tensor(k) for k in kfac] for kfac in lap.H.kfacs] lap._init_H() for kfac1, kfac2 in zip(H, lap.H.kfacs): for k1, k2 in zip(kfac1, kfac2): - assert torch.allclose(k1, k2) + assert torch.allclose(k1, torch.tensor(k2)) @pytest.mark.xfail(strict=True) diff --git a/tests/test_lllaplace.py b/tests/test_lllaplace.py index cbff89f8..abf4843e 100644 --- a/tests/test_lllaplace.py +++ b/tests/test_lllaplace.py @@ -54,11 +54,11 @@ def test_laplace_init(laplace, model): lap._init_H() assert torch.allclose(H, lap.H) else: - H = [[k.clone() for k in kfac] for kfac in lap.H.kfacs] + H = [[torch.tensor(k) for k in kfac] for kfac in lap.H.kfacs] lap._init_H() for kfac1, kfac2 in zip(H, lap.H.kfacs): for k1, k2 in zip(kfac1, kfac2): - assert torch.allclose(k1, k2) + assert torch.allclose(k1, torch.tensor(k2)) @pytest.mark.parametrize('laplace', flavors) @@ -77,11 +77,11 @@ def test_laplace_large_init(laplace, large_model): lap._init_H() assert torch.allclose(H, lap.H) else: - H = [[k.clone() for k in kfac] for kfac in lap.H.kfacs] + H = [[torch.tensor(k) for k in kfac] for kfac in lap.H.kfacs] lap._init_H() for kfac1, kfac2 in zip(H, lap.H.kfacs): for k1, k2 in zip(kfac1, kfac2): - assert torch.allclose(k1, k2) + assert torch.allclose(k1, torch.tensor(k2)) @pytest.mark.parametrize('laplace', flavors) diff --git a/tests/test_matrix.py b/tests/test_matrix.py index 6847eaf6..ad2d3308 100644 --- a/tests/test_matrix.py +++ b/tests/test_matrix.py @@ -32,20 +32,20 @@ def small_model(): def test_init_from_model(model): kron = Kron.init_from_model(model, 'cpu') - expected_sizes = [[20*20, 3*3], [20*20], [2*2, 20*20], [2*2]] - for facs, exp_facs in zip(kron.kfacs, expected_sizes): + expected_init_values = [[0., 0.], [0.], [0., 0.], [0.]] + for facs, exp_facs in zip(kron.kfacs, expected_init_values): + assert len(facs) == len(exp_facs) for fi, exp_fi in zip(facs, exp_facs): - assert torch.all(fi == 0) - assert np.prod(fi.shape) == exp_fi + assert fi == exp_fi def test_init_from_iterable(model): kron = Kron.init_from_model(model.parameters(), 'cpu') - expected_sizes = [[20*20, 3*3], [20*20], [2*2, 20*20], [2*2]] - for facs, exp_facs in zip(kron.kfacs, expected_sizes): + expected_init_values = [[0., 0.], [0.], [0., 0.], [0.]] + for facs, exp_facs in zip(kron.kfacs, expected_init_values): + assert len(facs) == len(exp_facs) for fi, exp_fi in zip(facs, exp_facs): - assert torch.all(fi == 0) - assert np.prod(fi.shape) == exp_fi + assert fi == exp_fi def test_addition(model): From c80954acead2ad903e575f86a2224ac887dbdd39 Mon Sep 17 00:00:00 2001 From: runame Date: Fri, 11 Aug 2023 21:49:03 +0200 Subject: [PATCH 2/4] Add support for diag Kronecker factors in Kron matrix class --- laplace/utils/matrix.py | 44 +++++++++------ tests/test_matrix.py | 116 ++++++++++++++++++++++++++++++++++++++-- tests/utils.py | 6 +++ 3 files changed, 146 insertions(+), 20 deletions(-) diff --git a/laplace/utils/matrix.py b/laplace/utils/matrix.py index 0c3e99d0..c4d13ad0 100644 --- a/laplace/utils/matrix.py +++ b/laplace/utils/matrix.py @@ -109,7 +109,14 @@ def decompose(self, damping=False): for F in self.kfacs: Qs, ls = list(), list() for Hi in F: - l, Q = symeig(Hi) + if Hi.ndim > 1: + # Dense Kronecker factor. + l, Q = symeig(Hi) + else: + # Diagonal Kronecker factor. + l, indices = Hi.sort() + # This might be too memory intensive since len(Hi) can be large. + Q = torch.eye(len(Hi), dtype=Hi.dtype, device=Hi.device)[indices].T Qs.append(Q) ls.append(l) eigvecs.append(Qs) @@ -140,14 +147,16 @@ def _bmm(self, W: torch.Tensor) -> torch.Tensor: Q = Fs[0] p = len(Q) W_p = W[:, cur_p:cur_p+p].T - SW.append((Q @ W_p).T) + SW.append((Q @ W_p).T if Q.ndim > 1 else (Q.view(-1, 1) * W_p).T) cur_p += p elif len(Fs) == 2: Q, H = Fs p_in, p_out = len(Q), len(H) p = p_in * p_out W_p = W[:, cur_p:cur_p+p].reshape(B * K, p_in, p_out) - SW.append((Q @ W_p @ H.T).reshape(B * K, p_in * p_out)) + QW_p= Q @ W_p if Q.ndim > 1 else Q.view(-1, 1) * W_p + QW_pHt = QW_p @ H.T if H.ndim > 1 else QW_p * H.view(1, -1) + SW.append(QW_pHt.reshape(B * K, p_in * p_out)) cur_p += p else: raise AttributeError('Shape mismatch') @@ -195,11 +204,12 @@ def logdet(self) -> torch.Tensor: logdet = 0 for F in self.kfacs: if len(F) == 1: - logdet += F[0].logdet() + logdet += F[0].logdet() if F[0].ndim > 1 else F[0].prod().log() else: # len(F) == 2 Hi, Hj = F p_in, p_out = len(Hi), len(Hj) - logdet += p_out * Hi.logdet() + p_in * Hj.logdet() + logdet += p_out * Hi.logdet() if Hi.ndim > 1 else p_out * Hi.prod().log() + logdet += p_in * Hj.logdet() if Hj.ndim > 1 else p_in * Hj.prod().log() return logdet def diag(self) -> torch.Tensor: @@ -211,10 +221,12 @@ def diag(self) -> torch.Tensor: """ diags = list() for F in self.kfacs: + F0 = F[0].diag() if F[0].ndim > 1 else F[0] if len(F) == 1: - diags.append(F[0].diagonal()) + diags.append(F0) else: - diags.append(torch.ger(F[0].diagonal(), F[1].diagonal()).flatten()) + F1 = F[1].diag() if F[1].ndim > 1 else F[1] + diags.append(torch.outer(F0, F1).flatten()) return torch.cat(diags) def to_matrix(self) -> torch.Tensor: @@ -228,10 +240,12 @@ def to_matrix(self) -> torch.Tensor: """ blocks = list() for F in self.kfacs: + F0 = F[0] if F[0].ndim > 1 else F[0].diag() if len(F) == 1: - blocks.append(F[0]) + blocks.append(F0) else: - blocks.append(kron(F[0], F[1])) + F1 = F[1] if F[1].ndim > 1 else F[1].diag() + blocks.append(kron(F0, F1)) return block_diag(blocks) # for commutative operations @@ -341,9 +355,9 @@ def logdet(self) -> torch.Tensor: l1, l2 = ls if self.damping: l1d, l2d = l1 + torch.sqrt(delta), l2 + torch.sqrt(delta) - logdet += torch.log(torch.ger(l1d, l2d)).sum() + logdet += torch.log(torch.outer(l1d, l2d)).sum() else: - logdet += torch.log(torch.ger(l1, l2) + delta).sum() + logdet += torch.log(torch.outer(l1, l2) + delta).sum() else: raise ValueError('Too many Kronecker factors. Something went wrong.') return logdet @@ -382,9 +396,9 @@ def _bmm(self, W: torch.Tensor, exponent: float = -1) -> torch.Tensor: p = len(l1) * len(l2) if self.damping: l1d, l2d = l1 + torch.sqrt(delta), l2 + torch.sqrt(delta) - ldelta_exp = torch.pow(torch.ger(l1d, l2d), exponent).unsqueeze(0) + ldelta_exp = torch.pow(torch.outer(l1d, l2d), exponent).unsqueeze(0) else: - ldelta_exp = torch.pow(torch.ger(l1, l2) + delta, exponent).unsqueeze(0) + ldelta_exp = torch.pow(torch.outer(l1, l2) + delta, exponent).unsqueeze(0) p_in, p_out = len(l1), len(l2) W_p = W[:, cur_p:cur_p+p].reshape(B * K, p_in, p_out) W_p = (Q1.T @ W_p @ Q2) * ldelta_exp @@ -448,9 +462,9 @@ def to_matrix(self, exponent: float = 1) -> torch.Tensor: Q = kron(Q1, Q2) if self.damping: delta_sqrt = torch.sqrt(delta) - l = torch.pow(torch.ger(l1 + delta_sqrt, l2 + delta_sqrt), exponent) + l = torch.pow(torch.outer(l1 + delta_sqrt, l2 + delta_sqrt), exponent) else: - l = torch.pow(torch.ger(l1, l2) + delta, exponent) + l = torch.pow(torch.outer(l1, l2) + delta, exponent) L = torch.diag(l.flatten()) blocks.append(Q @ L @ Q.T) return block_diag(blocks) diff --git a/tests/test_matrix.py b/tests/test_matrix.py index ad2d3308..8ca3b987 100644 --- a/tests/test_matrix.py +++ b/tests/test_matrix.py @@ -6,8 +6,8 @@ from laplace.utils import Kron, block_diag from laplace.utils import kron as kron_prod -from laplace.curvature import BackPackGGN -from tests.utils import get_psd_matrix, jacobians_naive +from laplace.curvature import BackPackGGN, AsdlGGN +from tests.utils import get_psd_matrix, get_diag_psd_matrix, jacobians_naive torch.set_default_tensor_type(torch.DoubleTensor) @@ -57,6 +57,7 @@ def test_addition(model): for fi, exp_fi in zip(facs, exp_facs): assert torch.allclose(fi, exp_fi) + def test_multiplication(): # kron * x should be the same as the expanded kronecker product * x expected_sizes = [[20, 3], [20], [2, 20], [2]] @@ -71,10 +72,12 @@ def test_multiplication(): facs = kron_prod(*facs) assert torch.allclose(exp, facs) + def test_decompose(): expected_sizes = [[20, 3], [20], [2, 20], [2]] P = 20 * 3 + 20 + 2 * 20 + 2 torch.manual_seed(7171) + # Dense Kronecker factors. kfacs = [[get_psd_matrix(i) for i in sizes] for sizes in expected_sizes] kron = Kron(kfacs) kron_decomp = kron.decompose() @@ -93,6 +96,25 @@ def test_decompose(): SW_kron = kron.bmm(W) SW_kron_decomp = kron_decomp.bmm(W, exponent=1) assert torch.allclose(SW_kron, SW_kron_decomp) + # Diagonal Kronecker factors. + diag_kfacs = [[get_diag_psd_matrix(i) for i in sizes] for sizes in expected_sizes] + kron = Kron(diag_kfacs) + kron_decomp = kron.decompose() + for facs, Qs, ls in zip(kron.kfacs, kron_decomp.eigenvectors, kron_decomp.eigenvalues): + if len(facs) == 1: + H, Q, l = facs[0], Qs[0], ls[0] + reconstructed = (Q @ torch.diag(l) @ Q.T).diag() + assert torch.allclose(H, reconstructed, rtol=1e-3) + if len(facs) == 2: + gtruth = kron_prod(facs[0].diag(), facs[1].diag()) + rec_1 = Qs[0] @ torch.diag(ls[0]) @ Qs[0].T + rec_2 = Qs[1] @ torch.diag(ls[1]) @ Qs[1].T + reconstructed = kron_prod(rec_1, rec_2) + assert torch.allclose(gtruth, reconstructed, rtol=1e-2) + W = torch.randn(P) + SW_kron = kron.bmm(W) + SW_kron_decomp = kron_decomp.bmm(W, exponent=1) + assert torch.allclose(SW_kron, SW_kron_decomp) def test_logdet_consistent(): @@ -102,17 +124,23 @@ def test_logdet_consistent(): kron = Kron(kfacs) kron_decomp = kron.decompose() assert torch.allclose(kron.logdet(), kron_decomp.logdet()) + diag_kfacs = [[get_diag_psd_matrix(i) for i in sizes] for sizes in expected_sizes] + kron = Kron(diag_kfacs) + kron_decomp = kron.decompose() + assert torch.allclose(kron.logdet(), kron_decomp.logdet()) -def test_bmm(small_model): +def test_bmm_dense(small_model): model = small_model # model = single_output_model X = torch.randn(5, 3) y = torch.randn(5, 2) + + # Dense Kronecker factors. backend = BackPackGGN(model, 'regression', stochastic=False) - loss, kron = backend.kron(X, y, N=5) + _, kron = backend.kron(X, y, N=5) kron_decomp = kron.decompose() - Js, f = jacobians_naive(model, X) + Js, _ = jacobians_naive(model, X) blocks = list() for F in kron.kfacs: if len(F) == 1: @@ -167,9 +195,77 @@ def test_bmm(small_model): assert torch.allclose(JS, JS_nodecomp) +def test_bmm_diag(small_model): + model = small_model + # model = single_output_model + X = torch.randn(5, 3) + y = torch.randn(5, 2) + + # Diagonal Kronecker factors. + backend = AsdlGGN(model, 'regression', stochastic=False) + _, kron = backend.kron(X, y, N=5, diag_A=True, diag_B=True) + kron_decomp = kron.decompose() + Js, _ = jacobians_naive(model, X) + blocks = list() + for F in kron.kfacs: + F0 = F[0] if F[0].ndim > 1 else F[0].diag() + if len(F) == 1: + blocks.append(F0) + else: + F1 = F[1] if F[1].ndim > 1 else F[1].diag() + blocks.append(kron_prod(F0, F1)) + S = block_diag(blocks) + assert torch.allclose(S, S.T) + assert torch.allclose(S.diagonal(), kron.diag()) + + # test J @ Kron_decomp @ Jt (square form) + JS = kron_decomp.bmm(Js, exponent=1) + JS_true = Js @ S + JSJ_true = torch.bmm(JS_true, Js.transpose(1,2)) + JSJ = torch.bmm(JS, Js.transpose(1,2)) + assert torch.allclose(JSJ, JSJ_true) + assert torch.allclose(JS, JS_true) + + # test J @ Kron @ Jt (square form) + JS_nodecomp = kron.bmm(Js) + JSJ_nodecomp = torch.bmm(JS_nodecomp, Js.transpose(1,2)) + assert torch.allclose(JSJ_nodecomp, JSJ) + assert torch.allclose(JS_nodecomp, JS) + + # test J @ S_inv @ J (funcitonal variance) + JSJ = kron_decomp.inv_square_form(Js) + S_inv = S.inverse() + JSJ_true = torch.bmm(Js @ S_inv, Js.transpose(1,2)) + assert torch.allclose(JSJ, JSJ_true) + + # test J @ S^-1/2 (sampling) + JS = kron_decomp.bmm(Js, exponent=-1/2) + JSJ = torch.bmm(JS, Js.transpose(1,2)) + l, Q = torch.linalg.eigh(S_inv, UPLO='U') + JS_true = Js @ Q @ torch.diag(torch.sqrt(l)) @ Q.T + JSJ_true = torch.bmm(JS_true, Js.transpose(1,2)) + assert torch.allclose(JS, JS_true) + assert torch.allclose(JSJ, JSJ_true) + + # test different Js shapes: + # 2 - dimensional + JS = kron_decomp.bmm(Js[:, 0, :].squeeze(), exponent=1) + JS_nodecomp = kron.bmm(Js[:, 0, :].squeeze()) + JS_true = Js[:, 0, :].squeeze() @ S + assert torch.allclose(JS, JS_true) + assert torch.allclose(JS, JS_nodecomp) + # 1 - dimensional + JS = kron_decomp.bmm(Js[0, 0, :].squeeze(), exponent=1) + JS_nodecomp = kron.bmm(Js[0, 0, :].squeeze()) + JS_true = Js[0, 0, :].squeeze() @ S + assert torch.allclose(JS, JS_true) + assert torch.allclose(JS, JS_nodecomp) + + def test_matrix_consistent(): expected_sizes = [[20, 3], [20], [2, 20], [2]] torch.manual_seed(7171) + # Dense Kronecker factors. kfacs = [[get_psd_matrix(i) for i in sizes] for sizes in expected_sizes] kron = Kron(kfacs) kron_decomp = kron.decompose() @@ -179,3 +275,13 @@ def test_matrix_consistent(): M_true.diagonal().add_(3.4) kron_decomp += torch.tensor(3.4) assert torch.allclose(M_true, kron_decomp.to_matrix(exponent=1)) + # Diagonal Kronecker factors. + diag_kfacs = [[get_diag_psd_matrix(i) for i in sizes] for sizes in expected_sizes] + kron = Kron(diag_kfacs) + kron_decomp = kron.decompose() + assert torch.allclose(kron.to_matrix(), kron_decomp.to_matrix(exponent=1)) + assert torch.allclose(kron.to_matrix().inverse(), kron_decomp.to_matrix(exponent=-1)) + M_true = kron.to_matrix() + M_true.diagonal().add_(3.4) + kron_decomp += torch.tensor(3.4) + assert torch.allclose(M_true, kron_decomp.to_matrix(exponent=1)) diff --git a/tests/utils.py b/tests/utils.py index ccaef943..c1d3c03d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,9 +5,15 @@ def get_psd_matrix(dim): X = torch.randn(dim, dim*3) return X @ X.T / (dim * 3) + +def get_diag_psd_matrix(dim): + return torch.randn(dim) ** 2 + + def grad(model): return torch.cat([p.grad.data.flatten() for p in model.parameters()]).detach() + def jacobians_naive(model, data): model.zero_grad() f = model(data) From 848dd7c4204210c34b4d2b766db61864994b5305 Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 16 Aug 2023 19:46:05 +0200 Subject: [PATCH 3/4] Undo sorting of eigenvalues --- laplace/utils/matrix.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/laplace/utils/matrix.py b/laplace/utils/matrix.py index c4d13ad0..b614ff46 100644 --- a/laplace/utils/matrix.py +++ b/laplace/utils/matrix.py @@ -114,9 +114,9 @@ def decompose(self, damping=False): l, Q = symeig(Hi) else: # Diagonal Kronecker factor. - l, indices = Hi.sort() + l = Hi # This might be too memory intensive since len(Hi) can be large. - Q = torch.eye(len(Hi), dtype=Hi.dtype, device=Hi.device)[indices].T + Q = torch.eye(len(Hi), dtype=Hi.dtype, device=Hi.device) Qs.append(Q) ls.append(l) eigvecs.append(Qs) From 49fdc8a1179c60badc25b75c8d35adb6acea6e5a Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 17 Aug 2023 00:18:41 +0200 Subject: [PATCH 4/4] Improve logdet numerical stability --- laplace/utils/matrix.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/laplace/utils/matrix.py b/laplace/utils/matrix.py index b614ff46..1cb65609 100644 --- a/laplace/utils/matrix.py +++ b/laplace/utils/matrix.py @@ -47,7 +47,7 @@ def init_from_model(cls, model, device): for p in params: if p.ndim == 1: # bias kfacs.append([0.]) - elif 4 >= p.ndim >= 2: # fully connected or or embedding or conv + elif 4 >= p.ndim >= 2: # fully connected or embedding or conv kfacs.append([0., 0.]) else: raise ValueError('Invalid parameter shape in network.') @@ -204,12 +204,12 @@ def logdet(self) -> torch.Tensor: logdet = 0 for F in self.kfacs: if len(F) == 1: - logdet += F[0].logdet() if F[0].ndim > 1 else F[0].prod().log() + logdet += F[0].logdet() if F[0].ndim > 1 else F[0].log().sum() else: # len(F) == 2 Hi, Hj = F p_in, p_out = len(Hi), len(Hj) - logdet += p_out * Hi.logdet() if Hi.ndim > 1 else p_out * Hi.prod().log() - logdet += p_in * Hj.logdet() if Hj.ndim > 1 else p_in * Hj.prod().log() + logdet += p_out * Hi.logdet() if Hi.ndim > 1 else p_out * Hi.log().sum() + logdet += p_in * Hj.logdet() if Hj.ndim > 1 else p_in * Hj.log().sum() return logdet def diag(self) -> torch.Tensor: