From 7047f877979952836e6778827248918818716b96 Mon Sep 17 00:00:00 2001 From: Oliver De Candido <40824443+decandido@users.noreply.github.com> Date: Tue, 15 Oct 2024 22:11:57 +0200 Subject: [PATCH] feat: support othellogpt in SAELens (#317) * support seqpos slicing * add basic tests, ensure it's in the SAE config * format * fix tests * fix tests 2 * fix: Changing the activations store to handle context sizes smaller than dataset lengths for tokenized datasets. * fix: Found bug which allowed for negative context lengths. Removed the bug * Update pytest to test new logic for context size of tokenized dataset * Reformat code to pass CI tests * Add warning for when context_size is smaller than the dataset context_size * feat: adding support for start and end position offsets for token sequences * Add start_pos_offset and end_pos_offset to the SAERunnerConfig * Add tests for start_pos_offset and end_pos_offset in the LanguageModelSAERunnerConfig * feat: start and end position offset support for SAELens. * Add test for CacheActivationsRunnerConfig with start and end pos offset * Test cache activation runner wtih valid start and end pos offset * feat: Enabling loading of start and end pos offset from saes. Adding tests for this * fix: Renaming variables and a test * adds test for position offests for saes * reformats files with black * Add start and end pos offset to the base sae dict * fix test for sae training runner config with position offsets * add a benchmark test to train an SAE on OthelloGPT * Remove double import from typing * change dead_feature_window to int * remove print statements from test file * Rebase on seqpos tuple implementation and remove start/end pos offset * Reword docstring for seqpos to be clearer. * Added script to train an SAE on othelloGPT --------- Co-authored-by: callummcdougall Co-authored-by: jbloomAus Co-authored-by: liuman --- sae_lens/config.py | 38 +++++++ sae_lens/sae.py | 6 + sae_lens/training/activations_store.py | 41 +++---- sae_lens/training/training_sae.py | 13 +++ ...raining_a_sparse_autoencoder_othelloGPT.py | 107 ++++++++++++++++++ .../test_language_model_sae_runner.py | 78 +++++++++++++ tests/unit/training/test_activations_store.py | 31 ++++- .../training/test_cache_activations_runner.py | 80 ++++++++++++- tests/unit/training/test_config.py | 49 +++++++- tests/unit/training/test_sae_basic.py | 17 +++ 10 files changed, 438 insertions(+), 22 deletions(-) create mode 100644 scripts/training_a_sparse_autoencoder_othelloGPT.py diff --git a/sae_lens/config.py b/sae_lens/config.py index 3fc4ac75..5bc2a1a7 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -60,6 +60,7 @@ class LanguageModelSAERunnerConfig: store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating actiations. train_batch_size_tokens (int): The batch size for training. This controls the batch size of the SAE Training loop. normalize_activations (str): Activation Normalization Strategy. Either none, expected_average_only_in (estimate the average activation norm and divide activations by it -> this can be folded post training and set to None), or constant_norm_rescale (at runtime set activation norm to sqrt(d_in) and then scale up the SAE output). + seqpos_slice (tuple): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0. device (str): The device to use. Usually cuda. act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram. seed (int): The seed to use. @@ -153,6 +154,7 @@ class LanguageModelSAERunnerConfig: normalize_activations: str = ( "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update) ) + seqpos_slice: tuple[int | None, ...] = (None,) # Misc device: str = "cpu" @@ -355,6 +357,13 @@ def __post_init__(self): if self.use_ghost_grads: print("Using Ghost Grads.") + if self.context_size < 0: + raise ValueError( + f"The provided context_size is {self.context_size} is negative. Expecting positive context_size." + ) + + _validate_seqpos(seqpos=self.seqpos_slice, context_size=self.context_size) + @property def total_training_tokens(self) -> int: return self.training_tokens + self.finetuning_tokens @@ -386,6 +395,7 @@ def get_base_sae_cfg_dict(self) -> dict[str, Any]: "normalize_activations": self.normalize_activations, "activation_fn_kwargs": self.activation_fn_kwargs, "model_from_pretrained_kwargs": self.model_from_pretrained_kwargs, + "seqpos_slice": self.seqpos_slice, } def get_training_sae_cfg_dict(self) -> dict[str, Any]: @@ -427,6 +437,15 @@ def to_json(self, path: str) -> None: def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig": with open(path + "cfg.json", "r") as f: cfg = json.load(f) + + # ensure that seqpos slices is a tuple + # Ensure seqpos_slice is a tuple + if "seqpos_slice" in cfg: + if isinstance(cfg["seqpos_slice"], list): + cfg["seqpos_slice"] = tuple(cfg["seqpos_slice"]) + elif not isinstance(cfg["seqpos_slice"], tuple): + cfg["seqpos_slice"] = (cfg["seqpos_slice"],) + return cls(**cfg) @@ -461,6 +480,7 @@ class CacheActivationsRunnerConfig: store_batch_size_prompts: int = 32 train_batch_size_tokens: int = 4096 normalize_activations: str = "none" # should always be none for activation caching + seqpos_slice: tuple[int | None, ...] = (None,) # Misc device: str = "cpu" @@ -491,6 +511,13 @@ def __post_init__(self): if self.act_store_device == "with_model": self.act_store_device = self.device + if self.context_size < 0: + raise ValueError( + f"The provided context_size is {self.context_size} is negative. Expecting positive context_size." + ) + + _validate_seqpos(seqpos=self.seqpos_slice, context_size=self.context_size) + @dataclass class ToyModelSAERunnerConfig: @@ -576,6 +603,17 @@ def _default_cached_activations_path( return path +def _validate_seqpos(seqpos: tuple[int | None, ...], context_size: int) -> None: + # Ensure that the step-size is larger or equal to 1 + if len(seqpos) == 3: + step_size = seqpos[2] or 1 + assert ( + step_size > 1 + ), f"Ensure the step_size {seqpos[2]=} for sequence slicing is positive." + # Ensure that the choice of seqpos doesn't end up with an empty list + assert len(list(range(context_size))[slice(*seqpos)]) > 0 + + @dataclass class PretokenizeRunnerConfig: tokenizer_name: str = "gpt2" diff --git a/sae_lens/sae.py b/sae_lens/sae.py index 125248ee..1610090c 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -63,6 +63,7 @@ class SAEConfig: activation_fn_kwargs: dict[str, Any] = field(default_factory=dict) neuronpedia_id: Optional[str] = None model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict) + seqpos_slice: tuple[int | None, ...] = (None,) @classmethod def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig": @@ -82,6 +83,10 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig": for k, v in config_dict.items() if k in cls.__dataclass_fields__ # pylint: disable=no-member } + + if "seqpos_slice" in config_dict: + config_dict["seqpos_slice"] = tuple(config_dict["seqpos_slice"]) + return cls(**config_dict) # def __post_init__(self): @@ -109,6 +114,7 @@ def to_dict(self) -> dict[str, Any]: "normalize_activations": self.normalize_activations, "neuronpedia_id": self.neuronpedia_id, "model_from_pretrained_kwargs": self.model_from_pretrained_kwargs, + "seqpos_slice": self.seqpos_slice, } diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 9f53c0eb..d4a2320d 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -88,6 +88,7 @@ def from_config( model_kwargs=cfg.model_kwargs, autocast_lm=cfg.autocast_lm, dataset_trust_remote_code=cfg.dataset_trust_remote_code, + seqpos_slice=cfg.seqpos_slice, ) @classmethod @@ -123,6 +124,7 @@ def from_sae( dataset_trust_remote_code=sae.cfg.dataset_trust_remote_code, dtype=sae.cfg.dtype, device=torch.device(device), + seqpos_slice=sae.cfg.seqpos_slice, ) def __init__( @@ -147,6 +149,7 @@ def __init__( model_kwargs: dict[str, Any] | None = None, autocast_lm: bool = False, dataset_trust_remote_code: bool | None = None, + seqpos_slice: tuple[int | None, ...] = (None,), ): self.model = model if model_kwargs is None: @@ -188,6 +191,7 @@ def __init__( self.dtype = DTYPE_MAP[dtype] self.cached_activations_path = cached_activations_path self.autocast_lm = autocast_lm + self.seqpos_slice = seqpos_slice self.n_dataset_processed = 0 @@ -220,10 +224,6 @@ def __init__( f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}. The context_size {ds_context_size} is expected to be larger than or equal to the provided context size {self.context_size}.""" ) - if self.context_size < 0: - raise ValueError( - f"The provided context_size is {self.context_size} is negative. Expecting positive context_size" - ) if self.context_size != ds_context_size: warnings.warn( f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}. Some data will be discarded in this case.""", @@ -441,7 +441,7 @@ def get_activations(self, batch_tokens: torch.Tensor): autocast_if_enabled = contextlib.nullcontext() with autocast_if_enabled: - layerwise_activations = self.model.run_with_cache( + layerwise_activations_cache = self.model.run_with_cache( batch_tokens, names_filter=[self.hook_name], stop_at_layer=self.hook_layer + 1, @@ -449,29 +449,31 @@ def get_activations(self, batch_tokens: torch.Tensor): **self.model_kwargs, )[1] - n_batches, n_context = batch_tokens.shape + layerwise_activations = layerwise_activations_cache[self.hook_name][ + :, slice(*self.seqpos_slice) + ] + + n_batches, n_context = layerwise_activations.shape[:2] stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in)) if self.hook_head_index is not None: - stacked_activations[:, :, 0] = layerwise_activations[self.hook_name][ + stacked_activations[:, :, 0] = layerwise_activations[ :, :, self.hook_head_index ] - elif ( - layerwise_activations[self.hook_name].ndim > 3 - ): # if we have a head dimension + elif layerwise_activations.ndim > 3: # if we have a head dimension try: - stacked_activations[:, :, 0] = layerwise_activations[ - self.hook_name - ].view(n_batches, n_context, -1) + stacked_activations[:, :, 0] = layerwise_activations.view( + n_batches, n_context, -1 + ) except RuntimeError as e: print(f"Error during view operation: {e}") print("Attempting to use reshape instead...") - stacked_activations[:, :, 0] = layerwise_activations[ - self.hook_name - ].reshape(n_batches, n_context, -1) + stacked_activations[:, :, 0] = layerwise_activations.reshape( + n_batches, n_context, -1 + ) else: - stacked_activations[:, :, 0] = layerwise_activations[self.hook_name] + stacked_activations[:, :, 0] = layerwise_activations return stacked_activations @@ -487,6 +489,7 @@ def get_buffer( If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react. """ context_size = self.context_size + training_context_size = len(range(context_size)[slice(*self.seqpos_slice)]) batch_size = self.store_batch_size_prompts d_in = self.d_in total_size = batch_size * n_batches_in_buffer @@ -494,7 +497,7 @@ def get_buffer( if self.cached_activations_path is not None: # Load the activations from disk - buffer_size = total_size * context_size + buffer_size = total_size * training_context_size # Initialize an empty tensor with an additional dimension for layers new_buffer = torch.zeros( (buffer_size, num_layers, d_in), @@ -548,7 +551,7 @@ def get_buffer( refill_iterator = range(0, batch_size * n_batches_in_buffer, batch_size) # Initialize empty tensor buffer of the maximum required size with an additional dimension for layers new_buffer = torch.zeros( - (total_size, context_size, num_layers, d_in), + (total_size, training_context_size, num_layers, d_in), dtype=self.dtype, # type: ignore device=self.device, ) diff --git a/sae_lens/training/training_sae.py b/sae_lens/training/training_sae.py index 217e7252..b7925d4e 100644 --- a/sae_lens/training/training_sae.py +++ b/sae_lens/training/training_sae.py @@ -75,6 +75,7 @@ def from_sae_runner_config( context_size=cfg.context_size, dataset_path=cfg.dataset_path, prepend_bos=cfg.prepend_bos, + seqpos_slice=cfg.seqpos_slice, # Training cfg l1_coefficient=cfg.l1_coefficient, lp_norm=cfg.lp_norm, @@ -99,6 +100,18 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAEConfig": valid_config_dict = { key: val for key, val in config_dict.items() if key in valid_field_names } + + # ensure seqpos slice is tuple + # ensure that seqpos slices is a tuple + # Ensure seqpos_slice is a tuple + if "seqpos_slice" in valid_config_dict: + if isinstance(valid_config_dict["seqpos_slice"], list): + valid_config_dict["seqpos_slice"] = tuple( + valid_config_dict["seqpos_slice"] + ) + elif not isinstance(valid_config_dict["seqpos_slice"], tuple): + valid_config_dict["seqpos_slice"] = (valid_config_dict["seqpos_slice"],) + return TrainingSAEConfig(**valid_config_dict) def to_dict(self) -> dict[str, Any]: diff --git a/scripts/training_a_sparse_autoencoder_othelloGPT.py b/scripts/training_a_sparse_autoencoder_othelloGPT.py new file mode 100644 index 00000000..ee48914d --- /dev/null +++ b/scripts/training_a_sparse_autoencoder_othelloGPT.py @@ -0,0 +1,107 @@ +import os + +import torch + +from sae_lens import ( + SAE, + HookedSAETransformer, + LanguageModelSAERunnerConfig, + SAETrainingRunner, + upload_saes_to_huggingface, +) + +if torch.cuda.is_available(): + device = "cuda" +elif torch.backends.mps.is_available(): + device = "mps" +else: + device = "cpu" + +print("Using device:", device) +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +model_name = "othello-gpt" +model = HookedSAETransformer.from_pretrained(model_name) + +dataset_path = "taufeeque/othellogpt" +context_size = 59 + +layer = 5 +training_tokens = int(1e3) +train_batch_size_tokens = 2048 +n_steps = int(training_tokens / train_batch_size_tokens) + +print(LanguageModelSAERunnerConfig()) +runner_cfg = LanguageModelSAERunnerConfig( + # + # Data generation + model_name=model_name, + hook_name=f"blocks.{layer}.mlp.hook_post", + hook_layer=layer, + d_in=model.cfg.d_mlp, + dataset_path=dataset_path, + is_dataset_tokenized=True, + prepend_bos=False, + streaming=True, + train_batch_size_tokens=train_batch_size_tokens, + context_size=context_size, + seqpos_slice=(5, -5), + # + # SAE achitecture + architecture="gated", + expansion_factor=8, + b_dec_init_method="zeros", + apply_b_dec_to_input=True, + normalize_sae_decoder=False, + scale_sparsity_penalty_by_decoder_norm=True, + decoder_heuristic_init=True, + init_encoder_as_decoder_transpose=True, + # + # Activations store + n_batches_in_buffer=32, + store_batch_size_prompts=16, + training_tokens=training_tokens, + # + # Training hyperparameters (standard) + lr=2e-4, + adam_beta1=0.9, + adam_beta2=0.999, + lr_scheduler_name="constant", + lr_warm_up_steps=int(0.2 * n_steps), + lr_decay_steps=int(0.2 * n_steps), + # + # Training hyperparameters (SAE-specific) + l1_coefficient=5, + l1_warm_up_steps=int(0.2 * n_steps), + use_ghost_grads=False, + feature_sampling_window=1000, + dead_feature_window=500, + dead_feature_threshold=1e-5, + # + # Logging / evals + log_to_wandb=True, + wandb_project=f"othello_gpt_sae_{layer=}", + wandb_log_frequency=30, + eval_every_n_wandb_logs=10, + checkpoint_path="checkpoints", + # + # Misc. + device=str(device), + seed=42, + n_checkpoints=5, + dtype="float32", +) + +# t.set_grad_enabled(True) +runner = SAETrainingRunner(runner_cfg) +sae = runner.run() + +hf_repo_id = "callummcdougall/arena-demos-othellogpt" +sae_id = "blocks.5.mlp.hook_post-v1" + +upload_saes_to_huggingface({sae_id: sae}, hf_repo_id=hf_repo_id) + +othellogpt_sae = SAE.from_pretrained( + release=hf_repo_id, sae_id=sae_id, device=str(device) +)[0] diff --git a/tests/benchmark/test_language_model_sae_runner.py b/tests/benchmark/test_language_model_sae_runner.py index 29d99c9c..995af247 100644 --- a/tests/benchmark/test_language_model_sae_runner.py +++ b/tests/benchmark/test_language_model_sae_runner.py @@ -292,3 +292,81 @@ def test_language_model_sae_runner_top_k(): assert sae is not None # know whether or not this works by looking at the dashboard! + + +def test_language_model_sae_runner_othellogpt(): + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + + # total_training_steps = 20_000 + total_training_steps = 500 + batch_size = 4096 + total_training_tokens = total_training_steps * batch_size + + lr_warm_up_steps = 0 + lr_decay_steps = 40_000 + l1_warmup_steps = 10_000 + + cfg = LanguageModelSAERunnerConfig( + # Data Generating Function (Model + Training Distibuion) + model_name="othello-gpt", # othello-gpt model + hook_name="blocks.6.hook_resid_pre", # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points) + hook_layer=6, # Only one layer in the model. + d_in=512, # the width of the mlp output. + dataset_path="taufeeque/othellogpt", # this is a tokenized language dataset on Huggingface for OthelloGPT games. + is_dataset_tokenized=True, + streaming=True, # we could pre-download the token dataset if it was small. + # SAE Parameters + mse_loss_normalization=None, # We won't normalize the mse loss, + expansion_factor=16, # the width of the SAE. Larger will result in better stats but slower training. + b_dec_init_method="geometric_median", # The geometric median can be used to initialize the decoder weights. + apply_b_dec_to_input=False, # We won't apply the decoder weights to the input. + normalize_sae_decoder=False, + scale_sparsity_penalty_by_decoder_norm=True, + decoder_heuristic_init=True, + init_encoder_as_decoder_transpose=True, + normalize_activations="expected_average_only_in", + # Training Parameters + lr=0.00003, # lower the better, we'll go fairly high to speed up the tutorial. + adam_beta1=0.9, # adam params (default, but once upon a time we experimented with these.) + adam_beta2=0.999, + lr_scheduler_name="constant", # constant learning rate with warmup. Could be better schedules out there. + lr_warm_up_steps=lr_warm_up_steps, # this can help avoid too many dead features initially. + lr_decay_steps=lr_decay_steps, # this will help us avoid overfitting. + l1_coefficient=0.001, # will control how sparse the feature activations are + l1_warm_up_steps=l1_warmup_steps, # this can help avoid too many dead features initially. + lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1) + train_batch_size_tokens=batch_size, + context_size=59, # will control the length of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one. + seqpos_slice=(5, -5), + # Activation Store Parameters + n_batches_in_buffer=32, # controls how many activations we store / shuffle. + training_tokens=total_training_tokens, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back. + store_batch_size_prompts=32, + # Resampling protocol + use_ghost_grads=False, # we don't use ghost grads anymore. + feature_sampling_window=500, # this controls our reporting of feature sparsity stats + dead_feature_window=1000000, # would effect resampling or ghost grads if we were using it. + dead_feature_threshold=1e-4, # would effect resampling or ghost grads if we were using it. + # WANDB + log_to_wandb=False, # always use wandb unless you are just testing code. + wandb_project="benchmark", + wandb_log_frequency=100, + eval_every_n_wandb_logs=20, + # Misc + device=device, + seed=42, + n_checkpoints=0, + checkpoint_path="checkpoints", + dtype="torch.float32", + ) + + # look at the next cell to see some instruction for what to do while this is running. + sae = SAETrainingRunner(cfg).run() + + assert sae is not None + # know whether or not this works by looking at the dashboard! diff --git a/tests/unit/training/test_activations_store.py b/tests/unit/training/test_activations_store.py index 7f14974d..6c845642 100644 --- a/tests/unit/training/test_activations_store.py +++ b/tests/unit/training/test_activations_store.py @@ -353,7 +353,7 @@ def test_activations_store___iterate_tokenized_sequences__yields_sequences_of_co # length of the dataset @pytest.mark.parametrize( "context_size, expected_error", - [(-1, ValueError), (5, RuntimeWarning), (10, None), (15, ValueError)], + [(5, RuntimeWarning), (10, None), (15, ValueError)], ) def test_activations_store__errors_on_context_size_mismatch( ts_model: HookedTransformer, context_size: int, expected_error: Optional[ValueError] @@ -389,6 +389,12 @@ def test_activations_store__errors_on_context_size_mismatch( ActivationsStore.from_config(ts_model, cfg, override_dataset=tokenized_dataset) +def test_activations_store__errors_on_negative_context_size(): + with pytest.raises(ValueError): + # We should raise an error when the context_size is negative + build_sae_cfg(prepend_bos=True, context_size=-1) + + def test_activations_store___iterate_tokenized_sequences__yields_identical_results_with_and_without_pretokenizing( ts_model: HookedTransformer, ): @@ -478,3 +484,26 @@ def test_validate_pretokenized_dataset_tokenizer_does_nothing_if_the_dataset_pat model_tokenizer = ts_model.tokenizer assert model_tokenizer is not None validate_pretokenized_dataset_tokenizer(ds_path, model_tokenizer) + + +def test_activations_store_respects_position_offsets(ts_model: HookedTransformer): + cfg = build_sae_cfg( + context_size=10, + seqpos_slice=(2, 8), # Only consider positions 2 to 7 (inclusive) + ) + dataset = Dataset.from_list( + [ + {"text": "This is a test sentence for slicing."}, + ] + * 100 + ) + + activation_store = ActivationsStore.from_config( + ts_model, cfg, override_dataset=dataset + ) + + batch = activation_store.get_batch_tokens(1) + activations = activation_store.get_activations(batch) + + assert batch.shape == (1, 10) # Full context size + assert activations.shape == (1, 6, 1, cfg.d_in) # Only 6 positions (2 to 7) diff --git a/tests/unit/training/test_cache_activations_runner.py b/tests/unit/training/test_cache_activations_runner.py index 7358fa5c..f10ee1c9 100644 --- a/tests/unit/training/test_cache_activations_runner.py +++ b/tests/unit/training/test_cache_activations_runner.py @@ -104,7 +104,6 @@ def test_load_cached_activations(): tokens_in_buffer = n_batches_in_buffer * store_batch_size * context_size total_training_tokens = n_buffers * tokens_in_buffer - print(f"Total Training Tokens: {total_training_tokens}") # better if we can look at the files cached_activations_fixture_path = os.path.join( @@ -216,3 +215,82 @@ def W_E(self) -> torch.Tensor: # no errors are ever raised if we do not ask for raise_at_epoch_end for _ in range(32): _ = activations_store.get_batch_tokens(batch_size, raise_at_epoch_end=False) + + +# The way to run this with this command: +# poetry run py.test tests/unit/test_cache_activations_runner.py --profile-svg -s +def test_cache_activations_runner_with_valid_seqpos(tmp_path: Path): + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + + # total_training_steps = 20_000 + context_size = 1024 + seqpos_slice = (12, -12) + training_context_size = len(range(context_size)[slice(*seqpos_slice)]) + n_batches_in_buffer = 32 + store_batch_size = 1 + n_buffers = 3 + + tokens_in_buffer = n_batches_in_buffer * store_batch_size * training_context_size + total_training_tokens = n_buffers * tokens_in_buffer + + # better if we can look at the files (change tmp_path to a real path to look at the files) + # tmp_path = os.path.join(os.path.dirname(__file__), "tmp") + # tmp_path = Path("/Volumes/T7 Shield/activations/gelu_1l") + # if os.path.exists(tmp_path): + # shutil.rmtree(tmp_path) + + cfg = CacheActivationsRunnerConfig( + new_cached_activations_path=str(tmp_path), + # Pick a tiny model to make this easier. + model_name="gelu-1l", + ## MLP Layer 0 ## + hook_name="blocks.0.hook_mlp_out", + hook_layer=0, + d_in=512, + dataset_path="NeelNanda/c4-tokenized-2b", + context_size=context_size, # Speed things up. + is_dataset_tokenized=True, + prepend_bos=True, # I used to train GPT2 SAEs with a prepended-bos but no longer think we should do this. + training_tokens=total_training_tokens, # For initial testing I think this is a good number. + train_batch_size_tokens=4096, + # Test the sequence slicing + seqpos_slice=seqpos_slice, + # Loss Function + ## Reconstruction Coefficient. + # Buffer details won't matter in we cache / shuffle our activations ahead of time. + n_batches_in_buffer=n_batches_in_buffer, + store_batch_size_prompts=store_batch_size, + normalize_activations="none", + # + shuffle_every_n_buffers=2, + n_shuffles_with_last_section=1, + n_shuffles_in_entire_dir=1, + n_shuffles_final=1, + # Misc + device=device, + seed=42, + dtype="float16", + ) + + # look at the next cell to see some instruction for what to do while this is running. + CacheActivationsRunner(cfg).run() + + assert os.path.exists(tmp_path) + + # assert that there are n_buffer files in the directory. + assert len(os.listdir(tmp_path)) == n_buffers + + for _, buffer_file in enumerate(os.listdir(tmp_path)): + path_to_file = Path(tmp_path) / buffer_file + with safe_open(path_to_file, framework="pt", device=str(device)) as f: # type: ignore + buffer = f.get_tensor("activations") + assert buffer.shape == ( + tokens_in_buffer, + 1, + cfg.d_in, + ) diff --git a/tests/unit/training/test_config.py b/tests/unit/training/test_config.py index ca0f154a..6b1a93d9 100644 --- a/tests/unit/training/test_config.py +++ b/tests/unit/training/test_config.py @@ -1,7 +1,9 @@ +from typing import Optional + import pytest from sae_lens import __version__ -from sae_lens.config import LanguageModelSAERunnerConfig +from sae_lens.config import CacheActivationsRunnerConfig, LanguageModelSAERunnerConfig TINYSTORIES_MODEL = "tiny-stories-1M" TINYSTORIES_DATASET = "roneneldan/TinyStories" @@ -67,6 +69,7 @@ def test_sae_training_runner_config_get_sae_base_parameters(): "model_from_pretrained_kwargs": { "center_writing_weights": False, }, + "seqpos_slice": (None,), } assert expected_config == cfg.get_base_sae_cfg_dict() @@ -91,3 +94,47 @@ def test_sae_training_runner_config_expansion_factor(): cfg = LanguageModelSAERunnerConfig() assert cfg.expansion_factor == 4 + + +test_cases_for_seqpos = [ + ((None, 10, -1), AssertionError), + ((None, 10, 0), AssertionError), + ((5, 5, None), AssertionError), + ((6, 3, None), AssertionError), +] + + +@pytest.mark.parametrize("seqpos_slice, expected_error", test_cases_for_seqpos) +def test_sae_training_runner_config_seqpos( + seqpos_slice: tuple[int, int], expected_error: Optional[AssertionError] +): + context_size = 10 + if expected_error is AssertionError: + with pytest.raises(expected_error): + LanguageModelSAERunnerConfig( + seqpos_slice=seqpos_slice, + context_size=context_size, + ) + else: + LanguageModelSAERunnerConfig( + seqpos_slice=seqpos_slice, + context_size=context_size, + ) + + +@pytest.mark.parametrize("seqpos_slice, expected_error", test_cases_for_seqpos) +def test_cache_activations_runner_config_seqpos( + seqpos_slice: tuple[int, int], expected_error: Optional[AssertionError] +): + context_size = 10 + if expected_error is AssertionError: + with pytest.raises(expected_error): + CacheActivationsRunnerConfig( + seqpos_slice=seqpos_slice, + context_size=context_size, + ) + else: + CacheActivationsRunnerConfig( + seqpos_slice=seqpos_slice, + context_size=context_size, + ) diff --git a/tests/unit/training/test_sae_basic.py b/tests/unit/training/test_sae_basic.py index 55428bb9..530fca2c 100644 --- a/tests/unit/training/test_sae_basic.py +++ b/tests/unit/training/test_sae_basic.py @@ -225,6 +225,23 @@ def test_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None: assert torch.allclose(sae_out_1, sae_out_2) +def test_sae_seqpos(tmp_path: Path) -> None: + cfg = build_sae_cfg( + seqpos_slice=(1, 3), + device="cpu", + ) + model_path = str(tmp_path) + sae = SAE.from_dict(cfg.get_base_sae_cfg_dict()) + + assert sae.cfg.seqpos_slice == (1, 3) + + sae.save_model(model_path) + + sae_loaded = SAE.load_from_pretrained(model_path, device="cpu") + + assert sae_loaded.cfg.seqpos_slice == (1, 3) + + # TODO: Handle scaling factor in saeBase # def test_sae_save_and_load_from_pretrained_lacks_scaling_factor( # tmp_path: Path,