Skip to content

Commit

Permalink
Rebase on seqpos tuple implementation and remove start/end pos offset
Browse files Browse the repository at this point in the history
  • Loading branch information
decandido committed Oct 9, 2024
1 parent c0dc5bf commit 9130ff9
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 139 deletions.
56 changes: 15 additions & 41 deletions sae_lens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ class LanguageModelSAERunnerConfig:
streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical.
is_dataset_tokenized (bool): NOT IN USE. We used to use this but now automatically detect if the dataset is tokenized.
context_size (int): The context size to use when generating activations on which to train the SAE.
start_pos_offset (int): A positive offset to cut off the start of the sequences used to train the SAE.
end_pos_offset (int): A positive offset to cut off the end of the sequences used to train the SAE.
use_cached_activations (bool): Whether to use cached activations. This is useful when doing sweeps over the same activations.
cached_activations_path (str, optional): The path to the cached activations.
d_in (int): The input dimension of the SAE.
Expand All @@ -62,7 +60,7 @@ class LanguageModelSAERunnerConfig:
store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating actiations.
train_batch_size_tokens (int): The batch size for training. This controls the batch size of the SAE Training loop.
normalize_activations (str): Activation Normalization Strategy. Either none, expected_average_only_in (estimate the average activation norm and divide activations by it -> this can be folded post training and set to None), or constant_norm_rescale (at runtime set activation norm to sqrt(d_in) and then scale up the SAE output).
seqpos_slice (tuple): Determines slicing of (batch, seq, d_in) activations when constructing batches, during training. Example: for Othello we sometimes use (5, -5).
seqpos_slice (tuple): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, step_size), e.g. for Othello we sometimes use (5, -5).
device (str): The device to use. Usually cuda.
act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram.
seed (int): The seed to use.
Expand Down Expand Up @@ -124,12 +122,6 @@ class LanguageModelSAERunnerConfig:
streaming: bool = True
is_dataset_tokenized: bool = True
context_size: int = 128
start_pos_offset: int = (
0 # set to n if you want to exclude first n seq positions from sae training
)
end_pos_offset: int = (
0 # set to n if you want to exclude last n seq positions from sae training
)
use_cached_activations: bool = False
cached_activations_path: Optional[str] = (
None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_head_index}"
Expand Down Expand Up @@ -370,18 +362,7 @@ def __post_init__(self):
f"The provided context_size is {self.context_size} is negative. Expecting positive context_size."
)

if (self.start_pos_offset < 0) or (self.start_pos_offset > self.context_size):
raise ValueError(
f"Start position offset {self.start_pos_offset} should be in range [0, {self.context_size}]"
)
if (self.end_pos_offset < 0) or (self.end_pos_offset >= self.context_size):
raise ValueError(
f"End position offset {self.end_pos_offset} should be in range [0, {self.context_size-1}]"
)
if self.start_pos_offset + self.end_pos_offset > self.context_size:
raise ValueError(
f"Choice of {self.start_pos_offset=} and {self.end_pos_offset=} is incompatible with {self.context_size=}. We expect start_pos_offset + end_pos_offset < context_size."
)
_validate_seqpos(seqpos=self.seqpos_slice, context_size=self.context_size)

@property
def total_training_tokens(self) -> int:
Expand All @@ -406,8 +387,6 @@ def get_base_sae_cfg_dict(self) -> dict[str, Any]:
"activation_fn_str": self.activation_fn,
"apply_b_dec_to_input": self.apply_b_dec_to_input,
"context_size": self.context_size,
"start_pos_offset": self.start_pos_offset,
"end_pos_offset": self.end_pos_offset,
"prepend_bos": self.prepend_bos,
"dataset_path": self.dataset_path,
"dataset_trust_remote_code": self.dataset_trust_remote_code,
Expand Down Expand Up @@ -487,12 +466,6 @@ class CacheActivationsRunnerConfig:
streaming: bool = True
is_dataset_tokenized: bool = True
context_size: int = 128
start_pos_offset: int = (
0 # set to n if you want to exclude first n seq positions from sae training
)
end_pos_offset: int = (
0 # set to n if you want to exclude last n seq positions from sae training
)
new_cached_activations_path: Optional[str] = (
None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_head_index}"
)
Expand Down Expand Up @@ -542,18 +515,8 @@ def __post_init__(self):
raise ValueError(
f"The provided context_size is {self.context_size} is negative. Expecting positive context_size."
)
if (self.start_pos_offset < 0) or (self.start_pos_offset > self.context_size):
raise ValueError(
f"Start position offset {self.start_pos_offset} should be in range [0, {self.context_size}]"
)
if (self.end_pos_offset < 0) or (self.end_pos_offset >= self.context_size):
raise ValueError(
f"End position offset {self.end_pos_offset} should be in range [0, {self.context_size-1}]"
)
if self.start_pos_offset + self.end_pos_offset > self.context_size:
raise ValueError(
f"Choice of {self.start_pos_offset=} and {self.end_pos_offset=} is incompatible with {self.context_size=}. We expect start_pos_offset + end_pos_offset < context_size."
)

_validate_seqpos(seqpos=self.seqpos_slice, context_size=self.context_size)


@dataclass
Expand Down Expand Up @@ -640,6 +603,17 @@ def _default_cached_activations_path(
return path


def _validate_seqpos(seqpos: tuple[int | None, ...], context_size: int) -> None:
# Ensure that the step-size is larger or equal to 1
if len(seqpos) == 3:
step_size = seqpos[2] or 1
assert (
step_size > 1
), f"Ensure the step_size {seqpos[2]=} for sequence slicing is positive."
# Ensure that the choice of seqpos doesn't end up with an empty list
assert len(list(range(context_size))[slice(*seqpos)]) > 0


@dataclass
class PretokenizeRunnerConfig:
tokenizer_name: str = "gpt2"
Expand Down
19 changes: 5 additions & 14 deletions sae_lens/training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ def from_config(
hook_layer=cfg.hook_layer,
hook_head_index=cfg.hook_head_index,
context_size=cfg.context_size,
start_pos_offset=cfg.start_pos_offset,
end_pos_offset=cfg.end_pos_offset,
d_in=cfg.d_in,
n_batches_in_buffer=cfg.n_batches_in_buffer,
total_training_tokens=cfg.training_tokens,
Expand Down Expand Up @@ -116,8 +114,6 @@ def from_sae(
hook_layer=sae.cfg.hook_layer,
hook_head_index=sae.cfg.hook_head_index,
context_size=sae.cfg.context_size if context_size is None else context_size,
start_pos_offset=sae.cfg.start_pos_offset,
end_pos_offset=sae.cfg.end_pos_offset,
prepend_bos=sae.cfg.prepend_bos,
streaming=streaming,
store_batch_size_prompts=store_batch_size_prompts,
Expand All @@ -140,8 +136,6 @@ def __init__(
hook_layer: int,
hook_head_index: int | None,
context_size: int,
start_pos_offset: int,
end_pos_offset: int,
d_in: int,
n_batches_in_buffer: int,
total_training_tokens: int,
Expand Down Expand Up @@ -185,8 +179,6 @@ def __init__(
self.hook_layer = hook_layer
self.hook_head_index = hook_head_index
self.context_size = context_size
self.start_pos_offset = start_pos_offset
self.end_pos_offset = end_pos_offset
self.d_in = d_in
self.n_batches_in_buffer = n_batches_in_buffer
self.half_buffer_size = n_batches_in_buffer // 2
Expand Down Expand Up @@ -457,7 +449,11 @@ def get_activations(self, batch_tokens: torch.Tensor):
**self.model_kwargs,
)[1]

n_batches, n_context = batch_tokens.shape
layerwise_activations = layerwise_activations_cache[self.hook_name][
:, slice(*self.seqpos_slice)
]

n_batches, n_context = layerwise_activations.shape[:2]

stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in))

Expand Down Expand Up @@ -498,11 +494,6 @@ def get_buffer(
d_in = self.d_in
total_size = batch_size * n_batches_in_buffer
num_layers = 1
# Calculate the effective context size
training_context_slice = list(
range(self.start_pos_offset, context_size - self.end_pos_offset)
)
training_context_size = len(training_context_slice)

if self.cached_activations_path is not None:
# Load the activations from disk
Expand Down
6 changes: 1 addition & 5 deletions tests/benchmark/test_language_model_sae_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,13 +306,10 @@ def test_language_model_sae_runner_othellogpt():
total_training_steps = 500
batch_size = 4096
total_training_tokens = total_training_steps * batch_size
print(f"Total Training Tokens: {total_training_tokens}")

lr_warm_up_steps = 0
lr_decay_steps = 40_000
print(f"lr_decay_steps: {lr_decay_steps}")
l1_warmup_steps = 10_000
print(f"l1_warmup_steps: {l1_warmup_steps}")

cfg = LanguageModelSAERunnerConfig(
# Data Generating Function (Model + Training Distibuion)
Expand Down Expand Up @@ -345,8 +342,7 @@ def test_language_model_sae_runner_othellogpt():
lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1)
train_batch_size_tokens=batch_size,
context_size=59, # will control the length of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
start_pos_offset=5,
end_pos_offset=5,
seqpos_slice=(5, -5),
# Activation Store Parameters
n_batches_in_buffer=32, # controls how many activations we store / shuffle.
training_tokens=total_training_tokens, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
Expand Down
4 changes: 0 additions & 4 deletions tests/unit/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ class LanguageModelSAERunnerConfigDict(TypedDict, total=False):
checkpoint_path: str
dtype: str
prepend_bos: bool
start_pos_offset: int
end_pos_offset: int


def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig:
Expand Down Expand Up @@ -77,8 +75,6 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig:
"checkpoint_path": "test/checkpoints",
"dtype": "float32",
"prepend_bos": True,
"start_pos_offset": 0,
"end_pos_offset": 0,
}

for key, value in kwargs.items():
Expand Down
12 changes: 5 additions & 7 deletions tests/unit/training/test_cache_activations_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def W_E(self) -> torch.Tensor:

# The way to run this with this command:
# poetry run py.test tests/unit/test_cache_activations_runner.py --profile-svg -s
def test_cache_activations_runner_with_valid_start_end_pos_offset(tmp_path: Path):
def test_cache_activations_runner_with_valid_seqpos(tmp_path: Path):
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
Expand All @@ -229,9 +229,8 @@ def test_cache_activations_runner_with_valid_start_end_pos_offset(tmp_path: Path

# total_training_steps = 20_000
context_size = 1024
start_pos_offset = 12
end_pos_offset = 12
training_context_size = context_size - start_pos_offset - end_pos_offset
seqpos_slice = (12, -12)
training_context_size = len(range(context_size)[slice(*seqpos_slice)])
n_batches_in_buffer = 32
store_batch_size = 1
n_buffers = 3
Expand Down Expand Up @@ -259,9 +258,8 @@ def test_cache_activations_runner_with_valid_start_end_pos_offset(tmp_path: Path
prepend_bos=True, # I used to train GPT2 SAEs with a prepended-bos but no longer think we should do this.
training_tokens=total_training_tokens, # For initial testing I think this is a good number.
train_batch_size_tokens=4096,
# Test the start and end pos offset
start_pos_offset=start_pos_offset,
end_pos_offset=end_pos_offset,
# Test the sequence slicing
seqpos_slice=seqpos_slice,
# Loss Function
## Reconstruction Coefficient.
# Buffer details won't matter in we cache / shuffle our activations ahead of time.
Expand Down
66 changes: 20 additions & 46 deletions tests/unit/training/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ def test_sae_training_runner_config_get_sae_base_parameters():
"hook_head_index": None,
"device": "cpu",
"context_size": 128,
"start_pos_offset": 0,
"end_pos_offset": 0,
"prepend_bos": True,
"finetuning_scaling_factor": False,
"dataset_path": "",
Expand Down Expand Up @@ -98,69 +96,45 @@ def test_sae_training_runner_config_expansion_factor():
assert cfg.expansion_factor == 4


@pytest.mark.parametrize(
"start_pos_offset, end_pos_offset, expected_error",
[
(-1, 0, ValueError),
(0, 0, None),
(10, 0, None),
(11, 0, ValueError),
(0, -1, ValueError),
(0, 10, ValueError),
(0, 11, ValueError),
(5, 5, None),
(6, 5, ValueError),
(3, 4, None),
],
)
def test_sae_training_runner_config_start_end_pos_offset(
start_pos_offset: int, end_pos_offset: int, expected_error: Optional[ValueError]
test_cases_for_seqpos = [
((None, 10, -1), AssertionError),
((None, 10, 0), AssertionError),
((5, 5, None), AssertionError),
((6, 3, None), AssertionError),
]


@pytest.mark.parametrize("seqpos_slice, expected_error", test_cases_for_seqpos)
def test_sae_training_runner_config_seqpos(
seqpos_slice: tuple[int, int], expected_error: Optional[AssertionError]
):
context_size = 10
if expected_error is ValueError:
if expected_error is AssertionError:
with pytest.raises(expected_error):
LanguageModelSAERunnerConfig(
start_pos_offset=start_pos_offset,
end_pos_offset=end_pos_offset,
seqpos_slice=seqpos_slice,
context_size=context_size,
)
else:
LanguageModelSAERunnerConfig(
start_pos_offset=start_pos_offset,
end_pos_offset=end_pos_offset,
seqpos_slice=seqpos_slice,
context_size=context_size,
)


@pytest.mark.parametrize(
"start_pos_offset, end_pos_offset, expected_error",
[
(-1, 0, ValueError),
(0, 0, None),
(10, 0, None),
(11, 0, ValueError),
(0, -1, ValueError),
(0, 10, ValueError),
(0, 11, ValueError),
(5, 5, None),
(6, 5, ValueError),
(3, 4, None),
],
)
def test_cache_activations_runner_config_start_end_pos_offset(
start_pos_offset: int, end_pos_offset: int, expected_error: Optional[ValueError]
@pytest.mark.parametrize("seqpos_slice, expected_error", test_cases_for_seqpos)
def test_cache_activations_runner_config_seqpos(
seqpos_slice: tuple[int, int], expected_error: Optional[AssertionError]
):
context_size = 10
if expected_error is ValueError:
if expected_error is AssertionError:
with pytest.raises(expected_error):
CacheActivationsRunnerConfig(
start_pos_offset=start_pos_offset,
end_pos_offset=end_pos_offset,
seqpos_slice=seqpos_slice,
context_size=context_size,
)
else:
CacheActivationsRunnerConfig(
start_pos_offset=start_pos_offset,
end_pos_offset=end_pos_offset,
seqpos_slice=seqpos_slice,
context_size=context_size,
)
22 changes: 0 additions & 22 deletions tests/unit/training/test_sae_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,25 +301,3 @@ def test_sae_change_dtype() -> None:
sae.to(dtype=torch.float16)
assert sae.dtype == torch.float16
assert sae.cfg.dtype == "torch.float16"


def test_sae_position_offsets(tmp_path: Path) -> None:
cfg = build_sae_cfg(
device="cpu",
context_size=10,
start_pos_offset=2,
end_pos_offset=2,
dtype="float64",
)
model_path = str(tmp_path)
sae = SAE.from_dict(cfg.get_base_sae_cfg_dict())

assert sae.cfg.start_pos_offset == 2
assert sae.cfg.end_pos_offset == 2

sae.save_model(model_path)

sae_loaded = sae.load_from_pretrained(model_path, device="cpu")

assert sae_loaded.cfg.start_pos_offset == 2
assert sae_loaded.cfg.end_pos_offset == 2

0 comments on commit 9130ff9

Please sign in to comment.