Skip to content

Commit

Permalink
polishing experiments and expanding README
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Sep 23, 2024
1 parent 1d3cc80 commit d3631a6
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 88 deletions.
56 changes: 47 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SAE Spelling

Shared code for SAE spelling experiments as part of LASR
Code for the paper [A is for Absorption: Studying Feature Splitting and Absorption in Sparse Autoencoders](https://linktr.ee/lasr_2024).

## Installation

Expand All @@ -10,16 +10,43 @@ This project uses [Poetry](https://python-poetry.org/) for dependency management
poetry install
```

#### Poetry tips
## Project structure

Below are some helpful tips for working with Poetry:
This project is set up so that code which could be reused in other projects is in the main `sae_spelling` package, and code specific to the experiments in the paper are in `sae_spelling.experiments`. In the future, we may move some of these utilities to their own library. The `sae_spelling` package is structured as follows:

- Install a new main dependency: `poetry add <package>`
- Install a new development dependency: `poetry add --dev <package>`
- Development dependencies are not required for the main code to run, but are for things like linting/type-checking/etc...
- Update the lockfile: `poetry lock`
- Run a command using the virtual environment: `poetry run <command>`
- Run a Python file from the CLI as a script (module-style): `poetry run python -m sae_spelling.path.to.file`
- `feature_attribution`: Code for running SAE feature attribution experiments. Attribution tries to estimate. Main exports include:
- `calculate_feature_attribution()`
- `calculate_integrated_gradient_attribution_patching()`
- `feature_ablation`: Code for running SAE feature ablation experiments. This involves ablating each firing SAE feature on a prompt to see how it affects a downstream metric (e.g. if the model knows the first letter of a token). The main function in this module is:
- `calculate_individual_feature_ablations()`
- `probing`: Code for training logistic-regression probes in Torch. Some helpfult exports from this module are:
- `train_multi_probe()`: train a multi-class binary probe
- `train_binary_probe()`: train a binary probe (same as the multi-class probe, but with only one class)
- `prompting`: Code for generating ICL prompts, mainly focussed on spelling. Some helpfult exports from this module are:
- `create_icl_prompt()`
- `spelling_formatter()`: formatter which outputs the spelling of a token
- `first_letter_formatter()`: formatter which outputs the first letter of a token
- `vocab`: Helpers for working with token-vocabularies. Some helpfult exports from this module are:
- `get_alpha_tokens()`: Filter tokens from tokenizer vocab which are alphabetic
- `sae_utils`: Helpers for working with SAEs. Main exports include:
- `apply_saes_and_run()`: Apply SAEs to a model and run on a prompt. Allows providing a list of hooks and optionally track activation gradients. This is used in attribution and ablation experiments.
- `spelling_grader`: Code for grading spelling prompts. Some helpfult exports from this module are:
- `SpellingGrader`: Class for grading model performing on spelling prompts
- `feature_absorption_calculator`: Code for calculating feature absorption. Some helpfult exports from this module are:
- `FeatureAbsorptionCalculator`

### Experiments

We include the following experiments from the paper in the `sae_spelling.experiments` package:

- `latent_evaluation`: This experiment finds the top SAE latent for each first-letter spelling task, and evaluates the latent's performance relative to a LR probe.
- `k_sparse_probing`: This experiment trains k-sparse probes on the first-letter task, and evaluates the performance with increasing value of `k`. This is used to detect feature splitting.
- `feature_absorption`: This experiment attempts to quantify feature absorption on the first-letter task across SAEs.

These experiments each include a main "runner" function to run the experiment. These runners will only create data-frames and save them to disk, but won't generate plots. Experiments packages include helpers for generating the plots in the paper, but these plots require tex to be installed, so we don't generate plots by default.

**NOTE**
The experiments all require logistic-regression probes to be trained and data about the train/test split to be saved into dataframes in a specific format. We have not yet moved that code into this repo, but will do that in the next few days.

## Development

Expand All @@ -38,3 +65,14 @@ You can install a pre-commit hook to run linting and type-checking before commit
```
poetry run pre-commit install
```

### Poetry tips

Below are some helpful tips for working with Poetry:

- Install a new main dependency: `poetry add <package>`
- Install a new development dependency: `poetry add --dev <package>`
- Development dependencies are not required for the main code to run, but are for things like linting/type-checking/etc...
- Update the lockfile: `poetry lock`
- Run a command using the virtual environment: `poetry run <command>`
- Run a Python file from the CLI as a script (module-style): `poetry run python -m sae_spelling.path.to.file`
4 changes: 4 additions & 0 deletions sae_spelling/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from sae_spelling.prompting import create_icl_prompt, spelling_formatter
from sae_spelling.vocab import get_alpha_tokens

###
# NOTE: this is never used in the paper. Just keeping here for reference.
###


@dataclass
class BaselineResult:
Expand Down
21 changes: 5 additions & 16 deletions sae_spelling/experiments/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,9 @@
DEFAULT_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

TEAM_DIR = Path("/content/drive/MyDrive/Team_Joseph")
EXPERIMENTS_DIR = TEAM_DIR / "experiments"
PROBES_DIR = (
TEAM_DIR
/ "data"
/ "probing_data"
/ "gemma-2"
/ "verbose_prompts"
/ "no_contamination"
)
EXPERIMENTS_DIR = Path.cwd() / "experiments"
# TODO: add probe training code
PROBES_DIR = Path.cwd() / "probes"


def dtype_to_str(dtype: torch.dtype | str) -> str:
Expand Down Expand Up @@ -230,11 +223,7 @@ def humanify_sae_width(width: int) -> str:
"""
A helper to convert SAE width to a nicer human-readable string.
"""
if width == 16_000:
return "16k"
elif width == 65_000:
return "65k"
elif width == 1_000_000:
if width == 1_000_000:
return "1m"
else:
raise ValueError(f"Unknown width: {width}")
return f"{width // 1_000}k"
8 changes: 6 additions & 2 deletions sae_spelling/experiments/feature_absorption.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from sae_spelling.experiments.common import (
EXPERIMENTS_DIR,
PROBES_DIR,
SaeInfo,
get_gemmascope_saes_info,
get_task_dir,
Expand Down Expand Up @@ -178,10 +179,11 @@ def load_and_run_calculate_ig_ablation_and_cos_sims(
calculator: FeatureAbsorptionCalculator,
auroc_f1_df: pd.DataFrame,
sae_info: SaeInfo,
probes_dir: Path | str,
sparse_probing_task_output_dir: Path,
) -> pd.DataFrame:
layer = sae_info.layer
probe = load_probe(layer=layer)
probe = load_probe(layer=layer, probes_dir=probes_dir)
sae = load_gemmascope_sae(layer, width=sae_info.width, l0=sae_info.l0)
likely_negs = get_stats_and_likely_false_negative_tokens(
auroc_f1_df,
Expand Down Expand Up @@ -306,6 +308,7 @@ def run_feature_absortion_experiments(
experiment_dir: Path | str = EXPERIMENTS_DIR / FEATURE_ABSORPTION_EXPERIMENT_NAME,
sparse_probing_experiment_dir: Path | str = EXPERIMENTS_DIR
/ SPARSE_PROBING_EXPERIMENT_NAME,
probes_dir: Path | str = PROBES_DIR,
task: str = "first_letter",
force: bool = False,
skip_1m_saes: bool = True,
Expand Down Expand Up @@ -371,7 +374,8 @@ def run_feature_absortion_experiments(
calculator,
auroc_f1_df,
sae_info,
sparse_probing_task_output_dir,
probes_dir=probes_dir,
sparse_probing_task_output_dir=sparse_probing_task_output_dir,
),
df_path,
force=force,
Expand Down
28 changes: 9 additions & 19 deletions sae_spelling/experiments/k_sparse_probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from sae_spelling.experiments.common import (
EXPERIMENTS_DIR,
PROBES_DIR,
SaeInfo,
get_gemmascope_saes_info,
get_task_dir,
Expand All @@ -27,7 +28,6 @@
load_probe,
load_probe_data_split,
)
from sae_spelling.experiments.encoder_auroc_and_f1 import find_optimal_f1_threshold
from sae_spelling.probing import LinearProbe, train_multi_probe
from sae_spelling.util import DEFAULT_DEVICE, batchify
from sae_spelling.vocab import LETTERS
Expand Down Expand Up @@ -247,6 +247,7 @@ def row_generator():
def load_and_run_eval_probe_and_sae_k_sparse_raw_scores(
sae_info: SaeInfo,
tokenizer: PreTrainedTokenizerFast,
probes_dir: Path | str,
verbose: bool = True,
) -> tuple[pd.DataFrame, pd.DataFrame]:
with torch.no_grad():
Expand All @@ -257,9 +258,12 @@ def load_and_run_eval_probe_and_sae_k_sparse_raw_scores(
)
if verbose:
print("Loading probe and training data", flush=True)
probe = load_probe(task="first_letter", layer=sae_info.layer)
probe = load_probe(
task="first_letter", layer=sae_info.layer, probes_dir=probes_dir
)
train_activations, train_data = load_probe_data_split(
tokenizer,
probes_dir=probes_dir,
task="first_letter",
layer=sae_info.layer,
split="train",
Expand All @@ -277,6 +281,7 @@ def load_and_run_eval_probe_and_sae_k_sparse_raw_scores(
print("Loading validation data", flush=True)
eval_activations, eval_data = load_probe_data_split(
tokenizer,
probes_dir=probes_dir,
task="first_letter",
layer=sae_info.layer,
split="test",
Expand Down Expand Up @@ -311,20 +316,11 @@ def build_f1_and_auroc_df(results_df, metadata_df):
f1_probe = metrics.f1_score(y, pred_probe > 0.0)
recall_probe = metrics.recall_score(y, pred_probe > 0.0)
precision_probe = metrics.precision_score(y, pred_probe > 0.0)
best_f1_bias_probe, f1_probe = find_optimal_f1_threshold(y, pred_probe)
recall_probe_best = metrics.recall_score(y, pred_probe > best_f1_bias_probe)
precision_probe_best = metrics.precision_score(
y, pred_probe > best_f1_bias_probe
)
auc_info = {
"auc_probe": auc_probe,
"f1_probe": f1_probe,
"f1_probe_best": f1_probe,
"recall_probe": recall_probe,
"precision_probe": precision_probe,
"recall_probe_best": recall_probe_best,
"precision_probe_best": precision_probe_best,
"bias_f1_probe_best": best_f1_bias_probe,
"letter": letter,
"layer": metadata_df["layer"].iloc[0],
"sae_width": metadata_df["sae_width"].iloc[0],
Expand All @@ -338,9 +334,6 @@ def build_f1_and_auroc_df(results_df, metadata_df):
recall = metrics.recall_score(y, pred_sae > 0.0)
precision = metrics.precision_score(y, pred_sae > 0.0)
auc_info[f"auc_sparse_sae_{k}"] = auc_sae
best_f1_bias_sae, f1_sae_best = find_optimal_f1_threshold(y, pred_sae)
recall_sae_best = metrics.recall_score(y, pred_sae > best_f1_bias_sae)
precision_sae_best = metrics.precision_score(y, pred_sae > best_f1_bias_sae)
sum_sae_pred = results_df[f"sum_sparse_sae_{letter}_k_{k}"].values
auc_sum_sae = metrics.roc_auc_score(y, sum_sae_pred)
f1_sum_sae = metrics.f1_score(y, sum_sae_pred > EPS)
Expand All @@ -349,11 +342,7 @@ def build_f1_and_auroc_df(results_df, metadata_df):

auc_info[f"f1_sparse_sae_{k}"] = f1
auc_info[f"recall_sparse_sae_{k}"] = recall
auc_info[f"recall_sparse_sae_{k}_best"] = recall_sae_best
auc_info[f"precision_sparse_sae_{k}"] = precision
auc_info[f"precision_sparse_sae_{k}_best"] = precision_sae_best
auc_info[f"f1_sparse_sae_{k}_best"] = f1_sae_best
auc_info[f"bias_f1_sparse_sae_{k}_best"] = best_f1_bias_sae
auc_info[f"auc_sum_sparse_sae_{k}"] = auc_sum_sae
auc_info[f"f1_sum_sparse_sae_{k}"] = f1_sum_sae
auc_info[f"recall_sum_sparse_sae_{k}"] = recall_sum_sae
Expand Down Expand Up @@ -549,6 +538,7 @@ def get_sparse_probing_auroc_f1_results_filename(sae_info: SaeInfo) -> str:
def run_k_sparse_probing_experiments(
layers: list[int],
experiment_dir: Path | str = EXPERIMENTS_DIR / SPARSE_PROBING_EXPERIMENT_NAME,
probes_dir: Path | str = PROBES_DIR,
task: str = "first_letter",
force: bool = False,
skip_1m_saes: bool = True,
Expand Down Expand Up @@ -594,7 +584,7 @@ def run_k_sparse_probing_experiments(
def get_raw_results_df():
return load_dfs_or_run(
lambda: load_and_run_eval_probe_and_sae_k_sparse_raw_scores(
sae_info, tokenizer, verbose=verbose
sae_info, tokenizer, probes_dir, verbose=verbose
),
(raw_results_path, metadata_results_path),
force=force,
Expand Down
Loading

0 comments on commit d3631a6

Please sign in to comment.