Skip to content

Commit

Permalink
Merge pull request #172 from aleximmer/llla-feat-reduction
Browse files Browse the repository at this point in the history
Add an option to reduce LLM features in `LLLaplace`
  • Loading branch information
wiseodd authored Jun 11, 2024
2 parents 039b4b7 + 945ebd7 commit 83c5de7
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 43 deletions.
43 changes: 37 additions & 6 deletions examples/huggingface_example.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,40 @@ class MyGPT2(nn.Module):
model = MyGPT2(tokenizer)
```

Now, let's apply Laplace. Let's do a last-layer Laplace first. We do so by switching off the
gradients of all layers except the top layer. Laplace will automatically only compute the
Hessian (and Jacobians) of the parameters in which `requires_grad` is `True`.
Now, let's apply Laplace. Let's do a last-layer Laplace first. Notice that we add
an argument `feature_reduction` there. This is because Huggingface models reduce the
logits and [not the features](https://github.com/huggingface/transformers/blob/a98c41798cf6ed99e1ff17e3792d6e06a2ff2ff3/src/transformers/models/gpt2/modeling_gpt2.py#L1678-L1704).

```python
model = MyGPT2(tokenizer)
model.eval()

la = Laplace(
model,
likelihood='classification',
subset_of_weights='last_layer',
hessian_structure='full',
# This must reflect faithfully the reduction technique used in the model
# Otherwise, correctness is not guaranteed
feature_reduction=FeatureReduction.PICK_LAST,
)
la.fit(dataloader)
la.optimize_prior_precision()

X_test = next(iter(dataloader))
print(f'[Last-layer Laplace] The predictive tensor is of shape: {la(X_test).shape}.')
```

Here's the output:
```
[Last-layer Laplace] The predictive tensor is of shape: torch.Size([4, 2]).
```

## Subnetwork Laplace

Also, we can do the same thing by switching off the gradients of all layers except the
top layer. Laplace will automatically only compute the Hessian (and Jacobians) of the
parameters in which `requires_grad` is `True`.

Notice that you can "mix-and-match" this gradient switching. You can do a subnetwork Laplace
easily by doing so!
Expand All @@ -159,16 +190,16 @@ la.fit(dataloader)
la.optimize_prior_precision()

X_test = next(iter(dataloader))
print(f'[Foundation Model] The predictive tensor is of shape: {la(X_test).shape}.')
print(f'[Subnetwork Laplace] The predictive tensor is of shape: {la(X_test).shape}.')
```

Here are the outputs to validate that Laplace works:

```
[Foundation Model] The predictive tensor is of shape: torch.Size([4, 2]).
[Subnetwork Laplace] The predictive tensor is of shape: torch.Size([4, 2]).
```

## Laplace on LoRA parameters only
## Full Laplace on LoRA parameters only

Of course, you can also apply Laplace on the parameter-efficient fine tuning weights (like LoRA).
To do this, simply extend your LLM with LoRA, using HF's `peft` library, and apply Laplace as
Expand Down
31 changes: 28 additions & 3 deletions examples/huggingface_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from laplace import Laplace

from laplace.utils.feature_extractor import FeatureReduction

logging.basicConfig(level='ERROR')
warnings.filterwarnings('ignore')

Expand Down Expand Up @@ -91,10 +93,33 @@ def forward(self, data: MutableMapping) -> torch.Tensor:
return output_dict.logits


# Last-layer Laplace
# ------------------
model = MyGPT2(tokenizer)
model.eval()

la = Laplace(
model,
likelihood='classification',
subset_of_weights='last_layer',
hessian_structure='full',
# This must reflect faithfully the reduction technique used in the model
# Otherwise, correctness is not guaranteed
feature_reduction=FeatureReduction.PICK_LAST,
)
la.fit(dataloader)
la.optimize_prior_precision()

X_test = next(iter(dataloader))
print(f'[Last-layer Laplace] The predictive tensor is of shape: {la(X_test).shape}.')

del model
del la

# Last-layer Laplace on the foundation model itself
# -------------------------------------------------

# Laplace on a subset of parameters by disabling gradients
# --------------------------------------------------------
model = MyGPT2(tokenizer)
model.eval()

# Enable grad only for the last layer
Expand All @@ -115,7 +140,7 @@ def forward(self, data: MutableMapping) -> torch.Tensor:
la.optimize_prior_precision()

X_test = next(iter(dataloader))
print(f'[Foundation Model] The predictive tensor is of shape: {la(X_test).shape}.')
print(f'[Subnetwork Laplace] The predictive tensor is of shape: {la(X_test).shape}.')

del model
del la
Expand Down
15 changes: 15 additions & 0 deletions laplace/lllaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from laplace.baselaplace import (DiagLaplace, FullLaplace, KronLaplace,
ParametricLaplace)
from laplace.utils import FeatureExtractor, Kron
from laplace.utils.feature_extractor import FeatureReduction

__all__ = ['LLLaplace', 'FullLLLaplace', 'KronLLLaplace', 'DiagLLLaplace']

Expand Down Expand Up @@ -52,6 +53,15 @@ class LLLaplace(ParametricLaplace):
enable_backprop: bool, default=False
whether to enable backprop to the input `x` through the Laplace predictive.
Useful for e.g. Bayesian optimization.
feature_reduction: FeatureReduction or str, optional, default=None
when the last-layer `features` is a tensor of dim >= 3, this tells how to reduce
it into a dim-2 tensor. E.g. in LLMs for non-language modeling problems,
the penultultimate output is a tensor of shape `(batch_size, seq_len, embd_dim)`.
But the last layer maps `(batch_size, embd_dim)` to `(batch_size, n_classes)`.
Note: Make sure that this option faithfully reflects the reduction in the model
definition. When inputting a string, available options are
`{'pick_first', 'pick_last', 'average'}`.
dict_key_x: str, default='input_ids'
The dictionary key under which the input tensor `x` is stored. Only has effect
when the model takes a `MutableMapping` as the input. Useful for Huggingface
Expand All @@ -78,6 +88,7 @@ def __init__(
prior_mean=0.0,
temperature=1.0,
enable_backprop=False,
feature_reduction=None,
dict_key_x='inputs_id',
dict_key_y='labels',
backend=None,
Expand All @@ -87,6 +98,7 @@ def __init__(
):
if asdl_fisher_kwargs is not None:
raise ValueError('Last-layer Laplace does not support asdl_fisher_kwargs.')

self.H = None
super().__init__(
model,
Expand All @@ -105,6 +117,7 @@ def __init__(
deepcopy(model),
last_layer_name=last_layer_name,
enable_backprop=enable_backprop,
feature_reduction=feature_reduction,
)
if self.model.last_layer is None:
self.mean = None
Expand Down Expand Up @@ -313,6 +326,7 @@ def __init__(
prior_mean=0.0,
temperature=1.0,
enable_backprop=False,
feature_reduction=None,
dict_key_x='inputs_id',
dict_key_y='labels',
backend=None,
Expand All @@ -329,6 +343,7 @@ def __init__(
prior_mean,
temperature,
enable_backprop,
feature_reduction,
dict_key_x,
dict_key_y,
backend,
Expand Down
46 changes: 44 additions & 2 deletions laplace/utils/feature_extractor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from enum import Enum
from typing import Callable, Optional, Tuple

import torch
import torch.nn as nn
from typing import Tuple, Callable, Optional


__all__ = ['FeatureExtractor']


class FeatureReduction(str, Enum):
PICK_FIRST = 'pick_first'
PICK_LAST = 'pick_last'
AVERAGE = 'average'


class FeatureExtractor(nn.Module):
"""Feature extractor for a PyTorch neural network.
A wrapper which can return the output of the penultimate layer in addition to
Expand All @@ -23,18 +30,39 @@ class FeatureExtractor(nn.Module):
last_layer_name : str, default=None
if the name of the last layer is already known, otherwise it will
be determined automatically.
enable_backprop: bool, default=False
whether to enable backprop through the feature extactor to get the gradients of
the inputs. Useful for e.g. Bayesian optimization.
feature_reduction: FeatureReduction or str, default=None
when the last-layer `features` is a tensor of dim >= 3, this tells how to reduce
it into a dim-2 tensor. E.g. in LLMs for non-language modeling problems,
the penultultimate output is a tensor of shape `(batch_size, seq_len, embd_dim)`.
But the last layer maps `(batch_size, embd_dim)` to `(batch_size, n_classes)`.
Note: Make sure that this option faithfully reflects the reduction in the model
definition. When inputting a string, available options are
`{'pick_first', 'pick_last', 'average'}`.
"""

def __init__(
self,
model: nn.Module,
last_layer_name: Optional[str] = None,
enable_backprop: bool = False,
feature_reduction: Optional[FeatureReduction] = None,
) -> None:
if feature_reduction is not None and feature_reduction not in [
fr.value for fr in FeatureReduction
]:
raise ValueError(
'`feature_reduction` must take value in the `FeatureReduction enum` or '
"one of `{'pick_first', 'pick_last', 'average'}`!"
)

super().__init__()
self.model = model
self._features = dict()
self.enable_backprop = enable_backprop
self.feature_reduction = feature_reduction

if last_layer_name is None:
self.last_layer = None
Expand Down Expand Up @@ -72,6 +100,20 @@ def forward_with_features(
"""
out = self.forward(x)
features = self._features[self._last_layer_name]

if features.dim() > 2 and self.feature_reduction is not None:
n_intermediate_dims = len(features.shape) - 2

if self.feature_reduction == FeatureReduction.PICK_FIRST:
features = features[:, *([0] * n_intermediate_dims), :].squeeze()
elif self.feature_reduction == FeatureReduction.PICK_LAST:
features = features[:, *([0] * n_intermediate_dims), :].squeeze()
else:
ndim = features.ndim
features = features.mean(
dim=tuple(d for d in range(ndim) if d not in [0, ndim - 1])
).squeeze()

return out, features

def set_last_layer(self, last_layer_name: str) -> None:
Expand Down
83 changes: 59 additions & 24 deletions tests/test_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,42 @@

from laplace.utils import FeatureExtractor

import pytest

from laplace.utils.feature_extractor import FeatureReduction


class CNN(nn.Module):
def __init__(self, num_classes):
nn.Module.__init__(self)
self.conv1 = nn.Sequential(
# Input shape (3, 64, 64)
nn.Conv2d(
in_channels=3,
out_channels=6,
kernel_size=5,
stride=1,
padding=2
in_channels=3, out_channels=6, kernel_size=5, stride=1, padding=2
),
# Output shape (6, 60, 60)
nn.ReLU(),
# Output shape (6, 30, 30)
nn.MaxPool2d(kernel_size=2)
nn.MaxPool2d(kernel_size=2),
)

self.fc = nn.Sequential(
nn.Linear(in_features=16 * 16 * 16,
out_features=300),
nn.Linear(in_features=16 * 16 * 16, out_features=300),
nn.ReLU(),
nn.Linear(in_features=300,
out_features=84),
nn.Linear(in_features=300, out_features=84),
nn.ReLU(),
nn.Linear(in_features=84,
out_features=num_classes)
nn.Linear(in_features=84, out_features=num_classes),
)

self.conv2 = nn.Sequential(
# Input shape (6, 30, 30)
nn.Conv2d(
in_channels=6,
out_channels=16,
kernel_size=5,
stride=1,
padding=2
in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=2
),
# Output shape (16, 26, 26)
nn.ReLU(),
# Output shape (16, 13, 13)
nn.MaxPool2d(kernel_size=2)
nn.MaxPool2d(kernel_size=2),
)

def forward(self, x):
Expand Down Expand Up @@ -94,9 +87,9 @@ def get_model(model_name):
nn.Conv2d(3, 6, 3, 1, 1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(6*64*64, 10),
nn.Linear(6 * 64 * 64, 10),
nn.ReLU(),
nn.Linear(10, 10)
nn.Linear(10, 10),
)
else:
raise ValueError(f'{model_name} is not supported.')
Expand All @@ -107,10 +100,23 @@ def get_model(model_name):
def test_feature_extractor():
# all torchvision classifcation models but 'squeezenet' (no linear last layer)
# + model where modules are initilaized in wrong order + nn.Sequential model
model_names = ['resnet18', 'alexnet', 'vgg16', 'densenet', 'inception',
'googlenet', 'shufflenet', 'mobilenet_v2', 'mobilenet_v3_large',
'mobilenet_v3_small', 'resnext50_32x4d', 'wide_resnet50_2',
'mnasnet', 'switchedCNN', 'sequential']
model_names = [
'resnet18',
'alexnet',
'vgg16',
'densenet',
'inception',
'googlenet',
'shufflenet',
'mobilenet_v2',
'mobilenet_v3_large',
'mobilenet_v3_small',
'resnext50_32x4d',
'wide_resnet50_2',
'mnasnet',
'switchedCNN',
'sequential',
]

# to test the last_layer_name argument
# last_layer_names = ['fc', 'classifier.6', 'classifier.6', 'classifier', 'fc',
Expand All @@ -135,3 +141,32 @@ def test_feature_extractor():
last_layer = feature_extractor.last_layer
out2 = last_layer(features)
assert (out == out2).all().item()


@torch.no_grad()
@pytest.mark.parametrize('reduction', [r.value for r in FeatureReduction] + [None])
@pytest.mark.parametrize('additional_dims', [tuple(), (7,), (7, 8, 9)])
def test_multidim_features(reduction, additional_dims):
BATCH_SIZE = 6
IN_DIM = 5
HIDDEN_DIM = 10
OUT_DIM = 2
EXPECTED_FEATS_SHAPE = (BATCH_SIZE, HIDDEN_DIM)

model = nn.Sequential(
nn.Linear(IN_DIM, HIDDEN_DIM),
nn.ReLU(),
nn.Linear(HIDDEN_DIM, OUT_DIM),
).eval()

X = torch.randn(BATCH_SIZE, *(additional_dims), IN_DIM)
out = model(X)
assert out.shape == (BATCH_SIZE, *(additional_dims), OUT_DIM)

feature_extractor = FeatureExtractor(model, feature_reduction=reduction)
out, feats = feature_extractor.forward_with_features(X)

if reduction is None:
assert feats.shape == (BATCH_SIZE, *(additional_dims), HIDDEN_DIM)
else:
assert feats.shape == EXPECTED_FEATS_SHAPE
Loading

0 comments on commit 83c5de7

Please sign in to comment.