Skip to content

Commit

Permalink
feat: support othellogpt in SAELens (#317)
Browse files Browse the repository at this point in the history
* support seqpos slicing

* add basic tests, ensure it's in the SAE config

* format

* fix tests

* fix tests 2

* fix: Changing the activations store to handle context sizes smaller than dataset lengths for tokenized datasets.

* fix: Found bug which allowed for negative context lengths. Removed the bug

* Update pytest to test new logic for context size of tokenized dataset

* Reformat code to pass CI tests

* Add warning for when context_size is smaller than the dataset context_size

* feat: adding support for start and end position offsets for token sequences

* Add start_pos_offset and end_pos_offset to the SAERunnerConfig

* Add tests for start_pos_offset and end_pos_offset in the LanguageModelSAERunnerConfig

* feat: start and end position offset support for SAELens.

* Add test for CacheActivationsRunnerConfig with start and end pos offset

* Test cache activation runner wtih valid start and end pos offset

* feat: Enabling loading of start and end pos offset from saes. Adding
tests for this

* fix: Renaming variables and a test

* adds test for position offests for saes

* reformats files with black

* Add start and end pos offset to the base sae dict

* fix test for sae training runner config with position offsets

* add a benchmark test to train an SAE on OthelloGPT

* Remove double import from typing

* change dead_feature_window to int

* remove print statements from test file

* Rebase on seqpos tuple implementation and remove start/end pos offset

* Reword docstring for seqpos to be clearer.

* Added script to train an SAE on othelloGPT

---------

Co-authored-by: callummcdougall <cal.s.mcdougall@gmail.com>
Co-authored-by: jbloomAus <jbloomaus@gmail.com>
Co-authored-by: liuman <zhenninghimme@gmail.com>
  • Loading branch information
4 people authored Oct 15, 2024
1 parent 42ba557 commit 7047f87
Show file tree
Hide file tree
Showing 10 changed files with 438 additions and 22 deletions.
38 changes: 38 additions & 0 deletions sae_lens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class LanguageModelSAERunnerConfig:
store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating actiations.
train_batch_size_tokens (int): The batch size for training. This controls the batch size of the SAE Training loop.
normalize_activations (str): Activation Normalization Strategy. Either none, expected_average_only_in (estimate the average activation norm and divide activations by it -> this can be folded post training and set to None), or constant_norm_rescale (at runtime set activation norm to sqrt(d_in) and then scale up the SAE output).
seqpos_slice (tuple): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
device (str): The device to use. Usually cuda.
act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram.
seed (int): The seed to use.
Expand Down Expand Up @@ -153,6 +154,7 @@ class LanguageModelSAERunnerConfig:
normalize_activations: str = (
"none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
)
seqpos_slice: tuple[int | None, ...] = (None,)

# Misc
device: str = "cpu"
Expand Down Expand Up @@ -355,6 +357,13 @@ def __post_init__(self):
if self.use_ghost_grads:
print("Using Ghost Grads.")

if self.context_size < 0:
raise ValueError(
f"The provided context_size is {self.context_size} is negative. Expecting positive context_size."
)

_validate_seqpos(seqpos=self.seqpos_slice, context_size=self.context_size)

@property
def total_training_tokens(self) -> int:
return self.training_tokens + self.finetuning_tokens
Expand Down Expand Up @@ -386,6 +395,7 @@ def get_base_sae_cfg_dict(self) -> dict[str, Any]:
"normalize_activations": self.normalize_activations,
"activation_fn_kwargs": self.activation_fn_kwargs,
"model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
"seqpos_slice": self.seqpos_slice,
}

def get_training_sae_cfg_dict(self) -> dict[str, Any]:
Expand Down Expand Up @@ -427,6 +437,15 @@ def to_json(self, path: str) -> None:
def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig":
with open(path + "cfg.json", "r") as f:
cfg = json.load(f)

# ensure that seqpos slices is a tuple
# Ensure seqpos_slice is a tuple
if "seqpos_slice" in cfg:
if isinstance(cfg["seqpos_slice"], list):
cfg["seqpos_slice"] = tuple(cfg["seqpos_slice"])
elif not isinstance(cfg["seqpos_slice"], tuple):
cfg["seqpos_slice"] = (cfg["seqpos_slice"],)

return cls(**cfg)


Expand Down Expand Up @@ -461,6 +480,7 @@ class CacheActivationsRunnerConfig:
store_batch_size_prompts: int = 32
train_batch_size_tokens: int = 4096
normalize_activations: str = "none" # should always be none for activation caching
seqpos_slice: tuple[int | None, ...] = (None,)

# Misc
device: str = "cpu"
Expand Down Expand Up @@ -491,6 +511,13 @@ def __post_init__(self):
if self.act_store_device == "with_model":
self.act_store_device = self.device

if self.context_size < 0:
raise ValueError(
f"The provided context_size is {self.context_size} is negative. Expecting positive context_size."
)

_validate_seqpos(seqpos=self.seqpos_slice, context_size=self.context_size)


@dataclass
class ToyModelSAERunnerConfig:
Expand Down Expand Up @@ -576,6 +603,17 @@ def _default_cached_activations_path(
return path


def _validate_seqpos(seqpos: tuple[int | None, ...], context_size: int) -> None:
# Ensure that the step-size is larger or equal to 1
if len(seqpos) == 3:
step_size = seqpos[2] or 1
assert (
step_size > 1
), f"Ensure the step_size {seqpos[2]=} for sequence slicing is positive."
# Ensure that the choice of seqpos doesn't end up with an empty list
assert len(list(range(context_size))[slice(*seqpos)]) > 0


@dataclass
class PretokenizeRunnerConfig:
tokenizer_name: str = "gpt2"
Expand Down
6 changes: 6 additions & 0 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class SAEConfig:
activation_fn_kwargs: dict[str, Any] = field(default_factory=dict)
neuronpedia_id: Optional[str] = None
model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)
seqpos_slice: tuple[int | None, ...] = (None,)

@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig":
Expand All @@ -82,6 +83,10 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig":
for k, v in config_dict.items()
if k in cls.__dataclass_fields__ # pylint: disable=no-member
}

if "seqpos_slice" in config_dict:
config_dict["seqpos_slice"] = tuple(config_dict["seqpos_slice"])

return cls(**config_dict)

# def __post_init__(self):
Expand Down Expand Up @@ -109,6 +114,7 @@ def to_dict(self) -> dict[str, Any]:
"normalize_activations": self.normalize_activations,
"neuronpedia_id": self.neuronpedia_id,
"model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
"seqpos_slice": self.seqpos_slice,
}


Expand Down
41 changes: 22 additions & 19 deletions sae_lens/training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def from_config(
model_kwargs=cfg.model_kwargs,
autocast_lm=cfg.autocast_lm,
dataset_trust_remote_code=cfg.dataset_trust_remote_code,
seqpos_slice=cfg.seqpos_slice,
)

@classmethod
Expand Down Expand Up @@ -123,6 +124,7 @@ def from_sae(
dataset_trust_remote_code=sae.cfg.dataset_trust_remote_code,
dtype=sae.cfg.dtype,
device=torch.device(device),
seqpos_slice=sae.cfg.seqpos_slice,
)

def __init__(
Expand All @@ -147,6 +149,7 @@ def __init__(
model_kwargs: dict[str, Any] | None = None,
autocast_lm: bool = False,
dataset_trust_remote_code: bool | None = None,
seqpos_slice: tuple[int | None, ...] = (None,),
):
self.model = model
if model_kwargs is None:
Expand Down Expand Up @@ -188,6 +191,7 @@ def __init__(
self.dtype = DTYPE_MAP[dtype]
self.cached_activations_path = cached_activations_path
self.autocast_lm = autocast_lm
self.seqpos_slice = seqpos_slice

self.n_dataset_processed = 0

Expand Down Expand Up @@ -220,10 +224,6 @@ def __init__(
f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}.
The context_size {ds_context_size} is expected to be larger than or equal to the provided context size {self.context_size}."""
)
if self.context_size < 0:
raise ValueError(
f"The provided context_size is {self.context_size} is negative. Expecting positive context_size"
)
if self.context_size != ds_context_size:
warnings.warn(
f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}. Some data will be discarded in this case.""",
Expand Down Expand Up @@ -441,37 +441,39 @@ def get_activations(self, batch_tokens: torch.Tensor):
autocast_if_enabled = contextlib.nullcontext()

with autocast_if_enabled:
layerwise_activations = self.model.run_with_cache(
layerwise_activations_cache = self.model.run_with_cache(
batch_tokens,
names_filter=[self.hook_name],
stop_at_layer=self.hook_layer + 1,
prepend_bos=False,
**self.model_kwargs,
)[1]

n_batches, n_context = batch_tokens.shape
layerwise_activations = layerwise_activations_cache[self.hook_name][
:, slice(*self.seqpos_slice)
]

n_batches, n_context = layerwise_activations.shape[:2]

stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in))

if self.hook_head_index is not None:
stacked_activations[:, :, 0] = layerwise_activations[self.hook_name][
stacked_activations[:, :, 0] = layerwise_activations[
:, :, self.hook_head_index
]
elif (
layerwise_activations[self.hook_name].ndim > 3
): # if we have a head dimension
elif layerwise_activations.ndim > 3: # if we have a head dimension
try:
stacked_activations[:, :, 0] = layerwise_activations[
self.hook_name
].view(n_batches, n_context, -1)
stacked_activations[:, :, 0] = layerwise_activations.view(
n_batches, n_context, -1
)
except RuntimeError as e:
print(f"Error during view operation: {e}")
print("Attempting to use reshape instead...")
stacked_activations[:, :, 0] = layerwise_activations[
self.hook_name
].reshape(n_batches, n_context, -1)
stacked_activations[:, :, 0] = layerwise_activations.reshape(
n_batches, n_context, -1
)
else:
stacked_activations[:, :, 0] = layerwise_activations[self.hook_name]
stacked_activations[:, :, 0] = layerwise_activations

return stacked_activations

Expand All @@ -487,14 +489,15 @@ def get_buffer(
If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react.
"""
context_size = self.context_size
training_context_size = len(range(context_size)[slice(*self.seqpos_slice)])
batch_size = self.store_batch_size_prompts
d_in = self.d_in
total_size = batch_size * n_batches_in_buffer
num_layers = 1

if self.cached_activations_path is not None:
# Load the activations from disk
buffer_size = total_size * context_size
buffer_size = total_size * training_context_size
# Initialize an empty tensor with an additional dimension for layers
new_buffer = torch.zeros(
(buffer_size, num_layers, d_in),
Expand Down Expand Up @@ -548,7 +551,7 @@ def get_buffer(
refill_iterator = range(0, batch_size * n_batches_in_buffer, batch_size)
# Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
new_buffer = torch.zeros(
(total_size, context_size, num_layers, d_in),
(total_size, training_context_size, num_layers, d_in),
dtype=self.dtype, # type: ignore
device=self.device,
)
Expand Down
13 changes: 13 additions & 0 deletions sae_lens/training/training_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def from_sae_runner_config(
context_size=cfg.context_size,
dataset_path=cfg.dataset_path,
prepend_bos=cfg.prepend_bos,
seqpos_slice=cfg.seqpos_slice,
# Training cfg
l1_coefficient=cfg.l1_coefficient,
lp_norm=cfg.lp_norm,
Expand All @@ -99,6 +100,18 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAEConfig":
valid_config_dict = {
key: val for key, val in config_dict.items() if key in valid_field_names
}

# ensure seqpos slice is tuple
# ensure that seqpos slices is a tuple
# Ensure seqpos_slice is a tuple
if "seqpos_slice" in valid_config_dict:
if isinstance(valid_config_dict["seqpos_slice"], list):
valid_config_dict["seqpos_slice"] = tuple(
valid_config_dict["seqpos_slice"]
)
elif not isinstance(valid_config_dict["seqpos_slice"], tuple):
valid_config_dict["seqpos_slice"] = (valid_config_dict["seqpos_slice"],)

return TrainingSAEConfig(**valid_config_dict)

def to_dict(self) -> dict[str, Any]:
Expand Down
107 changes: 107 additions & 0 deletions scripts/training_a_sparse_autoencoder_othelloGPT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import os

import torch

from sae_lens import (
SAE,
HookedSAETransformer,
LanguageModelSAERunnerConfig,
SAETrainingRunner,
upload_saes_to_huggingface,
)

if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"


model_name = "othello-gpt"
model = HookedSAETransformer.from_pretrained(model_name)

dataset_path = "taufeeque/othellogpt"
context_size = 59

layer = 5
training_tokens = int(1e3)
train_batch_size_tokens = 2048
n_steps = int(training_tokens / train_batch_size_tokens)

print(LanguageModelSAERunnerConfig())
runner_cfg = LanguageModelSAERunnerConfig(
#
# Data generation
model_name=model_name,
hook_name=f"blocks.{layer}.mlp.hook_post",
hook_layer=layer,
d_in=model.cfg.d_mlp,
dataset_path=dataset_path,
is_dataset_tokenized=True,
prepend_bos=False,
streaming=True,
train_batch_size_tokens=train_batch_size_tokens,
context_size=context_size,
seqpos_slice=(5, -5),
#
# SAE achitecture
architecture="gated",
expansion_factor=8,
b_dec_init_method="zeros",
apply_b_dec_to_input=True,
normalize_sae_decoder=False,
scale_sparsity_penalty_by_decoder_norm=True,
decoder_heuristic_init=True,
init_encoder_as_decoder_transpose=True,
#
# Activations store
n_batches_in_buffer=32,
store_batch_size_prompts=16,
training_tokens=training_tokens,
#
# Training hyperparameters (standard)
lr=2e-4,
adam_beta1=0.9,
adam_beta2=0.999,
lr_scheduler_name="constant",
lr_warm_up_steps=int(0.2 * n_steps),
lr_decay_steps=int(0.2 * n_steps),
#
# Training hyperparameters (SAE-specific)
l1_coefficient=5,
l1_warm_up_steps=int(0.2 * n_steps),
use_ghost_grads=False,
feature_sampling_window=1000,
dead_feature_window=500,
dead_feature_threshold=1e-5,
#
# Logging / evals
log_to_wandb=True,
wandb_project=f"othello_gpt_sae_{layer=}",
wandb_log_frequency=30,
eval_every_n_wandb_logs=10,
checkpoint_path="checkpoints",
#
# Misc.
device=str(device),
seed=42,
n_checkpoints=5,
dtype="float32",
)

# t.set_grad_enabled(True)
runner = SAETrainingRunner(runner_cfg)
sae = runner.run()

hf_repo_id = "callummcdougall/arena-demos-othellogpt"
sae_id = "blocks.5.mlp.hook_post-v1"

upload_saes_to_huggingface({sae_id: sae}, hf_repo_id=hf_repo_id)

othellogpt_sae = SAE.from_pretrained(
release=hf_repo_id, sae_id=sae_id, device=str(device)
)[0]
Loading

0 comments on commit 7047f87

Please sign in to comment.