Skip to content

Commit

Permalink
Unified probe training (#38)
Browse files Browse the repository at this point in the history
* adding updated lr probing code and integrating into experiments

* skip embedding SAEs

* train probes if needed in latent evaluation experiment

* try more filtering in probe training to save gpu memory

* fixing inference mode nonsense

* fixing save_probe_and_data

* removing note from README
  • Loading branch information
chanind authored Nov 12, 2024
1 parent 9bcada0 commit 03159e4
Show file tree
Hide file tree
Showing 7 changed files with 375 additions and 333 deletions.
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ We include the following experiments from the paper in the `sae_spelling.experim

These experiments each include a main "runner" function to run the experiment. These runners will only create data-frames and save them to disk, but won't generate plots. Experiments packages include helpers for generating the plots in the paper, but these plots require tex to be installed, so we don't generate plots by default.

**NOTE**
The experiments all require logistic-regression probes to be trained and data about the train/test split to be saved into dataframes in a specific format. We have not yet moved that code into this repo, but will do that in the next few days.

## Development

This project uses [Ruff](https://docs.astral.sh/ruff/) for linting and formatting, [Pyright](https://github.com/microsoft/pyright) for type checking, and [Pytest](https://docs.pytest.org/en/stable/) for testing.
Expand Down
164 changes: 148 additions & 16 deletions sae_spelling/experiments/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,24 @@
import torch
from sae_lens import SAE
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from tqdm.autonotebook import tqdm
from transformer_lens import HookedTransformer
from transformers import PreTrainedTokenizerFast

from sae_spelling.probing import LinearProbe
from sae_spelling.probing import (
LinearProbe,
create_dataset_probe_training,
gen_and_save_df_acts_probing,
save_probe_and_data,
train_linear_probe_for_task,
)
from sae_spelling.prompting import (
VERBOSE_FIRST_LETTER_TEMPLATE,
VERBOSE_FIRST_LETTER_TOKEN_POS,
Formatter,
first_letter_formatter,
)
from sae_spelling.vocab import get_alpha_tokens

DEFAULT_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -68,35 +82,77 @@ def load_gemmascope_sae(
return sae


def load_or_train_probe(
model: HookedTransformer,
layer: int = 0,
probes_dir: str | Path = PROBES_DIR,
dtype: torch.dtype = DEFAULT_DTYPE,
device: str = DEFAULT_DEVICE,
) -> LinearProbe:
probe_path = Path(probes_dir) / f"layer_{layer}" / "probe.pth"
if not probe_path.exists():
print(f"Probe for layer {layer} not found, training...")
train_and_save_probes(
model,
[layer],
probes_dir,
)
return load_probe(layer, probes_dir, dtype, device)


def load_probe(
task: str = "first_letter",
layer: int = 0,
probes_dir: str | Path = PROBES_DIR,
dtype: torch.dtype = DEFAULT_DTYPE,
device: str = DEFAULT_DEVICE,
) -> LinearProbe:
probe = torch.load(
Path(probes_dir) / task / f"layer_{layer}" / f"{task}_probe.pth",
Path(probes_dir) / f"layer_{layer}" / "probe.pth",
map_location=device,
).to(dtype=dtype)
return probe


def load_probe_data_split_or_train(
model: HookedTransformer,
layer: int = 0,
split: Literal["train", "test"] = "test",
probes_dir: str | Path = PROBES_DIR,
dtype: torch.dtype = DEFAULT_DTYPE,
device: str = DEFAULT_DEVICE,
) -> tuple[torch.Tensor, list[tuple[str, int]]]:
probe_path = Path(probes_dir) / f"layer_{layer}" / "probe.pth"
if not probe_path.exists():
print(f"Probe for layer {layer} not found, training...")
train_and_save_probes(
model,
[layer],
probes_dir,
)
return load_probe_data_split(
model.tokenizer, # type: ignore
layer,
split,
probes_dir,
dtype,
device,
)


@torch.inference_mode()
def load_probe_data_split(
tokenizer: PreTrainedTokenizerFast,
task: str = "first_letter",
layer: int = 0,
split: Literal["train", "test"] = "test",
probes_dir: str | Path = PROBES_DIR,
dtype: torch.dtype = DEFAULT_DTYPE,
device: str = DEFAULT_DEVICE,
) -> tuple[torch.Tensor, list[tuple[str, int]]]:
np_data = np.load(
Path(probes_dir) / task / f"layer_{layer}" / f"{task}_data.npz",
Path(probes_dir) / f"layer_{layer}" / "data.npz",
)
df = pd.read_csv(
Path(probes_dir) / task / f"layer_{layer}" / f"{task}_{split}_df.csv",
Path(probes_dir) / f"layer_{layer}" / f"{split}_df.csv",
keep_default_na=False,
na_values=[""],
)
Expand Down Expand Up @@ -155,7 +211,9 @@ def get_gemmascope_saes_info(layer: int | None = None) -> list[SaeInfo]:
if width_match.group(2) == "m":
width *= 1000
layer_match = re.search(r"layer_(\d+)", sae_name)
assert layer_match is not None
# new embedding SAEs don't have a layer; we don't care about them, so just skip
if layer_match is None:
continue
sae_layer = int(layer_match.group(1))
# this SAE is missing, see https://github.com/jbloomAus/SAELens/pull/293. Just skip it.
if layer == 11 and l0 == 79:
Expand All @@ -165,21 +223,15 @@ def get_gemmascope_saes_info(layer: int | None = None) -> list[SaeInfo]:
return saes


def get_task_dir(
def get_or_make_dir(
experiment_dir: str | Path,
task: str = "first_letter",
) -> Path:
"""
Helper to create a directory for a specific task within an experiment directory.
"""
# TODO: support more tasks
if task != "first_letter":
raise ValueError(f"Unsupported task: {task}")

experiment_dir = Path(experiment_dir)
task_output_dir = experiment_dir / task
task_output_dir.mkdir(parents=True, exist_ok=True)
return task_output_dir
experiment_dir.mkdir(parents=True, exist_ok=True)
return experiment_dir


def load_experiment_df(
Expand Down Expand Up @@ -227,3 +279,83 @@ def humanify_sae_width(width: int) -> str:
return "1m"
else:
return f"{width // 1_000}k"


def create_and_train_probe(
model: HookedTransformer,
formatter: Formatter,
hook_point: str,
probes_dir: str | Path,
vocab: list[str],
batch_size: int,
num_epochs: int,
lr: float,
device: torch.device,
base_template: str,
pos_idx: int,
num_prompts_per_token: int = 1,
):
train_dataset, test_dataset = create_dataset_probe_training(
vocab=vocab,
formatter=formatter,
num_prompts_per_token=num_prompts_per_token,
base_template=base_template,
)

layer = int(hook_point.split(".")[1])

train_df, test_df, train_activations, test_activations = (
gen_and_save_df_acts_probing(
model=model,
train_dataset=train_dataset,
test_dataset=test_dataset,
path=probes_dir,
hook_point=hook_point,
batch_size=batch_size,
layer=layer,
position_idx=pos_idx,
)
)

num_classes = 26
probe, probe_data = train_linear_probe_for_task(
train_df=train_df,
test_df=test_df,
train_activations=train_activations,
test_activations=test_activations,
num_classes=num_classes,
batch_size=32 * batch_size,
num_epochs=num_epochs,
lr=lr,
device=device,
)

save_probe_and_data(probe, probe_data, probes_dir, layer)
print("Probe saved successfully.\n")


def train_and_save_probes(
model: HookedTransformer,
layers: list[int],
probes_dir: str | Path = PROBES_DIR,
batch_size=64,
num_epochs=50,
lr=1e-2,
device=torch.device("cuda"),
):
vocab = get_alpha_tokens(model.tokenizer) # type: ignore
for layer in tqdm(layers):
hook_point = f"blocks.{layer}.hook_resid_post"
create_and_train_probe(
model=model,
hook_point=hook_point,
formatter=first_letter_formatter(),
probes_dir=probes_dir,
vocab=vocab,
batch_size=batch_size,
num_epochs=num_epochs,
lr=lr,
device=device,
base_template=VERBOSE_FIRST_LETTER_TEMPLATE,
pos_idx=VERBOSE_FIRST_LETTER_TOKEN_POS,
)
15 changes: 5 additions & 10 deletions sae_spelling/experiments/feature_absorption.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
PROBES_DIR,
SaeInfo,
get_gemmascope_saes_info,
get_task_dir,
get_or_make_dir,
humanify_sae_width,
load_df_or_run,
load_experiment_df,
Expand Down Expand Up @@ -231,9 +231,8 @@ def _aggregate_results_df(
def plot_absorption_rate_vs_l0(
results: dict[int, list[tuple[pd.DataFrame, SaeInfo]]],
experiment_dir: Path | str = EXPERIMENTS_DIR / FEATURE_ABSORPTION_EXPERIMENT_NAME,
task: str = "first_letter",
):
task_output_dir = get_task_dir(experiment_dir, task=task)
task_output_dir = get_or_make_dir(experiment_dir)
df = _aggregate_results_df(results)

sns.set_theme()
Expand Down Expand Up @@ -263,9 +262,8 @@ def plot_absorption_rate_vs_l0(
def plot_absorption_rate_vs_layer(
results: dict[int, list[tuple[pd.DataFrame, SaeInfo]]],
experiment_dir: Path | str = EXPERIMENTS_DIR / FEATURE_ABSORPTION_EXPERIMENT_NAME,
task: str = "first_letter",
):
task_output_dir = get_task_dir(experiment_dir, task=task)
task_output_dir = get_or_make_dir(experiment_dir)
df = _aggregate_results_df(results)
grouped_df = (
df[["layer", "sae_l0", "sae_width_str", "absorption_rate"]]
Expand Down Expand Up @@ -309,7 +307,6 @@ def run_feature_absortion_experiments(
sparse_probing_experiment_dir: Path | str = EXPERIMENTS_DIR
/ SPARSE_PROBING_EXPERIMENT_NAME,
probes_dir: Path | str = PROBES_DIR,
task: str = "first_letter",
force: bool = False,
skip_1m_saes: bool = True,
skip_32k_saes: bool = True,
Expand All @@ -321,10 +318,8 @@ def run_feature_absortion_experiments(
"""
NOTE: this experiments requires the results of the k-sparse probing experiments. Make sure to run them first.
"""
task_output_dir = get_task_dir(experiment_dir, task=task)
sparse_probing_task_output_dir = get_task_dir(
sparse_probing_experiment_dir, task=task
)
task_output_dir = get_or_make_dir(experiment_dir)
sparse_probing_task_output_dir = get_or_make_dir(sparse_probing_experiment_dir)

model = load_gemma2_model()
vocab = get_alpha_tokens(model.tokenizer) # type: ignore
Expand Down
Loading

0 comments on commit 03159e4

Please sign in to comment.