From ef6717c5518cf72fb86d600d3b2ef30c4a67d894 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Sat, 9 Mar 2024 11:24:34 -0500 Subject: [PATCH 01/10] Serialization for all Laplaces except Kron and Subnet --- laplace/baselaplace.py | 66 +++++++++++++++++++++++++++++------------- laplace/lllaplace.py | 30 +++++++++++++++---- 2 files changed, 71 insertions(+), 25 deletions(-) diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index 3db62184..44947bec 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -4,7 +4,7 @@ from torch.nn.utils import parameters_to_vector, vector_to_parameters from torch.distributions import MultivariateNormal, Dirichlet, Normal -from laplace.utils import (parameters_per_layer, invsqrt_precision, +from laplace.utils import (parameters_per_layer, invsqrt_precision, get_nll, validate, Kron, normal_samples, fix_prior_prec_structure) from laplace.curvature import AsdlGGN, BackPackGGN, AsdlHessian @@ -357,7 +357,7 @@ class ParametricLaplace(BaseLaplace): """ def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1., - prior_mean=0., temperature=1., enable_backprop=False, + prior_mean=0., temperature=1., enable_backprop=False, backend=None, backend_kwargs=None): super().__init__(model, likelihood, sigma_noise, prior_precision, prior_mean, temperature, enable_backprop, backend, backend_kwargs) @@ -527,7 +527,7 @@ def log_marginal_likelihood(self, prior_precision=None, sigma_noise=None): return self.log_likelihood - 0.5 * (self.log_det_ratio + self.scatter) - def __call__(self, x, pred_type='glm', joint=False, link_approx='probit', + def __call__(self, x, pred_type='glm', joint=False, link_approx='probit', n_samples=100, diagonal_output=False, generator=None): """Compute the posterior predictive on input data `x`. @@ -543,13 +543,13 @@ def __call__(self, x, pred_type='glm', joint=False, link_approx='probit', link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'} how to approximate the classification link function for the `'glm'`. - For `pred_type='nn'`, only 'mc' is possible. + For `pred_type='nn'`, only 'mc' is possible. joint : bool Whether to output a joint predictive distribution in regression with `pred_type='glm'`. If set to `True`, the predictive distribution has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). - If `False`, then only outputs the marginal predictive distribution. + If `False`, then only outputs the marginal predictive distribution. Only available for regression and GLM predictive. n_samples : int @@ -569,8 +569,8 @@ def __call__(self, x, pred_type='glm', joint=False, link_approx='probit', a distribution over classes (similar to a Softmax). For `likelihood='regression'`, a tuple of torch.Tensor is returned with the mean and the predictive variance. - For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor - is returned with the mean and the predictive covariance. + For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor + is returned with the mean and the predictive covariance. """ if pred_type not in ['glm', 'nn']: raise ValueError('Only glm and nn supported as prediction types.') @@ -580,7 +580,7 @@ def __call__(self, x, pred_type='glm', joint=False, link_approx='probit', if pred_type == 'nn' and link_approx != 'mc': raise ValueError('Only mc link approximation is supported for nn prediction type.') - + if generator is not None: if not isinstance(generator, torch.Generator) or generator.device != x.device: raise ValueError('Invalid random generator (check type and device).') @@ -594,7 +594,7 @@ def __call__(self, x, pred_type='glm', joint=False, link_approx='probit', return f_mu, f_var # classification if link_approx == 'mc': - return self.predictive_samples(x, pred_type='glm', n_samples=n_samples, + return self.predictive_samples(x, pred_type='glm', n_samples=n_samples, diagonal_output=diagonal_output).mean(dim=0) elif link_approx == 'probit': kappa = 1 / torch.sqrt(1. + np.pi / 8 * f_var.diagonal(dim1=1, dim2=2)) @@ -623,7 +623,7 @@ def __call__(self, x, pred_type='glm', joint=False, link_approx='probit', return samples.mean(dim=0), samples.var(dim=0) return samples.mean(dim=0) - def predictive_samples(self, x, pred_type='glm', n_samples=100, + def predictive_samples(self, x, pred_type='glm', n_samples=100, diagonal_output=False, generator=None): """Sample from the posterior predictive on input data `x`. Can be used, for example, for Thompson sampling. @@ -720,7 +720,7 @@ def functional_covariance(self, Jacs): `f_cov = Jacs @ P.inv() @ Jacs.T`, which is a batch*output x batch*output predictive covariance matrix. - This emulates the GP posterior covariance N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). + This emulates the GP posterior covariance N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]). Useful for joint predictions, such as in batched Bayesian optimization. Parameters @@ -770,6 +770,32 @@ def posterior_precision(self): """ raise NotImplementedError + def state_dict(self) -> dict: + self._check_H_init() + state_dict = { + 'mean': self.mean, + 'H': self.H, + 'loss': self.loss, + 'prior_precision': self.prior_precision, + 'sigma_noise': self.sigma_noise, + 'n_data': self.n_data, + 'n_outputs': self.n_outputs, + } + return state_dict + + def load_state_dict(self, state_dict: dict): + self.mean = state_dict['mean'] + if self.n_params is not None and len(self.mean) != self.n_params: + raise ValueError('Attempting to load Laplace with different number of parameters than the model.') + + self.H = state_dict['H'] + self.loss = state_dict['loss'] + self.prior_precision = state_dict['prior_precision'] + self.sigma_noise = state_dict['sigma_noise'] + self.n_data = state_dict['n_data'] + self.n_outputs = state_dict['n_outputs'] + setattr(self.model, 'output_size', self.n_outputs) + class FullLaplace(ParametricLaplace): """Laplace approximation with full, i.e., dense, log likelihood Hessian approximation @@ -875,7 +901,7 @@ class KronLaplace(ParametricLaplace): _key = ('all', 'kron') def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1., - prior_mean=0., temperature=1., enable_backprop=False, backend=None, + prior_mean=0., temperature=1., enable_backprop=False, backend=None, damping=False, **backend_kwargs): self.damping = damping self.H_facs = None @@ -965,26 +991,26 @@ def prior_precision(self, prior_precision): class LowRankLaplace(ParametricLaplace): - """Laplace approximation with low-rank log likelihood Hessian (approximation). + """Laplace approximation with low-rank log likelihood Hessian (approximation). The low-rank matrix is represented by an eigendecomposition (vecs, values). Based on the chosen `backend`, either a true Hessian or, for example, GGN approximation could be used. The posterior precision is computed as \\( P = V diag(l) V^T + P_0.\\) - To sample, compute the functional variance, and log determinant, algebraic tricks + To sample, compute the functional variance, and log determinant, algebraic tricks are usedto reduce the costs of inversion to the that of a \\(K \times K\\) matrix if we have a rank of K. - + See `BaseLaplace` for the full interface. """ _key = ('all', 'lowrank') - def __init__(self, model, likelihood, sigma_noise=1, prior_precision=1, prior_mean=0, + def __init__(self, model, likelihood, sigma_noise=1, prior_precision=1, prior_mean=0, temperature=1, enable_backprop=False, backend=AsdlHessian, backend_kwargs=None): - super().__init__(model, likelihood, sigma_noise=sigma_noise, - prior_precision=prior_precision, prior_mean=prior_mean, - temperature=temperature, enable_backprop=enable_backprop, + super().__init__(model, likelihood, sigma_noise=sigma_noise, + prior_precision=prior_precision, prior_mean=prior_mean, + temperature=temperature, enable_backprop=enable_backprop, backend=backend, backend_kwargs=backend_kwargs) - + def _init_H(self): self.H = None diff --git a/laplace/lllaplace.py b/laplace/lllaplace.py index 827a29e1..dc3163d8 100644 --- a/laplace/lllaplace.py +++ b/laplace/lllaplace.py @@ -62,7 +62,7 @@ def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1., backend_kwargs=None): self.H = None super().__init__(model, likelihood, sigma_noise=sigma_noise, prior_precision=1., - prior_mean=0., temperature=temperature, + prior_mean=0., temperature=temperature, enable_backprop=enable_backprop, backend=backend, backend_kwargs=backend_kwargs) self.model = FeatureExtractor( @@ -103,12 +103,13 @@ def fit(self, train_loader, override=True): self.model.eval() if self.model.last_layer is None: - X, _ = next(iter(train_loader)) + # Save an example batch for when loading the serialized Laplace + self.X, _ = next(iter(train_loader)) with torch.no_grad(): try: - self.model.find_last_layer(X[:1].to(self._device)) + self.model.find_last_layer(self.X[:1].to(self._device)) except (TypeError, AttributeError): - self.model.find_last_layer(X.to(self._device)) + self.model.find_last_layer(self.X.to(self._device)) params = parameters_to_vector(self.model.last_layer.parameters()).detach() self.n_params = len(params) self.n_layers = len(list(self.model.last_layer.parameters())) @@ -125,7 +126,7 @@ def fit(self, train_loader, override=True): def _glm_predictive_distribution(self, X, joint=False): Js, f_mu = self.backend.last_layer_jacobians(X) - + if joint: f_mu = f_mu.flatten() # (batch*out) f_var = self.functional_covariance(Js) # (batch*out, batch*out) @@ -164,6 +165,25 @@ def prior_precision_diag(self): else: raise ValueError('Mismatch of prior and model. Diagonal or scalar prior.') + def state_dict(self) -> dict: + state_dict = super().state_dict() + state_dict['X'] = self.X + return state_dict + + def load_state_dict(self, state_dict: dict): + super().load_state_dict(state_dict) + + self.X = state_dict['X'] + with torch.no_grad(): + try: + self.model.find_last_layer(self.X[:1].to(self._device)) + except (TypeError, AttributeError): + self.model.find_last_layer(self.X.to(self._device)) + + params = parameters_to_vector(self.model.last_layer.parameters()).detach() + self.n_params = len(params) + self.n_layers = len(list(self.model.last_layer.parameters())) + class FullLLLaplace(LLLaplace, FullLaplace): """Last-layer Laplace approximation with full, i.e., dense, log likelihood Hessian approximation From 9ca7bf849f8bf2c9f9b7a12861dbb810f0bde682 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Sat, 9 Mar 2024 11:57:21 -0500 Subject: [PATCH 02/10] Serialization for KronLaplace & KronLLLaplace --- .gitignore | 2 ++ examples/regression_example.py | 20 ++++++++++++++------ laplace/baselaplace.py | 13 +++++++++++++ laplace/lllaplace.py | 4 ++-- laplace/utils/matrix.py | 4 ++-- 5 files changed, 33 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index e7759bc2..2af921b3 100644 --- a/.gitignore +++ b/.gitignore @@ -132,3 +132,5 @@ dmypy.json data/ .DS_Store + +state_dict.bin diff --git a/examples/regression_example.py b/examples/regression_example.py index cded77f7..2eb5e13e 100644 --- a/examples/regression_example.py +++ b/examples/regression_example.py @@ -43,6 +43,14 @@ def get_model(): neg_marglik.backward() hyper_optimizer.step() +# Serialization for fitted quantities +state_dict = la.state_dict() +torch.save(state_dict, 'state_dict.bin') + +la = Laplace(model, 'regression', subset_of_weights='all', hessian_structure='full') +# Load serialized, fitted quantities +la.load_state_dict(torch.load('state_dict.bin')) + print(f'sigma={la.sigma_noise.item():.2f}', f'prior precision={la.prior_precision.item():.2f}') @@ -51,11 +59,11 @@ def get_model(): # Two options: # 1.) Marginal predictive distribution N(f_map(x_i), var(x_i)) # The mean is (m,k), the var is (m,k,k) -f_mu, f_var = la(X_test) +f_mu, f_var = la(X_test) # 2.) Joint pred. dist. N((f_map(x_1),...,f_map(x_m)), Cov(f(x_1),...,f(x_m))) # The mean is (m*k,) where k is the output dim. The cov is (m*k,m*k) -f_mu_joint, f_cov = la(X_test, joint=True) +f_mu_joint, f_cov = la(X_test, joint=True) # Both should be true assert torch.allclose(f_mu.flatten(), f_mu_joint) @@ -65,14 +73,14 @@ def get_model(): f_sigma = f_var.squeeze().detach().sqrt().cpu().numpy() pred_std = np.sqrt(f_sigma**2 + la.sigma_noise.item()**2) -plot_regression(X_train, y_train, x, f_mu, pred_std, - file_name='regression_example', plot=False) +plot_regression(X_train, y_train, x, f_mu, pred_std, + file_name='regression_example', plot=True) # alternatively, optimize parameters and hyperparameters of the prior jointly model = get_model() la, model, margliks, losses = marglik_training( model=model, train_loader=train_loader, likelihood='regression', - hessian_structure='full', backend=BackPackGGN, n_epochs=n_epochs, + hessian_structure='full', backend=BackPackGGN, n_epochs=n_epochs, optimizer_kwargs={'lr': 1e-2}, prior_structure='scalar' ) @@ -83,5 +91,5 @@ def get_model(): f_mu = f_mu.squeeze().detach().cpu().numpy() f_sigma = f_var.squeeze().sqrt().cpu().numpy() pred_std = np.sqrt(f_sigma**2 + la.sigma_noise.item()**2) -plot_regression(X_train, y_train, x, f_mu, pred_std, +plot_regression(X_train, y_train, x, f_mu, pred_std, file_name='regression_example_online', plot=False) diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index 44947bec..beb61cdb 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -989,6 +989,19 @@ def prior_precision(self, prior_precision): if len(self.prior_precision) not in [1, self.n_layers]: raise ValueError('Prior precision for Kron either scalar or per-layer.') + def state_dict(self) -> dict: + self._check_H_init() + state_dict = super().state_dict() + state_dict['H'] = self.H_facs.kfacs + return state_dict + + def load_state_dict(self, state_dict: dict): + super().load_state_dict(state_dict) + self._init_H() + self.H_facs = self.H + self.H_facs.kfacs = state_dict['H'] + self.H = self.H_facs.decompose(damping=self.damping) + class LowRankLaplace(ParametricLaplace): """Laplace approximation with low-rank log likelihood Hessian (approximation). diff --git a/laplace/lllaplace.py b/laplace/lllaplace.py index dc3163d8..d89ddd44 100644 --- a/laplace/lllaplace.py +++ b/laplace/lllaplace.py @@ -171,8 +171,6 @@ def state_dict(self) -> dict: return state_dict def load_state_dict(self, state_dict: dict): - super().load_state_dict(state_dict) - self.X = state_dict['X'] with torch.no_grad(): try: @@ -180,6 +178,8 @@ def load_state_dict(self, state_dict: dict): except (TypeError, AttributeError): self.model.find_last_layer(self.X.to(self._device)) + super().load_state_dict(state_dict) + params = parameters_to_vector(self.model.last_layer.parameters()).detach() self.n_params = len(params) self.n_layers = len(list(self.model.last_layer.parameters())) diff --git a/laplace/utils/matrix.py b/laplace/utils/matrix.py index c60d8e1c..cddee587 100644 --- a/laplace/utils/matrix.py +++ b/laplace/utils/matrix.py @@ -2,7 +2,7 @@ import torch import numpy as np from typing import Union -import opt_einsum as oe +import opt_einsum as oe from laplace.utils import _is_valid_scalar, symeig, kron, block_diag @@ -451,7 +451,7 @@ def diag(self, exponent: float = 1) -> torch.Tensor: Ql = Qs[0] * torch.pow(ls[0] + delta, exponent).reshape(1, -1) d = torch.einsum('mp,mp->m', Ql, Qs[0]) # only compute inner products for diag diags.append(d) - else: + else: Q1, Q2 = Qs l1, l2 = ls if self.damping: From c45277c97e8abdcbfa792c7678ae09fca79aae7f Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Sat, 9 Mar 2024 12:36:10 -0500 Subject: [PATCH 03/10] Add serialization tests --- tests/test_serialization.py | 83 +++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 tests/test_serialization.py diff --git a/tests/test_serialization.py b/tests/test_serialization.py new file mode 100644 index 00000000..d0aa3b7e --- /dev/null +++ b/tests/test_serialization.py @@ -0,0 +1,83 @@ +from math import sqrt, prod +import pytest +from itertools import product +import numpy as np +from copy import deepcopy +import torch +from torch import nn +from torch.distributions.multivariate_normal import MultivariateNormal +from torch.nn.utils import parameters_to_vector +from torch.utils.data import DataLoader, TensorDataset +import os + +from laplace.laplace import FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace +from laplace.utils import KronDecomposed +from tests.utils import jacobians_naive + + +torch.manual_seed(240) +torch.set_default_tensor_type(torch.DoubleTensor) +flavors = [FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace] + + +@pytest.fixture +def model(): + model = torch.nn.Sequential(nn.Linear(3, 20), nn.Linear(20, 2)) + setattr(model, 'output_size', 2) + model_params = list(model.parameters()) + setattr(model, 'n_layers', len(model_params)) # number of parameter groups + setattr(model, 'n_params', len(parameters_to_vector(model_params))) + return model + + +@pytest.fixture +def reg_loader(): + X = torch.randn(10, 3) + y = torch.randn(10, 2) + return DataLoader(TensorDataset(X, y), batch_size=3) + + +def _cleanup(): + if os.path.exists('state_dict.bin'): + os.remove('state_dict.bin') + + +@pytest.mark.parametrize('laplace', flavors) +def test_serialize(laplace, model, reg_loader): + la = laplace(model, 'regression') + la.fit(reg_loader) + la.optimize_prior_precision() + la.sigma_noise = 1231 + torch.save(la.state_dict(), 'state_dict.bin') + + la2 = laplace(model, 'regression') + la2.load_state_dict(torch.load('state_dict.bin')) + + assert la.sigma_noise == la2.sigma_noise + + X, _ = next(iter(reg_loader)) + f_mean, f_var = la(X) + f_mean2, f_var2 = la2(X) + assert torch.allclose(f_mean, f_mean2) + assert torch.allclose(f_var, f_var2) + + _cleanup() + + +@pytest.mark.parametrize('laplace', flavors) +def test_serialize_no_pickle(laplace, model, reg_loader): + la = laplace(model, 'regression') + la.fit(reg_loader) + la.optimize_prior_precision() + la.sigma_noise = 1231 + torch.save(la.state_dict(), 'state_dict.bin') + state_dict = torch.load('state_dict.bin') + + # Make sure no pickle object + for val in state_dict.values(): + assert isinstance(val, (list, tuple, int, float, torch.Tensor)) + + _cleanup() + + + From 853eca69ceedaf8fcec09553bd160777fcc6644e Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Sat, 9 Mar 2024 18:07:22 -0500 Subject: [PATCH 04/10] Add serialization test for SubnetLaplace --- tests/test_serialization.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index d0aa3b7e..328638d9 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -10,7 +10,7 @@ from torch.utils.data import DataLoader, TensorDataset import os -from laplace.laplace import FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace +from laplace.laplace import FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace, DiagSubnetLaplace, FullSubnetLaplace from laplace.utils import KronDecomposed from tests.utils import jacobians_naive @@ -18,6 +18,7 @@ torch.manual_seed(240) torch.set_default_tensor_type(torch.DoubleTensor) flavors = [FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace] +subnet_flavors = [DiagSubnetLaplace, FullSubnetLaplace] @pytest.fixture @@ -37,7 +38,10 @@ def reg_loader(): return DataLoader(TensorDataset(X, y), batch_size=3) -def _cleanup(): +@pytest.fixture(autouse=True) +def cleanup(): + yield + # Run after test if os.path.exists('state_dict.bin'): os.remove('state_dict.bin') @@ -61,8 +65,6 @@ def test_serialize(laplace, model, reg_loader): assert torch.allclose(f_mean, f_mean2) assert torch.allclose(f_var, f_var2) - _cleanup() - @pytest.mark.parametrize('laplace', flavors) def test_serialize_no_pickle(laplace, model, reg_loader): @@ -77,7 +79,24 @@ def test_serialize_no_pickle(laplace, model, reg_loader): for val in state_dict.values(): assert isinstance(val, (list, tuple, int, float, torch.Tensor)) - _cleanup() +@pytest.mark.parametrize('laplace', subnet_flavors) +def test_serialize_subnetlaplace(laplace, model, reg_loader): + subnetwork_indices = torch.LongTensor([1, 10, 104, 44]) + la = laplace(model, 'regression', subnetwork_indices=subnetwork_indices) + la.fit(reg_loader) + la.optimize_prior_precision() + la.sigma_noise = 1231 + torch.save(la.state_dict(), 'state_dict.bin') + + la2 = laplace(model, 'regression', subnetwork_indices=subnetwork_indices) + la2.load_state_dict(torch.load('state_dict.bin')) + + assert la.sigma_noise == la2.sigma_noise + X, _ = next(iter(reg_loader)) + f_mean, f_var = la(X) + f_mean2, f_var2 = la2(X) + assert torch.allclose(f_mean, f_mean2) + assert torch.allclose(f_var, f_var2) From d2860021522b4db7d00025f683aac0aa9e250183 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Sat, 9 Mar 2024 18:46:04 -0500 Subject: [PATCH 05/10] Add example of serialization on README.md --- README.md | 68 +++++++++++++++++++++++++++++++------------------------ 1 file changed, 38 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index a9d6146a..90ea7d8b 100644 --- a/README.md +++ b/README.md @@ -38,36 +38,6 @@ pip install -e .[tests] pytest tests/ ``` -## Structure -The laplace package consists of two main components: - -1. The subclasses of [`laplace.BaseLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/baselaplace.py) that implement different sparsity structures: different subsets of weights (`'all'`, `'subnetwork'` and `'last_layer'`) and different structures of the Hessian approximation (`'full'`, `'kron'`, `'lowrank'` and `'diag'`). This results in _nine_ currently available options: `laplace.FullLaplace`, `laplace.KronLaplace`, `laplace.DiagLaplace`, the corresponding last-layer variations `laplace.FullLLLaplace`, `laplace.KronLLLaplace`, and `laplace.DiagLLLaplace` (which are all subclasses of [`laplace.LLLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/lllaplace.py)), [`laplace.SubnetLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/subnetlaplace.py) (which only supports `'full'` and `'diag'` Hessian approximations) and `laplace.LowRankLaplace` (which only supports inference over `'all'` weights). All of these can be conveniently accessed via the [`laplace.Laplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/laplace.py) function. -2. The backends in [`laplace.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/) which provide access to Hessian approximations of -the corresponding sparsity structures, for example, the diagonal GGN. - -Additionally, the package provides utilities for -decomposing a neural network into feature extractor and last layer for `LLLaplace` subclasses ([`laplace.utils.feature_extractor`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/feature_extractor.py)) -and -effectively dealing with Kronecker factors ([`laplace.utils.matrix`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/matrix.py)). - -Finally, the package implements several options to select/specify a subnetwork for `SubnetLaplace` (as subclasses of [`laplace.utils.subnetmask.SubnetMask`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/subnetmask.py)). -Automatic subnetwork selection strategies include: uniformly at random (`laplace.utils.subnetmask.RandomSubnetMask`), by largest parameter magnitudes (`LargestMagnitudeSubnetMask`), and by largest marginal parameter variances (`LargestVarianceDiagLaplaceSubnetMask` and `LargestVarianceSWAGSubnetMask`). -In addition to that, subnetworks can also be specified manually, by listing the names of either the model parameters (`ParamNameSubnetMask`) or modules (`ModuleNameSubnetMask`) to perform Laplace inference over. - -## Extendability -To extend the laplace package, new `BaseLaplace` subclasses can be designed, for example, -Laplace with a block-diagonal Hessian structure. -One can also implement custom subnetwork selection strategies as new subclasses of `SubnetMask`. - -Alternatively, extending or integrating backends (subclasses of [`curvature.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/curvature.py)) allows to provide different Hessian -approximations to the Laplace approximations. -For example, currently the [`curvature.CurvlinopsInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/curvlinops.py) based on [Curvlinops](https://github.com/f-dangel/curvlinops) and the native `torch.func` (previously known as `functorch`), [`curvature.BackPackInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/backpack.py) based on [BackPACK](https://github.com/f-dangel/backpack/) and [`curvature.AsdlInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/asdl.py) based on [ASDL](https://github.com/kazukiosawa/asdfghjkl) are available. - -The `curvature.CurvlinopsInterface` backend is the default and provides all Hessian approximation variants except the low-rank Hessian. -For the latter, `curvature.AsdlInterface` can be used. -Note that `curvature.AsdlInterface` and `curvature.BackPackInterface` are less complete and less compatible than `curvature.CurvlinopsInterface`. -So, we recommend to stick with `curvature.CurvlinopsInterface` unless you have a specific need of ASDL or BackPACK. - ## Example usage ### *Post-hoc* prior precision tuning of diagonal LA @@ -94,6 +64,13 @@ la.optimize_prior_precision(method='gridsearch', val_loader=val_loader) # User-specified predictive approx. pred = la(x, link_approx='probit') + +# Serialization +torch.save(la.state_dict(), 'state_dict.bin') + +# Load serialized Laplace +la2 = Laplace(model, 'regression', subset_of_weights='all', hessian_structure='full') +la2.load_state_dict(torch.load('state_dict.bin')) ``` ### Differentiating the log marginal likelihood w.r.t. hyperparameters @@ -157,6 +134,37 @@ la = Laplace(model, 'classification', la.fit(train_loader) ``` + +## Structure +The laplace package consists of two main components: + +1. The subclasses of [`laplace.BaseLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/baselaplace.py) that implement different sparsity structures: different subsets of weights (`'all'`, `'subnetwork'` and `'last_layer'`) and different structures of the Hessian approximation (`'full'`, `'kron'`, `'lowrank'` and `'diag'`). This results in _nine_ currently available options: `laplace.FullLaplace`, `laplace.KronLaplace`, `laplace.DiagLaplace`, the corresponding last-layer variations `laplace.FullLLLaplace`, `laplace.KronLLLaplace`, and `laplace.DiagLLLaplace` (which are all subclasses of [`laplace.LLLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/lllaplace.py)), [`laplace.SubnetLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/subnetlaplace.py) (which only supports `'full'` and `'diag'` Hessian approximations) and `laplace.LowRankLaplace` (which only supports inference over `'all'` weights). All of these can be conveniently accessed via the [`laplace.Laplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/laplace.py) function. +2. The backends in [`laplace.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/) which provide access to Hessian approximations of +the corresponding sparsity structures, for example, the diagonal GGN. + +Additionally, the package provides utilities for +decomposing a neural network into feature extractor and last layer for `LLLaplace` subclasses ([`laplace.utils.feature_extractor`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/feature_extractor.py)) +and +effectively dealing with Kronecker factors ([`laplace.utils.matrix`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/matrix.py)). + +Finally, the package implements several options to select/specify a subnetwork for `SubnetLaplace` (as subclasses of [`laplace.utils.subnetmask.SubnetMask`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/subnetmask.py)). +Automatic subnetwork selection strategies include: uniformly at random (`laplace.utils.subnetmask.RandomSubnetMask`), by largest parameter magnitudes (`LargestMagnitudeSubnetMask`), and by largest marginal parameter variances (`LargestVarianceDiagLaplaceSubnetMask` and `LargestVarianceSWAGSubnetMask`). +In addition to that, subnetworks can also be specified manually, by listing the names of either the model parameters (`ParamNameSubnetMask`) or modules (`ModuleNameSubnetMask`) to perform Laplace inference over. + +## Extendability +To extend the laplace package, new `BaseLaplace` subclasses can be designed, for example, +Laplace with a block-diagonal Hessian structure. +One can also implement custom subnetwork selection strategies as new subclasses of `SubnetMask`. + +Alternatively, extending or integrating backends (subclasses of [`curvature.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/curvature.py)) allows to provide different Hessian +approximations to the Laplace approximations. +For example, currently the [`curvature.CurvlinopsInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/curvlinops.py) based on [Curvlinops](https://github.com/f-dangel/curvlinops) and the native `torch.func` (previously known as `functorch`), [`curvature.BackPackInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/backpack.py) based on [BackPACK](https://github.com/f-dangel/backpack/) and [`curvature.AsdlInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/asdl.py) based on [ASDL](https://github.com/kazukiosawa/asdfghjkl) are available. + +The `curvature.CurvlinopsInterface` backend is the default and provides all Hessian approximation variants except the low-rank Hessian. +For the latter, `curvature.AsdlInterface` can be used. +Note that `curvature.AsdlInterface` and `curvature.BackPackInterface` are less complete and less compatible than `curvature.CurvlinopsInterface`. +So, we recommend to stick with `curvature.CurvlinopsInterface` unless you have a specific need of ASDL or BackPACK. + ## Documentation The documentation is available [here](https://aleximmer.github.io/Laplace) or can be generated and/or viewed locally: From 894f640b716d4386c5b68c555ce9d6dbb5b1d4fd Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Thu, 14 Mar 2024 12:17:29 -0400 Subject: [PATCH 06/10] Provide explicit checks when loading serialized Laplace --- laplace/baselaplace.py | 40 ++++++++++++-- laplace/lllaplace.py | 18 ++++--- tests/test_serialization.py | 104 ++++++++++++++++++++++++++++++++++-- 3 files changed, 149 insertions(+), 13 deletions(-) diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index d2e23908..dc1e6285 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -3,6 +3,7 @@ import torch from torch.nn.utils import parameters_to_vector, vector_to_parameters from torch.distributions import MultivariateNormal +import warnings from laplace.utils import (parameters_per_layer, invsqrt_precision, get_nll, validate, Kron, normal_samples, @@ -774,25 +775,57 @@ def state_dict(self) -> dict: 'mean': self.mean, 'H': self.H, 'loss': self.loss, + 'prior_mean': self.prior_mean, 'prior_precision': self.prior_precision, 'sigma_noise': self.sigma_noise, 'n_data': self.n_data, 'n_outputs': self.n_outputs, + 'likelihood': self.likelihood, + 'temperature': self.temperature, + 'enable_backprop': self.enable_backprop, + 'cls_name': self.__class__.__name__ } return state_dict def load_state_dict(self, state_dict: dict): - self.mean = state_dict['mean'] - if self.n_params is not None and len(self.mean) != self.n_params: - raise ValueError('Attempting to load Laplace with different number of parameters than the model.') + # Dealbreaker errors + if self.__class__.__name__ != state_dict['cls_name']: + raise ValueError( + 'Loading a wrong Laplace type. Make sure `subset_of_weights` and' + + ' `hessian_structure` are correct!' + ) + if self.n_params is not None and len(state_dict['mean']) != self.n_params: + raise ValueError( + 'Attempting to load Laplace with different number of parameters than the model.' + + ' Make sure that you use the same `subset_of_weights` value and the same `.requires_grad`' + + ' switch on `model.parameters()`.' + ) + if self.likelihood != state_dict['likelihood']: + raise ValueError('Different likelihoods detected!') + + # Ignorable warnings + if self.prior_mean is None and state_dict['prior_mean'] is not None: + warnings.warn('Loading non-`None` prior mean into a `None` prior mean. You might get wrong results.') + if self.temperature != state_dict['temperature']: + warnings.warn('Different `temperature` parameters detected. Some calculation might be off!') + if self.enable_backprop != state_dict['enable_backprop']: + warnings.warn( + 'Different `enable_backprop` values. You might encounter error when differentiating' + + ' the predictive mean and variance.' + ) + self.mean = state_dict['mean'] self.H = state_dict['H'] self.loss = state_dict['loss'] + self.prior_mean = state_dict['prior_mean'] self.prior_precision = state_dict['prior_precision'] self.sigma_noise = state_dict['sigma_noise'] self.n_data = state_dict['n_data'] self.n_outputs = state_dict['n_outputs'] setattr(self.model, 'output_size', self.n_outputs) + self.likelihood = state_dict['likelihood'] + self.temperature = state_dict['temperature'] + self.enable_backprop = state_dict['enable_backprop'] class FullLaplace(ParametricLaplace): @@ -988,7 +1021,6 @@ def prior_precision(self, prior_precision): raise ValueError('Prior precision for Kron either scalar or per-layer.') def state_dict(self) -> dict: - self._check_H_init() state_dict = super().state_dict() state_dict['H'] = self.H_facs.kfacs return state_dict diff --git a/laplace/lllaplace.py b/laplace/lllaplace.py index d89ddd44..ead2b6af 100644 --- a/laplace/lllaplace.py +++ b/laplace/lllaplace.py @@ -84,6 +84,7 @@ def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1., self.mean = self.prior_mean self._init_H() self._backend_kwargs['last_layer'] = True + self._last_layer_name = last_layer_name def fit(self, train_loader, override=True): """Fit the local Laplace approximation at the parameters of the model. @@ -167,16 +168,21 @@ def prior_precision_diag(self): def state_dict(self) -> dict: state_dict = super().state_dict() - state_dict['X'] = self.X + state_dict['X'] = getattr(self, 'X', None) + state_dict['_last_layer_name'] = self._last_layer_name return state_dict def load_state_dict(self, state_dict: dict): + if self._last_layer_name != state_dict['_last_layer_name']: + raise ValueError('Different `last_layer_name` detected!') + self.X = state_dict['X'] - with torch.no_grad(): - try: - self.model.find_last_layer(self.X[:1].to(self._device)) - except (TypeError, AttributeError): - self.model.find_last_layer(self.X.to(self._device)) + if self.X is not None: + with torch.no_grad(): + try: + self.model.find_last_layer(self.X[:1].to(self._device)) + except (TypeError, AttributeError): + self.model.find_last_layer(self.X.to(self._device)) super().load_state_dict(state_dict) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 328638d9..405cc979 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -9,7 +9,9 @@ from torch.nn.utils import parameters_to_vector from torch.utils.data import DataLoader, TensorDataset import os +from collections import OrderedDict +from laplace import Laplace from laplace.laplace import FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace, DiagSubnetLaplace, FullSubnetLaplace from laplace.utils import KronDecomposed from tests.utils import jacobians_naive @@ -18,7 +20,9 @@ torch.manual_seed(240) torch.set_default_tensor_type(torch.DoubleTensor) flavors = [FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace] -subnet_flavors = [DiagSubnetLaplace, FullSubnetLaplace] +flavors_no_llla = [FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace] +flavors_llla = [FullLLLaplace, KronLLLaplace, DiagLLLaplace] +flavors_subnet = [DiagSubnetLaplace, FullSubnetLaplace] @pytest.fixture @@ -31,6 +35,31 @@ def model(): return model +@pytest.fixture +def model2(): + model = torch.nn.Sequential(nn.Linear(3, 25), nn.Linear(25, 2)) + setattr(model, 'output_size', 2) + model_params = list(model.parameters()) + setattr(model, 'n_layers', len(model_params)) # number of parameter groups + setattr(model, 'n_params', len(parameters_to_vector(model_params))) + return model + + +@pytest.fixture +def model3(): + model = torch.nn.Sequential( + OrderedDict([ + ('fc1', nn.Linear(3, 20)), + ('clf', nn.Linear(20, 2)) + ]) + ) + setattr(model, 'output_size', 2) + model_params = list(model.parameters()) + setattr(model, 'n_layers', len(model_params)) # number of parameter groups + setattr(model, 'n_params', len(parameters_to_vector(model_params))) + return model + + @pytest.fixture def reg_loader(): X = torch.randn(10, 3) @@ -77,10 +106,11 @@ def test_serialize_no_pickle(laplace, model, reg_loader): # Make sure no pickle object for val in state_dict.values(): - assert isinstance(val, (list, tuple, int, float, torch.Tensor)) + if val is not None: + assert isinstance(val, (list, tuple, int, float, str, bool, torch.Tensor)) -@pytest.mark.parametrize('laplace', subnet_flavors) +@pytest.mark.parametrize('laplace', flavors_subnet) def test_serialize_subnetlaplace(laplace, model, reg_loader): subnetwork_indices = torch.LongTensor([1, 10, 104, 44]) la = laplace(model, 'regression', subnetwork_indices=subnetwork_indices) @@ -100,3 +130,71 @@ def test_serialize_subnetlaplace(laplace, model, reg_loader): assert torch.allclose(f_mean, f_mean2) assert torch.allclose(f_var, f_var2) + +@pytest.mark.parametrize('laplace', flavors_no_llla) +def test_serialize_fail_different_models(laplace, model, model2, reg_loader): + la = laplace(model, 'regression') + la.fit(reg_loader) + la.optimize_prior_precision() + la.sigma_noise = 1231 + torch.save(la.state_dict(), 'state_dict.bin') + + la2 = laplace(model2, 'regression') + + with pytest.raises(ValueError): + la2.load_state_dict(torch.load('state_dict.bin')) + + +def test_serialize_fail_different_hess_structures(model, reg_loader): + la = Laplace(model, 'regression', subset_of_weights='all', hessian_structure='kron') + la.fit(reg_loader) + la.optimize_prior_precision() + la.sigma_noise = 1231 + torch.save(la.state_dict(), 'state_dict.bin') + + la2 = Laplace(model, 'regression', subset_of_weights='all', hessian_structure='diag') + + with pytest.raises(ValueError): + la2.load_state_dict(torch.load('state_dict.bin')) + + +def test_serialize_fail_different_subset_of_weights(model, reg_loader): + la = Laplace(model, 'regression', subset_of_weights='last_layer', hessian_structure='diag') + la.fit(reg_loader) + la.optimize_prior_precision() + la.sigma_noise = 1231 + torch.save(la.state_dict(), 'state_dict.bin') + + la2 = Laplace(model, 'regression', subset_of_weights='all', hessian_structure='diag') + + with pytest.raises(ValueError): + la2.load_state_dict(torch.load('state_dict.bin')) + + +@pytest.mark.parametrize('laplace', flavors) +def test_serialize_fail_different_liks(laplace, model, reg_loader): + la = laplace(model, 'regression') + la.fit(reg_loader) + la.optimize_prior_precision() + la.sigma_noise = 1231 + torch.save(la.state_dict(), 'state_dict.bin') + + la2 = laplace(model, 'classification') + + with pytest.raises(ValueError): + la2.load_state_dict(torch.load('state_dict.bin')) + + +@pytest.mark.parametrize('laplace', flavors_llla) +def test_serialize_fail_llla_different_last_layer_name(laplace, model, model3, reg_loader): + print([n for n, _ in model.named_parameters()]) + la = laplace(model, 'regression', last_layer_name='1') + la.fit(reg_loader) + la.optimize_prior_precision() + la.sigma_noise = 1231 + torch.save(la.state_dict(), 'state_dict.bin') + + la2 = laplace(model3, 'classification', last_layer_name='clf') + + with pytest.raises(ValueError): + la2.load_state_dict(torch.load('state_dict.bin')) From 30e28f2d94590f7b8d265f68cf98489c3a7af5dd Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Thu, 14 Mar 2024 12:24:14 -0400 Subject: [PATCH 07/10] Fix readme --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 90ea7d8b..80a82d1d 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,9 @@ pred = la(x, link_approx='probit') torch.save(la.state_dict(), 'state_dict.bin') # Load serialized Laplace -la2 = Laplace(model, 'regression', subset_of_weights='all', hessian_structure='full') +la2 = Laplace(model, 'classification', + subset_of_weights='all', + hessian_structure='diag') la2.load_state_dict(torch.load('state_dict.bin')) ``` From 49b19d346f532a06449d027391e58c6e995edd19 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Thu, 14 Mar 2024 15:57:42 -0400 Subject: [PATCH 08/10] Add test serialization override --- tests/test_serialization.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 405cc979..0bbb77cd 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -95,6 +95,25 @@ def test_serialize(laplace, model, reg_loader): assert torch.allclose(f_var, f_var2) +@pytest.mark.parametrize('laplace', set(flavors_no_llla) - {LowRankLaplace, KronLaplace, FullLaplace}) +def test_serialize_override(laplace, model, reg_loader): + la = laplace(model, 'regression') + la.fit(reg_loader) + la.optimize_prior_precision() + la.sigma_noise = 1231 + H_orig = la.H_kfacs.to_matrix() if laplace == KronLaplace else la.H + torch.save(la.state_dict(), 'state_dict.bin') + + la2 = laplace(model, 'regression') + la2.load_state_dict(torch.load('state_dict.bin')) + + # Emulating continual learning + la2.fit(reg_loader, override=False) + + H_new = la2.H_kfacs.to_matrix() if laplace == KronLaplace else la2.H + assert not torch.allclose(H_orig, H_new) + + @pytest.mark.parametrize('laplace', flavors) def test_serialize_no_pickle(laplace, model, reg_loader): la = laplace(model, 'regression') From e3e42e2cd2cd47a71fe08d255960f4ff56fc8d48 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Thu, 14 Mar 2024 15:58:34 -0400 Subject: [PATCH 09/10] Fix typos --- tests/test_serialization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 0bbb77cd..f6951275 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -95,13 +95,13 @@ def test_serialize(laplace, model, reg_loader): assert torch.allclose(f_var, f_var2) -@pytest.mark.parametrize('laplace', set(flavors_no_llla) - {LowRankLaplace, KronLaplace, FullLaplace}) +@pytest.mark.parametrize('laplace', set(flavors_no_llla) - {LowRankLaplace}) def test_serialize_override(laplace, model, reg_loader): la = laplace(model, 'regression') la.fit(reg_loader) la.optimize_prior_precision() la.sigma_noise = 1231 - H_orig = la.H_kfacs.to_matrix() if laplace == KronLaplace else la.H + H_orig = la.H_facs.to_matrix() if laplace == KronLaplace else la.H torch.save(la.state_dict(), 'state_dict.bin') la2 = laplace(model, 'regression') @@ -110,7 +110,7 @@ def test_serialize_override(laplace, model, reg_loader): # Emulating continual learning la2.fit(reg_loader, override=False) - H_new = la2.H_kfacs.to_matrix() if laplace == KronLaplace else la2.H + H_new = la2.H_facs.to_matrix() if laplace == KronLaplace else la2.H assert not torch.allclose(H_orig, H_new) From cad5d248acb93369b700530859860dd7432a4bfd Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Thu, 14 Mar 2024 20:35:26 -0400 Subject: [PATCH 10/10] More definitive assert on --- tests/test_serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index f6951275..b3fe4796 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -111,7 +111,7 @@ def test_serialize_override(laplace, model, reg_loader): la2.fit(reg_loader, override=False) H_new = la2.H_facs.to_matrix() if laplace == KronLaplace else la2.H - assert not torch.allclose(H_orig, H_new) + assert torch.allclose(2 * H_orig, H_new) @pytest.mark.parametrize('laplace', flavors)