Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subnetwork Laplace #58

Merged
merged 54 commits into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
0b32640
Add support for subnetwork Laplace approximation
edaxberger Nov 30, 2021
88e806c
Fix issues with subnetwork Laplace integration
edaxberger Nov 30, 2021
768fa63
Remove notes to myself
edaxberger Nov 30, 2021
f8ab8ac
Remove SubnetLaplace base class; only option remains FullSubnetLaplace
edaxberger Dec 10, 2021
57f46d2
Make jacobians and last_layer_jacobians non-static and adapted code a…
edaxberger Dec 10, 2021
2530122
Add SubnetMask baseclass and subclasses for random, largest magnitude…
edaxberger Dec 10, 2021
c0be3f9
Adapt FullSubnetLaplace to use new SubnetMask class interface
edaxberger Dec 10, 2021
257d33b
Add support for largest variance subnet selection (using diagonal Lap…
edaxberger Dec 10, 2021
5d29dbf
Merge branch 'main' into subnetlaplace
edaxberger Dec 10, 2021
91931c5
Merge branch 'main' into subnetlaplace
aleximmer Dec 10, 2021
3771e94
Remove change
aleximmer Dec 10, 2021
38cd0f6
Change FullSubnetLaplace to SubnetLaplace as it's the only option
edaxberger Dec 10, 2021
dddf746
Merge branch 'subnetlaplace' of https://github.com/AlexImmer/Laplace …
edaxberger Dec 10, 2021
f933dad
Convert indentation from tabs to spaces
edaxberger Dec 10, 2021
ac4542c
Remove model as argument from jacobians() as it has access to self.mo…
edaxberger Dec 10, 2021
6a88c82
Minor fixes for SubnetLaplace
edaxberger Dec 10, 2021
c47c7d0
Add tests for SubnetLaplace and SubnetMasks
edaxberger Dec 10, 2021
da23af9
Change indentation to spaces in test_subnetlaplace.py
edaxberger Dec 10, 2021
ddf840c
Implement sample() method for SubnetLaplace (as e.g. required for the…
edaxberger Dec 15, 2021
7a48489
Add tests for SubnetLaplace predictives
edaxberger Dec 15, 2021
bcc9ca7
Fix small bug in LastLayerSubnetMask
edaxberger Dec 16, 2021
381f79d
Add reference to subnetwork inference paper to SubnetLaplace docstring
edaxberger Dec 16, 2021
f345ac5
Add SubnetMask that allows specifying subnet parameters or modules by…
edaxberger Dec 20, 2021
ea16f34
Make subnet mask type check independent of CUDA and change default ty…
edaxberger Dec 20, 2021
176d678
Change LargestMagnitudeSubnetMask to use absolute parameter values
edaxberger Dec 21, 2021
59603e8
Small refactoring of SubnetLaplace tests
edaxberger Dec 21, 2021
8410eca
Add implementation of SubnetMask that selects params with largest var…
edaxberger Dec 21, 2021
e91646f
Remove tqdm dependency in SWAG
edaxberger Dec 21, 2021
d43b2e4
Merge remote-tracking branch 'origin/main' into subnetlaplace
edaxberger Dec 21, 2021
3a48e02
Set H=None in SubnetLaplace before calling super() constructor for co…
edaxberger Dec 21, 2021
7ae867f
Change indentation from tabs to spaces
edaxberger Dec 21, 2021
015e917
Update README: include subnetwork and low-rank Laplace & update paper…
edaxberger Dec 21, 2021
fce8545
Minor changes to README
edaxberger Dec 21, 2021
55e418a
Move utility files to new utils directory
edaxberger Jan 3, 2022
6044316
Change SubnetLaplace to take subnetwork indices instead of a subclass…
edaxberger Jan 3, 2022
6de38c9
Change subnetwork indices validity check to only allow long tensors
edaxberger Jan 3, 2022
7058b39
Add test for SubnetLaplace with custom subnetwork indices specification
edaxberger Jan 3, 2022
e66ba51
Remove None default value for subnetwork indices in SubnetLaplace
edaxberger Jan 3, 2022
589b846
Add failing test case for scalar subnetwork indices
edaxberger Jan 4, 2022
be245c8
Change SubnetMask.select() to return subnet indices and improve docum…
edaxberger Jan 4, 2022
6ae4f9f
Minor refactorings (subnet indices checks and documentation)
edaxberger Jan 4, 2022
288f767
Add README example for SubnetLaplace
edaxberger Jan 4, 2022
b5d8adf
Add __all__ for utils.py
edaxberger Jan 12, 2022
557aa05
Add __all__ for swag.py
edaxberger Jan 12, 2022
4ab30df
Add __init__.py to utils/ to simplify utility imports
edaxberger Jan 12, 2022
99d3fb2
Simplify check for duplicate indices in Subnet Laplace
edaxberger Jan 12, 2022
6186573
Add __all__ for matrix.py
edaxberger Jan 12, 2022
86b91bc
Add line breaks with proper indents for __all__ in subnetmask.py
edaxberger Jan 12, 2022
3b54e8a
Change name of fit_diagonal_swag() to fit_diagonal_swag_var()
edaxberger Jan 12, 2022
d5d2d23
Shorten lines to 100 chars and change line breaks from backslash to b…
edaxberger Jan 12, 2022
ddc15c5
Add call to _init_H() in the Subnet Laplace constructor
edaxberger Jan 12, 2022
0952ead
Add test for instantiating Subnet Laplace with large model
edaxberger Jan 12, 2022
83877b2
Update docs
edaxberger Jan 12, 2022
b956356
Merge remote-tracking branch 'origin/main' into subnetlaplace
edaxberger Jan 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 64 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@

[![Main](https://travis-ci.com/AlexImmer/Laplace.svg?token=rpuRxEjQS6cCZi7ptL9y&branch=main)](https://travis-ci.com/AlexImmer/Laplace)

The laplace package facilitates the application of Laplace approximations for entire neural networks or just their last layer.
The laplace package facilitates the application of Laplace approximations for entire neural networks, subnetworks of neural networks, or just their last layer.
The package enables posterior approximations, marginal-likelihood estimation, and various posterior predictive computations.
The library documentation is available at [https://aleximmer.github.io/Laplace](https://aleximmer.github.io/Laplace).

There is also a corresponding paper, [*Laplace Redux — Effortless Bayesian Deep Learning*](https://arxiv.org/abs/2106.14806), which introduces the library, provides an introduction to the Laplace approximation, reviews its use in deep learning, and empirically demonstrates its versatility and competitiveness. Please consider referring to the paper when using our library:
```bibtex
@article{daxberger2021laplace,
title={Laplace Redux--Effortless Bayesian Deep Learning},
author={Daxberger, Erik and Kristiadi, Agustinus and Immer, Alexander
and Eschenhagen, Runa and Bauer, Matthias and Hennig, Philipp},
journal={arXiv preprint arXiv:2106.14806},
@inproceedings{laplace2021,
title={Laplace Redux--Effortless {B}ayesian Deep Learning},
author={Erik Daxberger and Agustinus Kristiadi and Alexander Immer
and Runa Eschenhagen and Matthias Bauer and Philipp Hennig},
booktitle={{N}eur{IPS}},
year={2021}
}
```
Expand All @@ -39,18 +39,24 @@ pytest tests/
## Structure
The laplace package consists of two main components:

1. The subclasses of [`laplace.BaseLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/baselaplace.py) that implement different sparsity structures: different subsets of weights (`'all'` and `'last_layer'`) and different structures of the Hessian approximation (`'full'`, `'kron'`, and `'diag'`). This results in six currently available options: `laplace.FullLaplace`, `laplace.KronLaplace`, `laplace.DiagLaplace`, and the corresponding last-layer variations `laplace.FullLLLaplace`, `laplace.KronLLLaplace`, and `laplace.DiagLLLaplace`, which are all subclasses of [`laplace.LLLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/lllaplace.py). All of these can be conveniently accessed via the [`laplace.Laplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/laplace.py) function.
1. The subclasses of [`laplace.BaseLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/baselaplace.py) that implement different sparsity structures: different subsets of weights (`'all'`, `'subnetwork'` and `'last_layer'`) and different structures of the Hessian approximation (`'full'`, `'kron'`, `'lowrank'` and `'diag'`). This results in _eight_ currently available options: `laplace.FullLaplace`, `laplace.KronLaplace`, `laplace.DiagLaplace`, the corresponding last-layer variations `laplace.FullLLLaplace`, `laplace.KronLLLaplace`, and `laplace.DiagLLLaplace` (which are all subclasses of [`laplace.LLLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/lllaplace.py)), [`laplace.SubnetLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/subnetlaplace.py) (which only supports a `'full'` Hessian approximation) and `laplace.LowRankLaplace` (which only supports inference over `'all'` weights). All of these can be conveniently accessed via the [`laplace.Laplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/laplace.py) function.
2. The backends in [`laplace.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/) which provide access to Hessian approximations of
the corresponding sparsity structures, for example, the diagonal GGN.

Additionally, the package provides utilities for
decomposing a neural network into feature extractor and last layer for `LLLaplace` subclasses ([`laplace.feature_extractor`](https://github.com/AlexImmer/Laplace/blob/main/laplace/feature_extractor.py))
decomposing a neural network into feature extractor and last layer for `LLLaplace` subclasses ([`laplace.utils.feature_extractor`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/feature_extractor.py))
and
effectively dealing with Kronecker factors ([`laplace.matrix`](https://github.com/AlexImmer/Laplace/blob/main/laplace/matrix.py)).
effectively dealing with Kronecker factors ([`laplace.utils.matrix`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/matrix.py)).

Finally, the package implements several options to select/specify a subnetwork for `SubnetLaplace` (as subclasses of [`laplace.utils.subnetmask.SubnetMask`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/subnetmask.py)).
Automatic subnetwork selection strategies include: uniformly at random (`laplace.utils.subnetmask.RandomSubnetMask`), by largest parameter magnitudes (`LargestMagnitudeSubnetMask`), and by largest marginal parameter variances (`LargestVarianceDiagLaplaceSubnetMask` and `LargestVarianceSWAGSubnetMask`).
In addition to that, subnetworks can also be specified manually, by listing the names of either the model parameters (`ParamNameSubnetMask`) or modules (`ModuleNameSubnetMask`) to perform Laplace inference over.

## Extendability
To extend the laplace package, new `BaseLaplace` subclasses can be designed, for example,
a block-diagonal structure or subset-of-weights Laplace.
Laplace with a block-diagonal Hessian structure.
One can also implement custom subnetwork selection strategies as new subclasses of `SubnetMask`.

Alternatively, extending or integrating backends (subclasses of [`curvature.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/curvature.py)) allows to provide different Hessian
approximations to the Laplace approximations.
For example, currently the [`curvature.BackPackInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/backpack.py) based on [BackPACK](https://github.com/f-dangel/backpack/) and [`curvature.AsdlInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/asdl.py) based on [ASDL](https://github.com/kazukiosawa/asdfghjkl) are available.
Expand All @@ -60,18 +66,19 @@ for a regression (MSELoss) loss function.

## Example usage

### *Post-hoc* prior precision tuning of last-layer LA
### *Post-hoc* prior precision tuning of diagonal LA

In the following example, a pre-trained model is loaded,
then the Laplace approximation is fit to the training data,
then the Laplace approximation is fit to the training data
(using a diagonal Hessian approximation over all parameters),
and the prior precision is optimized with cross-validation `'CV'`.
After that, the resulting LA is used for prediction with
the `'probit'` predictive for classification.

```python
from laplace import Laplace

# pre-trained model
# Pre-trained model
model = load_map_model()

# User-specified LA flavor
Expand All @@ -87,7 +94,7 @@ pred = la(x, link_approx='probit')

### Differentiating the log marginal likelihood w.r.t. hyperparameters

The marginal likelihood can be used for model selection and is differentiable
The marginal likelihood can be used for model selection [10] and is differentiable
for continuous hyperparameters like the prior precision or observation noise.
Here, we fit the library default, KFAC last-layer LA and differentiate
the log marginal likelihood.
Expand All @@ -107,6 +114,45 @@ ml = la.log_marginal_likelihood(prior_prec, obs_noise)
ml.backward()
```

### Applying the LA over only a subset of the model parameters

This example shows how to fit the Laplace approximation over only
a subnetwork within a neural network (while keeping all other parameters
fixed at their MAP estimates), as proposed in [11]. It also exemplifies
different ways to specify the subnetwork to perform inference over.

```python
from laplace import Laplace

# Pre-trained model
model = load_model()

# Examples of different ways to specify the subnetwork
# via indices of the vectorized model parameters
#
# Example 1: select the 128 parameters with the largest magnitude
from laplace.utils import LargestMagnitudeSubnetMask
subnetwork_mask = LargestMagnitudeSubnetMask(model, n_params_subnet=128)
subnetwork_indices = subnetwork_mask.select()

# Example 2: specify the layers that define the subnetwork
from laplace.utils import ModuleNameSubnetMask
subnetwork_mask = ModuleNameSubnetMask(model, module_names=['layer.1', 'layer.3'])
subnetwork_mask.select()
subnetwork_indices = subnetwork_mask.indices

# Example 3: manually define the subnetwork via custom subnetwork indices
import torch
subnetwork_indices = torch.tensor([0, 4, 11, 42, 123, 2021])

# Define and fit subnetwork LA using the specified subnetwork indices
la = Laplace(model, 'classification',
subset_of_weights='subnetwork',
hessian_structure='full',
subnetwork_indices=subnetwork_indices)
la.fit(train_loader)
```

## Documentation

The documentation is available [here](https://aleximmer.github.io/Laplace) or can be generated and/or viewed locally:
Expand All @@ -122,7 +168,7 @@ pdoc --http 0.0.0.0:8080 laplace --template-dir template

## References

This package relies on various improvements to the Laplace approximation for neural networks, which was originally due to MacKay [1].
This package relies on various improvements to the Laplace approximation for neural networks, which was originally due to MacKay [1]. Please consider citing the respective papers if you use any of their proposed methods via our laplace library.

- [1] MacKay, DJC. [*A Practical Bayesian Framework for Backpropagation Networks*](https://authors.library.caltech.edu/13793/). Neural Computation 1992.
- [2] Gibbs, M. N. [*Bayesian Gaussian Processes for Regression and Classification*](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.147.1130&rep=rep1&type=pdf). PhD Thesis 1997.
Expand All @@ -132,4 +178,6 @@ This package relies on various improvements to the Laplace approximation for neu
- [6] Khan, M. E., Immer, A., Abedi, E., Korzepa, M. [*Approximate Inference Turns Deep Networks into Gaussian Processes*](https://arxiv.org/abs/1906.01930). NeurIPS 2019.
- [7] Kristiadi, A., Hein, M., Hennig, P. [*Being Bayesian, Even Just a Bit, Fixes Overconfidence in ReLU Networks*](https://arxiv.org/abs/2002.10118). ICML 2020.
- [8] Immer, A., Korzepa, M., Bauer, M. [*Improving predictions of Bayesian neural nets via local linearization*](https://arxiv.org/abs/2008.08400). AISTATS 2021.
- [9] Immer, A., Bauer, M., Fortuin, V., Rätsch, G., Khan, EM. [*Scalable Marginal Likelihood Estimation for Model Selection in Deep Learning*](https://arxiv.org/abs/2104.04975). ICML 2021.
- [9] Sharma, A., Azizan, N., Pavone, M. [*Sketching Curvature for Efficient Out-of-Distribution Detection for Deep Neural Networks*](https://arxiv.org/abs/2102.12567). UAI 2021.
- [10] Immer, A., Bauer, M., Fortuin, V., Rätsch, G., Khan, EM. [*Scalable Marginal Likelihood Estimation for Model Selection in Deep Learning*](https://arxiv.org/abs/2104.04975). ICML 2021.
- [11] Daxberger, E., Nalisnick, E., Allingham, JU., Antorán, J., Hernández-Lobato, JM. [*Bayesian Deep Learning via Subnetwork Inference*](https://arxiv.org/abs/2010.14689). ICML 2021.
Loading