Skip to content

Commit

Permalink
Merge pull request #148 from aleximmer/serialization
Browse files Browse the repository at this point in the history
Add native serialization support
  • Loading branch information
runame authored Mar 15, 2024
2 parents 7541c9b + cad5d24 commit 2543274
Show file tree
Hide file tree
Showing 6 changed files with 375 additions and 39 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,5 @@ dmypy.json
data/

.DS_Store

state_dict.bin
70 changes: 40 additions & 30 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,36 +38,6 @@ pip install -e .[tests]
pytest tests/
```

## Structure
The laplace package consists of two main components:

1. The subclasses of [`laplace.BaseLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/baselaplace.py) that implement different sparsity structures: different subsets of weights (`'all'`, `'subnetwork'` and `'last_layer'`) and different structures of the Hessian approximation (`'full'`, `'kron'`, `'lowrank'` and `'diag'`). This results in _nine_ currently available options: `laplace.FullLaplace`, `laplace.KronLaplace`, `laplace.DiagLaplace`, the corresponding last-layer variations `laplace.FullLLLaplace`, `laplace.KronLLLaplace`, and `laplace.DiagLLLaplace` (which are all subclasses of [`laplace.LLLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/lllaplace.py)), [`laplace.SubnetLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/subnetlaplace.py) (which only supports `'full'` and `'diag'` Hessian approximations) and `laplace.LowRankLaplace` (which only supports inference over `'all'` weights). All of these can be conveniently accessed via the [`laplace.Laplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/laplace.py) function.
2. The backends in [`laplace.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/) which provide access to Hessian approximations of
the corresponding sparsity structures, for example, the diagonal GGN.

Additionally, the package provides utilities for
decomposing a neural network into feature extractor and last layer for `LLLaplace` subclasses ([`laplace.utils.feature_extractor`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/feature_extractor.py))
and
effectively dealing with Kronecker factors ([`laplace.utils.matrix`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/matrix.py)).

Finally, the package implements several options to select/specify a subnetwork for `SubnetLaplace` (as subclasses of [`laplace.utils.subnetmask.SubnetMask`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/subnetmask.py)).
Automatic subnetwork selection strategies include: uniformly at random (`laplace.utils.subnetmask.RandomSubnetMask`), by largest parameter magnitudes (`LargestMagnitudeSubnetMask`), and by largest marginal parameter variances (`LargestVarianceDiagLaplaceSubnetMask` and `LargestVarianceSWAGSubnetMask`).
In addition to that, subnetworks can also be specified manually, by listing the names of either the model parameters (`ParamNameSubnetMask`) or modules (`ModuleNameSubnetMask`) to perform Laplace inference over.

## Extendability
To extend the laplace package, new `BaseLaplace` subclasses can be designed, for example,
Laplace with a block-diagonal Hessian structure.
One can also implement custom subnetwork selection strategies as new subclasses of `SubnetMask`.

Alternatively, extending or integrating backends (subclasses of [`curvature.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/curvature.py)) allows to provide different Hessian
approximations to the Laplace approximations.
For example, currently the [`curvature.CurvlinopsInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/curvlinops.py) based on [Curvlinops](https://github.com/f-dangel/curvlinops) and the native `torch.func` (previously known as `functorch`), [`curvature.BackPackInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/backpack.py) based on [BackPACK](https://github.com/f-dangel/backpack/) and [`curvature.AsdlInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/asdl.py) based on [ASDL](https://github.com/kazukiosawa/asdfghjkl) are available.

The `curvature.CurvlinopsInterface` backend is the default and provides all Hessian approximation variants except the low-rank Hessian.
For the latter, `curvature.AsdlInterface` can be used.
Note that `curvature.AsdlInterface` and `curvature.BackPackInterface` are less complete and less compatible than `curvature.CurvlinopsInterface`.
So, we recommend to stick with `curvature.CurvlinopsInterface` unless you have a specific need of ASDL or BackPACK.

## Example usage

### *Post-hoc* prior precision tuning of diagonal LA
Expand All @@ -94,6 +64,15 @@ la.optimize_prior_precision(method='gridsearch', val_loader=val_loader)

# User-specified predictive approx.
pred = la(x, link_approx='probit')

# Serialization
torch.save(la.state_dict(), 'state_dict.bin')

# Load serialized Laplace
la2 = Laplace(model, 'classification',
subset_of_weights='all',
hessian_structure='diag')
la2.load_state_dict(torch.load('state_dict.bin'))
```

### Differentiating the log marginal likelihood w.r.t. hyperparameters
Expand Down Expand Up @@ -157,6 +136,37 @@ la = Laplace(model, 'classification',
la.fit(train_loader)
```


## Structure
The laplace package consists of two main components:

1. The subclasses of [`laplace.BaseLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/baselaplace.py) that implement different sparsity structures: different subsets of weights (`'all'`, `'subnetwork'` and `'last_layer'`) and different structures of the Hessian approximation (`'full'`, `'kron'`, `'lowrank'` and `'diag'`). This results in _nine_ currently available options: `laplace.FullLaplace`, `laplace.KronLaplace`, `laplace.DiagLaplace`, the corresponding last-layer variations `laplace.FullLLLaplace`, `laplace.KronLLLaplace`, and `laplace.DiagLLLaplace` (which are all subclasses of [`laplace.LLLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/lllaplace.py)), [`laplace.SubnetLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/subnetlaplace.py) (which only supports `'full'` and `'diag'` Hessian approximations) and `laplace.LowRankLaplace` (which only supports inference over `'all'` weights). All of these can be conveniently accessed via the [`laplace.Laplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/laplace.py) function.
2. The backends in [`laplace.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/) which provide access to Hessian approximations of
the corresponding sparsity structures, for example, the diagonal GGN.

Additionally, the package provides utilities for
decomposing a neural network into feature extractor and last layer for `LLLaplace` subclasses ([`laplace.utils.feature_extractor`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/feature_extractor.py))
and
effectively dealing with Kronecker factors ([`laplace.utils.matrix`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/matrix.py)).

Finally, the package implements several options to select/specify a subnetwork for `SubnetLaplace` (as subclasses of [`laplace.utils.subnetmask.SubnetMask`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/subnetmask.py)).
Automatic subnetwork selection strategies include: uniformly at random (`laplace.utils.subnetmask.RandomSubnetMask`), by largest parameter magnitudes (`LargestMagnitudeSubnetMask`), and by largest marginal parameter variances (`LargestVarianceDiagLaplaceSubnetMask` and `LargestVarianceSWAGSubnetMask`).
In addition to that, subnetworks can also be specified manually, by listing the names of either the model parameters (`ParamNameSubnetMask`) or modules (`ModuleNameSubnetMask`) to perform Laplace inference over.

## Extendability
To extend the laplace package, new `BaseLaplace` subclasses can be designed, for example,
Laplace with a block-diagonal Hessian structure.
One can also implement custom subnetwork selection strategies as new subclasses of `SubnetMask`.

Alternatively, extending or integrating backends (subclasses of [`curvature.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/curvature.py)) allows to provide different Hessian
approximations to the Laplace approximations.
For example, currently the [`curvature.CurvlinopsInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/curvlinops.py) based on [Curvlinops](https://github.com/f-dangel/curvlinops) and the native `torch.func` (previously known as `functorch`), [`curvature.BackPackInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/backpack.py) based on [BackPACK](https://github.com/f-dangel/backpack/) and [`curvature.AsdlInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/asdl.py) based on [ASDL](https://github.com/kazukiosawa/asdfghjkl) are available.

The `curvature.CurvlinopsInterface` backend is the default and provides all Hessian approximation variants except the low-rank Hessian.
For the latter, `curvature.AsdlInterface` can be used.
Note that `curvature.AsdlInterface` and `curvature.BackPackInterface` are less complete and less compatible than `curvature.CurvlinopsInterface`.
So, we recommend to stick with `curvature.CurvlinopsInterface` unless you have a specific need of ASDL or BackPACK.

## Documentation

The documentation is available [here](https://aleximmer.github.io/Laplace) or can be generated and/or viewed locally:
Expand Down
20 changes: 14 additions & 6 deletions examples/regression_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ def get_model():
neg_marglik.backward()
hyper_optimizer.step()

# Serialization for fitted quantities
state_dict = la.state_dict()
torch.save(state_dict, 'state_dict.bin')

la = Laplace(model, 'regression', subset_of_weights='all', hessian_structure='full')
# Load serialized, fitted quantities
la.load_state_dict(torch.load('state_dict.bin'))

print(f'sigma={la.sigma_noise.item():.2f}',
f'prior precision={la.prior_precision.item():.2f}')

Expand All @@ -51,11 +59,11 @@ def get_model():
# Two options:
# 1.) Marginal predictive distribution N(f_map(x_i), var(x_i))
# The mean is (m,k), the var is (m,k,k)
f_mu, f_var = la(X_test)
f_mu, f_var = la(X_test)

# 2.) Joint pred. dist. N((f_map(x_1),...,f_map(x_m)), Cov(f(x_1),...,f(x_m)))
# The mean is (m*k,) where k is the output dim. The cov is (m*k,m*k)
f_mu_joint, f_cov = la(X_test, joint=True)
f_mu_joint, f_cov = la(X_test, joint=True)

# Both should be true
assert torch.allclose(f_mu.flatten(), f_mu_joint)
Expand All @@ -65,14 +73,14 @@ def get_model():
f_sigma = f_var.squeeze().detach().sqrt().cpu().numpy()
pred_std = np.sqrt(f_sigma**2 + la.sigma_noise.item()**2)

plot_regression(X_train, y_train, x, f_mu, pred_std,
file_name='regression_example', plot=False)
plot_regression(X_train, y_train, x, f_mu, pred_std,
file_name='regression_example', plot=True)

# alternatively, optimize parameters and hyperparameters of the prior jointly
model = get_model()
la, model, margliks, losses = marglik_training(
model=model, train_loader=train_loader, likelihood='regression',
hessian_structure='full', backend=BackPackGGN, n_epochs=n_epochs,
hessian_structure='full', backend=BackPackGGN, n_epochs=n_epochs,
optimizer_kwargs={'lr': 1e-2}, prior_structure='scalar'
)

Expand All @@ -83,5 +91,5 @@ def get_model():
f_mu = f_mu.squeeze().detach().cpu().numpy()
f_sigma = f_var.squeeze().sqrt().cpu().numpy()
pred_std = np.sqrt(f_sigma**2 + la.sigma_noise.item()**2)
plot_regression(X_train, y_train, x, f_mu, pred_std,
plot_regression(X_train, y_train, x, f_mu, pred_std,
file_name='regression_example_online', plot=False)
71 changes: 71 additions & 0 deletions laplace/baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.distributions import MultivariateNormal
import warnings

from laplace.utils import (parameters_per_layer, invsqrt_precision,
get_nll, validate, Kron, normal_samples,
Expand Down Expand Up @@ -768,6 +769,64 @@ def posterior_precision(self):
"""
raise NotImplementedError

def state_dict(self) -> dict:
self._check_H_init()
state_dict = {
'mean': self.mean,
'H': self.H,
'loss': self.loss,
'prior_mean': self.prior_mean,
'prior_precision': self.prior_precision,
'sigma_noise': self.sigma_noise,
'n_data': self.n_data,
'n_outputs': self.n_outputs,
'likelihood': self.likelihood,
'temperature': self.temperature,
'enable_backprop': self.enable_backprop,
'cls_name': self.__class__.__name__
}
return state_dict

def load_state_dict(self, state_dict: dict):
# Dealbreaker errors
if self.__class__.__name__ != state_dict['cls_name']:
raise ValueError(
'Loading a wrong Laplace type. Make sure `subset_of_weights` and'
+ ' `hessian_structure` are correct!'
)
if self.n_params is not None and len(state_dict['mean']) != self.n_params:
raise ValueError(
'Attempting to load Laplace with different number of parameters than the model.'
+ ' Make sure that you use the same `subset_of_weights` value and the same `.requires_grad`'
+ ' switch on `model.parameters()`.'
)
if self.likelihood != state_dict['likelihood']:
raise ValueError('Different likelihoods detected!')

# Ignorable warnings
if self.prior_mean is None and state_dict['prior_mean'] is not None:
warnings.warn('Loading non-`None` prior mean into a `None` prior mean. You might get wrong results.')
if self.temperature != state_dict['temperature']:
warnings.warn('Different `temperature` parameters detected. Some calculation might be off!')
if self.enable_backprop != state_dict['enable_backprop']:
warnings.warn(
'Different `enable_backprop` values. You might encounter error when differentiating'
+ ' the predictive mean and variance.'
)

self.mean = state_dict['mean']
self.H = state_dict['H']
self.loss = state_dict['loss']
self.prior_mean = state_dict['prior_mean']
self.prior_precision = state_dict['prior_precision']
self.sigma_noise = state_dict['sigma_noise']
self.n_data = state_dict['n_data']
self.n_outputs = state_dict['n_outputs']
setattr(self.model, 'output_size', self.n_outputs)
self.likelihood = state_dict['likelihood']
self.temperature = state_dict['temperature']
self.enable_backprop = state_dict['enable_backprop']


class FullLaplace(ParametricLaplace):
"""Laplace approximation with full, i.e., dense, log likelihood Hessian approximation
Expand Down Expand Up @@ -961,6 +1020,18 @@ def prior_precision(self, prior_precision):
if len(self.prior_precision) not in [1, self.n_layers]:
raise ValueError('Prior precision for Kron either scalar or per-layer.')

def state_dict(self) -> dict:
state_dict = super().state_dict()
state_dict['H'] = self.H_facs.kfacs
return state_dict

def load_state_dict(self, state_dict: dict):
super().load_state_dict(state_dict)
self._init_H()
self.H_facs = self.H
self.H_facs.kfacs = state_dict['H']
self.H = self.H_facs.decompose(damping=self.damping)


class LowRankLaplace(ParametricLaplace):
"""Laplace approximation with low-rank log likelihood Hessian (approximation).
Expand Down
32 changes: 29 additions & 3 deletions laplace/lllaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
self.mean = self.prior_mean
self._init_H()
self._backend_kwargs['last_layer'] = True
self._last_layer_name = last_layer_name

def fit(self, train_loader, override=True):
"""Fit the local Laplace approximation at the parameters of the model.
Expand All @@ -103,12 +104,13 @@ def fit(self, train_loader, override=True):
self.model.eval()

if self.model.last_layer is None:
X, _ = next(iter(train_loader))
# Save an example batch for when loading the serialized Laplace
self.X, _ = next(iter(train_loader))
with torch.no_grad():
try:
self.model.find_last_layer(X[:1].to(self._device))
self.model.find_last_layer(self.X[:1].to(self._device))
except (TypeError, AttributeError):
self.model.find_last_layer(X.to(self._device))
self.model.find_last_layer(self.X.to(self._device))
params = parameters_to_vector(self.model.last_layer.parameters()).detach()
self.n_params = len(params)
self.n_layers = len(list(self.model.last_layer.parameters()))
Expand Down Expand Up @@ -164,6 +166,30 @@ def prior_precision_diag(self):
else:
raise ValueError('Mismatch of prior and model. Diagonal or scalar prior.')

def state_dict(self) -> dict:
state_dict = super().state_dict()
state_dict['X'] = getattr(self, 'X', None)
state_dict['_last_layer_name'] = self._last_layer_name
return state_dict

def load_state_dict(self, state_dict: dict):
if self._last_layer_name != state_dict['_last_layer_name']:
raise ValueError('Different `last_layer_name` detected!')

self.X = state_dict['X']
if self.X is not None:
with torch.no_grad():
try:
self.model.find_last_layer(self.X[:1].to(self._device))
except (TypeError, AttributeError):
self.model.find_last_layer(self.X.to(self._device))

super().load_state_dict(state_dict)

params = parameters_to_vector(self.model.last_layer.parameters()).detach()
self.n_params = len(params)
self.n_layers = len(list(self.model.last_layer.parameters()))


class FullLLLaplace(LLLaplace, FullLaplace):
"""Last-layer Laplace approximation with full, i.e., dense, log likelihood Hessian approximation
Expand Down
Loading

0 comments on commit 2543274

Please sign in to comment.