From 0b326407a57724791c9304d9c32134257d99d1a5 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Tue, 30 Nov 2021 09:24:03 +0000
Subject: [PATCH 01/49] Add support for subnetwork Laplace approximation
---
laplace/baselaplace.py | 2 +-
laplace/curvature/asdl.py | 20 ++++--
laplace/curvature/backpack.py | 17 +++--
laplace/curvature/curvature.py | 23 +++++--
laplace/laplace.py | 5 +-
laplace/subnetlaplace.py | 120 +++++++++++++++++++++++++++++++++
6 files changed, 169 insertions(+), 18 deletions(-)
create mode 100644 laplace/subnetlaplace.py
diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py
index 527a130c..4b9ce704 100644
--- a/laplace/baselaplace.py
+++ b/laplace/baselaplace.py
@@ -560,7 +560,7 @@ def predictive_samples(self, x, pred_type='glm', n_samples=100):
@torch.enable_grad()
def _glm_predictive_distribution(self, X):
- Js, f_mu = self.backend.jacobians(self.model, X)
+ Js, f_mu = self.backend.jacobians(self.model, X, self.backend.subnetwork_indices)
f_var = self.functional_variance(Js)
return f_mu.detach(), f_var.detach()
diff --git a/laplace/curvature/asdl.py b/laplace/curvature/asdl.py
index c5307484..afe3e673 100644
--- a/laplace/curvature/asdl.py
+++ b/laplace/curvature/asdl.py
@@ -15,13 +15,13 @@
class AsdlInterface(CurvatureInterface):
"""Interface for asdfghjkl backend.
"""
- def __init__(self, model, likelihood, last_layer=False):
+ def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None):
if likelihood != 'classification':
raise ValueError('This backend only supports classification currently.')
- super().__init__(model, likelihood, last_layer)
+ super().__init__(model, likelihood, last_layer, subnetwork_indices)
@staticmethod
- def jacobians(model, x):
+ def jacobians(model, x, subnetwork_indices=None):
"""Compute Jacobians \\(\\nabla_\\theta f(x;\\theta)\\) at current parameter \\(\\theta\\)
using asdfghjkl's gradient per output dimension.
@@ -30,6 +30,9 @@ def jacobians(model, x):
model : torch.nn.Module
x : torch.Tensor
input data `(batch, input_shape)` on compatible device with model.
+ subnetwork_indices : torch.Tensor, default=None
+ indices of the vectorized model parameters that define the subnetwork
+ to apply the Laplace approximation over
Returns
-------
@@ -44,7 +47,10 @@ def loss_fn(outputs, targets):
return outputs[:, i].sum()
f = batch_gradient(model, loss_fn, x, None).detach()
- Js.append(_get_batch_grad(model))
+ Jk = _get_batch_grad(model)
+ if subnetwork_indices is not None:
+ Jk = Jk[:, subnetwork_indices]
+ Js.append(Jk)
Js = torch.stack(Js, dim=1)
return Js, f
@@ -66,6 +72,8 @@ def gradients(self, x, y):
"""
f = batch_gradient(self.model, self.lossfunc, x, y).detach()
Gs = _get_batch_grad(self._model)
+ if self.subnetwork_indices is not None:
+ Gs = Gs[:, self.subnetwork_indices]
loss = self.lossfunc(f, y)
return Gs, loss
@@ -134,8 +142,8 @@ def kron(self, X, y, N, **wkwargs) -> [torch.Tensor, Kron]:
class AsdlGGN(AsdlInterface, GGNInterface):
"""Implementation of the `GGNInterface` using asdfghjkl.
"""
- def __init__(self, model, likelihood, last_layer=False, stochastic=False):
- super().__init__(model, likelihood, last_layer)
+ def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False):
+ super().__init__(model, likelihood, last_layer, subnetwork_indices)
self.stochastic = stochastic
@property
diff --git a/laplace/curvature/backpack.py b/laplace/curvature/backpack.py
index 885ee2b9..6e655944 100644
--- a/laplace/curvature/backpack.py
+++ b/laplace/curvature/backpack.py
@@ -11,13 +11,13 @@
class BackPackInterface(CurvatureInterface):
"""Interface for Backpack backend.
"""
- def __init__(self, model, likelihood, last_layer=False):
- super().__init__(model, likelihood, last_layer)
+ def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None):
+ super().__init__(model, likelihood, last_layer, subnetwork_indices)
extend(self._model)
extend(self.lossfunc)
@staticmethod
- def jacobians(model, x):
+ def jacobians(model, x, subnetwork_indices=None):
"""Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\)
using backpack's BatchGrad per output dimension.
@@ -26,6 +26,9 @@ def jacobians(model, x):
model : torch.nn.Module
x : torch.Tensor
input data `(batch, input_shape)` on compatible device with model.
+ subnetwork_indices : torch.Tensor, default=None
+ indices of the vectorized model parameters that define the subnetwork
+ to apply the Laplace approximation over
Returns
-------
@@ -49,6 +52,8 @@ def jacobians(model, x):
to_cat.append(param.grad_batch.detach().reshape(x.shape[0], -1))
delattr(param, 'grad_batch')
Jk = torch.cat(to_cat, dim=1)
+ if subnetwork_indices is not None:
+ Jk = Jk[:, subnetwork_indices]
to_stack.append(Jk)
if i == 0:
f = out.detach()
@@ -83,14 +88,16 @@ def gradients(self, x, y):
loss.backward()
Gs = torch.cat([p.grad_batch.data.flatten(start_dim=1)
for p in self._model.parameters()], dim=1)
+ if self.subnetwork_indices is not None:
+ Gs = Gs[:, self.subnetwork_indices]
return Gs, loss
class BackPackGGN(BackPackInterface, GGNInterface):
"""Implementation of the `GGNInterface` using Backpack.
"""
- def __init__(self, model, likelihood, last_layer=False, stochastic=False):
- super().__init__(model, likelihood, last_layer)
+ def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False):
+ super().__init__(model, likelihood, last_layer, subnetwork_indices)
self.stochastic = stochastic
def _get_diag_ggn(self):
diff --git a/laplace/curvature/curvature.py b/laplace/curvature/curvature.py
index 363e51d1..735fd74d 100644
--- a/laplace/curvature/curvature.py
+++ b/laplace/curvature/curvature.py
@@ -16,6 +16,9 @@ class CurvatureInterface:
likelihood : {'classification', 'regression'}
last_layer : bool, default=False
only consider curvature of last layer
+ subnetwork_indices : torch.Tensor, default=None
+ indices of the vectorized model parameters that define the subnetwork
+ to apply the Laplace approximation over
Attributes
----------
@@ -24,11 +27,12 @@ class CurvatureInterface:
conversion factor between torch losses and base likelihoods
For example, \\(\\frac{1}{2}\\) to get to \\(\\mathcal{N}(f, 1)\\) from MSELoss.
"""
- def __init__(self, model, likelihood, last_layer=False):
+ def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None):
assert likelihood in ['regression', 'classification']
self.likelihood = likelihood
self.model = model
self.last_layer = last_layer
+ self.subnetwork_indices = subnetwork_indices
if likelihood == 'regression':
self.lossfunc = MSELoss(reduction='sum')
self.factor = 0.5
@@ -41,7 +45,7 @@ def _model(self):
return self.model.last_layer if self.last_layer else self.model
@staticmethod
- def jacobians(model, x):
+ def jacobians(model, x, subnetwork_indices=None):
"""Compute Jacobians \\(\\nabla_\\theta f(x;\\theta)\\) at current parameter \\(\\theta\\).
Parameters
@@ -49,6 +53,9 @@ def jacobians(model, x):
model : torch.nn.Module
x : torch.Tensor
input data `(batch, input_shape)` on compatible device with model.
+ subnetwork_indices : torch.Tensor, default=None
+ indices of the vectorized model parameters that define the subnetwork
+ to apply the Laplace approximation over
Returns
-------
@@ -180,12 +187,15 @@ class GGNInterface(CurvatureInterface):
likelihood : {'classification', 'regression'}
last_layer : bool, default=False
only consider curvature of last layer
+ subnetwork_indices : torch.Tensor, default=None
+ indices of the vectorized model parameters that define the subnetwork
+ to apply the Laplace approximation over
stochastic : bool, default=False
Fisher if stochastic else GGN
"""
- def __init__(self, model, likelihood, last_layer=False, stochastic=False):
+ def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False):
self.stochastic = stochastic
- super().__init__(model, likelihood, last_layer)
+ super().__init__(model, likelihood, last_layer, subnetwork_indices)
def _get_full_ggn(self, Js, f, y):
"""Compute full GGN from Jacobians.
@@ -239,7 +249,7 @@ def full(self, x, y, **kwargs):
if self.last_layer:
Js, f = self.last_layer_jacobians(self.model, x)
else:
- Js, f = self.jacobians(self.model, x)
+ Js, f = self.jacobians(self.model, x, self.subnetwork_indices)
loss, H_ggn = self._get_full_ggn(Js, f, y)
return loss, H_ggn
@@ -256,6 +266,9 @@ class EFInterface(CurvatureInterface):
likelihood : {'classification', 'regression'}
last_layer : bool, default=False
only consider curvature of last layer
+ subnetwork_indices : torch.Tensor, default=None
+ indices of the vectorized model parameters that define the subnetwork
+ to apply the Laplace approximation over
Attributes
----------
diff --git a/laplace/laplace.py b/laplace/laplace.py
index 98f11cad..24241f17 100644
--- a/laplace/laplace.py
+++ b/laplace/laplace.py
@@ -10,7 +10,7 @@ def Laplace(model, likelihood, subset_of_weights='last_layer', hessian_structure
----------
model : torch.nn.Module
likelihood : {'classification', 'regression'}
- subset_of_weights : {'last_layer', 'all'}, default='last_layer'
+ subset_of_weights : {'last_layer', 'subnetwork', 'all'}, default='last_layer'
subset of weights to consider for inference
hessian_structure : {'diag', 'kron', 'full'}, default='kron'
structure of the Hessian approximation
@@ -20,6 +20,9 @@ def Laplace(model, likelihood, subset_of_weights='last_layer', hessian_structure
laplace : ParametricLaplace
chosen subclass of ParametricLaplace instantiated with additional arguments
"""
+ if subset_of_weights == 'subnetwork' and hessian_structure != 'full':
+ raise ValueError('Subnetwork Laplace requires using a full Hessian approximation!')
+
laplace_map = {subclass._key: subclass for subclass in _all_subclasses(ParametricLaplace)
if hasattr(subclass, '_key')}
laplace_class = laplace_map[(subset_of_weights, hessian_structure)]
diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py
new file mode 100644
index 00000000..30bf15d2
--- /dev/null
+++ b/laplace/subnetlaplace.py
@@ -0,0 +1,120 @@
+import torch
+
+from laplace.baselaplace import ParametricLaplace, FullLaplace
+
+from laplace.curvature import BackPackGGN
+
+
+__all__ = ['FullSubnetLaplace']
+
+
+class SubnetLaplace(ParametricLaplace):
+ """Baseclass for all subnetwork Laplace approximations in this library.
+ Subclasses specify the structure of the Hessian approximation.
+ See `BaseLaplace` for the full interface.
+
+ A Laplace approximation is represented by a MAP which is given by the
+ `model` parameter and a posterior precision or covariance specifying
+ a Gaussian distribution \\(\\mathcal{N}(\\theta_{MAP}, P^{-1})\\).
+ Here, only the parameters of a subnetwork of the neural network
+ are treated probabilistically.
+ The goal of this class is to compute the posterior precision \\(P\\)
+ which sums as
+ \\[
+ P = \\sum_{n=1}^N \\nabla^2_\\theta \\log p(\\mathcal{D}_n \\mid \\theta)
+ \\vert_{\\theta_{MAP}} + \\nabla^2_\\theta \\log p(\\theta) \\vert_{\\theta_{MAP}}.
+ \\]
+ There is one subclass, which implements the only supported option of a full
+ approximation to the log likelihood Hessian. The prior is assumed to be Gaussian and
+ therefore we have a simple form for
+ \\(\\nabla^2_\\theta \\log p(\\theta) \\vert_{\\theta_{MAP}} = P_0 \\).
+ In particular, we assume a scalar or diagonal prior precision so that in
+ all cases \\(P_0 = \\textrm{diag}(p_0)\\) and the structure of \\(p_0\\) can be varied.
+
+ Parameters
+ ----------
+ model : torch.nn.Module or `laplace.feature_extractor.FeatureExtractor`
+ likelihood : {'classification', 'regression'}
+ determines the log likelihood Hessian approximation
+ subnetwork_mask : torch.Tensor, default=None
+ mask defining the subnetwork to apply the Laplace approximation over
+ sigma_noise : torch.Tensor or float, default=1
+ observation noise for the regression setting; must be 1 for classification
+ prior_precision : torch.Tensor or float, default=1
+ prior precision of a Gaussian prior (= weight decay);
+ can be scalar, per-layer, or diagonal in the most general case
+ prior_mean : torch.Tensor or float, default=0
+ prior mean of a Gaussian prior, useful for continual learning
+ temperature : float, default=1
+ temperature of the likelihood; lower temperature leads to more
+ concentrated posterior and vice versa.
+ backend : subclasses of `laplace.curvature.CurvatureInterface`
+ backend for access to curvature/Hessian approximations
+ backend_kwargs : dict, default=None
+ arguments passed to the backend on initialization, for example to
+ set the number of MC samples for stochastic approximations.
+ """
+ def __init__(self, model, likelihood, subnetwork_mask=None, sigma_noise=1., prior_precision=1.,
+ prior_mean=0., temperature=1., backend=BackPackGGN, backend_kwargs=None):
+ super().__init__(model, likelihood, sigma_noise=sigma_noise, prior_precision=prior_precision,
+ prior_mean=prior_mean, temperature=temperature, backend=backend,
+ backend_kwargs=backend_kwargs)
+ self.subnetwork_mask = subnetwork_mask
+
+ @property
+ def subnetwork_mask(self):
+ return self._subnetwork_mask
+
+ @subnetwork_mask.setter
+ def subnetwork_mask(self, subnetwork_mask):
+ """Check validity of subnetwork mask and convert it to a vector of indices of the vectorized
+ model parameters that define the subnetwork to apply the Laplace approximation over.
+ """
+ if isinstance(subnetwork_mask, torch.Tensor) and len(subnetwork_mask.shape) == 1:
+ if len(subnetwork_mask) == self.n_params and\
+ len(subnetwork_mask[subnetwork_mask == 0]) +\
+ len(subnetwork_mask[subnetwork_mask == 1]) == self.n_params:
+ self._subnetwork_mask = subnetwork_mask.nonzero(as_tuple=True)[0]
+
+ elif len(subnetwork_mask) <= self.n_params and\
+ len(subnetwork_mask[subnetwork_mask >= self.n_params]) == 0:
+ self._subnetwork_mask = subnetwork_mask
+
+ else:
+ raise ValueError('Subnetwork mask needs to identify the subnetwork parameters\
+ from the vectorized model parameters as:\
+ 1) a vector of indices of the subnetwork parameters,\
+ 2) a binary vector of size (parameters) where 1s locate the subnetwork parameters')
+
+ elif subnetwork_mask is None:
+ raise ValueError('You need to specify a subnetwork mask!')
+
+ else:
+ raise ValueError('Subnetwork mask needs to be 1-dimensional torch.Tensor!')
+
+ # Q: do we allow changing the subnetwork after instantiation, or should it stay fixed?
+ #self._backend_kwargs['subnetwork_mask'] = self._subnetwork_mask
+ self.backend.subnetwork_indices = self._subnetwork_mask
+
+ # Q: documentation: should I mention subnetworks everywhere and write down the number
+ # of parameters?
+
+ # Q jacobian() is static and therefore cannot access self.subnetwork_indices (need to pass it)
+ # what about making it non-static? it's also ugly in l. 563 of baselaplace.py!
+
+
+
+class FullSubnetLaplace(SubnetLaplace, FullLaplace):
+ """Subnetwork Laplace approximation with full, i.e., dense, log likelihood Hessian approximation
+ and hence posterior precision. Based on the chosen `backend` parameter, the full
+ approximation can be, for example, a generalized Gauss-Newton matrix.
+ Mathematically, we have \\(P \\in \\mathbb{R}^{P \\times P}\\).
+ See `FullLaplace`, `LLLaplace`, and `BaseLaplace` for the full interface.
+ """
+ # key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure)
+ _key = ('subnetwork', 'full')
+
+ def __init__(self, model, likelihood, subnetwork_mask=None, sigma_noise=1., prior_precision=1.,
+ prior_mean=0., temperature=1., backend=BackPackGGN, backend_kwargs=None):
+ super().__init__(model, likelihood, subnetwork_mask, sigma_noise, prior_precision,
+ prior_mean, temperature, backend, backend_kwargs)
From 88e806ce49fe56579eafef694648160503c9e5c8 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Tue, 30 Nov 2021 10:26:26 +0000
Subject: [PATCH 02/49] Fix issues with subnetwork Laplace integration
---
laplace/__init__.py | 5 ++++-
laplace/subnetlaplace.py | 45 ++++++++++++++++++++++++++++++++--------
2 files changed, 40 insertions(+), 10 deletions(-)
diff --git a/laplace/__init__.py b/laplace/__init__.py
index 2f866125..32f67929 100644
--- a/laplace/__init__.py
+++ b/laplace/__init__.py
@@ -8,10 +8,13 @@
from laplace.baselaplace import BaseLaplace, ParametricLaplace, FullLaplace, KronLaplace, DiagLaplace
from laplace.lllaplace import LLLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace
+from laplace.subnetlaplace import SubnetLaplace, FullSubnetLaplace
from laplace.laplace import Laplace
__all__ = ['Laplace', # direct access to all Laplace classes via unified interface
'BaseLaplace', 'ParametricLaplace', # base-class and its (first-level) subclasses
'FullLaplace', 'KronLaplace', 'DiagLaplace', # all-weights
'LLLaplace', # base-class last-layer
- 'FullLLLaplace', 'KronLLLaplace', 'DiagLLLaplace'] # last-layer
+ 'FullLLLaplace', 'KronLLLaplace', 'DiagLLLaplace', # last-layer
+ 'SubnetLaplace', # base-class subnetwork
+ 'FullSubnetLaplace'] # subnetwork
diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py
index 30bf15d2..213b1ef0 100644
--- a/laplace/subnetlaplace.py
+++ b/laplace/subnetlaplace.py
@@ -60,6 +60,7 @@ def __init__(self, model, likelihood, subnetwork_mask=None, sigma_noise=1., prio
prior_mean=prior_mean, temperature=temperature, backend=backend,
backend_kwargs=backend_kwargs)
self.subnetwork_mask = subnetwork_mask
+ self.n_params_subnet = len(self.subnetwork_mask)
@property
def subnetwork_mask(self):
@@ -70,8 +71,12 @@ def subnetwork_mask(self, subnetwork_mask):
"""Check validity of subnetwork mask and convert it to a vector of indices of the vectorized
model parameters that define the subnetwork to apply the Laplace approximation over.
"""
- if isinstance(subnetwork_mask, torch.Tensor) and len(subnetwork_mask.shape) == 1:
- if len(subnetwork_mask) == self.n_params and\
+ if isinstance(subnetwork_mask, torch.Tensor):
+ if subnetwork_mask.type() not in ['torch.ByteTensor', 'torch.IntTensor', 'torch.LongTensor'] or\
+ len(subnetwork_mask.shape) != 1:
+ raise ValueError('Subnetwork mask needs to be 1-dimensional torch.{Byte,Int,Long}Tensor!')
+
+ elif len(subnetwork_mask) == self.n_params and\
len(subnetwork_mask[subnetwork_mask == 0]) +\
len(subnetwork_mask[subnetwork_mask == 1]) == self.n_params:
self._subnetwork_mask = subnetwork_mask.nonzero(as_tuple=True)[0]
@@ -81,16 +86,16 @@ def subnetwork_mask(self, subnetwork_mask):
self._subnetwork_mask = subnetwork_mask
else:
- raise ValueError('Subnetwork mask needs to identify the subnetwork parameters\
- from the vectorized model parameters as:\
- 1) a vector of indices of the subnetwork parameters,\
- 2) a binary vector of size (parameters) where 1s locate the subnetwork parameters')
+ raise ValueError('Subnetwork mask needs to identify the subnetwork parameters '\
+ 'from the vectorized model parameters as:\n'\
+ '1) a vector of indices of the subnetwork parameters, or\n'\
+ '2) a binary vector of size (parameters) where 1s locate the subnetwork parameters.')
elif subnetwork_mask is None:
- raise ValueError('You need to specify a subnetwork mask!')
+ raise ValueError('Subnetwork Laplace requires passing a subnetwork mask!')
else:
- raise ValueError('Subnetwork mask needs to be 1-dimensional torch.Tensor!')
+ raise ValueError('Subnetwork mask needs to be torch.Tensor!')
# Q: do we allow changing the subnetwork after instantiation, or should it stay fixed?
#self._backend_kwargs['subnetwork_mask'] = self._subnetwork_mask
@@ -101,7 +106,26 @@ def subnetwork_mask(self, subnetwork_mask):
# Q jacobian() is static and therefore cannot access self.subnetwork_indices (need to pass it)
# what about making it non-static? it's also ugly in l. 563 of baselaplace.py!
-
+
+ # still need to implement nn mc predictive (need to sample subnet separately and then put samples together)
+
+ @property
+ def prior_precision_diag(self):
+ """Obtain the diagonal prior precision \\(p_0\\) constructed from either
+ a scalar or diagonal prior precision.
+
+ Returns
+ -------
+ prior_precision_diag : torch.Tensor
+ """
+ if len(self.prior_precision) == 1: # scalar
+ return self.prior_precision * torch.ones(self.n_params_subnet, device=self._device)
+
+ elif len(self.prior_precision) == self.n_params_subnet: # diagonal
+ return self.prior_precision
+
+ else:
+ raise ValueError('Mismatch of prior and model. Diagonal or scalar prior.')
class FullSubnetLaplace(SubnetLaplace, FullLaplace):
@@ -118,3 +142,6 @@ def __init__(self, model, likelihood, subnetwork_mask=None, sigma_noise=1., prio
prior_mean=0., temperature=1., backend=BackPackGGN, backend_kwargs=None):
super().__init__(model, likelihood, subnetwork_mask, sigma_noise, prior_precision,
prior_mean, temperature, backend, backend_kwargs)
+
+ def _init_H(self):
+ self.H = torch.zeros(self.n_params_subnet, self.n_params_subnet, device=self._device)
From 768fa63e56a57f7343801f4474ff3f13effbcb81 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Tue, 30 Nov 2021 10:28:41 +0000
Subject: [PATCH 03/49] Remove notes to myself
---
laplace/subnetlaplace.py | 10 ----------
1 file changed, 10 deletions(-)
diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py
index 213b1ef0..298ddffa 100644
--- a/laplace/subnetlaplace.py
+++ b/laplace/subnetlaplace.py
@@ -97,18 +97,8 @@ def subnetwork_mask(self, subnetwork_mask):
else:
raise ValueError('Subnetwork mask needs to be torch.Tensor!')
- # Q: do we allow changing the subnetwork after instantiation, or should it stay fixed?
- #self._backend_kwargs['subnetwork_mask'] = self._subnetwork_mask
self.backend.subnetwork_indices = self._subnetwork_mask
- # Q: documentation: should I mention subnetworks everywhere and write down the number
- # of parameters?
-
- # Q jacobian() is static and therefore cannot access self.subnetwork_indices (need to pass it)
- # what about making it non-static? it's also ugly in l. 563 of baselaplace.py!
-
- # still need to implement nn mc predictive (need to sample subnet separately and then put samples together)
-
@property
def prior_precision_diag(self):
"""Obtain the diagonal prior precision \\(p_0\\) constructed from either
From f8ab8ac4532c24fbbf097ba6a85672b4a690d880 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Fri, 10 Dec 2021 07:44:17 +0000
Subject: [PATCH 04/49] Remove SubnetLaplace base class; only option remains
FullSubnetLaplace
---
laplace/__init__.py | 3 +--
laplace/subnetlaplace.py | 50 +++++++++++++++++-----------------------
2 files changed, 22 insertions(+), 31 deletions(-)
diff --git a/laplace/__init__.py b/laplace/__init__.py
index 32f67929..e34989b0 100644
--- a/laplace/__init__.py
+++ b/laplace/__init__.py
@@ -8,7 +8,7 @@
from laplace.baselaplace import BaseLaplace, ParametricLaplace, FullLaplace, KronLaplace, DiagLaplace
from laplace.lllaplace import LLLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace
-from laplace.subnetlaplace import SubnetLaplace, FullSubnetLaplace
+from laplace.subnetlaplace import FullSubnetLaplace
from laplace.laplace import Laplace
__all__ = ['Laplace', # direct access to all Laplace classes via unified interface
@@ -16,5 +16,4 @@
'FullLaplace', 'KronLaplace', 'DiagLaplace', # all-weights
'LLLaplace', # base-class last-layer
'FullLLLaplace', 'KronLLLaplace', 'DiagLLLaplace', # last-layer
- 'SubnetLaplace', # base-class subnetwork
'FullSubnetLaplace'] # subnetwork
diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py
index 298ddffa..9749c1f6 100644
--- a/laplace/subnetlaplace.py
+++ b/laplace/subnetlaplace.py
@@ -1,6 +1,6 @@
import torch
-from laplace.baselaplace import ParametricLaplace, FullLaplace
+from laplace.baselaplace import FullLaplace
from laplace.curvature import BackPackGGN
@@ -8,29 +8,34 @@
__all__ = ['FullSubnetLaplace']
-class SubnetLaplace(ParametricLaplace):
- """Baseclass for all subnetwork Laplace approximations in this library.
- Subclasses specify the structure of the Hessian approximation.
- See `BaseLaplace` for the full interface.
+class FullSubnetLaplace(FullLaplace):
+ """Class for subnetwork Laplace, which computes the Laplace approximation over
+ just a subset of the model parameters (i.e. a subnetwork within the neural network).
+ Subnetwork Laplace only supports a full Hessian approximation; other Hessian
+ approximations could be used in theory, but would not make as much sense conceptually.
A Laplace approximation is represented by a MAP which is given by the
`model` parameter and a posterior precision or covariance specifying
a Gaussian distribution \\(\\mathcal{N}(\\theta_{MAP}, P^{-1})\\).
- Here, only the parameters of a subnetwork of the neural network
- are treated probabilistically.
+ Here, only a subset of the model parameters (i.e. a subnetwork of the
+ neural network) are treated probabilistically.
The goal of this class is to compute the posterior precision \\(P\\)
which sums as
\\[
P = \\sum_{n=1}^N \\nabla^2_\\theta \\log p(\\mathcal{D}_n \\mid \\theta)
\\vert_{\\theta_{MAP}} + \\nabla^2_\\theta \\log p(\\theta) \\vert_{\\theta_{MAP}}.
\\]
- There is one subclass, which implements the only supported option of a full
- approximation to the log likelihood Hessian. The prior is assumed to be Gaussian and
- therefore we have a simple form for
+ The prior is assumed to be Gaussian and therefore we have a simple form for
\\(\\nabla^2_\\theta \\log p(\\theta) \\vert_{\\theta_{MAP}} = P_0 \\).
In particular, we assume a scalar or diagonal prior precision so that in
all cases \\(P_0 = \\textrm{diag}(p_0)\\) and the structure of \\(p_0\\) can be varied.
+ The subnetwork Laplace approximation only supports a full, i.e., dense, log likelihood
+ Hessian approximation and hence posterior precision. Based on the chosen `backend`
+ parameter, the full approximation can be, for example, a generalized Gauss-Newton
+ matrix. Mathematically, we have \\(P \\in \\mathbb{R}^{P \\times P}\\).
+ See `FullLaplace` and `BaseLaplace` for the full interface.
+
Parameters
----------
model : torch.nn.Module or `laplace.feature_extractor.FeatureExtractor`
@@ -54,6 +59,9 @@ class SubnetLaplace(ParametricLaplace):
arguments passed to the backend on initialization, for example to
set the number of MC samples for stochastic approximations.
"""
+ # key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure)
+ _key = ('subnetwork', 'full')
+
def __init__(self, model, likelihood, subnetwork_mask=None, sigma_noise=1., prior_precision=1.,
prior_mean=0., temperature=1., backend=BackPackGGN, backend_kwargs=None):
super().__init__(model, likelihood, sigma_noise=sigma_noise, prior_precision=prior_precision,
@@ -62,6 +70,9 @@ def __init__(self, model, likelihood, subnetwork_mask=None, sigma_noise=1., prio
self.subnetwork_mask = subnetwork_mask
self.n_params_subnet = len(self.subnetwork_mask)
+ def _init_H(self):
+ self.H = torch.zeros(self.n_params_subnet, self.n_params_subnet, device=self._device)
+
@property
def subnetwork_mask(self):
return self._subnetwork_mask
@@ -116,22 +127,3 @@ def prior_precision_diag(self):
else:
raise ValueError('Mismatch of prior and model. Diagonal or scalar prior.')
-
-
-class FullSubnetLaplace(SubnetLaplace, FullLaplace):
- """Subnetwork Laplace approximation with full, i.e., dense, log likelihood Hessian approximation
- and hence posterior precision. Based on the chosen `backend` parameter, the full
- approximation can be, for example, a generalized Gauss-Newton matrix.
- Mathematically, we have \\(P \\in \\mathbb{R}^{P \\times P}\\).
- See `FullLaplace`, `LLLaplace`, and `BaseLaplace` for the full interface.
- """
- # key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure)
- _key = ('subnetwork', 'full')
-
- def __init__(self, model, likelihood, subnetwork_mask=None, sigma_noise=1., prior_precision=1.,
- prior_mean=0., temperature=1., backend=BackPackGGN, backend_kwargs=None):
- super().__init__(model, likelihood, subnetwork_mask, sigma_noise, prior_precision,
- prior_mean, temperature, backend, backend_kwargs)
-
- def _init_H(self):
- self.H = torch.zeros(self.n_params_subnet, self.n_params_subnet, device=self._device)
From 57f46d2b4f9909eecd1117635ffb6bf3fb7f8da5 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Fri, 10 Dec 2021 08:16:39 +0000
Subject: [PATCH 05/49] Make jacobians and last_layer_jacobians non-static and
adapted code accordingly (incl. tests)
---
laplace/baselaplace.py | 2 +-
laplace/curvature/asdl.py | 10 +++-------
laplace/curvature/backpack.py | 10 +++-------
laplace/curvature/curvature.py | 11 +++--------
tests/test_jacobians.py | 25 +++++++++++++++----------
5 files changed, 25 insertions(+), 33 deletions(-)
diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py
index 4b9ce704..527a130c 100644
--- a/laplace/baselaplace.py
+++ b/laplace/baselaplace.py
@@ -560,7 +560,7 @@ def predictive_samples(self, x, pred_type='glm', n_samples=100):
@torch.enable_grad()
def _glm_predictive_distribution(self, X):
- Js, f_mu = self.backend.jacobians(self.model, X, self.backend.subnetwork_indices)
+ Js, f_mu = self.backend.jacobians(self.model, X)
f_var = self.functional_variance(Js)
return f_mu.detach(), f_var.detach()
diff --git a/laplace/curvature/asdl.py b/laplace/curvature/asdl.py
index afe3e673..2d9dc0bb 100644
--- a/laplace/curvature/asdl.py
+++ b/laplace/curvature/asdl.py
@@ -20,8 +20,7 @@ def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None)
raise ValueError('This backend only supports classification currently.')
super().__init__(model, likelihood, last_layer, subnetwork_indices)
- @staticmethod
- def jacobians(model, x, subnetwork_indices=None):
+ def jacobians(self, model, x):
"""Compute Jacobians \\(\\nabla_\\theta f(x;\\theta)\\) at current parameter \\(\\theta\\)
using asdfghjkl's gradient per output dimension.
@@ -30,9 +29,6 @@ def jacobians(model, x, subnetwork_indices=None):
model : torch.nn.Module
x : torch.Tensor
input data `(batch, input_shape)` on compatible device with model.
- subnetwork_indices : torch.Tensor, default=None
- indices of the vectorized model parameters that define the subnetwork
- to apply the Laplace approximation over
Returns
-------
@@ -48,8 +44,8 @@ def loss_fn(outputs, targets):
f = batch_gradient(model, loss_fn, x, None).detach()
Jk = _get_batch_grad(model)
- if subnetwork_indices is not None:
- Jk = Jk[:, subnetwork_indices]
+ if self.subnetwork_indices is not None:
+ Jk = Jk[:, self.subnetwork_indices]
Js.append(Jk)
Js = torch.stack(Js, dim=1)
return Js, f
diff --git a/laplace/curvature/backpack.py b/laplace/curvature/backpack.py
index 6e655944..42599729 100644
--- a/laplace/curvature/backpack.py
+++ b/laplace/curvature/backpack.py
@@ -16,8 +16,7 @@ def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None)
extend(self._model)
extend(self.lossfunc)
- @staticmethod
- def jacobians(model, x, subnetwork_indices=None):
+ def jacobians(self, model, x):
"""Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\)
using backpack's BatchGrad per output dimension.
@@ -26,9 +25,6 @@ def jacobians(model, x, subnetwork_indices=None):
model : torch.nn.Module
x : torch.Tensor
input data `(batch, input_shape)` on compatible device with model.
- subnetwork_indices : torch.Tensor, default=None
- indices of the vectorized model parameters that define the subnetwork
- to apply the Laplace approximation over
Returns
-------
@@ -52,8 +48,8 @@ def jacobians(model, x, subnetwork_indices=None):
to_cat.append(param.grad_batch.detach().reshape(x.shape[0], -1))
delattr(param, 'grad_batch')
Jk = torch.cat(to_cat, dim=1)
- if subnetwork_indices is not None:
- Jk = Jk[:, subnetwork_indices]
+ if self.subnetwork_indices is not None:
+ Jk = Jk[:, self.subnetwork_indices]
to_stack.append(Jk)
if i == 0:
f = out.detach()
diff --git a/laplace/curvature/curvature.py b/laplace/curvature/curvature.py
index 735fd74d..373bf398 100644
--- a/laplace/curvature/curvature.py
+++ b/laplace/curvature/curvature.py
@@ -44,8 +44,7 @@ def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None)
def _model(self):
return self.model.last_layer if self.last_layer else self.model
- @staticmethod
- def jacobians(model, x, subnetwork_indices=None):
+ def jacobians(self, model, x):
"""Compute Jacobians \\(\\nabla_\\theta f(x;\\theta)\\) at current parameter \\(\\theta\\).
Parameters
@@ -53,9 +52,6 @@ def jacobians(model, x, subnetwork_indices=None):
model : torch.nn.Module
x : torch.Tensor
input data `(batch, input_shape)` on compatible device with model.
- subnetwork_indices : torch.Tensor, default=None
- indices of the vectorized model parameters that define the subnetwork
- to apply the Laplace approximation over
Returns
-------
@@ -66,8 +62,7 @@ def jacobians(model, x, subnetwork_indices=None):
"""
raise NotImplementedError
- @staticmethod
- def last_layer_jacobians(model, x):
+ def last_layer_jacobians(self, model, x):
"""Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\)
only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\).
@@ -249,7 +244,7 @@ def full(self, x, y, **kwargs):
if self.last_layer:
Js, f = self.last_layer_jacobians(self.model, x)
else:
- Js, f = self.jacobians(self.model, x, self.subnetwork_indices)
+ Js, f = self.jacobians(self.model, x)
loss, H_ggn = self._get_full_ggn(Js, f, y)
return loss, H_ggn
diff --git a/tests/test_jacobians.py b/tests/test_jacobians.py
index 7a5a22ef..8f676db1 100644
--- a/tests/test_jacobians.py
+++ b/tests/test_jacobians.py
@@ -35,9 +35,10 @@ def X():
return torch.randn(200, 3)
-@pytest.mark.parametrize('backend', [AsdlInterface, BackPackInterface])
-def test_linear_jacobians(linear_model, X, backend):
+@pytest.mark.parametrize('backend_cls', [AsdlInterface, BackPackInterface])
+def test_linear_jacobians(linear_model, X, backend_cls):
# jacobian of linear model is input X.
+ backend = backend_cls(linear_model, 'classification')
Js, f = backend.jacobians(linear_model, X)
# into Jacs shape (batch_size, output_size, params)
true_Js = X.reshape(len(X), 1, -1)
@@ -46,9 +47,10 @@ def test_linear_jacobians(linear_model, X, backend):
assert torch.allclose(f, linear_model(X), atol=1e-5)
-@pytest.mark.parametrize('backend', [AsdlInterface, BackPackInterface])
-def test_jacobians_singleoutput(singleoutput_model, X, backend):
+@pytest.mark.parametrize('backend_cls', [AsdlInterface, BackPackInterface])
+def test_jacobians_singleoutput(singleoutput_model, X, backend_cls):
model = singleoutput_model
+ backend = backend_cls(model, 'classification')
Js, f = backend.jacobians(model, X)
Js_naive, f_naive = jacobians_naive(model, X)
assert Js.shape == Js_naive.shape
@@ -57,9 +59,10 @@ def test_jacobians_singleoutput(singleoutput_model, X, backend):
assert torch.allclose(f, f_naive)
-@pytest.mark.parametrize('backend', [AsdlInterface, BackPackInterface])
-def test_jacobians_multioutput(multioutput_model, X, backend):
+@pytest.mark.parametrize('backend_cls', [AsdlInterface, BackPackInterface])
+def test_jacobians_multioutput(multioutput_model, X, backend_cls):
model = multioutput_model
+ backend = backend_cls(model, 'classification')
Js, f = backend.jacobians(model, X)
Js_naive, f_naive = jacobians_naive(model, X)
assert Js.shape == Js_naive.shape
@@ -68,9 +71,10 @@ def test_jacobians_multioutput(multioutput_model, X, backend):
assert torch.allclose(f, f_naive)
-@pytest.mark.parametrize('backend', [AsdlInterface, BackPackInterface])
-def test_last_layer_jacobians_singleoutput(singleoutput_model, X, backend):
+@pytest.mark.parametrize('backend_cls', [AsdlInterface, BackPackInterface])
+def test_last_layer_jacobians_singleoutput(singleoutput_model, X, backend_cls):
model = FeatureExtractor(singleoutput_model)
+ backend = backend_cls(model, 'classification')
Js, f = backend.last_layer_jacobians(model, X)
_, phi = model.forward_with_features(X)
Js_naive, f_naive = jacobians_naive(model.last_layer, phi)
@@ -80,9 +84,10 @@ def test_last_layer_jacobians_singleoutput(singleoutput_model, X, backend):
assert torch.allclose(f, f_naive)
-@pytest.mark.parametrize('backend', [AsdlInterface, BackPackInterface])
-def test_last_layer_jacobians_multioutput(multioutput_model, X, backend):
+@pytest.mark.parametrize('backend_cls', [AsdlInterface, BackPackInterface])
+def test_last_layer_jacobians_multioutput(multioutput_model, X, backend_cls):
model = FeatureExtractor(multioutput_model)
+ backend = backend_cls(model, 'classification')
Js, f = backend.last_layer_jacobians(model, X)
_, phi = model.forward_with_features(X)
Js_naive, f_naive = jacobians_naive(model.last_layer, phi)
From 253012204613b699186c6ec80fc79d0f7315e24e Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Fri, 10 Dec 2021 15:12:51 +0000
Subject: [PATCH 06/49] Add SubnetMask baseclass and subclasses for random,
largest magnitude, and last-layer subnet masks
---
laplace/subnetmask.py | 193 ++++++++++++++++++++++++++++++++++++++++++
1 file changed, 193 insertions(+)
create mode 100644 laplace/subnetmask.py
diff --git a/laplace/subnetmask.py b/laplace/subnetmask.py
new file mode 100644
index 00000000..e85790fc
--- /dev/null
+++ b/laplace/subnetmask.py
@@ -0,0 +1,193 @@
+import torch
+from torch.nn.utils import parameters_to_vector
+
+from laplace.feature_extractor import FeatureExtractor
+
+__all__ = ['SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask', 'LastLayerSubnetMask']
+
+
+class SubnetMask:
+ """Baseclass for all subnetwork masks in this library (for subnetwork Laplace).
+
+ Parameters
+ ----------
+ model : torch.nn.Module
+ """
+ def __init__(self, model):
+ self.model = model
+ self.parameter_vector = parameters_to_vector(self.model.parameters()).detach()
+ self._n_params = len(self.parameter_vector)
+ self._device = next(self.model.parameters()).device
+ self._indices = None
+ self._n_params_subnet = None
+
+ @property
+ def n_params_subnet(self):
+ raise NotImplementedError
+
+ def _check_select(self):
+ if self._indices is None:
+ raise AttributeError('Subnetwork mask not selected. Run select() first.')
+
+ @property
+ def indices(self):
+ self._check_select()
+ return self._indices
+
+ def convert_subnet_mask_to_indices(self, subnet_mask):
+ """Converts a subnetwork mask into subnetwork indices.
+
+ Parameters
+ ----------
+ subnet_mask : torch.Tensor
+ a binary vector of size (n_params) where 1s locate the subnetwork parameters
+ within the vectorized model parameters
+
+ Returns
+ -------
+ subnet_mask_indices : torch.Tensor
+ a vector of indices of the vectorized model parameters that define the subnetwork
+ """
+ if not isinstance(subnet_mask, torch.Tensor):
+ raise ValueError('Subnetwork mask needs to be torch.Tensor!')
+ elif subnet_mask.type() not in ['torch.ByteTensor', 'torch.IntTensor', 'torch.LongTensor'] or\
+ len(subnet_mask.shape) != 1:
+ raise ValueError('Subnetwork mask needs to be 1-dimensional torch.{Byte,Int,Long}Tensor!')
+ elif len(subnet_mask) != self._n_params or\
+ len(subnet_mask[subnet_mask == 0]) + len(subnet_mask[subnet_mask == 1]) != self._n_params:
+ raise ValueError('Subnetwork mask needs to be a binary vector of size (n_params) where 1s'\
+ 'locate the subnetwork parameters within the vectorized model parameters!')
+
+ subnet_mask_indices = subnet_mask.nonzero(as_tuple=True)[0]
+ return subnet_mask_indices
+
+ def select(self, train_loader):
+ """ Select the subnetwork mask.
+
+ Parameters
+ ----------
+ train_loader : torch.data.utils.DataLoader
+ each iterate is a training batch (X, y);
+ `train_loader.dataset` needs to be set to access \\(N\\), size of the data set
+ """
+ if self._indices is not None:
+ raise ValueError('Subnetwork mask already selected.')
+
+ subnet_mask = self.get_subnet_mask(train_loader)
+ self._indices = self.convert_subnet_mask_to_indices(subnet_mask)
+
+ def get_subnet_mask(self, train_loader):
+ """ Get the subnetwork mask.
+
+ Parameters
+ ----------
+ train_loader : torch.data.utils.DataLoader
+ each iterate is a training batch (X, y);
+ `train_loader.dataset` needs to be set to access \\(N\\), size of the data set
+
+ Returns
+ -------
+ subnet_mask: torch.Tensor
+ a binary vector of size (n_params) where 1s locate the subnetwork parameters
+ within the vectorized model parameters
+ """
+ raise NotImplementedError
+
+
+class ScoreBasedSubnetMask(SubnetMask):
+ """Baseclass for subnetwork masks defined by selecting the top-scoring parameters according to some criterion.
+
+ Parameters
+ ----------
+ model : torch.nn.Module
+ n_params_subnet : int
+ the number of parameters in the subnetwork (i.e. the number of top-scoring parameters to select)
+ """
+ def __init__(self, model, n_params_subnet):
+ super().__init__(model)
+
+ if n_params_subnet is None:
+ raise ValueError(f'Need to pass number of subnetwork parameters when using subnetwork Laplace.')
+ if n_params_subnet > self._n_params:
+ raise ValueError(f'Subnetwork ({n_params_subnet}) cannot be larger than model ({self._n_params}).')
+ self._n_params_subnet = n_params_subnet
+ self._param_scores = None
+
+ @property
+ def n_params_subnet(self):
+ return self._n_params_subnet
+
+ def compute_param_scores(self, train_loader):
+ raise NotImplementedError
+
+ def _check_param_scores(self):
+ if self._param_scores.shape != self.parameter_vector.shape:
+ raise ValueError('Parameter scores need to be of same shape as parameter vector.')
+
+ def get_subnet_mask(self, train_loader):
+ """ Get the subnetwork mask by ranking parameters based on their scores."""
+
+ if self._param_scores is None:
+ self._param_scores = self.compute_param_scores(train_loader)
+ self._check_param_scores()
+
+ idx = torch.argsort(self._param_scores, descending=True)[:self._n_params_subnet]
+ idx = idx.sort()[0]
+ subnet_mask = torch.zeros_like(self.parameter_vector).byte()
+ subnet_mask[idx] = 1
+ return subnet_mask
+
+
+class RandomSubnetMask(ScoreBasedSubnetMask):
+ """Subnetwork mask of parameters sampled uniformly at random."""
+ def compute_param_scores(self, train_loader):
+ return torch.rand_like(self.parameter_vector)
+
+
+class LargestMagnitudeSubnetMask(ScoreBasedSubnetMask):
+ """Subnetwork mask identifying the parameters with the largest magnitude. """
+ def compute_param_scores(self, train_loader):
+ return self.parameter_vector
+
+
+class LastLayerSubnetMask(SubnetMask):
+ """Subnetwork mask corresponding to the last layer of the neural network.
+
+ Parameters
+ ----------
+ model : torch.nn.Module
+ last_layer_name: str, default=None
+ name of the model's last layer, if None it will be determined automatically
+ """
+ def __init__(self, model, last_layer_name=None):
+ super().__init__(model)
+ self.model = FeatureExtractor(self.model, last_layer_name=last_layer_name)
+ self._n_params_subnet = None
+
+ @property
+ def n_params_subnet(self):
+ if self._n_params_subnet is None:
+ self._check_select()
+ self._n_params_subnet = torch.count_nonzero(self._indices).item()
+ return self._n_params_subnet
+
+ def get_subnet_mask(self, train_loader):
+ """ Get the subnetwork mask identifying the last layer."""
+
+ self.model.eval()
+ if self.model.last_layer is None:
+ X, _ = next(iter(train_loader))
+ with torch.no_grad():
+ self.model.find_last_layer(X[:1].to(self._device))
+
+ subnet_mask_list = []
+ for name, layer in self.model.model.named_modules():
+ if len(list(layer.children())) > 0:
+ continue
+ if name == self.model._last_layer_name:
+ mask_method = torch.ones_like
+ else:
+ mask_method = torch.zeros_like
+ subnet_mask_list.append(mask_method(parameters_to_vector(layer.parameters())))
+ subnet_mask = torch.cat(subnet_mask_list).byte()
+ return subnet_mask
From c0be3f92c5fc8f5bc041c3a44f555e6928e5fcab Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Fri, 10 Dec 2021 15:13:47 +0000
Subject: [PATCH 07/49] Adapt FullSubnetLaplace to use new SubnetMask class
interface
---
laplace/subnetlaplace.py | 66 +++++++++++++++-------------------------
1 file changed, 25 insertions(+), 41 deletions(-)
diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py
index 9749c1f6..01d8735e 100644
--- a/laplace/subnetlaplace.py
+++ b/laplace/subnetlaplace.py
@@ -41,7 +41,7 @@ class FullSubnetLaplace(FullLaplace):
model : torch.nn.Module or `laplace.feature_extractor.FeatureExtractor`
likelihood : {'classification', 'regression'}
determines the log likelihood Hessian approximation
- subnetwork_mask : torch.Tensor, default=None
+ subnetwork_mask : subclasses of `laplace.subnetmask.SubnetMask`, default=None
mask defining the subnetwork to apply the Laplace approximation over
sigma_noise : torch.Tensor or float, default=1
observation noise for the regression setting; must be 1 for classification
@@ -58,58 +58,24 @@ class FullSubnetLaplace(FullLaplace):
backend_kwargs : dict, default=None
arguments passed to the backend on initialization, for example to
set the number of MC samples for stochastic approximations.
+ subnetmask_kwargs : dict, default=None
+ arguments passed to the subnetwork mask on initialization.
"""
# key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure)
_key = ('subnetwork', 'full')
def __init__(self, model, likelihood, subnetwork_mask=None, sigma_noise=1., prior_precision=1.,
- prior_mean=0., temperature=1., backend=BackPackGGN, backend_kwargs=None):
+ prior_mean=0., temperature=1., backend=BackPackGGN, backend_kwargs=None, subnetmask_kwargs=None):
super().__init__(model, likelihood, sigma_noise=sigma_noise, prior_precision=prior_precision,
prior_mean=prior_mean, temperature=temperature, backend=backend,
backend_kwargs=backend_kwargs)
- self.subnetwork_mask = subnetwork_mask
- self.n_params_subnet = len(self.subnetwork_mask)
+ self._subnetmask_kwargs = dict() if subnetmask_kwargs is None else subnetmask_kwargs
+ self.subnetwork_mask = subnetwork_mask(self.model, **self._subnetmask_kwargs)
+ self.n_params_subnet = None
def _init_H(self):
self.H = torch.zeros(self.n_params_subnet, self.n_params_subnet, device=self._device)
- @property
- def subnetwork_mask(self):
- return self._subnetwork_mask
-
- @subnetwork_mask.setter
- def subnetwork_mask(self, subnetwork_mask):
- """Check validity of subnetwork mask and convert it to a vector of indices of the vectorized
- model parameters that define the subnetwork to apply the Laplace approximation over.
- """
- if isinstance(subnetwork_mask, torch.Tensor):
- if subnetwork_mask.type() not in ['torch.ByteTensor', 'torch.IntTensor', 'torch.LongTensor'] or\
- len(subnetwork_mask.shape) != 1:
- raise ValueError('Subnetwork mask needs to be 1-dimensional torch.{Byte,Int,Long}Tensor!')
-
- elif len(subnetwork_mask) == self.n_params and\
- len(subnetwork_mask[subnetwork_mask == 0]) +\
- len(subnetwork_mask[subnetwork_mask == 1]) == self.n_params:
- self._subnetwork_mask = subnetwork_mask.nonzero(as_tuple=True)[0]
-
- elif len(subnetwork_mask) <= self.n_params and\
- len(subnetwork_mask[subnetwork_mask >= self.n_params]) == 0:
- self._subnetwork_mask = subnetwork_mask
-
- else:
- raise ValueError('Subnetwork mask needs to identify the subnetwork parameters '\
- 'from the vectorized model parameters as:\n'\
- '1) a vector of indices of the subnetwork parameters, or\n'\
- '2) a binary vector of size (parameters) where 1s locate the subnetwork parameters.')
-
- elif subnetwork_mask is None:
- raise ValueError('Subnetwork Laplace requires passing a subnetwork mask!')
-
- else:
- raise ValueError('Subnetwork mask needs to be torch.Tensor!')
-
- self.backend.subnetwork_indices = self._subnetwork_mask
-
@property
def prior_precision_diag(self):
"""Obtain the diagonal prior precision \\(p_0\\) constructed from either
@@ -127,3 +93,21 @@ def prior_precision_diag(self):
else:
raise ValueError('Mismatch of prior and model. Diagonal or scalar prior.')
+
+ def fit(self, train_loader):
+ """Fit the local Laplace approximation at the parameters of the subnetwork.
+
+ Parameters
+ ----------
+ train_loader : torch.data.utils.DataLoader
+ each iterate is a training batch (X, y);
+ `train_loader.dataset` needs to be set to access \\(N\\), size of the data set
+ """
+
+ # select subnetwork and pass it to backend
+ self.subnetwork_mask.select(train_loader)
+ self.backend.subnetwork_indices = self.subnetwork_mask.indices
+ self.n_params_subnet = self.subnetwork_mask.n_params_subnet
+
+ # fit Laplace approximation over subnetwork
+ super().fit(train_loader)
From 257d33bd775f993b4f63dc61a1a40135cb572713 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Fri, 10 Dec 2021 15:54:47 +0000
Subject: [PATCH 08/49] Add support for largest variance subnet selection
(using diagonal Laplace)
---
laplace/subnetlaplace.py | 7 ++++++-
laplace/subnetmask.py | 25 +++++++++++++++++++++++--
2 files changed, 29 insertions(+), 3 deletions(-)
diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py
index 01d8735e..f67403be 100644
--- a/laplace/subnetlaplace.py
+++ b/laplace/subnetlaplace.py
@@ -1,8 +1,9 @@
import torch
-from laplace.baselaplace import FullLaplace
+from laplace.baselaplace import FullLaplace, DiagLaplace
from laplace.curvature import BackPackGGN
+from laplace.subnetmask import LargestVarianceDiagLaplaceSubnetMask
__all__ = ['FullSubnetLaplace']
@@ -70,6 +71,10 @@ def __init__(self, model, likelihood, subnetwork_mask=None, sigma_noise=1., prio
prior_mean=prior_mean, temperature=temperature, backend=backend,
backend_kwargs=backend_kwargs)
self._subnetmask_kwargs = dict() if subnetmask_kwargs is None else subnetmask_kwargs
+ if subnetwork_mask == LargestVarianceDiagLaplaceSubnetMask:
+ # instantiate and pass diagonal Laplace model for largest variance subnetwork selection
+ self._subnetmask_kwargs.update(diag_laplace_model=DiagLaplace(self.model, likelihood, sigma_noise,
+ prior_precision, prior_mean, temperature, backend, backend_kwargs))
self.subnetwork_mask = subnetwork_mask(self.model, **self._subnetmask_kwargs)
self.n_params_subnet = None
diff --git a/laplace/subnetmask.py b/laplace/subnetmask.py
index e85790fc..b32a7ba4 100644
--- a/laplace/subnetmask.py
+++ b/laplace/subnetmask.py
@@ -3,7 +3,7 @@
from laplace.feature_extractor import FeatureExtractor
-__all__ = ['SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask', 'LastLayerSubnetMask']
+__all__ = ['SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask', 'LastLayerSubnetMask', 'LargestVarianceDiagLaplaceSubnetMask']
class SubnetMask:
@@ -125,7 +125,7 @@ def _check_param_scores(self):
raise ValueError('Parameter scores need to be of same shape as parameter vector.')
def get_subnet_mask(self, train_loader):
- """ Get the subnetwork mask by ranking parameters based on their scores."""
+ """ Get the subnetwork mask by (descendingly) ranking parameters based on their scores."""
if self._param_scores is None:
self._param_scores = self.compute_param_scores(train_loader)
@@ -150,6 +150,27 @@ def compute_param_scores(self, train_loader):
return self.parameter_vector
+class LargestVarianceDiagLaplaceSubnetMask(ScoreBasedSubnetMask):
+ """Subnetwork mask identifying the parameters with the largest marginal variances
+ (estimated using a diagional Laplace approximation over all model parameters).
+
+ Parameters
+ ----------
+ model : torch.nn.Module
+ n_params_subnet : int
+ the number of parameters in the subnetwork (i.e. the number of top-scoring parameters to select)
+ diag_laplace_model : `laplace.baselaplace.DiagLaplace`
+ diagonal Laplace model to use for variance estimation
+ """
+ def __init__(self, model, n_params_subnet, diag_laplace_model):
+ super().__init__(model, n_params_subnet)
+ self.diag_laplace_model = diag_laplace_model
+
+ def compute_param_scores(self, train_loader):
+ self.diag_laplace_model.fit(train_loader)
+ return self.diag_laplace_model.posterior_variance
+
+
class LastLayerSubnetMask(SubnetMask):
"""Subnetwork mask corresponding to the last layer of the neural network.
From 3771e94920c2e53c700dbc47115a706604df220b Mon Sep 17 00:00:00 2001
From: Alex Immer
Date: Fri, 10 Dec 2021 17:18:24 +0100
Subject: [PATCH 09/49] Remove change
---
laplace/curvature/asdl.py | 4 ----
1 file changed, 4 deletions(-)
diff --git a/laplace/curvature/asdl.py b/laplace/curvature/asdl.py
index 8bdc19ca..d65c4334 100644
--- a/laplace/curvature/asdl.py
+++ b/laplace/curvature/asdl.py
@@ -18,10 +18,6 @@
class AsdlInterface(CurvatureInterface):
"""Interface for asdfghjkl backend.
"""
- def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None):
- if likelihood != 'classification':
- raise ValueError('This backend only supports classification currently.')
- super().__init__(model, likelihood, last_layer, subnetwork_indices)
def jacobians(self, model, x):
"""Compute Jacobians \\(\\nabla_\\theta f(x;\\theta)\\) at current parameter \\(\\theta\\)
From 38cd0f67620df75398e4c1bbc0c6df8f7ef386fd Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Fri, 10 Dec 2021 16:21:32 +0000
Subject: [PATCH 10/49] Change FullSubnetLaplace to SubnetLaplace as it's the
only option
---
laplace/__init__.py | 4 ++--
laplace/subnetlaplace.py | 4 ++--
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/laplace/__init__.py b/laplace/__init__.py
index d092d50a..429b9f0d 100644
--- a/laplace/__init__.py
+++ b/laplace/__init__.py
@@ -9,7 +9,7 @@
from laplace.baselaplace import BaseLaplace, ParametricLaplace, FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace
from laplace.lllaplace import LLLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace
-from laplace.subnetlaplace import FullSubnetLaplace
+from laplace.subnetlaplace import SubnetLaplace
from laplace.laplace import Laplace
from laplace.marglik_training import marglik_training
@@ -18,5 +18,5 @@
'FullLaplace', 'KronLaplace', 'DiagLaplace', 'LowRankLaplace', # all-weights
'LLLaplace', # base-class last-layer
'FullLLLaplace', 'KronLLLaplace', 'DiagLLLaplace', # last-layer
- 'FullSubnetLaplace', # subnetwork
+ 'SubnetLaplace', # subnetwork
'marglik_training'] # methods
diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py
index f67403be..90cd8aeb 100644
--- a/laplace/subnetlaplace.py
+++ b/laplace/subnetlaplace.py
@@ -6,10 +6,10 @@
from laplace.subnetmask import LargestVarianceDiagLaplaceSubnetMask
-__all__ = ['FullSubnetLaplace']
+__all__ = ['SubnetLaplace']
-class FullSubnetLaplace(FullLaplace):
+class SubnetLaplace(FullLaplace):
"""Class for subnetwork Laplace, which computes the Laplace approximation over
just a subset of the model parameters (i.e. a subnetwork within the neural network).
Subnetwork Laplace only supports a full Hessian approximation; other Hessian
From f933dad8ec6137e2a0921b6dd8d1baac854c7555 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Fri, 10 Dec 2021 16:25:00 +0000
Subject: [PATCH 11/49] Convert indentation from tabs to spaces
---
laplace/subnetmask.py | 314 +++++++++++++++++++++---------------------
1 file changed, 157 insertions(+), 157 deletions(-)
diff --git a/laplace/subnetmask.py b/laplace/subnetmask.py
index b32a7ba4..468734d4 100644
--- a/laplace/subnetmask.py
+++ b/laplace/subnetmask.py
@@ -7,77 +7,77 @@
class SubnetMask:
- """Baseclass for all subnetwork masks in this library (for subnetwork Laplace).
-
- Parameters
- ----------
- model : torch.nn.Module
- """
- def __init__(self, model):
- self.model = model
- self.parameter_vector = parameters_to_vector(self.model.parameters()).detach()
- self._n_params = len(self.parameter_vector)
- self._device = next(self.model.parameters()).device
- self._indices = None
- self._n_params_subnet = None
-
- @property
- def n_params_subnet(self):
- raise NotImplementedError
-
- def _check_select(self):
- if self._indices is None:
- raise AttributeError('Subnetwork mask not selected. Run select() first.')
-
- @property
- def indices(self):
- self._check_select()
- return self._indices
-
- def convert_subnet_mask_to_indices(self, subnet_mask):
- """Converts a subnetwork mask into subnetwork indices.
-
- Parameters
- ----------
- subnet_mask : torch.Tensor
- a binary vector of size (n_params) where 1s locate the subnetwork parameters
- within the vectorized model parameters
-
- Returns
- -------
- subnet_mask_indices : torch.Tensor
- a vector of indices of the vectorized model parameters that define the subnetwork
- """
- if not isinstance(subnet_mask, torch.Tensor):
- raise ValueError('Subnetwork mask needs to be torch.Tensor!')
- elif subnet_mask.type() not in ['torch.ByteTensor', 'torch.IntTensor', 'torch.LongTensor'] or\
+ """Baseclass for all subnetwork masks in this library (for subnetwork Laplace).
+
+ Parameters
+ ----------
+ model : torch.nn.Module
+ """
+ def __init__(self, model):
+ self.model = model
+ self.parameter_vector = parameters_to_vector(self.model.parameters()).detach()
+ self._n_params = len(self.parameter_vector)
+ self._device = next(self.model.parameters()).device
+ self._indices = None
+ self._n_params_subnet = None
+
+ @property
+ def n_params_subnet(self):
+ raise NotImplementedError
+
+ def _check_select(self):
+ if self._indices is None:
+ raise AttributeError('Subnetwork mask not selected. Run select() first.')
+
+ @property
+ def indices(self):
+ self._check_select()
+ return self._indices
+
+ def convert_subnet_mask_to_indices(self, subnet_mask):
+ """Converts a subnetwork mask into subnetwork indices.
+
+ Parameters
+ ----------
+ subnet_mask : torch.Tensor
+ a binary vector of size (n_params) where 1s locate the subnetwork parameters
+ within the vectorized model parameters
+
+ Returns
+ -------
+ subnet_mask_indices : torch.Tensor
+ a vector of indices of the vectorized model parameters that define the subnetwork
+ """
+ if not isinstance(subnet_mask, torch.Tensor):
+ raise ValueError('Subnetwork mask needs to be torch.Tensor!')
+ elif subnet_mask.type() not in ['torch.ByteTensor', 'torch.IntTensor', 'torch.LongTensor'] or\
len(subnet_mask.shape) != 1:
- raise ValueError('Subnetwork mask needs to be 1-dimensional torch.{Byte,Int,Long}Tensor!')
- elif len(subnet_mask) != self._n_params or\
+ raise ValueError('Subnetwork mask needs to be 1-dimensional torch.{Byte,Int,Long}Tensor!')
+ elif len(subnet_mask) != self._n_params or\
len(subnet_mask[subnet_mask == 0]) + len(subnet_mask[subnet_mask == 1]) != self._n_params:
- raise ValueError('Subnetwork mask needs to be a binary vector of size (n_params) where 1s'\
- 'locate the subnetwork parameters within the vectorized model parameters!')
+ raise ValueError('Subnetwork mask needs to be a binary vector of size (n_params) where 1s'\
+ 'locate the subnetwork parameters within the vectorized model parameters!')
- subnet_mask_indices = subnet_mask.nonzero(as_tuple=True)[0]
- return subnet_mask_indices
+ subnet_mask_indices = subnet_mask.nonzero(as_tuple=True)[0]
+ return subnet_mask_indices
- def select(self, train_loader):
- """ Select the subnetwork mask.
+ def select(self, train_loader):
+ """ Select the subnetwork mask.
Parameters
----------
train_loader : torch.data.utils.DataLoader
each iterate is a training batch (X, y);
`train_loader.dataset` needs to be set to access \\(N\\), size of the data set
- """
- if self._indices is not None:
- raise ValueError('Subnetwork mask already selected.')
+ """
+ if self._indices is not None:
+ raise ValueError('Subnetwork mask already selected.')
- subnet_mask = self.get_subnet_mask(train_loader)
- self._indices = self.convert_subnet_mask_to_indices(subnet_mask)
+ subnet_mask = self.get_subnet_mask(train_loader)
+ self._indices = self.convert_subnet_mask_to_indices(subnet_mask)
- def get_subnet_mask(self, train_loader):
- """ Get the subnetwork mask.
+ def get_subnet_mask(self, train_loader):
+ """ Get the subnetwork mask.
Parameters
----------
@@ -85,130 +85,130 @@ def get_subnet_mask(self, train_loader):
each iterate is a training batch (X, y);
`train_loader.dataset` needs to be set to access \\(N\\), size of the data set
- Returns
- -------
- subnet_mask: torch.Tensor
- a binary vector of size (n_params) where 1s locate the subnetwork parameters
- within the vectorized model parameters
- """
- raise NotImplementedError
+ Returns
+ -------
+ subnet_mask: torch.Tensor
+ a binary vector of size (n_params) where 1s locate the subnetwork parameters
+ within the vectorized model parameters
+ """
+ raise NotImplementedError
class ScoreBasedSubnetMask(SubnetMask):
- """Baseclass for subnetwork masks defined by selecting the top-scoring parameters according to some criterion.
+ """Baseclass for subnetwork masks defined by selecting the top-scoring parameters according to some criterion.
- Parameters
- ----------
- model : torch.nn.Module
- n_params_subnet : int
- the number of parameters in the subnetwork (i.e. the number of top-scoring parameters to select)
- """
- def __init__(self, model, n_params_subnet):
- super().__init__(model)
+ Parameters
+ ----------
+ model : torch.nn.Module
+ n_params_subnet : int
+ the number of parameters in the subnetwork (i.e. the number of top-scoring parameters to select)
+ """
+ def __init__(self, model, n_params_subnet):
+ super().__init__(model)
- if n_params_subnet is None:
- raise ValueError(f'Need to pass number of subnetwork parameters when using subnetwork Laplace.')
- if n_params_subnet > self._n_params:
- raise ValueError(f'Subnetwork ({n_params_subnet}) cannot be larger than model ({self._n_params}).')
- self._n_params_subnet = n_params_subnet
- self._param_scores = None
+ if n_params_subnet is None:
+ raise ValueError(f'Need to pass number of subnetwork parameters when using subnetwork Laplace.')
+ if n_params_subnet > self._n_params:
+ raise ValueError(f'Subnetwork ({n_params_subnet}) cannot be larger than model ({self._n_params}).')
+ self._n_params_subnet = n_params_subnet
+ self._param_scores = None
- @property
- def n_params_subnet(self):
- return self._n_params_subnet
+ @property
+ def n_params_subnet(self):
+ return self._n_params_subnet
- def compute_param_scores(self, train_loader):
- raise NotImplementedError
+ def compute_param_scores(self, train_loader):
+ raise NotImplementedError
- def _check_param_scores(self):
- if self._param_scores.shape != self.parameter_vector.shape:
- raise ValueError('Parameter scores need to be of same shape as parameter vector.')
+ def _check_param_scores(self):
+ if self._param_scores.shape != self.parameter_vector.shape:
+ raise ValueError('Parameter scores need to be of same shape as parameter vector.')
- def get_subnet_mask(self, train_loader):
- """ Get the subnetwork mask by (descendingly) ranking parameters based on their scores."""
+ def get_subnet_mask(self, train_loader):
+ """ Get the subnetwork mask by (descendingly) ranking parameters based on their scores."""
- if self._param_scores is None:
- self._param_scores = self.compute_param_scores(train_loader)
- self._check_param_scores()
+ if self._param_scores is None:
+ self._param_scores = self.compute_param_scores(train_loader)
+ self._check_param_scores()
- idx = torch.argsort(self._param_scores, descending=True)[:self._n_params_subnet]
- idx = idx.sort()[0]
- subnet_mask = torch.zeros_like(self.parameter_vector).byte()
- subnet_mask[idx] = 1
- return subnet_mask
+ idx = torch.argsort(self._param_scores, descending=True)[:self._n_params_subnet]
+ idx = idx.sort()[0]
+ subnet_mask = torch.zeros_like(self.parameter_vector).byte()
+ subnet_mask[idx] = 1
+ return subnet_mask
class RandomSubnetMask(ScoreBasedSubnetMask):
- """Subnetwork mask of parameters sampled uniformly at random."""
- def compute_param_scores(self, train_loader):
- return torch.rand_like(self.parameter_vector)
+ """Subnetwork mask of parameters sampled uniformly at random."""
+ def compute_param_scores(self, train_loader):
+ return torch.rand_like(self.parameter_vector)
class LargestMagnitudeSubnetMask(ScoreBasedSubnetMask):
- """Subnetwork mask identifying the parameters with the largest magnitude. """
- def compute_param_scores(self, train_loader):
- return self.parameter_vector
+ """Subnetwork mask identifying the parameters with the largest magnitude. """
+ def compute_param_scores(self, train_loader):
+ return self.parameter_vector
class LargestVarianceDiagLaplaceSubnetMask(ScoreBasedSubnetMask):
- """Subnetwork mask identifying the parameters with the largest marginal variances
- (estimated using a diagional Laplace approximation over all model parameters).
-
- Parameters
- ----------
- model : torch.nn.Module
- n_params_subnet : int
- the number of parameters in the subnetwork (i.e. the number of top-scoring parameters to select)
+ """Subnetwork mask identifying the parameters with the largest marginal variances
+ (estimated using a diagional Laplace approximation over all model parameters).
+
+ Parameters
+ ----------
+ model : torch.nn.Module
+ n_params_subnet : int
+ the number of parameters in the subnetwork (i.e. the number of top-scoring parameters to select)
diag_laplace_model : `laplace.baselaplace.DiagLaplace`
diagonal Laplace model to use for variance estimation
- """
- def __init__(self, model, n_params_subnet, diag_laplace_model):
- super().__init__(model, n_params_subnet)
- self.diag_laplace_model = diag_laplace_model
+ """
+ def __init__(self, model, n_params_subnet, diag_laplace_model):
+ super().__init__(model, n_params_subnet)
+ self.diag_laplace_model = diag_laplace_model
- def compute_param_scores(self, train_loader):
- self.diag_laplace_model.fit(train_loader)
- return self.diag_laplace_model.posterior_variance
+ def compute_param_scores(self, train_loader):
+ self.diag_laplace_model.fit(train_loader)
+ return self.diag_laplace_model.posterior_variance
class LastLayerSubnetMask(SubnetMask):
- """Subnetwork mask corresponding to the last layer of the neural network.
+ """Subnetwork mask corresponding to the last layer of the neural network.
- Parameters
- ----------
- model : torch.nn.Module
+ Parameters
+ ----------
+ model : torch.nn.Module
last_layer_name: str, default=None
name of the model's last layer, if None it will be determined automatically
- """
- def __init__(self, model, last_layer_name=None):
- super().__init__(model)
- self.model = FeatureExtractor(self.model, last_layer_name=last_layer_name)
- self._n_params_subnet = None
-
- @property
- def n_params_subnet(self):
- if self._n_params_subnet is None:
- self._check_select()
- self._n_params_subnet = torch.count_nonzero(self._indices).item()
- return self._n_params_subnet
-
- def get_subnet_mask(self, train_loader):
- """ Get the subnetwork mask identifying the last layer."""
-
- self.model.eval()
- if self.model.last_layer is None:
- X, _ = next(iter(train_loader))
- with torch.no_grad():
- self.model.find_last_layer(X[:1].to(self._device))
-
- subnet_mask_list = []
- for name, layer in self.model.model.named_modules():
- if len(list(layer.children())) > 0:
- continue
- if name == self.model._last_layer_name:
- mask_method = torch.ones_like
- else:
- mask_method = torch.zeros_like
- subnet_mask_list.append(mask_method(parameters_to_vector(layer.parameters())))
- subnet_mask = torch.cat(subnet_mask_list).byte()
- return subnet_mask
+ """
+ def __init__(self, model, last_layer_name=None):
+ super().__init__(model)
+ self.model = FeatureExtractor(self.model, last_layer_name=last_layer_name)
+ self._n_params_subnet = None
+
+ @property
+ def n_params_subnet(self):
+ if self._n_params_subnet is None:
+ self._check_select()
+ self._n_params_subnet = torch.count_nonzero(self._indices).item()
+ return self._n_params_subnet
+
+ def get_subnet_mask(self, train_loader):
+ """ Get the subnetwork mask identifying the last layer."""
+
+ self.model.eval()
+ if self.model.last_layer is None:
+ X, _ = next(iter(train_loader))
+ with torch.no_grad():
+ self.model.find_last_layer(X[:1].to(self._device))
+
+ subnet_mask_list = []
+ for name, layer in self.model.model.named_modules():
+ if len(list(layer.children())) > 0:
+ continue
+ if name == self.model._last_layer_name:
+ mask_method = torch.ones_like
+ else:
+ mask_method = torch.zeros_like
+ subnet_mask_list.append(mask_method(parameters_to_vector(layer.parameters())))
+ subnet_mask = torch.cat(subnet_mask_list).byte()
+ return subnet_mask
From ac4542c47489af328191e98dc434864e0b202367 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Fri, 10 Dec 2021 16:37:26 +0000
Subject: [PATCH 12/49] Remove model as argument from jacobians() as it has
access to self.model now (same for last_layer_jacobians)
---
laplace/baselaplace.py | 2 +-
laplace/curvature/asdl.py | 9 ++++-----
laplace/curvature/backpack.py | 5 ++---
laplace/curvature/curvature.py | 14 ++++++--------
laplace/lllaplace.py | 2 +-
tests/test_jacobians.py | 10 +++++-----
tests/test_lllaplace.py | 2 +-
7 files changed, 20 insertions(+), 24 deletions(-)
diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py
index 47449350..df2d2c27 100644
--- a/laplace/baselaplace.py
+++ b/laplace/baselaplace.py
@@ -592,7 +592,7 @@ def predictive_samples(self, x, pred_type='glm', n_samples=100):
@torch.enable_grad()
def _glm_predictive_distribution(self, X):
- Js, f_mu = self.backend.jacobians(self.model, X)
+ Js, f_mu = self.backend.jacobians(X)
f_var = self.functional_variance(Js)
return f_mu.detach(), f_var.detach()
diff --git a/laplace/curvature/asdl.py b/laplace/curvature/asdl.py
index d65c4334..25dcdcc9 100644
--- a/laplace/curvature/asdl.py
+++ b/laplace/curvature/asdl.py
@@ -19,13 +19,12 @@ class AsdlInterface(CurvatureInterface):
"""Interface for asdfghjkl backend.
"""
- def jacobians(self, model, x):
+ def jacobians(self, x):
"""Compute Jacobians \\(\\nabla_\\theta f(x;\\theta)\\) at current parameter \\(\\theta\\)
using asdfghjkl's gradient per output dimension.
Parameters
----------
- model : torch.nn.Module
x : torch.Tensor
input data `(batch, input_shape)` on compatible device with model.
@@ -37,12 +36,12 @@ def jacobians(self, model, x):
output function `(batch, outputs)`
"""
Js = list()
- for i in range(model.output_size):
+ for i in range(self.model.output_size):
def loss_fn(outputs, targets):
return outputs[:, i].sum()
- f = batch_gradient(model, loss_fn, x, None).detach()
- Jk = _get_batch_grad(model)
+ f = batch_gradient(self.model, loss_fn, x, None).detach()
+ Jk = _get_batch_grad(self.model)
if self.subnetwork_indices is not None:
Jk = Jk[:, self.subnetwork_indices]
Js.append(Jk)
diff --git a/laplace/curvature/backpack.py b/laplace/curvature/backpack.py
index 42599729..a0885800 100644
--- a/laplace/curvature/backpack.py
+++ b/laplace/curvature/backpack.py
@@ -16,13 +16,12 @@ def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None)
extend(self._model)
extend(self.lossfunc)
- def jacobians(self, model, x):
+ def jacobians(self, x):
"""Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\)
using backpack's BatchGrad per output dimension.
Parameters
----------
- model : torch.nn.Module
x : torch.Tensor
input data `(batch, input_shape)` on compatible device with model.
@@ -33,7 +32,7 @@ def jacobians(self, model, x):
f : torch.Tensor
output function `(batch, outputs)`
"""
- model = extend(model)
+ model = extend(self.model)
to_stack = []
for i in range(model.output_size):
model.zero_grad()
diff --git a/laplace/curvature/curvature.py b/laplace/curvature/curvature.py
index 47d87730..98b703b7 100644
--- a/laplace/curvature/curvature.py
+++ b/laplace/curvature/curvature.py
@@ -44,12 +44,11 @@ def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None)
def _model(self):
return self.model.last_layer if self.last_layer else self.model
- def jacobians(self, model, x):
+ def jacobians(self, x):
"""Compute Jacobians \\(\\nabla_\\theta f(x;\\theta)\\) at current parameter \\(\\theta\\).
Parameters
----------
- model : torch.nn.Module
x : torch.Tensor
input data `(batch, input_shape)` on compatible device with model.
@@ -62,13 +61,12 @@ def jacobians(self, model, x):
"""
raise NotImplementedError
- def last_layer_jacobians(self, model, x):
+ def last_layer_jacobians(self, x):
"""Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\)
only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\).
Parameters
----------
- model : laplace.feature_extractor.FeatureExtractor
x : torch.Tensor
Returns
@@ -78,7 +76,7 @@ def last_layer_jacobians(self, model, x):
f : torch.Tensor
output function `(batch, outputs)`
"""
- f, phi = model.forward_with_features(x)
+ f, phi = self.model.forward_with_features(x)
bsize = phi.shape[0]
output_size = f.shape[-1]
@@ -86,7 +84,7 @@ def last_layer_jacobians(self, model, x):
identity = torch.eye(output_size, device=x.device).unsqueeze(0).tile(bsize, 1, 1)
# Jacobians are batch x output x params
Js = torch.einsum('kp,kij->kijp', phi, identity).reshape(bsize, output_size, -1)
- if model.last_layer.bias is not None:
+ if self.model.last_layer.bias is not None:
Js = torch.cat([Js, identity], dim=2)
return Js, f.detach()
@@ -242,9 +240,9 @@ def full(self, x, y, **kwargs):
raise ValueError('Stochastic approximation not implemented for full GGN.')
if self.last_layer:
- Js, f = self.last_layer_jacobians(self.model, x)
+ Js, f = self.last_layer_jacobians(x)
else:
- Js, f = self.jacobians(self.model, x)
+ Js, f = self.jacobians(x)
loss, H_ggn = self._get_full_ggn(Js, f, y)
return loss, H_ggn
diff --git a/laplace/lllaplace.py b/laplace/lllaplace.py
index 8336e670..8053e519 100644
--- a/laplace/lllaplace.py
+++ b/laplace/lllaplace.py
@@ -115,7 +115,7 @@ def fit(self, train_loader, override=True):
self.mean = parameters_to_vector(self.model.last_layer.parameters()).detach()
def _glm_predictive_distribution(self, X):
- Js, f_mu = self.backend.last_layer_jacobians(self.model, X)
+ Js, f_mu = self.backend.last_layer_jacobians(X)
f_var = self.functional_variance(Js)
return f_mu.detach(), f_var.detach()
diff --git a/tests/test_jacobians.py b/tests/test_jacobians.py
index 8f676db1..45cd2f37 100644
--- a/tests/test_jacobians.py
+++ b/tests/test_jacobians.py
@@ -39,7 +39,7 @@ def X():
def test_linear_jacobians(linear_model, X, backend_cls):
# jacobian of linear model is input X.
backend = backend_cls(linear_model, 'classification')
- Js, f = backend.jacobians(linear_model, X)
+ Js, f = backend.jacobians(X)
# into Jacs shape (batch_size, output_size, params)
true_Js = X.reshape(len(X), 1, -1)
assert true_Js.shape == Js.shape
@@ -51,7 +51,7 @@ def test_linear_jacobians(linear_model, X, backend_cls):
def test_jacobians_singleoutput(singleoutput_model, X, backend_cls):
model = singleoutput_model
backend = backend_cls(model, 'classification')
- Js, f = backend.jacobians(model, X)
+ Js, f = backend.jacobians(X)
Js_naive, f_naive = jacobians_naive(model, X)
assert Js.shape == Js_naive.shape
assert torch.abs(Js-Js_naive).max() < 1e-6
@@ -63,7 +63,7 @@ def test_jacobians_singleoutput(singleoutput_model, X, backend_cls):
def test_jacobians_multioutput(multioutput_model, X, backend_cls):
model = multioutput_model
backend = backend_cls(model, 'classification')
- Js, f = backend.jacobians(model, X)
+ Js, f = backend.jacobians(X)
Js_naive, f_naive = jacobians_naive(model, X)
assert Js.shape == Js_naive.shape
assert torch.abs(Js-Js_naive).max() < 1e-6
@@ -75,7 +75,7 @@ def test_jacobians_multioutput(multioutput_model, X, backend_cls):
def test_last_layer_jacobians_singleoutput(singleoutput_model, X, backend_cls):
model = FeatureExtractor(singleoutput_model)
backend = backend_cls(model, 'classification')
- Js, f = backend.last_layer_jacobians(model, X)
+ Js, f = backend.last_layer_jacobians(X)
_, phi = model.forward_with_features(X)
Js_naive, f_naive = jacobians_naive(model.last_layer, phi)
assert Js.shape == Js_naive.shape
@@ -88,7 +88,7 @@ def test_last_layer_jacobians_singleoutput(singleoutput_model, X, backend_cls):
def test_last_layer_jacobians_multioutput(multioutput_model, X, backend_cls):
model = FeatureExtractor(multioutput_model)
backend = backend_cls(model, 'classification')
- Js, f = backend.last_layer_jacobians(model, X)
+ Js, f = backend.last_layer_jacobians(X)
_, phi = model.forward_with_features(X)
Js_naive, f_naive = jacobians_naive(model.last_layer, phi)
assert Js.shape == Js_naive.shape
diff --git a/tests/test_lllaplace.py b/tests/test_lllaplace.py
index 8b565687..40068d15 100644
--- a/tests/test_lllaplace.py
+++ b/tests/test_lllaplace.py
@@ -261,7 +261,7 @@ def test_laplace_functionality(laplace, lh, model, reg_loader, class_loader):
Js, f = jacobians_naive(feature_extractor.last_layer, phi)
true_f_var = torch.einsum('mkp,pq,mcq->mkc', Js, Sigma, Js)
# test last-layer Jacobians
- comp_Js, comp_f = lap.backend.last_layer_jacobians(lap.model, X)
+ comp_Js, comp_f = lap.backend.last_layer_jacobians(X)
assert torch.allclose(Js, comp_Js)
assert torch.allclose(f, comp_f)
comp_f_var = lap.functional_variance(comp_Js)
From 6a88c8254e16beb1cf7e9133b876c3dad2b8a984 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Fri, 10 Dec 2021 20:42:50 +0000
Subject: [PATCH 13/49] Minor fixes for SubnetLaplace
---
laplace/subnetlaplace.py | 2 +-
laplace/subnetmask.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py
index 90cd8aeb..67c6eb0e 100644
--- a/laplace/subnetlaplace.py
+++ b/laplace/subnetlaplace.py
@@ -65,7 +65,7 @@ class SubnetLaplace(FullLaplace):
# key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure)
_key = ('subnetwork', 'full')
- def __init__(self, model, likelihood, subnetwork_mask=None, sigma_noise=1., prior_precision=1.,
+ def __init__(self, model, likelihood, subnetwork_mask, sigma_noise=1., prior_precision=1.,
prior_mean=0., temperature=1., backend=BackPackGGN, backend_kwargs=None, subnetmask_kwargs=None):
super().__init__(model, likelihood, sigma_noise=sigma_noise, prior_precision=prior_precision,
prior_mean=prior_mean, temperature=temperature, backend=backend,
diff --git a/laplace/subnetmask.py b/laplace/subnetmask.py
index 468734d4..45bb2c3e 100644
--- a/laplace/subnetmask.py
+++ b/laplace/subnetmask.py
@@ -189,7 +189,7 @@ def __init__(self, model, last_layer_name=None):
def n_params_subnet(self):
if self._n_params_subnet is None:
self._check_select()
- self._n_params_subnet = torch.count_nonzero(self._indices).item()
+ self._n_params_subnet = len(self._indices)
return self._n_params_subnet
def get_subnet_mask(self, train_loader):
From c47c7d078af7451fa382940cc961056fab1c1b41 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Fri, 10 Dec 2021 20:43:14 +0000
Subject: [PATCH 14/49] Add tests for SubnetLaplace and SubnetMasks
---
tests/test_subnetlaplace.py | 205 ++++++++++++++++++++++++++++++++++++
1 file changed, 205 insertions(+)
create mode 100644 tests/test_subnetlaplace.py
diff --git a/tests/test_subnetlaplace.py b/tests/test_subnetlaplace.py
new file mode 100644
index 00000000..83162cfc
--- /dev/null
+++ b/tests/test_subnetlaplace.py
@@ -0,0 +1,205 @@
+import pytest
+from itertools import product
+
+import torch
+from torch import nn
+from torch.nn.utils import parameters_to_vector
+from torch.utils.data import DataLoader, TensorDataset
+
+from laplace import Laplace, SubnetLaplace
+from laplace.subnetmask import SubnetMask, RandomSubnetMask, LargestMagnitudeSubnetMask, LastLayerSubnetMask, LargestVarianceDiagLaplaceSubnetMask
+
+
+torch.manual_seed(240)
+torch.set_default_tensor_type(torch.DoubleTensor)
+score_based_subnet_masks = [RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask]
+likelihoods = ['classification', 'regression']
+
+
+@pytest.fixture
+def model():
+ model = torch.nn.Sequential(nn.Linear(3, 20), nn.Linear(20, 2))
+ model_params = list(model.parameters())
+ setattr(model, 'n_params', len(parameters_to_vector(model_params)))
+ return model
+
+
+@pytest.fixture
+def class_loader():
+ X = torch.randn(10, 3)
+ y = torch.randint(2, (10,))
+ return DataLoader(TensorDataset(X, y), batch_size=3)
+
+
+@pytest.fixture
+def reg_loader():
+ X = torch.randn(10, 3)
+ y = torch.randn(10, 2)
+ return DataLoader(TensorDataset(X, y), batch_size=3)
+
+
+@pytest.mark.parametrize('likelihood', likelihoods)
+def test_subnet_laplace_init(model, likelihood):
+ # use last-layer subnet mask for this test
+ subnetwork_mask = LastLayerSubnetMask
+
+ # subnet Laplace with full Hessian should work
+ hessian_structure = 'full'
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure=hessian_structure)
+ assert isinstance(lap, SubnetLaplace)
+
+ # subnet Laplace with diag, kron or lowrank Hessians should raise errors
+ hessian_structure = 'diag'
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure=hessian_structure)
+ hessian_structure = 'kron'
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure=hessian_structure)
+ hessian_structure = 'lowrank'
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure=hessian_structure)
+
+
+@pytest.mark.parametrize('subnetwork_mask,likelihood', product(score_based_subnet_masks, likelihoods))
+def test_score_based_subnet_masks(model, likelihood, subnetwork_mask, class_loader, reg_loader):
+ loader = class_loader if likelihood == 'classification' else reg_loader
+
+ # should raise error if we don't pass number of subnet parameters within the subnetmask_kwargs
+ subnetmask_kwargs = dict()
+ with pytest.raises(TypeError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+
+ # should raise error if we set number of subnet parameters to None
+ subnetmask_kwargs = dict(n_params_subnet=None)
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+
+ # should raise error if we set number of subnet parameters to be larger than number of model parameters
+ subnetmask_kwargs = dict(n_params_subnet=99999)
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+
+ # define valid subnet Laplace model
+ n_params_subnet = 32
+ subnetmask_kwargs = dict(n_params_subnet=n_params_subnet)
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ assert isinstance(lap, SubnetLaplace)
+ assert isinstance(lap.subnetwork_mask, subnetwork_mask)
+
+ # should raise error if we try to access the subnet indices before the subnet has been selected
+ with pytest.raises(AttributeError):
+ lap.subnetwork_mask.indices
+
+ # select subnet mask
+ lap.subnetwork_mask.select(loader)
+
+ # should raise error if we try to select the subnet again
+ with pytest.raises(ValueError):
+ lap.subnetwork_mask.select(loader)
+
+ # re-define valid subnet Laplace model
+ n_params_subnet = 32
+ subnetmask_kwargs = dict(n_params_subnet=n_params_subnet)
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ assert isinstance(lap, SubnetLaplace)
+ assert isinstance(lap.subnetwork_mask, subnetwork_mask)
+
+ # fit Laplace model (which internally selects the subnet mask)
+ lap.fit(loader)
+
+ # check some parameters
+ assert lap.subnetwork_mask.indices.equal(lap.backend.subnetwork_indices)
+ assert lap.subnetwork_mask.n_params_subnet == n_params_subnet
+ assert lap.n_params_subnet == n_params_subnet
+
+ # check that Hessian and prior precision is of correct shape
+ assert lap.H.shape == (n_params_subnet, n_params_subnet)
+ assert lap.prior_precision_diag.shape == (n_params_subnet,)
+
+ # should raise error if we try to fit the Laplace mdoel again
+ with pytest.raises(ValueError):
+ lap.fit(loader)
+
+
+@pytest.mark.parametrize('likelihood', likelihoods)
+def test_last_layer_subnet_mask(model, likelihood, class_loader, reg_loader):
+ subnetwork_mask = LastLayerSubnetMask
+ loader = class_loader if likelihood == 'classification' else reg_loader
+
+ # should raise error if we pass number of subnet parameters
+ subnetmask_kwargs = dict(n_params_subnet=32)
+ with pytest.raises(TypeError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+
+ # should raise error if we pass invalid last-layer name
+ subnetmask_kwargs = dict(last_layer_name='123')
+ with pytest.raises(KeyError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+
+ # define valid last-layer subnet Laplace model (without passing the last-layer name)
+ subnetmask_kwargs = dict()
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ assert isinstance(lap, SubnetLaplace)
+ assert isinstance(lap.subnetwork_mask, subnetwork_mask)
+
+ # define valid last-layer subnet Laplace model (with passing the last-layer name)
+ subnetmask_kwargs = dict(last_layer_name='1')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ assert isinstance(lap, SubnetLaplace)
+ assert isinstance(lap.subnetwork_mask, subnetwork_mask)
+
+ # should raise error if we access number of subnet parameters before selecting the subnet
+ with pytest.raises(AttributeError):
+ n_params_subnet = lap.subnetwork_mask.n_params_subnet
+
+ # fit Laplace model
+ lap.fit(loader)
+
+ # check some parameters
+ n_params_subnet = 42
+ assert lap.subnetwork_mask.indices.equal(lap.backend.subnetwork_indices)
+ assert lap.subnetwork_mask.n_params_subnet == n_params_subnet
+ assert lap.n_params_subnet == n_params_subnet
+
+ # check that Hessian and prior precision is of correct shape
+ assert lap.H.shape == (n_params_subnet, n_params_subnet)
+ assert lap.prior_precision_diag.shape == (n_params_subnet,)
+
+ # check that Hessian is identical to that of a full LLLaplace model
+ lllap = Laplace(model, likelihood=likelihood, subset_of_weights='last_layer', hessian_structure='full')
+ lllap.fit(loader)
+ assert lllap.H.equal(lap.H)
+
+
+@pytest.mark.parametrize('likelihood', likelihoods)
+def test_full_subnet_mask(model, likelihood, class_loader, reg_loader):
+ loader = class_loader if likelihood == 'classification' else reg_loader
+
+ # define full model 'subnet' mask class (i.e. where all parameters are part of the subnet)
+ class FullSubnetMask(SubnetMask):
+ @property
+ def n_params_subnet(self):
+ if self._n_params_subnet is None:
+ self._check_select()
+ self._n_params_subnet = len(self._indices)
+ return self._n_params_subnet
+
+ def get_subnet_mask(self, train_loader):
+ return torch.ones(model.n_params).byte()
+
+ # define and fit valid full subnet Laplace model
+ subnetwork_mask = FullSubnetMask
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full')
+ lap.fit(loader)
+ assert isinstance(lap, SubnetLaplace)
+ assert isinstance(lap.subnetwork_mask, subnetwork_mask)
+
+ # check some parameters
+ assert lap.subnetwork_mask.indices.equal(torch.tensor(list(range(model.n_params))))
+ assert lap.subnetwork_mask.n_params_subnet == model.n_params
+ assert lap.n_params_subnet == model.n_params
+
+ # check that the Hessian is identical to that of a all-weights FullLaplace model
+ full_lap = Laplace(model, likelihood=likelihood, subset_of_weights='all', hessian_structure='full')
+ full_lap.fit(loader)
+ assert full_lap.H.equal(lap.H)
From da23af97263199d8fe2ab23a1db92e51f5cee550 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Fri, 10 Dec 2021 20:44:23 +0000
Subject: [PATCH 15/49] Change indentation to spaces in test_subnetlaplace.py
---
tests/test_subnetlaplace.py | 278 ++++++++++++++++++------------------
1 file changed, 139 insertions(+), 139 deletions(-)
diff --git a/tests/test_subnetlaplace.py b/tests/test_subnetlaplace.py
index 83162cfc..875d9ca8 100644
--- a/tests/test_subnetlaplace.py
+++ b/tests/test_subnetlaplace.py
@@ -40,166 +40,166 @@ def reg_loader():
@pytest.mark.parametrize('likelihood', likelihoods)
def test_subnet_laplace_init(model, likelihood):
- # use last-layer subnet mask for this test
- subnetwork_mask = LastLayerSubnetMask
-
- # subnet Laplace with full Hessian should work
- hessian_structure = 'full'
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure=hessian_structure)
- assert isinstance(lap, SubnetLaplace)
-
- # subnet Laplace with diag, kron or lowrank Hessians should raise errors
- hessian_structure = 'diag'
- with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure=hessian_structure)
- hessian_structure = 'kron'
- with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure=hessian_structure)
- hessian_structure = 'lowrank'
- with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure=hessian_structure)
+ # use last-layer subnet mask for this test
+ subnetwork_mask = LastLayerSubnetMask
+
+ # subnet Laplace with full Hessian should work
+ hessian_structure = 'full'
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure=hessian_structure)
+ assert isinstance(lap, SubnetLaplace)
+
+ # subnet Laplace with diag, kron or lowrank Hessians should raise errors
+ hessian_structure = 'diag'
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure=hessian_structure)
+ hessian_structure = 'kron'
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure=hessian_structure)
+ hessian_structure = 'lowrank'
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure=hessian_structure)
@pytest.mark.parametrize('subnetwork_mask,likelihood', product(score_based_subnet_masks, likelihoods))
def test_score_based_subnet_masks(model, likelihood, subnetwork_mask, class_loader, reg_loader):
- loader = class_loader if likelihood == 'classification' else reg_loader
+ loader = class_loader if likelihood == 'classification' else reg_loader
- # should raise error if we don't pass number of subnet parameters within the subnetmask_kwargs
- subnetmask_kwargs = dict()
- with pytest.raises(TypeError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ # should raise error if we don't pass number of subnet parameters within the subnetmask_kwargs
+ subnetmask_kwargs = dict()
+ with pytest.raises(TypeError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
- # should raise error if we set number of subnet parameters to None
- subnetmask_kwargs = dict(n_params_subnet=None)
- with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ # should raise error if we set number of subnet parameters to None
+ subnetmask_kwargs = dict(n_params_subnet=None)
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
- # should raise error if we set number of subnet parameters to be larger than number of model parameters
- subnetmask_kwargs = dict(n_params_subnet=99999)
- with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ # should raise error if we set number of subnet parameters to be larger than number of model parameters
+ subnetmask_kwargs = dict(n_params_subnet=99999)
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
- # define valid subnet Laplace model
- n_params_subnet = 32
- subnetmask_kwargs = dict(n_params_subnet=n_params_subnet)
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
- assert isinstance(lap, SubnetLaplace)
- assert isinstance(lap.subnetwork_mask, subnetwork_mask)
+ # define valid subnet Laplace model
+ n_params_subnet = 32
+ subnetmask_kwargs = dict(n_params_subnet=n_params_subnet)
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ assert isinstance(lap, SubnetLaplace)
+ assert isinstance(lap.subnetwork_mask, subnetwork_mask)
- # should raise error if we try to access the subnet indices before the subnet has been selected
- with pytest.raises(AttributeError):
- lap.subnetwork_mask.indices
+ # should raise error if we try to access the subnet indices before the subnet has been selected
+ with pytest.raises(AttributeError):
+ lap.subnetwork_mask.indices
- # select subnet mask
- lap.subnetwork_mask.select(loader)
+ # select subnet mask
+ lap.subnetwork_mask.select(loader)
- # should raise error if we try to select the subnet again
- with pytest.raises(ValueError):
- lap.subnetwork_mask.select(loader)
+ # should raise error if we try to select the subnet again
+ with pytest.raises(ValueError):
+ lap.subnetwork_mask.select(loader)
- # re-define valid subnet Laplace model
- n_params_subnet = 32
- subnetmask_kwargs = dict(n_params_subnet=n_params_subnet)
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
- assert isinstance(lap, SubnetLaplace)
- assert isinstance(lap.subnetwork_mask, subnetwork_mask)
+ # re-define valid subnet Laplace model
+ n_params_subnet = 32
+ subnetmask_kwargs = dict(n_params_subnet=n_params_subnet)
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ assert isinstance(lap, SubnetLaplace)
+ assert isinstance(lap.subnetwork_mask, subnetwork_mask)
- # fit Laplace model (which internally selects the subnet mask)
- lap.fit(loader)
+ # fit Laplace model (which internally selects the subnet mask)
+ lap.fit(loader)
- # check some parameters
- assert lap.subnetwork_mask.indices.equal(lap.backend.subnetwork_indices)
- assert lap.subnetwork_mask.n_params_subnet == n_params_subnet
- assert lap.n_params_subnet == n_params_subnet
+ # check some parameters
+ assert lap.subnetwork_mask.indices.equal(lap.backend.subnetwork_indices)
+ assert lap.subnetwork_mask.n_params_subnet == n_params_subnet
+ assert lap.n_params_subnet == n_params_subnet
- # check that Hessian and prior precision is of correct shape
- assert lap.H.shape == (n_params_subnet, n_params_subnet)
- assert lap.prior_precision_diag.shape == (n_params_subnet,)
+ # check that Hessian and prior precision is of correct shape
+ assert lap.H.shape == (n_params_subnet, n_params_subnet)
+ assert lap.prior_precision_diag.shape == (n_params_subnet,)
- # should raise error if we try to fit the Laplace mdoel again
- with pytest.raises(ValueError):
- lap.fit(loader)
+ # should raise error if we try to fit the Laplace mdoel again
+ with pytest.raises(ValueError):
+ lap.fit(loader)
@pytest.mark.parametrize('likelihood', likelihoods)
def test_last_layer_subnet_mask(model, likelihood, class_loader, reg_loader):
- subnetwork_mask = LastLayerSubnetMask
- loader = class_loader if likelihood == 'classification' else reg_loader
-
- # should raise error if we pass number of subnet parameters
- subnetmask_kwargs = dict(n_params_subnet=32)
- with pytest.raises(TypeError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
-
- # should raise error if we pass invalid last-layer name
- subnetmask_kwargs = dict(last_layer_name='123')
- with pytest.raises(KeyError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
-
- # define valid last-layer subnet Laplace model (without passing the last-layer name)
- subnetmask_kwargs = dict()
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
- assert isinstance(lap, SubnetLaplace)
- assert isinstance(lap.subnetwork_mask, subnetwork_mask)
-
- # define valid last-layer subnet Laplace model (with passing the last-layer name)
- subnetmask_kwargs = dict(last_layer_name='1')
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
- assert isinstance(lap, SubnetLaplace)
- assert isinstance(lap.subnetwork_mask, subnetwork_mask)
-
- # should raise error if we access number of subnet parameters before selecting the subnet
- with pytest.raises(AttributeError):
- n_params_subnet = lap.subnetwork_mask.n_params_subnet
-
- # fit Laplace model
- lap.fit(loader)
-
- # check some parameters
- n_params_subnet = 42
- assert lap.subnetwork_mask.indices.equal(lap.backend.subnetwork_indices)
- assert lap.subnetwork_mask.n_params_subnet == n_params_subnet
- assert lap.n_params_subnet == n_params_subnet
-
- # check that Hessian and prior precision is of correct shape
- assert lap.H.shape == (n_params_subnet, n_params_subnet)
- assert lap.prior_precision_diag.shape == (n_params_subnet,)
-
- # check that Hessian is identical to that of a full LLLaplace model
- lllap = Laplace(model, likelihood=likelihood, subset_of_weights='last_layer', hessian_structure='full')
- lllap.fit(loader)
- assert lllap.H.equal(lap.H)
+ subnetwork_mask = LastLayerSubnetMask
+ loader = class_loader if likelihood == 'classification' else reg_loader
+
+ # should raise error if we pass number of subnet parameters
+ subnetmask_kwargs = dict(n_params_subnet=32)
+ with pytest.raises(TypeError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+
+ # should raise error if we pass invalid last-layer name
+ subnetmask_kwargs = dict(last_layer_name='123')
+ with pytest.raises(KeyError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+
+ # define valid last-layer subnet Laplace model (without passing the last-layer name)
+ subnetmask_kwargs = dict()
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ assert isinstance(lap, SubnetLaplace)
+ assert isinstance(lap.subnetwork_mask, subnetwork_mask)
+
+ # define valid last-layer subnet Laplace model (with passing the last-layer name)
+ subnetmask_kwargs = dict(last_layer_name='1')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ assert isinstance(lap, SubnetLaplace)
+ assert isinstance(lap.subnetwork_mask, subnetwork_mask)
+
+ # should raise error if we access number of subnet parameters before selecting the subnet
+ with pytest.raises(AttributeError):
+ n_params_subnet = lap.subnetwork_mask.n_params_subnet
+
+ # fit Laplace model
+ lap.fit(loader)
+
+ # check some parameters
+ n_params_subnet = 42
+ assert lap.subnetwork_mask.indices.equal(lap.backend.subnetwork_indices)
+ assert lap.subnetwork_mask.n_params_subnet == n_params_subnet
+ assert lap.n_params_subnet == n_params_subnet
+
+ # check that Hessian and prior precision is of correct shape
+ assert lap.H.shape == (n_params_subnet, n_params_subnet)
+ assert lap.prior_precision_diag.shape == (n_params_subnet,)
+
+ # check that Hessian is identical to that of a full LLLaplace model
+ lllap = Laplace(model, likelihood=likelihood, subset_of_weights='last_layer', hessian_structure='full')
+ lllap.fit(loader)
+ assert lllap.H.equal(lap.H)
@pytest.mark.parametrize('likelihood', likelihoods)
def test_full_subnet_mask(model, likelihood, class_loader, reg_loader):
- loader = class_loader if likelihood == 'classification' else reg_loader
-
- # define full model 'subnet' mask class (i.e. where all parameters are part of the subnet)
- class FullSubnetMask(SubnetMask):
- @property
- def n_params_subnet(self):
- if self._n_params_subnet is None:
- self._check_select()
- self._n_params_subnet = len(self._indices)
- return self._n_params_subnet
-
- def get_subnet_mask(self, train_loader):
- return torch.ones(model.n_params).byte()
-
- # define and fit valid full subnet Laplace model
- subnetwork_mask = FullSubnetMask
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full')
- lap.fit(loader)
- assert isinstance(lap, SubnetLaplace)
- assert isinstance(lap.subnetwork_mask, subnetwork_mask)
-
- # check some parameters
- assert lap.subnetwork_mask.indices.equal(torch.tensor(list(range(model.n_params))))
- assert lap.subnetwork_mask.n_params_subnet == model.n_params
- assert lap.n_params_subnet == model.n_params
-
- # check that the Hessian is identical to that of a all-weights FullLaplace model
- full_lap = Laplace(model, likelihood=likelihood, subset_of_weights='all', hessian_structure='full')
- full_lap.fit(loader)
- assert full_lap.H.equal(lap.H)
+ loader = class_loader if likelihood == 'classification' else reg_loader
+
+ # define full model 'subnet' mask class (i.e. where all parameters are part of the subnet)
+ class FullSubnetMask(SubnetMask):
+ @property
+ def n_params_subnet(self):
+ if self._n_params_subnet is None:
+ self._check_select()
+ self._n_params_subnet = len(self._indices)
+ return self._n_params_subnet
+
+ def get_subnet_mask(self, train_loader):
+ return torch.ones(model.n_params).byte()
+
+ # define and fit valid full subnet Laplace model
+ subnetwork_mask = FullSubnetMask
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full')
+ lap.fit(loader)
+ assert isinstance(lap, SubnetLaplace)
+ assert isinstance(lap.subnetwork_mask, subnetwork_mask)
+
+ # check some parameters
+ assert lap.subnetwork_mask.indices.equal(torch.tensor(list(range(model.n_params))))
+ assert lap.subnetwork_mask.n_params_subnet == model.n_params
+ assert lap.n_params_subnet == model.n_params
+
+ # check that the Hessian is identical to that of a all-weights FullLaplace model
+ full_lap = Laplace(model, likelihood=likelihood, subset_of_weights='all', hessian_structure='full')
+ full_lap.fit(loader)
+ assert full_lap.H.equal(lap.H)
From ddf840cc127fc1fb32c8e9519c093fc8517222d6 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Wed, 15 Dec 2021 17:11:55 +0000
Subject: [PATCH 16/49] Implement sample() method for SubnetLaplace (as e.g.
required for the NN predictive)
---
laplace/subnetlaplace.py | 20 ++++++++++++++++----
1 file changed, 16 insertions(+), 4 deletions(-)
diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py
index 67c6eb0e..cce841fa 100644
--- a/laplace/subnetlaplace.py
+++ b/laplace/subnetlaplace.py
@@ -1,4 +1,5 @@
import torch
+from torch.distributions import MultivariateNormal
from laplace.baselaplace import FullLaplace, DiagLaplace
@@ -75,7 +76,7 @@ def __init__(self, model, likelihood, subnetwork_mask, sigma_noise=1., prior_pre
# instantiate and pass diagonal Laplace model for largest variance subnetwork selection
self._subnetmask_kwargs.update(diag_laplace_model=DiagLaplace(self.model, likelihood, sigma_noise,
prior_precision, prior_mean, temperature, backend, backend_kwargs))
- self.subnetwork_mask = subnetwork_mask(self.model, **self._subnetmask_kwargs)
+ self._subnetwork_mask = subnetwork_mask(self.model, **self._subnetmask_kwargs)
self.n_params_subnet = None
def _init_H(self):
@@ -110,9 +111,20 @@ def fit(self, train_loader):
"""
# select subnetwork and pass it to backend
- self.subnetwork_mask.select(train_loader)
- self.backend.subnetwork_indices = self.subnetwork_mask.indices
- self.n_params_subnet = self.subnetwork_mask.n_params_subnet
+ self._subnetwork_mask.select(train_loader)
+ self.backend.subnetwork_indices = self._subnetwork_mask.indices
+ self.n_params_subnet = self._subnetwork_mask.n_params_subnet
# fit Laplace approximation over subnetwork
super().fit(train_loader)
+
+ def sample(self, n_samples=100):
+ # sample parameters just of the subnetwork
+ subnet_mean = self.mean[self._subnetwork_mask.indices]
+ dist = MultivariateNormal(loc=subnet_mean, scale_tril=self.posterior_scale)
+ subnet_samples = dist.sample((n_samples,))
+
+ # set all other parameters to their MAP estimates
+ full_samples = self.mean.repeat(n_samples, 1)
+ full_samples[:, self._subnetwork_mask.indices] = subnet_samples
+ return full_samples
From 7a4848982b1946af1b5be5050d26864f614dcc24 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Wed, 15 Dec 2021 17:12:12 +0000
Subject: [PATCH 17/49] Add tests for SubnetLaplace predictives
---
tests/test_subnetlaplace.py | 91 +++++++++++++++++++++++++++++++------
1 file changed, 76 insertions(+), 15 deletions(-)
diff --git a/tests/test_subnetlaplace.py b/tests/test_subnetlaplace.py
index 875d9ca8..8ce14e6f 100644
--- a/tests/test_subnetlaplace.py
+++ b/tests/test_subnetlaplace.py
@@ -13,6 +13,7 @@
torch.manual_seed(240)
torch.set_default_tensor_type(torch.DoubleTensor)
score_based_subnet_masks = [RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask]
+all_subnet_masks = score_based_subnet_masks + [LastLayerSubnetMask]
likelihoods = ['classification', 'regression']
@@ -84,32 +85,32 @@ def test_score_based_subnet_masks(model, likelihood, subnetwork_mask, class_load
subnetmask_kwargs = dict(n_params_subnet=n_params_subnet)
lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
assert isinstance(lap, SubnetLaplace)
- assert isinstance(lap.subnetwork_mask, subnetwork_mask)
+ assert isinstance(lap._subnetwork_mask, subnetwork_mask)
# should raise error if we try to access the subnet indices before the subnet has been selected
with pytest.raises(AttributeError):
- lap.subnetwork_mask.indices
+ lap._subnetwork_mask.indices
# select subnet mask
- lap.subnetwork_mask.select(loader)
+ lap._subnetwork_mask.select(loader)
# should raise error if we try to select the subnet again
with pytest.raises(ValueError):
- lap.subnetwork_mask.select(loader)
+ lap._subnetwork_mask.select(loader)
# re-define valid subnet Laplace model
n_params_subnet = 32
subnetmask_kwargs = dict(n_params_subnet=n_params_subnet)
lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
assert isinstance(lap, SubnetLaplace)
- assert isinstance(lap.subnetwork_mask, subnetwork_mask)
+ assert isinstance(lap._subnetwork_mask, subnetwork_mask)
# fit Laplace model (which internally selects the subnet mask)
lap.fit(loader)
# check some parameters
- assert lap.subnetwork_mask.indices.equal(lap.backend.subnetwork_indices)
- assert lap.subnetwork_mask.n_params_subnet == n_params_subnet
+ assert lap._subnetwork_mask.indices.equal(lap.backend.subnetwork_indices)
+ assert lap._subnetwork_mask.n_params_subnet == n_params_subnet
assert lap.n_params_subnet == n_params_subnet
# check that Hessian and prior precision is of correct shape
@@ -140,25 +141,25 @@ def test_last_layer_subnet_mask(model, likelihood, class_loader, reg_loader):
subnetmask_kwargs = dict()
lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
assert isinstance(lap, SubnetLaplace)
- assert isinstance(lap.subnetwork_mask, subnetwork_mask)
+ assert isinstance(lap._subnetwork_mask, subnetwork_mask)
# define valid last-layer subnet Laplace model (with passing the last-layer name)
subnetmask_kwargs = dict(last_layer_name='1')
lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
assert isinstance(lap, SubnetLaplace)
- assert isinstance(lap.subnetwork_mask, subnetwork_mask)
+ assert isinstance(lap._subnetwork_mask, subnetwork_mask)
# should raise error if we access number of subnet parameters before selecting the subnet
with pytest.raises(AttributeError):
- n_params_subnet = lap.subnetwork_mask.n_params_subnet
+ n_params_subnet = lap._subnetwork_mask.n_params_subnet
# fit Laplace model
lap.fit(loader)
# check some parameters
n_params_subnet = 42
- assert lap.subnetwork_mask.indices.equal(lap.backend.subnetwork_indices)
- assert lap.subnetwork_mask.n_params_subnet == n_params_subnet
+ assert lap._subnetwork_mask.indices.equal(lap.backend.subnetwork_indices)
+ assert lap._subnetwork_mask.n_params_subnet == n_params_subnet
assert lap.n_params_subnet == n_params_subnet
# check that Hessian and prior precision is of correct shape
@@ -192,14 +193,74 @@ def get_subnet_mask(self, train_loader):
lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full')
lap.fit(loader)
assert isinstance(lap, SubnetLaplace)
- assert isinstance(lap.subnetwork_mask, subnetwork_mask)
+ assert isinstance(lap._subnetwork_mask, subnetwork_mask)
# check some parameters
- assert lap.subnetwork_mask.indices.equal(torch.tensor(list(range(model.n_params))))
- assert lap.subnetwork_mask.n_params_subnet == model.n_params
+ assert lap._subnetwork_mask.indices.equal(torch.tensor(list(range(model.n_params))))
+ assert lap._subnetwork_mask.n_params_subnet == model.n_params
assert lap.n_params_subnet == model.n_params
# check that the Hessian is identical to that of a all-weights FullLaplace model
full_lap = Laplace(model, likelihood=likelihood, subset_of_weights='all', hessian_structure='full')
full_lap.fit(loader)
assert full_lap.H.equal(lap.H)
+
+
+@pytest.mark.parametrize('subnetwork_mask', all_subnet_masks)
+def test_regression_predictive(model, reg_loader, subnetwork_mask):
+ subnetmask_kwargs = dict(n_params_subnet=32) if subnetwork_mask in score_based_subnet_masks else dict()
+ lap = Laplace(model, likelihood='regression', subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ assert isinstance(lap, SubnetLaplace)
+ assert isinstance(lap._subnetwork_mask, subnetwork_mask)
+
+ lap.fit(reg_loader)
+ X, _ = reg_loader.dataset.tensors
+ f = model(X)
+
+ # error
+ with pytest.raises(ValueError):
+ lap(X, pred_type='linear')
+
+ # GLM predictive
+ f_mu, f_var = lap(X, pred_type='glm')
+ assert torch.allclose(f_mu, f)
+ assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1], f_mu.shape[1]])
+ assert len(f_mu) == len(X)
+
+ # NN predictive (only diagonal variance estimation)
+ f_mu, f_var = lap(X, pred_type='nn')
+ assert f_mu.shape == f_var.shape
+ assert f_var.shape == torch.Size([f_mu.shape[0], f_mu.shape[1]])
+ assert len(f_mu) == len(X)
+
+
+@pytest.mark.parametrize('subnetwork_mask', all_subnet_masks)
+def test_classification_predictive(model, class_loader, subnetwork_mask):
+ subnetmask_kwargs = dict(n_params_subnet=32) if subnetwork_mask in score_based_subnet_masks else dict()
+ lap = Laplace(model, likelihood='classification', subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ assert isinstance(lap, SubnetLaplace)
+ assert isinstance(lap._subnetwork_mask, subnetwork_mask)
+
+ lap.fit(class_loader)
+ X, _ = class_loader.dataset.tensors
+ f = torch.softmax(model(X), dim=-1)
+
+ # error
+ with pytest.raises(ValueError):
+ lap(X, pred_type='linear')
+
+ # GLM predictive
+ f_pred = lap(X, pred_type='glm', link_approx='mc', n_samples=100)
+ assert f_pred.shape == f.shape
+ assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) # sum up to 1
+ f_pred = lap(X, pred_type='glm', link_approx='probit')
+ assert f_pred.shape == f.shape
+ assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) # sum up to 1
+ f_pred = lap(X, pred_type='glm', link_approx='bridge')
+ assert f_pred.shape == f.shape
+ assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) # sum up to 1
+
+ # NN predictive
+ f_pred = lap(X, pred_type='nn', n_samples=100)
+ assert f_pred.shape == f.shape
+ assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) # sum up to 1
\ No newline at end of file
From bcc9ca7df8983a59357dd6e38ac5fbfe53d4bb44 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Thu, 16 Dec 2021 13:36:36 +0000
Subject: [PATCH 18/49] Fix small bug in LastLayerSubnetMask
---
laplace/subnetmask.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/laplace/subnetmask.py b/laplace/subnetmask.py
index 45bb2c3e..44d5c2fc 100644
--- a/laplace/subnetmask.py
+++ b/laplace/subnetmask.py
@@ -101,7 +101,7 @@ class ScoreBasedSubnetMask(SubnetMask):
----------
model : torch.nn.Module
n_params_subnet : int
- the number of parameters in the subnetwork (i.e. the number of top-scoring parameters to select)
+ number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)
"""
def __init__(self, model, n_params_subnet):
super().__init__(model)
@@ -152,13 +152,13 @@ def compute_param_scores(self, train_loader):
class LargestVarianceDiagLaplaceSubnetMask(ScoreBasedSubnetMask):
"""Subnetwork mask identifying the parameters with the largest marginal variances
- (estimated using a diagional Laplace approximation over all model parameters).
+ (estimated using a diagonal Laplace approximation over all model parameters).
Parameters
----------
model : torch.nn.Module
n_params_subnet : int
- the number of parameters in the subnetwork (i.e. the number of top-scoring parameters to select)
+ number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)
diag_laplace_model : `laplace.baselaplace.DiagLaplace`
diagonal Laplace model to use for variance estimation
"""
@@ -203,7 +203,7 @@ def get_subnet_mask(self, train_loader):
subnet_mask_list = []
for name, layer in self.model.model.named_modules():
- if len(list(layer.children())) > 0:
+ if len(list(layer.children())) > 0 or len(list(layer.parameters())) == 0:
continue
if name == self.model._last_layer_name:
mask_method = torch.ones_like
From 381f79d44e3d48ad997e001531d9fd8348f02b01 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Thu, 16 Dec 2021 13:37:22 +0000
Subject: [PATCH 19/49] Add reference to subnetwork inference paper to
SubnetLaplace docstring
---
laplace/subnetlaplace.py | 10 ++++++++--
1 file changed, 8 insertions(+), 2 deletions(-)
diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py
index cce841fa..16ed87cb 100644
--- a/laplace/subnetlaplace.py
+++ b/laplace/subnetlaplace.py
@@ -12,8 +12,8 @@
class SubnetLaplace(FullLaplace):
"""Class for subnetwork Laplace, which computes the Laplace approximation over
- just a subset of the model parameters (i.e. a subnetwork within the neural network).
- Subnetwork Laplace only supports a full Hessian approximation; other Hessian
+ just a subset of the model parameters (i.e. a subnetwork within the neural network),
+ as proposed in [1]. Subnetwork Laplace only supports a full Hessian approximation; other
approximations could be used in theory, but would not make as much sense conceptually.
A Laplace approximation is represented by a MAP which is given by the
@@ -38,6 +38,12 @@ class SubnetLaplace(FullLaplace):
matrix. Mathematically, we have \\(P \\in \\mathbb{R}^{P \\times P}\\).
See `FullLaplace` and `BaseLaplace` for the full interface.
+ References
+ ----------
+ [1] Daxberger, E., Nalisnick, E., Allingham, JU., Antorán, J., Hernández-Lobato, JM.
+ [*Bayesian Deep Learning via Subnetwork Inference*](https://arxiv.org/abs/2010.14689).
+ ICML 2021.
+
Parameters
----------
model : torch.nn.Module or `laplace.feature_extractor.FeatureExtractor`
From f345ac55728ce74b7edcff465b3cd0d380a57b2d Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Mon, 20 Dec 2021 13:50:49 +0000
Subject: [PATCH 20/49] Add SubnetMask that allows specifying subnet parameters
or modules by name
---
laplace/subnetmask.py | 137 ++++++++++++++++++++-------
tests/test_subnetlaplace.py | 179 ++++++++++++++++++++++++++++--------
2 files changed, 248 insertions(+), 68 deletions(-)
diff --git a/laplace/subnetmask.py b/laplace/subnetmask.py
index 44d5c2fc..8b0b8d29 100644
--- a/laplace/subnetmask.py
+++ b/laplace/subnetmask.py
@@ -1,9 +1,11 @@
+from copy import deepcopy
+
import torch
from torch.nn.utils import parameters_to_vector
from laplace.feature_extractor import FeatureExtractor
-__all__ = ['SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask', 'LastLayerSubnetMask', 'LargestVarianceDiagLaplaceSubnetMask']
+__all__ = ['SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask', 'LargestVarianceDiagLaplaceSubnetMask', 'ParamNameSubnetMask', 'ModuleNameSubnetMask', 'LastLayerSubnetMask']
class SubnetMask:
@@ -21,10 +23,6 @@ def __init__(self, model):
self._indices = None
self._n_params_subnet = None
- @property
- def n_params_subnet(self):
- raise NotImplementedError
-
def _check_select(self):
if self._indices is None:
raise AttributeError('Subnetwork mask not selected. Run select() first.')
@@ -34,6 +32,13 @@ def indices(self):
self._check_select()
return self._indices
+ @property
+ def n_params_subnet(self):
+ if self._n_params_subnet is None:
+ self._check_select()
+ self._n_params_subnet = len(self._indices)
+ return self._n_params_subnet
+
def convert_subnet_mask_to_indices(self, subnet_mask):
"""Converts a subnetwork mask into subnetwork indices.
@@ -113,10 +118,6 @@ def __init__(self, model, n_params_subnet):
self._n_params_subnet = n_params_subnet
self._param_scores = None
- @property
- def n_params_subnet(self):
- return self._n_params_subnet
-
def compute_param_scores(self, train_loader):
raise NotImplementedError
@@ -171,44 +172,118 @@ def compute_param_scores(self, train_loader):
return self.diag_laplace_model.posterior_variance
-class LastLayerSubnetMask(SubnetMask):
- """Subnetwork mask corresponding to the last layer of the neural network.
+class ParamNameSubnetMask(SubnetMask):
+ """Subnetwork mask corresponding to the specified parameters of the neural network.
Parameters
----------
model : torch.nn.Module
- last_layer_name: str, default=None
- name of the model's last layer, if None it will be determined automatically
+ parameter_names: List[str]
+ list of names of the parameters (as in `model.named_parameters()`) that define the subnetwork
"""
- def __init__(self, model, last_layer_name=None):
+ def __init__(self, model, parameter_names):
super().__init__(model)
- self.model = FeatureExtractor(self.model, last_layer_name=last_layer_name)
+ self._parameter_names = parameter_names
self._n_params_subnet = None
- @property
- def n_params_subnet(self):
- if self._n_params_subnet is None:
- self._check_select()
- self._n_params_subnet = len(self._indices)
- return self._n_params_subnet
+ def _check_param_names(self):
+ param_names = deepcopy(self._parameter_names)
+ if len(param_names) == 0:
+ raise ValueError(f'Parameter name list cannot be empty.')
+
+ for name, _ in self.model.named_parameters():
+ if name in param_names:
+ param_names.remove(name)
+ if len(param_names) > 0:
+ raise ValueError(f'Parameters {param_names} do not exist in model.')
def get_subnet_mask(self, train_loader):
- """ Get the subnetwork mask identifying the last layer."""
+ """ Get the subnetwork mask identifying the specified parameters."""
- self.model.eval()
- if self.model.last_layer is None:
- X, _ = next(iter(train_loader))
- with torch.no_grad():
- self.model.find_last_layer(X[:1].to(self._device))
+ self._check_param_names()
subnet_mask_list = []
- for name, layer in self.model.model.named_modules():
- if len(list(layer.children())) > 0 or len(list(layer.parameters())) == 0:
+ for name, param in self.model.named_parameters():
+ if name in self._parameter_names:
+ mask_method = torch.ones_like
+ else:
+ mask_method = torch.zeros_like
+ subnet_mask_list.append(mask_method(parameters_to_vector(param)))
+ subnet_mask = torch.cat(subnet_mask_list).byte()
+ return subnet_mask
+
+
+class ModuleNameSubnetMask(SubnetMask):
+ """Subnetwork mask corresponding to the specified modules of the neural network.
+
+ Parameters
+ ----------
+ model : torch.nn.Module
+ parameter_names: List[str]
+ list of names of the modules (as in `model.named_modules()`) that define the subnetwork;
+ the modules cannot have children, i.e. need to be leaf modules
+ """
+ def __init__(self, model, module_names):
+ super().__init__(model)
+ self._module_names = module_names
+ self._n_params_subnet = None
+
+ def _check_module_names(self):
+ module_names = deepcopy(self._module_names)
+ if len(module_names) == 0:
+ raise ValueError(f'Module name list cannot be empty.')
+
+ for name, module in self.model.named_modules():
+ if name in module_names:
+ if len(list(module.children())) > 0:
+ raise ValueError(f'Module "{name}" has children, which is not supported.')
+ elif len(list(module.parameters())) == 0:
+ raise ValueError(f'Module "{name}" does not have any parameters.')
+ else:
+ module_names.remove(name)
+ if len(module_names) > 0:
+ raise ValueError(f'Modules {module_names} do not exist in model.')
+
+ def get_subnet_mask(self, train_loader):
+ """ Get the subnetwork mask identifying the specified modules."""
+
+ self._check_module_names()
+
+ subnet_mask_list = []
+ for name, module in self.model.named_modules():
+ if len(list(module.children())) > 0 or len(list(module.parameters())) == 0:
continue
- if name == self.model._last_layer_name:
+ if name in self._module_names:
mask_method = torch.ones_like
else:
mask_method = torch.zeros_like
- subnet_mask_list.append(mask_method(parameters_to_vector(layer.parameters())))
+ subnet_mask_list.append(mask_method(parameters_to_vector(module.parameters())))
subnet_mask = torch.cat(subnet_mask_list).byte()
return subnet_mask
+
+
+class LastLayerSubnetMask(ModuleNameSubnetMask):
+ """Subnetwork mask corresponding to the last layer of the neural network.
+
+ Parameters
+ ----------
+ model : torch.nn.Module
+ last_layer_name: str, default=None
+ name of the model's last layer, if None it will be determined automatically
+ """
+ def __init__(self, model, last_layer_name=None):
+ super().__init__(model, None)
+ self._feature_extractor = FeatureExtractor(self.model, last_layer_name=last_layer_name)
+ self._n_params_subnet = None
+
+ def get_subnet_mask(self, train_loader):
+ """ Get the subnetwork mask identifying the last layer."""
+
+ self._feature_extractor.eval()
+ if self._feature_extractor.last_layer is None:
+ X = next(iter(train_loader))[0]
+ with torch.no_grad():
+ self._feature_extractor.find_last_layer(X[:1].to(self._device))
+ self._module_names = [self._feature_extractor._last_layer_name]
+
+ return super().get_subnet_mask(train_loader)
diff --git a/tests/test_subnetlaplace.py b/tests/test_subnetlaplace.py
index 8ce14e6f..b6687c96 100644
--- a/tests/test_subnetlaplace.py
+++ b/tests/test_subnetlaplace.py
@@ -7,13 +7,14 @@
from torch.utils.data import DataLoader, TensorDataset
from laplace import Laplace, SubnetLaplace
-from laplace.subnetmask import SubnetMask, RandomSubnetMask, LargestMagnitudeSubnetMask, LastLayerSubnetMask, LargestVarianceDiagLaplaceSubnetMask
+from laplace.subnetmask import SubnetMask, RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask, ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask
torch.manual_seed(240)
torch.set_default_tensor_type(torch.DoubleTensor)
score_based_subnet_masks = [RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask]
-all_subnet_masks = score_based_subnet_masks + [LastLayerSubnetMask]
+layer_subnet_masks = [ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask]
+all_subnet_masks = score_based_subnet_masks + layer_subnet_masks
likelihoods = ['classification', 'regression']
@@ -64,25 +65,26 @@ def test_subnet_laplace_init(model, likelihood):
@pytest.mark.parametrize('subnetwork_mask,likelihood', product(score_based_subnet_masks, likelihoods))
def test_score_based_subnet_masks(model, likelihood, subnetwork_mask, class_loader, reg_loader):
loader = class_loader if likelihood == 'classification' else reg_loader
+ model_params = parameters_to_vector(model.parameters())
+ subnetmask_kwargs = dict()
# should raise error if we don't pass number of subnet parameters within the subnetmask_kwargs
- subnetmask_kwargs = dict()
with pytest.raises(TypeError):
lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
# should raise error if we set number of subnet parameters to None
- subnetmask_kwargs = dict(n_params_subnet=None)
+ subnetmask_kwargs.update(n_params_subnet=None)
with pytest.raises(ValueError):
lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
# should raise error if we set number of subnet parameters to be larger than number of model parameters
- subnetmask_kwargs = dict(n_params_subnet=99999)
+ subnetmask_kwargs.update(n_params_subnet=99999)
with pytest.raises(ValueError):
lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
# define valid subnet Laplace model
n_params_subnet = 32
- subnetmask_kwargs = dict(n_params_subnet=n_params_subnet)
+ subnetmask_kwargs.update(n_params_subnet=n_params_subnet)
lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
assert isinstance(lap, SubnetLaplace)
assert isinstance(lap._subnetwork_mask, subnetwork_mask)
@@ -100,7 +102,7 @@ def test_score_based_subnet_masks(model, likelihood, subnetwork_mask, class_load
# re-define valid subnet Laplace model
n_params_subnet = 32
- subnetmask_kwargs = dict(n_params_subnet=n_params_subnet)
+ subnetmask_kwargs.update(n_params_subnet=n_params_subnet)
lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
assert isinstance(lap, SubnetLaplace)
assert isinstance(lap._subnetwork_mask, subnetwork_mask)
@@ -112,6 +114,7 @@ def test_score_based_subnet_masks(model, likelihood, subnetwork_mask, class_load
assert lap._subnetwork_mask.indices.equal(lap.backend.subnetwork_indices)
assert lap._subnetwork_mask.n_params_subnet == n_params_subnet
assert lap.n_params_subnet == n_params_subnet
+ assert parameters_to_vector(model.parameters()).equal(model_params)
# check that Hessian and prior precision is of correct shape
assert lap.H.shape == (n_params_subnet, n_params_subnet)
@@ -122,42 +125,135 @@ def test_score_based_subnet_masks(model, likelihood, subnetwork_mask, class_load
lap.fit(loader)
-@pytest.mark.parametrize('likelihood', likelihoods)
-def test_last_layer_subnet_mask(model, likelihood, class_loader, reg_loader):
- subnetwork_mask = LastLayerSubnetMask
+@pytest.mark.parametrize('subnetwork_mask,likelihood', product(layer_subnet_masks, likelihoods))
+def test_layer_subnet_masks(model, likelihood, subnetwork_mask, class_loader, reg_loader):
loader = class_loader if likelihood == 'classification' else reg_loader
+ # fit last-layer Laplace model
+ lllap = Laplace(model, likelihood=likelihood, subset_of_weights='last_layer', hessian_structure='full')
+ lllap.fit(loader)
+
# should raise error if we pass number of subnet parameters
subnetmask_kwargs = dict(n_params_subnet=32)
with pytest.raises(TypeError):
lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
- # should raise error if we pass invalid last-layer name
- subnetmask_kwargs = dict(last_layer_name='123')
- with pytest.raises(KeyError):
+ if subnetwork_mask == ParamNameSubnetMask:
+ # should raise error if we pass no parameter name list
+ subnetmask_kwargs = dict()
+ with pytest.raises(TypeError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+
+ # should raise error if we pass an empty parameter name list
+ subnetmask_kwargs = dict(parameter_names=[])
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ lap.fit(loader)
+
+ # should raise error if we pass a parameter name list with invalid parameter names
+ subnetmask_kwargs = dict(parameter_names=['123'])
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ lap.fit(loader)
+
+ # define last-layer Laplace model by parameter names and check that Hessian is identical to that of a full LLLaplace model
+ subnetmask_kwargs = dict(parameter_names=['1.weight', '1.bias'])
lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ lap.fit(loader)
+ assert lllap.H.equal(lap.H)
- # define valid last-layer subnet Laplace model (without passing the last-layer name)
- subnetmask_kwargs = dict()
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
- assert isinstance(lap, SubnetLaplace)
- assert isinstance(lap._subnetwork_mask, subnetwork_mask)
+ # define valid parameter name subnet Laplace model
+ subnetmask_kwargs = dict(parameter_names=['0.weight', '1.bias'])
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ n_params_subnet = 62
+ assert isinstance(lap, SubnetLaplace)
+ assert isinstance(lap._subnetwork_mask, subnetwork_mask)
- # define valid last-layer subnet Laplace model (with passing the last-layer name)
- subnetmask_kwargs = dict(last_layer_name='1')
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
- assert isinstance(lap, SubnetLaplace)
- assert isinstance(lap._subnetwork_mask, subnetwork_mask)
+ # should raise error if we access number of subnet parameters before selecting the subnet
+ with pytest.raises(AttributeError):
+ n_params_subnet = lap._subnetwork_mask.n_params_subnet
- # should raise error if we access number of subnet parameters before selecting the subnet
- with pytest.raises(AttributeError):
- n_params_subnet = lap._subnetwork_mask.n_params_subnet
+ # fit Laplace model
+ lap.fit(loader)
- # fit Laplace model
- lap.fit(loader)
+ elif subnetwork_mask == ModuleNameSubnetMask:
+ # should raise error if we pass no module name list
+ subnetmask_kwargs = dict()
+ with pytest.raises(TypeError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+
+ # should raise error if we pass an empty module name list
+ subnetmask_kwargs = dict(module_names=[])
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ lap.fit(loader)
+
+ # should raise error if we pass a module name list with invalid module names
+ subnetmask_kwargs = dict(module_names=['123'])
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ lap.fit(loader)
+
+ # define last-layer Laplace model by module name and check that Hessian is identical to that of a full LLLaplace model
+ subnetmask_kwargs = dict(module_names=['1'])
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ lap.fit(loader)
+ assert lllap.H.equal(lap.H)
+
+ # define valid parameter name subnet Laplace model
+ subnetmask_kwargs = dict(module_names=['0'])
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ n_params_subnet = 80
+ assert isinstance(lap, SubnetLaplace)
+ assert isinstance(lap._subnetwork_mask, subnetwork_mask)
+
+ # should raise error if we access number of subnet parameters before selecting the subnet
+ with pytest.raises(AttributeError):
+ n_params_subnet = lap._subnetwork_mask.n_params_subnet
+
+ # fit Laplace model
+ lap.fit(loader)
+
+ elif subnetwork_mask == LastLayerSubnetMask:
+ # should raise error if we pass invalid last-layer name
+ subnetmask_kwargs = dict(last_layer_name='123')
+ with pytest.raises(KeyError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+
+ # define valid last-layer subnet Laplace model (without passing the last-layer name)
+ subnetmask_kwargs = dict()
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ assert isinstance(lap, SubnetLaplace)
+ assert isinstance(lap._subnetwork_mask, subnetwork_mask)
+
+ # should raise error if we access number of subnet parameters before selecting the subnet
+ with pytest.raises(AttributeError):
+ n_params_subnet = lap._subnetwork_mask.n_params_subnet
+
+ # fit Laplace model
+ lap.fit(loader)
+
+ # check that Hessian is identical to that of a full LLLaplace model
+ assert lllap.H.equal(lap.H)
+
+ # define valid last-layer subnet Laplace model (with passing the last-layer name)
+ subnetmask_kwargs = dict(last_layer_name='1')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ n_params_subnet = 42
+ assert isinstance(lap, SubnetLaplace)
+ assert isinstance(lap._subnetwork_mask, subnetwork_mask)
+
+ # should raise error if we access number of subnet parameters before selecting the subnet
+ with pytest.raises(AttributeError):
+ n_params_subnet = lap._subnetwork_mask.n_params_subnet
+
+ # fit Laplace model
+ lap.fit(loader)
+
+ # check that Hessian is identical to that of a full LLLaplace model
+ assert lllap.H.equal(lap.H)
# check some parameters
- n_params_subnet = 42
assert lap._subnetwork_mask.indices.equal(lap.backend.subnetwork_indices)
assert lap._subnetwork_mask.n_params_subnet == n_params_subnet
assert lap.n_params_subnet == n_params_subnet
@@ -166,11 +262,6 @@ def test_last_layer_subnet_mask(model, likelihood, class_loader, reg_loader):
assert lap.H.shape == (n_params_subnet, n_params_subnet)
assert lap.prior_precision_diag.shape == (n_params_subnet,)
- # check that Hessian is identical to that of a full LLLaplace model
- lllap = Laplace(model, likelihood=likelihood, subset_of_weights='last_layer', hessian_structure='full')
- lllap.fit(loader)
- assert lllap.H.equal(lap.H)
-
@pytest.mark.parametrize('likelihood', likelihoods)
def test_full_subnet_mask(model, likelihood, class_loader, reg_loader):
@@ -208,7 +299,14 @@ def get_subnet_mask(self, train_loader):
@pytest.mark.parametrize('subnetwork_mask', all_subnet_masks)
def test_regression_predictive(model, reg_loader, subnetwork_mask):
- subnetmask_kwargs = dict(n_params_subnet=32) if subnetwork_mask in score_based_subnet_masks else dict()
+ if subnetwork_mask in score_based_subnet_masks:
+ subnetmask_kwargs = dict(n_params_subnet=32)
+ elif subnetwork_mask == ParamNameSubnetMask:
+ subnetmask_kwargs = dict(parameter_names=['0.weight', '1.bias'])
+ elif subnetwork_mask == ModuleNameSubnetMask:
+ subnetmask_kwargs = dict(module_names=['0'])
+ else:
+ subnetmask_kwargs = dict()
lap = Laplace(model, likelihood='regression', subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
assert isinstance(lap, SubnetLaplace)
assert isinstance(lap._subnetwork_mask, subnetwork_mask)
@@ -236,7 +334,14 @@ def test_regression_predictive(model, reg_loader, subnetwork_mask):
@pytest.mark.parametrize('subnetwork_mask', all_subnet_masks)
def test_classification_predictive(model, class_loader, subnetwork_mask):
- subnetmask_kwargs = dict(n_params_subnet=32) if subnetwork_mask in score_based_subnet_masks else dict()
+ if subnetwork_mask in score_based_subnet_masks:
+ subnetmask_kwargs = dict(n_params_subnet=32)
+ elif subnetwork_mask == ParamNameSubnetMask:
+ subnetmask_kwargs = dict(parameter_names=['0.weight', '1.bias'])
+ elif subnetwork_mask == ModuleNameSubnetMask:
+ subnetmask_kwargs = dict(module_names=['0'])
+ else:
+ subnetmask_kwargs = dict()
lap = Laplace(model, likelihood='classification', subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
assert isinstance(lap, SubnetLaplace)
assert isinstance(lap._subnetwork_mask, subnetwork_mask)
@@ -263,4 +368,4 @@ def test_classification_predictive(model, class_loader, subnetwork_mask):
# NN predictive
f_pred = lap(X, pred_type='nn', n_samples=100)
assert f_pred.shape == f.shape
- assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) # sum up to 1
\ No newline at end of file
+ assert torch.allclose(f_pred.sum(), torch.tensor(len(f_pred), dtype=torch.double)) # sum up to 1
From ea16f3467181f5e7dafd1559404cee2a3805224b Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Mon, 20 Dec 2021 17:30:15 +0000
Subject: [PATCH 21/49] Make subnet mask type check independent of CUDA and
change default type to bool
---
laplace/subnetmask.py | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/laplace/subnetmask.py b/laplace/subnetmask.py
index 8b0b8d29..caa29094 100644
--- a/laplace/subnetmask.py
+++ b/laplace/subnetmask.py
@@ -55,9 +55,9 @@ def convert_subnet_mask_to_indices(self, subnet_mask):
"""
if not isinstance(subnet_mask, torch.Tensor):
raise ValueError('Subnetwork mask needs to be torch.Tensor!')
- elif subnet_mask.type() not in ['torch.ByteTensor', 'torch.IntTensor', 'torch.LongTensor'] or\
+ elif subnet_mask.dtype not in [torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool] or\
len(subnet_mask.shape) != 1:
- raise ValueError('Subnetwork mask needs to be 1-dimensional torch.{Byte,Int,Long}Tensor!')
+ raise ValueError('Subnetwork mask needs to be 1-dimensional integral or boolean tensor!')
elif len(subnet_mask) != self._n_params or\
len(subnet_mask[subnet_mask == 0]) + len(subnet_mask[subnet_mask == 1]) != self._n_params:
raise ValueError('Subnetwork mask needs to be a binary vector of size (n_params) where 1s'\
@@ -134,7 +134,7 @@ def get_subnet_mask(self, train_loader):
idx = torch.argsort(self._param_scores, descending=True)[:self._n_params_subnet]
idx = idx.sort()[0]
- subnet_mask = torch.zeros_like(self.parameter_vector).byte()
+ subnet_mask = torch.zeros_like(self.parameter_vector).bool()
subnet_mask[idx] = 1
return subnet_mask
@@ -209,7 +209,7 @@ def get_subnet_mask(self, train_loader):
else:
mask_method = torch.zeros_like
subnet_mask_list.append(mask_method(parameters_to_vector(param)))
- subnet_mask = torch.cat(subnet_mask_list).byte()
+ subnet_mask = torch.cat(subnet_mask_list).bool()
return subnet_mask
@@ -258,7 +258,7 @@ def get_subnet_mask(self, train_loader):
else:
mask_method = torch.zeros_like
subnet_mask_list.append(mask_method(parameters_to_vector(module.parameters())))
- subnet_mask = torch.cat(subnet_mask_list).byte()
+ subnet_mask = torch.cat(subnet_mask_list).bool()
return subnet_mask
From 176d678093fa41aa267d3d522890491571ac3934 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Tue, 21 Dec 2021 09:36:38 +0000
Subject: [PATCH 22/49] Change LargestMagnitudeSubnetMask to use absolute
parameter values
---
laplace/subnetmask.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/laplace/subnetmask.py b/laplace/subnetmask.py
index caa29094..3030f163 100644
--- a/laplace/subnetmask.py
+++ b/laplace/subnetmask.py
@@ -148,7 +148,7 @@ def compute_param_scores(self, train_loader):
class LargestMagnitudeSubnetMask(ScoreBasedSubnetMask):
"""Subnetwork mask identifying the parameters with the largest magnitude. """
def compute_param_scores(self, train_loader):
- return self.parameter_vector
+ return self.parameter_vector.abs()
class LargestVarianceDiagLaplaceSubnetMask(ScoreBasedSubnetMask):
From 59603e8e7b0d1cc6c0cab165162565aead74e196 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Tue, 21 Dec 2021 09:38:34 +0000
Subject: [PATCH 23/49] Small refactoring of SubnetLaplace tests
---
tests/test_subnetlaplace.py | 7 -------
1 file changed, 7 deletions(-)
diff --git a/tests/test_subnetlaplace.py b/tests/test_subnetlaplace.py
index b6687c96..3d15e736 100644
--- a/tests/test_subnetlaplace.py
+++ b/tests/test_subnetlaplace.py
@@ -269,13 +269,6 @@ def test_full_subnet_mask(model, likelihood, class_loader, reg_loader):
# define full model 'subnet' mask class (i.e. where all parameters are part of the subnet)
class FullSubnetMask(SubnetMask):
- @property
- def n_params_subnet(self):
- if self._n_params_subnet is None:
- self._check_select()
- self._n_params_subnet = len(self._indices)
- return self._n_params_subnet
-
def get_subnet_mask(self, train_loader):
return torch.ones(model.n_params).byte()
From 8410ecaf23477e6e75207ff0fa6fbe1105f1a2f2 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Tue, 21 Dec 2021 09:43:52 +0000
Subject: [PATCH 24/49] Add implementation of SubnetMask that selects params
with largest variance, estimated via diagonal SWAG
---
laplace/subnetmask.py | 38 ++++++++++++++++-
laplace/swag.py | 82 +++++++++++++++++++++++++++++++++++++
tests/test_subnetlaplace.py | 8 ++--
3 files changed, 124 insertions(+), 4 deletions(-)
create mode 100644 laplace/swag.py
diff --git a/laplace/subnetmask.py b/laplace/subnetmask.py
index 3030f163..36a135bd 100644
--- a/laplace/subnetmask.py
+++ b/laplace/subnetmask.py
@@ -1,11 +1,13 @@
from copy import deepcopy
import torch
+from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn.utils import parameters_to_vector
from laplace.feature_extractor import FeatureExtractor
+from laplace.swag import fit_diagonal_swag
-__all__ = ['SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask', 'LargestVarianceDiagLaplaceSubnetMask', 'ParamNameSubnetMask', 'ModuleNameSubnetMask', 'LastLayerSubnetMask']
+__all__ = ['SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask', 'LargestVarianceDiagLaplaceSubnetMask', 'LargestVarianceSWAGSubnetMask', 'ParamNameSubnetMask', 'ModuleNameSubnetMask', 'LastLayerSubnetMask']
class SubnetMask:
@@ -172,6 +174,40 @@ def compute_param_scores(self, train_loader):
return self.diag_laplace_model.posterior_variance
+class LargestVarianceSWAGSubnetMask(ScoreBasedSubnetMask):
+ """Subnetwork mask identifying the parameters with the largest marginal variances
+ (estimated using diagonal SWAG over all model parameters).
+
+ Parameters
+ ----------
+ model : torch.nn.Module
+ n_params_subnet : int
+ number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)
+ likelihood : str
+ 'classification' or 'regression'
+ swag_n_snapshots : int
+ number of model snapshots to collect for SWAG
+ swag_snapshot_freq : int
+ SWAG snapshot collection frequency (in epochs)
+ swag_lr : float
+ learning rate for SWAG snapshot collection
+ """
+ def __init__(self, model, n_params_subnet, likelihood='classification', swag_n_snapshots=40, swag_snapshot_freq=1, swag_lr=0.01):
+ super().__init__(model, n_params_subnet)
+ self.likelihood = likelihood
+ self.swag_n_snapshots = swag_n_snapshots
+ self.swag_snapshot_freq = swag_snapshot_freq
+ self.swag_lr = swag_lr
+
+ def compute_param_scores(self, train_loader):
+ if self.likelihood == 'classification':
+ criterion = CrossEntropyLoss(reduction='mean')
+ elif self.likelihood == 'regression':
+ criterion = MSELoss(reduction='mean')
+ param_variances = fit_diagonal_swag(self.model, train_loader, criterion, n_snapshots_total=self.swag_n_snapshots, snapshot_freq=self.swag_snapshot_freq, lr=self.swag_lr)
+ return param_variances
+
+
class ParamNameSubnetMask(SubnetMask):
"""Subnetwork mask corresponding to the specified parameters of the neural network.
diff --git a/laplace/swag.py b/laplace/swag.py
new file mode 100644
index 00000000..c8461fbb
--- /dev/null
+++ b/laplace/swag.py
@@ -0,0 +1,82 @@
+from copy import deepcopy
+from tqdm import tqdm
+
+import torch
+from torch.nn.utils import parameters_to_vector
+
+
+def param_vector(model):
+ return parameters_to_vector(model.parameters()).detach()
+
+
+def fit_diagonal_swag(model, train_loader, criterion, n_snapshots_total=40, snapshot_freq=1, lr=0.01, momentum=0.9, weight_decay=3e-4, min_var=1e-30):
+ """
+ Fit diagonal SWAG [1], which estimates marginal variances of model parameters by
+ computing the first and second moment of SGD iterates with a large learning rate.
+
+ Implementation partly adapted from:
+ - https://github.com/wjmaddox/swa_gaussian/blob/master/swag/posteriors/swag.py
+ - https://github.com/wjmaddox/swa_gaussian/blob/master/experiments/train/run_swag.py
+
+ References
+ ----------
+ [1] Maddox, W., Garipov, T., Izmailov, P., Vetrov, D., Wilson, AG.
+ [*A Simple Baseline for Bayesian Uncertainty in Deep Learning*](https://arxiv.org/abs/1902.02476).
+ NeurIPS 2019.
+
+ Parameters
+ ----------
+ model : torch.nn.Module
+ train_loader : torch.data.utils.DataLoader
+ training data loader to use for snapshot collection
+ criterion : torch.nn.CrossEntropyLoss or torch.nn.MSELoss
+ loss function to use for snapshot collection
+ n_snapshots_total : int
+ total number of model snapshots to collect
+ snapshot_freq : int
+ snapshot collection frequency (in epochs)
+ lr : float
+ SGD learning rate for collecting snapshots
+ momentum : float
+ SGD momentum
+ weight_decay : float
+ SGD weight decay
+ min_var : float
+ minimum parameter variance to clamp to (for numerical stability)
+
+ Returns
+ -------
+ param_variances : torch.Tensor
+ vector of marginal variances for each model parameter
+ """
+
+ # create a copy of the model to avoid undesired changes to the original model parameters
+ _model = deepcopy(model)
+ _model.train()
+ device = next(_model.parameters()).device
+
+ # initialize running estimates of first and second moment of model parameters
+ mean = torch.zeros_like(param_vector(_model))
+ sq_mean = torch.zeros_like(param_vector(_model))
+ n_snapshots = 0
+
+ # run SGD to collect model snapshots
+ optimizer = torch.optim.SGD(_model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
+ n_epochs = snapshot_freq * n_snapshots_total
+ for epoch in tqdm(range(n_epochs)):
+ for inputs, targets in train_loader:
+ inputs, targets = inputs.to(device), targets.to(device)
+ optimizer.zero_grad()
+ loss = criterion(_model(inputs), targets)
+ loss.backward()
+ optimizer.step()
+
+ if epoch % snapshot_freq == 0:
+ # update running estimates of first and second moment of model parameters
+ mean = mean * n_snapshots / (n_snapshots + 1) + param_vector(_model) / (n_snapshots + 1)
+ sq_mean = sq_mean * n_snapshots / (n_snapshots + 1) + param_vector(_model) ** 2 / (n_snapshots + 1)
+ n_snapshots += 1
+
+ # compute marginal parameter variances, Var[P] = E[P^2] - E[P]^2
+ param_variances = torch.clamp(sq_mean - mean ** 2, min_var)
+ return param_variances
diff --git a/tests/test_subnetlaplace.py b/tests/test_subnetlaplace.py
index 3d15e736..1b10bb25 100644
--- a/tests/test_subnetlaplace.py
+++ b/tests/test_subnetlaplace.py
@@ -7,12 +7,12 @@
from torch.utils.data import DataLoader, TensorDataset
from laplace import Laplace, SubnetLaplace
-from laplace.subnetmask import SubnetMask, RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask, ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask
+from laplace.subnetmask import SubnetMask, RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask, ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask
torch.manual_seed(240)
torch.set_default_tensor_type(torch.DoubleTensor)
-score_based_subnet_masks = [RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask]
+score_based_subnet_masks = [RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask]
layer_subnet_masks = [ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask]
all_subnet_masks = score_based_subnet_masks + layer_subnet_masks
likelihoods = ['classification', 'regression']
@@ -66,7 +66,7 @@ def test_subnet_laplace_init(model, likelihood):
def test_score_based_subnet_masks(model, likelihood, subnetwork_mask, class_loader, reg_loader):
loader = class_loader if likelihood == 'classification' else reg_loader
model_params = parameters_to_vector(model.parameters())
- subnetmask_kwargs = dict()
+ subnetmask_kwargs = dict(likelihood=likelihood) if subnetwork_mask == LargestVarianceSWAGSubnetMask else dict()
# should raise error if we don't pass number of subnet parameters within the subnetmask_kwargs
with pytest.raises(TypeError):
@@ -300,6 +300,7 @@ def test_regression_predictive(model, reg_loader, subnetwork_mask):
subnetmask_kwargs = dict(module_names=['0'])
else:
subnetmask_kwargs = dict()
+ subnetmask_kwargs.update(dict(likelihood='regression') if subnetwork_mask == LargestVarianceSWAGSubnetMask else dict())
lap = Laplace(model, likelihood='regression', subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
assert isinstance(lap, SubnetLaplace)
assert isinstance(lap._subnetwork_mask, subnetwork_mask)
@@ -335,6 +336,7 @@ def test_classification_predictive(model, class_loader, subnetwork_mask):
subnetmask_kwargs = dict(module_names=['0'])
else:
subnetmask_kwargs = dict()
+ subnetmask_kwargs.update(dict(likelihood='classification') if subnetwork_mask == LargestVarianceSWAGSubnetMask else dict())
lap = Laplace(model, likelihood='classification', subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
assert isinstance(lap, SubnetLaplace)
assert isinstance(lap._subnetwork_mask, subnetwork_mask)
From e91646f2d1512128ee86acf99a532cdfad6100c8 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Tue, 21 Dec 2021 09:58:33 +0000
Subject: [PATCH 25/49] Remove tqdm dependency in SWAG
---
laplace/swag.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/laplace/swag.py b/laplace/swag.py
index c8461fbb..a9abf42b 100644
--- a/laplace/swag.py
+++ b/laplace/swag.py
@@ -1,5 +1,4 @@
from copy import deepcopy
-from tqdm import tqdm
import torch
from torch.nn.utils import parameters_to_vector
@@ -63,7 +62,7 @@ def fit_diagonal_swag(model, train_loader, criterion, n_snapshots_total=40, snap
# run SGD to collect model snapshots
optimizer = torch.optim.SGD(_model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
n_epochs = snapshot_freq * n_snapshots_total
- for epoch in tqdm(range(n_epochs)):
+ for epoch in range(n_epochs):
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
From 3a48e021efd69252096748eb6ad460bb1dd37c47 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Tue, 21 Dec 2021 10:24:31 +0000
Subject: [PATCH 26/49] Set H=None in SubnetLaplace before calling super()
constructor for compatibility with fixed H init
---
laplace/subnetlaplace.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py
index 16ed87cb..3aeb2454 100644
--- a/laplace/subnetlaplace.py
+++ b/laplace/subnetlaplace.py
@@ -74,6 +74,7 @@ class SubnetLaplace(FullLaplace):
def __init__(self, model, likelihood, subnetwork_mask, sigma_noise=1., prior_precision=1.,
prior_mean=0., temperature=1., backend=BackPackGGN, backend_kwargs=None, subnetmask_kwargs=None):
+ self.H = None
super().__init__(model, likelihood, sigma_noise=sigma_noise, prior_precision=prior_precision,
prior_mean=prior_mean, temperature=temperature, backend=backend,
backend_kwargs=backend_kwargs)
From 7ae867f46a2e194c0a0a09be112eafca83e2c23c Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Tue, 21 Dec 2021 10:43:17 +0000
Subject: [PATCH 27/49] Change indentation from tabs to spaces
---
laplace/subnetmask.py | 8 +--
laplace/swag.py | 116 +++++++++++++++++++++---------------------
2 files changed, 62 insertions(+), 62 deletions(-)
diff --git a/laplace/subnetmask.py b/laplace/subnetmask.py
index 36a135bd..939cff97 100644
--- a/laplace/subnetmask.py
+++ b/laplace/subnetmask.py
@@ -185,10 +185,10 @@ class LargestVarianceSWAGSubnetMask(ScoreBasedSubnetMask):
number of parameters in the subnetwork (i.e. number of top-scoring parameters to select)
likelihood : str
'classification' or 'regression'
- swag_n_snapshots : int
- number of model snapshots to collect for SWAG
- swag_snapshot_freq : int
- SWAG snapshot collection frequency (in epochs)
+ swag_n_snapshots : int
+ number of model snapshots to collect for SWAG
+ swag_snapshot_freq : int
+ SWAG snapshot collection frequency (in epochs)
swag_lr : float
learning rate for SWAG snapshot collection
"""
diff --git a/laplace/swag.py b/laplace/swag.py
index a9abf42b..d780fe16 100644
--- a/laplace/swag.py
+++ b/laplace/swag.py
@@ -5,17 +5,17 @@
def param_vector(model):
- return parameters_to_vector(model.parameters()).detach()
+ return parameters_to_vector(model.parameters()).detach()
def fit_diagonal_swag(model, train_loader, criterion, n_snapshots_total=40, snapshot_freq=1, lr=0.01, momentum=0.9, weight_decay=3e-4, min_var=1e-30):
- """
- Fit diagonal SWAG [1], which estimates marginal variances of model parameters by
- computing the first and second moment of SGD iterates with a large learning rate.
-
- Implementation partly adapted from:
- - https://github.com/wjmaddox/swa_gaussian/blob/master/swag/posteriors/swag.py
- - https://github.com/wjmaddox/swa_gaussian/blob/master/experiments/train/run_swag.py
+ """
+ Fit diagonal SWAG [1], which estimates marginal variances of model parameters by
+ computing the first and second moment of SGD iterates with a large learning rate.
+
+ Implementation partly adapted from:
+ - https://github.com/wjmaddox/swa_gaussian/blob/master/swag/posteriors/swag.py
+ - https://github.com/wjmaddox/swa_gaussian/blob/master/experiments/train/run_swag.py
References
----------
@@ -23,59 +23,59 @@ def fit_diagonal_swag(model, train_loader, criterion, n_snapshots_total=40, snap
[*A Simple Baseline for Bayesian Uncertainty in Deep Learning*](https://arxiv.org/abs/1902.02476).
NeurIPS 2019.
- Parameters
- ----------
- model : torch.nn.Module
- train_loader : torch.data.utils.DataLoader
- training data loader to use for snapshot collection
- criterion : torch.nn.CrossEntropyLoss or torch.nn.MSELoss
- loss function to use for snapshot collection
- n_snapshots_total : int
- total number of model snapshots to collect
- snapshot_freq : int
- snapshot collection frequency (in epochs)
- lr : float
- SGD learning rate for collecting snapshots
- momentum : float
- SGD momentum
- weight_decay : float
- SGD weight decay
- min_var : float
- minimum parameter variance to clamp to (for numerical stability)
+ Parameters
+ ----------
+ model : torch.nn.Module
+ train_loader : torch.data.utils.DataLoader
+ training data loader to use for snapshot collection
+ criterion : torch.nn.CrossEntropyLoss or torch.nn.MSELoss
+ loss function to use for snapshot collection
+ n_snapshots_total : int
+ total number of model snapshots to collect
+ snapshot_freq : int
+ snapshot collection frequency (in epochs)
+ lr : float
+ SGD learning rate for collecting snapshots
+ momentum : float
+ SGD momentum
+ weight_decay : float
+ SGD weight decay
+ min_var : float
+ minimum parameter variance to clamp to (for numerical stability)
- Returns
- -------
- param_variances : torch.Tensor
- vector of marginal variances for each model parameter
- """
+ Returns
+ -------
+ param_variances : torch.Tensor
+ vector of marginal variances for each model parameter
+ """
- # create a copy of the model to avoid undesired changes to the original model parameters
- _model = deepcopy(model)
- _model.train()
- device = next(_model.parameters()).device
+ # create a copy of the model to avoid undesired changes to the original model parameters
+ _model = deepcopy(model)
+ _model.train()
+ device = next(_model.parameters()).device
- # initialize running estimates of first and second moment of model parameters
- mean = torch.zeros_like(param_vector(_model))
- sq_mean = torch.zeros_like(param_vector(_model))
- n_snapshots = 0
+ # initialize running estimates of first and second moment of model parameters
+ mean = torch.zeros_like(param_vector(_model))
+ sq_mean = torch.zeros_like(param_vector(_model))
+ n_snapshots = 0
- # run SGD to collect model snapshots
- optimizer = torch.optim.SGD(_model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
- n_epochs = snapshot_freq * n_snapshots_total
- for epoch in range(n_epochs):
- for inputs, targets in train_loader:
- inputs, targets = inputs.to(device), targets.to(device)
- optimizer.zero_grad()
- loss = criterion(_model(inputs), targets)
- loss.backward()
- optimizer.step()
+ # run SGD to collect model snapshots
+ optimizer = torch.optim.SGD(_model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
+ n_epochs = snapshot_freq * n_snapshots_total
+ for epoch in range(n_epochs):
+ for inputs, targets in train_loader:
+ inputs, targets = inputs.to(device), targets.to(device)
+ optimizer.zero_grad()
+ loss = criterion(_model(inputs), targets)
+ loss.backward()
+ optimizer.step()
- if epoch % snapshot_freq == 0:
- # update running estimates of first and second moment of model parameters
- mean = mean * n_snapshots / (n_snapshots + 1) + param_vector(_model) / (n_snapshots + 1)
- sq_mean = sq_mean * n_snapshots / (n_snapshots + 1) + param_vector(_model) ** 2 / (n_snapshots + 1)
- n_snapshots += 1
+ if epoch % snapshot_freq == 0:
+ # update running estimates of first and second moment of model parameters
+ mean = mean * n_snapshots / (n_snapshots + 1) + param_vector(_model) / (n_snapshots + 1)
+ sq_mean = sq_mean * n_snapshots / (n_snapshots + 1) + param_vector(_model) ** 2 / (n_snapshots + 1)
+ n_snapshots += 1
- # compute marginal parameter variances, Var[P] = E[P^2] - E[P]^2
- param_variances = torch.clamp(sq_mean - mean ** 2, min_var)
- return param_variances
+ # compute marginal parameter variances, Var[P] = E[P^2] - E[P]^2
+ param_variances = torch.clamp(sq_mean - mean ** 2, min_var)
+ return param_variances
From 015e91781933c0fc2669986d78e96f8af76943ce Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Tue, 21 Dec 2021 11:51:55 +0000
Subject: [PATCH 28/49] Update README: include subnetwork and low-rank Laplace
& update paper reference to NeurIPS'21
---
README.md | 33 +++++++++++++++++++++------------
1 file changed, 21 insertions(+), 12 deletions(-)
diff --git a/README.md b/README.md
index 079d1229..7d72853d 100644
--- a/README.md
+++ b/README.md
@@ -4,17 +4,17 @@
[![Main](https://travis-ci.com/AlexImmer/Laplace.svg?token=rpuRxEjQS6cCZi7ptL9y&branch=main)](https://travis-ci.com/AlexImmer/Laplace)
-The laplace package facilitates the application of Laplace approximations for entire neural networks or just their last layer.
+The laplace package facilitates the application of Laplace approximations for entire neural networks (NNs), subnetworks of NNs, or just their last layer.
The package enables posterior approximations, marginal-likelihood estimation, and various posterior predictive computations.
The library documentation is available at [https://aleximmer.github.io/Laplace](https://aleximmer.github.io/Laplace).
There is also a corresponding paper, [*Laplace Redux — Effortless Bayesian Deep Learning*](https://arxiv.org/abs/2106.14806), which introduces the library, provides an introduction to the Laplace approximation, reviews its use in deep learning, and empirically demonstrates its versatility and competitiveness. Please consider referring to the paper when using our library:
```bibtex
-@article{daxberger2021laplace,
- title={Laplace Redux--Effortless Bayesian Deep Learning},
- author={Daxberger, Erik and Kristiadi, Agustinus and Immer, Alexander
- and Eschenhagen, Runa and Bauer, Matthias and Hennig, Philipp},
- journal={arXiv preprint arXiv:2106.14806},
+@inproceedings{laplace2021,
+ title={Laplace Redux--Effortless {B}ayesian Deep Learning},
+ author={Erik Daxberger and Agustinus Kristiadi and Alexander Immer
+ and Runa Eschenhagen and Matthias Bauer and Philipp Hennig},
+ booktitle={{N}eur{IPS}},
year={2021}
}
```
@@ -39,7 +39,7 @@ pytest tests/
## Structure
The laplace package consists of two main components:
-1. The subclasses of [`laplace.BaseLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/baselaplace.py) that implement different sparsity structures: different subsets of weights (`'all'` and `'last_layer'`) and different structures of the Hessian approximation (`'full'`, `'kron'`, and `'diag'`). This results in six currently available options: `laplace.FullLaplace`, `laplace.KronLaplace`, `laplace.DiagLaplace`, and the corresponding last-layer variations `laplace.FullLLLaplace`, `laplace.KronLLLaplace`, and `laplace.DiagLLLaplace`, which are all subclasses of [`laplace.LLLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/lllaplace.py). All of these can be conveniently accessed via the [`laplace.Laplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/laplace.py) function.
+1. The subclasses of [`laplace.BaseLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/baselaplace.py) that implement different sparsity structures: different subsets of weights (`'all'`, `'subnetwork'` and `'last_layer'`) and different structures of the Hessian approximation (`'full'`, `'kron'`, `'lowrank'` and `'diag'`). This results in eight currently available options: `laplace.FullLaplace`, `laplace.KronLaplace`, `laplace.DiagLaplace`, the corresponding last-layer variations `laplace.FullLLLaplace`, `laplace.KronLLLaplace`, and `laplace.DiagLLLaplace` (which are all subclasses of [`laplace.LLLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/lllaplace.py)), `laplace.SubnetLaplace` (which only supports a `'full'` Hessian approximation) and `laplace.LowRankLaplace` (which only supports inference over `'all'` weights). All of these can be conveniently accessed via the [`laplace.Laplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/laplace.py) function.
2. The backends in [`laplace.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/) which provide access to Hessian approximations of
the corresponding sparsity structures, for example, the diagonal GGN.
@@ -48,9 +48,15 @@ decomposing a neural network into feature extractor and last layer for `LLLaplac
and
effectively dealing with Kronecker factors ([`laplace.matrix`](https://github.com/AlexImmer/Laplace/blob/main/laplace/matrix.py)).
+Finally, the package implements several options to select/specify the subnetwork for `laplace.SubnetLaplace` (as subclasses of ([`laplace.subnetmask.SubnetMask`](https://github.com/AlexImmer/Laplace/blob/main/laplace/subnetmask.py)).
+Automatic subnetwork selection strategies include: uniformly at random (`laplace.subnetmask.RandomSubnetMask`), by largest parameter magnitudes (`laplace.subnetmask.LargestMagnitudeSubnetMask`) and by largest marginal parameter variances (`laplace.subnetmask.LargestVarianceDiagLaplaceSubnetMask` and `laplace.subnetmask.LargestVarianceSWAGSubnetMask`).
+In addition to that, subnetworks can also be specified manually, by listing the names of either the model parameters (`laplace.subnetmask.ParamNameSubnetMask`) or modules (`laplace.subnetmask.ModuleNameSubnetMask`) to perform Laplace inference over.
+
## Extendability
To extend the laplace package, new `BaseLaplace` subclasses can be designed, for example,
-a block-diagonal structure or subset-of-weights Laplace.
+Laplace with a block-diagonal Hessian structure.
+One can also implement custom subnetwork selection strategies as new subclasses of `SubnetMask`.
+
Alternatively, extending or integrating backends (subclasses of [`curvature.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/curvature.py)) allows to provide different Hessian
approximations to the Laplace approximations.
For example, currently the [`curvature.BackPackInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/backpack.py) based on [BackPACK](https://github.com/f-dangel/backpack/) and [`curvature.AsdlInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/asdl.py) based on [ASDL](https://github.com/kazukiosawa/asdfghjkl) are available.
@@ -60,10 +66,11 @@ for a regression (MSELoss) loss function.
## Example usage
-### *Post-hoc* prior precision tuning of last-layer LA
+### *Post-hoc* prior precision tuning of diagonal LA
In the following example, a pre-trained model is loaded,
-then the Laplace approximation is fit to the training data,
+then the Laplace approximation is fit to the training data
+(using a diagonal Hessian approximation over all parameters),
and the prior precision is optimized with cross-validation `'CV'`.
After that, the resulting LA is used for prediction with
the `'probit'` predictive for classification.
@@ -122,7 +129,7 @@ pdoc --http 0.0.0.0:8080 laplace --template-dir template
## References
-This package relies on various improvements to the Laplace approximation for neural networks, which was originally due to MacKay [1].
+This package relies on various improvements to the Laplace approximation for neural networks, which was originally due to MacKay [1]. Please consider citing the respective papers if you use any of their proposed methods via our laplace library.
- [1] MacKay, DJC. [*A Practical Bayesian Framework for Backpropagation Networks*](https://authors.library.caltech.edu/13793/). Neural Computation 1992.
- [2] Gibbs, M. N. [*Bayesian Gaussian Processes for Regression and Classification*](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.147.1130&rep=rep1&type=pdf). PhD Thesis 1997.
@@ -132,4 +139,6 @@ This package relies on various improvements to the Laplace approximation for neu
- [6] Khan, M. E., Immer, A., Abedi, E., Korzepa, M. [*Approximate Inference Turns Deep Networks into Gaussian Processes*](https://arxiv.org/abs/1906.01930). NeurIPS 2019.
- [7] Kristiadi, A., Hein, M., Hennig, P. [*Being Bayesian, Even Just a Bit, Fixes Overconfidence in ReLU Networks*](https://arxiv.org/abs/2002.10118). ICML 2020.
- [8] Immer, A., Korzepa, M., Bauer, M. [*Improving predictions of Bayesian neural nets via local linearization*](https://arxiv.org/abs/2008.08400). AISTATS 2021.
-- [9] Immer, A., Bauer, M., Fortuin, V., Rätsch, G., Khan, EM. [*Scalable Marginal Likelihood Estimation for Model Selection in Deep Learning*](https://arxiv.org/abs/2104.04975). ICML 2021.
+- [9] Sharma, A., Azizan, N., Pavone, M. [*Sketching Curvature for Efficient Out-of-Distribution Detection for Deep Neural Networks*](https://arxiv.org/abs/2102.12567). UAI 2021.
+- [10] Immer, A., Bauer, M., Fortuin, V., Rätsch, G., Khan, EM. [*Scalable Marginal Likelihood Estimation for Model Selection in Deep Learning*](https://arxiv.org/abs/2104.04975). ICML 2021.
+- [11] Daxberger, E., Nalisnick, E., Allingham, JU., Antorán, J., Hernández-Lobato, JM. [*Bayesian Deep Learning via Subnetwork Inference*](https://arxiv.org/abs/2010.14689). ICML 2021.
\ No newline at end of file
From fce8545d76a3b53a56a112e909460b1f0efe1666 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Tue, 21 Dec 2021 11:57:33 +0000
Subject: [PATCH 29/49] Minor changes to README
---
README.md | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/README.md b/README.md
index 7d72853d..1773728e 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@
[![Main](https://travis-ci.com/AlexImmer/Laplace.svg?token=rpuRxEjQS6cCZi7ptL9y&branch=main)](https://travis-ci.com/AlexImmer/Laplace)
-The laplace package facilitates the application of Laplace approximations for entire neural networks (NNs), subnetworks of NNs, or just their last layer.
+The laplace package facilitates the application of Laplace approximations for entire neural networks, subnetworks of neural networks, or just their last layer.
The package enables posterior approximations, marginal-likelihood estimation, and various posterior predictive computations.
The library documentation is available at [https://aleximmer.github.io/Laplace](https://aleximmer.github.io/Laplace).
@@ -39,7 +39,7 @@ pytest tests/
## Structure
The laplace package consists of two main components:
-1. The subclasses of [`laplace.BaseLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/baselaplace.py) that implement different sparsity structures: different subsets of weights (`'all'`, `'subnetwork'` and `'last_layer'`) and different structures of the Hessian approximation (`'full'`, `'kron'`, `'lowrank'` and `'diag'`). This results in eight currently available options: `laplace.FullLaplace`, `laplace.KronLaplace`, `laplace.DiagLaplace`, the corresponding last-layer variations `laplace.FullLLLaplace`, `laplace.KronLLLaplace`, and `laplace.DiagLLLaplace` (which are all subclasses of [`laplace.LLLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/lllaplace.py)), `laplace.SubnetLaplace` (which only supports a `'full'` Hessian approximation) and `laplace.LowRankLaplace` (which only supports inference over `'all'` weights). All of these can be conveniently accessed via the [`laplace.Laplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/laplace.py) function.
+1. The subclasses of [`laplace.BaseLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/baselaplace.py) that implement different sparsity structures: different subsets of weights (`'all'`, `'subnetwork'` and `'last_layer'`) and different structures of the Hessian approximation (`'full'`, `'kron'`, `'lowrank'` and `'diag'`). This results in _eight_ currently available options: `laplace.FullLaplace`, `laplace.KronLaplace`, `laplace.DiagLaplace`, the corresponding last-layer variations `laplace.FullLLLaplace`, `laplace.KronLLLaplace`, and `laplace.DiagLLLaplace` (which are all subclasses of [`laplace.LLLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/lllaplace.py)), `laplace.SubnetLaplace` (which only supports a `'full'` Hessian approximation) and `laplace.LowRankLaplace` (which only supports inference over `'all'` weights). All of these can be conveniently accessed via the [`laplace.Laplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/laplace.py) function.
2. The backends in [`laplace.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/) which provide access to Hessian approximations of
the corresponding sparsity structures, for example, the diagonal GGN.
@@ -48,9 +48,9 @@ decomposing a neural network into feature extractor and last layer for `LLLaplac
and
effectively dealing with Kronecker factors ([`laplace.matrix`](https://github.com/AlexImmer/Laplace/blob/main/laplace/matrix.py)).
-Finally, the package implements several options to select/specify the subnetwork for `laplace.SubnetLaplace` (as subclasses of ([`laplace.subnetmask.SubnetMask`](https://github.com/AlexImmer/Laplace/blob/main/laplace/subnetmask.py)).
-Automatic subnetwork selection strategies include: uniformly at random (`laplace.subnetmask.RandomSubnetMask`), by largest parameter magnitudes (`laplace.subnetmask.LargestMagnitudeSubnetMask`) and by largest marginal parameter variances (`laplace.subnetmask.LargestVarianceDiagLaplaceSubnetMask` and `laplace.subnetmask.LargestVarianceSWAGSubnetMask`).
-In addition to that, subnetworks can also be specified manually, by listing the names of either the model parameters (`laplace.subnetmask.ParamNameSubnetMask`) or modules (`laplace.subnetmask.ModuleNameSubnetMask`) to perform Laplace inference over.
+Finally, the package implements several options to select/specify a subnetwork for `laplace.SubnetLaplace` (as subclasses of [`laplace.subnetmask.SubnetMask`](https://github.com/AlexImmer/Laplace/blob/main/laplace/subnetmask.py).
+Automatic subnetwork selection strategies include: uniformly at random (`laplace.subnetmask.RandomSubnetMask`), by largest parameter magnitudes (`LargestMagnitudeSubnetMask`), and by largest marginal parameter variances (`LargestVarianceDiagLaplaceSubnetMask` and `LargestVarianceSWAGSubnetMask`).
+In addition to that, subnetworks can also be specified manually, by listing the names of either the model parameters (`ParamNameSubnetMask`) or modules (`ModuleNameSubnetMask`) to perform Laplace inference over.
## Extendability
To extend the laplace package, new `BaseLaplace` subclasses can be designed, for example,
From 55e418a3b561f3a5c874aa8086a1350ecad61d2b Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Mon, 3 Jan 2022 10:40:19 +0000
Subject: [PATCH 30/49] Move utility files to new utils directory
---
README.md | 8 ++++----
laplace/baselaplace.py | 8 ++++----
laplace/curvature/asdl.py | 4 ++--
laplace/curvature/backpack.py | 2 +-
laplace/curvature/curvature.py | 8 ++++----
laplace/lllaplace.py | 8 ++++----
laplace/marglik_training.py | 2 +-
laplace/subnetlaplace.py | 6 +++---
laplace/{ => utils}/feature_extractor.py | 0
laplace/{ => utils}/matrix.py | 2 +-
laplace/{ => utils}/subnetmask.py | 4 ++--
laplace/{ => utils}/swag.py | 0
laplace/{ => utils}/utils.py | 0
tests/test_baselaplace.py | 2 +-
tests/test_feature_extractor.py | 2 +-
tests/test_jacobians.py | 3 +--
tests/test_lllaplace.py | 4 ++--
tests/test_matrix.py | 6 +++---
tests/test_subnetlaplace.py | 2 +-
tests/test_utils.py | 2 +-
20 files changed, 36 insertions(+), 37 deletions(-)
rename laplace/{ => utils}/feature_extractor.py (100%)
rename laplace/{ => utils}/matrix.py (99%)
rename laplace/{ => utils}/subnetmask.py (99%)
rename laplace/{ => utils}/swag.py (100%)
rename laplace/{ => utils}/utils.py (100%)
diff --git a/README.md b/README.md
index 1773728e..fbe90ee0 100644
--- a/README.md
+++ b/README.md
@@ -44,12 +44,12 @@ The laplace package consists of two main components:
the corresponding sparsity structures, for example, the diagonal GGN.
Additionally, the package provides utilities for
-decomposing a neural network into feature extractor and last layer for `LLLaplace` subclasses ([`laplace.feature_extractor`](https://github.com/AlexImmer/Laplace/blob/main/laplace/feature_extractor.py))
+decomposing a neural network into feature extractor and last layer for `LLLaplace` subclasses ([`laplace.utils.feature_extractor`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/feature_extractor.py))
and
-effectively dealing with Kronecker factors ([`laplace.matrix`](https://github.com/AlexImmer/Laplace/blob/main/laplace/matrix.py)).
+effectively dealing with Kronecker factors ([`laplace.utils.matrix`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/matrix.py)).
-Finally, the package implements several options to select/specify a subnetwork for `laplace.SubnetLaplace` (as subclasses of [`laplace.subnetmask.SubnetMask`](https://github.com/AlexImmer/Laplace/blob/main/laplace/subnetmask.py).
-Automatic subnetwork selection strategies include: uniformly at random (`laplace.subnetmask.RandomSubnetMask`), by largest parameter magnitudes (`LargestMagnitudeSubnetMask`), and by largest marginal parameter variances (`LargestVarianceDiagLaplaceSubnetMask` and `LargestVarianceSWAGSubnetMask`).
+Finally, the package implements several options to select/specify a subnetwork for `laplace.SubnetLaplace` (as subclasses of [`laplace.utils.subnetmask.SubnetMask`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/subnetmask.py).
+Automatic subnetwork selection strategies include: uniformly at random (`laplace.utils.subnetmask.RandomSubnetMask`), by largest parameter magnitudes (`LargestMagnitudeSubnetMask`), and by largest marginal parameter variances (`LargestVarianceDiagLaplaceSubnetMask` and `LargestVarianceSWAGSubnetMask`).
In addition to that, subnetworks can also be specified manually, by listing the names of either the model parameters (`ParamNameSubnetMask`) or modules (`ModuleNameSubnetMask`) to perform Laplace inference over.
## Extendability
diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py
index 522fe047..2c382d98 100644
--- a/laplace/baselaplace.py
+++ b/laplace/baselaplace.py
@@ -4,8 +4,8 @@
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.distributions import MultivariateNormal, Dirichlet, Normal
-from laplace.utils import parameters_per_layer, invsqrt_precision, get_nll, validate
-from laplace.matrix import Kron
+from laplace.utils.utils import parameters_per_layer, invsqrt_precision, get_nll, validate
+from laplace.utils.matrix import Kron
from laplace.curvature import BackPackGGN, AsdlHessian
@@ -754,7 +754,7 @@ class KronLaplace(ParametricLaplace):
Mathematically, we have for each parameter group, e.g., torch.nn.Module,
that \\P\\approx Q \\otimes H\\.
See `BaseLaplace` for the full interface and see
- `laplace.matrix.Kron` and `laplace.matrix.KronDecomposed` for the structure of
+ `laplace.utils.matrix.Kron` and `laplace.utils.matrix.KronDecomposed` for the structure of
the Kronecker factors. `Kron` is used to aggregate factors by summing up and
`KronDecomposed` is used to add the prior, a Hessian factor (e.g. temperature),
and computing posterior covariances, marginal likelihood, etc.
@@ -812,7 +812,7 @@ def posterior_precision(self):
Returns
-------
- precision : `laplace.matrix.KronDecomposed`
+ precision : `laplace.utils.matrix.KronDecomposed`
"""
self._check_H_init()
return self.H * self._H_factor + self.prior_precision
diff --git a/laplace/curvature/asdl.py b/laplace/curvature/asdl.py
index fc7eda9a..a3ccd3d6 100644
--- a/laplace/curvature/asdl.py
+++ b/laplace/curvature/asdl.py
@@ -9,8 +9,8 @@
from asdfghjkl.gradient import batch_gradient
from laplace.curvature import CurvatureInterface, GGNInterface, EFInterface
-from laplace.matrix import Kron
-from laplace.utils import _is_batchnorm
+from laplace.utils.matrix import Kron
+from laplace.utils.utils import _is_batchnorm
EPS = 1e-6
diff --git a/laplace/curvature/backpack.py b/laplace/curvature/backpack.py
index a0885800..5c78a093 100644
--- a/laplace/curvature/backpack.py
+++ b/laplace/curvature/backpack.py
@@ -5,7 +5,7 @@
from backpack.context import CTX
from laplace.curvature import CurvatureInterface, GGNInterface, EFInterface
-from laplace.matrix import Kron
+from laplace.utils.matrix import Kron
class BackPackInterface(CurvatureInterface):
diff --git a/laplace/curvature/curvature.py b/laplace/curvature/curvature.py
index 98b703b7..72b0a041 100644
--- a/laplace/curvature/curvature.py
+++ b/laplace/curvature/curvature.py
@@ -11,7 +11,7 @@ class CurvatureInterface:
Parameters
----------
- model : torch.nn.Module or `laplace.feature_extractor.FeatureExtractor`
+ model : torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor`
torch model (neural network)
likelihood : {'classification', 'regression'}
last_layer : bool, default=False
@@ -143,7 +143,7 @@ def kron(self, x, y, **kwargs):
Returns
-------
loss : torch.Tensor
- H : `laplace.matrix.Kron`
+ H : `laplace.utils.matrix.Kron`
Kronecker factored Hessian approximation.
"""
raise NotImplementedError
@@ -175,7 +175,7 @@ class GGNInterface(CurvatureInterface):
Parameters
----------
- model : torch.nn.Module or `laplace.feature_extractor.FeatureExtractor`
+ model : torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor`
torch model (neural network)
likelihood : {'classification', 'regression'}
last_layer : bool, default=False
@@ -254,7 +254,7 @@ class EFInterface(CurvatureInterface):
Parameters
----------
- model : torch.nn.Module or `laplace.feature_extractor.FeatureExtractor`
+ model : torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor`
torch model (neural network)
likelihood : {'classification', 'regression'}
last_layer : bool, default=False
diff --git a/laplace/lllaplace.py b/laplace/lllaplace.py
index b00dbf55..8f93ff53 100644
--- a/laplace/lllaplace.py
+++ b/laplace/lllaplace.py
@@ -3,9 +3,9 @@
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from laplace.baselaplace import ParametricLaplace, FullLaplace, KronLaplace, DiagLaplace
-from laplace.feature_extractor import FeatureExtractor
+from laplace.utils.feature_extractor import FeatureExtractor
-from laplace.matrix import Kron
+from laplace.utils.matrix import Kron
from laplace.curvature import BackPackGGN
@@ -36,7 +36,7 @@ class LLLaplace(ParametricLaplace):
Parameters
----------
- model : torch.nn.Module or `laplace.feature_extractor.FeatureExtractor`
+ model : torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor`
likelihood : {'classification', 'regression'}
determines the log likelihood Hessian approximation
sigma_noise : torch.Tensor or float, default=1
@@ -168,7 +168,7 @@ class KronLLLaplace(LLLaplace, KronLaplace):
Mathematically, we have for the last parameter group, i.e., torch.nn.Linear,
that \\P\\approx Q \\otimes H\\.
See `KronLaplace`, `LLLaplace`, and `BaseLaplace` for the full interface and see
- `laplace.matrix.Kron` and `laplace.matrix.KronDecomposed` for the structure of
+ `laplace.utils.matrix.Kron` and `laplace.utils.matrix.KronDecomposed` for the structure of
the Kronecker factors. `Kron` is used to aggregate factors by summing up and
`KronDecomposed` is used to add the prior, a Hessian factor (e.g. temperature),
and computing posterior covariances, marginal likelihood, etc.
diff --git a/laplace/marglik_training.py b/laplace/marglik_training.py
index ec100542..3a2b2161 100644
--- a/laplace/marglik_training.py
+++ b/laplace/marglik_training.py
@@ -9,7 +9,7 @@
from laplace import Laplace
from laplace.curvature import AsdlGGN
-from laplace.utils import expand_prior_precision
+from laplace.utils.utils import expand_prior_precision
def marglik_training(
diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py
index 3aeb2454..dc952d0a 100644
--- a/laplace/subnetlaplace.py
+++ b/laplace/subnetlaplace.py
@@ -4,7 +4,7 @@
from laplace.baselaplace import FullLaplace, DiagLaplace
from laplace.curvature import BackPackGGN
-from laplace.subnetmask import LargestVarianceDiagLaplaceSubnetMask
+from laplace.utils.subnetmask import LargestVarianceDiagLaplaceSubnetMask
__all__ = ['SubnetLaplace']
@@ -46,10 +46,10 @@ class SubnetLaplace(FullLaplace):
Parameters
----------
- model : torch.nn.Module or `laplace.feature_extractor.FeatureExtractor`
+ model : torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor`
likelihood : {'classification', 'regression'}
determines the log likelihood Hessian approximation
- subnetwork_mask : subclasses of `laplace.subnetmask.SubnetMask`, default=None
+ subnetwork_mask : subclasses of `laplace.utils.subnetmask.SubnetMask`, default=None
mask defining the subnetwork to apply the Laplace approximation over
sigma_noise : torch.Tensor or float, default=1
observation noise for the regression setting; must be 1 for classification
diff --git a/laplace/feature_extractor.py b/laplace/utils/feature_extractor.py
similarity index 100%
rename from laplace/feature_extractor.py
rename to laplace/utils/feature_extractor.py
diff --git a/laplace/matrix.py b/laplace/utils/matrix.py
similarity index 99%
rename from laplace/matrix.py
rename to laplace/utils/matrix.py
index 61c07ab5..30b0245d 100644
--- a/laplace/matrix.py
+++ b/laplace/utils/matrix.py
@@ -3,7 +3,7 @@
import numpy as np
from typing import Union
-from laplace.utils import _is_valid_scalar, symeig, kron, block_diag
+from laplace.utils.utils import _is_valid_scalar, symeig, kron, block_diag
class Kron:
diff --git a/laplace/subnetmask.py b/laplace/utils/subnetmask.py
similarity index 99%
rename from laplace/subnetmask.py
rename to laplace/utils/subnetmask.py
index 939cff97..a5cdc3d0 100644
--- a/laplace/subnetmask.py
+++ b/laplace/utils/subnetmask.py
@@ -4,8 +4,8 @@
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn.utils import parameters_to_vector
-from laplace.feature_extractor import FeatureExtractor
-from laplace.swag import fit_diagonal_swag
+from laplace.utils.feature_extractor import FeatureExtractor
+from laplace.utils.swag import fit_diagonal_swag
__all__ = ['SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask', 'LargestVarianceDiagLaplaceSubnetMask', 'LargestVarianceSWAGSubnetMask', 'ParamNameSubnetMask', 'ModuleNameSubnetMask', 'LastLayerSubnetMask']
diff --git a/laplace/swag.py b/laplace/utils/swag.py
similarity index 100%
rename from laplace/swag.py
rename to laplace/utils/swag.py
diff --git a/laplace/utils.py b/laplace/utils/utils.py
similarity index 100%
rename from laplace/utils.py
rename to laplace/utils/utils.py
diff --git a/tests/test_baselaplace.py b/tests/test_baselaplace.py
index 75529be8..a9292e8d 100644
--- a/tests/test_baselaplace.py
+++ b/tests/test_baselaplace.py
@@ -12,7 +12,7 @@
from torchvision.models import wide_resnet50_2
from laplace.laplace import FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace
-from laplace.matrix import KronDecomposed
+from laplace.utils.matrix import KronDecomposed
from tests.utils import jacobians_naive
diff --git a/tests/test_feature_extractor.py b/tests/test_feature_extractor.py
index 37494d76..b80bbcb4 100644
--- a/tests/test_feature_extractor.py
+++ b/tests/test_feature_extractor.py
@@ -2,7 +2,7 @@
import torch.nn as nn
import torchvision.models as models
-from laplace.feature_extractor import FeatureExtractor
+from laplace.utils.feature_extractor import FeatureExtractor
class CNN(nn.Module):
diff --git a/tests/test_jacobians.py b/tests/test_jacobians.py
index 45cd2f37..0495adb3 100644
--- a/tests/test_jacobians.py
+++ b/tests/test_jacobians.py
@@ -1,10 +1,9 @@
import pytest
import torch
from torch import nn
-from torch.nn.utils import parameters_to_vector
from laplace.curvature import AsdlInterface, BackPackInterface
-from laplace.feature_extractor import FeatureExtractor
+from laplace.utils.feature_extractor import FeatureExtractor
from tests.utils import jacobians_naive
diff --git a/tests/test_lllaplace.py b/tests/test_lllaplace.py
index 65fbf1a3..bc6f7a5e 100644
--- a/tests/test_lllaplace.py
+++ b/tests/test_lllaplace.py
@@ -8,8 +8,8 @@
from torch.distributions import Normal, Categorical
from torchvision.models import wide_resnet50_2
-from laplace.lllaplace import LLLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace
-from laplace.feature_extractor import FeatureExtractor
+from laplace.lllaplace import FullLLLaplace, KronLLLaplace, DiagLLLaplace
+from laplace.utils.feature_extractor import FeatureExtractor
from tests.utils import jacobians_naive
diff --git a/tests/test_matrix.py b/tests/test_matrix.py
index fb5bef1e..66a5da48 100644
--- a/tests/test_matrix.py
+++ b/tests/test_matrix.py
@@ -4,10 +4,10 @@
from torch import nn
from torch.nn.utils import parameters_to_vector
-from laplace.matrix import Kron, KronDecomposed
-from laplace.utils import kron as kron_prod
+from laplace.utils.matrix import Kron
+from laplace.utils.utils import kron as kron_prod
from laplace.curvature import BackPackGGN
-from laplace.utils import block_diag
+from laplace.utils.utils import block_diag
from tests.utils import get_psd_matrix, jacobians_naive
diff --git a/tests/test_subnetlaplace.py b/tests/test_subnetlaplace.py
index 1b10bb25..9007a7a5 100644
--- a/tests/test_subnetlaplace.py
+++ b/tests/test_subnetlaplace.py
@@ -7,7 +7,7 @@
from torch.utils.data import DataLoader, TensorDataset
from laplace import Laplace, SubnetLaplace
-from laplace.subnetmask import SubnetMask, RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask, ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask
+from laplace.utils.subnetmask import SubnetMask, RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask, ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask
torch.manual_seed(240)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 1ad0f517..b673be3d 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,5 +1,5 @@
import torch
-from laplace.utils import invsqrt_precision, diagonal_add_scalar, symeig
+from laplace.utils.utils import invsqrt_precision, diagonal_add_scalar, symeig
def test_sqrt_precision():
From 6044316168ad38cd77e017f1d6facb6745df42b7 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Mon, 3 Jan 2022 14:56:51 +0000
Subject: [PATCH 31/49] Change SubnetLaplace to take subnetwork indices instead
of a subclass of SubnetMask
---
laplace/subnetlaplace.py | 62 ++++-----
laplace/utils/subnetmask.py | 13 +-
tests/test_subnetlaplace.py | 243 ++++++++++++++++++++----------------
3 files changed, 174 insertions(+), 144 deletions(-)
diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py
index dc952d0a..25c5ae96 100644
--- a/laplace/subnetlaplace.py
+++ b/laplace/subnetlaplace.py
@@ -1,10 +1,8 @@
import torch
from torch.distributions import MultivariateNormal
-from laplace.baselaplace import FullLaplace, DiagLaplace
-
+from laplace.baselaplace import FullLaplace
from laplace.curvature import BackPackGGN
-from laplace.utils.subnetmask import LargestVarianceDiagLaplaceSubnetMask
__all__ = ['SubnetLaplace']
@@ -49,8 +47,9 @@ class SubnetLaplace(FullLaplace):
model : torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor`
likelihood : {'classification', 'regression'}
determines the log likelihood Hessian approximation
- subnetwork_mask : subclasses of `laplace.utils.subnetmask.SubnetMask`, default=None
- mask defining the subnetwork to apply the Laplace approximation over
+ subnetwork_indices : torch.Tensor, default=None
+ indices of the vectorized model parameters that define the subnetwork
+ to apply the Laplace approximation over
sigma_noise : torch.Tensor or float, default=1
observation noise for the regression setting; must be 1 for classification
prior_precision : torch.Tensor or float, default=1
@@ -66,29 +65,38 @@ class SubnetLaplace(FullLaplace):
backend_kwargs : dict, default=None
arguments passed to the backend on initialization, for example to
set the number of MC samples for stochastic approximations.
- subnetmask_kwargs : dict, default=None
- arguments passed to the subnetwork mask on initialization.
"""
# key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure)
_key = ('subnetwork', 'full')
- def __init__(self, model, likelihood, subnetwork_mask, sigma_noise=1., prior_precision=1.,
- prior_mean=0., temperature=1., backend=BackPackGGN, backend_kwargs=None, subnetmask_kwargs=None):
+ def __init__(self, model, likelihood, subnetwork_indices=None, sigma_noise=1., prior_precision=1.,
+ prior_mean=0., temperature=1., backend=BackPackGGN, backend_kwargs=None):
self.H = None
super().__init__(model, likelihood, sigma_noise=sigma_noise, prior_precision=prior_precision,
prior_mean=prior_mean, temperature=temperature, backend=backend,
backend_kwargs=backend_kwargs)
- self._subnetmask_kwargs = dict() if subnetmask_kwargs is None else subnetmask_kwargs
- if subnetwork_mask == LargestVarianceDiagLaplaceSubnetMask:
- # instantiate and pass diagonal Laplace model for largest variance subnetwork selection
- self._subnetmask_kwargs.update(diag_laplace_model=DiagLaplace(self.model, likelihood, sigma_noise,
- prior_precision, prior_mean, temperature, backend, backend_kwargs))
- self._subnetwork_mask = subnetwork_mask(self.model, **self._subnetmask_kwargs)
- self.n_params_subnet = None
+ # check validity of subnetwork indices and pass them to backend
+ self._check_subnetwork_indices(subnetwork_indices)
+ self.backend.subnetwork_indices = subnetwork_indices
+ self.n_params_subnet = len(subnetwork_indices)
def _init_H(self):
self.H = torch.zeros(self.n_params_subnet, self.n_params_subnet, device=self._device)
+ def _check_subnetwork_indices(self, subnetwork_indices):
+ """Check that subnetwork indices are valid indices of the vectorized model parameters.
+ """
+ if subnetwork_indices is None:
+ raise ValueError('Subnetwork indices cannot be None.')
+ elif not (isinstance(subnetwork_indices, torch.Tensor) and len(subnetwork_indices.shape) == 1 and\
+ subnetwork_indices.dtype in [torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8]):
+ raise ValueError('Subnetwork indices need to be 1-dimensional integral torch.Tensor!')
+ elif not (len(subnetwork_indices[subnetwork_indices < 0]) == 0 and\
+ len(subnetwork_indices[subnetwork_indices >= self.n_params]) == 0):
+ raise ValueError(f'Subnetwork indices must lie between 0 and n_params={self.n_params}.')
+ elif not (subnetwork_indices.sort()[0].equal(torch.unique(subnetwork_indices, sorted=True))):
+ raise ValueError('Subnetwork indices must be unique.')
+
@property
def prior_precision_diag(self):
"""Obtain the diagonal prior precision \\(p_0\\) constructed from either
@@ -107,31 +115,13 @@ def prior_precision_diag(self):
else:
raise ValueError('Mismatch of prior and model. Diagonal or scalar prior.')
- def fit(self, train_loader):
- """Fit the local Laplace approximation at the parameters of the subnetwork.
-
- Parameters
- ----------
- train_loader : torch.data.utils.DataLoader
- each iterate is a training batch (X, y);
- `train_loader.dataset` needs to be set to access \\(N\\), size of the data set
- """
-
- # select subnetwork and pass it to backend
- self._subnetwork_mask.select(train_loader)
- self.backend.subnetwork_indices = self._subnetwork_mask.indices
- self.n_params_subnet = self._subnetwork_mask.n_params_subnet
-
- # fit Laplace approximation over subnetwork
- super().fit(train_loader)
-
def sample(self, n_samples=100):
# sample parameters just of the subnetwork
- subnet_mean = self.mean[self._subnetwork_mask.indices]
+ subnet_mean = self.mean[self.backend.subnetwork_indices]
dist = MultivariateNormal(loc=subnet_mean, scale_tril=self.posterior_scale)
subnet_samples = dist.sample((n_samples,))
# set all other parameters to their MAP estimates
full_samples = self.mean.repeat(n_samples, 1)
- full_samples[:, self._subnetwork_mask.indices] = subnet_samples
+ full_samples[:, self.backend.subnetwork_indices] = subnet_samples
return full_samples
diff --git a/laplace/utils/subnetmask.py b/laplace/utils/subnetmask.py
index a5cdc3d0..f1213d55 100644
--- a/laplace/utils/subnetmask.py
+++ b/laplace/utils/subnetmask.py
@@ -68,12 +68,12 @@ def convert_subnet_mask_to_indices(self, subnet_mask):
subnet_mask_indices = subnet_mask.nonzero(as_tuple=True)[0]
return subnet_mask_indices
- def select(self, train_loader):
+ def select(self, train_loader=None):
""" Select the subnetwork mask.
Parameters
----------
- train_loader : torch.data.utils.DataLoader
+ train_loader : torch.data.utils.DataLoader, default=None
each iterate is a training batch (X, y);
`train_loader.dataset` needs to be set to access \\(N\\), size of the data set
"""
@@ -170,6 +170,9 @@ def __init__(self, model, n_params_subnet, diag_laplace_model):
self.diag_laplace_model = diag_laplace_model
def compute_param_scores(self, train_loader):
+ if train_loader is None:
+ raise ValueError('Need to pass train loader for subnet selection.')
+
self.diag_laplace_model.fit(train_loader)
return self.diag_laplace_model.posterior_variance
@@ -200,6 +203,9 @@ def __init__(self, model, n_params_subnet, likelihood='classification', swag_n_s
self.swag_lr = swag_lr
def compute_param_scores(self, train_loader):
+ if train_loader is None:
+ raise ValueError('Need to pass train loader for subnet selection.')
+
if self.likelihood == 'classification':
criterion = CrossEntropyLoss(reduction='mean')
elif self.likelihood == 'regression':
@@ -315,6 +321,9 @@ def __init__(self, model, last_layer_name=None):
def get_subnet_mask(self, train_loader):
""" Get the subnetwork mask identifying the last layer."""
+ if train_loader is None:
+ raise ValueError('Need to pass train loader for subnet selection.')
+
self._feature_extractor.eval()
if self._feature_extractor.last_layer is None:
X = next(iter(train_loader))[0]
diff --git a/tests/test_subnetlaplace.py b/tests/test_subnetlaplace.py
index 9007a7a5..1a024a75 100644
--- a/tests/test_subnetlaplace.py
+++ b/tests/test_subnetlaplace.py
@@ -7,6 +7,7 @@
from torch.utils.data import DataLoader, TensorDataset
from laplace import Laplace, SubnetLaplace
+from laplace.baselaplace import DiagLaplace
from laplace.utils.subnetmask import SubnetMask, RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask, ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask
@@ -42,77 +43,86 @@ def reg_loader():
@pytest.mark.parametrize('likelihood', likelihoods)
def test_subnet_laplace_init(model, likelihood):
- # use last-layer subnet mask for this test
- subnetwork_mask = LastLayerSubnetMask
+ # use random subnet mask for this test
+ subnetwork_mask = RandomSubnetMask
+ subnetmask_kwargs = dict(model=model, n_params_subnet=10)
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
+ subnetmask.select()
# subnet Laplace with full Hessian should work
hessian_structure = 'full'
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure=hessian_structure)
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure)
assert isinstance(lap, SubnetLaplace)
# subnet Laplace with diag, kron or lowrank Hessians should raise errors
hessian_structure = 'diag'
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure=hessian_structure)
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure)
hessian_structure = 'kron'
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure=hessian_structure)
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure)
hessian_structure = 'lowrank'
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure=hessian_structure)
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure)
@pytest.mark.parametrize('subnetwork_mask,likelihood', product(score_based_subnet_masks, likelihoods))
def test_score_based_subnet_masks(model, likelihood, subnetwork_mask, class_loader, reg_loader):
loader = class_loader if likelihood == 'classification' else reg_loader
model_params = parameters_to_vector(model.parameters())
- subnetmask_kwargs = dict(likelihood=likelihood) if subnetwork_mask == LargestVarianceSWAGSubnetMask else dict()
+
+ # set subnetwork mask arguments
+ if subnetwork_mask == LargestVarianceDiagLaplaceSubnetMask:
+ diag_laplace_model = DiagLaplace(model, likelihood)
+ subnetmask_kwargs = dict(model=model, diag_laplace_model=diag_laplace_model)
+ elif subnetwork_mask == LargestVarianceSWAGSubnetMask:
+ subnetmask_kwargs = dict(model=model, likelihood=likelihood)
+ else:
+ subnetmask_kwargs = dict(model=model)
# should raise error if we don't pass number of subnet parameters within the subnetmask_kwargs
with pytest.raises(TypeError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
+ subnetmask.select(loader)
# should raise error if we set number of subnet parameters to None
subnetmask_kwargs.update(n_params_subnet=None)
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
+ subnetmask.select(loader)
# should raise error if we set number of subnet parameters to be larger than number of model parameters
subnetmask_kwargs.update(n_params_subnet=99999)
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
+ subnetmask.select(loader)
- # define valid subnet Laplace model
+ # define subnetwork mask
n_params_subnet = 32
subnetmask_kwargs.update(n_params_subnet=n_params_subnet)
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
- assert isinstance(lap, SubnetLaplace)
- assert isinstance(lap._subnetwork_mask, subnetwork_mask)
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
# should raise error if we try to access the subnet indices before the subnet has been selected
with pytest.raises(AttributeError):
- lap._subnetwork_mask.indices
+ subnetmask.indices
# select subnet mask
- lap._subnetwork_mask.select(loader)
+ subnetmask.select(loader)
# should raise error if we try to select the subnet again
with pytest.raises(ValueError):
- lap._subnetwork_mask.select(loader)
+ subnetmask.select(loader)
- # re-define valid subnet Laplace model
- n_params_subnet = 32
- subnetmask_kwargs.update(n_params_subnet=n_params_subnet)
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ # define valid subnet Laplace model
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure='full')
assert isinstance(lap, SubnetLaplace)
- assert isinstance(lap._subnetwork_mask, subnetwork_mask)
- # fit Laplace model (which internally selects the subnet mask)
+ # fit Laplace model
lap.fit(loader)
# check some parameters
- assert lap._subnetwork_mask.indices.equal(lap.backend.subnetwork_indices)
- assert lap._subnetwork_mask.n_params_subnet == n_params_subnet
+ assert subnetmask.indices.equal(lap.backend.subnetwork_indices)
+ assert subnetmask.n_params_subnet == n_params_subnet
assert lap.n_params_subnet == n_params_subnet
assert parameters_to_vector(model.parameters()).equal(model_params)
@@ -120,142 +130,152 @@ def test_score_based_subnet_masks(model, likelihood, subnetwork_mask, class_load
assert lap.H.shape == (n_params_subnet, n_params_subnet)
assert lap.prior_precision_diag.shape == (n_params_subnet,)
- # should raise error if we try to fit the Laplace mdoel again
- with pytest.raises(ValueError):
- lap.fit(loader)
-
@pytest.mark.parametrize('subnetwork_mask,likelihood', product(layer_subnet_masks, likelihoods))
def test_layer_subnet_masks(model, likelihood, subnetwork_mask, class_loader, reg_loader):
loader = class_loader if likelihood == 'classification' else reg_loader
+ subnetmask_kwargs = dict(model=model)
# fit last-layer Laplace model
lllap = Laplace(model, likelihood=likelihood, subset_of_weights='last_layer', hessian_structure='full')
lllap.fit(loader)
# should raise error if we pass number of subnet parameters
- subnetmask_kwargs = dict(n_params_subnet=32)
+ subnetmask_kwargs.update(n_params_subnet=32)
with pytest.raises(TypeError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
+ subnetmask.select(loader)
+ subnetmask_kwargs = dict(model=model)
if subnetwork_mask == ParamNameSubnetMask:
# should raise error if we pass no parameter name list
- subnetmask_kwargs = dict()
+ subnetmask_kwargs.update()
with pytest.raises(TypeError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
+ subnetmask.select(loader)
# should raise error if we pass an empty parameter name list
- subnetmask_kwargs = dict(parameter_names=[])
+ subnetmask_kwargs.update(parameter_names=[])
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
- lap.fit(loader)
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
+ subnetmask.select(loader)
# should raise error if we pass a parameter name list with invalid parameter names
- subnetmask_kwargs = dict(parameter_names=['123'])
+ subnetmask_kwargs.update(parameter_names=['123'])
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
- lap.fit(loader)
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
+ subnetmask.select(loader)
# define last-layer Laplace model by parameter names and check that Hessian is identical to that of a full LLLaplace model
- subnetmask_kwargs = dict(parameter_names=['1.weight', '1.bias'])
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ subnetmask_kwargs.update(parameter_names=['1.weight', '1.bias'])
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
+ subnetmask.select(loader)
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure='full')
lap.fit(loader)
assert lllap.H.equal(lap.H)
- # define valid parameter name subnet Laplace model
- subnetmask_kwargs = dict(parameter_names=['0.weight', '1.bias'])
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
- n_params_subnet = 62
- assert isinstance(lap, SubnetLaplace)
- assert isinstance(lap._subnetwork_mask, subnetwork_mask)
+ # define valid parameter name subnet mask
+ subnetmask_kwargs.update(parameter_names=['0.weight', '1.bias'])
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
# should raise error if we access number of subnet parameters before selecting the subnet
+ n_params_subnet = 62
with pytest.raises(AttributeError):
- n_params_subnet = lap._subnetwork_mask.n_params_subnet
+ n_params_subnet = subnetmask.n_params_subnet
- # fit Laplace model
+ # select subnet mask and fit Laplace model
+ subnetmask.select(loader)
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure='full')
lap.fit(loader)
+ assert isinstance(lap, SubnetLaplace)
elif subnetwork_mask == ModuleNameSubnetMask:
# should raise error if we pass no module name list
- subnetmask_kwargs = dict()
+ subnetmask_kwargs.update()
with pytest.raises(TypeError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
+ subnetmask.select(loader)
# should raise error if we pass an empty module name list
- subnetmask_kwargs = dict(module_names=[])
+ subnetmask_kwargs.update(module_names=[])
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
- lap.fit(loader)
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
+ subnetmask.select(loader)
# should raise error if we pass a module name list with invalid module names
- subnetmask_kwargs = dict(module_names=['123'])
+ subnetmask_kwargs.update(module_names=['123'])
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
- lap.fit(loader)
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
+ subnetmask.select(loader)
# define last-layer Laplace model by module name and check that Hessian is identical to that of a full LLLaplace model
- subnetmask_kwargs = dict(module_names=['1'])
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ subnetmask_kwargs.update(module_names=['1'])
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
+ subnetmask.select(loader)
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure='full')
lap.fit(loader)
assert lllap.H.equal(lap.H)
- # define valid parameter name subnet Laplace model
- subnetmask_kwargs = dict(module_names=['0'])
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
- n_params_subnet = 80
- assert isinstance(lap, SubnetLaplace)
- assert isinstance(lap._subnetwork_mask, subnetwork_mask)
+ # define valid parameter name subnet mask
+ subnetmask_kwargs.update(module_names=['0'])
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
# should raise error if we access number of subnet parameters before selecting the subnet
+ n_params_subnet = 80
with pytest.raises(AttributeError):
- n_params_subnet = lap._subnetwork_mask.n_params_subnet
+ n_params_subnet = subnetmask.n_params_subnet
- # fit Laplace model
+ # select subnet mask and fit Laplace model
+ subnetmask.select(loader)
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure='full')
lap.fit(loader)
+ assert isinstance(lap, SubnetLaplace)
elif subnetwork_mask == LastLayerSubnetMask:
# should raise error if we pass invalid last-layer name
- subnetmask_kwargs = dict(last_layer_name='123')
+ subnetmask_kwargs.update(last_layer_name='123')
with pytest.raises(KeyError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
+ subnetmask.select(loader)
- # define valid last-layer subnet Laplace model (without passing the last-layer name)
- subnetmask_kwargs = dict()
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
- assert isinstance(lap, SubnetLaplace)
- assert isinstance(lap._subnetwork_mask, subnetwork_mask)
+ # define valid last-layer subnet mask (without passing the last-layer name)
+ subnetmask_kwargs = dict(model=model)
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
# should raise error if we access number of subnet parameters before selecting the subnet
with pytest.raises(AttributeError):
- n_params_subnet = lap._subnetwork_mask.n_params_subnet
+ n_params_subnet = subnetmask.n_params_subnet
- # fit Laplace model
+ # select subnet mask and fit Laplace model
+ subnetmask.select(loader)
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure='full')
lap.fit(loader)
+ assert isinstance(lap, SubnetLaplace)
# check that Hessian is identical to that of a full LLLaplace model
assert lllap.H.equal(lap.H)
- # define valid last-layer subnet Laplace model (with passing the last-layer name)
- subnetmask_kwargs = dict(last_layer_name='1')
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
- n_params_subnet = 42
- assert isinstance(lap, SubnetLaplace)
- assert isinstance(lap._subnetwork_mask, subnetwork_mask)
+ # define valid last-layer subnet mask (with passing the last-layer name)
+ subnetmask_kwargs.update(last_layer_name='1')
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
# should raise error if we access number of subnet parameters before selecting the subnet
+ n_params_subnet = 42
with pytest.raises(AttributeError):
- n_params_subnet = lap._subnetwork_mask.n_params_subnet
+ n_params_subnet = subnetmask.n_params_subnet
- # fit Laplace model
+ # select subnet mask and fit Laplace model
+ subnetmask.select(loader)
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure='full')
lap.fit(loader)
+ assert isinstance(lap, SubnetLaplace)
# check that Hessian is identical to that of a full LLLaplace model
assert lllap.H.equal(lap.H)
# check some parameters
- assert lap._subnetwork_mask.indices.equal(lap.backend.subnetwork_indices)
- assert lap._subnetwork_mask.n_params_subnet == n_params_subnet
+ assert subnetmask.indices.equal(lap.backend.subnetwork_indices)
+ assert subnetmask.n_params_subnet == n_params_subnet
assert lap.n_params_subnet == n_params_subnet
# check that Hessian and prior precision is of correct shape
@@ -274,14 +294,15 @@ def get_subnet_mask(self, train_loader):
# define and fit valid full subnet Laplace model
subnetwork_mask = FullSubnetMask
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full')
+ subnetmask = subnetwork_mask(model=model)
+ subnetmask.select(loader)
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure='full')
lap.fit(loader)
assert isinstance(lap, SubnetLaplace)
- assert isinstance(lap._subnetwork_mask, subnetwork_mask)
# check some parameters
- assert lap._subnetwork_mask.indices.equal(torch.tensor(list(range(model.n_params))))
- assert lap._subnetwork_mask.n_params_subnet == model.n_params
+ assert subnetmask.indices.equal(torch.tensor(list(range(model.n_params))))
+ assert subnetmask.n_params_subnet == model.n_params
assert lap.n_params_subnet == model.n_params
# check that the Hessian is identical to that of a all-weights FullLaplace model
@@ -292,18 +313,23 @@ def get_subnet_mask(self, train_loader):
@pytest.mark.parametrize('subnetwork_mask', all_subnet_masks)
def test_regression_predictive(model, reg_loader, subnetwork_mask):
+ subnetmask_kwargs = dict(model=model)
if subnetwork_mask in score_based_subnet_masks:
- subnetmask_kwargs = dict(n_params_subnet=32)
+ subnetmask_kwargs.update(n_params_subnet=32)
+ if subnetwork_mask == LargestVarianceSWAGSubnetMask:
+ subnetmask_kwargs.update(likelihood='regression')
+ elif subnetwork_mask == LargestVarianceDiagLaplaceSubnetMask:
+ diag_laplace_model = DiagLaplace(model, 'regression')
+ subnetmask_kwargs.update(diag_laplace_model=diag_laplace_model)
elif subnetwork_mask == ParamNameSubnetMask:
- subnetmask_kwargs = dict(parameter_names=['0.weight', '1.bias'])
+ subnetmask_kwargs.update(parameter_names=['0.weight', '1.bias'])
elif subnetwork_mask == ModuleNameSubnetMask:
- subnetmask_kwargs = dict(module_names=['0'])
- else:
- subnetmask_kwargs = dict()
- subnetmask_kwargs.update(dict(likelihood='regression') if subnetwork_mask == LargestVarianceSWAGSubnetMask else dict())
- lap = Laplace(model, likelihood='regression', subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ subnetmask_kwargs.update(module_names=['0'])
+
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
+ subnetmask.select(reg_loader)
+ lap = Laplace(model, likelihood='regression', subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure='full')
assert isinstance(lap, SubnetLaplace)
- assert isinstance(lap._subnetwork_mask, subnetwork_mask)
lap.fit(reg_loader)
X, _ = reg_loader.dataset.tensors
@@ -328,18 +354,23 @@ def test_regression_predictive(model, reg_loader, subnetwork_mask):
@pytest.mark.parametrize('subnetwork_mask', all_subnet_masks)
def test_classification_predictive(model, class_loader, subnetwork_mask):
+ subnetmask_kwargs = dict(model=model)
if subnetwork_mask in score_based_subnet_masks:
- subnetmask_kwargs = dict(n_params_subnet=32)
+ subnetmask_kwargs.update(n_params_subnet=32)
+ if subnetwork_mask == LargestVarianceSWAGSubnetMask:
+ subnetmask_kwargs.update(likelihood='classification')
+ elif subnetwork_mask == LargestVarianceDiagLaplaceSubnetMask:
+ diag_laplace_model = DiagLaplace(model, 'classification')
+ subnetmask_kwargs.update(diag_laplace_model=diag_laplace_model)
elif subnetwork_mask == ParamNameSubnetMask:
- subnetmask_kwargs = dict(parameter_names=['0.weight', '1.bias'])
+ subnetmask_kwargs.update(parameter_names=['0.weight', '1.bias'])
elif subnetwork_mask == ModuleNameSubnetMask:
- subnetmask_kwargs = dict(module_names=['0'])
- else:
- subnetmask_kwargs = dict()
- subnetmask_kwargs.update(dict(likelihood='classification') if subnetwork_mask == LargestVarianceSWAGSubnetMask else dict())
- lap = Laplace(model, likelihood='classification', subset_of_weights='subnetwork', subnetwork_mask=subnetwork_mask, hessian_structure='full', subnetmask_kwargs=subnetmask_kwargs)
+ subnetmask_kwargs.update(module_names=['0'])
+
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
+ subnetmask.select(class_loader)
+ lap = Laplace(model, likelihood='classification', subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure='full')
assert isinstance(lap, SubnetLaplace)
- assert isinstance(lap._subnetwork_mask, subnetwork_mask)
lap.fit(class_loader)
X, _ = class_loader.dataset.tensors
From 6de38c9821247a0ebf2a60e16c78c471b7113fb7 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Mon, 3 Jan 2022 18:13:18 +0000
Subject: [PATCH 32/49] Change subnetwork indices validity check to only allow
long tensors
---
laplace/subnetlaplace.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py
index 25c5ae96..f624a29b 100644
--- a/laplace/subnetlaplace.py
+++ b/laplace/subnetlaplace.py
@@ -88,9 +88,9 @@ def _check_subnetwork_indices(self, subnetwork_indices):
"""
if subnetwork_indices is None:
raise ValueError('Subnetwork indices cannot be None.')
- elif not (isinstance(subnetwork_indices, torch.Tensor) and len(subnetwork_indices.shape) == 1 and\
- subnetwork_indices.dtype in [torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8]):
- raise ValueError('Subnetwork indices need to be 1-dimensional integral torch.Tensor!')
+ elif not (isinstance(subnetwork_indices, torch.Tensor) and subnetwork_indices.numel() > 0\
+ and len(subnetwork_indices.shape) == 1 and subnetwork_indices.dtype == torch.int64):
+ raise ValueError('Subnetwork indices must be non-empty, 1-dimensional torch.LongTensor.')
elif not (len(subnetwork_indices[subnetwork_indices < 0]) == 0 and\
len(subnetwork_indices[subnetwork_indices >= self.n_params]) == 0):
raise ValueError(f'Subnetwork indices must lie between 0 and n_params={self.n_params}.')
From 7058b399469ec4dbfd06d2f51c7fb097c2da5ce1 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Mon, 3 Jan 2022 18:13:58 +0000
Subject: [PATCH 33/49] Add test for SubnetLaplace with custom subnetwork
indices specification
---
tests/test_subnetlaplace.py | 84 +++++++++++++++++++++++++++++++++++++
1 file changed, 84 insertions(+)
diff --git a/tests/test_subnetlaplace.py b/tests/test_subnetlaplace.py
index 1a024a75..011fc0f9 100644
--- a/tests/test_subnetlaplace.py
+++ b/tests/test_subnetlaplace.py
@@ -66,6 +66,90 @@ def test_subnet_laplace_init(model, likelihood):
lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure)
+@pytest.mark.parametrize('likelihood', likelihoods)
+def test_custom_subnetwork_indices(model, likelihood, class_loader, reg_loader):
+ loader = class_loader if likelihood == 'classification' else reg_loader
+
+ # subnetwork indices that are None should raise an error
+ subnetwork_indices = None
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+
+ # subnetwork indices that are not PyTorch tensors should raise an error
+ subnetwork_indices = [0, 5, 11, 42]
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+
+ # subnetwork indices that are empty tensors should raise an error
+ subnetwork_indices = torch.LongTensor([])
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+
+ # subnetwork indices that are not 1D PyTorch tensors should raise an error
+ subnetwork_indices = torch.LongTensor([[0, 5], [11, 42]])
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+
+ # subnetwork indices that are double tensors should raise an error
+ subnetwork_indices = torch.DoubleTensor([0.0, 5.0, 11.0, 42.0])
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+
+ # subnetwork indices that are float tensors should raise an error
+ subnetwork_indices = torch.FloatTensor([0.0, 5.0, 11.0, 42.0])
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+
+ # subnetwork indices that are half tensors should raise an error
+ subnetwork_indices = torch.HalfTensor([0.0, 5.0, 11.0, 42.0])
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+
+ # subnetwork indices that are int tensors should raise an error
+ subnetwork_indices = torch.IntTensor([0, 5, 11, 42])
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+
+ # subnetwork indices that are short tensors should raise an error
+ subnetwork_indices = torch.ShortTensor([0, 5, 11, 42])
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+
+ # subnetwork indices that are char tensors should raise an error
+ subnetwork_indices = torch.CharTensor([0, 5, 11, 42])
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+
+ # subnetwork indices that are bool tensors should raise an error
+ subnetwork_indices = torch.BoolTensor([0, 5, 11, 42])
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+
+ # subnetwork indices that contain elements smaller than zero should raise an error
+ subnetwork_indices = torch.LongTensor([0, -1, -11])
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+
+ # subnetwork indices that contain elements larger than n_params should raise an error
+ subnetwork_indices = torch.LongTensor([model.n_params + 1, model.n_params + 42])
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+
+ # subnetwork indices that contain duplicate entries should raise an error
+ subnetwork_indices = torch.LongTensor([0, 0, 5, 11, 11, 42])
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+
+ # Non-empty, 1-dimensional torch.LongTensor with valid entries should work
+ subnetwork_indices = torch.LongTensor([0, 5, 11, 42])
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+ lap.fit(loader)
+ assert isinstance(lap, SubnetLaplace)
+ assert lap.n_params_subnet == 4
+ assert lap.H.shape == (4, 4)
+ assert lap.backend.subnetwork_indices.equal(subnetwork_indices)
+
+
@pytest.mark.parametrize('subnetwork_mask,likelihood', product(score_based_subnet_masks, likelihoods))
def test_score_based_subnet_masks(model, likelihood, subnetwork_mask, class_loader, reg_loader):
loader = class_loader if likelihood == 'classification' else reg_loader
From e66ba5147720c1f88c4777527808739516980e7e Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Mon, 3 Jan 2022 18:24:58 +0000
Subject: [PATCH 34/49] Remove None default value for subnetwork indices in
SubnetLaplace
---
laplace/subnetlaplace.py | 4 ++--
tests/test_subnetlaplace.py | 4 ++++
2 files changed, 6 insertions(+), 2 deletions(-)
diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py
index f624a29b..010215e5 100644
--- a/laplace/subnetlaplace.py
+++ b/laplace/subnetlaplace.py
@@ -47,7 +47,7 @@ class SubnetLaplace(FullLaplace):
model : torch.nn.Module or `laplace.utils.feature_extractor.FeatureExtractor`
likelihood : {'classification', 'regression'}
determines the log likelihood Hessian approximation
- subnetwork_indices : torch.Tensor, default=None
+ subnetwork_indices : torch.LongTensor
indices of the vectorized model parameters that define the subnetwork
to apply the Laplace approximation over
sigma_noise : torch.Tensor or float, default=1
@@ -69,7 +69,7 @@ class SubnetLaplace(FullLaplace):
# key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure)
_key = ('subnetwork', 'full')
- def __init__(self, model, likelihood, subnetwork_indices=None, sigma_noise=1., prior_precision=1.,
+ def __init__(self, model, likelihood, subnetwork_indices, sigma_noise=1., prior_precision=1.,
prior_mean=0., temperature=1., backend=BackPackGGN, backend_kwargs=None):
self.H = None
super().__init__(model, likelihood, sigma_noise=sigma_noise, prior_precision=prior_precision,
diff --git a/tests/test_subnetlaplace.py b/tests/test_subnetlaplace.py
index 011fc0f9..b472b6aa 100644
--- a/tests/test_subnetlaplace.py
+++ b/tests/test_subnetlaplace.py
@@ -54,6 +54,10 @@ def test_subnet_laplace_init(model, likelihood):
lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure)
assert isinstance(lap, SubnetLaplace)
+ # subnet Laplace without specifying subnetwork indices should raise an error
+ with pytest.raises(TypeError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', hessian_structure=hessian_structure)
+
# subnet Laplace with diag, kron or lowrank Hessians should raise errors
hessian_structure = 'diag'
with pytest.raises(ValueError):
From 589b8462cf074b7df80abe72e1f853058bbc56cf Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Tue, 4 Jan 2022 09:55:18 +0000
Subject: [PATCH 35/49] Add failing test case for scalar subnetwork indices
---
tests/test_subnetlaplace.py | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/tests/test_subnetlaplace.py b/tests/test_subnetlaplace.py
index b472b6aa..1e98622c 100644
--- a/tests/test_subnetlaplace.py
+++ b/tests/test_subnetlaplace.py
@@ -89,6 +89,11 @@ def test_custom_subnetwork_indices(model, likelihood, class_loader, reg_loader):
with pytest.raises(ValueError):
lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+ # subnetwork indices that are scalar tensors should raise an error
+ subnetwork_indices = torch.LongTensor(11)
+ with pytest.raises(ValueError):
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+
# subnetwork indices that are not 1D PyTorch tensors should raise an error
subnetwork_indices = torch.LongTensor([[0, 5], [11, 42]])
with pytest.raises(ValueError):
From be245c8f39a62879259eb5fddcccca65196203ec Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Tue, 4 Jan 2022 09:56:02 +0000
Subject: [PATCH 36/49] Change SubnetMask.select() to return subnet indices and
improve documentation
---
laplace/utils/subnetmask.py | 19 ++++++++++++++++---
1 file changed, 16 insertions(+), 3 deletions(-)
diff --git a/laplace/utils/subnetmask.py b/laplace/utils/subnetmask.py
index f1213d55..9a74e705 100644
--- a/laplace/utils/subnetmask.py
+++ b/laplace/utils/subnetmask.py
@@ -49,11 +49,14 @@ def convert_subnet_mask_to_indices(self, subnet_mask):
subnet_mask : torch.Tensor
a binary vector of size (n_params) where 1s locate the subnetwork parameters
within the vectorized model parameters
+ (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
Returns
-------
- subnet_mask_indices : torch.Tensor
- a vector of indices of the vectorized model parameters that define the subnetwork
+ subnet_mask_indices : torch.LongTensor
+ a vector of indices of the vectorized model parameters
+ (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
+ that define the subnetwork
"""
if not isinstance(subnet_mask, torch.Tensor):
raise ValueError('Subnetwork mask needs to be torch.Tensor!')
@@ -63,7 +66,8 @@ def convert_subnet_mask_to_indices(self, subnet_mask):
elif len(subnet_mask) != self._n_params or\
len(subnet_mask[subnet_mask == 0]) + len(subnet_mask[subnet_mask == 1]) != self._n_params:
raise ValueError('Subnetwork mask needs to be a binary vector of size (n_params) where 1s'\
- 'locate the subnetwork parameters within the vectorized model parameters!')
+ 'locate the subnetwork parameters within the vectorized model parameters'\
+ '(i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)!')
subnet_mask_indices = subnet_mask.nonzero(as_tuple=True)[0]
return subnet_mask_indices
@@ -76,12 +80,20 @@ def select(self, train_loader=None):
train_loader : torch.data.utils.DataLoader, default=None
each iterate is a training batch (X, y);
`train_loader.dataset` needs to be set to access \\(N\\), size of the data set
+
+ Returns
+ -------
+ subnet_mask_indices : torch.LongTensor
+ a vector of indices of the vectorized model parameters
+ (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
+ that define the subnetwork
"""
if self._indices is not None:
raise ValueError('Subnetwork mask already selected.')
subnet_mask = self.get_subnet_mask(train_loader)
self._indices = self.convert_subnet_mask_to_indices(subnet_mask)
+ return self._indices
def get_subnet_mask(self, train_loader):
""" Get the subnetwork mask.
@@ -97,6 +109,7 @@ def get_subnet_mask(self, train_loader):
subnet_mask: torch.Tensor
a binary vector of size (n_params) where 1s locate the subnetwork parameters
within the vectorized model parameters
+ (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
"""
raise NotImplementedError
From 6ae4f9f63560db08ac7263ccc5d5db0ac9fd4bf8 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Tue, 4 Jan 2022 09:56:34 +0000
Subject: [PATCH 37/49] Minor refactorings (subnet indices checks and
documentation)
---
laplace/subnetlaplace.py | 14 ++++++++------
1 file changed, 8 insertions(+), 6 deletions(-)
diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py
index 010215e5..6b04cf10 100644
--- a/laplace/subnetlaplace.py
+++ b/laplace/subnetlaplace.py
@@ -48,8 +48,9 @@ class SubnetLaplace(FullLaplace):
likelihood : {'classification', 'regression'}
determines the log likelihood Hessian approximation
subnetwork_indices : torch.LongTensor
- indices of the vectorized model parameters that define the subnetwork
- to apply the Laplace approximation over
+ indices of the vectorized model parameters
+ (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)
+ that define the subnetwork to apply the Laplace approximation over
sigma_noise : torch.Tensor or float, default=1
observation noise for the regression setting; must be 1 for classification
prior_precision : torch.Tensor or float, default=1
@@ -84,18 +85,19 @@ def _init_H(self):
self.H = torch.zeros(self.n_params_subnet, self.n_params_subnet, device=self._device)
def _check_subnetwork_indices(self, subnetwork_indices):
- """Check that subnetwork indices are valid indices of the vectorized model parameters.
+ """Check that subnetwork indices are valid indices of the vectorized model parameters
+ (i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`).
"""
if subnetwork_indices is None:
raise ValueError('Subnetwork indices cannot be None.')
- elif not (isinstance(subnetwork_indices, torch.Tensor) and subnetwork_indices.numel() > 0\
- and len(subnetwork_indices.shape) == 1 and subnetwork_indices.dtype == torch.int64):
+ elif not (isinstance(subnetwork_indices, torch.LongTensor) and\
+ subnetwork_indices.numel() > 0 and len(subnetwork_indices.shape) == 1):
raise ValueError('Subnetwork indices must be non-empty, 1-dimensional torch.LongTensor.')
elif not (len(subnetwork_indices[subnetwork_indices < 0]) == 0 and\
len(subnetwork_indices[subnetwork_indices >= self.n_params]) == 0):
raise ValueError(f'Subnetwork indices must lie between 0 and n_params={self.n_params}.')
elif not (subnetwork_indices.sort()[0].equal(torch.unique(subnetwork_indices, sorted=True))):
- raise ValueError('Subnetwork indices must be unique.')
+ raise ValueError('Subnetwork indices must not contain duplicate entries.')
@property
def prior_precision_diag(self):
From 288f7678aea11c3f97b7ffa21dded511841c1a1e Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Tue, 4 Jan 2022 09:57:15 +0000
Subject: [PATCH 38/49] Add README example for SubnetLaplace
---
README.md | 47 +++++++++++++++++++++++++++++++++++++++++++----
1 file changed, 43 insertions(+), 4 deletions(-)
diff --git a/README.md b/README.md
index fbe90ee0..b7543727 100644
--- a/README.md
+++ b/README.md
@@ -39,7 +39,7 @@ pytest tests/
## Structure
The laplace package consists of two main components:
-1. The subclasses of [`laplace.BaseLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/baselaplace.py) that implement different sparsity structures: different subsets of weights (`'all'`, `'subnetwork'` and `'last_layer'`) and different structures of the Hessian approximation (`'full'`, `'kron'`, `'lowrank'` and `'diag'`). This results in _eight_ currently available options: `laplace.FullLaplace`, `laplace.KronLaplace`, `laplace.DiagLaplace`, the corresponding last-layer variations `laplace.FullLLLaplace`, `laplace.KronLLLaplace`, and `laplace.DiagLLLaplace` (which are all subclasses of [`laplace.LLLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/lllaplace.py)), `laplace.SubnetLaplace` (which only supports a `'full'` Hessian approximation) and `laplace.LowRankLaplace` (which only supports inference over `'all'` weights). All of these can be conveniently accessed via the [`laplace.Laplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/laplace.py) function.
+1. The subclasses of [`laplace.BaseLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/baselaplace.py) that implement different sparsity structures: different subsets of weights (`'all'`, `'subnetwork'` and `'last_layer'`) and different structures of the Hessian approximation (`'full'`, `'kron'`, `'lowrank'` and `'diag'`). This results in _eight_ currently available options: `laplace.FullLaplace`, `laplace.KronLaplace`, `laplace.DiagLaplace`, the corresponding last-layer variations `laplace.FullLLLaplace`, `laplace.KronLLLaplace`, and `laplace.DiagLLLaplace` (which are all subclasses of [`laplace.LLLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/lllaplace.py)), [`laplace.SubnetLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/subnetlaplace.py) (which only supports a `'full'` Hessian approximation) and `laplace.LowRankLaplace` (which only supports inference over `'all'` weights). All of these can be conveniently accessed via the [`laplace.Laplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/laplace.py) function.
2. The backends in [`laplace.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/) which provide access to Hessian approximations of
the corresponding sparsity structures, for example, the diagonal GGN.
@@ -48,7 +48,7 @@ decomposing a neural network into feature extractor and last layer for `LLLaplac
and
effectively dealing with Kronecker factors ([`laplace.utils.matrix`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/matrix.py)).
-Finally, the package implements several options to select/specify a subnetwork for `laplace.SubnetLaplace` (as subclasses of [`laplace.utils.subnetmask.SubnetMask`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/subnetmask.py).
+Finally, the package implements several options to select/specify a subnetwork for `SubnetLaplace` (as subclasses of [`laplace.utils.subnetmask.SubnetMask`](https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/subnetmask.py)).
Automatic subnetwork selection strategies include: uniformly at random (`laplace.utils.subnetmask.RandomSubnetMask`), by largest parameter magnitudes (`LargestMagnitudeSubnetMask`), and by largest marginal parameter variances (`LargestVarianceDiagLaplaceSubnetMask` and `LargestVarianceSWAGSubnetMask`).
In addition to that, subnetworks can also be specified manually, by listing the names of either the model parameters (`ParamNameSubnetMask`) or modules (`ModuleNameSubnetMask`) to perform Laplace inference over.
@@ -78,7 +78,7 @@ the `'probit'` predictive for classification.
```python
from laplace import Laplace
-# pre-trained model
+# Pre-trained model
model = load_map_model()
# User-specified LA flavor
@@ -94,7 +94,7 @@ pred = la(x, link_approx='probit')
### Differentiating the log marginal likelihood w.r.t. hyperparameters
-The marginal likelihood can be used for model selection and is differentiable
+The marginal likelihood can be used for model selection [10] and is differentiable
for continuous hyperparameters like the prior precision or observation noise.
Here, we fit the library default, KFAC last-layer LA and differentiate
the log marginal likelihood.
@@ -114,6 +114,45 @@ ml = la.log_marginal_likelihood(prior_prec, obs_noise)
ml.backward()
```
+### Applying the LA over only a subset of the model parameters
+
+This example shows how to fit the Laplace approximation over only
+a subnetwork within a neural network (while keeping all other parameters
+fixed at their MAP estimates), as proposed in [11]. It also exemplifies
+different ways to specify the subnetwork to perform inference over.
+
+```python
+from laplace import Laplace
+
+# Pre-trained model
+model = load_model()
+
+# Examples of different ways to specify the subnetwork
+# via indices of the vectorized model parameters
+#
+# Example 1: select the 128 parameters with the largest magnitude
+from laplace.utils.subnetmask import LargestMagnitudeSubnetMask
+subnetwork_mask = LargestMagnitudeSubnetMask(model, n_params_subnet=128)
+subnetwork_indices = subnetwork_mask.select()
+
+# Example 2: specify the layers that define the subnetwork
+from laplace.utils.subnetmask import ModuleNameSubnetMask
+subnetwork_mask = ModuleNameSubnetMask(model, module_names=['layer.1', 'layer.3'])
+subnetwork_mask.select()
+subnetwork_indices = subnetwork_mask.indices
+
+# Example 3: manually define the subnetwork via custom subnetwork indices
+import torch
+subnetwork_indices = torch.tensor([0, 4, 11, 42, 123, 2021])
+
+# Define and fit subnetwork LA using the specified subnetwork indices
+la = Laplace(model, 'classification',
+ subset_of_weights='subnetwork',
+ hessian_structure='full',
+ subnetwork_indices=subnetwork_indices)
+la.fit(train_loader)
+```
+
## Documentation
The documentation is available [here](https://aleximmer.github.io/Laplace) or can be generated and/or viewed locally:
From b5d8adfe34d55e96152eb832a9da8082b50a9034 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Wed, 12 Jan 2022 09:24:04 +0000
Subject: [PATCH 39/49] Add __all__ for utils.py
---
laplace/utils/utils.py | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/laplace/utils/utils.py b/laplace/utils/utils.py
index 5b059d31..a00dc2f4 100644
--- a/laplace/utils/utils.py
+++ b/laplace/utils/utils.py
@@ -8,6 +8,10 @@
from torch.distributions.multivariate_normal import _precision_to_scale_tril
+__all__ = ['get_nll', 'validate', 'parameters_per_layer', 'invsqrt_precision', 'kron',
+ 'diagonal_add_scalar', 'symeig', 'block_diag', 'expand_prior_precision']
+
+
def get_nll(out_dist, targets):
return F.nll_loss(torch.log(out_dist), targets)
From 557aa05e98484da75d0cbe114799f0055fd02c12 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Wed, 12 Jan 2022 09:24:20 +0000
Subject: [PATCH 40/49] Add __all__ for swag.py
---
laplace/utils/swag.py | 13 ++++++++-----
1 file changed, 8 insertions(+), 5 deletions(-)
diff --git a/laplace/utils/swag.py b/laplace/utils/swag.py
index d780fe16..8e6fb514 100644
--- a/laplace/utils/swag.py
+++ b/laplace/utils/swag.py
@@ -4,7 +4,10 @@
from torch.nn.utils import parameters_to_vector
-def param_vector(model):
+__all__ = ['fit_diagonal_swag']
+
+
+def _param_vector(model):
return parameters_to_vector(model.parameters()).detach()
@@ -55,8 +58,8 @@ def fit_diagonal_swag(model, train_loader, criterion, n_snapshots_total=40, snap
device = next(_model.parameters()).device
# initialize running estimates of first and second moment of model parameters
- mean = torch.zeros_like(param_vector(_model))
- sq_mean = torch.zeros_like(param_vector(_model))
+ mean = torch.zeros_like(_param_vector(_model))
+ sq_mean = torch.zeros_like(_param_vector(_model))
n_snapshots = 0
# run SGD to collect model snapshots
@@ -72,8 +75,8 @@ def fit_diagonal_swag(model, train_loader, criterion, n_snapshots_total=40, snap
if epoch % snapshot_freq == 0:
# update running estimates of first and second moment of model parameters
- mean = mean * n_snapshots / (n_snapshots + 1) + param_vector(_model) / (n_snapshots + 1)
- sq_mean = sq_mean * n_snapshots / (n_snapshots + 1) + param_vector(_model) ** 2 / (n_snapshots + 1)
+ mean = mean * n_snapshots / (n_snapshots + 1) + _param_vector(_model) / (n_snapshots + 1)
+ sq_mean = sq_mean * n_snapshots / (n_snapshots + 1) + _param_vector(_model) ** 2 / (n_snapshots + 1)
n_snapshots += 1
# compute marginal parameter variances, Var[P] = E[P^2] - E[P]^2
From 4ab30df9c7668a024a8be593a1a781e3a9bd912a Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Wed, 12 Jan 2022 09:25:57 +0000
Subject: [PATCH 41/49] Add __init__.py to utils/ to simplify utility imports
---
README.md | 4 ++--
laplace/baselaplace.py | 3 +--
laplace/curvature/asdl.py | 3 +--
laplace/curvature/backpack.py | 2 +-
laplace/lllaplace.py | 4 +---
laplace/marglik_training.py | 2 +-
laplace/utils/__init__.py | 14 ++++++++++++++
laplace/utils/matrix.py | 2 +-
laplace/utils/subnetmask.py | 7 ++++---
tests/test_baselaplace.py | 2 +-
tests/test_feature_extractor.py | 2 +-
tests/test_jacobians.py | 2 +-
tests/test_lllaplace.py | 2 +-
tests/test_matrix.py | 5 ++---
tests/test_subnetlaplace.py | 4 ++--
tests/test_utils.py | 2 +-
16 files changed, 35 insertions(+), 25 deletions(-)
create mode 100644 laplace/utils/__init__.py
diff --git a/README.md b/README.md
index b7543727..d3590abd 100644
--- a/README.md
+++ b/README.md
@@ -131,12 +131,12 @@ model = load_model()
# via indices of the vectorized model parameters
#
# Example 1: select the 128 parameters with the largest magnitude
-from laplace.utils.subnetmask import LargestMagnitudeSubnetMask
+from laplace.utils import LargestMagnitudeSubnetMask
subnetwork_mask = LargestMagnitudeSubnetMask(model, n_params_subnet=128)
subnetwork_indices = subnetwork_mask.select()
# Example 2: specify the layers that define the subnetwork
-from laplace.utils.subnetmask import ModuleNameSubnetMask
+from laplace.utils import ModuleNameSubnetMask
subnetwork_mask = ModuleNameSubnetMask(model, module_names=['layer.1', 'layer.3'])
subnetwork_mask.select()
subnetwork_indices = subnetwork_mask.indices
diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py
index 2c382d98..09c05ca3 100644
--- a/laplace/baselaplace.py
+++ b/laplace/baselaplace.py
@@ -4,8 +4,7 @@
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.distributions import MultivariateNormal, Dirichlet, Normal
-from laplace.utils.utils import parameters_per_layer, invsqrt_precision, get_nll, validate
-from laplace.utils.matrix import Kron
+from laplace.utils import parameters_per_layer, invsqrt_precision, get_nll, validate, Kron
from laplace.curvature import BackPackGGN, AsdlHessian
diff --git a/laplace/curvature/asdl.py b/laplace/curvature/asdl.py
index a3ccd3d6..dac769fe 100644
--- a/laplace/curvature/asdl.py
+++ b/laplace/curvature/asdl.py
@@ -9,8 +9,7 @@
from asdfghjkl.gradient import batch_gradient
from laplace.curvature import CurvatureInterface, GGNInterface, EFInterface
-from laplace.utils.matrix import Kron
-from laplace.utils.utils import _is_batchnorm
+from laplace.utils import Kron, _is_batchnorm
EPS = 1e-6
diff --git a/laplace/curvature/backpack.py b/laplace/curvature/backpack.py
index 5c78a093..8cffc154 100644
--- a/laplace/curvature/backpack.py
+++ b/laplace/curvature/backpack.py
@@ -5,7 +5,7 @@
from backpack.context import CTX
from laplace.curvature import CurvatureInterface, GGNInterface, EFInterface
-from laplace.utils.matrix import Kron
+from laplace.utils import Kron
class BackPackInterface(CurvatureInterface):
diff --git a/laplace/lllaplace.py b/laplace/lllaplace.py
index 8f93ff53..73c552df 100644
--- a/laplace/lllaplace.py
+++ b/laplace/lllaplace.py
@@ -3,9 +3,7 @@
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from laplace.baselaplace import ParametricLaplace, FullLaplace, KronLaplace, DiagLaplace
-from laplace.utils.feature_extractor import FeatureExtractor
-
-from laplace.utils.matrix import Kron
+from laplace.utils import FeatureExtractor, Kron
from laplace.curvature import BackPackGGN
diff --git a/laplace/marglik_training.py b/laplace/marglik_training.py
index 3a2b2161..ec100542 100644
--- a/laplace/marglik_training.py
+++ b/laplace/marglik_training.py
@@ -9,7 +9,7 @@
from laplace import Laplace
from laplace.curvature import AsdlGGN
-from laplace.utils.utils import expand_prior_precision
+from laplace.utils import expand_prior_precision
def marglik_training(
diff --git a/laplace/utils/__init__.py b/laplace/utils/__init__.py
new file mode 100644
index 00000000..cc05d4d9
--- /dev/null
+++ b/laplace/utils/__init__.py
@@ -0,0 +1,14 @@
+from laplace.utils.utils import get_nll, validate, parameters_per_layer, invsqrt_precision, _is_batchnorm, _is_valid_scalar, kron, diagonal_add_scalar, symeig, block_diag, expand_prior_precision
+from laplace.utils.feature_extractor import FeatureExtractor
+from laplace.utils.matrix import Kron, KronDecomposed
+from laplace.utils.swag import fit_diagonal_swag
+from laplace.utils.subnetmask import SubnetMask, RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask, ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask
+
+
+__all__ = ['get_nll', 'validate', 'parameters_per_layer', 'invsqrt_precision', 'kron',
+ 'diagonal_add_scalar', 'symeig', 'block_diag', 'expand_prior_precision',
+ 'FeatureExtractor',
+ 'Kron', 'KronDecomposed',
+ 'fit_diagonal_swag',
+ 'SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask', 'LargestVarianceDiagLaplaceSubnetMask',
+ 'LargestVarianceSWAGSubnetMask', 'ParamNameSubnetMask', 'ModuleNameSubnetMask', 'LastLayerSubnetMask']
diff --git a/laplace/utils/matrix.py b/laplace/utils/matrix.py
index 30b0245d..61c07ab5 100644
--- a/laplace/utils/matrix.py
+++ b/laplace/utils/matrix.py
@@ -3,7 +3,7 @@
import numpy as np
from typing import Union
-from laplace.utils.utils import _is_valid_scalar, symeig, kron, block_diag
+from laplace.utils import _is_valid_scalar, symeig, kron, block_diag
class Kron:
diff --git a/laplace/utils/subnetmask.py b/laplace/utils/subnetmask.py
index 9a74e705..9609401a 100644
--- a/laplace/utils/subnetmask.py
+++ b/laplace/utils/subnetmask.py
@@ -4,10 +4,11 @@
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn.utils import parameters_to_vector
-from laplace.utils.feature_extractor import FeatureExtractor
-from laplace.utils.swag import fit_diagonal_swag
+from laplace.utils import FeatureExtractor, fit_diagonal_swag
-__all__ = ['SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask', 'LargestVarianceDiagLaplaceSubnetMask', 'LargestVarianceSWAGSubnetMask', 'ParamNameSubnetMask', 'ModuleNameSubnetMask', 'LastLayerSubnetMask']
+
+__all__ = ['SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask', 'LargestVarianceDiagLaplaceSubnetMask',
+ 'LargestVarianceSWAGSubnetMask', 'ParamNameSubnetMask', 'ModuleNameSubnetMask', 'LastLayerSubnetMask']
class SubnetMask:
diff --git a/tests/test_baselaplace.py b/tests/test_baselaplace.py
index a9292e8d..36fe8a16 100644
--- a/tests/test_baselaplace.py
+++ b/tests/test_baselaplace.py
@@ -12,7 +12,7 @@
from torchvision.models import wide_resnet50_2
from laplace.laplace import FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace
-from laplace.utils.matrix import KronDecomposed
+from laplace.utils import KronDecomposed
from tests.utils import jacobians_naive
diff --git a/tests/test_feature_extractor.py b/tests/test_feature_extractor.py
index b80bbcb4..d3b95ad5 100644
--- a/tests/test_feature_extractor.py
+++ b/tests/test_feature_extractor.py
@@ -2,7 +2,7 @@
import torch.nn as nn
import torchvision.models as models
-from laplace.utils.feature_extractor import FeatureExtractor
+from laplace.utils import FeatureExtractor
class CNN(nn.Module):
diff --git a/tests/test_jacobians.py b/tests/test_jacobians.py
index 0495adb3..13d2466e 100644
--- a/tests/test_jacobians.py
+++ b/tests/test_jacobians.py
@@ -3,7 +3,7 @@
from torch import nn
from laplace.curvature import AsdlInterface, BackPackInterface
-from laplace.utils.feature_extractor import FeatureExtractor
+from laplace.utils import FeatureExtractor
from tests.utils import jacobians_naive
diff --git a/tests/test_lllaplace.py b/tests/test_lllaplace.py
index bc6f7a5e..0e6855aa 100644
--- a/tests/test_lllaplace.py
+++ b/tests/test_lllaplace.py
@@ -9,7 +9,7 @@
from torchvision.models import wide_resnet50_2
from laplace.lllaplace import FullLLLaplace, KronLLLaplace, DiagLLLaplace
-from laplace.utils.feature_extractor import FeatureExtractor
+from laplace.utils import FeatureExtractor
from tests.utils import jacobians_naive
diff --git a/tests/test_matrix.py b/tests/test_matrix.py
index 66a5da48..7c366990 100644
--- a/tests/test_matrix.py
+++ b/tests/test_matrix.py
@@ -4,10 +4,9 @@
from torch import nn
from torch.nn.utils import parameters_to_vector
-from laplace.utils.matrix import Kron
-from laplace.utils.utils import kron as kron_prod
+from laplace.utils import Kron, block_diag
+from laplace.utils import kron as kron_prod
from laplace.curvature import BackPackGGN
-from laplace.utils.utils import block_diag
from tests.utils import get_psd_matrix, jacobians_naive
diff --git a/tests/test_subnetlaplace.py b/tests/test_subnetlaplace.py
index 1e98622c..8e0d3d18 100644
--- a/tests/test_subnetlaplace.py
+++ b/tests/test_subnetlaplace.py
@@ -1,5 +1,5 @@
import pytest
-from itertools import product
+from itertools import product
import torch
from torch import nn
@@ -8,7 +8,7 @@
from laplace import Laplace, SubnetLaplace
from laplace.baselaplace import DiagLaplace
-from laplace.utils.subnetmask import SubnetMask, RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask, ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask
+from laplace.utils import SubnetMask, RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask, ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask
torch.manual_seed(240)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index b673be3d..1ad0f517 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,5 +1,5 @@
import torch
-from laplace.utils.utils import invsqrt_precision, diagonal_add_scalar, symeig
+from laplace.utils import invsqrt_precision, diagonal_add_scalar, symeig
def test_sqrt_precision():
From 99d3fb2dd54444fb7e3c11cc903d5c1a00d80bd2 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Wed, 12 Jan 2022 09:32:33 +0000
Subject: [PATCH 42/49] Simplify check for duplicate indices in Subnet Laplace
---
laplace/subnetlaplace.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py
index 6b04cf10..4f815bc4 100644
--- a/laplace/subnetlaplace.py
+++ b/laplace/subnetlaplace.py
@@ -96,7 +96,7 @@ def _check_subnetwork_indices(self, subnetwork_indices):
elif not (len(subnetwork_indices[subnetwork_indices < 0]) == 0 and\
len(subnetwork_indices[subnetwork_indices >= self.n_params]) == 0):
raise ValueError(f'Subnetwork indices must lie between 0 and n_params={self.n_params}.')
- elif not (subnetwork_indices.sort()[0].equal(torch.unique(subnetwork_indices, sorted=True))):
+ elif not (len(subnetwork_indices.unique()) == len(subnetwork_indices)):
raise ValueError('Subnetwork indices must not contain duplicate entries.')
@property
From 61865737eeeeb76040642fd3f9fbcd59494c3a2c Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Wed, 12 Jan 2022 09:35:19 +0000
Subject: [PATCH 43/49] Add __all__ for matrix.py
---
laplace/utils/matrix.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/laplace/utils/matrix.py b/laplace/utils/matrix.py
index 61c07ab5..14a84bfe 100644
--- a/laplace/utils/matrix.py
+++ b/laplace/utils/matrix.py
@@ -6,6 +6,9 @@
from laplace.utils import _is_valid_scalar, symeig, kron, block_diag
+__all__ = ['Kron', 'KronDecomposed']
+
+
class Kron:
"""Kronecker factored approximate curvature representation for a corresponding
neural network.
From 86b91bc4e806d5fb5e659eefbf4c70fa6ab473f4 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Wed, 12 Jan 2022 09:37:13 +0000
Subject: [PATCH 44/49] Add line breaks with proper indents for __all__ in
subnetmask.py
---
laplace/utils/subnetmask.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/laplace/utils/subnetmask.py b/laplace/utils/subnetmask.py
index 9609401a..607b8f60 100644
--- a/laplace/utils/subnetmask.py
+++ b/laplace/utils/subnetmask.py
@@ -7,8 +7,9 @@
from laplace.utils import FeatureExtractor, fit_diagonal_swag
-__all__ = ['SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask', 'LargestVarianceDiagLaplaceSubnetMask',
- 'LargestVarianceSWAGSubnetMask', 'ParamNameSubnetMask', 'ModuleNameSubnetMask', 'LastLayerSubnetMask']
+__all__ = ['SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask',
+ 'LargestVarianceDiagLaplaceSubnetMask', 'LargestVarianceSWAGSubnetMask',
+ 'ParamNameSubnetMask', 'ModuleNameSubnetMask', 'LastLayerSubnetMask']
class SubnetMask:
From 3b54e8a858b0086cbc546292f32339798258339a Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Wed, 12 Jan 2022 09:45:52 +0000
Subject: [PATCH 45/49] Change name of fit_diagonal_swag() to
fit_diagonal_swag_var()
---
laplace/utils/__init__.py | 4 ++--
laplace/utils/subnetmask.py | 4 ++--
laplace/utils/swag.py | 4 ++--
3 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/laplace/utils/__init__.py b/laplace/utils/__init__.py
index cc05d4d9..10f559e0 100644
--- a/laplace/utils/__init__.py
+++ b/laplace/utils/__init__.py
@@ -1,7 +1,7 @@
from laplace.utils.utils import get_nll, validate, parameters_per_layer, invsqrt_precision, _is_batchnorm, _is_valid_scalar, kron, diagonal_add_scalar, symeig, block_diag, expand_prior_precision
from laplace.utils.feature_extractor import FeatureExtractor
from laplace.utils.matrix import Kron, KronDecomposed
-from laplace.utils.swag import fit_diagonal_swag
+from laplace.utils.swag import fit_diagonal_swag_var
from laplace.utils.subnetmask import SubnetMask, RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask, ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask
@@ -9,6 +9,6 @@
'diagonal_add_scalar', 'symeig', 'block_diag', 'expand_prior_precision',
'FeatureExtractor',
'Kron', 'KronDecomposed',
- 'fit_diagonal_swag',
+ 'fit_diagonal_swag_var',
'SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask', 'LargestVarianceDiagLaplaceSubnetMask',
'LargestVarianceSWAGSubnetMask', 'ParamNameSubnetMask', 'ModuleNameSubnetMask', 'LastLayerSubnetMask']
diff --git a/laplace/utils/subnetmask.py b/laplace/utils/subnetmask.py
index 607b8f60..d76751d3 100644
--- a/laplace/utils/subnetmask.py
+++ b/laplace/utils/subnetmask.py
@@ -4,7 +4,7 @@
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn.utils import parameters_to_vector
-from laplace.utils import FeatureExtractor, fit_diagonal_swag
+from laplace.utils import FeatureExtractor, fit_diagonal_swag_var
__all__ = ['SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask',
@@ -225,7 +225,7 @@ def compute_param_scores(self, train_loader):
criterion = CrossEntropyLoss(reduction='mean')
elif self.likelihood == 'regression':
criterion = MSELoss(reduction='mean')
- param_variances = fit_diagonal_swag(self.model, train_loader, criterion, n_snapshots_total=self.swag_n_snapshots, snapshot_freq=self.swag_snapshot_freq, lr=self.swag_lr)
+ param_variances = fit_diagonal_swag_var(self.model, train_loader, criterion, n_snapshots_total=self.swag_n_snapshots, snapshot_freq=self.swag_snapshot_freq, lr=self.swag_lr)
return param_variances
diff --git a/laplace/utils/swag.py b/laplace/utils/swag.py
index 8e6fb514..7b2e529f 100644
--- a/laplace/utils/swag.py
+++ b/laplace/utils/swag.py
@@ -4,14 +4,14 @@
from torch.nn.utils import parameters_to_vector
-__all__ = ['fit_diagonal_swag']
+__all__ = ['fit_diagonal_swag_var']
def _param_vector(model):
return parameters_to_vector(model.parameters()).detach()
-def fit_diagonal_swag(model, train_loader, criterion, n_snapshots_total=40, snapshot_freq=1, lr=0.01, momentum=0.9, weight_decay=3e-4, min_var=1e-30):
+def fit_diagonal_swag_var(model, train_loader, criterion, n_snapshots_total=40, snapshot_freq=1, lr=0.01, momentum=0.9, weight_decay=3e-4, min_var=1e-30):
"""
Fit diagonal SWAG [1], which estimates marginal variances of model parameters by
computing the first and second moment of SGD iterates with a large learning rate.
From d5d2d2348c46ab3f770712293b86aa63b10c2e35 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Wed, 12 Jan 2022 11:43:43 +0000
Subject: [PATCH 46/49] Shorten lines to 100 chars and change line breaks from
backslash to bracket
---
laplace/subnetlaplace.py | 12 ++--
laplace/utils/subnetmask.py | 36 ++++++++----
laplace/utils/swag.py | 11 ++--
tests/test_subnetlaplace.py | 114 ++++++++++++++++++++++++------------
4 files changed, 112 insertions(+), 61 deletions(-)
diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py
index 4f815bc4..32767ba2 100644
--- a/laplace/subnetlaplace.py
+++ b/laplace/subnetlaplace.py
@@ -73,9 +73,9 @@ class SubnetLaplace(FullLaplace):
def __init__(self, model, likelihood, subnetwork_indices, sigma_noise=1., prior_precision=1.,
prior_mean=0., temperature=1., backend=BackPackGGN, backend_kwargs=None):
self.H = None
- super().__init__(model, likelihood, sigma_noise=sigma_noise, prior_precision=prior_precision,
- prior_mean=prior_mean, temperature=temperature, backend=backend,
- backend_kwargs=backend_kwargs)
+ super().__init__(model, likelihood, sigma_noise=sigma_noise,
+ prior_precision=prior_precision, prior_mean=prior_mean,
+ temperature=temperature, backend=backend, backend_kwargs=backend_kwargs)
# check validity of subnetwork indices and pass them to backend
self._check_subnetwork_indices(subnetwork_indices)
self.backend.subnetwork_indices = subnetwork_indices
@@ -90,10 +90,10 @@ def _check_subnetwork_indices(self, subnetwork_indices):
"""
if subnetwork_indices is None:
raise ValueError('Subnetwork indices cannot be None.')
- elif not (isinstance(subnetwork_indices, torch.LongTensor) and\
+ elif not (isinstance(subnetwork_indices, torch.LongTensor) and
subnetwork_indices.numel() > 0 and len(subnetwork_indices.shape) == 1):
- raise ValueError('Subnetwork indices must be non-empty, 1-dimensional torch.LongTensor.')
- elif not (len(subnetwork_indices[subnetwork_indices < 0]) == 0 and\
+ raise ValueError('Subnetwork indices must be non-empty 1-dimensional torch.LongTensor.')
+ elif not (len(subnetwork_indices[subnetwork_indices < 0]) == 0 and
len(subnetwork_indices[subnetwork_indices >= self.n_params]) == 0):
raise ValueError(f'Subnetwork indices must lie between 0 and n_params={self.n_params}.')
elif not (len(subnetwork_indices.unique()) == len(subnetwork_indices)):
diff --git a/laplace/utils/subnetmask.py b/laplace/utils/subnetmask.py
index d76751d3..00d73ff4 100644
--- a/laplace/utils/subnetmask.py
+++ b/laplace/utils/subnetmask.py
@@ -62,13 +62,15 @@ def convert_subnet_mask_to_indices(self, subnet_mask):
"""
if not isinstance(subnet_mask, torch.Tensor):
raise ValueError('Subnetwork mask needs to be torch.Tensor!')
- elif subnet_mask.dtype not in [torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool] or\
- len(subnet_mask.shape) != 1:
- raise ValueError('Subnetwork mask needs to be 1-dimensional integral or boolean tensor!')
- elif len(subnet_mask) != self._n_params or\
- len(subnet_mask[subnet_mask == 0]) + len(subnet_mask[subnet_mask == 1]) != self._n_params:
- raise ValueError('Subnetwork mask needs to be a binary vector of size (n_params) where 1s'\
- 'locate the subnetwork parameters within the vectorized model parameters'\
+ elif subnet_mask.dtype not in [torch.int64, torch.int32, torch.int16, torch.int8,
+ torch.uint8, torch.bool] or len(subnet_mask.shape) != 1:
+ raise ValueError(
+ 'Subnetwork mask needs to be 1-dimensional integral or boolean tensor!')
+ elif (len(subnet_mask) != self._n_params or len(subnet_mask[subnet_mask == 0])
+ + len(subnet_mask[subnet_mask == 1]) != self._n_params):
+ raise ValueError('Subnetwork mask needs to be a binary vector of'
+ 'size (n_params) where 1s locate the subnetwork'
+ 'parameters within the vectorized model parameters'
'(i.e. `torch.nn.utils.parameters_to_vector(model.parameters())`)!')
subnet_mask_indices = subnet_mask.nonzero(as_tuple=True)[0]
@@ -117,7 +119,8 @@ def get_subnet_mask(self, train_loader):
class ScoreBasedSubnetMask(SubnetMask):
- """Baseclass for subnetwork masks defined by selecting the top-scoring parameters according to some criterion.
+ """Baseclass for subnetwork masks defined by selecting
+ the top-scoring parameters according to some criterion.
Parameters
----------
@@ -129,9 +132,11 @@ def __init__(self, model, n_params_subnet):
super().__init__(model)
if n_params_subnet is None:
- raise ValueError(f'Need to pass number of subnetwork parameters when using subnetwork Laplace.')
+ raise ValueError(
+ 'Need to pass number of subnetwork parameters when using subnetwork Laplace.')
if n_params_subnet > self._n_params:
- raise ValueError(f'Subnetwork ({n_params_subnet}) cannot be larger than model ({self._n_params}).')
+ raise ValueError(
+ f'Subnetwork ({n_params_subnet}) cannot be larger than model ({self._n_params}).')
self._n_params_subnet = n_params_subnet
self._param_scores = None
@@ -210,7 +215,8 @@ class LargestVarianceSWAGSubnetMask(ScoreBasedSubnetMask):
swag_lr : float
learning rate for SWAG snapshot collection
"""
- def __init__(self, model, n_params_subnet, likelihood='classification', swag_n_snapshots=40, swag_snapshot_freq=1, swag_lr=0.01):
+ def __init__(self, model, n_params_subnet, likelihood='classification',
+ swag_n_snapshots=40, swag_snapshot_freq=1, swag_lr=0.01):
super().__init__(model, n_params_subnet)
self.likelihood = likelihood
self.swag_n_snapshots = swag_n_snapshots
@@ -225,7 +231,10 @@ def compute_param_scores(self, train_loader):
criterion = CrossEntropyLoss(reduction='mean')
elif self.likelihood == 'regression':
criterion = MSELoss(reduction='mean')
- param_variances = fit_diagonal_swag_var(self.model, train_loader, criterion, n_snapshots_total=self.swag_n_snapshots, snapshot_freq=self.swag_snapshot_freq, lr=self.swag_lr)
+ param_variances = fit_diagonal_swag_var(self.model, train_loader, criterion,
+ n_snapshots_total=self.swag_n_snapshots,
+ snapshot_freq=self.swag_snapshot_freq,
+ lr=self.swag_lr)
return param_variances
@@ -236,7 +245,8 @@ class ParamNameSubnetMask(SubnetMask):
----------
model : torch.nn.Module
parameter_names: List[str]
- list of names of the parameters (as in `model.named_parameters()`) that define the subnetwork
+ list of names of the parameters (as in `model.named_parameters()`)
+ that define the subnetwork
"""
def __init__(self, model, parameter_names):
super().__init__(model)
diff --git a/laplace/utils/swag.py b/laplace/utils/swag.py
index 7b2e529f..a6aba701 100644
--- a/laplace/utils/swag.py
+++ b/laplace/utils/swag.py
@@ -11,7 +11,8 @@ def _param_vector(model):
return parameters_to_vector(model.parameters()).detach()
-def fit_diagonal_swag_var(model, train_loader, criterion, n_snapshots_total=40, snapshot_freq=1, lr=0.01, momentum=0.9, weight_decay=3e-4, min_var=1e-30):
+def fit_diagonal_swag_var(model, train_loader, criterion, n_snapshots_total=40, snapshot_freq=1,
+ lr=0.01, momentum=0.9, weight_decay=3e-4, min_var=1e-30):
"""
Fit diagonal SWAG [1], which estimates marginal variances of model parameters by
computing the first and second moment of SGD iterates with a large learning rate.
@@ -63,7 +64,8 @@ def fit_diagonal_swag_var(model, train_loader, criterion, n_snapshots_total=40,
n_snapshots = 0
# run SGD to collect model snapshots
- optimizer = torch.optim.SGD(_model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
+ optimizer = torch.optim.SGD(
+ _model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
n_epochs = snapshot_freq * n_snapshots_total
for epoch in range(n_epochs):
for inputs, targets in train_loader:
@@ -75,8 +77,9 @@ def fit_diagonal_swag_var(model, train_loader, criterion, n_snapshots_total=40,
if epoch % snapshot_freq == 0:
# update running estimates of first and second moment of model parameters
- mean = mean * n_snapshots / (n_snapshots + 1) + _param_vector(_model) / (n_snapshots + 1)
- sq_mean = sq_mean * n_snapshots / (n_snapshots + 1) + _param_vector(_model) ** 2 / (n_snapshots + 1)
+ old_fac, new_fac = n_snapshots / (n_snapshots + 1), 1 / (n_snapshots + 1)
+ mean = mean * old_fac + _param_vector(_model) * new_fac
+ sq_mean = sq_mean * old_fac + _param_vector(_model) ** 2 * new_fac
n_snapshots += 1
# compute marginal parameter variances, Var[P] = E[P^2] - E[P]^2
diff --git a/tests/test_subnetlaplace.py b/tests/test_subnetlaplace.py
index 8e0d3d18..10b3c319 100644
--- a/tests/test_subnetlaplace.py
+++ b/tests/test_subnetlaplace.py
@@ -8,12 +8,15 @@
from laplace import Laplace, SubnetLaplace
from laplace.baselaplace import DiagLaplace
-from laplace.utils import SubnetMask, RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask, ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask
+from laplace.utils import (SubnetMask, RandomSubnetMask, LargestMagnitudeSubnetMask,
+ LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask,
+ ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask)
torch.manual_seed(240)
torch.set_default_tensor_type(torch.DoubleTensor)
-score_based_subnet_masks = [RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask]
+score_based_subnet_masks = [RandomSubnetMask, LargestMagnitudeSubnetMask,
+ LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask]
layer_subnet_masks = [ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask]
all_subnet_masks = score_based_subnet_masks + layer_subnet_masks
likelihoods = ['classification', 'regression']
@@ -51,23 +54,28 @@ def test_subnet_laplace_init(model, likelihood):
# subnet Laplace with full Hessian should work
hessian_structure = 'full'
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure)
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure)
assert isinstance(lap, SubnetLaplace)
# subnet Laplace without specifying subnetwork indices should raise an error
with pytest.raises(TypeError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', hessian_structure=hessian_structure)
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ hessian_structure=hessian_structure)
# subnet Laplace with diag, kron or lowrank Hessians should raise errors
hessian_structure = 'diag'
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure)
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure)
hessian_structure = 'kron'
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure)
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure)
hessian_structure = 'lowrank'
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure)
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure)
@pytest.mark.parametrize('likelihood', likelihoods)
@@ -77,81 +85,97 @@ def test_custom_subnetwork_indices(model, likelihood, class_loader, reg_loader):
# subnetwork indices that are None should raise an error
subnetwork_indices = None
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetwork_indices, hessian_structure='full')
# subnetwork indices that are not PyTorch tensors should raise an error
subnetwork_indices = [0, 5, 11, 42]
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetwork_indices, hessian_structure='full')
# subnetwork indices that are empty tensors should raise an error
subnetwork_indices = torch.LongTensor([])
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetwork_indices, hessian_structure='full')
# subnetwork indices that are scalar tensors should raise an error
subnetwork_indices = torch.LongTensor(11)
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetwork_indices, hessian_structure='full')
# subnetwork indices that are not 1D PyTorch tensors should raise an error
subnetwork_indices = torch.LongTensor([[0, 5], [11, 42]])
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetwork_indices, hessian_structure='full')
# subnetwork indices that are double tensors should raise an error
subnetwork_indices = torch.DoubleTensor([0.0, 5.0, 11.0, 42.0])
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetwork_indices, hessian_structure='full')
# subnetwork indices that are float tensors should raise an error
subnetwork_indices = torch.FloatTensor([0.0, 5.0, 11.0, 42.0])
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetwork_indices, hessian_structure='full')
# subnetwork indices that are half tensors should raise an error
subnetwork_indices = torch.HalfTensor([0.0, 5.0, 11.0, 42.0])
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetwork_indices, hessian_structure='full')
# subnetwork indices that are int tensors should raise an error
subnetwork_indices = torch.IntTensor([0, 5, 11, 42])
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetwork_indices, hessian_structure='full')
# subnetwork indices that are short tensors should raise an error
subnetwork_indices = torch.ShortTensor([0, 5, 11, 42])
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetwork_indices, hessian_structure='full')
# subnetwork indices that are char tensors should raise an error
subnetwork_indices = torch.CharTensor([0, 5, 11, 42])
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetwork_indices, hessian_structure='full')
# subnetwork indices that are bool tensors should raise an error
subnetwork_indices = torch.BoolTensor([0, 5, 11, 42])
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetwork_indices, hessian_structure='full')
# subnetwork indices that contain elements smaller than zero should raise an error
subnetwork_indices = torch.LongTensor([0, -1, -11])
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetwork_indices, hessian_structure='full')
# subnetwork indices that contain elements larger than n_params should raise an error
subnetwork_indices = torch.LongTensor([model.n_params + 1, model.n_params + 42])
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetwork_indices, hessian_structure='full')
# subnetwork indices that contain duplicate entries should raise an error
subnetwork_indices = torch.LongTensor([0, 0, 5, 11, 11, 42])
with pytest.raises(ValueError):
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetwork_indices, hessian_structure='full')
# Non-empty, 1-dimensional torch.LongTensor with valid entries should work
subnetwork_indices = torch.LongTensor([0, 5, 11, 42])
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetwork_indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetwork_indices, hessian_structure='full')
lap.fit(loader)
assert isinstance(lap, SubnetLaplace)
assert lap.n_params_subnet == 4
@@ -184,7 +208,7 @@ def test_score_based_subnet_masks(model, likelihood, subnetwork_mask, class_load
subnetmask = subnetwork_mask(**subnetmask_kwargs)
subnetmask.select(loader)
- # should raise error if we set number of subnet parameters to be larger than number of model parameters
+ # should raise error if number of subnet parameters is larger than number of model parameters
subnetmask_kwargs.update(n_params_subnet=99999)
with pytest.raises(ValueError):
subnetmask = subnetwork_mask(**subnetmask_kwargs)
@@ -207,7 +231,8 @@ def test_score_based_subnet_masks(model, likelihood, subnetwork_mask, class_load
subnetmask.select(loader)
# define valid subnet Laplace model
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetmask.indices, hessian_structure='full')
assert isinstance(lap, SubnetLaplace)
# fit Laplace model
@@ -230,7 +255,8 @@ def test_layer_subnet_masks(model, likelihood, subnetwork_mask, class_loader, re
subnetmask_kwargs = dict(model=model)
# fit last-layer Laplace model
- lllap = Laplace(model, likelihood=likelihood, subset_of_weights='last_layer', hessian_structure='full')
+ lllap = Laplace(model, likelihood=likelihood, subset_of_weights='last_layer',
+ hessian_structure='full')
lllap.fit(loader)
# should raise error if we pass number of subnet parameters
@@ -259,11 +285,13 @@ def test_layer_subnet_masks(model, likelihood, subnetwork_mask, class_loader, re
subnetmask = subnetwork_mask(**subnetmask_kwargs)
subnetmask.select(loader)
- # define last-layer Laplace model by parameter names and check that Hessian is identical to that of a full LLLaplace model
+ # define last-layer Laplace model by parameter names and check that
+ # Hessian is identical to that of a full LLLaplace model
subnetmask_kwargs.update(parameter_names=['1.weight', '1.bias'])
subnetmask = subnetwork_mask(**subnetmask_kwargs)
subnetmask.select(loader)
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetmask.indices, hessian_structure='full')
lap.fit(loader)
assert lllap.H.equal(lap.H)
@@ -278,7 +306,8 @@ def test_layer_subnet_masks(model, likelihood, subnetwork_mask, class_loader, re
# select subnet mask and fit Laplace model
subnetmask.select(loader)
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetmask.indices, hessian_structure='full')
lap.fit(loader)
assert isinstance(lap, SubnetLaplace)
@@ -301,11 +330,13 @@ def test_layer_subnet_masks(model, likelihood, subnetwork_mask, class_loader, re
subnetmask = subnetwork_mask(**subnetmask_kwargs)
subnetmask.select(loader)
- # define last-layer Laplace model by module name and check that Hessian is identical to that of a full LLLaplace model
+ # define last-layer Laplace model by module name and check that
+ # Hessian is identical to that of a full LLLaplace model
subnetmask_kwargs.update(module_names=['1'])
subnetmask = subnetwork_mask(**subnetmask_kwargs)
subnetmask.select(loader)
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetmask.indices, hessian_structure='full')
lap.fit(loader)
assert lllap.H.equal(lap.H)
@@ -320,7 +351,8 @@ def test_layer_subnet_masks(model, likelihood, subnetwork_mask, class_loader, re
# select subnet mask and fit Laplace model
subnetmask.select(loader)
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetmask.indices, hessian_structure='full')
lap.fit(loader)
assert isinstance(lap, SubnetLaplace)
@@ -341,7 +373,8 @@ def test_layer_subnet_masks(model, likelihood, subnetwork_mask, class_loader, re
# select subnet mask and fit Laplace model
subnetmask.select(loader)
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetmask.indices, hessian_structure='full')
lap.fit(loader)
assert isinstance(lap, SubnetLaplace)
@@ -359,7 +392,8 @@ def test_layer_subnet_masks(model, likelihood, subnetwork_mask, class_loader, re
# select subnet mask and fit Laplace model
subnetmask.select(loader)
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetmask.indices, hessian_structure='full')
lap.fit(loader)
assert isinstance(lap, SubnetLaplace)
@@ -389,7 +423,8 @@ def get_subnet_mask(self, train_loader):
subnetwork_mask = FullSubnetMask
subnetmask = subnetwork_mask(model=model)
subnetmask.select(loader)
- lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure='full')
+ lap = Laplace(model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetmask.indices, hessian_structure='full')
lap.fit(loader)
assert isinstance(lap, SubnetLaplace)
@@ -399,7 +434,8 @@ def get_subnet_mask(self, train_loader):
assert lap.n_params_subnet == model.n_params
# check that the Hessian is identical to that of a all-weights FullLaplace model
- full_lap = Laplace(model, likelihood=likelihood, subset_of_weights='all', hessian_structure='full')
+ full_lap = Laplace(model, likelihood=likelihood, subset_of_weights='all',
+ hessian_structure='full')
full_lap.fit(loader)
assert full_lap.H.equal(lap.H)
@@ -421,7 +457,8 @@ def test_regression_predictive(model, reg_loader, subnetwork_mask):
subnetmask = subnetwork_mask(**subnetmask_kwargs)
subnetmask.select(reg_loader)
- lap = Laplace(model, likelihood='regression', subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure='full')
+ lap = Laplace(model, likelihood='regression', subset_of_weights='subnetwork',
+ subnetwork_indices=subnetmask.indices, hessian_structure='full')
assert isinstance(lap, SubnetLaplace)
lap.fit(reg_loader)
@@ -462,7 +499,8 @@ def test_classification_predictive(model, class_loader, subnetwork_mask):
subnetmask = subnetwork_mask(**subnetmask_kwargs)
subnetmask.select(class_loader)
- lap = Laplace(model, likelihood='classification', subset_of_weights='subnetwork', subnetwork_indices=subnetmask.indices, hessian_structure='full')
+ lap = Laplace(model, likelihood='classification', subset_of_weights='subnetwork',
+ subnetwork_indices=subnetmask.indices, hessian_structure='full')
assert isinstance(lap, SubnetLaplace)
lap.fit(class_loader)
From ddc15c55e37b36ba2826661155c4f849da38454f Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Wed, 12 Jan 2022 11:56:42 +0000
Subject: [PATCH 47/49] Add call to _init_H() in the Subnet Laplace constructor
---
laplace/subnetlaplace.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/laplace/subnetlaplace.py b/laplace/subnetlaplace.py
index 32767ba2..86178ba6 100644
--- a/laplace/subnetlaplace.py
+++ b/laplace/subnetlaplace.py
@@ -80,6 +80,7 @@ def __init__(self, model, likelihood, subnetwork_indices, sigma_noise=1., prior_
self._check_subnetwork_indices(subnetwork_indices)
self.backend.subnetwork_indices = subnetwork_indices
self.n_params_subnet = len(subnetwork_indices)
+ self._init_H()
def _init_H(self):
self.H = torch.zeros(self.n_params_subnet, self.n_params_subnet, device=self._device)
From 0952ead86a528a73ddfc07493923fb1847a65833 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Wed, 12 Jan 2022 12:07:50 +0000
Subject: [PATCH 48/49] Add test for instantiating Subnet Laplace with large
model
---
tests/test_subnetlaplace.py | 25 +++++++++++++++++++++++++
1 file changed, 25 insertions(+)
diff --git a/tests/test_subnetlaplace.py b/tests/test_subnetlaplace.py
index 10b3c319..f51f5a5e 100644
--- a/tests/test_subnetlaplace.py
+++ b/tests/test_subnetlaplace.py
@@ -5,6 +5,7 @@
from torch import nn
from torch.nn.utils import parameters_to_vector
from torch.utils.data import DataLoader, TensorDataset
+from torchvision.models import wide_resnet50_2
from laplace import Laplace, SubnetLaplace
from laplace.baselaplace import DiagLaplace
@@ -30,6 +31,12 @@ def model():
return model
+@pytest.fixture
+def large_model():
+ model = wide_resnet50_2()
+ return model
+
+
@pytest.fixture
def class_loader():
X = torch.randn(10, 3)
@@ -78,6 +85,24 @@ def test_subnet_laplace_init(model, likelihood):
subnetwork_indices=subnetmask.indices, hessian_structure=hessian_structure)
+@pytest.mark.parametrize('likelihood', likelihoods)
+def test_subnet_laplace_large_init(large_model, likelihood):
+ # use random subnet mask for this test
+ subnetwork_mask = RandomSubnetMask
+ n_param_subnet = 10
+ subnetmask_kwargs = dict(model=large_model, n_params_subnet=n_param_subnet)
+ subnetmask = subnetwork_mask(**subnetmask_kwargs)
+ subnetmask.select()
+
+ lap = Laplace(large_model, likelihood=likelihood, subset_of_weights='subnetwork',
+ subnetwork_indices=subnetmask.indices, hessian_structure='full')
+ assert lap.n_params_subnet == n_param_subnet
+ assert lap.H.shape == (lap.n_params_subnet, lap.n_params_subnet)
+ H = lap.H.clone()
+ lap._init_H()
+ assert torch.allclose(H, lap.H)
+
+
@pytest.mark.parametrize('likelihood', likelihoods)
def test_custom_subnetwork_indices(model, likelihood, class_loader, reg_loader):
loader = class_loader if likelihood == 'classification' else reg_loader
From 83877b29bcc1a040718a1d127cfb5773348cca84 Mon Sep 17 00:00:00 2001
From: "Erik A. Daxberger"
Date: Wed, 12 Jan 2022 12:12:41 +0000
Subject: [PATCH 49/49] Update docs
---
docs/baselaplace.html | 505 ++++++-----
docs/curvature/asdl.html | 54 +-
docs/curvature/backpack.html | 15 +-
docs/curvature/curvature.html | 36 +-
docs/curvature/index.html | 105 ++-
docs/index.html | 441 ++++++++--
docs/laplace.html | 4 +-
docs/lllaplace.html | 208 +++--
docs/regression_example.png | Bin 27924 -> 28052 bytes
docs/regression_example_online.png | Bin 28674 -> 28716 bytes
docs/subnetlaplace.html | 171 ++++
docs/{ => utils}/feature_extractor.html | 30 +-
docs/utils/index.html | 1017 +++++++++++++++++++++++
docs/{ => utils}/matrix.html | 80 +-
docs/utils/subnetmask.html | 466 +++++++++++
docs/utils/swag.html | 102 +++
docs/{ => utils}/utils.html | 44 +-
17 files changed, 2808 insertions(+), 470 deletions(-)
create mode 100644 docs/subnetlaplace.html
rename docs/{ => utils}/feature_extractor.html (84%)
create mode 100644 docs/utils/index.html
rename docs/{ => utils}/matrix.html (75%)
create mode 100644 docs/utils/subnetmask.html
create mode 100644 docs/utils/swag.html
rename docs/{ => utils}/utils.html (84%)
diff --git a/docs/baselaplace.html b/docs/baselaplace.html
index fbaa0a07..ea13c62b 100644
--- a/docs/baselaplace.html
+++ b/docs/baselaplace.html
@@ -172,6 +172,253 @@ Parameters
+
+class ParametricLaplace
+(model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, backend_kwargs=None)
+
+
+Parametric Laplace class.
+
Subclasses need to specify how the Hessian approximation is initialized,
+how to add up curvature over training data, how to sample from the
+Laplace approximation, and how to compute the functional variance.
+
A Laplace approximation is represented by a MAP which is given by the
+model
parameter and a posterior precision or covariance specifying
+a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}).
+The goal of this class is to compute the posterior precision P
+which sums as
+
+P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta)
+\vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}.
+
+Every subclass implements different approximations to the log likelihood Hessians,
+for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have
+a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 .
+In particular, we assume a scalar, layer-wise, or diagonal prior precision so that in
+all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.
+Ancestors
+
+Subclasses
+
+Instance variables
+
+var scatter
+-
+
Computes the scatter, a term of the log marginal likelihood that
+corresponds to L-2 regularization:
+scatter
= (\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) .
+
Returns
+
[type]
+[description]
+
+var log_det_prior_precision
+-
+
Compute log determinant of the prior precision
+\log \det P_0
+
Returns
+
+log_det
: torch.Tensor
+-
+
+
+var log_det_posterior_precision
+-
+
Compute log determinant of the posterior precision
+\log \det P which depends on the subclasses structure
+used for the Hessian approximation.
+
Returns
+
+log_det
: torch.Tensor
+-
+
+
+var log_det_ratio
+-
+
Compute the log determinant ratio, a part of the log marginal likelihood.
+
+\log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0
+
+
Returns
+
+log_det_ratio
: torch.Tensor
+-
+
+
+var posterior_precision
+-
+
Compute or return the posterior precision P.
+
Returns
+
+posterior_prec
: torch.Tensor
+-
+
+
+
+Methods
+
+
+def fit(self, train_loader, override=True)
+
+-
+
Fit the local Laplace approximation at the parameters of the model.
+
Parameters
+
+train_loader
: torch.data.utils.DataLoader
+- each iterate is a training batch (X, y);
+
train_loader.dataset
needs to be set to access N, size of the data set
+override
: bool
, default=True
+- whether to initialize H, loss, and n_data again; setting to False is useful for
+online learning settings to accumulate a sequential posterior approximation.
+
+
+
+def square_norm(self, value)
+
+-
+
Compute the square norm under post. Precision with value-self.mean
as 𝛥:
+
+\Delta^
+op P \Delta
+
+Returns
+
+
+square_form
+-
+
+
+
+def log_prob(self, value, normalized=True)
+
+-
+
Compute the log probability under the (current) Laplace approximation.
+
Parameters
+
+normalized
: bool
, default=True
+- whether to return log of a properly normalized Gaussian or just the
+terms that depend on
value
.
+
+
Returns
+
+log_prob
: torch.Tensor
+-
+
+
+
+def log_marginal_likelihood(self, prior_precision=None, sigma_noise=None)
+
+-
+
Compute the Laplace approximation to the log marginal likelihood subject
+to specific Hessian approximations that subclasses implement.
+Requires that the Laplace approximation has been fit before.
+The resulting torch.Tensor is differentiable in prior_precision
and
+sigma_noise
if these have gradients enabled.
+By passing prior_precision
or sigma_noise
, the current value is
+overwritten. This is useful for iterating on the log marginal likelihood.
+
Parameters
+
+prior_precision
: torch.Tensor
, optional
+- prior precision if should be changed from current
prior_precision
value
+sigma_noise
: [type]
, optional
+- observation noise standard deviation if should be changed
+
+
Returns
+
+log_marglik
: torch.Tensor
+-
+
+
+
+def predictive_samples(self, x, pred_type='glm', n_samples=100)
+
+-
+
Sample from the posterior predictive on input data x
.
+Can be used, for example, for Thompson sampling.
+
Parameters
+
+x
: torch.Tensor
+- input data
(batch_size, input_shape)
+pred_type
: {'glm', 'nn'}
, default='glm'
+- type of posterior predictive, linearized GLM predictive or neural
+network sampling predictive. The GLM predictive is consistent with
+the curvature approximations used here.
+n_samples
: int
+- number of samples
+
+
Returns
+
+samples
: torch.Tensor
+- samples
(n_samples, batch_size, output_shape)
+
+
+
+def functional_variance(self, Jacs)
+
+-
+
Compute functional variance for the 'glm'
predictive:
+f_var[i] = Jacs[i] @ P.inv() @ Jacs[i].T
, which is a output x output
+predictive covariance matrix.
+Mathematically, we have for a single Jacobian
+\mathcal{J} = \nabla_\theta f(x;\theta)\vert_{\theta_{MAP}}
+the output covariance matrix
+ \mathcal{J} P^{-1} \mathcal{J}^T .
+
Parameters
+
+Jacs
: torch.Tensor
+- Jacobians of model output wrt parameters
+
(batch, outputs, parameters)
+
+
Returns
+
+f_var
: torch.Tensor
+- output covariance
(batch, outputs, outputs)
+
+
+
+def sample(self, n_samples=100)
+
+-
+
Sample from the Laplace posterior approximation, i.e.,
+ \theta \sim \mathcal{N}(\theta_{MAP}, P^{-1}).
+
Parameters
+
+n_samples
: int
, default=100
+- number of samples
+
+
+
+def optimize_prior_precision(self, method='marglik', pred_type='glm', n_steps=100, lr=0.1, init_prior_prec=1.0, val_loader=None, loss=<function get_nll>, log_prior_prec_min=-4, log_prior_prec_max=4, grid_size=100, link_approx='probit', n_samples=100, verbose=False, cv_loss_with_var=False)
+
+-
+
+
+
+Inherited members
+
+
class FullLaplace
(model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, backend_kwargs=None)
@@ -190,6 +437,7 @@ Ancestors
Subclasses
Instance variables
@@ -233,11 +481,13 @@ Inherited members
log_det_ratio
log_likelihood
log_marginal_likelihood
+log_prob
optimize_prior_precision_base
predictive_samples
prior_precision_diag
sample
scatter
+square_norm
@@ -252,7 +502,7 @@ Inherited members
Mathematically, we have for each parameter group, e.g., torch.nn.Module,
that \P\approx Q \otimes H.
See BaseLaplace
for the full interface and see
-Kron
and KronDecomposed
for the structure of
+Kron
and KronDecomposed
for the structure of
the Kronecker factors. Kron
is used to aggregate factors by summing up and
KronDecomposed
is used to add the prior, a Hessian factor (e.g. temperature),
and computing posterior covariances, marginal likelihood, etc.
@@ -273,7 +523,7 @@ Instance variables
@@ -293,11 +543,13 @@ Inherited members
log_det_ratio
log_likelihood
log_marginal_likelihood
+log_prob
optimize_prior_precision_base
predictive_samples
prior_precision_diag
sample
scatter
+square_norm
@@ -361,218 +613,80 @@ Inherited members
log_det_ratio
log_likelihood
log_marginal_likelihood
+log_prob
optimize_prior_precision_base
predictive_samples
prior_precision_diag
sample
scatter
+square_norm
-
-class ParametricLaplace
-(model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, backend_kwargs=None)
+
+class LowRankLaplace
+(model, likelihood, sigma_noise=1, prior_precision=1, prior_mean=0, temperature=1, backend=laplace.curvature.asdl.AsdlHessian, backend_kwargs=None)
-Parametric Laplace class.
-
Subclasses need to specify how the Hessian approximation is initialized,
-how to add up curvature over training data, how to sample from the
-Laplace approximation, and how to compute the functional variance.
-
A Laplace approximation is represented by a MAP which is given by the
-model
parameter and a posterior precision or covariance specifying
-a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}).
-The goal of this class is to compute the posterior precision P
-which sums as
-
-P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta)
-\vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}.
-
-Every subclass implements different approximations to the log likelihood Hessians,
-for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have
-a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 .
-In particular, we assume a scalar, layer-wise, or diagonal prior precision so that in
-all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.
+Laplace approximation with low-rank log likelihood Hessian (approximation).
+The low-rank matrix is represented by an eigendecomposition (vecs, values).
+Based on the chosen backend
, either a true Hessian or, for example, GGN
+approximation could be used.
+The posterior precision is computed as
+ P = V diag(l) V^T + P_0.
+To sample, compute the functional variance, and log determinant, algebraic tricks
+are usedto reduce the costs of inversion to the that of a K
+imes K matrix
+if we have a rank of K.
+
See BaseLaplace
for the full interface.
Ancestors
-Subclasses
-
Instance variables
-var scatter
--
-
Computes the scatter, a term of the log marginal likelihood that
-corresponds to L-2 regularization:
-scatter
= (\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) .
-
Returns
-
[type]
-[description]
-
-var log_det_prior_precision
--
-
Compute log determinant of the prior precision
-\log \det P_0
-
Returns
-
-log_det
: torch.Tensor
--
-
-
-var log_det_posterior_precision
+var V
-
-
Compute log determinant of the posterior precision
-\log \det P which depends on the subclasses structure
-used for the Hessian approximation.
-
Returns
-
-log_det
: torch.Tensor
--
-
-
-var log_det_ratio
--
-
Compute the log determinant ratio, a part of the log marginal likelihood.
-
-\log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0
-
-
Returns
-
-log_det_ratio
: torch.Tensor
--
-
-
-var posterior_precision
--
-
Compute or return the posterior precision P.
-
Returns
-
-posterior_prec
: torch.Tensor
--
-
-
-
-Methods
-
-
-def fit(self, train_loader)
-
--
-
Fit the local Laplace approximation at the parameters of the model.
-
Parameters
-
-train_loader
: torch.data.utils.DataLoader
-- each iterate is a training batch (X, y);
-
train_loader.dataset
needs to be set to access N, size of the data set
-
-
-
-def log_marginal_likelihood(self, prior_precision=None, sigma_noise=None)
-
--
-
Compute the Laplace approximation to the log marginal likelihood subject
-to specific Hessian approximations that subclasses implement.
-Requires that the Laplace approximation has been fit before.
-The resulting torch.Tensor is differentiable in prior_precision
and
-sigma_noise
if these have gradients enabled.
-By passing prior_precision
or sigma_noise
, the current value is
-overwritten. This is useful for iterating on the log marginal likelihood.
-
Parameters
-
-prior_precision
: torch.Tensor
, optional
-- prior precision if should be changed from current
prior_precision
value
-sigma_noise
: [type]
, optional
-- observation noise standard deviation if should be changed
-
-
Returns
-
-log_marglik
: torch.Tensor
--
-
+
-
-def predictive_samples(self, x, pred_type='glm', n_samples=100)
-
+var Kinv
-
-
Sample from the posterior predictive on input data x
.
-Can be used, for example, for Thompson sampling.
-
Parameters
-
-x
: torch.Tensor
-- input data
(batch_size, input_shape)
-pred_type
: {'glm', 'nn'}
, default='glm'
-- type of posterior predictive, linearized GLM predictive or neural
-network sampling predictive. The GLM predictive is consistent with
-the curvature approximations used here.
-n_samples
: int
-- number of samples
-
-
Returns
-
-samples
: torch.Tensor
-- samples
(n_samples, batch_size, output_shape)
-
+
-
-def functional_variance(self, Jacs)
-
+var posterior_precision
-
-
Compute functional variance for the 'glm'
predictive:
-f_var[i] = Jacs[i] @ P.inv() @ Jacs[i].T
, which is a output x output
-predictive covariance matrix.
-Mathematically, we have for a single Jacobian
-\mathcal{J} = \nabla_\theta f(x;\theta)\vert_{\theta_{MAP}}
-the output covariance matrix
- \mathcal{J} P^{-1} \mathcal{J}^T .
-
Parameters
-
-Jacs
: torch.Tensor
-- Jacobians of model output wrt parameters
-
(batch, outputs, parameters)
-
+
Return correctly scaled posterior precision that would be constructed
+as H[0] @ diag(H[1]) @ H[0].T + self.prior_precision_diag.
Returns
-f_var
: torch.Tensor
-- output covariance
(batch, outputs, outputs)
-
-
-
-def sample(self, n_samples=100)
-
-
-
-
Sample from the Laplace posterior approximation, i.e.,
- \theta \sim \mathcal{N}(\theta_{MAP}, P^{-1}).
-
Parameters
-
-n_samples
: int
, default=100
-- number of samples
+H
: tuple(eigenvectors, eigenvalues)
+- scaled self.H with temperature and loss factors.
+prior_precision_diag
: torch.Tensor
+- diagonal prior precision shape
parameters
to be added to H.
-
-def optimize_prior_precision(self, method='marglik', pred_type='glm', n_steps=100, lr=0.1, init_prior_prec=1.0, val_loader=None, loss=<function get_nll>, log_prior_prec_min=-4, log_prior_prec_max=4, grid_size=100, link_approx='probit', n_samples=100, verbose=False, cv_loss_with_var=False)
-
-
-
-
-
Inherited members
@@ -603,18 +717,11 @@
-
-
-
-
-
-
-
-
-
-
-
+
-
+
+
+
-
+
+
+
-
+
+
+
-
+
+
diff --git a/docs/curvature/asdl.html b/docs/curvature/asdl.html
index 23b40a34..ecba76dd 100644
--- a/docs/curvature/asdl.html
+++ b/docs/curvature/asdl.html
@@ -35,7 +35,7 @@
class AsdlInterface
-(model, likelihood, last_layer=False)
+(model, likelihood, last_layer=False, subnetwork_indices=None)
-
Interface for asdfghjkl backend.
@@ -47,19 +47,18 @@ Subclasses
-Static methods
+Methods
-def jacobians(model, x)
+def jacobians(self, x)
-
Compute Jacobians \nabla_\theta f(x;\theta) at current parameter \theta
using asdfghjkl's gradient per output dimension.
Parameters
-model
: torch.nn.Module
--
x
: torch.Tensor
- input data
(batch, input_shape)
on compatible device with model.
@@ -71,9 +70,6 @@
Returns
- output function
(batch, outputs)
-
-Methods
-
def gradients(self, x, y)
@@ -108,9 +104,43 @@ Inherited members
+
+class AsdlHessian
+(model, likelihood, last_layer=False, low_rank=10)
+
+-
+
Interface for asdfghjkl backend.
+Ancestors
+
+Methods
+
+
+def eig_lowrank(self, data_loader)
+
+-
+
+
+
+Inherited members
+
+
class AsdlGGN
-(model, likelihood, last_layer=False, stochastic=False)
+(model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False)
-
Implementation of the GGNInterface
using asdfghjkl.
@@ -184,6 +214,12 @@
+
+
+-
-
diff --git a/docs/curvature/backpack.html b/docs/curvature/backpack.html
index 0e610d54..1ae69561 100644
--- a/docs/curvature/backpack.html
+++ b/docs/curvature/backpack.html
@@ -35,7 +35,7 @@
class BackPackInterface
-(model, likelihood, last_layer=False)
+(model, likelihood, last_layer=False, subnetwork_indices=None)
-
Interface for Backpack backend.
@@ -48,18 +48,16 @@ Subclasses
- BackPackEF
- BackPackGGN
-Static methods
+Methods
-def jacobians(model, x)
+def jacobians(self, x)
-
Compute Jacobians \nabla_{\theta} f(x;\theta) at current parameter \theta
using backpack's BatchGrad per output dimension.
Parameters
-model
: torch.nn.Module
--
x
: torch.Tensor
- input data
(batch, input_shape)
on compatible device with model.
@@ -71,9 +69,6 @@
Returns
- output function
(batch, outputs)
-
-Methods
-
def gradients(self, x, y)
@@ -110,7 +105,7 @@ Inherited members
class BackPackGGN
-(model, likelihood, last_layer=False, stochastic=False)
+(model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False)
-
Implementation of the GGNInterface
using Backpack.
@@ -136,7 +131,7 @@ Inherited members
class BackPackEF
-(model, likelihood, last_layer=False)
+(model, likelihood, last_layer=False, subnetwork_indices=None)
-
Implementation of EFInterface
using Backpack.
diff --git a/docs/curvature/curvature.html b/docs/curvature/curvature.html
index 084432df..645baae7 100644
--- a/docs/curvature/curvature.html
+++ b/docs/curvature/curvature.html
@@ -35,7 +35,7 @@
class CurvatureInterface
-(model, likelihood, last_layer=False)
+(model, likelihood, last_layer=False, subnetwork_indices=None)
-
Interface to access curvature for a model and corresponding likelihood.
@@ -45,12 +45,15 @@
structures, for example, a block-diagonal one.
Parameters
-model
: torch.nn.Module
or FeatureExtractor
+model
: torch.nn.Module
or FeatureExtractor
- torch model (neural network)
likelihood
: {'classification', 'regression'}
-
last_layer
: bool
, default=False
- only consider curvature of last layer
+subnetwork_indices
: torch.Tensor
, default=None
+- indices of the vectorized model parameters that define the subnetwork
+to apply the Laplace approximation over
Attributes
@@ -67,17 +70,15 @@ Subclasses
- EFInterface
- GGNInterface
-
Static methods
+
Methods
-def jacobians(model, x)
+def jacobians(self, x)
-
Compute Jacobians \nabla_\theta f(x;\theta) at current parameter \theta.
Parameters
-model
: torch.nn.Module
--
x
: torch.Tensor
- input data
(batch, input_shape)
on compatible device with model.
@@ -90,15 +91,13 @@
Returns
-def last_layer_jacobians(model, x)
+def last_layer_jacobians(self, x)
-
Compute Jacobians \nabla_{\theta_\textrm{last}} f(x;\theta_\textrm{last})
only at current last-layer parameter \theta_{\textrm{last}}.
Parameters
-model
: FeatureExtractor
--
x
: torch.Tensor
-
@@ -110,9 +109,6 @@
Returns
- output function
(batch, outputs)
-
-
Methods
-
def gradients(self, x, y)
@@ -175,7 +171,7 @@ Returns
loss
: torch.Tensor
-
-H
: Kron
+H
: Kron
- Kronecker factored Hessian approximation.
@@ -204,7 +200,7 @@ Returns
class GGNInterface
-(model, likelihood, last_layer=False, stochastic=False)
+(model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False)
-
Generalized Gauss-Newton or Fisher Curvature Interface.
@@ -212,12 +208,15 @@
Returns
In addition to
CurvatureInterface
, methods for Jacobians are required by subclasses.
Parameters
-model
: torch.nn.Module
or FeatureExtractor
+model
: torch.nn.Module
or FeatureExtractor
- torch model (neural network)
likelihood
: {'classification', 'regression'}
-
last_layer
: bool
, default=False
- only consider curvature of last layer
+subnetwork_indices
: torch.Tensor
, default=None
+- indices of the vectorized model parameters that define the subnetwork
+to apply the Laplace approximation over
stochastic
: bool
, default=False
- Fisher if stochastic else GGN
@@ -270,19 +269,22 @@ Inherited members
class EFInterface
-(model, likelihood, last_layer=False)
+(model, likelihood, last_layer=False, subnetwork_indices=None)
-
Interface for Empirical Fisher as Hessian approximation.
In addition to CurvatureInterface
, methods for gradients are required by subclasses.
Parameters
-model
: torch.nn.Module
or FeatureExtractor
+model
: torch.nn.Module
or FeatureExtractor
- torch model (neural network)
likelihood
: {'classification', 'regression'}
-
last_layer
: bool
, default=False
- only consider curvature of last layer
+subnetwork_indices
: torch.Tensor
, default=None
+- indices of the vectorized model parameters that define the subnetwork
+to apply the Laplace approximation over
Attributes
diff --git a/docs/curvature/index.html b/docs/curvature/index.html
index 72e1203b..00001e2f 100644
--- a/docs/curvature/index.html
+++ b/docs/curvature/index.html
@@ -50,7 +50,7 @@
class CurvatureInterface
-(model, likelihood, last_layer=False)
+(model, likelihood, last_layer=False, subnetwork_indices=None)
-
Interface to access curvature for a model and corresponding likelihood.
@@ -60,12 +60,15 @@
structures, for example, a block-diagonal one.
Parameters
-model
: torch.nn.Module
or FeatureExtractor
+model
: torch.nn.Module
or FeatureExtractor
- torch model (neural network)
likelihood
: {'classification', 'regression'}
-
last_layer
: bool
, default=False
- only consider curvature of last layer
+subnetwork_indices
: torch.Tensor
, default=None
+- indices of the vectorized model parameters that define the subnetwork
+to apply the Laplace approximation over
Attributes
@@ -82,17 +85,15 @@ Subclasses
- EFInterface
- GGNInterface
-
Static methods
+
Methods
-def jacobians(model, x)
+def jacobians(self, x)
-
Compute Jacobians \nabla_\theta f(x;\theta) at current parameter \theta.
Parameters
-model
: torch.nn.Module
--
x
: torch.Tensor
- input data
(batch, input_shape)
on compatible device with model.
@@ -105,15 +106,13 @@
Returns
-def last_layer_jacobians(model, x)
+def last_layer_jacobians(self, x)
-
Compute Jacobians \nabla_{\theta_\textrm{last}} f(x;\theta_\textrm{last})
only at current last-layer parameter \theta_{\textrm{last}}.
Parameters
-model
: FeatureExtractor
--
x
: torch.Tensor
-
@@ -125,9 +124,6 @@
Returns
- output function
(batch, outputs)
-
-
Methods
-
def gradients(self, x, y)
@@ -190,7 +186,7 @@ Returns
loss
: torch.Tensor
-
-H
: Kron
+H
: Kron
- Kronecker factored Hessian approximation.
@@ -219,7 +215,7 @@ Returns
class GGNInterface
-(model, likelihood, last_layer=False, stochastic=False)
+(model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False)
-
Generalized Gauss-Newton or Fisher Curvature Interface.
@@ -227,12 +223,15 @@
Returns
In addition to
CurvatureInterface
, methods for Jacobians are required by subclasses.
Parameters
-model
: torch.nn.Module
or FeatureExtractor
+model
: torch.nn.Module
or FeatureExtractor
- torch model (neural network)
likelihood
: {'classification', 'regression'}
-
last_layer
: bool
, default=False
- only consider curvature of last layer
+subnetwork_indices
: torch.Tensor
, default=None
+- indices of the vectorized model parameters that define the subnetwork
+to apply the Laplace approximation over
stochastic
: bool
, default=False
- Fisher if stochastic else GGN
@@ -285,19 +284,22 @@ Inherited members
class EFInterface
-(model, likelihood, last_layer=False)
+(model, likelihood, last_layer=False, subnetwork_indices=None)
-
Interface for Empirical Fisher as Hessian approximation.
In addition to CurvatureInterface
, methods for gradients are required by subclasses.
Parameters
-model
: torch.nn.Module
or FeatureExtractor
+model
: torch.nn.Module
or FeatureExtractor
- torch model (neural network)
likelihood
: {'classification', 'regression'}
-
last_layer
: bool
, default=False
- only consider curvature of last layer
+subnetwork_indices
: torch.Tensor
, default=None
+- indices of the vectorized model parameters that define the subnetwork
+to apply the Laplace approximation over
Attributes
@@ -356,7 +358,7 @@ Inherited members
class BackPackInterface
-(model, likelihood, last_layer=False)
+(model, likelihood, last_layer=False, subnetwork_indices=None)
-
Interface for Backpack backend.
@@ -369,18 +371,16 @@ Subclasses
- BackPackEF
- BackPackGGN
-Static methods
+Methods
-def jacobians(model, x)
+def jacobians(self, x)
-
Compute Jacobians \nabla_{\theta} f(x;\theta) at current parameter \theta
using backpack's BatchGrad per output dimension.
Parameters
-model
: torch.nn.Module
--
x
: torch.Tensor
- input data
(batch, input_shape)
on compatible device with model.
@@ -392,9 +392,6 @@
Returns
- output function
(batch, outputs)
-
-Methods
-
def gradients(self, x, y)
@@ -431,7 +428,7 @@ Inherited members
class BackPackGGN
-(model, likelihood, last_layer=False, stochastic=False)
+(model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False)
-
@@ -457,7 +454,7 @@
Inherited members
class BackPackEF
-(model, likelihood, last_layer=False)
+(model, likelihood, last_layer=False, subnetwork_indices=None)
-
@@ -483,7 +480,7 @@
Inherited members
class AsdlInterface
-(model, likelihood, last_layer=False)
+(model, likelihood, last_layer=False, subnetwork_indices=None)
-
Interface for asdfghjkl backend.
@@ -495,19 +492,18 @@ Subclasses
-Static methods
+Methods
-def jacobians(model, x)
+def jacobians(self, x)
-
Compute Jacobians \nabla_\theta f(x;\theta) at current parameter \theta
using asdfghjkl's gradient per output dimension.
Parameters
-model
: torch.nn.Module
--
x
: torch.Tensor
- input data
(batch, input_shape)
on compatible device with model.
@@ -519,9 +515,6 @@
Returns
- output function
(batch, outputs)
-
-Methods
-
def gradients(self, x, y)
@@ -558,7 +551,7 @@ Inherited members
class AsdlGGN
-(model, likelihood, last_layer=False, stochastic=False)
+(model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False)
-
@@ -608,6 +601,40 @@
Inherited members
+
+class AsdlHessian
+(model, likelihood, last_layer=False, low_rank=10)
+
+-
+
Interface for asdfghjkl backend.
+Ancestors
+
+Methods
+
+
+def eig_lowrank(self, data_loader)
+
+-
+
+
+
+Inherited members
+
+
@@ -680,6 +707,12 @@
-The laplace package facilitates the application of Laplace approximations for entire neural networks or just their last layer.
+
The laplace package facilitates the application of Laplace approximations for entire neural networks, subnetworks of neural networks, or just their last layer.
The package enables posterior approximations, marginal-likelihood estimation, and various posterior predictive computations.
The library documentation is available at https://aleximmer.github.io/Laplace.
There is also a corresponding paper, Laplace Redux — Effortless Bayesian Deep Learning, which introduces the library, provides an introduction to the Laplace approximation, reviews its use in deep learning, and empirically demonstrates its versatility and competitiveness. Please consider referring to the paper when using our library:
-@article{daxberger2021laplace,
- title={Laplace Redux--Effortless Bayesian Deep Learning},
- author={Daxberger, Erik and Kristiadi, Agustinus and Immer, Alexander
- and Eschenhagen, Runa and Bauer, Matthias and Hennig, Philipp},
- journal={arXiv preprint arXiv:2106.14806},
+@inproceedings{laplace2021,
+ title={Laplace Redux--Effortless {B}ayesian Deep Learning},
+ author={Erik Daxberger and Agustinus Kristiadi and Alexander Immer
+ and Runa Eschenhagen and Matthias Bauer and Philipp Hennig},
+ booktitle={{N}eur{IPS}},
year={2021}
}
@@ -56,34 +56,39 @@ Setup
Structure
The laplace package consists of two main components:
-- The subclasses of
laplace.BaseLaplace
that implement different sparsity structures: different subsets of weights ('all'
and 'last_layer'
) and different structures of the Hessian approximation ('full'
, 'kron'
, and 'diag'
). This results in six currently available options: FullLaplace
, KronLaplace
, DiagLaplace
, and the corresponding last-layer variations FullLLLaplace
, KronLLLaplace
,
-and DiagLLLaplace
, which are all subclasses of laplace.LLLaplace
. All of these can be conveniently accessed via the laplace.Laplace
function.
+- The subclasses of
laplace.BaseLaplace
that implement different sparsity structures: different subsets of weights ('all'
, 'subnetwork'
and 'last_layer'
) and different structures of the Hessian approximation ('full'
, 'kron'
, 'lowrank'
and 'diag'
). This results in eight currently available options: FullLaplace
, KronLaplace
, DiagLaplace
, the corresponding last-layer variations FullLLLaplace
, KronLLLaplace
,
+and DiagLLLaplace
(which are all subclasses of laplace.LLLaplace
), laplace.SubnetLaplace
(which only supports a 'full'
Hessian approximation) and LowRankLaplace
(which only supports inference over 'all'
weights). All of these can be conveniently accessed via the laplace.Laplace
function.
- The backends in
laplace.curvature
which provide access to Hessian approximations of
the corresponding sparsity structures, for example, the diagonal GGN.
Additionally, the package provides utilities for
-decomposing a neural network into feature extractor and last layer for LLLaplace
subclasses (laplace.feature_extractor
)
+decomposing a neural network into feature extractor and last layer for LLLaplace
subclasses (laplace.utils.feature_extractor
)
and
-effectively dealing with Kronecker factors (laplace.matrix
).
+effectively dealing with Kronecker factors (laplace.utils.matrix
).
+Finally, the package implements several options to select/specify a subnetwork for SubnetLaplace
(as subclasses of laplace.utils.subnetmask.SubnetMask
).
+Automatic subnetwork selection strategies include: uniformly at random (RandomSubnetMask
), by largest parameter magnitudes (LargestMagnitudeSubnetMask
), and by largest marginal parameter variances (LargestVarianceDiagLaplaceSubnetMask
and LargestVarianceSWAGSubnetMask
).
+In addition to that, subnetworks can also be specified manually, by listing the names of either the model parameters (ParamNameSubnetMask
) or modules (ModuleNameSubnetMask
) to perform Laplace inference over.
Extendability
To extend the laplace package, new BaseLaplace
subclasses can be designed, for example,
-a block-diagonal structure or subset-of-weights Laplace.
-Alternatively, extending or integrating backends (subclasses of curvature.curvature
) allows to provide different Hessian
+Laplace with a block-diagonal Hessian structure.
+One can also implement custom subnetwork selection strategies as new subclasses of SubnetMask
.
+Alternatively, extending or integrating backends (subclasses of curvature.curvature
) allows to provide different Hessian
approximations to the Laplace approximations.
For example, currently the curvature.BackPackInterface
based on BackPACK and curvature.AsdlInterface
based on ASDL are available.
The AsdlInterface
provides a Kronecker factored empirical Fisher while the BackPackInterface
does not, and only the BackPackInterface
provides access to Hessian approximations
for a regression (MSELoss) loss function.
Example usage
-Post-hoc prior precision tuning of last-layer LA
+Post-hoc prior precision tuning of diagonal LA
In the following example, a pre-trained model is loaded,
-then the Laplace approximation is fit to the training data,
+then the Laplace approximation is fit to the training data
+(using a diagonal Hessian approximation over all parameters),
and the prior precision is optimized with cross-validation 'CV'
.
After that, the resulting LA is used for prediction with
the 'probit'
predictive for classification.
from laplace import Laplace
-# pre-trained model
+# Pre-trained model
model = load_map_model()
# User-specified LA flavor
@@ -97,7 +102,7 @@ Post-hoc prio
pred = la(x, link_approx='probit')
Differentiating the log marginal likelihood w.r.t. hyperparameters
-The marginal likelihood can be used for model selection and is differentiable
+
The marginal likelihood can be used for model selection [10] and is differentiable
for continuous hyperparameters like the prior precision or observation noise.
Here, we fit the library default, KFAC last-layer LA and differentiate
the log marginal likelihood.
@@ -114,6 +119,41 @@ Differe
ml = la.log_marginal_likelihood(prior_prec, obs_noise)
ml.backward()
+Applying the LA over only a subset of the model parameters
+This example shows how to fit the Laplace approximation over only
+a subnetwork within a neural network (while keeping all other parameters
+fixed at their MAP estimates), as proposed in [11]. It also exemplifies
+different ways to specify the subnetwork to perform inference over.
+from laplace import Laplace
+
+# Pre-trained model
+model = load_model()
+
+# Examples of different ways to specify the subnetwork
+# via indices of the vectorized model parameters
+#
+# Example 1: select the 128 parameters with the largest magnitude
+from laplace.utils import LargestMagnitudeSubnetMask
+subnetwork_mask = LargestMagnitudeSubnetMask(model, n_params_subnet=128)
+subnetwork_indices = subnetwork_mask.select()
+
+# Example 2: specify the layers that define the subnetwork
+from laplace.utils import ModuleNameSubnetMask
+subnetwork_mask = ModuleNameSubnetMask(model, module_names=['layer.1', 'layer.3'])
+subnetwork_mask.select()
+subnetwork_indices = subnetwork_mask.indices
+
+# Example 3: manually define the subnetwork via custom subnetwork indices
+import torch
+subnetwork_indices = torch.tensor([0, 4, 11, 42, 123, 2021])
+
+# Define and fit subnetwork LA using the specified subnetwork indices
+la = Laplace(model, 'classification',
+ subset_of_weights='subnetwork',
+ hessian_structure='full',
+ subnetwork_indices=subnetwork_indices)
+la.fit(train_loader)
+
Documentation
The documentation is available here or can be generated and/or viewed locally:
# assuming the repository was cloned
@@ -124,7 +164,7 @@ Documentation
pdoc --http 0.0.0.0:8080 laplace --template-dir template
References
-This package relies on various improvements to the Laplace approximation for neural networks, which was originally due to MacKay [1].
+This package relies on various improvements to the Laplace approximation for neural networks, which was originally due to MacKay [1]. Please consider citing the respective papers if you use any of their proposed methods via our laplace library.
- [1] MacKay, DJC. A Practical Bayesian Framework for Backpropagation Networks. Neural Computation 1992.
- [2] Gibbs, M. N. Bayesian Gaussian Processes for Regression and Classification. PhD Thesis 1997.
@@ -134,7 +174,9 @@ References
- [6] Khan, M. E., Immer, A., Abedi, E., Korzepa, M. Approximate Inference Turns Deep Networks into Gaussian Processes. NeurIPS 2019.
- [7] Kristiadi, A., Hein, M., Hennig, P. Being Bayesian, Even Just a Bit, Fixes Overconfidence in ReLU Networks. ICML 2020.
- [8] Immer, A., Korzepa, M., Bauer, M. Improving predictions of Bayesian neural nets via local linearization. AISTATS 2021.
-- [9] Immer, A., Bauer, M., Fortuin, V., Rätsch, G., Khan, EM. Scalable Marginal Likelihood Estimation for Model Selection in Deep Learning. ICML 2021.
+- [9] Sharma, A., Azizan, N., Pavone, M. Sketching Curvature for Efficient Out-of-Distribution Detection for Deep Neural Networks. UAI 2021.
+- [10] Immer, A., Bauer, M., Fortuin, V., Rätsch, G., Khan, EM. Scalable Marginal Likelihood Estimation for Model Selection in Deep Learning. ICML 2021.
+- [11] Daxberger, E., Nalisnick, E., Allingham, JU., Antorán, J., Hernández-Lobato, JM. Bayesian Deep Learning via Subnetwork Inference. ICML 2021.
Full example: Optimization of the marginal likelihood and prediction
Sinusoidal toy data
@@ -326,10 +368,6 @@
-
- laplace.feature_extractor
--
-
-
laplace.laplace
-
@@ -338,11 +376,11 @@
-
- laplace.matrix
+laplace.subnetlaplace
-
-laplace.utils
+laplace.utils
-
@@ -364,9 +402,9 @@ Parameters
-
likelihood
: {'classification', 'regression'}
-
-subset_of_weights
: {'last_layer', 'all'}
, default='last_layer'
+subset_of_weights
: {'last_layer', 'subnetwork', 'all'}
, default='last_layer'
- subset of weights to consider for inference
-hessian_structure
: {'diag', 'kron', 'full'}
, default='kron'
+hessian_structure
: {'diag', 'kron', 'full', 'lowrank'}
, default='kron'
- structure of the Hessian approximation
Returns
@@ -636,7 +674,8 @@ Subclasses
- DiagLaplace
- FullLaplace
- KronLaplace
-
- laplace.lllaplace.LLLaplace
+
- LowRankLaplace
+
- LLLaplace
Instance variables
@@ -697,7 +736,7 @@ Returns
Methods
-def fit(self, train_loader)
+def fit(self, train_loader, override=True)
-
Fit the local Laplace approximation at the parameters of the model.
@@ -706,6 +745,45 @@
Parameters
train_loader
: torch.data.utils.DataLoader
- each iterate is a training batch (X, y);
train_loader.dataset
needs to be set to access N, size of the data set
+
override
: bool
, default=True
+
- whether to initialize H, loss, and n_data again; setting to False is useful for
+online learning settings to accumulate a sequential posterior approximation.
+
+
+
+def square_norm(self, value)
+
+-
+
Compute the square norm under post. Precision with value-self.mean
as 𝛥:
+
+\Delta^
+op P \Delta
+
+Returns
+
+
+square_form
+-
+
+
+
+def log_prob(self, value, normalized=True)
+
+-
+
Compute the log probability under the (current) Laplace approximation.
+
Parameters
+
+normalized
: bool
, default=True
+- whether to return log of a properly normalized Gaussian or just the
+terms that depend on
value
.
+
+
Returns
+
+log_prob
: torch.Tensor
+-
@@ -826,6 +904,7 @@ Ancestors
Subclasses
Instance variables
@@ -869,11 +948,13 @@ Inherited members
log_det_ratio
log_likelihood
log_marginal_likelihood
+log_prob
optimize_prior_precision_base
predictive_samples
prior_precision_diag
sample
scatter
+square_norm
@@ -888,7 +969,7 @@ Inherited members
Mathematically, we have for each parameter group, e.g., torch.nn.Module,
that \P\approx Q \otimes H.
See BaseLaplace
for the full interface and see
-Kron
and KronDecomposed
for the structure of
+Kron
and KronDecomposed
for the structure of
the Kronecker factors. Kron
is used to aggregate factors by summing up and
KronDecomposed
is used to add the prior, a Hessian factor (e.g. temperature),
and computing posterior covariances, marginal likelihood, etc.
@@ -909,7 +990,7 @@ Instance variables
@@ -929,11 +1010,13 @@ Inherited members
log_det_ratio
log_likelihood
log_marginal_likelihood
+log_prob
optimize_prior_precision_base
predictive_samples
prior_precision_diag
sample
scatter
+square_norm
@@ -997,11 +1080,80 @@ Inherited members
log_det_ratio
log_likelihood
log_marginal_likelihood
+log_prob
optimize_prior_precision_base
predictive_samples
prior_precision_diag
sample
scatter
+square_norm
+
+
+
+
+
+class LowRankLaplace
+(model, likelihood, sigma_noise=1, prior_precision=1, prior_mean=0, temperature=1, backend=laplace.curvature.asdl.AsdlHessian, backend_kwargs=None)
+
+-
+
Laplace approximation with low-rank log likelihood Hessian (approximation).
+The low-rank matrix is represented by an eigendecomposition (vecs, values).
+Based on the chosen backend
, either a true Hessian or, for example, GGN
+approximation could be used.
+The posterior precision is computed as
+ P = V diag(l) V^T + P_0.
+To sample, compute the functional variance, and log determinant, algebraic tricks
+are usedto reduce the costs of inversion to the that of a K
+imes K matrix
+if we have a rank of K.
+
See BaseLaplace
for the full interface.
+Ancestors
+
+Instance variables
+
+var V
+-
+
+
+var Kinv
+-
+
+
+var posterior_precision
+-
+
Return correctly scaled posterior precision that would be constructed
+as H[0] @ diag(H[1]) @ H[0].T + self.prior_precision_diag.
+
Returns
+
+H
: tuple(eigenvectors, eigenvalues)
+- scaled self.H with temperature and loss factors.
+prior_precision_diag
: torch.Tensor
+- diagonal prior precision shape
parameters
to be added to H.
+
+
+
+Inherited members
+
@@ -1035,7 +1187,7 @@ Inherited members
all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.
Parameters
-model
: torch.nn.Module
or FeatureExtractor
+model
: torch.nn.Module
or FeatureExtractor
-
likelihood
: {'classification', 'regression'}
- determines the log likelihood Hessian approximation
@@ -1092,11 +1244,13 @@ Inherited members
log_det_ratio
log_likelihood
log_marginal_likelihood
+log_prob
optimize_prior_precision_base
posterior_precision
predictive_samples
sample
scatter
+square_norm
@@ -1113,30 +1267,36 @@ Inherited members
See FullLaplace
, LLLaplace
, and BaseLaplace
for the full interface.
Ancestors
Inherited members
@@ -1151,35 +1311,37 @@ Inherited members
Mathematically, we have for the last parameter group, i.e., torch.nn.Linear,
that \P\approx Q \otimes H.
See KronLaplace
, LLLaplace
, and BaseLaplace
for the full interface and see
-Kron
and KronDecomposed
for the structure of
+Kron
and KronDecomposed
for the structure of
the Kronecker factors. Kron
is used to aggregate factors by summing up and
KronDecomposed
is used to add the prior, a Hessian factor (e.g. temperature),
and computing posterior covariances, marginal likelihood, etc.
Use of damping
is possible by initializing or setting damping=True
.
Ancestors
Inherited members
@@ -1195,30 +1357,143 @@ Inherited members
See DiagLaplace
, LLLaplace
, and BaseLaplace
for the full interface.
Ancestors
Inherited members
+
+
+class SubnetLaplace
+(model, likelihood, subnetwork_indices, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, backend_kwargs=None)
+
+
+Class for subnetwork Laplace, which computes the Laplace approximation over
+just a subset of the model parameters (i.e. a subnetwork within the neural network),
+as proposed in [1]. Subnetwork Laplace only supports a full Hessian approximation; other
+approximations could be used in theory, but would not make as much sense conceptually.
+
A Laplace approximation is represented by a MAP which is given by the
+model
parameter and a posterior precision or covariance specifying
+a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}).
+Here, only a subset of the model parameters (i.e. a subnetwork of the
+neural network) are treated probabilistically.
+The goal of this class is to compute the posterior precision P
+which sums as
+
+P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta)
+\vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}.
+
+The prior is assumed to be Gaussian and therefore we have a simple form for
+\nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 .
+In particular, we assume a scalar or diagonal prior precision so that in
+all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.
+
The subnetwork Laplace approximation only supports a full, i.e., dense, log likelihood
+Hessian approximation and hence posterior precision.
+Based on the chosen backend
+parameter, the full approximation can be, for example, a generalized Gauss-Newton
+matrix.
+Mathematically, we have P \in \mathbb{R}^{P \times P}.
+See FullLaplace
and BaseLaplace
for the full interface.
+
References
+
[1] Daxberger, E., Nalisnick, E., Allingham, JU., Antorán, J., Hernández-Lobato, JM.
+Bayesian Deep Learning via Subnetwork Inference.
+ICML 2021.
+
Parameters
+
+model
: torch.nn.Module
or FeatureExtractor
+-
+likelihood
: {'classification', 'regression'}
+- determines the log likelihood Hessian approximation
+subnetwork_indices
: torch.LongTensor
+- indices of the vectorized model parameters
+(i.e.
torch.nn.utils.parameters_to_vector(model.parameters())
)
+that define the subnetwork to apply the Laplace approximation over
+sigma_noise
: torch.Tensor
or float
, default=1
+- observation noise for the regression setting; must be 1 for classification
+prior_precision
: torch.Tensor
or float
, default=1
+- prior precision of a Gaussian prior (= weight decay);
+can be scalar, per-layer, or diagonal in the most general case
+prior_mean
: torch.Tensor
or float
, default=0
+- prior mean of a Gaussian prior, useful for continual learning
+temperature
: float
, default=1
+- temperature of the likelihood; lower temperature leads to more
+concentrated posterior and vice versa.
+backend
: subclasses
of CurvatureInterface
+- backend for access to curvature/Hessian approximations
+backend_kwargs
: dict
, default=None
+- arguments passed to the backend on initialization, for example to
+set the number of MC samples for stochastic approximations.
+
+Ancestors
+
+Instance variables
+
+var prior_precision_diag
+-
+
Obtain the diagonal prior precision p_0 constructed from either
+a scalar or diagonal prior precision.
+
Returns
+
+prior_precision_diag
: torch.Tensor
+-
+
+
+
+Inherited members
+
@@ -1234,8 +1509,9 @@ Index
Structure
Extendability
Example usage
Documentation
@@ -1262,11 +1538,10 @@ Index
Functions
@@ -1290,6 +1565,8 @@ BaseLaplace
ParametricLaplace
+LowRankLaplace
+
+
LLLaplace
@@ -1318,6 +1598,9 @@ KronLLL
DiagLLLaplace
+
+SubnetLaplace
+
diff --git a/docs/laplace.html b/docs/laplace.html
index d72602d6..99dae2b8 100644
--- a/docs/laplace.html
+++ b/docs/laplace.html
@@ -42,9 +42,9 @@ Parameters
likelihood
: {'classification', 'regression'}
-subset_of_weights
: {'last_layer', 'all'}
, default='last_layer'
+subset_of_weights
: {'last_layer', 'subnetwork', 'all'}
, default='last_layer'
subset of weights to consider for inference
-hessian_structure
: {'diag', 'kron', 'full'}
, default='kron'
+hessian_structure
: {'diag', 'kron', 'full', 'lowrank'}
, default='kron'
structure of the Hessian approximation
Returns
diff --git a/docs/lllaplace.html b/docs/lllaplace.html
index 108e9b0b..6ea940b2 100644
--- a/docs/lllaplace.html
+++ b/docs/lllaplace.html
@@ -33,6 +33,103 @@ Module laplace.lllaplace
+
+class LLLaplace
+(model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, last_layer_name=None, backend_kwargs=None)
+
+-
+
Baseclass for all last-layer Laplace approximations in this library.
+Subclasses specify the structure of the Hessian approximation.
+See BaseLaplace
for the full interface.
+
A Laplace approximation is represented by a MAP which is given by the
+model
parameter and a posterior precision or covariance specifying
+a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}).
+Here, only the parameters of the last layer of the neural network
+are treated probabilistically.
+The goal of this class is to compute the posterior precision P
+which sums as
+
+P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta)
+\vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}.
+
+Every subclass implements different approximations to the log likelihood Hessians,
+for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have
+a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 .
+In particular, we assume a scalar or diagonal prior precision so that in
+all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.
+
Parameters
+
+model
: torch.nn.Module
or FeatureExtractor
+-
+likelihood
: {'classification', 'regression'}
+- determines the log likelihood Hessian approximation
+sigma_noise
: torch.Tensor
or float
, default=1
+- observation noise for the regression setting; must be 1 for classification
+prior_precision
: torch.Tensor
or float
, default=1
+- prior precision of a Gaussian prior (= weight decay);
+can be scalar, per-layer, or diagonal in the most general case
+prior_mean
: torch.Tensor
or float
, default=0
+- prior mean of a Gaussian prior, useful for continual learning
+temperature
: float
, default=1
+- temperature of the likelihood; lower temperature leads to more
+concentrated posterior and vice versa.
+backend
: subclasses
of CurvatureInterface
+- backend for access to curvature/Hessian approximations
+last_layer_name
: str
, default=None
+- name of the model's last layer, if None it will be determined automatically
+backend_kwargs
: dict
, default=None
+- arguments passed to the backend on initialization, for example to
+set the number of MC samples for stochastic approximations.
+
+Ancestors
+
+Subclasses
+
+Instance variables
+
+var prior_precision_diag
+-
+
Obtain the diagonal prior precision p_0 constructed from either
+a scalar or diagonal prior precision.
+
Returns
+
+prior_precision_diag
: torch.Tensor
+-
+
+
+
+Inherited members
+
+
class FullLLLaplace
(model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, last_layer_name=None, backend_kwargs=None)
@@ -42,33 +139,39 @@
and hence posterior precision. Based on the chosen backend
parameter, the full
approximation can be, for example, a generalized Gauss-Newton matrix.
Mathematically, we have P \in \mathbb{R}^{P \times P}.
-See FullLaplace
, LLLaplace
, and BaseLaplace
for the full interface.
+See FullLaplace
, LLLaplace
, and BaseLaplace
for the full interface.
Ancestors
Inherited members
@@ -82,36 +185,38 @@ Inherited members
and hence posterior precision.
Mathematically, we have for the last parameter group, i.e., torch.nn.Linear,
that \P\approx Q \otimes H.
-See KronLaplace
, LLLaplace
, and BaseLaplace
for the full interface and see
-Kron
and KronDecomposed
for the structure of
+See KronLaplace
, LLLaplace
, and BaseLaplace
for the full interface and see
+Kron
and KronDecomposed
for the structure of
the Kronecker factors. Kron
is used to aggregate factors by summing up and
KronDecomposed
is used to add the prior, a Hessian factor (e.g. temperature),
and computing posterior covariances, marginal likelihood, etc.
Use of damping
is possible by initializing or setting damping=True
.
Ancestors
Inherited members
@@ -124,33 +229,39 @@ Inherited members
Last-layer Laplace approximation with diagonal log likelihood Hessian approximation
and hence posterior precision.
Mathematically, we have P \approx \textrm{diag}(P).
-See DiagLaplace
, LLLaplace
, and BaseLaplace
for the full interface.
+See DiagLaplace
, LLLaplace
, and BaseLaplace
for the full interface.
Ancestors
Inherited members
@@ -172,6 +283,9 @@ Index
Classes