Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encode with slice #334

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

callummcdougall
Copy link
Contributor

@callummcdougall callummcdougall commented Oct 16, 2024

It's useful to be able to get a slice of latent activations in an architecture-general way, without having to compute all the latent activation values (e.g. I've found this useful while working on autointerp for SAEBench, which only samples a fraction of the total latents). I've implemented this by adding an optional latents argument to all the encode functions - e.g. if we supply latents=range(100) then we'll just compute feature activations for the first 100 features. The default is latents=None which works as normal.

Some other notes:

  • I've added a test function to test_sae_basic.py, which just verifies that "run encode with latents" gives the same result as "run encode then slice the result with latents"
  • I'm aware this won't work for TopK, but we can add an error message when it comes to that

Also a somewhat unrelated thing - I'm not sure why the forward function doesn't use encode (we essentially have the encode code duplicated here) - if there's no good reason for this, then I'm happy to submit a PR to fix this.

Copy link
Collaborator

@chanind chanind left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this interact with error term calculation, or is this only a performance optimization when not running a full forward pass? How significant of a performance speed-up in your use-case?

I can see the benefit of this, but also worry it makes encode() more complicated if users decide to make SAE subclasses, adds a confusing feature to the TrainingSAE.encode() where it looks like you can pass a latents arg but then it throws an error if you try to, and silently breaks topk SAEs. If the performance improvement this gives is worth the added complexity / downsides it can still make sense to do though.

sae_lens/sae.py Outdated
@@ -544,11 +560,12 @@ def encode_jumprelu(
return feature_acts

def encode_standard(
self, x: Float[torch.Tensor, "... d_in"]
self, x: Float[torch.Tensor, "... d_in"], latents: Iterable[int] | None = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should this also take a tensor as well? It seems like the code will work with a tensor of ints as well

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also looks like topk SAEs go through this codepath and will silently break if anything is passed for latents. We should make it non-silent.

@@ -70,6 +70,19 @@ def test_sae_init(cfg: LanguageModelSAERunnerConfig):
assert sae.b_dec.shape == (cfg.d_in,)


def test_sae_encode_with_slicing(cfg: LanguageModelSAERunnerConfig):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It would be good to test with pytest.mark.parametrize with all architectures, since it seems like it's easy to accidentally mess up the implementation in one architecture and have it slip through the cracks. It could make sense to add a second test calling build_sae_cfg() directly so we can ensure we're hitting every architecture variant explicitly rather than just the pre-defined SAEs in the cfg fixture.

@chanind
Copy link
Collaborator

chanind commented Oct 17, 2024

Re: the separate encode implementations in forward, I noticed this as well and fixed it in #328, just waiting on a review. It seems like the reason is that we didn't want to trigger hooks in error term calculation since then the error term would change when users do interventions on SAE latents. But it also seems like all this duplication is likely introducing subtle bugs

@callummcdougall
Copy link
Contributor Author

callummcdougall commented Oct 17, 2024

How does this interact with error term calculation, or is this only a performance optimization when not running a full forward pass? How significant of a performance speed-up in your use-case?

It's designed only for encode, I don't anticipate a big use case for people doing full fwd passes. It's specifically if you want a way of computing a small subset of feature activations without computing all of them, because otherwise if we have a tensor of (batch, seq_len, d_sae) with very large d_sae, this can dominate memory costs and cause OOMs (which is obvs impractical if we only want e.g. 1 feature and d_sae is O(10k) or larger).

nit: should this also take a tensor as well? It seems like the code will work with a tensor of ints as well

Yep, I can make that change! Makes sense that this would be a bit easier for the type errors. I'll do that in my next commit.

It also looks like topk SAEs go through this codepath and will silently break if anything is passed for latents. We should make it non-silent.

Oh yep I see what you mean, will make that explicit. Was confused cause I thought topk isn't yet a supported architecture but it is, just is expressed in a different way, via activation_fn. I'll do this by changing how the self.encode = ... code works, i.e. actually having an encode method which checks for this.

nit: It would be good to test with pytest.mark.parametrize with all architectures, since it seems like it's easy to accidentally mess up the implementation in one architecture and have it slip through the cracks

Yep will sort that out

@callummcdougall
Copy link
Contributor Author

I can see the benefit of this, but also worry it makes encode() more complicated if users decide to make SAE subclasses, adds a confusing feature to the TrainingSAE.encode() where it looks like you can pass a latents arg but then it throws an error if you try to, and silently breaks topk SAEs

I can add a short docstring explaining about this. Joseph mentioned he implemented this once using a wrapper around the other encode functions but I think it really has to be done on a per-architecture basis, because things like TopK (and possibly more complicated architectures) need to be considered as special cases.

@callummcdougall
Copy link
Contributor Author

callummcdougall commented Oct 17, 2024

@chanind most recent test has failed because LanguageModelSAERunnerConfig has a different name for the activation function (it calls it activation_fn and doesn't have an activation_fn_str, whereas SAEConfig has an activation_fn_str and activation_fn is the actual function). Not pushing a fix yet because I don't have full context on why these are different and I don't think I'll be able to guess what the best solution is without that context. Do you have a sense for what the best solution here is?

Copy link
Collaborator

@chanind chanind left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow the problem related to activation_fn for the runner config vs activation_fn_str for the SAE config. I agree it's weird these are different and don't know why they're not the same, but why is it causing a problem? It looks like tests are failing currently because initialize_weights() is not being called, it doesn't look related to activation_fn_str.

The `latents` argument allows for the computation of a specific subset of the hidden values. If `latents` is not
provided, all latent values will be computed.
"""
latents_tensor = torch.arange(self.cfg.d_sae) if latents is None else latents
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous implementation you had, latents_slice = slice(None) if latents is None else torch.tensor(latents), seems better to me since this new version will create a new tensor of size d_sae on every SAE forward pass when not selecting specific latents. This would likely reduce performance for most users, which seems counter-productive since this PR is just meant to be a performance improvement if I understand the goal correctly. Wouldn't the old implementation have worked fine just adding torch.Tensor to the type of latents? e.g. latents: Iterable[int] | torch.Tensor | None = None

self.encode = self.encode_jumprelu
else:
raise (ValueError)
if self.cfg.architecture not in ["standard", "gated", "jumprelu"]:
Copy link
Collaborator

@chanind chanind Oct 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like tests are failing because initialize_weights() have been moved into their own function, but that function is never called now. IMO this error can be raised in the new initialize_weights() method instead.

@jbloomAus
Copy link
Owner

Sharing a link to my previous solution for this: https://github.com/jbloomAus/SAEDashboard/blob/4f3d6e0d0320816f10de7045dc8e3cbc3ee29f0e/sae_dashboard/feature_data_generator.py#L220

@callummcdougall @chanind can I leave this with you both? Sorry for being MIA. Think this is a good idea.

@callummcdougall
Copy link
Contributor Author

Ah really sorry forgot to follow up on this - yep I'll try and get this sorted this week (and apologies if I'd got wrong the reason for failing tests)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants