Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add native serialization support #148

Merged
merged 11 commits into from
Mar 15, 2024
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,5 @@ dmypy.json
data/

.DS_Store

state_dict.bin
70 changes: 40 additions & 30 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,36 +38,6 @@ pip install -e .[tests]
pytest tests/
```

## Structure
The laplace package consists of two main components:

1. The subclasses of [`laplace.BaseLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/baselaplace.py) that implement different sparsity structures: different subsets of weights (`'all'`, `'subnetwork'` and `'last_layer'`) and different structures of the Hessian approximation (`'full'`, `'kron'`, `'lowrank'` and `'diag'`). This results in _nine_ currently available options: `laplace.FullLaplace`, `laplace.KronLaplace`, `laplace.DiagLaplace`, the corresponding last-layer variations `laplace.FullLLLaplace`, `laplace.KronLLLaplace`, and `laplace.DiagLLLaplace` (which are all subclasses of [`laplace.LLLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/lllaplace.py)), [`laplace.SubnetLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/subnetlaplace.py) (which only supports `'full'` and `'diag'` Hessian approximations) and `laplace.LowRankLaplace` (which only supports inference over `'all'` weights). All of these can be conveniently accessed via the [`laplace.Laplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/laplace.py) function.
2. The backends in [`laplace.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/) which provide access to Hessian approximations of
the corresponding sparsity structures, for example, the diagonal GGN.

Additionally, the package provides utilities for
decomposing a neural network into feature extractor and last layer for `LLLaplace` subclasses ([`laplace.utils.feature_extractor`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/feature_extractor.py))
and
effectively dealing with Kronecker factors ([`laplace.utils.matrix`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/matrix.py)).

Finally, the package implements several options to select/specify a subnetwork for `SubnetLaplace` (as subclasses of [`laplace.utils.subnetmask.SubnetMask`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/subnetmask.py)).
Automatic subnetwork selection strategies include: uniformly at random (`laplace.utils.subnetmask.RandomSubnetMask`), by largest parameter magnitudes (`LargestMagnitudeSubnetMask`), and by largest marginal parameter variances (`LargestVarianceDiagLaplaceSubnetMask` and `LargestVarianceSWAGSubnetMask`).
In addition to that, subnetworks can also be specified manually, by listing the names of either the model parameters (`ParamNameSubnetMask`) or modules (`ModuleNameSubnetMask`) to perform Laplace inference over.

## Extendability
To extend the laplace package, new `BaseLaplace` subclasses can be designed, for example,
Laplace with a block-diagonal Hessian structure.
One can also implement custom subnetwork selection strategies as new subclasses of `SubnetMask`.

Alternatively, extending or integrating backends (subclasses of [`curvature.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/curvature.py)) allows to provide different Hessian
approximations to the Laplace approximations.
For example, currently the [`curvature.CurvlinopsInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/curvlinops.py) based on [Curvlinops](https://github.com/f-dangel/curvlinops) and the native `torch.func` (previously known as `functorch`), [`curvature.BackPackInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/backpack.py) based on [BackPACK](https://github.com/f-dangel/backpack/) and [`curvature.AsdlInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/asdl.py) based on [ASDL](https://github.com/kazukiosawa/asdfghjkl) are available.

The `curvature.CurvlinopsInterface` backend is the default and provides all Hessian approximation variants except the low-rank Hessian.
For the latter, `curvature.AsdlInterface` can be used.
Note that `curvature.AsdlInterface` and `curvature.BackPackInterface` are less complete and less compatible than `curvature.CurvlinopsInterface`.
So, we recommend to stick with `curvature.CurvlinopsInterface` unless you have a specific need of ASDL or BackPACK.

## Example usage

### *Post-hoc* prior precision tuning of diagonal LA
Expand All @@ -94,6 +64,15 @@ la.optimize_prior_precision(method='gridsearch', val_loader=val_loader)

# User-specified predictive approx.
pred = la(x, link_approx='probit')

# Serialization
torch.save(la.state_dict(), 'state_dict.bin')

# Load serialized Laplace
la2 = Laplace(model, 'classification',
subset_of_weights='all',
hessian_structure='diag')
la2.load_state_dict(torch.load('state_dict.bin'))
```

### Differentiating the log marginal likelihood w.r.t. hyperparameters
Expand Down Expand Up @@ -157,6 +136,37 @@ la = Laplace(model, 'classification',
la.fit(train_loader)
```


## Structure
The laplace package consists of two main components:

1. The subclasses of [`laplace.BaseLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/baselaplace.py) that implement different sparsity structures: different subsets of weights (`'all'`, `'subnetwork'` and `'last_layer'`) and different structures of the Hessian approximation (`'full'`, `'kron'`, `'lowrank'` and `'diag'`). This results in _nine_ currently available options: `laplace.FullLaplace`, `laplace.KronLaplace`, `laplace.DiagLaplace`, the corresponding last-layer variations `laplace.FullLLLaplace`, `laplace.KronLLLaplace`, and `laplace.DiagLLLaplace` (which are all subclasses of [`laplace.LLLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/lllaplace.py)), [`laplace.SubnetLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/subnetlaplace.py) (which only supports `'full'` and `'diag'` Hessian approximations) and `laplace.LowRankLaplace` (which only supports inference over `'all'` weights). All of these can be conveniently accessed via the [`laplace.Laplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/laplace.py) function.
2. The backends in [`laplace.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/) which provide access to Hessian approximations of
the corresponding sparsity structures, for example, the diagonal GGN.

Additionally, the package provides utilities for
decomposing a neural network into feature extractor and last layer for `LLLaplace` subclasses ([`laplace.utils.feature_extractor`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/feature_extractor.py))
and
effectively dealing with Kronecker factors ([`laplace.utils.matrix`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/matrix.py)).

Finally, the package implements several options to select/specify a subnetwork for `SubnetLaplace` (as subclasses of [`laplace.utils.subnetmask.SubnetMask`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/subnetmask.py)).
Automatic subnetwork selection strategies include: uniformly at random (`laplace.utils.subnetmask.RandomSubnetMask`), by largest parameter magnitudes (`LargestMagnitudeSubnetMask`), and by largest marginal parameter variances (`LargestVarianceDiagLaplaceSubnetMask` and `LargestVarianceSWAGSubnetMask`).
In addition to that, subnetworks can also be specified manually, by listing the names of either the model parameters (`ParamNameSubnetMask`) or modules (`ModuleNameSubnetMask`) to perform Laplace inference over.

## Extendability
To extend the laplace package, new `BaseLaplace` subclasses can be designed, for example,
Laplace with a block-diagonal Hessian structure.
One can also implement custom subnetwork selection strategies as new subclasses of `SubnetMask`.

Alternatively, extending or integrating backends (subclasses of [`curvature.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/curvature.py)) allows to provide different Hessian
approximations to the Laplace approximations.
For example, currently the [`curvature.CurvlinopsInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/curvlinops.py) based on [Curvlinops](https://github.com/f-dangel/curvlinops) and the native `torch.func` (previously known as `functorch`), [`curvature.BackPackInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/backpack.py) based on [BackPACK](https://github.com/f-dangel/backpack/) and [`curvature.AsdlInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/asdl.py) based on [ASDL](https://github.com/kazukiosawa/asdfghjkl) are available.

The `curvature.CurvlinopsInterface` backend is the default and provides all Hessian approximation variants except the low-rank Hessian.
For the latter, `curvature.AsdlInterface` can be used.
Note that `curvature.AsdlInterface` and `curvature.BackPackInterface` are less complete and less compatible than `curvature.CurvlinopsInterface`.
So, we recommend to stick with `curvature.CurvlinopsInterface` unless you have a specific need of ASDL or BackPACK.

## Documentation

The documentation is available [here](https://aleximmer.github.io/Laplace) or can be generated and/or viewed locally:
Expand Down
20 changes: 14 additions & 6 deletions examples/regression_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ def get_model():
neg_marglik.backward()
hyper_optimizer.step()

# Serialization for fitted quantities
state_dict = la.state_dict()
torch.save(state_dict, 'state_dict.bin')

la = Laplace(model, 'regression', subset_of_weights='all', hessian_structure='full')
# Load serialized, fitted quantities
la.load_state_dict(torch.load('state_dict.bin'))

print(f'sigma={la.sigma_noise.item():.2f}',
f'prior precision={la.prior_precision.item():.2f}')

Expand All @@ -51,11 +59,11 @@ def get_model():
# Two options:
# 1.) Marginal predictive distribution N(f_map(x_i), var(x_i))
# The mean is (m,k), the var is (m,k,k)
f_mu, f_var = la(X_test)
f_mu, f_var = la(X_test)

# 2.) Joint pred. dist. N((f_map(x_1),...,f_map(x_m)), Cov(f(x_1),...,f(x_m)))
# The mean is (m*k,) where k is the output dim. The cov is (m*k,m*k)
f_mu_joint, f_cov = la(X_test, joint=True)
f_mu_joint, f_cov = la(X_test, joint=True)

# Both should be true
assert torch.allclose(f_mu.flatten(), f_mu_joint)
Expand All @@ -65,14 +73,14 @@ def get_model():
f_sigma = f_var.squeeze().detach().sqrt().cpu().numpy()
pred_std = np.sqrt(f_sigma**2 + la.sigma_noise.item()**2)

plot_regression(X_train, y_train, x, f_mu, pred_std,
file_name='regression_example', plot=False)
plot_regression(X_train, y_train, x, f_mu, pred_std,
file_name='regression_example', plot=True)

# alternatively, optimize parameters and hyperparameters of the prior jointly
model = get_model()
la, model, margliks, losses = marglik_training(
model=model, train_loader=train_loader, likelihood='regression',
hessian_structure='full', backend=BackPackGGN, n_epochs=n_epochs,
hessian_structure='full', backend=BackPackGGN, n_epochs=n_epochs,
optimizer_kwargs={'lr': 1e-2}, prior_structure='scalar'
)

Expand All @@ -83,5 +91,5 @@ def get_model():
f_mu = f_mu.squeeze().detach().cpu().numpy()
f_sigma = f_var.squeeze().sqrt().cpu().numpy()
pred_std = np.sqrt(f_sigma**2 + la.sigma_noise.item()**2)
plot_regression(X_train, y_train, x, f_mu, pred_std,
plot_regression(X_train, y_train, x, f_mu, pred_std,
file_name='regression_example_online', plot=False)
71 changes: 71 additions & 0 deletions laplace/baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.distributions import MultivariateNormal
import warnings

from laplace.utils import (parameters_per_layer, invsqrt_precision,
get_nll, validate, Kron, normal_samples,
Expand Down Expand Up @@ -768,6 +769,64 @@ def posterior_precision(self):
"""
raise NotImplementedError

def state_dict(self) -> dict:
self._check_H_init()
state_dict = {
'mean': self.mean,
'H': self.H,
'loss': self.loss,
'prior_mean': self.prior_mean,
'prior_precision': self.prior_precision,
'sigma_noise': self.sigma_noise,
'n_data': self.n_data,
'n_outputs': self.n_outputs,
'likelihood': self.likelihood,
'temperature': self.temperature,
'enable_backprop': self.enable_backprop,
'cls_name': self.__class__.__name__
}
return state_dict

def load_state_dict(self, state_dict: dict):
# Dealbreaker errors
if self.__class__.__name__ != state_dict['cls_name']:
raise ValueError(
'Loading a wrong Laplace type. Make sure `subset_of_weights` and'
+ ' `hessian_structure` are correct!'
)
if self.n_params is not None and len(state_dict['mean']) != self.n_params:
raise ValueError(
'Attempting to load Laplace with different number of parameters than the model.'
+ ' Make sure that you use the same `subset_of_weights` value and the same `.requires_grad`'
+ ' switch on `model.parameters()`.'
)
if self.likelihood != state_dict['likelihood']:
raise ValueError('Different likelihoods detected!')

# Ignorable warnings
if self.prior_mean is None and state_dict['prior_mean'] is not None:
warnings.warn('Loading non-`None` prior mean into a `None` prior mean. You might get wrong results.')
if self.temperature != state_dict['temperature']:
warnings.warn('Different `temperature` parameters detected. Some calculation might be off!')
if self.enable_backprop != state_dict['enable_backprop']:
warnings.warn(
'Different `enable_backprop` values. You might encounter error when differentiating'
+ ' the predictive mean and variance.'
)

self.mean = state_dict['mean']
self.H = state_dict['H']
self.loss = state_dict['loss']
self.prior_mean = state_dict['prior_mean']
self.prior_precision = state_dict['prior_precision']
self.sigma_noise = state_dict['sigma_noise']
self.n_data = state_dict['n_data']
self.n_outputs = state_dict['n_outputs']
setattr(self.model, 'output_size', self.n_outputs)
self.likelihood = state_dict['likelihood']
self.temperature = state_dict['temperature']
self.enable_backprop = state_dict['enable_backprop']


class FullLaplace(ParametricLaplace):
"""Laplace approximation with full, i.e., dense, log likelihood Hessian approximation
Expand Down Expand Up @@ -961,6 +1020,18 @@ def prior_precision(self, prior_precision):
if len(self.prior_precision) not in [1, self.n_layers]:
raise ValueError('Prior precision for Kron either scalar or per-layer.')

def state_dict(self) -> dict:
state_dict = super().state_dict()
state_dict['H'] = self.H_facs.kfacs
return state_dict

def load_state_dict(self, state_dict: dict):
super().load_state_dict(state_dict)
self._init_H()
wiseodd marked this conversation as resolved.
Show resolved Hide resolved
self.H_facs = self.H
self.H_facs.kfacs = state_dict['H']
self.H = self.H_facs.decompose(damping=self.damping)


class LowRankLaplace(ParametricLaplace):
"""Laplace approximation with low-rank log likelihood Hessian (approximation).
Expand Down
32 changes: 29 additions & 3 deletions laplace/lllaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
self.mean = self.prior_mean
self._init_H()
self._backend_kwargs['last_layer'] = True
self._last_layer_name = last_layer_name

def fit(self, train_loader, override=True):
"""Fit the local Laplace approximation at the parameters of the model.
Expand All @@ -103,12 +104,13 @@ def fit(self, train_loader, override=True):
self.model.eval()

if self.model.last_layer is None:
X, _ = next(iter(train_loader))
# Save an example batch for when loading the serialized Laplace
self.X, _ = next(iter(train_loader))
wiseodd marked this conversation as resolved.
Show resolved Hide resolved
with torch.no_grad():
try:
self.model.find_last_layer(X[:1].to(self._device))
self.model.find_last_layer(self.X[:1].to(self._device))
except (TypeError, AttributeError):
self.model.find_last_layer(X.to(self._device))
self.model.find_last_layer(self.X.to(self._device))
params = parameters_to_vector(self.model.last_layer.parameters()).detach()
self.n_params = len(params)
self.n_layers = len(list(self.model.last_layer.parameters()))
Expand Down Expand Up @@ -164,6 +166,30 @@ def prior_precision_diag(self):
else:
raise ValueError('Mismatch of prior and model. Diagonal or scalar prior.')

def state_dict(self) -> dict:
state_dict = super().state_dict()
state_dict['X'] = getattr(self, 'X', None)
state_dict['_last_layer_name'] = self._last_layer_name
return state_dict

def load_state_dict(self, state_dict: dict):
if self._last_layer_name != state_dict['_last_layer_name']:
raise ValueError('Different `last_layer_name` detected!')

self.X = state_dict['X']
if self.X is not None:
with torch.no_grad():
try:
self.model.find_last_layer(self.X[:1].to(self._device))
except (TypeError, AttributeError):
self.model.find_last_layer(self.X.to(self._device))

super().load_state_dict(state_dict)

params = parameters_to_vector(self.model.last_layer.parameters()).detach()
self.n_params = len(params)
self.n_layers = len(list(self.model.last_layer.parameters()))


class FullLLLaplace(LLLaplace, FullLaplace):
"""Last-layer Laplace approximation with full, i.e., dense, log likelihood Hessian approximation
Expand Down
Loading