Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adding load_model helper for huggingface causal LM models #226

Merged
merged 13 commits into from
Nov 10, 2024

Conversation

chanind
Copy link
Collaborator

@chanind chanind commented Jul 10, 2024

Description

This PR adds support for loading huggingface AutoModelForCausalLM models by internally wrapping them with a HookedRootModule subclass.

To load a model from Huggingface, you can specify model_class_name = 'AutoModelForCausalLM' in the SAE runner config. The hook_name will need to match the named_parameters of the huggingface model, so the usual blocks.0.hooks_resid_pre won't work. Otherwise everything should work the same as when working with TransformerLens models.

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

You have tested formatting, typing and unit tests (acceptance tests not currently in use)

  • I have run make check-ci to check format and linting. (you can run make format to format code if needed.)

Copy link

codecov bot commented Oct 9, 2024

Codecov Report

Attention: Patch coverage is 81.70732% with 15 lines in your changes missing coverage. Please review.

Project coverage is 67.08%. Comparing base (f739500) to head (6e9350d).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
sae_lens/load_model.py 84.37% 5 Missing and 5 partials ⚠️
sae_lens/sae_training_runner.py 66.66% 1 Missing and 1 partial ⚠️
sae_lens/training/activations_store.py 71.42% 1 Missing and 1 partial ⚠️
sae_lens/config.py 80.00% 0 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #226      +/-   ##
==========================================
+ Coverage   66.80%   67.08%   +0.27%     
==========================================
  Files          25       25              
  Lines        3389     3463      +74     
  Branches      434      451      +17     
==========================================
+ Hits         2264     2323      +59     
- Misses       1005     1012       +7     
- Partials      120      128       +8     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@chanind chanind requested a review from jbloomAus October 9, 2024 21:56
@chanind chanind marked this pull request as ready for review October 9, 2024 21:56
@chanind chanind changed the title adding load_model helper for huggingface causal LM models feat: adding load_model helper for huggingface causal LM models Oct 9, 2024
@chanind chanind requested review from curt-tigges and removed request for jbloomAus October 15, 2024 20:13
@chanind
Copy link
Collaborator Author

chanind commented Nov 4, 2024

tagging @anthonyduong9 as a reviewer informally, since Github won't let me use the "reviewers" list to do this.

@@ -40,5 +47,130 @@ def load_model(
model_name, device=cast(Any, device), **model_from_pretrained_kwargs
),
)
elif model_class_name == "AutoModelForCausalLM":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see many elifs and elses after returns. I find it to easier to read when the elifs are ifs, and there's no else, but I've known others who are against this, so am curious to hear your thoughts.

There's a Pylint message for this that seems to be enabled in .pylintrc, but it seems we don't use .pylintrc. Perhaps I should open an issue on Pylint for discussion?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion on this, if we can have a linting rule set up to enforce a style here I'm happy with that. the .pylintrc is likely legacy, so I'll delete that. I'd support moving to Ruff, since that supports everything from pylint, flake8, pyflakes, isort, black, etc, is a lot faster, and seems to be what the industry is moving towards. I'll open an issue for this.

**kwargs: Any,
) -> Output | Loss:
# This is just what's needed for evals, not everything that HookedTransformer has
assert return_type in (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see a lot of assert in the codebase, outside of tests, but a lot of people say to not have assert in production code and to have exceptions instead. Should we have assert outside of tests?

Perhaps I could open an issue for discussion.

Copy link
Collaborator Author

@chanind chanind Nov 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see the argument for that. I would at least agree that a user of the library should never see an assert fail. We use a lot of asserts throughout the codebase to narrow types for pyright which seems fine IMO, but should change to throw exceptions if it's something we expect a user to see. Here, this is something that should never be triggered and I just wanted to make sure that if we did try to pass another return type into this wrapper we'd at least get an error. I'll change this to a NotImplementedError - that should accomplish the same goal.

Comment on lines 354 to 376
sae = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.4.hook_resid_pre",
device="cpu",
)[0]
hf_model = load_model(
model_class_name="AutoModelForCausalLM",
model_name="gpt2",
device="cpu",
)
tlens_model = HookedTransformer.from_pretrained_no_processing("gpt2", device="cpu")

example_ds = Dataset.from_list(
[
{"text": "hello world1"},
{"text": "hello world2"},
{"text": "hello world3"},
]
* 20
)
cfg = build_sae_cfg(hook_name="transformer.h.3")
sae.cfg.hook_name = "transformer.h.3"
hf_store = ActivationsStore.from_config(hf_model, cfg, override_dataset=example_ds)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block is repeated in

test_get_sparsity_and_variance_metrics_with_hf_model_gives_same_results_as_tlens_model().

Should we extract the repeated code to pytest fixtures?

model_name="gpt2",
device="cpu",
)
assert model is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove? If model is None, the next assert would fail.

Comment on lines +43 to +47
model = load_model(
model_class_name="AutoModelForCausalLM",
model_name="gpt2",
device="cpu",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have this in four tests in this file, perhaps we should extract to a fixture.

@chanind chanind merged commit 044d4be into main Nov 10, 2024
7 checks passed
@chanind chanind deleted the huggingface-models branch November 10, 2024 18:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants