From 91c0d170865df21a3ad5e9a875e5e82e74d01db4 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Thu, 4 Jul 2024 15:49:16 -0400 Subject: [PATCH 1/3] Add caveats in README --- README.md | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index baab30f..af80878 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,10 @@
- Laplace -
+Laplace -[![Main](https://travis-ci.com/AlexImmer/Laplace.svg?token=rpuRxEjQS6cCZi7ptL9y&branch=main)](https://travis-ci.com/AlexImmer/Laplace) +![pytest](https://github.com/aleximmer/laplace/actions/workflows/pytest.yml/badge.svg) +![lint](https://github.com/aleximmer/laplace/actions/workflows/lint-ruff.yml/badge.svg) +![format](https://github.com/aleximmer/laplace/actions/workflows/format-ruff.yml/badge.svg) + The laplace package facilitates the application of Laplace approximations for entire neural networks, subnetworks of neural networks, or just their last layer. The package enables posterior approximations, marginal-likelihood estimation, and various posterior predictive computations. @@ -45,6 +47,13 @@ pytest tests/ ## Example usage +> [!IMPORTANT] +> As a user, one should not expect Laplace to work automatically. +> That is, one should experiment with different Laplace's options +> (hessian_factorization, prior precision tuning method, predictive method, backend, +> etc!). Try looking at various papers that use Laplace for references on how to +> set all those options depending on the applications/problems at hand. + ### _Post-hoc_ prior precision tuning of diagonal LA In the following example, a pre-trained model is loaded, @@ -283,6 +292,11 @@ trained on a GPU but want to run predictions on CPU. In this case, use torch.load(..., map_location="cpu") ``` +> [!WARNING] +> Currently, this library always assumes that the model has an +> output tensor of shape `(batch_size, ..., n_classes)`, so in +> the case of image outputs, you need to rearrange from NCHW to NHWC. + ## Structure The laplace package consists of two main components: From ee633bf086373e5b7d2e98bfc4b7b9706b3230cc Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Thu, 4 Jul 2024 15:50:01 -0400 Subject: [PATCH 2/3] Fall back to `functorch_jacobians` in GLM predictive --- laplace/baselaplace.py | 23 ++++++++++++++--------- laplace/curvature/curvature.py | 2 ++ tests/test_baselaplace.py | 16 ++++++++++++---- 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index da956d0..de246b8 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -1021,18 +1021,23 @@ def _glm_predictive_distribution( self, X: torch.Tensor | MutableMapping[str, torch.Tensor | Any], joint: bool = False, - diagonal_output=False, + diagonal_output: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: - backend_name = self._backend_cls.__name__.lower() - if self.enable_backprop and ( - "curvlinops" not in backend_name and "backpack" not in backend_name - ): - raise ValueError( - "Backprop through the GLM predictive is only available for the " - "Curvlinops and BackPACK backends." + if "asdl" in self._backend_cls.__name__.lower(): + # Asdl's doesn't support backprop over Jacobians + # falling back to functorch + warnings.warn( + "ASDL backend is used which does not support backprop through " + "the functional variance, but `self.enable_backprop = True`. " + "Falling back to using `self.backend.functorch_jacobians` " + "which can be memory intensive for large models." ) - Js, f_mu = self.backend.jacobians(X, enable_backprop=self.enable_backprop) + Js, f_mu = self.backend.functorch_jacobians( + X, enable_backprop=self.enable_backprop + ) + else: + Js, f_mu = self.backend.jacobians(X, enable_backprop=self.enable_backprop) if joint: f_mu = f_mu.flatten() # (batch*out) diff --git a/laplace/curvature/curvature.py b/laplace/curvature/curvature.py index 7b5fe60..c0397f1 100644 --- a/laplace/curvature/curvature.py +++ b/laplace/curvature/curvature.py @@ -287,6 +287,8 @@ def diag( """ raise NotImplementedError + functorch_jacobians = jacobians + class GGNInterface(CurvatureInterface): """Generalized Gauss-Newton or Fisher Curvature Interface. diff --git a/tests/test_baselaplace.py b/tests/test_baselaplace.py index d7393fb..5015a61 100644 --- a/tests/test_baselaplace.py +++ b/tests/test_baselaplace.py @@ -662,7 +662,9 @@ def test_dict_data(laplace, backend, lik, custom_loader, custom_model, request): @pytest.mark.parametrize("laplace", [FullLaplace, KronLaplace, DiagLaplace]) -@pytest.mark.parametrize("backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF]) +@pytest.mark.parametrize( + "backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF, AsdlGGN, AsdlEF] +) def test_backprop_glm(laplace, model, reg_loader, backend): X, y = reg_loader.dataset.tensors X.requires_grad = True @@ -682,7 +684,9 @@ def test_backprop_glm(laplace, model, reg_loader, backend): @pytest.mark.parametrize("laplace", [FullLaplace, KronLaplace, DiagLaplace]) -@pytest.mark.parametrize("backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF]) +@pytest.mark.parametrize( + "backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF, AsdlGGN, AsdlEF] +) def test_backprop_glm_joint(laplace, model, reg_loader, backend): X, y = reg_loader.dataset.tensors X.requires_grad = True @@ -702,7 +706,9 @@ def test_backprop_glm_joint(laplace, model, reg_loader, backend): @pytest.mark.parametrize("laplace", [FullLaplace, KronLaplace, DiagLaplace]) -@pytest.mark.parametrize("backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF]) +@pytest.mark.parametrize( + "backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF, AsdlGGN, AsdlEF] +) def test_backprop_glm_mc(laplace, model, reg_loader, backend): X, y = reg_loader.dataset.tensors X.requires_grad = True @@ -722,7 +728,9 @@ def test_backprop_glm_mc(laplace, model, reg_loader, backend): @pytest.mark.parametrize("laplace", [FullLaplace, KronLaplace, DiagLaplace]) -@pytest.mark.parametrize("backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF]) +@pytest.mark.parametrize( + "backend", [BackPackGGN, CurvlinopsGGN, CurvlinopsEF, AsdlGGN, AsdlEF] +) def test_backprop_nn(laplace, model, reg_loader, backend): X, y = reg_loader.dataset.tensors X.requires_grad = True From 0069835c4fbdfc7581d745f7389d89455eddb07a Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Fri, 5 Jul 2024 08:53:52 -0400 Subject: [PATCH 3/3] Fix workflows error due to curvlinops --- .github/workflows/pytest.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 1c49f7f..2cba2f8 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -19,7 +19,7 @@ jobs: # You can test your matrix by printing the current Python version - name: Install dependencies run: | - python -m pip install --upgrade pip + python -m pip install --upgrade pip wheel packaging pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu pip install -e . - name: Test with pytest