diff --git a/docs/training_saes.md b/docs/training_saes.md index 7413a52e..034bdeda 100644 --- a/docs/training_saes.md +++ b/docs/training_saes.md @@ -93,6 +93,47 @@ sparse_autoencoder = SAETrainingRunner(cfg).run() As you can see, the training setup provides a large number of options to explore. The full list of options can be found in the [LanguageModelSAERunnerConfig][sae_lens.LanguageModelSAERunnerConfig] class. +### Training Topk SAEs + +By default, SAELens will train SAEs using a L1 loss term with ReLU activation. A popular alternative architecture is the [TopK](https://arxiv.org/abs/2406.04093) architecture, which fixes the L0 of the SAE using a TopK activation function. To train a TopK SAE, set the `architecture` parameter to `"topk"` in the config. You can set the `k` parameter via `activation_fn_kwargs`. If not set, the default is `k=100`. The TopK architecture ignores the `l1_coefficient` parameter. + +```python +cfg = LanguageModelSAERunnerConfig( + architecture="topk", + activation_fn_kwargs={"k": 100}, + # ... +) +sparse_autoencoder = SAETrainingRunner(cfg).run() +``` + +### Training JumpReLU SAEs + +[JumpReLU SAEs](https://arxiv.org/abs/2407.14435) are the current state-of-the-art SAE architecture, but are often more tricky to train than other architectures. To train a JumpReLU SAE, set the `architecture` parameter to `"jumprelu"` in the config. JumpReLU SAEs use an sparsity penalty that is controlled using the `l1_coefficient` parameter. This is technically a misnomer as the JumpReLU sparsity penalty is not a L1 penalty, but we keep the parameter name for consistency with the L1 penalty used by the standard architecture. The JumpReLU architecture also has two additional parameters: `jumprelu_bandwidth` and `jumprelu_init_threshold`. Both of these are likely fine at their default values, but may be worth experimenting with if JumpReLU training is too slow to converge. + +```python +cfg = LanguageModelSAERunnerConfig( + architecture="jumprelu", + l1_coefficient=5.0, + jumprelu_bandwidth=0.001, + jumprelu_init_threshold=0.001, + # ... +) +sparse_autoencoder = SAETrainingRunner(cfg).run() +``` + +### Training Gated SAEs + +[Gated SAEs](https://arxiv.org/abs/2404.16014) are a precursor to JumpReLU SAEs, but using a simpler training procedure that should make them easier to train. To train a Gated SAE, set the `architecture` parameter to `"gated"` in the config. Gated SAEs use the `l1_coefficient` parameter to control the sparsity of the SAE, the same as standard SAEs. If JumpReLU training is too slow to converge, it may be worth trying a Gated SAE instead. + +```python +cfg = LanguageModelSAERunnerConfig( + architecture="gated", + l1_coefficient=5.0, + # ... +) +sparse_autoencoder = SAETrainingRunner(cfg).run() +``` + ## CLI Runner The SAE training runner can also be run from the command line via the `sae_lens.sae_training_runner` module. This can be useful for quickly testing different hyperparameters or running training on a remote server. The command line interface is shown below. All options to the CLI are the same as the [LanguageModelSAERunnerConfig][sae_lens.LanguageModelSAERunnerConfig] with a `--` prefix. E.g., `--model_name` is the same as `model_name` in the config. diff --git a/sae_lens/config.py b/sae_lens/config.py index f3478533..8f7fe0f7 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -29,7 +29,7 @@ class LanguageModelSAERunnerConfig: Configuration for training a sparse autoencoder on a language model. Args: - architecture (str): The architecture to use, either "standard", "gated", or "jumprelu". + architecture (str): The architecture to use, either "standard", "gated", "topk", or "jumprelu". model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub. model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`. hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook. @@ -130,15 +130,15 @@ class LanguageModelSAERunnerConfig: ) # SAE Parameters - architecture: Literal["standard", "gated", "jumprelu"] = "standard" + architecture: Literal["standard", "gated", "jumprelu", "topk"] = "standard" d_in: int = 512 d_sae: Optional[int] = None b_dec_init_method: str = "geometric_median" expansion_factor: Optional[int] = ( None # defaults to 4 if d_sae and expansion_factor is None ) - activation_fn: str = "relu" # relu, tanh-relu, topk - activation_fn_kwargs: dict[str, Any] = field(default_factory=dict) # for topk + activation_fn: str = None # relu, tanh-relu, topk. Default is relu. # type: ignore + activation_fn_kwargs: dict[str, Any] = None # for topk # type: ignore normalize_sae_decoder: bool = True noise_scale: float = 0.0 from_pretrained_path: Optional[str] = None @@ -164,6 +164,8 @@ class LanguageModelSAERunnerConfig: seed: int = 42 dtype: str = "float32" # type: ignore # prepend_bos: bool = True + + # JumpReLU Parameters jumprelu_init_threshold: float = 0.001 jumprelu_bandwidth: float = 0.001 @@ -251,6 +253,22 @@ def __post_init__(self): self.hook_head_index, ) + if self.activation_fn is None: + self.activation_fn = "topk" if self.architecture == "topk" else "relu" + + if self.architecture == "topk" and self.activation_fn != "topk": + raise ValueError("If using topk architecture, activation_fn must be topk.") + + if self.activation_fn_kwargs is None: + self.activation_fn_kwargs = ( + {"k": 100} if self.activation_fn == "topk" else {} + ) + + if self.architecture == "topk" and self.activation_fn_kwargs.get("k") is None: + raise ValueError( + "activation_fn_kwargs.k must be provided for topk architecture." + ) + if self.d_sae is not None and self.expansion_factor is not None: raise ValueError("You can't set both d_sae and expansion_factor.") diff --git a/sae_lens/sae.py b/sae_lens/sae.py index 85d89ba5..cbdae95f 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -38,7 +38,7 @@ @dataclass class SAEConfig: # architecture details - architecture: Literal["standard", "gated", "jumprelu"] + architecture: Literal["standard", "gated", "jumprelu", "topk"] # forward pass details. d_in: int @@ -157,7 +157,7 @@ def __init__( self.device = torch.device(cfg.device) self.use_error_term = use_error_term - if self.cfg.architecture == "standard": + if self.cfg.architecture == "standard" or self.cfg.architecture == "topk": self.initialize_weights_basic() self.encode = self.encode_standard elif self.cfg.architecture == "gated": @@ -681,6 +681,7 @@ def __init__( self.k = k self.postact_fn = postact_fn + # TODO: Use a fused kernel to speed up topk decoding like https://github.com/EleutherAI/sae/blob/main/sae/kernels.py def forward(self, x: torch.Tensor) -> torch.Tensor: topk = torch.topk(x, k=self.k, dim=-1) values = self.postact_fn(topk.values) diff --git a/sae_lens/sae_training_runner.py b/sae_lens/sae_training_runner.py index a5637b23..b3a34cac 100644 --- a/sae_lens/sae_training_runner.py +++ b/sae_lens/sae_training_runner.py @@ -113,8 +113,7 @@ def run(self): sae = self.run_trainer_with_interruption_handling(trainer) if self.cfg.log_to_wandb: - # remove this type ignore comment after https://github.com/wandb/wandb/issues/8248 is resolved - wandb.finish() # type: ignore + wandb.finish() return sae @@ -228,8 +227,7 @@ def save_checkpoint( model_artifact.add_file(f"{path}/{SAE_WEIGHTS_PATH}") model_artifact.add_file(f"{path}/{SAE_CFG_PATH}") - # remove this type ignore comment after https://github.com/wandb/wandb/issues/8248 is resolved - wandb.log_artifact(model_artifact, aliases=wandb_aliases) # type: ignore + wandb.log_artifact(model_artifact, aliases=wandb_aliases) sparsity_artifact = wandb.Artifact( f"{sae_name}_log_feature_sparsity", @@ -237,8 +235,7 @@ def save_checkpoint( metadata=dict(trainer.cfg.__dict__), ) sparsity_artifact.add_file(log_feature_sparsity_path) - # remove this type ignore comment after https://github.com/wandb/wandb/issues/8248 is resolved - wandb.log_artifact(sparsity_artifact) # type: ignore + wandb.log_artifact(sparsity_artifact) return checkpoint_path diff --git a/sae_lens/training/sae_trainer.py b/sae_lens/training/sae_trainer.py index fc59fb56..a5a17b10 100644 --- a/sae_lens/training/sae_trainer.py +++ b/sae_lens/training/sae_trainer.py @@ -114,13 +114,13 @@ def __init__( num_cycles=cfg.n_restart_cycles, ) self.l1_scheduler = L1Scheduler( - l1_warm_up_steps=cfg.l1_warm_up_steps, # type: ignore + l1_warm_up_steps=cfg.l1_warm_up_steps, total_steps=cfg.total_training_steps, final_l1_coefficient=cfg.l1_coefficient, ) # Setup autocast if using - self.scaler = torch.cuda.amp.GradScaler(enabled=self.cfg.autocast) + self.scaler = torch.amp.GradScaler(device="cuda", enabled=self.cfg.autocast) if self.cfg.autocast: self.autocast_if_enabled = torch.autocast( diff --git a/sae_lens/training/training_sae.py b/sae_lens/training/training_sae.py index a2e7a7a3..b4ef2e28 100644 --- a/sae_lens/training/training_sae.py +++ b/sae_lens/training/training_sae.py @@ -91,6 +91,7 @@ class TrainStepOutput: sae_in: torch.Tensor sae_out: torch.Tensor feature_acts: torch.Tensor + hidden_pre: torch.Tensor loss: torch.Tensor # we need to call backwards on this losses: dict[str, float | torch.Tensor] @@ -237,7 +238,7 @@ def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False): super().__init__(base_sae_cfg) self.cfg = cfg # type: ignore - if cfg.architecture == "standard": + if cfg.architecture == "standard" or cfg.architecture == "topk": self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre elif cfg.architecture == "gated": self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre_gated @@ -281,11 +282,10 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAE": return cls(TrainingSAEConfig.from_dict(config_dict)) def check_cfg_compatibility(self): - if self.cfg.architecture == "gated": - assert ( - self.cfg.use_ghost_grads is False - ), "Gated SAEs do not support ghost grads" - assert self.use_error_term is False, "Gated SAEs do not support error terms" + if self.cfg.architecture != "standard" and self.cfg.use_ghost_grads: + raise ValueError(f"{self.cfg.architecture} SAEs do not support ghost grads") + if self.cfg.architecture == "gated" and self.use_error_term: + raise ValueError("Gated SAEs do not support error terms") def encode_standard( self, x: Float[torch.Tensor, "... d_in"] @@ -413,6 +413,15 @@ def training_forward_pass( l0_loss = (current_l1_coefficient * l0).mean() loss = mse_loss + l0_loss losses["l0_loss"] = l0_loss + elif self.cfg.architecture == "topk": + topk_loss = self.calculate_topk_aux_loss( + sae_in=sae_in, + sae_out=sae_out, + hidden_pre=hidden_pre, + dead_neuron_mask=dead_neuron_mask, + ) + losses["auxiliary_reconstruction_loss"] = topk_loss + loss = mse_loss + topk_loss else: # default SAE sparsity loss weighted_feature_acts = feature_acts @@ -446,10 +455,47 @@ def training_forward_pass( sae_in=sae_in, sae_out=sae_out, feature_acts=feature_acts, + hidden_pre=hidden_pre, loss=loss, losses=losses, ) + def calculate_topk_aux_loss( + self, + sae_in: torch.Tensor, + sae_out: torch.Tensor, + hidden_pre: torch.Tensor, + dead_neuron_mask: torch.Tensor | None, + ) -> torch.Tensor: + # Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization + # NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here + if ( + dead_neuron_mask is not None + and (num_dead := int(dead_neuron_mask.sum())) > 0 + ): + residual = sae_in - sae_out + + # Heuristic from Appendix B.1 in the paper + k_aux = hidden_pre.shape[-1] // 2 + + # Reduce the scale of the loss if there are a small number of dead latents + scale = min(num_dead / k_aux, 1.0) + k_aux = min(k_aux, num_dead) + + auxk_acts = _calculate_topk_aux_acts( + k_aux=k_aux, + hidden_pre=hidden_pre, + dead_neuron_mask=dead_neuron_mask, + ) + + # Encourage the top ~50% of dead latents to predict the residual of the + # top k living latents + recons = self.decode(auxk_acts) + auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean() + return scale * auxk_loss + else: + return sae_out.new_tensor(0.0) + def calculate_ghost_grad_loss( self, x: torch.Tensor, @@ -644,3 +690,19 @@ def remove_gradient_parallel_to_decoder_directions(self): self.W_dec.data, "d_sae, d_sae d_in -> d_sae d_in", ) + + +def _calculate_topk_aux_acts( + k_aux: int, + hidden_pre: torch.Tensor, + dead_neuron_mask: torch.Tensor, +) -> torch.Tensor: + # Don't include living latents in this loss + auxk_latents = torch.where(dead_neuron_mask[None], hidden_pre, -torch.inf) + # Top-k dead latents + auxk_topk = auxk_latents.topk(k_aux, sorted=False) + # Set the activations to zero for all but the top k_aux dead latents + auxk_acts = torch.zeros_like(hidden_pre) + auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values) + # Set activations to zero for all but top k_aux dead latents + return auxk_acts diff --git a/tests/unit/training/test_config.py b/tests/unit/training/test_config.py index 6b1a93d9..3b0bff2c 100644 --- a/tests/unit/training/test_config.py +++ b/tests/unit/training/test_config.py @@ -138,3 +138,26 @@ def test_cache_activations_runner_config_seqpos( seqpos_slice=seqpos_slice, context_size=context_size, ) + + +def test_topk_architecture_requires_topk_activation(): + with pytest.raises( + ValueError, match="If using topk architecture, activation_fn must be topk." + ): + LanguageModelSAERunnerConfig(architecture="topk", activation_fn="relu") + + +def test_topk_architecture_requires_k_parameter(): + with pytest.raises( + ValueError, + match="activation_fn_kwargs.k must be provided for topk architecture.", + ): + LanguageModelSAERunnerConfig( + architecture="topk", activation_fn="topk", activation_fn_kwargs={} + ) + + +def test_topk_architecture_sets_topk_defaults(): + cfg = LanguageModelSAERunnerConfig(architecture="topk") + assert cfg.activation_fn == "topk" + assert cfg.activation_fn_kwargs == {"k": 100} diff --git a/tests/unit/training/test_sae_trainer.py b/tests/unit/training/test_sae_trainer.py index 94bc906f..bf64adc4 100644 --- a/tests/unit/training/test_sae_trainer.py +++ b/tests/unit/training/test_sae_trainer.py @@ -168,6 +168,7 @@ def test_build_train_step_log_dict(trainer: SAETrainer) -> None: sae_in=torch.tensor([[-1, 0], [0, 2], [1, 1]]).float(), sae_out=torch.tensor([[0, 0], [0, 2], [0.5, 1]]).float(), feature_acts=torch.tensor([[0, 0, 0, 1], [1, 0, 0, 1], [1, 0, 1, 1]]).float(), + hidden_pre=torch.tensor([[-1, 0, 0, 1], [1, -1, 0, 1], [1, -1, 1, 1]]).float(), loss=torch.tensor(0.5), losses={ "mse_loss": 0.25, diff --git a/tests/unit/training/test_training_sae.py b/tests/unit/training/test_training_sae.py index 16a84594..37cd1014 100644 --- a/tests/unit/training/test_training_sae.py +++ b/tests/unit/training/test_training_sae.py @@ -4,7 +4,11 @@ import torch from sae_lens.sae import SAE -from sae_lens.training.training_sae import TrainingSAE, TrainingSAEConfig +from sae_lens.training.training_sae import ( + TrainingSAE, + TrainingSAEConfig, + _calculate_topk_aux_acts, +) from tests.unit.helpers import build_sae_cfg @@ -42,7 +46,142 @@ def test_TrainingSAE_training_forward_pass_can_scale_sparsity_penalty_by_decoder ) -@pytest.mark.parametrize("architecture", ["standard", "gated", "jumprelu"]) +def test_calculate_topk_aux_acts(): + # Create test inputs + k_aux = 3 + hidden_pre = torch.tensor( + [ + [1.0, 2.0, -3.0, 4.0, -5.0, 6.0], + [-1.0, -2.0, 3.0, -4.0, 5.0, -6.0], + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + [-0.6, -0.5, -0.4, -0.3, -0.2, -0.1], + ] + ) + + # Create dead neuron mask where neurons 1,3,5 are dead + dead_neuron_mask = torch.tensor([False, True, False, True, False, True]) + + # Calculate expected result + # For each row, should select top k_aux=3 values from dead neurons (indices 1,3,5) + # and zero out all other values + expected = torch.zeros_like(hidden_pre) + expected[0, [1, 3, 5]] = torch.tensor([2.0, 4.0, 6.0]) + expected[1, [1, 3, 5]] = torch.tensor([-2.0, -4.0, -6.0]) + expected[2, [1, 3, 5]] = torch.tensor([0.2, 0.4, 0.6]) + expected[3, [1, 3, 5]] = torch.tensor([-0.5, -0.3, -0.1]) + + result = _calculate_topk_aux_acts(k_aux, hidden_pre, dead_neuron_mask) + + assert torch.allclose(result, expected) + + +def test_calculate_topk_aux_acts_k_less_than_dead(): + # Create test inputs with k_aux less than number of dead neurons + k_aux = 1 # Only select top 1 dead neuron + hidden_pre = torch.tensor( + [ + [1.0, 2.0, -3.0, 4.0], # 2 items in batch + [-1.0, -2.0, 3.0, -4.0], + ] + ) + + # Create dead neuron mask where neurons 1,3 are dead (2 dead neurons) + dead_neuron_mask = torch.tensor([False, True, False, True]) + + # Calculate expected result + # For each row, should select only top k_aux=1 value from dead neurons (indices 1,3) + # and zero out all other values + expected = torch.zeros_like(hidden_pre) + expected[0, 3] = 4.0 # Only highest value among dead neurons for first item + expected[1, 1] = -2.0 # Only highest value among dead neurons for second item + + result = _calculate_topk_aux_acts(k_aux, hidden_pre, dead_neuron_mask) + + assert torch.allclose(result, expected) + + +def test_TrainingSAE_calculate_topk_aux_loss(): + # Create a small test SAE with d_sae=4, d_in=3 + cfg = build_sae_cfg( + d_in=3, + d_sae=4, + architecture="topk", + normalize_sae_decoder=False, + ) + + sae = TrainingSAE(TrainingSAEConfig.from_sae_runner_config(cfg)) + + # Set up test inputs + hidden_pre = torch.tensor( + [[1.0, -2.0, 3.0, -4.0], [1.0, 0.0, -3.0, -4.0]] # batch size 2 + ) + sae.W_dec.data = torch.tensor(2 * torch.ones((4, 3))) + sae.b_dec.data = torch.tensor(torch.zeros(3)) + + sae_out = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + sae_in = torch.tensor([[2.0, 1.0, 3.0], [5.0, 4.0, 6.0]]) + # Mark neurons 1 and 3 as dead + dead_neuron_mask = torch.tensor([False, True, False, True]) + + # Calculate loss + loss = sae.calculate_topk_aux_loss( + sae_in=sae_in, + hidden_pre=hidden_pre, + sae_out=sae_out, + dead_neuron_mask=dead_neuron_mask, + ) + + # The loss should: + # 1. Select top k_aux=2 (half of d_sae) dead neurons + # 2. Decode their activations (should be 2x the sum of the activations of the dead neurons) + # thus, (-12, -12, -12), (-8, -8, -8) + # and the residual is (1, -1, 0), (1, -1, 0) + # Thus, squared errors are (169, 121, 144), (81, 49, 64) + # and the sums are (434, 194) + # and the mean of these is 314 + + assert loss == 314 + + +def test_TrainingSAE_forward_includes_topk_loss_with_topk_architecture(): + cfg = build_sae_cfg( + d_in=3, + d_sae=4, + architecture="topk", + activation_fn_kwargs={"k": 2}, + normalize_sae_decoder=False, + ) + sae = TrainingSAE(TrainingSAEConfig.from_sae_runner_config(cfg)) + x = torch.randn(32, 3) + train_step_output = sae.training_forward_pass( + sae_in=x, + current_l1_coefficient=2.0, + dead_neuron_mask=None, + ) + assert "auxiliary_reconstruction_loss" in train_step_output.losses + assert train_step_output.losses["auxiliary_reconstruction_loss"] == 0.0 + + +def test_TrainingSAE_forward_includes_topk_loss_is_nonzero_if_dead_neurons_present(): + cfg = build_sae_cfg( + d_in=3, + d_sae=4, + architecture="topk", + activation_fn_kwargs={"k": 2}, + normalize_sae_decoder=False, + ) + sae = TrainingSAE(TrainingSAEConfig.from_sae_runner_config(cfg)) + x = torch.randn(32, 3) + train_step_output = sae.training_forward_pass( + sae_in=x, + current_l1_coefficient=2.0, + dead_neuron_mask=torch.tensor([False, True, False, True]), + ) + assert "auxiliary_reconstruction_loss" in train_step_output.losses + assert train_step_output.losses["auxiliary_reconstruction_loss"] > 0.0 + + +@pytest.mark.parametrize("architecture", ["standard", "gated", "jumprelu", "topk"]) def test_TrainingSAE_encode_returns_same_value_as_encode_with_hidden_pre( architecture: str, ):