-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: adding load_model helper for huggingface causal LM models #226
Changes from 9 commits
dc627db
2177505
7386b9e
58e4b60
9e28deb
684ff81
bfa02f5
08f1f8c
4d7cda9
add7e2c
9c255c4
0c3a5b8
6e9350d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,15 @@ | ||
from typing import Any, cast | ||
from typing import Any, Literal, cast | ||
|
||
import torch | ||
from transformer_lens import HookedTransformer | ||
from transformer_lens.hook_points import HookedRootModule | ||
from transformer_lens.hook_points import HookedRootModule, HookPoint | ||
from transformer_lens.HookedTransformer import Loss, Output | ||
from transformer_lens.utils import ( | ||
USE_DEFAULT_VALUE, | ||
get_tokens_with_bos_removed, | ||
lm_cross_entropy_loss, | ||
) | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase | ||
|
||
|
||
def load_model( | ||
|
@@ -40,5 +47,130 @@ | |
model_name, device=cast(Any, device), **model_from_pretrained_kwargs | ||
), | ||
) | ||
elif model_class_name == "AutoModelForCausalLM": | ||
hf_model = AutoModelForCausalLM.from_pretrained( | ||
model_name, **model_from_pretrained_kwargs | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
return HookedProxyLM(hf_model, tokenizer) | ||
|
||
else: # pragma: no cover | ||
raise ValueError(f"Unknown model class: {model_class_name}") | ||
|
||
|
||
class HookedProxyLM(HookedRootModule): | ||
""" | ||
A HookedRootModule that wraps a Huggingface AutoModelForCausalLM. | ||
""" | ||
|
||
tokenizer: PreTrainedTokenizerBase | ||
model: torch.nn.Module | ||
|
||
def __init__(self, model: torch.nn.Module, tokenizer: PreTrainedTokenizerBase): | ||
super().__init__() | ||
self.model = model | ||
self.tokenizer = tokenizer | ||
self.setup() | ||
|
||
# copied and modified from base HookedRootModule | ||
def setup(self): | ||
self.mod_dict = {} | ||
self.hook_dict: dict[str, HookPoint] = {} | ||
for name, module in self.model.named_modules(): | ||
if name == "": | ||
continue | ||
|
||
hook_point = HookPoint() | ||
hook_point.name = name # type: ignore | ||
|
||
module.register_forward_hook(get_hook_fn(hook_point)) | ||
|
||
self.hook_dict[name] = hook_point | ||
self.mod_dict[name] = hook_point | ||
|
||
def forward( | ||
self, | ||
tokens: torch.Tensor, | ||
return_type: Literal["both", "logits"] = "logits", | ||
loss_per_token: bool = False, | ||
# TODO: implement real support for stop_at_layer | ||
stop_at_layer: int | None = None, | ||
**kwargs: Any, | ||
) -> Output | Loss: | ||
# This is just what's needed for evals, not everything that HookedTransformer has | ||
assert return_type in ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see a lot of Perhaps I could open an issue for discussion. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can see the argument for that. I would at least agree that a user of the library should never see an assert fail. We use a lot of asserts throughout the codebase to narrow types for pyright which seems fine IMO, but should change to throw exceptions if it's something we expect a user to see. Here, this is something that should never be triggered and I just wanted to make sure that if we did try to pass another return type into this wrapper we'd at least get an error. I'll change this to a |
||
"both", | ||
"logits", | ||
), "Only return_type supported is 'both' or 'logits' to match what's in evals.py and ActivationsStore" | ||
output = self.model(tokens) | ||
logits = _extract_logits_from_output(output) | ||
|
||
if return_type == "logits": | ||
return logits | ||
|
||
if tokens.device != logits.device: | ||
tokens = tokens.to(logits.device) | ||
loss = lm_cross_entropy_loss(logits, tokens, per_token=loss_per_token) | ||
return Output(logits, loss) | ||
|
||
def to_tokens( | ||
self, | ||
input: str | list[str], | ||
prepend_bos: bool | None = USE_DEFAULT_VALUE, | ||
padding_side: Literal["left", "right"] | None = USE_DEFAULT_VALUE, | ||
move_to_device: bool = True, | ||
truncate: bool = True, | ||
) -> torch.Tensor: | ||
# Hackily modified version of HookedTransformer.to_tokens to work with ActivationsStore | ||
# Assumes that prepend_bos is always False, move_to_device is always False, and truncate is always False | ||
# copied from HookedTransformer.to_tokens | ||
|
||
assert ( | ||
prepend_bos is False | ||
), "Only works with prepend_bos=False, to match ActivationsStore usage" | ||
assert ( | ||
padding_side is None | ||
), "Only works with padding_side=None, to match ActivationsStore usage" | ||
assert ( | ||
truncate is False | ||
), "Only works with truncate=False, to match ActivationsStore usage" | ||
assert ( | ||
move_to_device is False | ||
), "Only works with move_to_device=False, to match ActivationsStore usage" | ||
|
||
tokens = self.tokenizer( | ||
input, | ||
return_tensors="pt", | ||
truncation=False, | ||
max_length=None, | ||
)["input_ids"] | ||
|
||
# We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually | ||
if hasattr(self.tokenizer, "add_bos_token") and self.tokenizer.add_bos_token: # type: ignore | ||
tokens = get_tokens_with_bos_removed(self.tokenizer, tokens) | ||
return tokens # type: ignore | ||
|
||
|
||
def _extract_logits_from_output(output: Any) -> torch.Tensor: | ||
if isinstance(output, torch.Tensor): | ||
return output | ||
elif isinstance(output, tuple) and isinstance(output[0], torch.Tensor): | ||
return output[0] | ||
elif isinstance(output, dict) and "logits" in output: | ||
return output["logits"] | ||
else: | ||
raise ValueError(f"Unknown output type: {type(output)}") | ||
|
||
|
||
def get_hook_fn(hook_point: HookPoint): | ||
|
||
def hook_fn(module: Any, input: Any, output: Any) -> Any: | ||
if isinstance(output, torch.Tensor): | ||
return hook_point(output) | ||
elif isinstance(output, tuple) and isinstance(output[0], torch.Tensor): | ||
return (hook_point(output[0]), *output[1:]) | ||
else: | ||
# if this isn't a tensor, just skip the hook entirely as this will break otherwise | ||
return output | ||
|
||
return hook_fn |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
import argparse | ||
import json | ||
import math | ||
from pathlib import Path | ||
from unittest.mock import MagicMock, patch | ||
|
||
|
@@ -11,12 +12,15 @@ | |
from sae_lens.evals import ( | ||
EvalConfig, | ||
all_loadable_saes, | ||
get_downstream_reconstruction_metrics, | ||
get_eval_everything_config, | ||
get_saes_from_regex, | ||
get_sparsity_and_variance_metrics, | ||
process_results, | ||
run_evals, | ||
run_evaluations, | ||
) | ||
from sae_lens.load_model import load_model | ||
from sae_lens.sae import SAE | ||
from sae_lens.toolkit.pretrained_saes_directory import PretrainedSAELookup | ||
from sae_lens.training.activations_store import ActivationsStore | ||
|
@@ -31,6 +35,11 @@ | |
) | ||
|
||
|
||
# not sure why we have NaNs in the feature metrics, but this is a quick fix for tests | ||
def _replace_nan(list: list[float]) -> list[float]: | ||
return [0 if math.isnan(x) else x for x in list] | ||
|
||
|
||
@pytest.fixture( | ||
params=[ | ||
{ | ||
|
@@ -341,6 +350,122 @@ def test_process_results(tmp_path: Path): | |
assert csv_path.exists() | ||
|
||
|
||
def test_get_downstream_reconstruction_metrics_with_hf_model_gives_same_results_as_tlens_model(): | ||
sae = SAE.from_pretrained( | ||
release="gpt2-small-res-jb", | ||
sae_id="blocks.4.hook_resid_pre", | ||
device="cpu", | ||
)[0] | ||
hf_model = load_model( | ||
model_class_name="AutoModelForCausalLM", | ||
model_name="gpt2", | ||
device="cpu", | ||
) | ||
tlens_model = HookedTransformer.from_pretrained_no_processing("gpt2", device="cpu") | ||
|
||
example_ds = Dataset.from_list( | ||
[ | ||
{"text": "hello world1"}, | ||
{"text": "hello world2"}, | ||
{"text": "hello world3"}, | ||
] | ||
* 20 | ||
) | ||
cfg = build_sae_cfg(hook_name="transformer.h.3") | ||
sae.cfg.hook_name = "transformer.h.3" | ||
hf_store = ActivationsStore.from_config(hf_model, cfg, override_dataset=example_ds) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block is repeated in
Should we extract the repeated code to |
||
hf_metrics = get_downstream_reconstruction_metrics( | ||
sae=sae, | ||
model=hf_model, | ||
activation_store=hf_store, | ||
compute_kl=True, | ||
compute_ce_loss=True, | ||
n_batches=1, | ||
eval_batch_size_prompts=4, | ||
) | ||
|
||
cfg = build_sae_cfg(hook_name="blocks.4.hook_resid_pre") | ||
sae.cfg.hook_name = "blocks.4.hook_resid_pre" | ||
tlens_store = ActivationsStore.from_config( | ||
tlens_model, cfg, override_dataset=example_ds | ||
) | ||
tlens_metrics = get_downstream_reconstruction_metrics( | ||
sae=sae, | ||
model=tlens_model, | ||
activation_store=tlens_store, | ||
compute_kl=True, | ||
compute_ce_loss=True, | ||
n_batches=1, | ||
eval_batch_size_prompts=4, | ||
) | ||
|
||
for key in hf_metrics.keys(): | ||
assert hf_metrics[key] == pytest.approx(tlens_metrics[key], abs=1e-3) | ||
|
||
|
||
def test_get_sparsity_and_variance_metrics_with_hf_model_gives_same_results_as_tlens_model(): | ||
sae = SAE.from_pretrained( | ||
release="gpt2-small-res-jb", | ||
sae_id="blocks.4.hook_resid_pre", | ||
device="cpu", | ||
)[0] | ||
hf_model = load_model( | ||
model_class_name="AutoModelForCausalLM", | ||
model_name="gpt2", | ||
device="cpu", | ||
) | ||
tlens_model = HookedTransformer.from_pretrained_no_processing("gpt2", device="cpu") | ||
|
||
example_ds = Dataset.from_list( | ||
[ | ||
{"text": "hello world1"}, | ||
{"text": "hello world2"}, | ||
{"text": "hello world3"}, | ||
] | ||
* 20 | ||
) | ||
cfg = build_sae_cfg(hook_name="transformer.h.3") | ||
sae.cfg.hook_name = "transformer.h.3" | ||
hf_store = ActivationsStore.from_config(hf_model, cfg, override_dataset=example_ds) | ||
hf_metrics, hf_feat_metrics = get_sparsity_and_variance_metrics( | ||
sae=sae, | ||
model=hf_model, | ||
activation_store=hf_store, | ||
n_batches=1, | ||
compute_l2_norms=True, | ||
compute_sparsity_metrics=True, | ||
compute_variance_metrics=True, | ||
compute_featurewise_density_statistics=True, | ||
eval_batch_size_prompts=4, | ||
model_kwargs={}, | ||
) | ||
|
||
cfg = build_sae_cfg(hook_name="blocks.4.hook_resid_pre") | ||
sae.cfg.hook_name = "blocks.4.hook_resid_pre" | ||
tlens_store = ActivationsStore.from_config( | ||
tlens_model, cfg, override_dataset=example_ds | ||
) | ||
tlens_metrics, tlens_feat_metrics = get_sparsity_and_variance_metrics( | ||
sae=sae, | ||
model=tlens_model, | ||
activation_store=tlens_store, | ||
n_batches=1, | ||
compute_l2_norms=True, | ||
compute_sparsity_metrics=True, | ||
compute_variance_metrics=True, | ||
compute_featurewise_density_statistics=True, | ||
eval_batch_size_prompts=4, | ||
model_kwargs={}, | ||
) | ||
|
||
for key in hf_metrics.keys(): | ||
assert hf_metrics[key] == pytest.approx(tlens_metrics[key], rel=1e-4) | ||
for key in hf_feat_metrics.keys(): | ||
assert _replace_nan(hf_feat_metrics[key]) == pytest.approx( | ||
_replace_nan(tlens_feat_metrics[key]), rel=1e-4 | ||
) | ||
|
||
|
||
@patch("sae_lens.evals.get_pretrained_saes_directory") | ||
def test_all_loadable_saes(mock_get_pretrained_saes_directory: MagicMock): | ||
mock_get_pretrained_saes_directory.return_value = { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see many
elif
s andelse
s afterreturn
s. I find it to easier to read when theelif
s areif
s, and there's noelse
, but I've known others who are against this, so am curious to hear your thoughts.There's a Pylint message for this that seems to be enabled in
.pylintrc
, but it seems we don't use.pylintrc
. Perhaps I should open an issue on Pylint for discussion?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a strong opinion on this, if we can have a linting rule set up to enforce a style here I'm happy with that. the .pylintrc is likely legacy, so I'll delete that. I'd support moving to Ruff, since that supports everything from pylint, flake8, pyflakes, isort, black, etc, is a lot faster, and seems to be what the industry is moving towards. I'll open an issue for this.