Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep initialization of H for all-weights and last-layer separate #72

Merged
merged 4 commits into from
Dec 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions laplace/baselaplace.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from math import sqrt, pi, log
from laplace.curvature.asdl import AsdlHessian
import numpy as np
import torch
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.distributions import MultivariateNormal, Dirichlet, Normal

from laplace.utils import parameters_per_layer, invsqrt_precision, get_nll, validate
from laplace.matrix import Kron
from laplace.curvature import BackPackGGN, BackPackEF, AsdlGGN, AsdlEF
from laplace.curvature import BackPackGGN, AsdlHessian


__all__ = ['BaseLaplace', 'FullLaplace', 'KronLaplace', 'DiagLaplace', 'ParametricLaplace']
__all__ = ['BaseLaplace', 'ParametricLaplace',
'FullLaplace', 'KronLaplace', 'DiagLaplace', 'LowRankLaplace']


class BaseLaplace:
Expand Down Expand Up @@ -330,16 +330,18 @@ def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
prior_mean=0., temperature=1., backend=BackPackGGN, backend_kwargs=None):
super().__init__(model, likelihood, sigma_noise, prior_precision,
prior_mean, temperature, backend, backend_kwargs)
try:
if not hasattr(self, 'H'):
self._init_H()
except AttributeError: # necessary information not yet available
pass
# posterior mean/mode
self.mean = self.prior_mean
# posterior mean/mode
self.mean = self.prior_mean

def _init_H(self):
raise NotImplementedError

def _check_H_init(self):
if self.H is None:
raise AttributeError('Laplace not fitted. Run fit() first.')

def fit(self, train_loader, override=True):
"""Fit the local Laplace approximation at the parameters of the model.

Expand Down Expand Up @@ -727,6 +729,7 @@ def posterior_precision(self):
precision : torch.tensor
`(parameters, parameters)`
"""
self._check_H_init()
return self._H_factor * self.H + torch.diag(self.prior_precision_diag)

@property
Expand Down Expand Up @@ -811,6 +814,7 @@ def posterior_precision(self):
-------
precision : `laplace.matrix.KronDecomposed`
"""
self._check_H_init()
return self.H * self._H_factor + self.prior_precision

@property
Expand Down Expand Up @@ -862,7 +866,7 @@ def __init__(self, model, likelihood, sigma_noise=1, prior_precision=1, prior_me
temperature=temperature, backend=backend, backend_kwargs=backend_kwargs)

def _init_H(self):
pass
self.H = None

@property
def V(self):
Expand Down Expand Up @@ -910,6 +914,7 @@ def posterior_precision(self):
prior_precision_diag : torch.Tensor
diagonal prior precision shape `parameters` to be added to H.
"""
self._check_H_init()
return (self.H[0], self._H_factor * self.H[1]), self.prior_precision_diag

def functional_variance(self, Jacs):
Expand Down Expand Up @@ -964,6 +969,7 @@ def posterior_precision(self):
precision : torch.tensor
`(parameters)`
"""
self._check_H_init()
return self._H_factor * self.H + self.prior_precision_diag

@property
Expand Down
4 changes: 2 additions & 2 deletions laplace/curvature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
logging.info('Backpack not available.')

try:
from laplace.curvature.asdl import AsdlGGN, AsdlEF, AsdlInterface
from laplace.curvature.asdl import AsdlHessian, AsdlGGN, AsdlEF, AsdlInterface
except ModuleNotFoundError:
logging.info('asdfghjkl backend not available.')

__all__ = ['CurvatureInterface', 'GGNInterface', 'EFInterface',
'BackPackInterface', 'BackPackGGN', 'BackPackEF',
'AsdlInterface', 'AsdlGGN', 'AsdlEF']
'AsdlInterface', 'AsdlGGN', 'AsdlEF', 'AsdlHessian']
2 changes: 1 addition & 1 deletion laplace/curvature/asdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def diag(self, X, y, **kwargs):
diag_ggn = curv.matrices_to_vector(None)
return self.factor * loss, self.factor * diag_ggn

def kron(self, X, y, N, **wkwargs) -> [torch.Tensor, Kron]:
def kron(self, X, y, N, **wkwargs):
with torch.no_grad():
if self.last_layer:
f, X = self.model.forward_with_features(X)
Expand Down
26 changes: 8 additions & 18 deletions laplace/lllaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from laplace.curvature import BackPackGGN


__all__ = ['FullLLLaplace', 'KronLLLaplace', 'DiagLLLaplace']
__all__ = ['LLLaplace', 'FullLLLaplace', 'KronLLLaplace', 'DiagLLLaplace']


class LLLaplace(ParametricLaplace):
Expand Down Expand Up @@ -60,12 +60,13 @@ class LLLaplace(ParametricLaplace):
def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
prior_mean=0., temperature=1., backend=BackPackGGN, last_layer_name=None,
backend_kwargs=None):
self.H = None
super().__init__(model, likelihood, sigma_noise=sigma_noise, prior_precision=1.,
prior_mean=0., temperature=temperature, backend=backend,
backend_kwargs=backend_kwargs)
self.model = FeatureExtractor(deepcopy(model), last_layer_name=last_layer_name)
if self.model.last_layer is None:
self.mean = prior_mean
self.mean = None
self.n_params = None
self.n_layers = None
# ignore checks of prior mean setter temporarily, check on .fit()
Expand All @@ -76,7 +77,8 @@ def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
self.n_layers = len(list(self.model.last_layer.parameters()))
self.prior_precision = prior_precision
self.prior_mean = prior_mean
self.mean = self.prior_mean
self.mean = self.prior_mean
self._init_H()
self._backend_kwargs['last_layer'] = True

def fit(self, train_loader, override=True):
Expand All @@ -91,6 +93,9 @@ def fit(self, train_loader, override=True):
whether to initialize H, loss, and n_data again; setting to False is useful for
online learning settings to accumulate a sequential posterior approximation.
"""
if not override:
raise ValueError('Last-layer Laplace approximations do not support `override=False`.')

self.model.eval()

if self.model.last_layer is None:
Expand All @@ -108,9 +113,6 @@ def fit(self, train_loader, override=True):
self.prior_mean = self._prior_mean
self._init_H()

if override:
self._init_H()

super().fit(train_loader, override=override)
self.mean = parameters_to_vector(self.model.last_layer.parameters()).detach()

Expand Down Expand Up @@ -159,12 +161,6 @@ class FullLLLaplace(LLLaplace, FullLaplace):
# key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure)
_key = ('last_layer', 'full')

def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
prior_mean=0., temperature=1., backend=BackPackGGN, last_layer_name=None,
backend_kwargs=None):
super().__init__(model, likelihood, sigma_noise, prior_precision,
prior_mean, temperature, backend, last_layer_name, backend_kwargs)


class KronLLLaplace(LLLaplace, KronLaplace):
"""Last-layer Laplace approximation with Kronecker factored log likelihood Hessian approximation
Expand Down Expand Up @@ -200,9 +196,3 @@ class DiagLLLaplace(LLLaplace, DiagLaplace):
"""
# key to map to correct subclass of BaseLaplace, (subset of weights, Hessian structure)
_key = ('last_layer', 'diag')

def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
prior_mean=0., temperature=1., backend=BackPackGGN, last_layer_name=None,
backend_kwargs=None):
super().__init__(model, likelihood, sigma_noise, prior_precision,
prior_mean, temperature, backend, last_layer_name, backend_kwargs)
25 changes: 25 additions & 0 deletions tests/test_baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch.nn.utils import parameters_to_vector
from torch.utils.data import DataLoader, TensorDataset
from torch.distributions import Normal, Categorical
from torchvision.models import wide_resnet50_2

from laplace.laplace import FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace
from laplace.matrix import KronDecomposed
Expand All @@ -31,6 +32,12 @@ def model():
return model


@pytest.fixture
def large_model():
model = wide_resnet50_2()
return model


@pytest.fixture
def class_loader():
X = torch.randn(10, 3)
Expand All @@ -48,6 +55,24 @@ def reg_loader():
@pytest.mark.parametrize('laplace', flavors)
def test_laplace_init(laplace, model):
lap = laplace(model, 'classification')
assert torch.allclose(lap.mean, lap.prior_mean)
if laplace in [FullLaplace, DiagLaplace]:
H = lap.H.clone()
lap._init_H()
assert torch.allclose(H, lap.H)
elif laplace == LowRankLaplace:
assert lap.H is None
else:
H = [[k.clone() for k in kfac] for kfac in lap.H.kfacs]
lap._init_H()
for kfac1, kfac2 in zip(H, lap.H.kfacs):
for k1, k2 in zip(kfac1, kfac2):
assert torch.allclose(k1, k2)


@pytest.mark.xfail(strict=True)
def test_laplace_large_init(large_model):
lap = FullLaplace(large_model, 'classification')


@pytest.mark.parametrize('laplace', flavors)
Expand Down
48 changes: 48 additions & 0 deletions tests/test_lllaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.nn.utils import parameters_to_vector
from torch.utils.data import DataLoader, TensorDataset
from torch.distributions import Normal, Categorical
from torchvision.models import wide_resnet50_2

from laplace.lllaplace import LLLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace
from laplace.feature_extractor import FeatureExtractor
Expand All @@ -24,6 +25,12 @@ def model():
return model


@pytest.fixture
def large_model():
model = wide_resnet50_2()
return model


@pytest.fixture
def class_loader():
X = torch.randn(10, 3)
Expand All @@ -41,6 +48,47 @@ def reg_loader():
@pytest.mark.parametrize('laplace', flavors)
def test_laplace_init(laplace, model):
lap = laplace(model, 'classification', last_layer_name='1')
assert torch.allclose(lap.mean, lap.prior_mean)
if laplace != KronLLLaplace:
H = lap.H.clone()
lap._init_H()
assert torch.allclose(H, lap.H)
else:
H = [[k.clone() for k in kfac] for kfac in lap.H.kfacs]
lap._init_H()
for kfac1, kfac2 in zip(H, lap.H.kfacs):
for k1, k2 in zip(kfac1, kfac2):
assert torch.allclose(k1, k2)


@pytest.mark.parametrize('laplace', flavors)
def test_laplace_init_nollname(laplace, model):
lap = laplace(model, 'classification')
assert lap.mean is None
assert lap.H is None


@pytest.mark.parametrize('laplace', [KronLLLaplace, DiagLLLaplace])
def test_laplace_large_init(laplace, large_model):
lap = laplace(large_model, 'classification', last_layer_name='fc')
assert torch.allclose(lap.mean, lap.prior_mean)
if laplace == DiagLLLaplace:
H = lap.H.clone()
lap._init_H()
assert torch.allclose(H, lap.H)
else:
H = [[k.clone() for k in kfac] for kfac in lap.H.kfacs]
lap._init_H()
for kfac1, kfac2 in zip(H, lap.H.kfacs):
for k1, k2 in zip(kfac1, kfac2):
assert torch.allclose(k1, k2)


@pytest.mark.parametrize('laplace', flavors)
def test_laplace_large_init_nollname(laplace, large_model):
lap = laplace(large_model, 'classification')
assert lap.mean is None
assert lap.H is None


@pytest.mark.parametrize('laplace', flavors)
Expand Down