Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove try-except from gridsearch #199

Merged
merged 9 commits into from
Jun 23, 2024
66 changes: 38 additions & 28 deletions laplace/baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torchmetrics as tm
import tqdm
from torch.linalg import LinAlgError
from torch.nn.utils import parameters_to_vector, vector_to_parameters

from laplace.curvature import CurvlinopsGGN
Expand Down Expand Up @@ -90,7 +91,7 @@ def __init__(
):
if likelihood not in ["classification", "regression", "reward_modeling"]:
raise ValueError(f"Invalid likelihood type {likelihood}")

self.likelihood = likelihood
self.model = model

# Only do Laplace on params that require grad
Expand All @@ -109,13 +110,6 @@ def __init__(
if sigma_noise != 1 and likelihood != "regression":
raise ValueError("Sigma noise != 1 only available for regression.")

self.reward_modeling = likelihood == "reward_modeling"
if self.reward_modeling:
# For fitting only. After it's done, self.likelihood = 'regression', see self.fit()
self.likelihood = "classification"
else:
self.likelihood = likelihood

self.sigma_noise = sigma_noise
self.temperature = temperature
self.enable_backprop = enable_backprop
Expand Down Expand Up @@ -155,9 +149,14 @@ def _device(self):
@property
def backend(self):
if self._backend is None:
likelihood = (
"classification"
if self.likelihood == "reward_modeling"
else self.likelihood
)
self._backend = self._backend_cls(
self.model,
self.likelihood,
likelihood,
dict_key_x=self.dict_key_x,
dict_key_y=self.dict_key_y,
**self._backend_kwargs,
Expand Down Expand Up @@ -305,7 +304,6 @@ def optimize_prior_precision_base(
link_approx="probit",
n_samples=100,
verbose=False,
cv_loss_with_var=False,
progress_bar=False,
):
"""Optimize the prior precision post-hoc using the `method`
Expand Down Expand Up @@ -335,9 +333,6 @@ def optimize_prior_precision_base(
If torchmetrics.Metric, running loss is computed (efficient). The default
depends on the likelihood: `RunningNLLMetric()` for classification and
reward modeling, running `MeanSquaredError()` for regression.
cv_loss_with_var: bool, default=False
if true, `loss` takes three arguments `loss(output_mean, output_var, target)`,
otherwise, `loss` takes two arguments `loss(output_mean, target)`
log_prior_prec_min : float, default=-4
lower bound of gridsearch interval.
log_prior_prec_max : float, default=4
Expand All @@ -356,6 +351,12 @@ def optimize_prior_precision_base(
whether to show a progress bar; updated at every batch-Hessian computation.
Useful for very large model and large amount of data, esp. when `subset_of_weights='all'`.
"""
likelihood = (
"classification"
if self.likelihood == "reward_modeling"
else self.likelihood
)

if method == "marglik":
self.prior_precision = init_prior_prec
if len(self.prior_precision) == 1 and prior_structure != "scalar":
Expand Down Expand Up @@ -393,9 +394,9 @@ def optimize_prior_precision_base(

if loss is None:
loss = (
tm.MeanSquaredError(num_outputs=self.n_outputs)
if self.likelihood == "regression"
else RunningNLLMetric()
tm.MeanSquaredError(num_outputs=self.n_outputs).to(self._device)
if likelihood == "regression"
else RunningNLLMetric().to(self._device)
)

self.prior_precision = self._gridsearch(
Expand All @@ -405,11 +406,11 @@ def optimize_prior_precision_base(
pred_type=pred_type,
link_approx=link_approx,
n_samples=n_samples,
loss_with_var=cv_loss_with_var,
progress_bar=progress_bar,
)
else:
raise ValueError("For now only marglik and gridsearch is implemented.")

if verbose:
print(f"Optimized prior precision is {self.prior_precision}.")

Expand All @@ -421,16 +422,17 @@ def _gridsearch(
pred_type,
link_approx="probit",
n_samples=100,
loss_with_var=False,
progress_bar=False,
):
assert callable(loss) or isinstance(loss, tm.Metric)

results = list()
prior_precs = list()
pbar = tqdm.tqdm(interval) if progress_bar else interval

for prior_prec in pbar:
self.prior_precision = prior_prec

try:
wiseodd marked this conversation as resolved.
Show resolved Hide resolved
result = validate(
self,
Expand All @@ -439,11 +441,15 @@ def _gridsearch(
pred_type=pred_type,
link_approx=link_approx,
n_samples=n_samples,
loss_with_var=loss_with_var,
dict_key_y=self.dict_key_y,
)
except RuntimeError:
except LinAlgError:
runame marked this conversation as resolved.
Show resolved Hide resolved
result = np.inf
except RuntimeError as err:
if "not positive definite" in str(err):
result = np.inf
else:
raise err

if progress_bar:
pbar.set_description(
Expand Down Expand Up @@ -741,6 +747,7 @@ def __call__(
n_samples=100,
diagonal_output=False,
generator=None,
fitting=False,
**model_kwargs,
):
"""Compute the posterior predictive on input data `x`.
Expand Down Expand Up @@ -779,6 +786,11 @@ def __call__(
generator : torch.Generator, optional
random number generator to control the samples (if sampling used).

fitting : bool, default=False
whether or not this predictive call is done during fitting. Only useful for
reward modeling: the likelihood is set to `"regression"` when `False` and
`"classification"` when `True`.

Returns
-------
predictive: torch.Tensor or Tuple[torch.Tensor]
Expand Down Expand Up @@ -807,16 +819,16 @@ def __call__(
):
raise ValueError("Invalid random generator (check type and device).")

# For reward modeling, replace the likelihood to regression
if self.reward_modeling and self.likelihood == "classification":
self.likelihood = "regression"
likelihood = self.likelihood
if likelihood == "reward_modeling":
likelihood = "classification" if fitting else "regression"

if pred_type == "glm":
f_mu, f_var = self._glm_predictive_distribution(
x, joint=joint and self.likelihood == "regression"
x, joint=joint and likelihood == "regression"
)
# regression
if self.likelihood == "regression":
if likelihood == "regression":
return f_mu, f_var
# classification
if link_approx == "mc":
Expand Down Expand Up @@ -854,7 +866,7 @@ def __call__(
alpha = (1 - 2 / K + f_mu.exp() / K**2 * sum_exp) / f_var_diag
return torch.nan_to_num(alpha / alpha.sum(dim=1).unsqueeze(-1), nan=1.0)
else:
if self.likelihood == "regression":
if likelihood == "regression":
samples = self._nn_predictive_samples(x, n_samples, **model_kwargs)
return samples.mean(dim=0), samples.var(dim=0)
else: # classification; the average is computed online
Expand Down Expand Up @@ -1033,7 +1045,6 @@ def optimize_prior_precision(
link_approx="probit",
n_samples=100,
verbose=False,
cv_loss_with_var=False,
progress_bar=False,
):
assert pred_type in ["glm", "nn"]
Expand All @@ -1052,7 +1063,6 @@ def optimize_prior_precision(
link_approx,
n_samples,
verbose,
cv_loss_with_var,
progress_bar,
)

Expand Down
15 changes: 11 additions & 4 deletions laplace/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def validate(
pred_type="glm",
link_approx="probit",
n_samples=100,
loss_with_var=False,
dict_key_y="labels",
) -> float:
laplace.model.eval()
Expand All @@ -54,7 +53,11 @@ def validate(
X = X.to(laplace._device)
y = y.to(laplace._device)
out = laplace(
X, pred_type=pred_type, link_approx=link_approx, n_samples=n_samples
X,
pred_type=pred_type,
link_approx=link_approx,
n_samples=n_samples,
fitting=True,
)

if type(out) == tuple:
Expand All @@ -63,7 +66,10 @@ def validate(
output_vars.append(out[1])
targets.append(y)
else:
loss.update(*out, y)
try:
loss.update(*out, y)
except TypeError: # If the online loss only accepts 2 args
loss.update(out[0], y)
else:
if is_offline:
output_means.append(out)
Expand All @@ -80,7 +86,8 @@ def validate(
targets = torch.cat(targets, dim=0)
return loss(means, variances, targets).item()
else:
return loss.compute().item()
# Aggregate since torchmetrics output n_classes values for the MSE metric
return loss.compute().sum().item()


def parameters_per_layer(model):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,3 +739,27 @@ def test_backprop_nn(laplace, model, reg_loader, backend):
assert grad_X_var.shape == X.shape
except ValueError:
assert False


@pytest.mark.parametrize(
"likelihood", ["classification", "regression", "reward_modeling"]
)
@pytest.mark.parametrize("prior_prec_type", ["scalar", "layerwise", "diag"])
def test_gridsearch(model, likelihood, prior_prec_type, reg_loader, class_loader):
if likelihood == "regression":
dataloader = reg_loader
else:
dataloader = class_loader

if prior_prec_type == "scalar":
prior_prec = 1.0
elif prior_prec_type == "layerwise":
prior_prec = torch.ones(model.n_layers)
else:
prior_prec = torch.ones(model.n_params)

lap = DiagLaplace(model, likelihood, prior_precision=prior_prec)
lap.fit(dataloader)

# Should not raise an error
lap.optimize_prior_precision(method="gridsearch", val_loader=dataloader, n_steps=10)