Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for refitting LA with option to override or update #62

Merged
merged 7 commits into from
Dec 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 97 additions & 38 deletions laplace/baselaplace.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from math import sqrt, pi
from math import sqrt, pi, log
import numpy as np
import torch
from torch.nn.utils import parameters_to_vector, vector_to_parameters
Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
# log likelihood = g(loss)
self.loss = 0.
self.n_outputs = None
self.n_data = None
self.n_data = 0

@property
def backend(self):
Expand All @@ -72,9 +72,6 @@ def backend(self):
def _curv_closure(self, X, y, N):
raise NotImplementedError

def _check_fit(self):
raise NotImplementedError

def fit(self, train_loader):
raise NotImplementedError

Expand All @@ -92,8 +89,6 @@ def log_likelihood(self):
-------
log_likelihood : torch.Tensor
"""
self._check_fit()

factor = - self._H_factor
if self.likelihood == 'regression':
# loss used is just MSE, need to add normalizer for gaussian likelihood
Expand Down Expand Up @@ -186,7 +181,7 @@ def prior_precision(self, prior_precision):
def optimize_prior_precision_base(self, pred_type, method='marglik', n_steps=100, lr=1e-1,
init_prior_prec=1., val_loader=None, loss=get_nll,
log_prior_prec_min=-4, log_prior_prec_max=4, grid_size=100,
link_approx='probit', n_samples=100, verbose=False,
link_approx='probit', n_samples=100, verbose=False,
cv_loss_with_var=False):
"""Optimize the prior precision post-hoc using the `method`
specified by the user.
Expand Down Expand Up @@ -336,34 +331,35 @@ def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
'GGN or EF backends required in ParametricLaplace.'
super().__init__(model, likelihood, sigma_noise, prior_precision,
prior_mean, temperature, backend, backend_kwargs)

self.H = None

try:
self._init_H()
except AttributeError: # necessary information not yet available
pass
# posterior mean/mode
self.mean = parameters_to_vector(self.model.parameters()).detach()
self.mean = self.prior_mean

def _init_H(self):
raise NotImplementedError

def _check_fit(self):
if self.H is None:
raise AttributeError('ParametricLaplace not fitted. Run fit() first.')

def fit(self, train_loader):
def fit(self, train_loader, override=True):
"""Fit the local Laplace approximation at the parameters of the model.

Parameters
----------
train_loader : torch.data.utils.DataLoader
each iterate is a training batch (X, y);
`train_loader.dataset` needs to be set to access \\(N\\), size of the data set
override : bool, default=True
whether to initialize H, loss, and n_data again; setting to False is useful for
online learning settings to accumulate a sequential posterior approximation.
"""
if self.H is not None:
raise ValueError('Already fit.')

self._init_H()
if override:
self._init_H()
self.loss = 0
self.n_data = 0

self.model.eval()
self.mean = parameters_to_vector(self.model.parameters()).detach()

X, _ = next(iter(train_loader))
with torch.no_grad():
Expand All @@ -382,7 +378,7 @@ def fit(self, train_loader):
self.loss += loss_batch
self.H += H_batch

self.n_data = N
self.n_data += N

@property
def scatter(self):
Expand Down Expand Up @@ -434,6 +430,36 @@ def log_det_ratio(self):
"""
return self.log_det_posterior_precision - self.log_det_prior_precision

def square_norm(self, value):
"""Compute the square norm under post. Precision with `value-self.mean` as 𝛥:
\\[
\\Delta^\top P \\Delta
\\]
Returns
-------
square_form
"""
raise NotImplementedError

def log_prob(self, value, normalized=True):
"""Compute the log probability under the (current) Laplace approximation.

Parameters
----------
normalized : bool, default=True
whether to return log of a properly normalized Gaussian or just the
terms that depend on `value`.

Returns
-------
log_prob : torch.Tensor
"""
if not normalized:
return - self.square_norm(value) / 2
log_prob = - self.n_params / 2 * log(2 * pi) + self.log_det_posterior_precision / 2
log_prob -= self.square_norm(value) / 2
return log_prob

def log_marginal_likelihood(self, prior_precision=None, sigma_noise=None):
"""Compute the Laplace approximation to the log marginal likelihood subject
to specific Hessian approximations that subclasses implement.
Expand All @@ -454,9 +480,6 @@ def log_marginal_likelihood(self, prior_precision=None, sigma_noise=None):
-------
log_marglik : torch.Tensor
"""
# make sure we can differentiate wrt prior and sigma_noise for regression
self._check_fit()

# update prior precision (useful when iterating on marglik)
if prior_precision is not None:
self.prior_precision = prior_precision
Expand Down Expand Up @@ -497,8 +520,6 @@ def __call__(self, x, pred_type='glm', link_approx='probit', n_samples=100):
For `likelihood='regression'`, a tuple of torch.Tensor is returned
with the mean and the predictive variance.
"""
self._check_fit()

if pred_type not in ['glm', 'nn']:
raise ValueError('Only glm and nn supported as prediction types.')

Expand Down Expand Up @@ -555,8 +576,6 @@ def predictive_samples(self, x, pred_type='glm', n_samples=100):
samples : torch.Tensor
samples `(n_samples, batch_size, output_shape)`
"""
self._check_fit()

if pred_type not in ['glm', 'nn']:
raise ValueError('Only glm and nn supported as prediction types.')

Expand Down Expand Up @@ -625,7 +644,7 @@ def sample(self, n_samples=100):
def optimize_prior_precision(self, method='marglik', pred_type='glm', n_steps=100, lr=1e-1,
init_prior_prec=1., val_loader=None, loss=get_nll,
log_prior_prec_min=-4, log_prior_prec_max=4, grid_size=100,
link_approx='probit', n_samples=100, verbose=False,
link_approx='probit', n_samples=100, verbose=False,
cv_loss_with_var=False):
assert pred_type in ['glm', 'nn']
self.optimize_prior_precision_base(pred_type, method, n_steps, lr,
Expand Down Expand Up @@ -667,6 +686,10 @@ def _init_H(self):
def _curv_closure(self, X, y, N):
return self.backend.full(X, y, N=N)

def fit(self, train_loader, override=True):
self._posterior_scale = None
return super().fit(train_loader, override=override)

def _compute_scale(self):
self._posterior_scale = invsqrt_precision(self.posterior_precision)

Expand Down Expand Up @@ -705,13 +728,16 @@ def posterior_precision(self):
precision : torch.tensor
`(parameters, parameters)`
"""
self._check_fit()
return self._H_factor * self.H + torch.diag(self.prior_precision_diag)

@property
def log_det_posterior_precision(self):
return self.posterior_precision.logdet()

def square_norm(self, value):
delta = value - self.mean
return delta @ self.posterior_precision @ delta

def functional_variance(self, Js):
return torch.einsum('ncp,pq,nkq->nck', Js, self.posterior_covariance, Js)

Expand Down Expand Up @@ -739,6 +765,7 @@ def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
prior_mean=0., temperature=1., backend=BackPackGGN, damping=False,
**backend_kwargs):
self.damping = damping
self.H_facs = None
super().__init__(model, likelihood, sigma_noise, prior_precision,
prior_mean, temperature, backend, **backend_kwargs)

Expand All @@ -748,12 +775,34 @@ def _init_H(self):
def _curv_closure(self, X, y, N):
return self.backend.kron(X, y, N=N)

def fit(self, train_loader, keep_factors=False):
super().fit(train_loader)
# Kron requires postprocessing as all quantities depend on the decomposition.
if keep_factors:
@staticmethod
def _rescale_factors(kron, factor):
for F in kron.kfacs:
if len(F) == 2:
F[1] *= factor
return kron

def fit(self, train_loader, override=True):
if override:
self.H_facs = None

if self.H_facs is not None:
n_data_old = self.n_data
n_data_new = len(train_loader.dataset)
self._init_H() # re-init H non-decomposed
# discount previous Kronecker factors to sum up properly together with new ones
self.H_facs = self._rescale_factors(self.H_facs, n_data_old / (n_data_old + n_data_new))

super().fit(train_loader, override=override)

if self.H_facs is None:
self.H_facs = self.H
self.H = self.H.decompose(damping=self.damping)
else:
# discount new factors that were computed assuming N = n_data_new
self.H = self._rescale_factors(self.H, n_data_new / (n_data_new + n_data_old))
self.H_facs += self.H
# Decompose to self.H for all required quantities but keep H_facs for further inference
self.H = self.H_facs.decompose(damping=self.damping)

@property
def posterior_precision(self):
Expand All @@ -763,13 +812,20 @@ def posterior_precision(self):
-------
precision : `laplace.matrix.KronDecomposed`
"""
self._check_fit()
return self.H * self._H_factor + self.prior_precision

@property
def log_det_posterior_precision(self):
if type(self.H) is Kron: # Fall back to diag prior
return self.prior_precision_diag.log().sum()
return self.posterior_precision.logdet()

def square_norm(self, value):
delta = value - self.mean
if type(self.H) is Kron: # fall back to prior
return (delta * self.prior_precision_diag) @ delta
return delta @ self.posterior_precision.bmm(delta, exponent=1)

def functional_variance(self, Js):
return self.posterior_precision.inv_square_form(Js)

Expand Down Expand Up @@ -810,7 +866,6 @@ def posterior_precision(self):
precision : torch.tensor
`(parameters)`
"""
self._check_fit()
return self._H_factor * self.H + self.prior_precision_diag

@property
Expand Down Expand Up @@ -839,6 +894,10 @@ def posterior_variance(self):
def log_det_posterior_precision(self):
return self.posterior_precision.log().sum()

def square_norm(self, value):
delta = value - self.mean
return delta @ (delta * self.posterior_precision)

def functional_variance(self, Js: torch.Tensor) -> torch.Tensor:
self._check_jacobians(Js)
return torch.einsum('ncp,p,nkp->nck', Js, self.posterior_variance, Js)
Expand Down
25 changes: 15 additions & 10 deletions laplace/lllaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,32 +65,32 @@ def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
backend_kwargs=backend_kwargs)
self.model = FeatureExtractor(deepcopy(model), last_layer_name=last_layer_name)
if self.model.last_layer is None:
self.mean = None
self.mean = prior_mean
self.n_params = None
self.n_layers = None
# ignore checks of prior mean setter temporarily, check on .fit()
self._prior_precision = prior_precision
self._prior_mean = prior_mean
else:
self.mean = parameters_to_vector(self.model.last_layer.parameters()).detach()
self.n_params = len(self.mean)
self.n_params = len(parameters_to_vector(self.model.last_layer.parameters()))
self.n_layers = len(list(self.model.last_layer.parameters()))
self.prior_precision = prior_precision
self.prior_mean = prior_mean
self.mean = self.prior_mean
self._backend_kwargs['last_layer'] = True

def fit(self, train_loader):
def fit(self, train_loader, override=True):
"""Fit the local Laplace approximation at the parameters of the model.

Parameters
----------
train_loader : torch.data.utils.DataLoader
each iterate is a training batch (X, y);
`train_loader.dataset` needs to be set to access \\(N\\), size of the data set
override : bool, default=True
whether to initialize H, loss, and n_data again; setting to False is useful for
online learning settings to accumulate a sequential posterior approximation.
"""
if self.H is not None:
raise ValueError('Already fit.')

self.model.eval()

if self.model.last_layer is None:
Expand All @@ -100,14 +100,19 @@ def fit(self, train_loader):
self.model.find_last_layer(X[:1].to(self._device))
except (TypeError, AttributeError):
self.model.find_last_layer(X.to(self._device))
self.mean = parameters_to_vector(self.model.last_layer.parameters()).detach()
self.n_params = len(self.mean)
params = parameters_to_vector(self.model.last_layer.parameters()).detach()
self.n_params = len(params)
self.n_layers = len(list(self.model.last_layer.parameters()))
# here, check the already set prior precision again
self.prior_precision = self._prior_precision
self.prior_mean = self._prior_mean
self._init_H()

if override:
self._init_H()

super().fit(train_loader)
super().fit(train_loader, override=override)
self.mean = parameters_to_vector(self.model.last_layer.parameters()).detach()

def _glm_predictive_distribution(self, X):
Js, f_mu = self.backend.last_layer_jacobians(self.model, X)
Expand Down
Loading