diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index 1127934..7b6b7df 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -10,6 +10,7 @@ import torchmetrics import tqdm from torch import nn +from torch.linalg import LinAlgError from torch.nn.utils import parameters_to_vector, vector_to_parameters from torch.utils.data import DataLoader @@ -104,6 +105,7 @@ def __init__( raise ValueError(f"Invalid likelihood type {likelihood}") self.model: nn.Module = model + self.likelihood: Likelihood | str = likelihood # Only do Laplace on params that require grad self.params: list[torch.Tensor] = [] @@ -121,14 +123,6 @@ def __init__( if sigma_noise != 1 and likelihood != Likelihood.REGRESSION: raise ValueError("Sigma noise != 1 only available for regression.") - self.reward_modeling: bool = likelihood == Likelihood.REWARD_MODELING - - if self.reward_modeling: - # For fitting only. After it's done, self.likelihood = 'regression', see self.fit() - self.likelihood = Likelihood.CLASSIFICATION - else: - self.likelihood = likelihood - self.sigma_noise: float | torch.Tensor = sigma_noise self.temperature: float = temperature self.enable_backprop: bool = enable_backprop @@ -176,9 +170,14 @@ def _device(self) -> torch.device: @property def backend(self) -> CurvatureInterface: if self._backend is None: + likelihood = ( + "classification" + if self.likelihood == "reward_modeling" + else self.likelihood + ) self._backend = self._backend_cls( self.model, - self.likelihood, + likelihood, dict_key_x=self.dict_key_x, dict_key_y=self.dict_key_y, **self._backend_kwargs, @@ -353,7 +352,6 @@ def optimize_prior_precision( link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, verbose: bool = False, - cv_loss_with_var: bool = False, progress_bar: bool = False, ) -> None: """Optimize the prior precision post-hoc using the `method` @@ -383,9 +381,6 @@ def optimize_prior_precision( If torchmetrics.Metric, running loss is computed (efficient). The default depends on the likelihood: `RunningNLLMetric()` for classification and reward modeling, running `MeanSquaredError()` for regression. - cv_loss_with_var: bool, default=False - if true, `loss` takes three arguments `loss(output_mean, output_var, target)`, - otherwise, `loss` takes two arguments `loss(output_mean, target)` log_prior_prec_min : float, default=-4 lower bound of gridsearch interval. log_prior_prec_max : float, default=4 @@ -404,6 +399,12 @@ def optimize_prior_precision( whether to show a progress bar; updated at every batch-Hessian computation. Useful for very large model and large amount of data, esp. when `subset_of_weights='all'`. """ + likelihood = ( + Likelihood.CLASSIFICATION + if self.likelihood == Likelihood.REWARD_MODELING + else self.likelihood + ) + if method == TuningMethod.MARGLIK: self.prior_precision = ( init_prior_prec @@ -451,9 +452,11 @@ def optimize_prior_precision( if loss is None: loss = ( - torchmetrics.MeanSquaredError(num_outputs=self.n_outputs) - if self.likelihood == "regression" - else RunningNLLMetric() + torchmetrics.MeanSquaredError(num_outputs=self.n_outputs).to( + self._device + ) + if likelihood == Likelihood.REGRESSION + else RunningNLLMetric().to(self._device) ) self.prior_precision = self._gridsearch( @@ -463,11 +466,11 @@ def optimize_prior_precision( pred_type=pred_type, link_approx=link_approx, n_samples=n_samples, - loss_with_var=cv_loss_with_var, progress_bar=progress_bar, ) else: raise ValueError("For now only marglik and gridsearch is implemented.") + if verbose: print(f"Optimized prior precision is {self.prior_precision}.") @@ -479,7 +482,6 @@ def _gridsearch( pred_type: PredType | str, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, - loss_with_var: bool = False, progress_bar: bool = False, ) -> torch.Tensor: assert callable(loss) or isinstance(loss, torchmetrics.Metric) @@ -490,6 +492,7 @@ def _gridsearch( for prior_prec in pbar: self.prior_precision = prior_prec + try: result = validate( self, @@ -498,11 +501,15 @@ def _gridsearch( pred_type=pred_type, link_approx=link_approx, n_samples=n_samples, - loss_with_var=loss_with_var, dict_key_y=self.dict_key_y, ) - except RuntimeError: + except LinAlgError: result = np.inf + except RuntimeError as err: + if "not positive definite" in str(err): + result = np.inf + else: + raise err if progress_bar: pbar.set_description( @@ -813,6 +820,7 @@ def __call__( n_samples: int = 100, diagonal_output: bool = False, generator: torch.Generator | None = None, + fitting: bool = False, **model_kwargs: dict[str, Any], ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """Compute the posterior predictive on input data `x`. @@ -851,6 +859,11 @@ def __call__( generator : torch.Generator, optional random number generator to control the samples (if sampling used). + fitting : bool, default=False + whether or not this predictive call is done during fitting. Only useful for + reward modeling: the likelihood is set to `"regression"` when `False` and + `"classification"` when `True`. + Returns ------- predictive: torch.Tensor or tuple[torch.Tensor] @@ -879,16 +892,16 @@ def __call__( ): raise ValueError("Invalid random generator (check type and device).") - # For reward modeling, replace the likelihood to regression and override model state - if self.reward_modeling and self.likelihood == Likelihood.CLASSIFICATION: - self.likelihood = Likelihood.REGRESSION + likelihood = self.likelihood + if likelihood == Likelihood.REWARD_MODELING: + likelihood = Likelihood.CLASSIFICATION if fitting else Likelihood.REGRESSION if pred_type == PredType.GLM: f_mu, f_var = self._glm_predictive_distribution( - x, joint=joint and self.likelihood == "regression" + x, joint=joint and likelihood == Likelihood.REGRESSION ) - if self.likelihood == Likelihood.REGRESSION: + if likelihood == Likelihood.REGRESSION: return f_mu, f_var if link_approx == LinkApprox.MC: @@ -933,7 +946,7 @@ def __call__( "Prediction path invalid. Check the likelihood, pred_type, link_approx combination!" ) else: - if self.likelihood == Likelihood.REGRESSION: + if likelihood == Likelihood.REGRESSION: samples = self._nn_predictive_samples(x, n_samples, **model_kwargs) return samples.mean(dim=0), samples.var(dim=0) else: # classification; the average is computed online @@ -1148,7 +1161,6 @@ def optimize_prior_precision( link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, verbose: bool = False, - cv_loss_with_var: bool = False, progress_bar: bool = False, ) -> None: assert pred_type in PredType.__members__.values() @@ -1168,7 +1180,6 @@ def optimize_prior_precision( link_approx, n_samples, verbose, - cv_loss_with_var, progress_bar, ) diff --git a/laplace/utils/utils.py b/laplace/utils/utils.py index dc26b5c..4fe459c 100644 --- a/laplace/utils/utils.py +++ b/laplace/utils/utils.py @@ -45,7 +45,6 @@ def validate( pred_type: PredType | str = PredType.GLM, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, - loss_with_var: int = False, dict_key_y: str = "labels", ) -> float: laplace.model.eval() @@ -64,7 +63,11 @@ def validate( X = X.to(laplace._device) y = y.to(laplace._device) out = laplace( - X, pred_type=pred_type, link_approx=link_approx, n_samples=n_samples + X, + pred_type=pred_type, + link_approx=link_approx, + n_samples=n_samples, + fitting=True, ) if type(out) == tuple: @@ -73,7 +76,10 @@ def validate( output_vars.append(out[1]) targets.append(y) else: - loss.update(*out, y) + try: + loss.update(*out, y) + except TypeError: # If the online loss only accepts 2 args + loss.update(out[0], y) else: if is_offline: output_means.append(out) @@ -90,7 +96,8 @@ def validate( targets = torch.cat(targets, dim=0) return loss(means, variances, targets).item() else: - return loss.compute().item() + # Aggregate since torchmetrics output n_classes values for the MSE metric + return loss.compute().sum().item() def parameters_per_layer(model: nn.Module) -> list[int]: diff --git a/tests/test_baselaplace.py b/tests/test_baselaplace.py index 35ccc3c..1455c76 100644 --- a/tests/test_baselaplace.py +++ b/tests/test_baselaplace.py @@ -739,3 +739,27 @@ def test_backprop_nn(laplace, model, reg_loader, backend): assert grad_X_var.shape == X.shape except ValueError: assert False + + +@pytest.mark.parametrize( + "likelihood", ["classification", "regression", "reward_modeling"] +) +@pytest.mark.parametrize("prior_prec_type", ["scalar", "layerwise", "diag"]) +def test_gridsearch(model, likelihood, prior_prec_type, reg_loader, class_loader): + if likelihood == "regression": + dataloader = reg_loader + else: + dataloader = class_loader + + if prior_prec_type == "scalar": + prior_prec = 1.0 + elif prior_prec_type == "layerwise": + prior_prec = torch.ones(model.n_layers) + else: + prior_prec = torch.ones(model.n_params) + + lap = DiagLaplace(model, likelihood, prior_precision=prior_prec) + lap.fit(dataloader) + + # Should not raise an error + lap.optimize_prior_precision(method="gridsearch", val_loader=dataloader, n_steps=10)