-
Notifications
You must be signed in to change notification settings - Fork 121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Encode with slice #334
base: main
Are you sure you want to change the base?
Encode with slice #334
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this interact with error term calculation, or is this only a performance optimization when not running a full forward pass? How significant of a performance speed-up in your use-case?
I can see the benefit of this, but also worry it makes encode() more complicated if users decide to make SAE subclasses, adds a confusing feature to the TrainingSAE.encode()
where it looks like you can pass a latents
arg but then it throws an error if you try to, and silently breaks topk
SAEs. If the performance improvement this gives is worth the added complexity / downsides it can still make sense to do though.
sae_lens/sae.py
Outdated
@@ -544,11 +560,12 @@ def encode_jumprelu( | |||
return feature_acts | |||
|
|||
def encode_standard( | |||
self, x: Float[torch.Tensor, "... d_in"] | |||
self, x: Float[torch.Tensor, "... d_in"], latents: Iterable[int] | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: should this also take a tensor as well? It seems like the code will work with a tensor of ints as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It also looks like topk
SAEs go through this codepath and will silently break if anything is passed for latents
. We should make it non-silent.
@@ -70,6 +70,19 @@ def test_sae_init(cfg: LanguageModelSAERunnerConfig): | |||
assert sae.b_dec.shape == (cfg.d_in,) | |||
|
|||
|
|||
def test_sae_encode_with_slicing(cfg: LanguageModelSAERunnerConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: It would be good to test with pytest.mark.parametrize
with all architectures, since it seems like it's easy to accidentally mess up the implementation in one architecture and have it slip through the cracks. It could make sense to add a second test calling build_sae_cfg()
directly so we can ensure we're hitting every architecture variant explicitly rather than just the pre-defined SAEs in the cfg
fixture.
Re: the separate encode implementations in forward, I noticed this as well and fixed it in #328, just waiting on a review. It seems like the reason is that we didn't want to trigger hooks in error term calculation since then the error term would change when users do interventions on SAE latents. But it also seems like all this duplication is likely introducing subtle bugs |
It's designed only for
Yep, I can make that change! Makes sense that this would be a bit easier for the type errors. I'll do that in my next commit.
Oh yep I see what you mean, will make that explicit. Was confused cause I thought topk isn't yet a supported architecture but it is, just is expressed in a different way, via
Yep will sort that out |
I can add a short docstring explaining about this. Joseph mentioned he implemented this once using a wrapper around the other encode functions but I think it really has to be done on a per-architecture basis, because things like TopK (and possibly more complicated architectures) need to be considered as special cases. |
@chanind most recent test has failed because |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't follow the problem related to activation_fn
for the runner config vs activation_fn_str
for the SAE config. I agree it's weird these are different and don't know why they're not the same, but why is it causing a problem? It looks like tests are failing currently because initialize_weights()
is not being called, it doesn't look related to activation_fn_str
.
The `latents` argument allows for the computation of a specific subset of the hidden values. If `latents` is not | ||
provided, all latent values will be computed. | ||
""" | ||
latents_tensor = torch.arange(self.cfg.d_sae) if latents is None else latents |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous implementation you had, latents_slice = slice(None) if latents is None else torch.tensor(latents)
, seems better to me since this new version will create a new tensor of size d_sae
on every SAE forward pass when not selecting specific latents. This would likely reduce performance for most users, which seems counter-productive since this PR is just meant to be a performance improvement if I understand the goal correctly. Wouldn't the old implementation have worked fine just adding torch.Tensor
to the type of latents
? e.g. latents: Iterable[int] | torch.Tensor | None = None
self.encode = self.encode_jumprelu | ||
else: | ||
raise (ValueError) | ||
if self.cfg.architecture not in ["standard", "gated", "jumprelu"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like tests are failing because initialize_weights()
have been moved into their own function, but that function is never called now. IMO this error can be raised in the new initialize_weights()
method instead.
Sharing a link to my previous solution for this: https://github.com/jbloomAus/SAEDashboard/blob/4f3d6e0d0320816f10de7045dc8e3cbc3ee29f0e/sae_dashboard/feature_data_generator.py#L220 @callummcdougall @chanind can I leave this with you both? Sorry for being MIA. Think this is a good idea. |
Ah really sorry forgot to follow up on this - yep I'll try and get this sorted this week (and apologies if I'd got wrong the reason for failing tests) |
It's useful to be able to get a slice of latent activations in an architecture-general way, without having to compute all the latent activation values (e.g. I've found this useful while working on autointerp for SAEBench, which only samples a fraction of the total latents). I've implemented this by adding an optional
latents
argument to all theencode
functions - e.g. if we supplylatents=range(100)
then we'll just compute feature activations for the first 100 features. The default islatents=None
which works as normal.Some other notes:
test_sae_basic.py
, which just verifies that "runencode
withlatents
" gives the same result as "runencode
then slice the result withlatents
"Also a somewhat unrelated thing - I'm not sure why the
forward
function doesn't useencode
(we essentially have theencode
code duplicated here) - if there's no good reason for this, then I'm happy to submit a PR to fix this.