Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adding load_model helper for huggingface causal LM models #226

Merged
merged 13 commits into from
Nov 10, 2024
5 changes: 4 additions & 1 deletion docs/training_saes.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ The learning rate scheduler can be controlled with the `lr_scheduler_name` param

To avoid dead features, it's often helpful to slowly increase the L1 penalty. This can be done by setting `l1_warm_up_steps` to a value larger than 0. This will linearly increase the L1 penalty over the first `l1_warm_up_steps` training steps.

## Training on Huggingface Models

While TransformerLens is the recommended way to use SAELens, it is also possible to use any Huggingface AutoModelForCausalLM as the model. This is useful if you want to use a model that is not supported by TransformerLens, or if you cannot use TransformerLens due to memory or performance reasons. To use a Huggingface AutoModelForCausalLM, you can specify `model_class_name = 'AutoModelForCausalLM'` in the SAE config. Your hook points will then need to correspond to the named parameters of the Huggingface model rather than the typical TransformerLens hook points. For instance, if you were using GPT2 from Huggingface, you would use `hook_name = 'transformer.h.1'` rather than `hook_name = 'blocks.1.hook_resid_post'`. Otherwise everything should work the same as with TransformerLens models.


## Datasets, streaming, and context size

Expand Down Expand Up @@ -190,7 +194,6 @@ CacheActivationsRunner(cfg).run()

To use the cached activations during training, set `use_cached_activations=True` and `cached_activations_path` to match the `new_cached_activations_path` above option in training configuration.


## Uploading SAEs to Huggingface

Once you have a set of SAEs that you're happy with, your next step is to share them with the world! SAELens has a `upload_saes_to_huggingface()` function which makes this easy to do. We also provide a [uploading saes to huggingface tutorial](https://github.com/jbloomAus/SAELens/blob/main/tutorials/uploading_saes_to_huggingface.ipynb) with more details.
Expand Down
136 changes: 134 additions & 2 deletions sae_lens/load_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from typing import Any, cast
from typing import Any, Literal, cast

import torch
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookedRootModule
from transformer_lens.hook_points import HookedRootModule, HookPoint
from transformer_lens.HookedTransformer import Loss, Output
from transformer_lens.utils import (
USE_DEFAULT_VALUE,
get_tokens_with_bos_removed,
lm_cross_entropy_loss,
)
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase


def load_model(
Expand Down Expand Up @@ -40,5 +47,130 @@
model_name, device=cast(Any, device), **model_from_pretrained_kwargs
),
)
elif model_class_name == "AutoModelForCausalLM":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see many elifs and elses after returns. I find it to easier to read when the elifs are ifs, and there's no else, but I've known others who are against this, so am curious to hear your thoughts.

There's a Pylint message for this that seems to be enabled in .pylintrc, but it seems we don't use .pylintrc. Perhaps I should open an issue on Pylint for discussion?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion on this, if we can have a linting rule set up to enforce a style here I'm happy with that. the .pylintrc is likely legacy, so I'll delete that. I'd support moving to Ruff, since that supports everything from pylint, flake8, pyflakes, isort, black, etc, is a lot faster, and seems to be what the industry is moving towards. I'll open an issue for this.

hf_model = AutoModelForCausalLM.from_pretrained(
model_name, **model_from_pretrained_kwargs
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return HookedProxyLM(hf_model, tokenizer)

else: # pragma: no cover
raise ValueError(f"Unknown model class: {model_class_name}")


class HookedProxyLM(HookedRootModule):
"""
A HookedRootModule that wraps a Huggingface AutoModelForCausalLM.
"""

tokenizer: PreTrainedTokenizerBase
model: torch.nn.Module

def __init__(self, model: torch.nn.Module, tokenizer: PreTrainedTokenizerBase):
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.setup()

# copied and modified from base HookedRootModule
def setup(self):
self.mod_dict = {}
self.hook_dict: dict[str, HookPoint] = {}
for name, module in self.model.named_modules():
if name == "":
continue

hook_point = HookPoint()
hook_point.name = name # type: ignore

module.register_forward_hook(get_hook_fn(hook_point))

self.hook_dict[name] = hook_point
self.mod_dict[name] = hook_point

def forward(
self,
tokens: torch.Tensor,
return_type: Literal["both", "logits"] = "logits",
loss_per_token: bool = False,
# TODO: implement real support for stop_at_layer
stop_at_layer: int | None = None,
**kwargs: Any,
) -> Output | Loss:
# This is just what's needed for evals, not everything that HookedTransformer has
assert return_type in (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see a lot of assert in the codebase, outside of tests, but a lot of people say to not have assert in production code and to have exceptions instead. Should we have assert outside of tests?

Perhaps I could open an issue for discussion.

Copy link
Collaborator Author

@chanind chanind Nov 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see the argument for that. I would at least agree that a user of the library should never see an assert fail. We use a lot of asserts throughout the codebase to narrow types for pyright which seems fine IMO, but should change to throw exceptions if it's something we expect a user to see. Here, this is something that should never be triggered and I just wanted to make sure that if we did try to pass another return type into this wrapper we'd at least get an error. I'll change this to a NotImplementedError - that should accomplish the same goal.

"both",
"logits",
), "Only return_type supported is 'both' or 'logits' to match what's in evals.py and ActivationsStore"
output = self.model(tokens)
logits = _extract_logits_from_output(output)

if return_type == "logits":
return logits

if tokens.device != logits.device:
tokens = tokens.to(logits.device)

Check warning on line 112 in sae_lens/load_model.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/load_model.py#L112

Added line #L112 was not covered by tests
loss = lm_cross_entropy_loss(logits, tokens, per_token=loss_per_token)
return Output(logits, loss)

def to_tokens(
self,
input: str | list[str],
prepend_bos: bool | None = USE_DEFAULT_VALUE,
padding_side: Literal["left", "right"] | None = USE_DEFAULT_VALUE,
move_to_device: bool = True,
truncate: bool = True,
) -> torch.Tensor:
# Hackily modified version of HookedTransformer.to_tokens to work with ActivationsStore
# Assumes that prepend_bos is always False, move_to_device is always False, and truncate is always False
# copied from HookedTransformer.to_tokens

assert (
prepend_bos is False
), "Only works with prepend_bos=False, to match ActivationsStore usage"
assert (
padding_side is None
), "Only works with padding_side=None, to match ActivationsStore usage"
assert (
truncate is False
), "Only works with truncate=False, to match ActivationsStore usage"
assert (
move_to_device is False
), "Only works with move_to_device=False, to match ActivationsStore usage"

tokens = self.tokenizer(
input,
return_tensors="pt",
truncation=False,
max_length=None,
)["input_ids"]

# We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually
if hasattr(self.tokenizer, "add_bos_token") and self.tokenizer.add_bos_token: # type: ignore
tokens = get_tokens_with_bos_removed(self.tokenizer, tokens)

Check warning on line 150 in sae_lens/load_model.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/load_model.py#L150

Added line #L150 was not covered by tests
return tokens # type: ignore


def _extract_logits_from_output(output: Any) -> torch.Tensor:
if isinstance(output, torch.Tensor):
return output

Check warning on line 156 in sae_lens/load_model.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/load_model.py#L156

Added line #L156 was not covered by tests
elif isinstance(output, tuple) and isinstance(output[0], torch.Tensor):
return output[0]
elif isinstance(output, dict) and "logits" in output:
return output["logits"]
else:
raise ValueError(f"Unknown output type: {type(output)}")

Check warning on line 162 in sae_lens/load_model.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/load_model.py#L162

Added line #L162 was not covered by tests


def get_hook_fn(hook_point: HookPoint):

def hook_fn(module: Any, input: Any, output: Any) -> Any:
if isinstance(output, torch.Tensor):
return hook_point(output)
elif isinstance(output, tuple) and isinstance(output[0], torch.Tensor):
return (hook_point(output[0]), *output[1:])
else:
# if this isn't a tensor, just skip the hook entirely as this will break otherwise
return output

return hook_fn
13 changes: 11 additions & 2 deletions sae_lens/training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@
self.model.to_tokens(
row,
truncate=False,
move_to_device=True,
move_to_device=False, # we move to device below
prepend_bos=False,
)
.squeeze(0)
Expand Down Expand Up @@ -436,7 +436,7 @@
else:
sequences.append(next(self.iterable_sequences))

return torch.stack(sequences, dim=0).to(self.model.W_E.device)
return torch.stack(sequences, dim=0).to(_get_model_device(self.model))

@torch.no_grad()
def get_activations(self, batch_tokens: torch.Tensor):
Expand Down Expand Up @@ -700,3 +700,12 @@
raise ValueError(
f"Dataset tokenizer {tokenizer_name} does not match model tokenizer {model_tokenizer}."
)


def _get_model_device(model: HookedRootModule) -> torch.device:
if hasattr(model, "W_E"):
return model.W_E.device
elif hasattr(model, "cfg") and hasattr(model.cfg, "device"):
return model.cfg.device

Check warning on line 709 in sae_lens/training/activations_store.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/training/activations_store.py#L709

Added line #L709 was not covered by tests
else:
return next(model.parameters()).device
125 changes: 125 additions & 0 deletions tests/unit/test_evals.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import json
import math
from pathlib import Path
from unittest.mock import MagicMock, patch

Expand All @@ -11,12 +12,15 @@
from sae_lens.evals import (
EvalConfig,
all_loadable_saes,
get_downstream_reconstruction_metrics,
get_eval_everything_config,
get_saes_from_regex,
get_sparsity_and_variance_metrics,
process_results,
run_evals,
run_evaluations,
)
from sae_lens.load_model import load_model
from sae_lens.sae import SAE
from sae_lens.toolkit.pretrained_saes_directory import PretrainedSAELookup
from sae_lens.training.activations_store import ActivationsStore
Expand All @@ -31,6 +35,11 @@
)


# not sure why we have NaNs in the feature metrics, but this is a quick fix for tests
def _replace_nan(list: list[float]) -> list[float]:
return [0 if math.isnan(x) else x for x in list]


@pytest.fixture(
params=[
{
Expand Down Expand Up @@ -341,6 +350,122 @@ def test_process_results(tmp_path: Path):
assert csv_path.exists()


def test_get_downstream_reconstruction_metrics_with_hf_model_gives_same_results_as_tlens_model():
sae = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.4.hook_resid_pre",
device="cpu",
)[0]
hf_model = load_model(
model_class_name="AutoModelForCausalLM",
model_name="gpt2",
device="cpu",
)
tlens_model = HookedTransformer.from_pretrained_no_processing("gpt2", device="cpu")

example_ds = Dataset.from_list(
[
{"text": "hello world1"},
{"text": "hello world2"},
{"text": "hello world3"},
]
* 20
)
cfg = build_sae_cfg(hook_name="transformer.h.3")
sae.cfg.hook_name = "transformer.h.3"
hf_store = ActivationsStore.from_config(hf_model, cfg, override_dataset=example_ds)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block is repeated in

test_get_sparsity_and_variance_metrics_with_hf_model_gives_same_results_as_tlens_model().

Should we extract the repeated code to pytest fixtures?

hf_metrics = get_downstream_reconstruction_metrics(
sae=sae,
model=hf_model,
activation_store=hf_store,
compute_kl=True,
compute_ce_loss=True,
n_batches=1,
eval_batch_size_prompts=4,
)

cfg = build_sae_cfg(hook_name="blocks.4.hook_resid_pre")
sae.cfg.hook_name = "blocks.4.hook_resid_pre"
tlens_store = ActivationsStore.from_config(
tlens_model, cfg, override_dataset=example_ds
)
tlens_metrics = get_downstream_reconstruction_metrics(
sae=sae,
model=tlens_model,
activation_store=tlens_store,
compute_kl=True,
compute_ce_loss=True,
n_batches=1,
eval_batch_size_prompts=4,
)

for key in hf_metrics.keys():
assert hf_metrics[key] == pytest.approx(tlens_metrics[key], abs=1e-3)


def test_get_sparsity_and_variance_metrics_with_hf_model_gives_same_results_as_tlens_model():
sae = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.4.hook_resid_pre",
device="cpu",
)[0]
hf_model = load_model(
model_class_name="AutoModelForCausalLM",
model_name="gpt2",
device="cpu",
)
tlens_model = HookedTransformer.from_pretrained_no_processing("gpt2", device="cpu")

example_ds = Dataset.from_list(
[
{"text": "hello world1"},
{"text": "hello world2"},
{"text": "hello world3"},
]
* 20
)
cfg = build_sae_cfg(hook_name="transformer.h.3")
sae.cfg.hook_name = "transformer.h.3"
hf_store = ActivationsStore.from_config(hf_model, cfg, override_dataset=example_ds)
hf_metrics, hf_feat_metrics = get_sparsity_and_variance_metrics(
sae=sae,
model=hf_model,
activation_store=hf_store,
n_batches=1,
compute_l2_norms=True,
compute_sparsity_metrics=True,
compute_variance_metrics=True,
compute_featurewise_density_statistics=True,
eval_batch_size_prompts=4,
model_kwargs={},
)

cfg = build_sae_cfg(hook_name="blocks.4.hook_resid_pre")
sae.cfg.hook_name = "blocks.4.hook_resid_pre"
tlens_store = ActivationsStore.from_config(
tlens_model, cfg, override_dataset=example_ds
)
tlens_metrics, tlens_feat_metrics = get_sparsity_and_variance_metrics(
sae=sae,
model=tlens_model,
activation_store=tlens_store,
n_batches=1,
compute_l2_norms=True,
compute_sparsity_metrics=True,
compute_variance_metrics=True,
compute_featurewise_density_statistics=True,
eval_batch_size_prompts=4,
model_kwargs={},
)

for key in hf_metrics.keys():
assert hf_metrics[key] == pytest.approx(tlens_metrics[key], rel=1e-4)
for key in hf_feat_metrics.keys():
assert _replace_nan(hf_feat_metrics[key]) == pytest.approx(
_replace_nan(tlens_feat_metrics[key]), rel=1e-4
)


@patch("sae_lens.evals.get_pretrained_saes_directory")
def test_all_loadable_saes(mock_get_pretrained_saes_directory: MagicMock):
mock_get_pretrained_saes_directory.return_value = {
Expand Down
Loading