Skip to content

Commit

Permalink
Merge pull request #202 from aleximmer/caveats
Browse files Browse the repository at this point in the history
Caveats
  • Loading branch information
wiseodd authored Jul 7, 2024
2 parents 8e06180 + 0069835 commit ff5d068
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
# You can test your matrix by printing the current Python version
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade pip wheel packaging
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -e .
- name: Test with pytest
Expand Down
20 changes: 17 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
<div align="center">
<img src="https://raw.githubusercontent.com/AlexImmer/Laplace/main/logo/laplace_logo.png" alt="Laplace" width="300"/>
</div>
<img src="https://raw.githubusercontent.com/AlexImmer/Laplace/main/logo/laplace_logo.png" alt="Laplace" width="300"/>

[![Main](https://travis-ci.com/AlexImmer/Laplace.svg?token=rpuRxEjQS6cCZi7ptL9y&branch=main)](https://travis-ci.com/AlexImmer/Laplace)
![pytest](https://github.com/aleximmer/laplace/actions/workflows/pytest.yml/badge.svg)
![lint](https://github.com/aleximmer/laplace/actions/workflows/lint-ruff.yml/badge.svg)
![format](https://github.com/aleximmer/laplace/actions/workflows/format-ruff.yml/badge.svg)
</div>

The laplace package facilitates the application of Laplace approximations for entire neural networks, subnetworks of neural networks, or just their last layer.
The package enables posterior approximations, marginal-likelihood estimation, and various posterior predictive computations.
Expand Down Expand Up @@ -45,6 +47,13 @@ pytest tests/

## Example usage

> [!IMPORTANT]
> As a user, one should not expect Laplace to work automatically.
> That is, one should experiment with different Laplace's options
> (hessian_factorization, prior precision tuning method, predictive method, backend,
> etc!). Try looking at various papers that use Laplace for references on how to
> set all those options depending on the applications/problems at hand.
### _Post-hoc_ prior precision tuning of diagonal LA

In the following example, a pre-trained model is loaded,
Expand Down Expand Up @@ -283,6 +292,11 @@ trained on a GPU but want to run predictions on CPU. In this case, use
torch.load(..., map_location="cpu")
```

> [!WARNING]
> Currently, this library always assumes that the model has an
> output tensor of shape `(batch_size, ..., n_classes)`, so in
> the case of image outputs, you need to rearrange from NCHW to NHWC.
## Structure

The laplace package consists of two main components:
Expand Down
23 changes: 14 additions & 9 deletions laplace/baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,18 +1021,23 @@ def _glm_predictive_distribution(
self,
X: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
joint: bool = False,
diagonal_output=False,
diagonal_output: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
backend_name = self._backend_cls.__name__.lower()
if self.enable_backprop and (
"curvlinops" not in backend_name and "backpack" not in backend_name
):
raise ValueError(
"Backprop through the GLM predictive is only available for the "
"Curvlinops and BackPACK backends."
if "asdl" in self._backend_cls.__name__.lower():
# Asdl's doesn't support backprop over Jacobians
# falling back to functorch
warnings.warn(
"ASDL backend is used which does not support backprop through "
"the functional variance, but `self.enable_backprop = True`. "
"Falling back to using `self.backend.functorch_jacobians` "
"which can be memory intensive for large models."
)

Js, f_mu = self.backend.jacobians(X, enable_backprop=self.enable_backprop)
Js, f_mu = self.backend.functorch_jacobians(
X, enable_backprop=self.enable_backprop
)
else:
Js, f_mu = self.backend.jacobians(X, enable_backprop=self.enable_backprop)

if joint:
f_mu = f_mu.flatten() # (batch*out)
Expand Down
2 changes: 2 additions & 0 deletions laplace/curvature/curvature.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ def diag(
"""
raise NotImplementedError

functorch_jacobians = jacobians


class GGNInterface(CurvatureInterface):
"""Generalized Gauss-Newton or Fisher Curvature Interface.
Expand Down
16 changes: 12 additions & 4 deletions tests/test_baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,9 @@ def test_dict_data(laplace, backend, lik, custom_loader, custom_model, request):


@pytest.mark.parametrize("laplace", [FullLaplace, KronLaplace, DiagLaplace])
@pytest.mark.parametrize("backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF])
@pytest.mark.parametrize(
"backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF, AsdlGGN, AsdlEF]
)
def test_backprop_glm(laplace, model, reg_loader, backend):
X, y = reg_loader.dataset.tensors
X.requires_grad = True
Expand All @@ -682,7 +684,9 @@ def test_backprop_glm(laplace, model, reg_loader, backend):


@pytest.mark.parametrize("laplace", [FullLaplace, KronLaplace, DiagLaplace])
@pytest.mark.parametrize("backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF])
@pytest.mark.parametrize(
"backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF, AsdlGGN, AsdlEF]
)
def test_backprop_glm_joint(laplace, model, reg_loader, backend):
X, y = reg_loader.dataset.tensors
X.requires_grad = True
Expand All @@ -702,7 +706,9 @@ def test_backprop_glm_joint(laplace, model, reg_loader, backend):


@pytest.mark.parametrize("laplace", [FullLaplace, KronLaplace, DiagLaplace])
@pytest.mark.parametrize("backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF])
@pytest.mark.parametrize(
"backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF, AsdlGGN, AsdlEF]
)
def test_backprop_glm_mc(laplace, model, reg_loader, backend):
X, y = reg_loader.dataset.tensors
X.requires_grad = True
Expand All @@ -722,7 +728,9 @@ def test_backprop_glm_mc(laplace, model, reg_loader, backend):


@pytest.mark.parametrize("laplace", [FullLaplace, KronLaplace, DiagLaplace])
@pytest.mark.parametrize("backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF])
@pytest.mark.parametrize(
"backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF, AsdlGGN, AsdlEF]
)
def test_backprop_nn(laplace, model, reg_loader, backend):
X, y = reg_loader.dataset.tensors
X.requires_grad = True
Expand Down

0 comments on commit ff5d068

Please sign in to comment.