Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Topk SAE training #370

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions docs/training_saes.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,47 @@ sparse_autoencoder = SAETrainingRunner(cfg).run()

As you can see, the training setup provides a large number of options to explore. The full list of options can be found in the [LanguageModelSAERunnerConfig][sae_lens.LanguageModelSAERunnerConfig] class.

### Training Topk SAEs

By default, SAELens will train SAEs using a L1 loss term with ReLU activation. A popular alternative architecture is the [TopK](https://arxiv.org/abs/2406.04093) architecture, which fixes the L0 of the SAE using a TopK activation function. To train a TopK SAE, set the `architecture` parameter to `"topk"` in the config. You can set the `k` parameter via `activation_fn_kwargs`. If not set, the default is `k=100`. The TopK architecture ignores the `l1_coefficient` parameter.

```python
cfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn_kwargs={"k": 100},
# ...
)
sparse_autoencoder = SAETrainingRunner(cfg).run()
```

### Training JumpReLU SAEs

[JumpReLU SAEs](https://arxiv.org/abs/2407.14435) are the current state-of-the-art SAE architecture, but are often more tricky to train than other architectures. To train a JumpReLU SAE, set the `architecture` parameter to `"jumprelu"` in the config. JumpReLU SAEs use an sparsity penalty that is controlled using the `l1_coefficient` parameter. This is technically a misnomer as the JumpReLU sparsity penalty is not a L1 penalty, but we keep the parameter name for consistency with the L1 penalty used by the standard architecture. The JumpReLU architecture also has two additional parameters: `jumprelu_bandwidth` and `jumprelu_init_threshold`. Both of these are likely fine at their default values, but may be worth experimenting with if JumpReLU training is too slow to converge.

```python
cfg = LanguageModelSAERunnerConfig(
architecture="jumprelu",
l1_coefficient=5.0,
jumprelu_bandwidth=0.001,
jumprelu_init_threshold=0.001,
# ...
)
sparse_autoencoder = SAETrainingRunner(cfg).run()
```

### Training Gated SAEs

[Gated SAEs](https://arxiv.org/abs/2404.16014) are a precursor to JumpReLU SAEs, but using a simpler training procedure that should make them easier to train. To train a Gated SAE, set the `architecture` parameter to `"gated"` in the config. Gated SAEs use the `l1_coefficient` parameter to control the sparsity of the SAE, the same as standard SAEs. If JumpReLU training is too slow to converge, it may be worth trying a Gated SAE instead.

```python
cfg = LanguageModelSAERunnerConfig(
architecture="gated",
l1_coefficient=5.0,
# ...
)
sparse_autoencoder = SAETrainingRunner(cfg).run()
```

## CLI Runner

The SAE training runner can also be run from the command line via the `sae_lens.sae_training_runner` module. This can be useful for quickly testing different hyperparameters or running training on a remote server. The command line interface is shown below. All options to the CLI are the same as the [LanguageModelSAERunnerConfig][sae_lens.LanguageModelSAERunnerConfig] with a `--` prefix. E.g., `--model_name` is the same as `model_name` in the config.
Expand Down
26 changes: 22 additions & 4 deletions sae_lens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class LanguageModelSAERunnerConfig:
Configuration for training a sparse autoencoder on a language model.

Args:
architecture (str): The architecture to use, either "standard", "gated", or "jumprelu".
architecture (str): The architecture to use, either "standard", "gated", "topk", or "jumprelu".
model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub.
model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
Expand Down Expand Up @@ -130,15 +130,15 @@ class LanguageModelSAERunnerConfig:
)

# SAE Parameters
architecture: Literal["standard", "gated", "jumprelu"] = "standard"
architecture: Literal["standard", "gated", "jumprelu", "topk"] = "standard"
d_in: int = 512
d_sae: Optional[int] = None
b_dec_init_method: str = "geometric_median"
expansion_factor: Optional[int] = (
None # defaults to 4 if d_sae and expansion_factor is None
)
activation_fn: str = "relu" # relu, tanh-relu, topk
activation_fn_kwargs: dict[str, Any] = field(default_factory=dict) # for topk
activation_fn: str = None # relu, tanh-relu, topk. Default is relu. # type: ignore
activation_fn_kwargs: dict[str, Any] = None # for topk # type: ignore
normalize_sae_decoder: bool = True
noise_scale: float = 0.0
from_pretrained_path: Optional[str] = None
Expand All @@ -164,6 +164,8 @@ class LanguageModelSAERunnerConfig:
seed: int = 42
dtype: str = "float32" # type: ignore #
prepend_bos: bool = True

# JumpReLU Parameters
jumprelu_init_threshold: float = 0.001
jumprelu_bandwidth: float = 0.001

Expand Down Expand Up @@ -251,6 +253,22 @@ def __post_init__(self):
self.hook_head_index,
)

if self.activation_fn is None:
self.activation_fn = "topk" if self.architecture == "topk" else "relu"

if self.architecture == "topk" and self.activation_fn != "topk":
raise ValueError("If using topk architecture, activation_fn must be topk.")

if self.activation_fn_kwargs is None:
self.activation_fn_kwargs = (
{"k": 100} if self.activation_fn == "topk" else {}
)

if self.architecture == "topk" and self.activation_fn_kwargs.get("k") is None:
raise ValueError(
"activation_fn_kwargs.k must be provided for topk architecture."
)

if self.d_sae is not None and self.expansion_factor is not None:
raise ValueError("You can't set both d_sae and expansion_factor.")

Expand Down
5 changes: 3 additions & 2 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
@dataclass
class SAEConfig:
# architecture details
architecture: Literal["standard", "gated", "jumprelu"]
architecture: Literal["standard", "gated", "jumprelu", "topk"]

# forward pass details.
d_in: int
Expand Down Expand Up @@ -157,7 +157,7 @@ def __init__(
self.device = torch.device(cfg.device)
self.use_error_term = use_error_term

if self.cfg.architecture == "standard":
if self.cfg.architecture == "standard" or self.cfg.architecture == "topk":
self.initialize_weights_basic()
self.encode = self.encode_standard
elif self.cfg.architecture == "gated":
Expand Down Expand Up @@ -681,6 +681,7 @@ def __init__(
self.k = k
self.postact_fn = postact_fn

# TODO: Use a fused kernel to speed up topk decoding like https://github.com/EleutherAI/sae/blob/main/sae/kernels.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
topk = torch.topk(x, k=self.k, dim=-1)
values = self.postact_fn(topk.values)
Expand Down
9 changes: 3 additions & 6 deletions sae_lens/sae_training_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@
sae = self.run_trainer_with_interruption_handling(trainer)

if self.cfg.log_to_wandb:
# remove this type ignore comment after https://github.com/wandb/wandb/issues/8248 is resolved
wandb.finish() # type: ignore
wandb.finish()

Check warning on line 116 in sae_lens/sae_training_runner.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/sae_training_runner.py#L116

Added line #L116 was not covered by tests

return sae

Expand Down Expand Up @@ -228,17 +227,15 @@
model_artifact.add_file(f"{path}/{SAE_WEIGHTS_PATH}")
model_artifact.add_file(f"{path}/{SAE_CFG_PATH}")

# remove this type ignore comment after https://github.com/wandb/wandb/issues/8248 is resolved
wandb.log_artifact(model_artifact, aliases=wandb_aliases) # type: ignore
wandb.log_artifact(model_artifact, aliases=wandb_aliases)

Check warning on line 230 in sae_lens/sae_training_runner.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/sae_training_runner.py#L230

Added line #L230 was not covered by tests

sparsity_artifact = wandb.Artifact(
f"{sae_name}_log_feature_sparsity",
type="log_feature_sparsity",
metadata=dict(trainer.cfg.__dict__),
)
sparsity_artifact.add_file(log_feature_sparsity_path)
# remove this type ignore comment after https://github.com/wandb/wandb/issues/8248 is resolved
wandb.log_artifact(sparsity_artifact) # type: ignore
wandb.log_artifact(sparsity_artifact)

Check warning on line 238 in sae_lens/sae_training_runner.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/sae_training_runner.py#L238

Added line #L238 was not covered by tests

return checkpoint_path

Expand Down
4 changes: 2 additions & 2 deletions sae_lens/training/sae_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ def __init__(
num_cycles=cfg.n_restart_cycles,
)
self.l1_scheduler = L1Scheduler(
l1_warm_up_steps=cfg.l1_warm_up_steps, # type: ignore
l1_warm_up_steps=cfg.l1_warm_up_steps,
total_steps=cfg.total_training_steps,
final_l1_coefficient=cfg.l1_coefficient,
)

# Setup autocast if using
self.scaler = torch.cuda.amp.GradScaler(enabled=self.cfg.autocast)
self.scaler = torch.amp.GradScaler(device="cuda", enabled=self.cfg.autocast)

if self.cfg.autocast:
self.autocast_if_enabled = torch.autocast(
Expand Down
74 changes: 68 additions & 6 deletions sae_lens/training/training_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
sae_in: torch.Tensor
sae_out: torch.Tensor
feature_acts: torch.Tensor
hidden_pre: torch.Tensor
loss: torch.Tensor # we need to call backwards on this
losses: dict[str, float | torch.Tensor]

Expand Down Expand Up @@ -237,7 +238,7 @@
super().__init__(base_sae_cfg)
self.cfg = cfg # type: ignore

if cfg.architecture == "standard":
if cfg.architecture == "standard" or cfg.architecture == "topk":
self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre
elif cfg.architecture == "gated":
self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre_gated
Expand Down Expand Up @@ -281,11 +282,10 @@
return cls(TrainingSAEConfig.from_dict(config_dict))

def check_cfg_compatibility(self):
if self.cfg.architecture == "gated":
assert (
self.cfg.use_ghost_grads is False
), "Gated SAEs do not support ghost grads"
assert self.use_error_term is False, "Gated SAEs do not support error terms"
if self.cfg.architecture != "standard" and self.cfg.use_ghost_grads:
raise ValueError(f"{self.cfg.architecture} SAEs do not support ghost grads")

Check warning on line 286 in sae_lens/training/training_sae.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/training/training_sae.py#L286

Added line #L286 was not covered by tests
if self.cfg.architecture == "gated" and self.use_error_term:
raise ValueError("Gated SAEs do not support error terms")

Check warning on line 288 in sae_lens/training/training_sae.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/training/training_sae.py#L288

Added line #L288 was not covered by tests

def encode_standard(
self, x: Float[torch.Tensor, "... d_in"]
Expand Down Expand Up @@ -413,6 +413,15 @@
l0_loss = (current_l1_coefficient * l0).mean()
loss = mse_loss + l0_loss
losses["l0_loss"] = l0_loss
elif self.cfg.architecture == "topk":
topk_loss = self.calculate_topk_aux_loss(
sae_in=sae_in,
sae_out=sae_out,
hidden_pre=hidden_pre,
dead_neuron_mask=dead_neuron_mask,
)
losses["auxiliary_reconstruction_loss"] = topk_loss
loss = mse_loss + topk_loss
else:
# default SAE sparsity loss
weighted_feature_acts = feature_acts
Expand Down Expand Up @@ -446,10 +455,47 @@
sae_in=sae_in,
sae_out=sae_out,
feature_acts=feature_acts,
hidden_pre=hidden_pre,
loss=loss,
losses=losses,
)

def calculate_topk_aux_loss(
self,
sae_in: torch.Tensor,
sae_out: torch.Tensor,
hidden_pre: torch.Tensor,
dead_neuron_mask: torch.Tensor | None,
) -> torch.Tensor:
# Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
# NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
if (
dead_neuron_mask is not None
and (num_dead := int(dead_neuron_mask.sum())) > 0
):
residual = sae_in - sae_out

# Heuristic from Appendix B.1 in the paper
k_aux = hidden_pre.shape[-1] // 2

# Reduce the scale of the loss if there are a small number of dead latents
scale = min(num_dead / k_aux, 1.0)
k_aux = min(k_aux, num_dead)

auxk_acts = _calculate_topk_aux_acts(
k_aux=k_aux,
hidden_pre=hidden_pre,
dead_neuron_mask=dead_neuron_mask,
)

# Encourage the top ~50% of dead latents to predict the residual of the
# top k living latents
recons = self.decode(auxk_acts)
auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
return scale * auxk_loss
else:
return sae_out.new_tensor(0.0)

def calculate_ghost_grad_loss(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -644,3 +690,19 @@
self.W_dec.data,
"d_sae, d_sae d_in -> d_sae d_in",
)


def _calculate_topk_aux_acts(
k_aux: int,
hidden_pre: torch.Tensor,
dead_neuron_mask: torch.Tensor,
) -> torch.Tensor:
# Don't include living latents in this loss
auxk_latents = torch.where(dead_neuron_mask[None], hidden_pre, -torch.inf)
# Top-k dead latents
auxk_topk = auxk_latents.topk(k_aux, sorted=False)
# Set the activations to zero for all but the top k_aux dead latents
auxk_acts = torch.zeros_like(hidden_pre)
auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
# Set activations to zero for all but top k_aux dead latents
return auxk_acts
23 changes: 23 additions & 0 deletions tests/unit/training/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,26 @@ def test_cache_activations_runner_config_seqpos(
seqpos_slice=seqpos_slice,
context_size=context_size,
)


def test_topk_architecture_requires_topk_activation():
with pytest.raises(
ValueError, match="If using topk architecture, activation_fn must be topk."
):
LanguageModelSAERunnerConfig(architecture="topk", activation_fn="relu")


def test_topk_architecture_requires_k_parameter():
with pytest.raises(
ValueError,
match="activation_fn_kwargs.k must be provided for topk architecture.",
):
LanguageModelSAERunnerConfig(
architecture="topk", activation_fn="topk", activation_fn_kwargs={}
)


def test_topk_architecture_sets_topk_defaults():
cfg = LanguageModelSAERunnerConfig(architecture="topk")
assert cfg.activation_fn == "topk"
assert cfg.activation_fn_kwargs == {"k": 100}
1 change: 1 addition & 0 deletions tests/unit/training/test_sae_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def test_build_train_step_log_dict(trainer: SAETrainer) -> None:
sae_in=torch.tensor([[-1, 0], [0, 2], [1, 1]]).float(),
sae_out=torch.tensor([[0, 0], [0, 2], [0.5, 1]]).float(),
feature_acts=torch.tensor([[0, 0, 0, 1], [1, 0, 0, 1], [1, 0, 1, 1]]).float(),
hidden_pre=torch.tensor([[-1, 0, 0, 1], [1, -1, 0, 1], [1, -1, 1, 1]]).float(),
loss=torch.tensor(0.5),
losses={
"mse_loss": 0.25,
Expand Down
Loading