Skip to content

Commit

Permalink
Raise an error if regression and output dim is different than target dim
Browse files Browse the repository at this point in the history
  • Loading branch information
wiseodd committed Sep 12, 2024
1 parent d9fe1a6 commit 0331f05
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 36 deletions.
42 changes: 24 additions & 18 deletions laplace/baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,8 +847,11 @@ def fit(
X, y = data
X, y = X.to(self._device), y.to(self._device)

if self.likelihood == Likelihood.REGRESSION and y.ndim == out.ndim - 1:
y = y.unsqueeze(-1)
if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim:
raise ValueError(
f"The model's output is of shape {tuple(out.shape)} but "
f"the target has shape {tuple(y.shape)}."
)

self.model.zero_grad()
loss_batch, H_batch = self._curv_closure(X, y, N=N)
Expand Down Expand Up @@ -1934,7 +1937,7 @@ class FunctionalLaplace(BaseLaplace):
See [Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021)](https://arxiv.org/abs/2008.08400)
for more details.
Note that for `likelihood='classification'`, we approximate \( L_{NN} \\) with a diagonal matrix
Note that for `likelihood='classification'`, we approximate \\( L_{NN} \\) with a diagonal matrix
( \\( L_{NN} \\) is a block-diagonal matrix, where blocks represent Hessians of per-data-point log-likelihood w.r.t.
neural network output \\( f \\), See Appendix [A.2.1](https://arxiv.org/abs/2008.08400) for exact definition). We
resort to such an approximation because of the (possible) errors found in Laplace approximation for
Expand Down Expand Up @@ -2027,9 +2030,9 @@ def _check_prior_precision(prior_precision: float | torch.Tensor):

def _init_K_MM(self):
"""Allocates memory for the kernel matrix evaluated at the subset of the training
data points. If the subset is of size \(M\) and the problem has \(C\) outputs,
this is a list of C \((M,M\)) tensors for diagonal kernel and \((M x C, M x C)\)
otherwise.
data points. If the subset is of size \\(M\\) and the problem has \\(C\\) outputs,
this is a list of C \\((M,M\\)) tensors for diagonal kernel and
\\((M \\times C, M \\times C)\\) otherwise.
"""
if self.independent_outputs:
self.K_MM = [
Expand All @@ -2044,9 +2047,9 @@ def _init_K_MM(self):

def _init_Sigma_inv(self):
"""Allocates memory for the cholesky decomposition of
\[
K_{MM} + \Lambda_{MM}^{-1}.
\]
\\[
K_{MM} + \\Lambda_{MM}^{-1}.
\\]
See See [Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021)](https://arxiv.org/abs/2008.08400)
Equation 15 for more information.
"""
Expand Down Expand Up @@ -2119,13 +2122,13 @@ class for more details.

def _build_Sigma_inv(self):
"""Computes the cholesky decomposition of
\[
K_{MM} + \Lambda_{MM}^{-1}.
\]
\\[
K_{MM} + \\Lambda_{MM}^{-1}.
\\]
See See [Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021)](https://arxiv.org/abs/2008.08400)
Equation 15 for more information.
As the diagonal approximation is performed with \Lambda_{MM} (which is stored in self.L),
As the diagonal approximation is performed with \\(\\Lambda_{MM}\\) (which is stored in self.L),
the code is greatly simplified.
"""
if self.independent_outputs:
Expand Down Expand Up @@ -2235,8 +2238,11 @@ def fit(

Js_batch, f_batch = self._jacobians(X, enable_backprop=False)

if self.likelihood == Likelihood.REGRESSION and y.ndim == f_batch.ndim - 1:
y = y.unsqueeze(1)
if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim:
raise ValueError(
f"The model's output is of shape {tuple(out.shape)} but "
f"the target has shape {tuple(y.shape)}."
)

with torch.no_grad():
loss_batch = self.backend.factor * self.backend.lossfunc(f_batch, y)
Expand Down Expand Up @@ -2559,11 +2565,11 @@ def log_det_ratio(self) -> torch.Tensor:
[GP book R&W 2006](http://www.gaussianprocess.org/gpml/chapters/) with
(note that we always use diagonal approximation \\(D\\) of the Hessian of log likelihood w.r.t. \\(f\\)):
log determinant term := \\( \log | I + D^{1/2}K D^{1/2} | \\)
log determinant term := \\( \\log | I + D^{1/2}K D^{1/2} | \\)
For `regression`, we use ["standard" GP marginal likelihood](https://stats.stackexchange.com/questions/280105/log-marginal-likelihood-for-gaussian-process):
log determinant term := \\( \log | K + \\sigma_2 I | \\)
log determinant term := \\( \\log | K + \\sigma_2 I | \\)
"""
if self.likelihood == Likelihood.REGRESSION:
if self.independent_outputs:
Expand Down Expand Up @@ -2603,7 +2609,7 @@ def scatter(self, eps: float = 0.00001) -> torch.Tensor:
"""Compute scatter term in GP log marginal likelihood.
For `classification` we use eq. (3.44) from Chapter 3.5 from
[GP book R&W 2006](http://www.gaussianprocess.org/gpml/chapters/) with \\(\hat{f} = f \\):
[GP book R&W 2006](http://www.gaussianprocess.org/gpml/chapters/) with \\(\\hat{f} = f \\):
scatter term := \\( f K^{-1} f^{T} \\)
Expand Down
13 changes: 4 additions & 9 deletions tests/test_baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tests.utils import ListDataset, dict_data_collator, jacobians_naive

torch.manual_seed(240)
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)

flavors = [FullLaplace, KronLaplace, DiagLaplace]
if find_spec("asdfghjkl") is not None:
Expand Down Expand Up @@ -847,16 +847,11 @@ def test_gridsearch(model, likelihood, prior_prec_type, reg_loader, class_loader


@pytest.mark.parametrize("laplace", flavors)
def test_prametric_fit_y_shape(model_1d, reg_loader_1d, reg_loader_1d_flat, laplace):
def test_parametric_fit_y_shape(model_1d, reg_loader_1d, reg_loader_1d_flat, laplace):
lap = laplace(model_1d, likelihood="regression")
lap.fit(reg_loader_1d) # OK

lap2 = laplace(model_1d, likelihood="regression")
lap2.fit(reg_loader_1d_flat) # Also OK!

H1, H2 = lap.H, lap2.H

if isinstance(H1, KronDecomposed) and isinstance(H2, KronDecomposed):
H1, H2 = H1.to_matrix(), H2.to_matrix()

assert torch.allclose(H1, H2)
with pytest.raises(ValueError):
lap2.fit(reg_loader_1d_flat)
7 changes: 3 additions & 4 deletions tests/test_functional_laplace_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,11 @@ def mock_jacobians(self, x):
)


def test_prametric_fit_y_shape(model_1d, reg_loader_1d, reg_loader_1d_flat):
def test_functional_fit_y_shape(model_1d, reg_loader_1d, reg_loader_1d_flat):
la = FunctionalLaplace(model_1d, "regression", 10, independent_outputs=False)
la.fit(reg_loader_1d)

la2 = FunctionalLaplace(model_1d, "regression", 10, independent_outputs=False)
la2.fit(reg_loader_1d_flat)

assert torch.allclose(la.mu, la2.mu)
assert torch.allclose(la.L, la2.L)
with pytest.raises(ValueError):
la2.fit(reg_loader_1d_flat)
2 changes: 1 addition & 1 deletion tests/test_laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from laplace.lllaplace import DiagLLLaplace, FullLLLaplace, KronLLLaplace

torch.manual_seed(240)
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)
flavors = [
FullLaplace,
KronLaplace,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from laplace.utils import kron as kron_prod
from tests.utils import get_diag_psd_matrix, get_psd_matrix, jacobians_naive

torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)

torch.manual_seed(240)
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)

lrlaplace_param = pytest.param(
LowRankLaplace, marks=pytest.mark.xfail(reason="Unimplemented in the new ASDL")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_subnetlaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)

torch.manual_seed(240)
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)
score_based_subnet_masks = [
RandomSubnetMask,
LargestMagnitudeSubnetMask,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_subset_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from laplace.curvature.curvlinops import CurvlinopsEF, CurvlinopsGGN, CurvlinopsHessian

torch.manual_seed(240)
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)
flavors = [KronLaplace, DiagLaplace, FullLaplace]
valid_backends = [CurvlinopsGGN, CurvlinopsEF, AsdlGGN, AsdlEF]

Expand Down

0 comments on commit 0331f05

Please sign in to comment.