From ed5d79148fa161c2e69e959f4ea7a8d9d3a87290 Mon Sep 17 00:00:00 2001 From: Anthony Duong <42191920+anthonyduong9@users.noreply.github.com> Date: Mon, 7 Oct 2024 22:51:01 +0100 Subject: [PATCH] chore: adds black-jupyter to dependencies (#318) * adds black-jupyter to dependencies * formats notebooks --- check_open_ai_sae_metrics.ipynb | 36 +- pyproject.toml | 2 +- ...joseph_curt_pairing_gemma_scope_saes.ipynb | 128 +- scripts/wandb_to_hf.ipynb | 21 +- tutorials/Hooked_SAE_Transformer_Demo.ipynb | 360 +- tutorials/basic_loading_and_analysing.ipynb | 989 ++-- tutorials/loading_tanh_sae.ipynb | 113 +- tutorials/logits_lens_with_features.ipynb | 1781 +++---- tutorials/pretokenizing_datasets.ipynb | 9 +- tutorials/training_a_gated_sae.ipynb | 1338 +++--- tutorials/training_a_sparse_autoencoder.ipynb | 1664 +++---- tutorials/tutorial_2_0.ipynb | 441 +- tutorials/uploading_saes_to_huggingface.ipynb | 4 +- .../using_an_sae_as_a_steering_vector.ipynb | 4235 +++++++++-------- 14 files changed, 5725 insertions(+), 5396 deletions(-) diff --git a/check_open_ai_sae_metrics.ipynb b/check_open_ai_sae_metrics.ipynb index 9612944e..8b7fef44 100644 --- a/check_open_ai_sae_metrics.ipynb +++ b/check_open_ai_sae_metrics.ipynb @@ -7,12 +7,12 @@ "outputs": [], "source": [ "from sae_lens.toolkit.pretrained_saes import load_sparsity\n", - "import plotly.express as px \n", + "import plotly.express as px\n", "\n", "path = \"open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_0\"\n", "\n", "\n", - "sparsity = load_sparsity(path)#[\"sparsity\"]\n", + "sparsity = load_sparsity(path) # [\"sparsity\"]\n", "\n", "\n", "px.histogram(sparsity.cpu(), nbins=100).write_html(\"sparsity_histogram.html\")" @@ -4226,23 +4226,28 @@ } ], "source": [ - "import pandas as pd \n", + "import pandas as pd\n", + "\n", "# get all json files in all subfolders of the mother path\n", "import os\n", "import json\n", "from IPython.display import display\n", "import imgkit\n", "\n", + "\n", "def get_all_json_files(mother_path):\n", " json_files = []\n", - " \n", + "\n", " for root, dirs, files in os.walk(mother_path):\n", " for file in files:\n", " if file.endswith(\"metrics.json\"):\n", " json_files.append(os.path.join(root, file))\n", " return json_files\n", "\n", - "def get_benchmark_stats_csv(mother_path = \"open_ai_sae_weights_resid_post_attn_reformatted\"):\n", + "\n", + "def get_benchmark_stats_csv(\n", + " mother_path=\"open_ai_sae_weights_resid_post_attn_reformatted\",\n", + "):\n", " json_files = get_all_json_files(mother_path)\n", " eval_metrics = {}\n", "\n", @@ -4250,33 +4255,34 @@ " with open(file, \"r\") as f:\n", " data = json.load(f)\n", " eval_metrics[file] = data\n", - " \n", - " \n", + "\n", " df = pd.DataFrame(eval_metrics).T\n", - " df[\"filepath\"]=df.index\n", + " df[\"filepath\"] = df.index\n", " df.head()\n", " pattern = r\".*/v(\\d+)_(\\d+)k_layer_(\\d+)/metrics\\.json\"\n", "\n", - " df[['version', 'd_sae', 'layer']] = df.filepath.str.extract(pattern)\n", + " df[[\"version\", \"d_sae\", \"layer\"]] = df.filepath.str.extract(pattern)\n", " # move these columns to the start\n", " cols = df.columns.tolist()\n", " cols = cols[-3:] + cols[:-3]\n", " df = df[cols]\n", " df[\"layer\"] = df[\"layer\"].astype(int)\n", - " \n", + "\n", " # remove \"metrics\" prefix from the columns\n", " df.columns = [i.replace(\"metrics/\", \"\") for i in df.columns]\n", " df.sort_values(by=[\"version\", \"d_sae\", \"layer\"], inplace=True)\n", " df.to_csv(os.path.join(mother_path, \"benchmark_stats.csv\"))\n", - " df.style.background_gradient(cmap='viridis', axis=0).to_html(os.path.join(mother_path, \"benchmark_stats.html\"))\n", - " \n", + " df.style.background_gradient(cmap=\"viridis\", axis=0).to_html(\n", + " os.path.join(mother_path, \"benchmark_stats.html\")\n", + " )\n", + "\n", " # read the html\n", " with open(os.path.join(mother_path, \"benchmark_stats.html\"), \"r\") as f:\n", " html = f.read()\n", " imgkit.from_string(html, os.path.join(mother_path, \"benchmark_stats.png\"))\n", - " \n", - " return df.style.background_gradient(cmap='viridis', axis=0)\n", - " \n", + "\n", + " return df.style.background_gradient(cmap=\"viridis\", axis=0)\n", + "\n", "\n", "# list all paths that start with OAI in the current fold\n", "paths = [i for i in os.listdir(\".\") if i.startswith(\"OAI\")]\n", diff --git a/pyproject.toml b/pyproject.toml index f79ed43a..5cd650e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ typing-extensions = "^4.10.0" [tool.poetry.group.dev.dependencies] -black = "24.4.0" +black = { version = "24.4.0", extras = ["jupyter"] } pytest = "^8.0.2" pytest-cov = "^4.1.0" pre-commit = "^3.6.2" diff --git a/scripts/joseph_curt_pairing_gemma_scope_saes.ipynb b/scripts/joseph_curt_pairing_gemma_scope_saes.ipynb index f0817815..df4f6421 100644 --- a/scripts/joseph_curt_pairing_gemma_scope_saes.ipynb +++ b/scripts/joseph_curt_pairing_gemma_scope_saes.ipynb @@ -5,9 +5,7 @@ "execution_count": 6, "metadata": {}, "outputs": [], - "source": [ - "\n" - ] + "source": [] }, { "cell_type": "code", @@ -27,57 +25,60 @@ } ], "source": [ - "import re \n", + "import re\n", "import pandas as pd\n", "\n", "\n", "from huggingface_hub import HfApi\n", "import os\n", "\n", + "\n", "def list_repo_files(repo_id):\n", " api = HfApi()\n", " repo_files = api.list_repo_files(repo_id)\n", " return repo_files\n", "\n", + "\n", "files = list_repo_files(repo_id)\n", "\n", "# print(f\"Files in the repository '{repo_id}':\")\n", "# for file in files:\n", "# print(file)\n", "\n", + "\n", "def get_details_from_file_path(file_path):\n", - " '''\n", + " \"\"\"\n", " eg: layer_11/width_16k/average_l0_79\n", - " \n", + "\n", " layer = 11\n", " width = 16k\n", " l0_or_canonical = \"79\"\n", - " \n", + "\n", " or if layer_11/width_16k/canonical\n", - " \n", + "\n", " layer = 11\n", " width = 16k\n", " l0_or_canonical = \"canonical\"\n", - " \n", + "\n", " or if layer_11/width_1m/average_l0_79\n", - " \n", + "\n", " layer = 11\n", " width = 1m\n", " l0_or_canonical = \"79\"\n", - " '''\n", - " \n", - " layer = re.search(r'layer_(\\d+)', file_path).group(1)\n", - " width = re.search(r'width_(\\d+[k|m])', file_path).group(1)\n", - " l0 = re.search(r'average_l0_(\\d+)', file_path)\n", + " \"\"\"\n", + "\n", + " layer = re.search(r\"layer_(\\d+)\", file_path).group(1)\n", + " width = re.search(r\"width_(\\d+[k|m])\", file_path).group(1)\n", + " l0 = re.search(r\"average_l0_(\\d+)\", file_path)\n", " if l0:\n", " l0 = l0.group(1)\n", " else:\n", - " l0 = re.search(r'(canonical)', file_path).group(1)\n", - " \n", - " \n", + " l0 = re.search(r\"(canonical)\", file_path).group(1)\n", + "\n", " return layer, width, l0\n", "\n", - "# # test it \n", + "\n", + "# # test it\n", "# file_path = 'layer_11/width_16k/average_l0_79'\n", "# layer, width, l0 = get_details_from_file_path(file_path)\n", "# print(f\"layer: {layer}, width: {width}, l0: {l0}\")\n", @@ -93,28 +94,28 @@ "# print(f\"layer: {layer}, width: {width}, l0: {l0}\")\n", "\n", "\n", - "\n", "def generate_entries(repo_id):\n", " entries = []\n", " files = list_repo_files(repo_id)\n", " for file in files:\n", - " if 'params.npz' in file:\n", + " if \"params.npz\" in file:\n", " entry = {}\n", " # print(file)\n", " layer, width, l0 = get_details_from_file_path(file)\n", " folder_path = os.path.dirname(file)\n", - " entry['repo_id'] = repo_id\n", - " entry['id'] = folder_path\n", - " entry['path'] = folder_path\n", - " entry['l0'] = l0\n", - " entry['layer'] = layer\n", - " entry['width'] = width\n", - " \n", + " entry[\"repo_id\"] = repo_id\n", + " entry[\"id\"] = folder_path\n", + " entry[\"path\"] = folder_path\n", + " entry[\"l0\"] = l0\n", + " entry[\"layer\"] = layer\n", + " entry[\"width\"] = width\n", + "\n", " entries.append(entry)\n", " return entries\n", "\n", - "def df_to_yaml(df, file_path, canonical = False):\n", - " '''\n", + "\n", + "def df_to_yaml(df, file_path, canonical=False):\n", + " \"\"\"\n", " EXAMPLE STRUCTURE:\n", "\n", " gemma-scope-2b-pt-res:\n", @@ -126,11 +127,13 @@ " path: layer_11/width_16k/average_l0_79\n", " l0: 79.0\n", "\n", - " '''\n", - " repo_id = df.iloc[0]['repo_id']\n", - " release_id = repo_id.split('/')[1] + '-canonical' if canonical else repo_id.split('/')[1]\n", - " with open(file_path, 'w') as f:\n", - " \n", + " \"\"\"\n", + " repo_id = df.iloc[0][\"repo_id\"]\n", + " release_id = (\n", + " repo_id.split(\"/\")[1] + \"-canonical\" if canonical else repo_id.split(\"/\")[1]\n", + " )\n", + " with open(file_path, \"w\") as f:\n", + "\n", " f.write(f\"{release_id}:\\n\")\n", " f.write(f\" repo_id: {repo_id}\\n\")\n", " f.write(f\" model: gemma-2-2b\\n\")\n", @@ -139,7 +142,7 @@ " for index, row in df.iterrows():\n", " f.write(f\" - id: {row['id']}\\n\")\n", " f.write(f\" path: {row['path']}\\n\")\n", - " if row['l0'] != 'canonical':\n", + " if row[\"l0\"] != \"canonical\":\n", " f.write(f\" l0: {row['l0']}\\n\")\n", " # f.write(f\" l0: {row['l0']}\\n\")\n", " # f.write(f\" layer: {row['layer']}\\n\")\n", @@ -154,7 +157,7 @@ " \"google/gemma-scope-9b-pt-res\",\n", " \"google/gemma-scope-9b-pt-mlp\",\n", " \"google/gemma-scope-9b-pt-att\",\n", - " \"google/gemma-scope-27b-pt-res\"\n", + " \"google/gemma-scope-27b-pt-res\",\n", "]\n", "\n", "for repo_id in repo_ids:\n", @@ -162,21 +165,26 @@ " entries = generate_entries(repo_id)\n", "\n", " df = pd.DataFrame(entries)\n", - " df[\"layer\"]= pd.to_numeric(df[\"layer\"])\n", - " df.sort_values(by=['width', 'layer', 'l0'], inplace=True)\n", + " df[\"layer\"] = pd.to_numeric(df[\"layer\"])\n", + " df.sort_values(by=[\"width\", \"layer\", \"l0\"], inplace=True)\n", " df.head(30)\n", "\n", - " canonical_only_df = df[df['l0'] == 'canonical']\n", - " non_canonical_df = df[df['l0'] != 'canonical']\n", + " canonical_only_df = df[df[\"l0\"] == \"canonical\"]\n", + " non_canonical_df = df[df[\"l0\"] != \"canonical\"]\n", "\n", - " df_to_yaml(non_canonical_df, f'{repo_id.split(\"/\")[1]}_not_canonical.yaml', canonical=False)\n", + " df_to_yaml(\n", + " non_canonical_df, f'{repo_id.split(\"/\")[1]}_not_canonical.yaml', canonical=False\n", + " )\n", " if canonical_only_df.shape[0] == 0:\n", " print(f\"No canonical entries found in {repo_id.split('/')[1]}\")\n", " continue\n", " else:\n", - " df_to_yaml(canonical_only_df, f'{repo_id.split(\"/\")[1]}_canonical_only.yaml', canonical=True) \n", + " df_to_yaml(\n", + " canonical_only_df,\n", + " f'{repo_id.split(\"/\")[1]}_canonical_only.yaml',\n", + " canonical=True,\n", + " )\n", "\n", - " \n", " # !cat canonical_only.yaml" ] }, @@ -229,7 +237,7 @@ "\n", "\n", "# Path to the YAML file\n", - "yaml_file = 'pretrained_saes.yaml'\n", + "yaml_file = \"pretrained_saes.yaml\"\n", "\n", "# Initialize yamel.yaml\n", "yaml = YAML()\n", @@ -237,7 +245,7 @@ "yaml.indent(mapping=2, sequence=4, offset=2)\n", "\n", "# Read the existing YAML file\n", - "with open(yaml_file, 'r') as file:\n", + "with open(yaml_file, \"r\") as file:\n", " data = yaml.load(file)\n", "\n", "# Generate new entries\n", @@ -245,20 +253,20 @@ "\n", "# Create a CommentedMap for gemmascope-2b-pt-res\n", "gemmascope_data = CommentedMap()\n", - "gemmascope_data['repo_id'] = \"gg-hf/gemmascope-2b-pt-res\"\n", - "gemmascope_data['model'] = \"gemma-2-2b\"\n", - "gemmascope_data['conversion_func'] = \"gemma_2\"\n", - "gemmascope_data['saes'] = new_entries\n", + "gemmascope_data[\"repo_id\"] = \"gg-hf/gemmascope-2b-pt-res\"\n", + "gemmascope_data[\"model\"] = \"gemma-2-2b\"\n", + "gemmascope_data[\"conversion_func\"] = \"gemma_2\"\n", + "gemmascope_data[\"saes\"] = new_entries\n", "\n", "# Remove the existing gemmascope-2b-pt-res entry if it exists\n", - "if 'SAE_LOOKUP' in data and 'gemmascope-2b-pt-res' in data['SAE_LOOKUP']:\n", - " del data['SAE_LOOKUP']['gemmascope-2b-pt-res']\n", + "if \"SAE_LOOKUP\" in data and \"gemmascope-2b-pt-res\" in data[\"SAE_LOOKUP\"]:\n", + " del data[\"SAE_LOOKUP\"][\"gemmascope-2b-pt-res\"]\n", "\n", "# Add gemmascope-2b-pt-res at the end\n", - "data['SAE_LOOKUP']['gemmascope-2b-pt-res'] = gemmascope_data\n", + "data[\"SAE_LOOKUP\"][\"gemmascope-2b-pt-res\"] = gemmascope_data\n", "\n", "# Write the updated YAML file\n", - "with open(yaml_file, 'w') as file:\n", + "with open(yaml_file, \"w\") as file:\n", " yaml.dump(data, file)\n", "\n", "print(f\"YAML file updated: {yaml_file}\")" @@ -293,19 +301,19 @@ } ], "source": [ - "from sae_lens import HookedSAETransformer, SAE \n", + "from sae_lens import HookedSAETransformer, SAE\n", "\n", "device = \"cuda\"\n", "\n", - "model = HookedSAETransformer.from_pretrained(\"gemma-2-2b\", device = device)\n", + "model = HookedSAETransformer.from_pretrained(\"gemma-2-2b\", device=device)\n", "\n", "# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)\n", "# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict\n", - "# We also return the feature sparsities which are stored in HF for convenience. \n", + "# We also return the feature sparsities which are stored in HF for convenience.\n", "sae, cfg_dict, sparsity = SAE.from_pretrained(\n", - " release = \"gemma-scope-9b-pt-mlp\", # <- Release name \n", - " sae_id = \"layer_2/width_131k/average_l0_12\", # <- SAE id (not always a hook point!)\n", - " device = device\n", + " release=\"gemma-scope-9b-pt-mlp\", # <- Release name\n", + " sae_id=\"layer_2/width_131k/average_l0_12\", # <- SAE id (not always a hook point!)\n", + " device=device,\n", ")" ] }, diff --git a/scripts/wandb_to_hf.ipynb b/scripts/wandb_to_hf.ipynb index a66dcf6d..fc9c3735 100644 --- a/scripts/wandb_to_hf.ipynb +++ b/scripts/wandb_to_hf.ipynb @@ -68,8 +68,7 @@ "metadata": {}, "outputs": [], "source": [ - "api = wandb.Api()\n", - "\n" + "api = wandb.Api()" ] }, { @@ -145,10 +144,9 @@ " for a in run.logged_artifacts():\n", " if is_model(a, model_endstr):\n", " return a\n", - " \n", + "\n", " else:\n", - " continue\n", - " " + " continue" ] }, { @@ -162,7 +160,7 @@ " for a in run.logged_artifacts():\n", " if is_sparsity(a, sparsity_endstr):\n", " return a\n", - " \n", + "\n", " else:\n", " continue" ] @@ -212,6 +210,7 @@ "outputs": [], "source": [ "from huggingface_hub import HfApi\n", + "\n", "api = HfApi()" ] }, @@ -222,12 +221,12 @@ "outputs": [], "source": [ "def upload_to_hf(model_path):\n", - " repo_path = model_path.split('/')[-1]\n", + " repo_path = model_path.split(\"/\")[-1]\n", " api.upload_folder(\n", - " folder_path=model_path,\n", - " repo_id=hf_repo_id,\n", - " path_in_repo=repo_path,\n", - " token=hf_token,\n", + " folder_path=model_path,\n", + " repo_id=hf_repo_id,\n", + " path_in_repo=repo_path,\n", + " token=hf_token,\n", " )" ] }, diff --git a/tutorials/Hooked_SAE_Transformer_Demo.ipynb b/tutorials/Hooked_SAE_Transformer_Demo.ipynb index e5ef35cc..1771410c 100644 --- a/tutorials/Hooked_SAE_Transformer_Demo.ipynb +++ b/tutorials/Hooked_SAE_Transformer_Demo.ipynb @@ -42,10 +42,11 @@ "DEVELOPMENT_MODE = False\n", "try:\n", " import google.colab\n", + "\n", " IN_COLAB = True\n", " print(\"Running as a Colab notebook\")\n", " %pip install git+https://github.com/jbloomAus/SAELens\n", - " \n", + "\n", "except:\n", " IN_COLAB = False\n", " print(\"Running as a Jupyter notebook - intended for development only!\")\n", @@ -73,11 +74,29 @@ "import plotly.graph_objects as go\n", "\n", "update_layout_set = {\n", - " \"xaxis_range\", \"yaxis_range\", \"hovermode\", \"xaxis_title\", \"yaxis_title\", \"colorbar\", \"colorscale\", \"coloraxis\",\n", - " \"title_x\", \"bargap\", \"bargroupgap\", \"xaxis_tickformat\", \"yaxis_tickformat\", \"title_y\", \"legend_title_text\", \"xaxis_showgrid\",\n", - " \"xaxis_gridwidth\", \"xaxis_gridcolor\", \"yaxis_showgrid\", \"yaxis_gridwidth\"\n", + " \"xaxis_range\",\n", + " \"yaxis_range\",\n", + " \"hovermode\",\n", + " \"xaxis_title\",\n", + " \"yaxis_title\",\n", + " \"colorbar\",\n", + " \"colorscale\",\n", + " \"coloraxis\",\n", + " \"title_x\",\n", + " \"bargap\",\n", + " \"bargroupgap\",\n", + " \"xaxis_tickformat\",\n", + " \"yaxis_tickformat\",\n", + " \"title_y\",\n", + " \"legend_title_text\",\n", + " \"xaxis_showgrid\",\n", + " \"xaxis_gridwidth\",\n", + " \"xaxis_gridcolor\",\n", + " \"yaxis_showgrid\",\n", + " \"yaxis_gridwidth\",\n", "}\n", "\n", + "\n", "def imshow(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n", " if isinstance(tensor, list):\n", " tensor = torch.stack(tensor)\n", @@ -89,43 +108,67 @@ " facet_labels = None\n", " if \"color_continuous_scale\" not in kwargs_pre:\n", " kwargs_pre[\"color_continuous_scale\"] = \"RdBu\"\n", - " fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0,labels={\"x\":xaxis, \"y\":yaxis}, **kwargs_pre).update_layout(**kwargs_post)\n", + " fig = px.imshow(\n", + " utils.to_numpy(tensor),\n", + " color_continuous_midpoint=0.0,\n", + " labels={\"x\": xaxis, \"y\": yaxis},\n", + " **kwargs_pre,\n", + " ).update_layout(**kwargs_post)\n", " if facet_labels:\n", " for i, label in enumerate(facet_labels):\n", - " fig.layout.annotations[i]['text'] = label\n", + " fig.layout.annotations[i][\"text\"] = label\n", "\n", " fig.show(renderer)\n", "\n", - "def scatter(x, y, xaxis=\"\", yaxis=\"\", caxis=\"\", renderer=None, return_fig=False, **kwargs):\n", + "\n", + "def scatter(\n", + " x, y, xaxis=\"\", yaxis=\"\", caxis=\"\", renderer=None, return_fig=False, **kwargs\n", + "):\n", " x = utils.to_numpy(x)\n", " y = utils.to_numpy(y)\n", - " fig = px.scatter(y=y, x=x, labels={\"x\":xaxis, \"y\":yaxis, \"color\":caxis}, **kwargs)\n", + " fig = px.scatter(\n", + " y=y, x=x, labels={\"x\": xaxis, \"y\": yaxis, \"color\": caxis}, **kwargs\n", + " )\n", " if return_fig:\n", " return fig\n", " fig.show(renderer)\n", "\n", + "\n", "from typing import List\n", - "def show_avg_logit_diffs(x_axis: List[str], per_prompt_logit_diffs: List[torch.tensor]):\n", "\n", "\n", - " y_data = [per_prompt_logit_diff.mean().item() for per_prompt_logit_diff in per_prompt_logit_diffs]\n", - " error_y_data = [per_prompt_logit_diff.std().item() for per_prompt_logit_diff in per_prompt_logit_diffs] \n", + "def show_avg_logit_diffs(x_axis: List[str], per_prompt_logit_diffs: List[torch.tensor]):\n", "\n", - " fig = go.Figure(data=[go.Bar(\n", - " x=x_axis,\n", - " y=y_data,\n", - " error_y=dict(\n", - " type='data', # specifies that the actual values are given\n", - " array=error_y_data, # the magnitudes of the errors\n", - " visible=True # make error bars visible\n", - " ),\n", - " )])\n", + " y_data = [\n", + " per_prompt_logit_diff.mean().item()\n", + " for per_prompt_logit_diff in per_prompt_logit_diffs\n", + " ]\n", + " error_y_data = [\n", + " per_prompt_logit_diff.std().item()\n", + " for per_prompt_logit_diff in per_prompt_logit_diffs\n", + " ]\n", + "\n", + " fig = go.Figure(\n", + " data=[\n", + " go.Bar(\n", + " x=x_axis,\n", + " y=y_data,\n", + " error_y=dict(\n", + " type=\"data\", # specifies that the actual values are given\n", + " array=error_y_data, # the magnitudes of the errors\n", + " visible=True, # make error bars visible\n", + " ),\n", + " )\n", + " ]\n", + " )\n", "\n", " # Customize layout\n", - " fig.update_layout(title_text=f'Logit Diff after Interventions',\n", - " xaxis_title_text='Intervention',\n", - " yaxis_title_text='Logit diff',\n", - " plot_bgcolor='white')\n", + " fig.update_layout(\n", + " title_text=f\"Logit Diff after Interventions\",\n", + " xaxis_title_text=\"Intervention\",\n", + " yaxis_title_text=\"Logit diff\",\n", + " plot_bgcolor=\"white\",\n", + " )\n", "\n", " # Show the figure\n", " fig.show()" @@ -141,7 +184,7 @@ " device = \"cuda\"\n", "# elif torch.backends.mps.is_available():\n", "# device = \"mps\"\n", - "else: \n", + "else:\n", " device = \"cpu\"\n", "torch.set_grad_enabled(False)" ] @@ -167,7 +210,10 @@ "outputs": [], "source": [ "from sae_lens import HookedSAETransformer\n", - "model: HookedSAETransformer = HookedSAETransformer.from_pretrained(\"gpt2-small\").to(device)" + "\n", + "model: HookedSAETransformer = HookedSAETransformer.from_pretrained(\"gpt2-small\").to(\n", + " device\n", + ")" ] }, { @@ -190,7 +236,10 @@ " \"After Martin and Amy went to the park,{} gave a drink to\",\n", "]\n", "names = [\n", - " (\" John\", \" Mary\",),\n", + " (\n", + " \" John\",\n", + " \" Mary\",\n", + " ),\n", " (\" Tom\", \" James\"),\n", " (\" Dan\", \" Sid\"),\n", " (\" Martin\", \" Amy\"),\n", @@ -232,12 +281,15 @@ " return answer_logit_diff\n", " else:\n", " return answer_logit_diff.mean()\n", - " \n", + "\n", + "\n", "tokens = model.to_tokens(prompts, prepend_bos=True)\n", "original_logits, cache = model.run_with_cache(tokens)\n", "original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)\n", "print(f\"Original average logit diff: {original_average_logit_diff}\")\n", - "original_per_prompt_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)\n", + "original_per_prompt_logit_diff = logits_to_ave_logit_diff(\n", + " original_logits, answer_tokens, per_prompt=True\n", + ")\n", "print(f\"Original per prompt logit diff: {original_per_prompt_logit_diff}\")" ] }, @@ -289,7 +341,7 @@ " device=device,\n", " )\n", " hook_name_to_sae[sae.cfg.hook_name] = sae\n", - " \n", + "\n", "\n", "print(hook_name_to_sae.keys())" ] @@ -328,18 +380,24 @@ "metadata": {}, "outputs": [], "source": [ - "all_layers = [[0, 3], [2, 4], [5,6], [7, 8], [9, 10, 11]]\n", - "x_axis = ['Clean Baseline']\n", + "all_layers = [[0, 3], [2, 4], [5, 6], [7, 8], [9, 10, 11]]\n", + "x_axis = [\"Clean Baseline\"]\n", "per_prompt_logit_diffs = [\n", - " original_per_prompt_logit_diff, \n", + " original_per_prompt_logit_diff,\n", "]\n", "\n", "for layers in all_layers:\n", - " hooked_saes = [hook_name_to_sae[utils.get_act_name('z', layer)] for layer in layers]\n", - " logits_with_saes = model.run_with_saes(tokens, saes=hooked_saes, use_error_term=None)\n", - " average_logit_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens)\n", - " per_prompt_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens, per_prompt=True)\n", - " \n", + " hooked_saes = [hook_name_to_sae[utils.get_act_name(\"z\", layer)] for layer in layers]\n", + " logits_with_saes = model.run_with_saes(\n", + " tokens, saes=hooked_saes, use_error_term=None\n", + " )\n", + " average_logit_diff_with_saes = logits_to_ave_logit_diff(\n", + " logits_with_saes, answer_tokens\n", + " )\n", + " per_prompt_diff_with_saes = logits_to_ave_logit_diff(\n", + " logits_with_saes, answer_tokens, per_prompt=True\n", + " )\n", + "\n", " x_axis.append(f\"With SAEs L{layers}\")\n", " per_prompt_logit_diffs.append(per_prompt_diff_with_saes)\n", "\n", @@ -371,16 +429,17 @@ "outputs": [], "source": [ "layer, s2_pos = 5, 10\n", - "saes = [hook_name_to_sae[utils.get_act_name('z', layer)]]\n", + "saes = [hook_name_to_sae[utils.get_act_name(\"z\", layer)]]\n", "_, cache = model.run_with_cache_with_saes(tokens, saes=saes)\n", - "sae_acts = cache[utils.get_act_name('z', layer) + \".hook_sae_acts_post\"][:, s2_pos, :]\n", + "sae_acts = cache[utils.get_act_name(\"z\", layer) + \".hook_sae_acts_post\"][:, s2_pos, :]\n", "live_feature_mask = sae_acts > 0\n", "live_feature_union = live_feature_mask.any(dim=0)\n", "\n", "imshow(\n", " sae_acts[:, live_feature_union],\n", - " title = \"Activations of Live SAE features at L5 S2 position per prompt\",\n", - " xaxis=\"Feature Id\", yaxis=\"Prompt\",\n", + " title=\"Activations of Live SAE features at L5 S2 position per prompt\",\n", + " xaxis=\"Feature Id\",\n", + " yaxis=\"Prompt\",\n", " x=list(map(str, live_feature_union.nonzero().flatten().tolist())),\n", ")" ] @@ -416,16 +475,19 @@ "source": [ "def ablate_sae_feature(sae_acts, hook, pos, feature_id):\n", " if pos is None:\n", - " sae_acts[:, :, feature_id] = 0.\n", + " sae_acts[:, :, feature_id] = 0.0\n", " else:\n", - " sae_acts[:, pos, feature_id] = 0.\n", + " sae_acts[:, pos, feature_id] = 0.0\n", " return sae_acts\n", "\n", + "\n", "layer = 5\n", - "sae = hook_name_to_sae[utils.get_act_name('z', layer)]\n", + "sae = hook_name_to_sae[utils.get_act_name(\"z\", layer)]\n", "\n", "logits_with_saes = model.run_with_saes(tokens, saes=sae)\n", - "clean_sae_baseline_per_prompt = logits_to_ave_logit_diff(logits_with_saes, answer_tokens, per_prompt=True)\n", + "clean_sae_baseline_per_prompt = logits_to_ave_logit_diff(\n", + " logits_with_saes, answer_tokens, per_prompt=True\n", + ")\n", "\n", "all_live_features = torch.arange(sae.cfg.d_sae)[live_feature_union.cpu()]\n", "\n", @@ -433,20 +495,35 @@ "fid_to_idx = {fid.item(): idx for idx, fid in enumerate(all_live_features)}\n", "\n", "\n", - "abl_layer, abl_pos = 5, 10\n", + "abl_layer, abl_pos = 5, 10\n", "for feature_id in tqdm.tqdm(all_live_features):\n", " feature_id = feature_id.item()\n", " abl_feature_logits = model.run_with_hooks_with_saes(\n", " tokens,\n", " saes=sae,\n", - " fwd_hooks=[(utils.get_act_name('z', abl_layer) + \".hook_sae_acts_post\", partial(ablate_sae_feature, pos=abl_pos, feature_id=feature_id))]\n", - " ) # [batch, seq, vocab]\n", - " \n", - " abl_feature_logit_diff = logits_to_ave_logit_diff(abl_feature_logits, answer_tokens, per_prompt=True) # [batch]\n", - " causal_effects[:, fid_to_idx[feature_id]] = abl_feature_logit_diff - clean_sae_baseline_per_prompt\n", + " fwd_hooks=[\n", + " (\n", + " utils.get_act_name(\"z\", abl_layer) + \".hook_sae_acts_post\",\n", + " partial(ablate_sae_feature, pos=abl_pos, feature_id=feature_id),\n", + " )\n", + " ],\n", + " ) # [batch, seq, vocab]\n", + "\n", + " abl_feature_logit_diff = logits_to_ave_logit_diff(\n", + " abl_feature_logits, answer_tokens, per_prompt=True\n", + " ) # [batch]\n", + " causal_effects[:, fid_to_idx[feature_id]] = (\n", + " abl_feature_logit_diff - clean_sae_baseline_per_prompt\n", + " )\n", "\n", "\n", - "imshow(causal_effects, title=f\"Change in logit diff when ablating L{abl_layer} SAE features for all prompts at pos {abl_pos}\", xaxis=\"Feature Idx\", yaxis=\"Prompt Idx\", x=list(map(str, all_live_features.tolist())))" + "imshow(\n", + " causal_effects,\n", + " title=f\"Change in logit diff when ablating L{abl_layer} SAE features for all prompts at pos {abl_pos}\",\n", + " xaxis=\"Feature Idx\",\n", + " yaxis=\"Prompt Idx\",\n", + " x=list(map(str, all_live_features.tolist())),\n", + ")" ] }, { @@ -482,7 +559,7 @@ "source": [ "print(\"Attached SAEs before add_sae\", model.acts_to_saes)\n", "layer = 5\n", - "sae = hook_name_to_sae[utils.get_act_name('z', layer)]\n", + "sae = hook_name_to_sae[utils.get_act_name(\"z\", layer)]\n", "model.add_sae(sae)\n", "print(\"Attached SAEs after add_sae\", model.acts_to_saes)" ] @@ -505,7 +582,9 @@ "\n", "average_logit_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens)\n", "print(f\"Average logit diff with SAEs: {average_logit_diff_with_saes}\")\n", - "per_prompt_diff_with_saes = logits_to_ave_logit_diff(logits_with_saes, answer_tokens, per_prompt=True)" + "per_prompt_diff_with_saes = logits_to_ave_logit_diff(\n", + " logits_with_saes, answer_tokens, per_prompt=True\n", + ")" ] }, { @@ -531,15 +610,16 @@ "layer = 5\n", "_, cache = model.run_with_cache(tokens)\n", "s2_pos = 10\n", - "sae_acts = cache[utils.get_act_name('z', layer) + \".hook_sae_acts_post\"][:, s2_pos, :]\n", + "sae_acts = cache[utils.get_act_name(\"z\", layer) + \".hook_sae_acts_post\"][:, s2_pos, :]\n", "\n", "live_feature_mask = sae_acts > 0\n", "live_feature_union = live_feature_mask.any(dim=0)\n", "\n", "imshow(\n", " sae_acts[:, live_feature_union],\n", - " title = \"Activations of Live SAE features at L5 S2 position per prompt\",\n", - " xaxis=\"Feature Id\", yaxis=\"Prompt\",\n", + " title=\"Activations of Live SAE features at L5 S2 position per prompt\",\n", + " xaxis=\"Feature Id\",\n", + " yaxis=\"Prompt\",\n", " x=list(map(str, live_feature_union.nonzero().flatten().tolist())),\n", ")" ] @@ -578,35 +658,51 @@ " return v_input\n", "\n", "\n", - "s_inhib_heads = [(7, 3), (7, 9), (8,6), (8,10)]\n", + "s_inhib_heads = [(7, 3), (7, 9), (8, 6), (8, 10)]\n", "\n", "results = torch.zeros(tokens.shape[0], all_live_features.shape[0])\n", "\n", "W_O_cat = einops.rearrange(\n", - " model.W_O,\n", - " \"n_layers n_heads d_head d_model -> n_layers (n_heads d_head) d_model\"\n", + " model.W_O, \"n_layers n_heads d_head d_model -> n_layers (n_heads d_head) d_model\"\n", ")\n", "\n", "for feature_id in tqdm.tqdm(all_live_features):\n", " feature_id = feature_id.item()\n", - " feature_acts = cache[utils.get_act_name('z', abl_layer) + \".hook_sae_acts_post\"][:, abl_pos, feature_id] # [batch]\n", - " feature_dirs = (feature_acts.unsqueeze(-1) * sae.W_dec[feature_id]) @ W_O_cat[abl_layer]\n", + " feature_acts = cache[utils.get_act_name(\"z\", abl_layer) + \".hook_sae_acts_post\"][\n", + " :, abl_pos, feature_id\n", + " ] # [batch]\n", + " feature_dirs = (feature_acts.unsqueeze(-1) * sae.W_dec[feature_id]) @ W_O_cat[\n", + " abl_layer\n", + " ]\n", " hook_fns = [\n", - " (utils.get_act_name('v_input', layer), partial(path_patch_v_input, feature_dirs=feature_dirs, pos=abl_pos, head_index=head)) for (layer, head) in s_inhib_heads\n", + " (\n", + " utils.get_act_name(\"v_input\", layer),\n", + " partial(\n", + " path_patch_v_input,\n", + " feature_dirs=feature_dirs,\n", + " pos=abl_pos,\n", + " head_index=head,\n", + " ),\n", + " )\n", + " for (layer, head) in s_inhib_heads\n", " ]\n", " path_patched_logits = model.run_with_hooks(\n", - " tokens,\n", - " return_type=\"logits\",\n", - " fwd_hooks=hook_fns\n", + " tokens, return_type=\"logits\", fwd_hooks=hook_fns\n", " )\n", "\n", - " path_patched_logit_diff = logits_to_ave_logit_diff(path_patched_logits, answer_tokens, per_prompt=True)\n", - " results[:, fid_to_idx[feature_id]] = path_patched_logit_diff - clean_sae_baseline_per_prompt\n", + " path_patched_logit_diff = logits_to_ave_logit_diff(\n", + " path_patched_logits, answer_tokens, per_prompt=True\n", + " )\n", + " results[:, fid_to_idx[feature_id]] = (\n", + " path_patched_logit_diff - clean_sae_baseline_per_prompt\n", + " )\n", "\n", "imshow(\n", - " results, \n", + " results,\n", " title=f\"Change in logit diff when path patching features from S_inhibition heads values per prompts\",\n", - " xaxis=\"Feature Id\", yaxis=\"Prompt Idx\", x=list(map(str, all_live_features.tolist()))\n", + " xaxis=\"Feature Id\",\n", + " yaxis=\"Prompt Idx\",\n", + " x=list(map(str, all_live_features.tolist())),\n", ")" ] }, @@ -669,9 +765,10 @@ "outputs": [], "source": [ "import copy\n", - "l5_sae = hook_name_to_sae[utils.get_act_name('z', 5)]\n", + "\n", + "l5_sae = hook_name_to_sae[utils.get_act_name(\"z\", 5)]\n", "l5_sae_with_error = copy.deepcopy(l5_sae)\n", - "l5_sae_with_error.use_error_term=True\n", + "l5_sae_with_error.use_error_term = True\n", "model.add_sae(l5_sae_with_error)\n", "print(\"Attached SAEs after adding l5_sae_with_error:\", model.acts_to_saes)" ] @@ -710,51 +807,77 @@ "source": [ "def ablate_sae_feature(sae_acts, hook, pos, feature_id):\n", " if pos is None:\n", - " sae_acts[:, :, feature_id] = 0.\n", + " sae_acts[:, :, feature_id] = 0.0\n", " else:\n", - " sae_acts[:, pos, feature_id] = 0.\n", + " sae_acts[:, pos, feature_id] = 0.0\n", " return sae_acts\n", "\n", + "\n", "layer = 5\n", - "hooked_encoder = model.acts_to_saes[utils.get_act_name('z', layer)]\n", + "hooked_encoder = model.acts_to_saes[utils.get_act_name(\"z\", layer)]\n", "all_live_features = torch.arange(hooked_encoder.cfg.d_sae)[live_feature_union.cpu()]\n", "\n", "causal_effects = torch.zeros((len(prompts), all_live_features.shape[0]))\n", "fid_to_idx = {fid.item(): idx for idx, fid in enumerate(all_live_features)}\n", "\n", "\n", - "abl_layer, abl_pos = 5, 10\n", + "abl_layer, abl_pos = 5, 10\n", "for feature_id in tqdm.tqdm(all_live_features):\n", " feature_id = feature_id.item()\n", " abl_feature_logits = model.run_with_hooks(\n", " tokens,\n", " return_type=\"logits\",\n", - " fwd_hooks=[(utils.get_act_name('z', abl_layer) + \".hook_sae_acts_post\", partial(ablate_sae_feature, pos=abl_pos, feature_id=feature_id))]\n", - " ) # [batch, seq, vocab]\n", - " \n", - " abl_feature_logit_diff = logits_to_ave_logit_diff(abl_feature_logits, answer_tokens, per_prompt=True) # [batch]\n", - " causal_effects[:, fid_to_idx[feature_id]] = abl_feature_logit_diff - original_per_prompt_logit_diff\n", + " fwd_hooks=[\n", + " (\n", + " utils.get_act_name(\"z\", abl_layer) + \".hook_sae_acts_post\",\n", + " partial(ablate_sae_feature, pos=abl_pos, feature_id=feature_id),\n", + " )\n", + " ],\n", + " ) # [batch, seq, vocab]\n", + "\n", + " abl_feature_logit_diff = logits_to_ave_logit_diff(\n", + " abl_feature_logits, answer_tokens, per_prompt=True\n", + " ) # [batch]\n", + " causal_effects[:, fid_to_idx[feature_id]] = (\n", + " abl_feature_logit_diff - original_per_prompt_logit_diff\n", + " )\n", + "\n", "\n", "def able_sae_error(sae_error, hook, pos):\n", " if pos is None:\n", - " sae_error = 0.\n", + " sae_error = 0.0\n", " else:\n", - " sae_error[:, pos, ...] = 0.\n", + " sae_error[:, pos, ...] = 0.0\n", " return sae_error\n", "\n", "\n", "abl_error_logits = model.run_with_hooks(\n", " tokens,\n", " return_type=\"logits\",\n", - " fwd_hooks=[(utils.get_act_name('z', abl_layer) + \".hook_sae_error\", partial(able_sae_error, pos=abl_pos))]\n", - ") # [batch, seq, vocab]\n", + " fwd_hooks=[\n", + " (\n", + " utils.get_act_name(\"z\", abl_layer) + \".hook_sae_error\",\n", + " partial(able_sae_error, pos=abl_pos),\n", + " )\n", + " ],\n", + ") # [batch, seq, vocab]\n", "\n", - "abl_error_logit_diff = logits_to_ave_logit_diff(abl_error_logits, answer_tokens, per_prompt=True) # [batch]\n", + "abl_error_logit_diff = logits_to_ave_logit_diff(\n", + " abl_error_logits, answer_tokens, per_prompt=True\n", + ") # [batch]\n", "error_abl_effect = abl_error_logit_diff - original_per_prompt_logit_diff\n", "\n", "\n", - "causal_effects_with_error = torch.cat([causal_effects, error_abl_effect.unsqueeze(-1).cpu()], dim=-1)\n", - "imshow(causal_effects_with_error, title=f\"Change in logit diff when ablating L{abl_layer} SAE features for all prompts at pos {abl_pos}\", xaxis=\"Feature Idx\", yaxis=\"Prompt Idx\", x=list(map(str, all_live_features.tolist()))+[\"error\"])" + "causal_effects_with_error = torch.cat(\n", + " [causal_effects, error_abl_effect.unsqueeze(-1).cpu()], dim=-1\n", + ")\n", + "imshow(\n", + " causal_effects_with_error,\n", + " title=f\"Change in logit diff when ablating L{abl_layer} SAE features for all prompts at pos {abl_pos}\",\n", + " xaxis=\"Feature Idx\",\n", + " yaxis=\"Prompt Idx\",\n", + " x=list(map(str, all_live_features.tolist())) + [\"error\"],\n", + ")" ] }, { @@ -797,32 +920,50 @@ "outputs": [], "source": [ "from transformer_lens import ActivationCache\n", + "\n", "filter_sae_acts = lambda name: (\"hook_sae_acts_post\" in name)\n", + "\n", + "\n", "def get_cache_fwd_and_bwd(model, tokens, metric):\n", " model.reset_hooks()\n", " cache = {}\n", + "\n", " def forward_cache_hook(act, hook):\n", " cache[hook.name] = act.detach()\n", + "\n", " model.add_hook(filter_sae_acts, forward_cache_hook, \"fwd\")\n", "\n", " grad_cache = {}\n", + "\n", " def backward_cache_hook(act, hook):\n", " grad_cache[hook.name] = act.detach()\n", + "\n", " model.add_hook(filter_sae_acts, backward_cache_hook, \"bwd\")\n", "\n", " value = metric(model(tokens))\n", " print(value)\n", " value.backward()\n", " model.reset_hooks()\n", - " return value.item(), ActivationCache(cache, model), ActivationCache(grad_cache, model)\n", + " return (\n", + " value.item(),\n", + " ActivationCache(cache, model),\n", + " ActivationCache(grad_cache, model),\n", + " )\n", "\n", "\n", "BASELINE = original_per_prompt_logit_diff\n", + "\n", + "\n", "def ioi_metric(logits, answer_tokens=answer_tokens):\n", - " return (logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=True) - BASELINE).sum()\n", + " return (\n", + " logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=True) - BASELINE\n", + " ).sum()\n", + "\n", "\n", "clean_tokens = tokens.clone()\n", - "clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(model, clean_tokens, ioi_metric)\n", + "clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(\n", + " model, clean_tokens, ioi_metric\n", + ")\n", "print(\"Clean Value:\", clean_value)\n", "print(\"Clean Activations Cached:\", len(clean_cache))\n", "print(\"Clean Gradients Cached:\", len(clean_grad_cache))" @@ -835,15 +976,21 @@ "outputs": [], "source": [ "def attr_patch_sae_acts(\n", - " clean_cache: ActivationCache, \n", - " clean_grad_cache: ActivationCache,\n", - " site: str, layer: int\n", - " ):\n", - " clean_sae_acts_post = clean_cache[utils.get_act_name(site, layer) + \".hook_sae_acts_post\"] \n", - " clean_grad_sae_acts_post = clean_grad_cache[utils.get_act_name(site, layer) + \".hook_sae_acts_post\"] \n", + " clean_cache: ActivationCache,\n", + " clean_grad_cache: ActivationCache,\n", + " site: str,\n", + " layer: int,\n", + "):\n", + " clean_sae_acts_post = clean_cache[\n", + " utils.get_act_name(site, layer) + \".hook_sae_acts_post\"\n", + " ]\n", + " clean_grad_sae_acts_post = clean_grad_cache[\n", + " utils.get_act_name(site, layer) + \".hook_sae_acts_post\"\n", + " ]\n", " sae_act_attr = clean_grad_sae_acts_post * (0 - clean_sae_acts_post)\n", " return sae_act_attr\n", "\n", + "\n", "site = \"z\"\n", "layer = 5\n", "sae_act_attr = attr_patch_sae_acts(clean_cache, clean_grad_cache, site, layer)\n", @@ -851,7 +998,10 @@ "imshow(\n", " sae_act_attr[:, s2_pos, all_live_features],\n", " title=\"attribution patching\",\n", - " xaxis=\"Feature Idx\", yaxis=\"Prompt Idx\", x=list(map(str, all_live_features.tolist())))" + " xaxis=\"Feature Idx\",\n", + " yaxis=\"Prompt Idx\",\n", + " x=list(map(str, all_live_features.tolist())),\n", + ")" ] }, { @@ -861,24 +1011,20 @@ "outputs": [], "source": [ "fig = scatter(\n", - " y=sae_act_attr[:, s2_pos, all_live_features].flatten(), \n", + " y=sae_act_attr[:, s2_pos, all_live_features].flatten(),\n", " x=causal_effects.flatten(),\n", " title=\"Attribution vs Activation Patching Per SAE feature (L5 S2 Pos, all prompts)\",\n", " xaxis=\"Activation Patch\",\n", " yaxis=\"Attribution Patch\",\n", - " return_fig=True\n", + " return_fig=True,\n", ")\n", "fig.add_shape(\n", - " type='line',\n", + " type=\"line\",\n", " x0=causal_effects.min(),\n", " y0=causal_effects.min(),\n", " x1=causal_effects.max(),\n", " y1=causal_effects.max(),\n", - " line=dict(\n", - " color='gray',\n", - " width=1,\n", - " dash='dot'\n", - " )\n", + " line=dict(color=\"gray\", width=1, dash=\"dot\"),\n", ")\n", "fig.show()" ] diff --git a/tutorials/basic_loading_and_analysing.ipynb b/tutorials/basic_loading_and_analysing.ipynb index b6f98c2b..422045d1 100644 --- a/tutorials/basic_loading_and_analysing.ipynb +++ b/tutorials/basic_loading_and_analysing.ipynb @@ -1,496 +1,503 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "MNk7IylTv610" - }, - "source": [ - "# Loading and Analysing Pre-Trained Sparse Autoencoders" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "i_DusoOvwV0M" - }, - "source": [ - "## Imports & Installs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "yfDUxRx0wSRl" - }, - "outputs": [], - "source": [ - "try:\n", - " import google.colab # type: ignore\n", - " from google.colab import output\n", - " COLAB = True\n", - " %pip install sae-lens transformer-lens sae-dashboard\n", - "except:\n", - " COLAB = False\n", - " from IPython import get_ipython # type: ignore\n", - " ipython = get_ipython(); assert ipython is not None\n", - " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", - " ipython.run_line_magic(\"autoreload\", \"2\")\n", - "\n", - "# Standard imports\n", - "import os\n", - "import torch\n", - "from tqdm import tqdm\n", - "import plotly.express as px\n", - "\n", - "# Imports for displaying vis in Colab / notebook\n", - "import webbrowser\n", - "import http.server\n", - "import socketserver\n", - "import threading\n", - "PORT = 8000\n", - "\n", - "torch.set_grad_enabled(False);" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7aGgWkbav610" - }, - "source": [ - "## Set Up" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "rQSD7trbv610", - "outputId": "222a40c4-75d4-46e2-ed3f-991841144926" - }, - "outputs": [], - "source": [ - "# For the most part I'll try to import functions and classes near where they are used\n", - "# to make it clear where they come from.\n", - "\n", - "if torch.backends.mps.is_available():\n", - " device = \"mps\"\n", - "else:\n", - " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "\n", - "print(f\"Device: {device}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "cPUq_bdW8mcp" - }, - "outputs": [], - "source": [ - "def display_vis_inline(filename: str, height: int = 850):\n", - " '''\n", - " Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each\n", - " vis has a unique port without having to define a port within the function.\n", - " '''\n", - " if not(COLAB):\n", - " webbrowser.open(filename);\n", - "\n", - " else:\n", - " global PORT\n", - "\n", - " def serve(directory):\n", - " os.chdir(directory)\n", - "\n", - " # Create a handler for serving files\n", - " handler = http.server.SimpleHTTPRequestHandler\n", - "\n", - " # Create a socket server with the handler\n", - " with socketserver.TCPServer((\"\", PORT), handler) as httpd:\n", - " print(f\"Serving files from {directory} on port {PORT}\")\n", - " httpd.serve_forever()\n", - "\n", - " thread = threading.Thread(target=serve, args=(\"/content\",))\n", - " thread.start()\n", - "\n", - " output.serve_kernel_port_as_iframe(PORT, path=f\"/{filename}\", height=height, cache_in_notebook=True)\n", - "\n", - " PORT += 1" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XoMx3VZpv611" - }, - "source": [ - "# Loading a pretrained Sparse Autoencoder\n", - "\n", - "Below we load a Transformerlens model, a pretrained SAE and a dataset from huggingface." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sNSfL80Uv611" - }, - "outputs": [], - "source": [ - "from datasets import load_dataset \n", - "from transformer_lens import HookedTransformer\n", - "from sae_lens import SAE\n", - "\n", - "model = HookedTransformer.from_pretrained(\"gpt2-small\", device = device)\n", - "\n", - "# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)\n", - "# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict\n", - "# We also return the feature sparsities which are stored in HF for convenience. \n", - "sae, cfg_dict, sparsity = SAE.from_pretrained(\n", - " release = \"gpt2-small-res-jb\", # see other options in sae_lens/pretrained_saes.yaml\n", - " sae_id = \"blocks.8.hook_resid_pre\", # won't always be a hook point\n", - " device = device\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from transformer_lens.utils import tokenize_and_concatenate\n", - "\n", - "dataset = load_dataset(\n", - " path = \"NeelNanda/pile-10k\",\n", - " split=\"train\",\n", - " streaming=False,\n", - ")\n", - "\n", - "token_dataset = tokenize_and_concatenate(\n", - " dataset= dataset,# type: ignore\n", - " tokenizer = model.tokenizer, # type: ignore\n", - " streaming=True,\n", - " max_length=sae.cfg.context_size,\n", - " add_bos_token=sae.cfg.prepend_bos,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gy2uUl38v611" - }, - "source": [ - "## Basic Analysis\n", - "\n", - "Let's check some basic stats on this SAE in order to see how some basic functionality in the codebase works.\n", - "\n", - "We'll calculate:\n", - "- L0 (the number of features that fire per activation)\n", - "- The cross entropy loss when the output of the SAE is used in place of the activations" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xOcubgsRv611" - }, - "source": [ - "### L0 Test and Reconstruction Test" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "gAUR5CRBv611" - }, - "outputs": [], - "source": [ - "sae.eval() # prevents error if we're expecting a dead neuron mask for who grads\n", - "\n", - "with torch.no_grad():\n", - " # activation store can give us tokens.\n", - " batch_tokens = token_dataset[:32][\"tokens\"]\n", - " _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)\n", - "\n", - " # Use the SAE\n", - " feature_acts = sae.encode(cache[sae.cfg.hook_name])\n", - " sae_out = sae.decode(feature_acts)\n", - "\n", - " # save some room\n", - " del cache\n", - "\n", - " # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position\n", - " l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()\n", - " print(\"average l0\", l0.mean().item())\n", - " px.histogram(l0.flatten().cpu().numpy()).show()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ijoelLtdv611" - }, - "source": [ - "Note that while the mean L0 is 64, it varies with the specific activation.\n", - "\n", - "To estimate reconstruction performance, we calculate the CE loss of the model with and without the SAE being used in place of the activations. This will vary depending on the tokens." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "fwrSvREJv612" - }, - "outputs": [], - "source": [ - "from transformer_lens import utils\n", - "from functools import partial\n", - "\n", - "# next we want to do a reconstruction test.\n", - "def reconstr_hook(activation, hook, sae_out):\n", - " return sae_out\n", - "\n", - "\n", - "def zero_abl_hook(activation, hook):\n", - " return torch.zeros_like(activation)\n", - "\n", - "\n", - "print(\"Orig\", model(batch_tokens, return_type=\"loss\").item())\n", - "print(\n", - " \"reconstr\",\n", - " model.run_with_hooks(\n", - " batch_tokens,\n", - " fwd_hooks=[\n", - " (\n", - " sae.cfg.hook_name,\n", - " partial(reconstr_hook, sae_out=sae_out),\n", - " )\n", - " ],\n", - " return_type=\"loss\",\n", - " ).item(),\n", - ")\n", - "print(\n", - " \"Zero\",\n", - " model.run_with_hooks(\n", - " batch_tokens,\n", - " return_type=\"loss\",\n", - " fwd_hooks=[(sae.cfg.hook_name, zero_abl_hook)],\n", - " ).item(),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "B_TRq_lFv612" - }, - "source": [ - "## Specific Capability Test\n", - "\n", - "Validating model performance on specific tasks when using the reconstructed activation is quite important when studying specific tasks." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "npxKip_Qv612" - }, - "outputs": [], - "source": [ - "example_prompt = \"When John and Mary went to the shops, John gave the bag to\"\n", - "example_answer = \" Mary\"\n", - "utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)\n", - "\n", - "logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)\n", - "tokens = model.to_tokens(example_prompt)\n", - "sae_out = sae(cache[sae.cfg.hook_name])\n", - "\n", - "\n", - "def reconstr_hook(activations, hook, sae_out):\n", - " return sae_out\n", - "\n", - "\n", - "def zero_abl_hook(mlp_out, hook):\n", - " return torch.zeros_like(mlp_out)\n", - "\n", - "\n", - "hook_name = sae.cfg.hook_name\n", - "\n", - "print(\"Orig\", model(tokens, return_type=\"loss\").item())\n", - "print(\n", - " \"reconstr\",\n", - " model.run_with_hooks(\n", - " tokens,\n", - " fwd_hooks=[\n", - " (\n", - " hook_name,\n", - " partial(reconstr_hook, sae_out=sae_out),\n", - " )\n", - " ],\n", - " return_type=\"loss\",\n", - " ).item(),\n", - ")\n", - "print(\n", - " \"Zero\",\n", - " model.run_with_hooks(\n", - " tokens,\n", - " return_type=\"loss\",\n", - " fwd_hooks=[(hook_name, zero_abl_hook)],\n", - " ).item(),\n", - ")\n", - "\n", - "\n", - "with model.hooks(\n", - " fwd_hooks=[\n", - " (\n", - " hook_name,\n", - " partial(reconstr_hook, sae_out=sae_out),\n", - " )\n", - " ]\n", - "):\n", - " utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1swj9KA7v612" - }, - "source": [ - "# Generating Feature Interfaces\n", - "\n", - "Feature dashboards are an important part of SAE Evaluation. They work by:\n", - "- 1. Collecting feature activations over a larger number of examples.\n", - "- 2. Aggregating feature specific statistics (such as max activating examples).\n", - "- 3. Representing that information in a standardized way\n", - "\n", - "For our feature visualizations, we will use a separate library called SAEDashboard." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "edt8ag4fv612" - }, - "outputs": [], - "source": [ - "from sae_dashboard.sae_vis_data import SaeVisConfig\n", - "from sae_dashboard.sae_vis_runner import SaeVisRunner\n", - "\n", - "test_feature_idx_gpt = list(range(10)) + [14057]\n", - "\n", - "feature_vis_config_gpt = SaeVisConfig(\n", - " hook_point=hook_name,\n", - " features=test_feature_idx_gpt,\n", - " minibatch_size_features=64,\n", - " minibatch_size_tokens=256,\n", - " verbose=True,\n", - " device=device,\n", - ")\n", - "\n", - "visualization_data_gpt = SaeVisRunner(feature_vis_config_gpt).run(\n", - " encoder=sae, # type: ignore\n", - " model=model,\n", - " tokens=token_dataset[:10000][\"tokens\"], # type: ignore\n", - ")\n", - "# SaeVisData.create(\n", - "# encoder=sae,\n", - "# model=model, # type: ignore\n", - "# tokens=token_dataset[:10000][\"tokens\"], # type: ignore\n", - "# cfg=feature_vis_config_gpt,\n", - "# )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "yQ94Frzbv612" - }, - "outputs": [], - "source": [ - "from sae_dashboard.data_writing_fns import save_feature_centric_vis\n", - "\n", - "filename = f\"demo_feature_dashboards.html\"\n", - "save_feature_centric_vis(sae_vis_data=visualization_data_gpt, filename=filename)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AUaD6CFDv612" - }, - "source": [ - "Now, since generating feature dashboards can be done once per sparse autoencoder, for pre-trained SAEs in the public domain, everyone can use the same dashboards. Neuronpedia hosts dashboards which we can load via the integration." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "BxluyNRBv612" - }, - "outputs": [], - "source": [ - "from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list\n", - "\n", - "# this function should open\n", - "neuronpedia_quick_list = get_neuronpedia_quick_list(\n", - " test_feature_idx_gpt,\n", - " layer=sae.cfg.hook_layer,\n", - " model=\"gpt2-small\",\n", - " dataset=\"res-jb\",\n", - " name=\"A quick list we made\",\n", - ")\n", - "\n", - "if COLAB:\n", - " # If you're on colab, click the link below\n", - " print(neuronpedia_quick_list)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "accelerator": "GPU", + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "MNk7IylTv610" + }, + "source": [ + "# Loading and Analysing Pre-Trained Sparse Autoencoders" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i_DusoOvwV0M" + }, + "source": [ + "## Imports & Installs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yfDUxRx0wSRl" + }, + "outputs": [], + "source": [ + "try:\n", + " import google.colab # type: ignore\n", + " from google.colab import output\n", + "\n", + " COLAB = True\n", + " %pip install sae-lens transformer-lens sae-dashboard\n", + "except:\n", + " COLAB = False\n", + " from IPython import get_ipython # type: ignore\n", + "\n", + " ipython = get_ipython()\n", + " assert ipython is not None\n", + " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", + " ipython.run_line_magic(\"autoreload\", \"2\")\n", + "\n", + "# Standard imports\n", + "import os\n", + "import torch\n", + "from tqdm import tqdm\n", + "import plotly.express as px\n", + "\n", + "# Imports for displaying vis in Colab / notebook\n", + "import webbrowser\n", + "import http.server\n", + "import socketserver\n", + "import threading\n", + "\n", + "PORT = 8000\n", + "\n", + "torch.set_grad_enabled(False);" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7aGgWkbav610" + }, + "source": [ + "## Set Up" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { "colab": { - "gpuType": "T4", - "provenance": [] + "base_uri": "https://localhost:8080/" }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } + "id": "rQSD7trbv610", + "outputId": "222a40c4-75d4-46e2-ed3f-991841144926" + }, + "outputs": [], + "source": [ + "# For the most part I'll try to import functions and classes near where they are used\n", + "# to make it clear where they come from.\n", + "\n", + "if torch.backends.mps.is_available():\n", + " device = \"mps\"\n", + "else:\n", + " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "print(f\"Device: {device}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cPUq_bdW8mcp" + }, + "outputs": [], + "source": [ + "def display_vis_inline(filename: str, height: int = 850):\n", + " \"\"\"\n", + " Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each\n", + " vis has a unique port without having to define a port within the function.\n", + " \"\"\"\n", + " if not (COLAB):\n", + " webbrowser.open(filename)\n", + "\n", + " else:\n", + " global PORT\n", + "\n", + " def serve(directory):\n", + " os.chdir(directory)\n", + "\n", + " # Create a handler for serving files\n", + " handler = http.server.SimpleHTTPRequestHandler\n", + "\n", + " # Create a socket server with the handler\n", + " with socketserver.TCPServer((\"\", PORT), handler) as httpd:\n", + " print(f\"Serving files from {directory} on port {PORT}\")\n", + " httpd.serve_forever()\n", + "\n", + " thread = threading.Thread(target=serve, args=(\"/content\",))\n", + " thread.start()\n", + "\n", + " output.serve_kernel_port_as_iframe(\n", + " PORT, path=f\"/{filename}\", height=height, cache_in_notebook=True\n", + " )\n", + "\n", + " PORT += 1" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XoMx3VZpv611" + }, + "source": [ + "# Loading a pretrained Sparse Autoencoder\n", + "\n", + "Below we load a Transformerlens model, a pretrained SAE and a dataset from huggingface." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sNSfL80Uv611" + }, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "from transformer_lens import HookedTransformer\n", + "from sae_lens import SAE\n", + "\n", + "model = HookedTransformer.from_pretrained(\"gpt2-small\", device=device)\n", + "\n", + "# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)\n", + "# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict\n", + "# We also return the feature sparsities which are stored in HF for convenience.\n", + "sae, cfg_dict, sparsity = SAE.from_pretrained(\n", + " release=\"gpt2-small-res-jb\", # see other options in sae_lens/pretrained_saes.yaml\n", + " sae_id=\"blocks.8.hook_resid_pre\", # won't always be a hook point\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformer_lens.utils import tokenize_and_concatenate\n", + "\n", + "dataset = load_dataset(\n", + " path=\"NeelNanda/pile-10k\",\n", + " split=\"train\",\n", + " streaming=False,\n", + ")\n", + "\n", + "token_dataset = tokenize_and_concatenate(\n", + " dataset=dataset, # type: ignore\n", + " tokenizer=model.tokenizer, # type: ignore\n", + " streaming=True,\n", + " max_length=sae.cfg.context_size,\n", + " add_bos_token=sae.cfg.prepend_bos,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gy2uUl38v611" + }, + "source": [ + "## Basic Analysis\n", + "\n", + "Let's check some basic stats on this SAE in order to see how some basic functionality in the codebase works.\n", + "\n", + "We'll calculate:\n", + "- L0 (the number of features that fire per activation)\n", + "- The cross entropy loss when the output of the SAE is used in place of the activations" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xOcubgsRv611" + }, + "source": [ + "### L0 Test and Reconstruction Test" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gAUR5CRBv611" + }, + "outputs": [], + "source": [ + "sae.eval() # prevents error if we're expecting a dead neuron mask for who grads\n", + "\n", + "with torch.no_grad():\n", + " # activation store can give us tokens.\n", + " batch_tokens = token_dataset[:32][\"tokens\"]\n", + " _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)\n", + "\n", + " # Use the SAE\n", + " feature_acts = sae.encode(cache[sae.cfg.hook_name])\n", + " sae_out = sae.decode(feature_acts)\n", + "\n", + " # save some room\n", + " del cache\n", + "\n", + " # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position\n", + " l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()\n", + " print(\"average l0\", l0.mean().item())\n", + " px.histogram(l0.flatten().cpu().numpy()).show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ijoelLtdv611" + }, + "source": [ + "Note that while the mean L0 is 64, it varies with the specific activation.\n", + "\n", + "To estimate reconstruction performance, we calculate the CE loss of the model with and without the SAE being used in place of the activations. This will vary depending on the tokens." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "fwrSvREJv612" + }, + "outputs": [], + "source": [ + "from transformer_lens import utils\n", + "from functools import partial\n", + "\n", + "\n", + "# next we want to do a reconstruction test.\n", + "def reconstr_hook(activation, hook, sae_out):\n", + " return sae_out\n", + "\n", + "\n", + "def zero_abl_hook(activation, hook):\n", + " return torch.zeros_like(activation)\n", + "\n", + "\n", + "print(\"Orig\", model(batch_tokens, return_type=\"loss\").item())\n", + "print(\n", + " \"reconstr\",\n", + " model.run_with_hooks(\n", + " batch_tokens,\n", + " fwd_hooks=[\n", + " (\n", + " sae.cfg.hook_name,\n", + " partial(reconstr_hook, sae_out=sae_out),\n", + " )\n", + " ],\n", + " return_type=\"loss\",\n", + " ).item(),\n", + ")\n", + "print(\n", + " \"Zero\",\n", + " model.run_with_hooks(\n", + " batch_tokens,\n", + " return_type=\"loss\",\n", + " fwd_hooks=[(sae.cfg.hook_name, zero_abl_hook)],\n", + " ).item(),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "B_TRq_lFv612" + }, + "source": [ + "## Specific Capability Test\n", + "\n", + "Validating model performance on specific tasks when using the reconstructed activation is quite important when studying specific tasks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "npxKip_Qv612" + }, + "outputs": [], + "source": [ + "example_prompt = \"When John and Mary went to the shops, John gave the bag to\"\n", + "example_answer = \" Mary\"\n", + "utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)\n", + "\n", + "logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)\n", + "tokens = model.to_tokens(example_prompt)\n", + "sae_out = sae(cache[sae.cfg.hook_name])\n", + "\n", + "\n", + "def reconstr_hook(activations, hook, sae_out):\n", + " return sae_out\n", + "\n", + "\n", + "def zero_abl_hook(mlp_out, hook):\n", + " return torch.zeros_like(mlp_out)\n", + "\n", + "\n", + "hook_name = sae.cfg.hook_name\n", + "\n", + "print(\"Orig\", model(tokens, return_type=\"loss\").item())\n", + "print(\n", + " \"reconstr\",\n", + " model.run_with_hooks(\n", + " tokens,\n", + " fwd_hooks=[\n", + " (\n", + " hook_name,\n", + " partial(reconstr_hook, sae_out=sae_out),\n", + " )\n", + " ],\n", + " return_type=\"loss\",\n", + " ).item(),\n", + ")\n", + "print(\n", + " \"Zero\",\n", + " model.run_with_hooks(\n", + " tokens,\n", + " return_type=\"loss\",\n", + " fwd_hooks=[(hook_name, zero_abl_hook)],\n", + " ).item(),\n", + ")\n", + "\n", + "\n", + "with model.hooks(\n", + " fwd_hooks=[\n", + " (\n", + " hook_name,\n", + " partial(reconstr_hook, sae_out=sae_out),\n", + " )\n", + " ]\n", + "):\n", + " utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1swj9KA7v612" + }, + "source": [ + "# Generating Feature Interfaces\n", + "\n", + "Feature dashboards are an important part of SAE Evaluation. They work by:\n", + "- 1. Collecting feature activations over a larger number of examples.\n", + "- 2. Aggregating feature specific statistics (such as max activating examples).\n", + "- 3. Representing that information in a standardized way\n", + "\n", + "For our feature visualizations, we will use a separate library called SAEDashboard." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "edt8ag4fv612" + }, + "outputs": [], + "source": [ + "from sae_dashboard.sae_vis_data import SaeVisConfig\n", + "from sae_dashboard.sae_vis_runner import SaeVisRunner\n", + "\n", + "test_feature_idx_gpt = list(range(10)) + [14057]\n", + "\n", + "feature_vis_config_gpt = SaeVisConfig(\n", + " hook_point=hook_name,\n", + " features=test_feature_idx_gpt,\n", + " minibatch_size_features=64,\n", + " minibatch_size_tokens=256,\n", + " verbose=True,\n", + " device=device,\n", + ")\n", + "\n", + "visualization_data_gpt = SaeVisRunner(feature_vis_config_gpt).run(\n", + " encoder=sae, # type: ignore\n", + " model=model,\n", + " tokens=token_dataset[:10000][\"tokens\"], # type: ignore\n", + ")\n", + "# SaeVisData.create(\n", + "# encoder=sae,\n", + "# model=model, # type: ignore\n", + "# tokens=token_dataset[:10000][\"tokens\"], # type: ignore\n", + "# cfg=feature_vis_config_gpt,\n", + "# )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yQ94Frzbv612" + }, + "outputs": [], + "source": [ + "from sae_dashboard.data_writing_fns import save_feature_centric_vis\n", + "\n", + "filename = f\"demo_feature_dashboards.html\"\n", + "save_feature_centric_vis(sae_vis_data=visualization_data_gpt, filename=filename)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AUaD6CFDv612" + }, + "source": [ + "Now, since generating feature dashboards can be done once per sparse autoencoder, for pre-trained SAEs in the public domain, everyone can use the same dashboards. Neuronpedia hosts dashboards which we can load via the integration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BxluyNRBv612" + }, + "outputs": [], + "source": [ + "from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list\n", + "\n", + "# this function should open\n", + "neuronpedia_quick_list = get_neuronpedia_quick_list(\n", + " test_feature_idx_gpt,\n", + " layer=sae.cfg.hook_layer,\n", + " model=\"gpt2-small\",\n", + " dataset=\"res-jb\",\n", + " name=\"A quick list we made\",\n", + ")\n", + "\n", + "if COLAB:\n", + " # If you're on colab, click the link below\n", + " print(neuronpedia_quick_list)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" }, - "nbformat": 4, - "nbformat_minor": 0 + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/tutorials/loading_tanh_sae.ipynb b/tutorials/loading_tanh_sae.ipynb index 57db9eb3..8365a699 100644 --- a/tutorials/loading_tanh_sae.ipynb +++ b/tutorials/loading_tanh_sae.ipynb @@ -49,28 +49,30 @@ " for k in f.keys():\n", " tensors[k] = f.get_tensor(k)\n", "\n", - "d_in = config_dict['n_input_features']\n", - "d_sae = config_dict['n_learned_features']\n", + "d_in = config_dict[\"n_input_features\"]\n", + "d_sae = config_dict[\"n_learned_features\"]\n", "cfg = LanguageModelSAERunnerConfig(\n", - " d_in=d_in,\n", - " expansion_factor=d_sae//d_in,\n", - " normalize_sae_decoder=False,\n", - " noise_scale=config_dict['noise_scale'],\n", - " model_name=\"gpt2\",\n", - " activation_fn=\"tanh-relu\",\n", - " hook_name=\"blocks.{layer}.hook_mlp_out\",\n", - " hook_layer=list(range(config_dict['n_components'])), # type: ignore\n", - " dtype=\"torch.float32\",\n", - " device=device,\n", - " )\n", + " d_in=d_in,\n", + " expansion_factor=d_sae // d_in,\n", + " normalize_sae_decoder=False,\n", + " noise_scale=config_dict[\"noise_scale\"],\n", + " model_name=\"gpt2\",\n", + " activation_fn=\"tanh-relu\",\n", + " hook_name=\"blocks.{layer}.hook_mlp_out\",\n", + " hook_layer=list(range(config_dict[\"n_components\"])), # type: ignore\n", + " dtype=\"torch.float32\",\n", + " device=device,\n", + ")\n", "\n", "single_sae = SAE.from_dict(cfg.get_base_sae_cfg_dict())\n", "with torch.no_grad():\n", " layer = single_sae.cfg.hook_layer\n", - " single_sae.W_enc.data = tensors['encoder.weight'].data[layer].T.clone()\n", - " single_sae.b_enc.data = tensors['encoder.bias'].data[layer].clone()\n", - " single_sae.W_dec.data = tensors['decoder.weight'].data[layer].T.clone()\n", - " single_sae.b_dec.data = tensors['post_decoder_bias._bias_reference'].data[layer].clone()" + " single_sae.W_enc.data = tensors[\"encoder.weight\"].data[layer].T.clone()\n", + " single_sae.b_enc.data = tensors[\"encoder.bias\"].data[layer].clone()\n", + " single_sae.W_dec.data = tensors[\"decoder.weight\"].data[layer].T.clone()\n", + " single_sae.b_dec.data = (\n", + " tensors[\"post_decoder_bias._bias_reference\"].data[layer].clone()\n", + " )" ] }, { @@ -120,7 +122,9 @@ "outputs": [], "source": [ "dataset_path = \"alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2\"\n", - "torch_dataset = load_dataset(dataset_path, split=\"train\", streaming=True).with_format(\"torch\")" + "torch_dataset = load_dataset(dataset_path, split=\"train\", streaming=True).with_format(\n", + " \"torch\"\n", + ")" ] }, { @@ -137,20 +141,22 @@ " \"\"\"Tokenized prompts.\"\"\"\n", "\n", " input_ids: list[TokenizedPrompt]\n", - " \n", + "\n", + "\n", "class TorchTokenizedPrompts(TypedDict):\n", " \"\"\"Tokenized prompts prepared for PyTorch.\"\"\"\n", "\n", " input_ids: torch.Tensor\n", "\n", + "\n", "dl = DataLoader[TorchTokenizedPrompts](\n", - " torch_dataset,\n", - " batch_size=16,\n", - " # Shuffle is most efficiently done with the `shuffle` method on the dataset itself, not\n", - " # here.\n", - " shuffle=False,\n", - " num_workers=1,\n", - " )" + " torch_dataset,\n", + " batch_size=16,\n", + " # Shuffle is most efficiently done with the `shuffle` method on the dataset itself, not\n", + " # here.\n", + " shuffle=False,\n", + " num_workers=1,\n", + ")" ] }, { @@ -161,32 +167,34 @@ "source": [ "saes_by_layer = {}\n", "hooked_layers = []\n", - "for layer in list(range(config_dict['n_components'])):\n", + "for layer in list(range(config_dict[\"n_components\"])):\n", " cfg = LanguageModelSAERunnerConfig(\n", - " d_in=d_in,\n", - " expansion_factor=d_sae//d_in,\n", - " normalize_sae_decoder=False,\n", - " noise_scale=config_dict['noise_scale'],\n", - " model_name=\"gpt2\",\n", - " activation_fn=\"tanh-relu\",\n", - " hook_name=f\"blocks.{layer}.hook_mlp_out\",\n", - " hook_layer=layer, # type: ignore\n", - " dtype=\"torch.float32\",\n", - " device=device,\n", - " verbose=False,\n", - " )\n", + " d_in=d_in,\n", + " expansion_factor=d_sae // d_in,\n", + " normalize_sae_decoder=False,\n", + " noise_scale=config_dict[\"noise_scale\"],\n", + " model_name=\"gpt2\",\n", + " activation_fn=\"tanh-relu\",\n", + " hook_name=f\"blocks.{layer}.hook_mlp_out\",\n", + " hook_layer=layer, # type: ignore\n", + " dtype=\"torch.float32\",\n", + " device=device,\n", + " verbose=False,\n", + " )\n", "\n", " single_sae = SAE.from_dict(cfg.get_base_sae_cfg_dict())\n", " with torch.no_grad():\n", " layer = single_sae.cfg.hook_layer\n", - " single_sae.W_enc.data = tensors['encoder.weight'].data[layer].T.clone()\n", - " single_sae.b_enc.data = tensors['encoder.bias'].data[layer].clone()\n", - " single_sae.W_dec.data = tensors['decoder.weight'].data[layer].T.clone()\n", - " single_sae.b_dec.data = tensors['post_decoder_bias._bias_reference'].data[layer].clone()\n", + " single_sae.W_enc.data = tensors[\"encoder.weight\"].data[layer].T.clone()\n", + " single_sae.b_enc.data = tensors[\"encoder.bias\"].data[layer].clone()\n", + " single_sae.W_dec.data = tensors[\"decoder.weight\"].data[layer].T.clone()\n", + " single_sae.b_dec.data = (\n", + " tensors[\"post_decoder_bias._bias_reference\"].data[layer].clone()\n", + " )\n", "\n", " saes_by_layer[layer] = single_sae\n", " hooked_layers.append(single_sae.cfg.hook_name)\n", - " \n", + "\n", "hooked_layers" ] }, @@ -201,7 +209,9 @@ " if i >= 1:\n", " break\n", " batch_tokens = batch[\"input_ids\"]\n", - " _, cache = model.run_with_cache(batch_tokens, prepend_bos=True, names_filter=hooked_layers)\n", + " _, cache = model.run_with_cache(\n", + " batch_tokens, prepend_bos=True, names_filter=hooked_layers\n", + " )\n", " residuals = [cache[layer] for layer in hooked_layers]\n", " del cache" ] @@ -215,11 +225,18 @@ "sae_hooks = [\"hook_sae_acts_post\", \"hook_sae_output\"]\n", "for i in range(len(residuals)):\n", " autoencoder = saes_by_layer[i]\n", - " _, cache = autoencoder.run_with_cache(residuals[i].to(autoencoder.device), names_filter=sae_hooks)\n", + " _, cache = autoencoder.run_with_cache(\n", + " residuals[i].to(autoencoder.device), names_filter=sae_hooks\n", + " )\n", " reconstructed = cache[\"hook_sae_output\"]\n", " feature_act = cache[\"hook_sae_acts_post\"]\n", - " l2_loss = torch.nn.functional.mse_loss(residuals[i].to(autoencoder.device), reconstructed)\n", - " l1_loss = torch.nn.functional.l1_loss(feature_act, torch.zeros_like(feature_act)) * autoencoder.cfg.d_sae\n", + " l2_loss = torch.nn.functional.mse_loss(\n", + " residuals[i].to(autoencoder.device), reconstructed\n", + " )\n", + " l1_loss = (\n", + " torch.nn.functional.l1_loss(feature_act, torch.zeros_like(feature_act))\n", + " * autoencoder.cfg.d_sae\n", + " )\n", " print(f\"Layer {i}: L2 loss: {l2_loss}, L1 loss: {l1_loss}\")\n", " del cache" ] diff --git a/tutorials/logits_lens_with_features.ipynb b/tutorials/logits_lens_with_features.ipynb index a034a7ee..3ad77a5d 100644 --- a/tutorials/logits_lens_with_features.ipynb +++ b/tutorials/logits_lens_with_features.ipynb @@ -1,891 +1,894 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "W6UyzFGfBSdA" - }, - "source": [ - "# Understanding SAE Features with the Logit Lens\n", - "\n", - "This notebook demonstrates how to use the mats_sae_training library to perform the analysis documented the post \"[Understanding SAE Features with the Logit Lens](https://www.alignmentforum.org/posts/qykrYY6rXXM7EEs8Q/understanding-sae-features-with-the-logit-lens)\".\n", - "\n", - "As such, the notebook will include sections for:\n", - "- Loading in GPT2-Small Residual Stream SAEs from Huggingface.\n", - "- Performing Virtual Weight Based Analysis of features (specifically looking at the logit weight distributions).\n", - "- Programmatically opening neuronpedia tabs to engage with public dashboards on [neuronpedia](https://www.neuronpedia.org/).\n", - "- Performing Token Set Enrichment Analysis (based on Gene Set Enrichment Analysis)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ff-PciWXBSdB" - }, - "source": [ - "## Set Up\n", - "\n", - "Here we'll load various functions for things like:\n", - "- downloading and loading our SAEs from huggingface.\n", - "- opening neuronpedia from a jupyter cell.\n", - "- calculating statistics of the logit weight distributions.\n", - "- performing Token Set Enrichment Analysis (TSEA) and plotting the results." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vImmTg-8BSdC" - }, - "source": [ - "### Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sOd2C0e1BfN1" - }, - "outputs": [], - "source": [ - "try:\n", - " # For Google Colab, a high RAM instance is needed\n", - " import google.colab # type: ignore\n", - " from google.colab import output\n", - " %pip install sae-lens transformer-lens\n", - "except:\n", - " from IPython import get_ipython # type: ignore\n", - " ipython = get_ipython(); assert ipython is not None\n", - " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", - " ipython.run_line_magic(\"autoreload\", \"2\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "QmdAd_25BSdC" - }, - "outputs": [], - "source": [ - "import numpy as np\n", - "import torch\n", - "import plotly_express as px\n", - "\n", - "from transformer_lens import HookedTransformer\n", - "\n", - "# Model Loading\n", - "from sae_lens import SAE\n", - "from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list\n", - "\n", - "# Virtual Weight / Feature Statistics Functions\n", - "from sae_lens.analysis.feature_statistics import (\n", - " get_all_stats_dfs,\n", - " get_W_U_W_dec_stats_df,\n", - ")\n", - "\n", - "# Enrichment Analysis Functions\n", - "from sae_lens.analysis.tsea import (\n", - " get_enrichment_df,\n", - " manhattan_plot_enrichment_scores,\n", - " plot_top_k_feature_projections_by_token_and_category,\n", - ")\n", - "from sae_lens.analysis.tsea import (\n", - " get_baby_name_sets,\n", - " get_letter_gene_sets,\n", - " generate_pos_sets,\n", - " get_test_gene_sets,\n", - " get_gene_set_from_regex,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0VEAe5FjBSdD" - }, - "source": [ - "### Loading GPT2 Small and SAE Weights\n", - "\n", - "This will take a while the first time you run it, but will be quick thereafter." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "HQ904zDOBSdD" - }, - "outputs": [], - "source": [ - "model = HookedTransformer.from_pretrained(\"gpt2-small\")\n", - "# this is an outdated way to load the SAE. We need to have feature spartisity loadable through the new interface to remove it.\n", - "gpt2_small_sparse_autoencoders = {}\n", - "gpt2_small_sae_sparsities = {}\n", - "\n", - "for layer in range(12):\n", - " sae, original_cfg_dict, sparsity = SAE.from_pretrained(\n", - " release=\"gpt2-small-res-jb\",\n", - " sae_id=\"blocks.0.hook_resid_pre\",\n", - " device=\"cpu\",\n", - " )\n", - " gpt2_small_sparse_autoencoders[f\"blocks.{layer}.hook_resid_pre\"] = sae\n", - " gpt2_small_sae_sparsities[f\"blocks.{layer}.hook_resid_pre\"] = sparsity\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4pFYJKeNBSdD" - }, - "source": [ - "# Statistical Properties of Feature Logit Distributions\n", - "\n", - "In the post I study layer 8 (for no particular reason). At the end of this notebook is code for visualizing these statistics across all layers. Feel free to change the layer here and explore different layers." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "LEK1jEpEBSdD" - }, - "outputs": [], - "source": [ - "# In the post, I focus on layer 8\n", - "layer = 8\n", - "\n", - "# get the corresponding SAE and feature sparsities.\n", - "sparse_autoencoder = gpt2_small_sparse_autoencoders[f\"blocks.{layer}.hook_resid_pre\"]\n", - "log_feature_sparsity = gpt2_small_sae_sparsities[f\"blocks.{layer}.hook_resid_pre\"].cpu()\n", - "\n", - "W_dec = sparse_autoencoder.W_dec.detach().cpu()\n", - "\n", - "# calculate the statistics of the logit weight distributions\n", - "W_U_stats_df_dec, dec_projection_onto_W_U = get_W_U_W_dec_stats_df(\n", - " W_dec, model, cosine_sim=False\n", - ")\n", - "W_U_stats_df_dec[\"sparsity\"] = (\n", - " log_feature_sparsity # add feature sparsity since it is often interesting.\n", - ")\n", - "display(W_U_stats_df_dec)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "GK56PDl3BSdE" - }, - "outputs": [], - "source": [ - "# Let's look at the distribution of the 3rd / 4th moments. I found these aren't as useful on their own as joint distributions can be.\n", - "px.histogram(\n", - " W_U_stats_df_dec,\n", - " x=\"skewness\",\n", - " width=800,\n", - " height=300,\n", - " nbins=1000,\n", - " title=\"Skewness of the Logit Weight Distributions\",\n", - ").show()\n", - "\n", - "px.histogram(\n", - " W_U_stats_df_dec,\n", - " x=np.log10(W_U_stats_df_dec[\"kurtosis\"]),\n", - " width=800,\n", - " height=300,\n", - " nbins=1000,\n", - " title=\"Kurtosis of the Logit Weight Distributions\",\n", - ").show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "4_6c12ftBSdE" - }, - "outputs": [], - "source": [ - "fig = px.scatter(\n", - " W_U_stats_df_dec,\n", - " x=\"skewness\",\n", - " y=\"kurtosis\",\n", - " color=\"std\",\n", - " color_continuous_scale=\"Portland\",\n", - " hover_name=\"feature\",\n", - " width=800,\n", - " height=500,\n", - " log_y=True, # Kurtosis has larger outliers so logging creates a nicer scale.\n", - " labels={\"x\": \"Skewness\", \"y\": \"Kurtosis\", \"color\": \"Standard Deviation\"},\n", - " title=f\"Layer {8}: Skewness vs Kurtosis of the Logit Weight Distributions\",\n", - ")\n", - "\n", - "# decrease point size\n", - "fig.update_traces(marker=dict(size=3))\n", - "\n", - "\n", - "fig.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KuWSPqqjBSdE" - }, - "outputs": [], - "source": [ - "# then you can query accross combinations of the statistics to find features of interest and open them in neuronpedia.\n", - "tmp_df = W_U_stats_df_dec[[\"feature\", \"skewness\", \"kurtosis\", \"std\"]]\n", - "# tmp_df = tmp_df[(tmp_df[\"std\"] > 0.04)]\n", - "# tmp_df = tmp_df[(tmp_df[\"skewness\"] > 0.65)]\n", - "tmp_df = tmp_df[(tmp_df[\"skewness\"] > 3)]\n", - "tmp_df = tmp_df.sort_values(\"skewness\", ascending=False).head(10)\n", - "display(tmp_df)\n", - "\n", - "# if desired, open the features in neuronpedia\n", - "get_neuronpedia_quick_list(sparse_autoencoder, list(tmp_df.feature))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AHGG86wkBSdE" - }, - "source": [ - "# Token Set Enrichment Analysis\n", - "\n", - "We now proceed to token set enrichment analysis. I highly recommend reading my AlignmentForum post (espeically the case studies) before reading too much into any of these results.\n", - "Also read this [post](https://transformer-circuits.pub/2024/qualitative-essay/index.html) for good general perspectives on statistics here." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "J4tuqT9SBSdE" - }, - "source": [ - "## Defining Our Token Sets" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "7mUxqd6mBSdE" - }, - "outputs": [], - "source": [ - "import nltk\n", - "\n", - "nltk.download(\"averaged_perceptron_tagger\")\n", - "# get the vocab we need to filter to formulate token sets.\n", - "vocab = model.tokenizer.get_vocab() # type: ignore\n", - "\n", - "# make a regex dictionary to specify more sets.\n", - "regex_dict = {\n", - " \"starts_with_space\": r\"Ġ.*\",\n", - " \"starts_with_capital\": r\"^Ġ*[A-Z].*\",\n", - " \"starts_with_lower\": r\"^Ġ*[a-z].*\",\n", - " \"all_digits\": r\"^Ġ*\\d+$\",\n", - " \"is_punctuation\": r\"^[^\\w\\s]+$\",\n", - " \"contains_close_bracket\": r\".*\\).*\",\n", - " \"contains_open_bracket\": r\".*\\(.*\",\n", - " \"all_caps\": r\"Ġ*[A-Z]+$\",\n", - " \"1 digit\": r\"Ġ*\\d{1}$\",\n", - " \"2 digits\": r\"Ġ*\\d{2}$\",\n", - " \"3 digits\": r\"Ġ*\\d{3}$\",\n", - " \"4 digits\": r\"Ġ*\\d{4}$\",\n", - " \"length_1\": r\"^Ġ*\\w{1}$\",\n", - " \"length_2\": r\"^Ġ*\\w{2}$\",\n", - " \"length_3\": r\"^Ġ*\\w{3}$\",\n", - " \"length_4\": r\"^Ġ*\\w{4}$\",\n", - " \"length_5\": r\"^Ġ*\\w{5}$\",\n", - "}\n", - "\n", - "# print size of gene sets\n", - "all_token_sets = get_letter_gene_sets(vocab)\n", - "for key, value in regex_dict.items():\n", - " gene_set = get_gene_set_from_regex(vocab, value)\n", - " all_token_sets[key] = gene_set\n", - "\n", - "# some other sets that can be interesting\n", - "baby_name_sets = get_baby_name_sets(vocab)\n", - "pos_sets = generate_pos_sets(vocab)\n", - "arbitrary_sets = get_test_gene_sets(model)\n", - "\n", - "all_token_sets = {**all_token_sets, **pos_sets}\n", - "all_token_sets = {**all_token_sets, **arbitrary_sets}\n", - "all_token_sets = {**all_token_sets, **baby_name_sets}\n", - "\n", - "# for each gene set, convert to string and print the first 5 tokens\n", - "for token_set_name, gene_set in sorted(\n", - " all_token_sets.items(), key=lambda x: len(x[1]), reverse=True\n", - "):\n", - " tokens = [model.to_string(id) for id in list(gene_set)][:10] # type: ignore\n", - " print(f\"{token_set_name}, has {len(gene_set)} genes\")\n", - " print(tokens)\n", - " print(\"----\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vxctX05KBSdE" - }, - "source": [ - "## Performing Token Set Enrichment Analysis\n", - "\n", - "Below we perform token set enrichment analysis on various token sets. In practice, we'd likely perform tests accross all tokens and large libraries of sets simultaneously but to make it easier to run, we look at features with higher skew and select of a few token sets at a time to consider." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "eHwM7qVlBSdF" - }, - "outputs": [], - "source": [ - "features_ordered_by_skew = (\n", - " W_U_stats_df_dec[\"skewness\"].sort_values(ascending=False).head(5000).index.to_list()\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "QxuiBx4NBSdF" - }, - "outputs": [], - "source": [ - "# filter our list.\n", - "token_sets_index = [\n", - " \"starts_with_space\",\n", - " \"starts_with_capital\",\n", - " \"all_digits\",\n", - " \"is_punctuation\",\n", - " \"all_caps\",\n", - "]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "\n", - "# calculate the enrichment scores\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, # use the logit weight values as our rankings over tokens.\n", - " features_ordered_by_skew, # subset by these features\n", - " token_set_selected, # use token_sets\n", - ")\n", - "\n", - "manhattan_plot_enrichment_scores(\n", - " df_enrichment_scores, label_threshold=0, top_n=3 # use our enrichment scores\n", - ").show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "yQ0n0aKTBSdF" - }, - "outputs": [], - "source": [ - "fig = px.scatter(\n", - " df_enrichment_scores.apply(lambda x: -1 * np.log(1 - x)).T,\n", - " x=\"starts_with_space\",\n", - " y=\"starts_with_capital\",\n", - " marginal_x=\"histogram\",\n", - " marginal_y=\"histogram\",\n", - " labels={\n", - " \"starts_with_space\": \"Starts with Space\",\n", - " \"starts_with_capital\": \"Starts with Capital\",\n", - " },\n", - " title=\"Enrichment Scores for Starts with Space vs Starts with Capital\",\n", - " height=800,\n", - " width=800,\n", - ")\n", - "# reduce point size on the scatter only\n", - "fig.update_traces(marker=dict(size=2), selector=dict(mode=\"markers\"))\n", - "fig.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "RcmU_6I9BSdF" - }, - "outputs": [], - "source": [ - "token_sets_index = [\"1 digit\", \"2 digits\", \"3 digits\", \"4 digits\"]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", - ")\n", - "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "BboYni5ZBSdF" - }, - "outputs": [], - "source": [ - "token_sets_index = [\"nltk_pos_PRP\", \"nltk_pos_VBZ\", \"nltk_pos_NNP\"]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", - ")\n", - "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "xKrcnE7GBSdF" - }, - "outputs": [], - "source": [ - "token_sets_index = [\"nltk_pos_VBN\", \"nltk_pos_VBG\", \"nltk_pos_VB\", \"nltk_pos_VBD\"]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", - ")\n", - "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sKVdyyxoBSdF" - }, - "outputs": [], - "source": [ - "token_sets_index = [\"nltk_pos_WP\", \"nltk_pos_RBR\", \"nltk_pos_WDT\", \"nltk_pos_RB\"]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", - ")\n", - "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "YYC3GFscBSdF" - }, - "outputs": [], - "source": [ - "token_sets_index = [\"a\", \"e\", \"i\", \"o\", \"u\"]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", - ")\n", - "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JvmgQ_YmBSdF" - }, - "outputs": [], - "source": [ - "token_sets_index = [\"negative_words\", \"positive_words\"]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", - ")\n", - "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ltYmy5lMBSdF" - }, - "outputs": [], - "source": [ - "fig = px.scatter(\n", - " df_enrichment_scores.apply(lambda x: -1 * np.log(1 - x))\n", - " .T.reset_index()\n", - " .rename(columns={\"index\": \"feature\"}),\n", - " x=\"negative_words\",\n", - " y=\"positive_words\",\n", - " marginal_x=\"histogram\",\n", - " marginal_y=\"histogram\",\n", - " labels={\n", - " \"starts_with_space\": \"Starts with Space\",\n", - " \"starts_with_capital\": \"Starts with Capital\",\n", - " },\n", - " title=\"Enrichment Scores for Starts with Space vs Starts with Capital\",\n", - " height=800,\n", - " width=800,\n", - " hover_name=\"feature\",\n", - ")\n", - "# reduce point size on the scatter only\n", - "fig.update_traces(marker=dict(size=2), selector=dict(mode=\"markers\"))\n", - "fig.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "5otoyu2SBSdF" - }, - "outputs": [], - "source": [ - "token_sets_index = [\"contains_close_bracket\", \"contains_open_bracket\"]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", - ")\n", - "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JlN-ScWhBSdF" - }, - "outputs": [], - "source": [ - "token_sets_index = [\n", - " \"1910's\",\n", - " \"1920's\",\n", - " \"1930's\",\n", - " \"1940's\",\n", - " \"1950's\",\n", - " \"1960's\",\n", - " \"1970's\",\n", - " \"1980's\",\n", - " \"1990's\",\n", - " \"2000's\",\n", - " \"2010's\",\n", - "]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", - ")\n", - "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "SDmCs9OhBSdF" - }, - "outputs": [], - "source": [ - "token_sets_index = [\"positive_words\", \"negative_words\"]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", - ")\n", - "manhattan_plot_enrichment_scores(df_enrichment_scores, label_threshold=0.98).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ilzC33VlBSdF" - }, - "outputs": [], - "source": [ - "token_sets_index = [\"boys_names\", \"girls_names\"]\n", - "token_set_selected = {\n", - " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", - "}\n", - "df_enrichment_scores = get_enrichment_df(\n", - " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", - ")\n", - "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "D0mTCic7BSdG" - }, - "outputs": [], - "source": [ - "tmp_df = df_enrichment_scores.apply(lambda x: -1 * np.log(1 - x)).T\n", - "color = (\n", - " W_U_stats_df_dec.sort_values(\"skewness\", ascending=False)\n", - " .head(5000)[\"skewness\"]\n", - " .values\n", - ")\n", - "fig = px.scatter(\n", - " tmp_df.reset_index().rename(columns={\"index\": \"feature\"}),\n", - " x=\"boys_names\",\n", - " y=\"girls_names\",\n", - " marginal_x=\"histogram\",\n", - " marginal_y=\"histogram\",\n", - " # color = color,\n", - " labels={\n", - " \"boys_names\": \"Enrichment Score (Boys Names)\",\n", - " \"girls_names\": \"Enrichment Score (Girls Names)\",\n", - " },\n", - " height=600,\n", - " width=800,\n", - " hover_name=\"feature\",\n", - ")\n", - "# reduce point size on the scatter only\n", - "fig.update_traces(marker=dict(size=3), selector=dict(mode=\"markers\"))\n", - "# annotate any features where the absolute distance between boys names and girls names > 3\n", - "for feature in df_enrichment_scores.columns:\n", - " if abs(tmp_df[\"boys_names\"][feature] - tmp_df[\"girls_names\"][feature]) > 2.9:\n", - " fig.add_annotation(\n", - " x=tmp_df[\"boys_names\"][feature] - 0.4,\n", - " y=tmp_df[\"girls_names\"][feature] + 0.1,\n", - " text=f\"{feature}\",\n", - " showarrow=False,\n", - " )\n", - "\n", - "\n", - "fig.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VwVcsxnkBSdG" - }, - "source": [ - "## Digging into Particular Features\n", - "\n", - "When we do these enrichments, I generate the logit weight histograms by category using the following function. It's important to make sure the categories you group by are in the columns of df_enrichment_scores." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "GgndTFdFBSdG" - }, - "outputs": [], - "source": [ - "for category in [\"boys_names\"]:\n", - " plot_top_k_feature_projections_by_token_and_category(\n", - " token_set_selected,\n", - " df_enrichment_scores,\n", - " category=category,\n", - " dec_projection_onto_W_U=dec_projection_onto_W_U,\n", - " model=model,\n", - " log_y=False,\n", - " histnorm=None,\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AKP3u0D2BSdG" - }, - "source": [ - "# Appendix Results: Logit Weight distribution Statistics Accross All Layers" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Vhht2iCQBSdG" - }, - "outputs": [], - "source": [ - "W_U_stats_df_dec_all_layers = get_all_stats_dfs(\n", - " gpt2_small_sparse_autoencoders, gpt2_small_sae_sparsities, model, cosine_sim=True\n", - ")\n", - "\n", - "display(W_U_stats_df_dec_all_layers.shape)\n", - "display(W_U_stats_df_dec_all_layers.head())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ok3DgORLBSdG" - }, - "outputs": [], - "source": [ - "# Let's plot the percentiles of the skewness and kurtosis by layer\n", - "tmp_df = W_U_stats_df_dec_all_layers.groupby(\"layer\")[\"skewness\"].describe(\n", - " percentiles=[0.01, 0.05, 0.10, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]\n", - ")\n", - "tmp_df = tmp_df[[\"1%\", \"5%\", \"10%\", \"25%\", \"50%\", \"75%\", \"90%\", \"95%\", \"99%\"]]\n", - "\n", - "fig = px.area(\n", - " tmp_df,\n", - " title=\"Skewness by Layer\",\n", - " width=800,\n", - " height=600,\n", - " color_discrete_sequence=px.colors.sequential.Turbo,\n", - ").show()\n", - "\n", - "\n", - "tmp_df = W_U_stats_df_dec_all_layers.groupby(\"layer\")[\"kurtosis\"].describe(\n", - " percentiles=[0.01, 0.05, 0.10, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]\n", - ")\n", - "tmp_df = tmp_df[[\"1%\", \"5%\", \"10%\", \"25%\", \"50%\", \"75%\", \"90%\", \"95%\", \"99%\"]]\n", - "\n", - "fig = px.area(\n", - " tmp_df,\n", - " title=\"Kurtosis by Layer\",\n", - " width=800,\n", - " height=600,\n", - " color_discrete_sequence=px.colors.sequential.Turbo,\n", - ")\n", - "\n", - "fig.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "UOkRZiunBSdG" - }, - "outputs": [], - "source": [ - "# let's make a pretty color scheme\n", - "from plotly.colors import n_colors\n", - "\n", - "colors = n_colors(\"rgb(5, 200, 200)\", \"rgb(200, 10, 10)\", 13, colortype=\"rgb\")\n", - "\n", - "# Make a box plot of the skewness by layer\n", - "fig = px.box(\n", - " W_U_stats_df_dec_all_layers,\n", - " x=\"layer\",\n", - " y=\"skewness\",\n", - " color=\"layer\",\n", - " color_discrete_sequence=colors,\n", - " height=600,\n", - " width=1200,\n", - " title=\"Skewness cos(W_U,W_dec) by Layer in GPT2 Small Residual Stream SAEs\",\n", - " labels={\"layer\": \"Layer\", \"skewnss\": \"Skewness\"},\n", - ")\n", - "fig.update_xaxes(showticklabels=True, dtick=1)\n", - "\n", - "# increase font size\n", - "fig.update_layout(font=dict(size=16))\n", - "fig.show()\n", - "\n", - "# Make a box plot of the skewness by layer\n", - "fig = px.box(\n", - " W_U_stats_df_dec_all_layers,\n", - " x=\"layer\",\n", - " y=\"kurtosis\",\n", - " color=\"layer\",\n", - " color_discrete_sequence=colors,\n", - " height=600,\n", - " width=1200,\n", - " log_y=True,\n", - " title=\"log kurtosis cos(W_U,W_dec) by Layer in GPT2 Small Residual Stream SAEs\",\n", - " labels={\"layer\": \"Layer\", \"kurtosis\": \"Log Kurtosis\"},\n", - ")\n", - "fig.update_xaxes(showticklabels=True, dtick=1)\n", - "\n", - "# increase font size\n", - "fig.update_layout(font=dict(size=16))\n", - "fig.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "hYNYdY3wBSdG" - }, - "outputs": [], - "source": [ - "# scatter\n", - "fig = px.scatter(\n", - " W_U_stats_df_dec_all_layers[W_U_stats_df_dec_all_layers.log_feature_sparsity >= -9],\n", - " # W_U_stats_df_dec_all_layers[W_U_stats_df_dec_all_layers.layer == 8],\n", - " x=\"skewness\",\n", - " y=\"kurtosis\",\n", - " color=\"std\",\n", - " color_continuous_scale=\"Portland\",\n", - " hover_name=\"feature\",\n", - " # color_continuous_midpoint = 0,\n", - " # range_color = [-4,-1],\n", - " log_y=True,\n", - " height=800,\n", - " # width = 2000,\n", - " # facet_col=\"layer\",\n", - " # facet_col_wrap=5,\n", - " animation_frame=\"layer\",\n", - ")\n", - "fig.update_yaxes(matches=None)\n", - "fig.for_each_yaxis(lambda yaxis: yaxis.update(showticklabels=True))\n", - "\n", - "# decrease point size\n", - "fig.update_traces(marker=dict(size=5))\n", - "fig.show()\n", - "fig.write_html(\"skewness_kurtosis_scatter_all_layers.html\")" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "T4", - "machine_shape": "hm", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.7" - } - }, - "nbformat": 4, - "nbformat_minor": 0 + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "W6UyzFGfBSdA" + }, + "source": [ + "# Understanding SAE Features with the Logit Lens\n", + "\n", + "This notebook demonstrates how to use the mats_sae_training library to perform the analysis documented the post \"[Understanding SAE Features with the Logit Lens](https://www.alignmentforum.org/posts/qykrYY6rXXM7EEs8Q/understanding-sae-features-with-the-logit-lens)\".\n", + "\n", + "As such, the notebook will include sections for:\n", + "- Loading in GPT2-Small Residual Stream SAEs from Huggingface.\n", + "- Performing Virtual Weight Based Analysis of features (specifically looking at the logit weight distributions).\n", + "- Programmatically opening neuronpedia tabs to engage with public dashboards on [neuronpedia](https://www.neuronpedia.org/).\n", + "- Performing Token Set Enrichment Analysis (based on Gene Set Enrichment Analysis)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ff-PciWXBSdB" + }, + "source": [ + "## Set Up\n", + "\n", + "Here we'll load various functions for things like:\n", + "- downloading and loading our SAEs from huggingface.\n", + "- opening neuronpedia from a jupyter cell.\n", + "- calculating statistics of the logit weight distributions.\n", + "- performing Token Set Enrichment Analysis (TSEA) and plotting the results." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vImmTg-8BSdC" + }, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sOd2C0e1BfN1" + }, + "outputs": [], + "source": [ + "try:\n", + " # For Google Colab, a high RAM instance is needed\n", + " import google.colab # type: ignore\n", + " from google.colab import output\n", + "\n", + " %pip install sae-lens transformer-lens\n", + "except:\n", + " from IPython import get_ipython # type: ignore\n", + "\n", + " ipython = get_ipython()\n", + " assert ipython is not None\n", + " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", + " ipython.run_line_magic(\"autoreload\", \"2\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QmdAd_25BSdC" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "import plotly_express as px\n", + "\n", + "from transformer_lens import HookedTransformer\n", + "\n", + "# Model Loading\n", + "from sae_lens import SAE\n", + "from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list\n", + "\n", + "# Virtual Weight / Feature Statistics Functions\n", + "from sae_lens.analysis.feature_statistics import (\n", + " get_all_stats_dfs,\n", + " get_W_U_W_dec_stats_df,\n", + ")\n", + "\n", + "# Enrichment Analysis Functions\n", + "from sae_lens.analysis.tsea import (\n", + " get_enrichment_df,\n", + " manhattan_plot_enrichment_scores,\n", + " plot_top_k_feature_projections_by_token_and_category,\n", + ")\n", + "from sae_lens.analysis.tsea import (\n", + " get_baby_name_sets,\n", + " get_letter_gene_sets,\n", + " generate_pos_sets,\n", + " get_test_gene_sets,\n", + " get_gene_set_from_regex,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0VEAe5FjBSdD" + }, + "source": [ + "### Loading GPT2 Small and SAE Weights\n", + "\n", + "This will take a while the first time you run it, but will be quick thereafter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HQ904zDOBSdD" + }, + "outputs": [], + "source": [ + "model = HookedTransformer.from_pretrained(\"gpt2-small\")\n", + "# this is an outdated way to load the SAE. We need to have feature spartisity loadable through the new interface to remove it.\n", + "gpt2_small_sparse_autoencoders = {}\n", + "gpt2_small_sae_sparsities = {}\n", + "\n", + "for layer in range(12):\n", + " sae, original_cfg_dict, sparsity = SAE.from_pretrained(\n", + " release=\"gpt2-small-res-jb\",\n", + " sae_id=\"blocks.0.hook_resid_pre\",\n", + " device=\"cpu\",\n", + " )\n", + " gpt2_small_sparse_autoencoders[f\"blocks.{layer}.hook_resid_pre\"] = sae\n", + " gpt2_small_sae_sparsities[f\"blocks.{layer}.hook_resid_pre\"] = sparsity" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4pFYJKeNBSdD" + }, + "source": [ + "# Statistical Properties of Feature Logit Distributions\n", + "\n", + "In the post I study layer 8 (for no particular reason). At the end of this notebook is code for visualizing these statistics across all layers. Feel free to change the layer here and explore different layers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LEK1jEpEBSdD" + }, + "outputs": [], + "source": [ + "# In the post, I focus on layer 8\n", + "layer = 8\n", + "\n", + "# get the corresponding SAE and feature sparsities.\n", + "sparse_autoencoder = gpt2_small_sparse_autoencoders[f\"blocks.{layer}.hook_resid_pre\"]\n", + "log_feature_sparsity = gpt2_small_sae_sparsities[f\"blocks.{layer}.hook_resid_pre\"].cpu()\n", + "\n", + "W_dec = sparse_autoencoder.W_dec.detach().cpu()\n", + "\n", + "# calculate the statistics of the logit weight distributions\n", + "W_U_stats_df_dec, dec_projection_onto_W_U = get_W_U_W_dec_stats_df(\n", + " W_dec, model, cosine_sim=False\n", + ")\n", + "W_U_stats_df_dec[\"sparsity\"] = (\n", + " log_feature_sparsity # add feature sparsity since it is often interesting.\n", + ")\n", + "display(W_U_stats_df_dec)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GK56PDl3BSdE" + }, + "outputs": [], + "source": [ + "# Let's look at the distribution of the 3rd / 4th moments. I found these aren't as useful on their own as joint distributions can be.\n", + "px.histogram(\n", + " W_U_stats_df_dec,\n", + " x=\"skewness\",\n", + " width=800,\n", + " height=300,\n", + " nbins=1000,\n", + " title=\"Skewness of the Logit Weight Distributions\",\n", + ").show()\n", + "\n", + "px.histogram(\n", + " W_U_stats_df_dec,\n", + " x=np.log10(W_U_stats_df_dec[\"kurtosis\"]),\n", + " width=800,\n", + " height=300,\n", + " nbins=1000,\n", + " title=\"Kurtosis of the Logit Weight Distributions\",\n", + ").show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4_6c12ftBSdE" + }, + "outputs": [], + "source": [ + "fig = px.scatter(\n", + " W_U_stats_df_dec,\n", + " x=\"skewness\",\n", + " y=\"kurtosis\",\n", + " color=\"std\",\n", + " color_continuous_scale=\"Portland\",\n", + " hover_name=\"feature\",\n", + " width=800,\n", + " height=500,\n", + " log_y=True, # Kurtosis has larger outliers so logging creates a nicer scale.\n", + " labels={\"x\": \"Skewness\", \"y\": \"Kurtosis\", \"color\": \"Standard Deviation\"},\n", + " title=f\"Layer {8}: Skewness vs Kurtosis of the Logit Weight Distributions\",\n", + ")\n", + "\n", + "# decrease point size\n", + "fig.update_traces(marker=dict(size=3))\n", + "\n", + "\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KuWSPqqjBSdE" + }, + "outputs": [], + "source": [ + "# then you can query accross combinations of the statistics to find features of interest and open them in neuronpedia.\n", + "tmp_df = W_U_stats_df_dec[[\"feature\", \"skewness\", \"kurtosis\", \"std\"]]\n", + "# tmp_df = tmp_df[(tmp_df[\"std\"] > 0.04)]\n", + "# tmp_df = tmp_df[(tmp_df[\"skewness\"] > 0.65)]\n", + "tmp_df = tmp_df[(tmp_df[\"skewness\"] > 3)]\n", + "tmp_df = tmp_df.sort_values(\"skewness\", ascending=False).head(10)\n", + "display(tmp_df)\n", + "\n", + "# if desired, open the features in neuronpedia\n", + "get_neuronpedia_quick_list(sparse_autoencoder, list(tmp_df.feature))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AHGG86wkBSdE" + }, + "source": [ + "# Token Set Enrichment Analysis\n", + "\n", + "We now proceed to token set enrichment analysis. I highly recommend reading my AlignmentForum post (espeically the case studies) before reading too much into any of these results.\n", + "Also read this [post](https://transformer-circuits.pub/2024/qualitative-essay/index.html) for good general perspectives on statistics here." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "J4tuqT9SBSdE" + }, + "source": [ + "## Defining Our Token Sets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7mUxqd6mBSdE" + }, + "outputs": [], + "source": [ + "import nltk\n", + "\n", + "nltk.download(\"averaged_perceptron_tagger\")\n", + "# get the vocab we need to filter to formulate token sets.\n", + "vocab = model.tokenizer.get_vocab() # type: ignore\n", + "\n", + "# make a regex dictionary to specify more sets.\n", + "regex_dict = {\n", + " \"starts_with_space\": r\"Ġ.*\",\n", + " \"starts_with_capital\": r\"^Ġ*[A-Z].*\",\n", + " \"starts_with_lower\": r\"^Ġ*[a-z].*\",\n", + " \"all_digits\": r\"^Ġ*\\d+$\",\n", + " \"is_punctuation\": r\"^[^\\w\\s]+$\",\n", + " \"contains_close_bracket\": r\".*\\).*\",\n", + " \"contains_open_bracket\": r\".*\\(.*\",\n", + " \"all_caps\": r\"Ġ*[A-Z]+$\",\n", + " \"1 digit\": r\"Ġ*\\d{1}$\",\n", + " \"2 digits\": r\"Ġ*\\d{2}$\",\n", + " \"3 digits\": r\"Ġ*\\d{3}$\",\n", + " \"4 digits\": r\"Ġ*\\d{4}$\",\n", + " \"length_1\": r\"^Ġ*\\w{1}$\",\n", + " \"length_2\": r\"^Ġ*\\w{2}$\",\n", + " \"length_3\": r\"^Ġ*\\w{3}$\",\n", + " \"length_4\": r\"^Ġ*\\w{4}$\",\n", + " \"length_5\": r\"^Ġ*\\w{5}$\",\n", + "}\n", + "\n", + "# print size of gene sets\n", + "all_token_sets = get_letter_gene_sets(vocab)\n", + "for key, value in regex_dict.items():\n", + " gene_set = get_gene_set_from_regex(vocab, value)\n", + " all_token_sets[key] = gene_set\n", + "\n", + "# some other sets that can be interesting\n", + "baby_name_sets = get_baby_name_sets(vocab)\n", + "pos_sets = generate_pos_sets(vocab)\n", + "arbitrary_sets = get_test_gene_sets(model)\n", + "\n", + "all_token_sets = {**all_token_sets, **pos_sets}\n", + "all_token_sets = {**all_token_sets, **arbitrary_sets}\n", + "all_token_sets = {**all_token_sets, **baby_name_sets}\n", + "\n", + "# for each gene set, convert to string and print the first 5 tokens\n", + "for token_set_name, gene_set in sorted(\n", + " all_token_sets.items(), key=lambda x: len(x[1]), reverse=True\n", + "):\n", + " tokens = [model.to_string(id) for id in list(gene_set)][:10] # type: ignore\n", + " print(f\"{token_set_name}, has {len(gene_set)} genes\")\n", + " print(tokens)\n", + " print(\"----\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vxctX05KBSdE" + }, + "source": [ + "## Performing Token Set Enrichment Analysis\n", + "\n", + "Below we perform token set enrichment analysis on various token sets. In practice, we'd likely perform tests accross all tokens and large libraries of sets simultaneously but to make it easier to run, we look at features with higher skew and select of a few token sets at a time to consider." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "eHwM7qVlBSdF" + }, + "outputs": [], + "source": [ + "features_ordered_by_skew = (\n", + " W_U_stats_df_dec[\"skewness\"].sort_values(ascending=False).head(5000).index.to_list()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QxuiBx4NBSdF" + }, + "outputs": [], + "source": [ + "# filter our list.\n", + "token_sets_index = [\n", + " \"starts_with_space\",\n", + " \"starts_with_capital\",\n", + " \"all_digits\",\n", + " \"is_punctuation\",\n", + " \"all_caps\",\n", + "]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "\n", + "# calculate the enrichment scores\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, # use the logit weight values as our rankings over tokens.\n", + " features_ordered_by_skew, # subset by these features\n", + " token_set_selected, # use token_sets\n", + ")\n", + "\n", + "manhattan_plot_enrichment_scores(\n", + " df_enrichment_scores, label_threshold=0, top_n=3 # use our enrichment scores\n", + ").show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yQ0n0aKTBSdF" + }, + "outputs": [], + "source": [ + "fig = px.scatter(\n", + " df_enrichment_scores.apply(lambda x: -1 * np.log(1 - x)).T,\n", + " x=\"starts_with_space\",\n", + " y=\"starts_with_capital\",\n", + " marginal_x=\"histogram\",\n", + " marginal_y=\"histogram\",\n", + " labels={\n", + " \"starts_with_space\": \"Starts with Space\",\n", + " \"starts_with_capital\": \"Starts with Capital\",\n", + " },\n", + " title=\"Enrichment Scores for Starts with Space vs Starts with Capital\",\n", + " height=800,\n", + " width=800,\n", + ")\n", + "# reduce point size on the scatter only\n", + "fig.update_traces(marker=dict(size=2), selector=dict(mode=\"markers\"))\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RcmU_6I9BSdF" + }, + "outputs": [], + "source": [ + "token_sets_index = [\"1 digit\", \"2 digits\", \"3 digits\", \"4 digits\"]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", + ")\n", + "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BboYni5ZBSdF" + }, + "outputs": [], + "source": [ + "token_sets_index = [\"nltk_pos_PRP\", \"nltk_pos_VBZ\", \"nltk_pos_NNP\"]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", + ")\n", + "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "xKrcnE7GBSdF" + }, + "outputs": [], + "source": [ + "token_sets_index = [\"nltk_pos_VBN\", \"nltk_pos_VBG\", \"nltk_pos_VB\", \"nltk_pos_VBD\"]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", + ")\n", + "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sKVdyyxoBSdF" + }, + "outputs": [], + "source": [ + "token_sets_index = [\"nltk_pos_WP\", \"nltk_pos_RBR\", \"nltk_pos_WDT\", \"nltk_pos_RB\"]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", + ")\n", + "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YYC3GFscBSdF" + }, + "outputs": [], + "source": [ + "token_sets_index = [\"a\", \"e\", \"i\", \"o\", \"u\"]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", + ")\n", + "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JvmgQ_YmBSdF" + }, + "outputs": [], + "source": [ + "token_sets_index = [\"negative_words\", \"positive_words\"]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", + ")\n", + "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ltYmy5lMBSdF" + }, + "outputs": [], + "source": [ + "fig = px.scatter(\n", + " df_enrichment_scores.apply(lambda x: -1 * np.log(1 - x))\n", + " .T.reset_index()\n", + " .rename(columns={\"index\": \"feature\"}),\n", + " x=\"negative_words\",\n", + " y=\"positive_words\",\n", + " marginal_x=\"histogram\",\n", + " marginal_y=\"histogram\",\n", + " labels={\n", + " \"starts_with_space\": \"Starts with Space\",\n", + " \"starts_with_capital\": \"Starts with Capital\",\n", + " },\n", + " title=\"Enrichment Scores for Starts with Space vs Starts with Capital\",\n", + " height=800,\n", + " width=800,\n", + " hover_name=\"feature\",\n", + ")\n", + "# reduce point size on the scatter only\n", + "fig.update_traces(marker=dict(size=2), selector=dict(mode=\"markers\"))\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5otoyu2SBSdF" + }, + "outputs": [], + "source": [ + "token_sets_index = [\"contains_close_bracket\", \"contains_open_bracket\"]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", + ")\n", + "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JlN-ScWhBSdF" + }, + "outputs": [], + "source": [ + "token_sets_index = [\n", + " \"1910's\",\n", + " \"1920's\",\n", + " \"1930's\",\n", + " \"1940's\",\n", + " \"1950's\",\n", + " \"1960's\",\n", + " \"1970's\",\n", + " \"1980's\",\n", + " \"1990's\",\n", + " \"2000's\",\n", + " \"2010's\",\n", + "]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", + ")\n", + "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "SDmCs9OhBSdF" + }, + "outputs": [], + "source": [ + "token_sets_index = [\"positive_words\", \"negative_words\"]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", + ")\n", + "manhattan_plot_enrichment_scores(df_enrichment_scores, label_threshold=0.98).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ilzC33VlBSdF" + }, + "outputs": [], + "source": [ + "token_sets_index = [\"boys_names\", \"girls_names\"]\n", + "token_set_selected = {\n", + " k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n", + "}\n", + "df_enrichment_scores = get_enrichment_df(\n", + " dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n", + ")\n", + "manhattan_plot_enrichment_scores(df_enrichment_scores).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "D0mTCic7BSdG" + }, + "outputs": [], + "source": [ + "tmp_df = df_enrichment_scores.apply(lambda x: -1 * np.log(1 - x)).T\n", + "color = (\n", + " W_U_stats_df_dec.sort_values(\"skewness\", ascending=False)\n", + " .head(5000)[\"skewness\"]\n", + " .values\n", + ")\n", + "fig = px.scatter(\n", + " tmp_df.reset_index().rename(columns={\"index\": \"feature\"}),\n", + " x=\"boys_names\",\n", + " y=\"girls_names\",\n", + " marginal_x=\"histogram\",\n", + " marginal_y=\"histogram\",\n", + " # color = color,\n", + " labels={\n", + " \"boys_names\": \"Enrichment Score (Boys Names)\",\n", + " \"girls_names\": \"Enrichment Score (Girls Names)\",\n", + " },\n", + " height=600,\n", + " width=800,\n", + " hover_name=\"feature\",\n", + ")\n", + "# reduce point size on the scatter only\n", + "fig.update_traces(marker=dict(size=3), selector=dict(mode=\"markers\"))\n", + "# annotate any features where the absolute distance between boys names and girls names > 3\n", + "for feature in df_enrichment_scores.columns:\n", + " if abs(tmp_df[\"boys_names\"][feature] - tmp_df[\"girls_names\"][feature]) > 2.9:\n", + " fig.add_annotation(\n", + " x=tmp_df[\"boys_names\"][feature] - 0.4,\n", + " y=tmp_df[\"girls_names\"][feature] + 0.1,\n", + " text=f\"{feature}\",\n", + " showarrow=False,\n", + " )\n", + "\n", + "\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VwVcsxnkBSdG" + }, + "source": [ + "## Digging into Particular Features\n", + "\n", + "When we do these enrichments, I generate the logit weight histograms by category using the following function. It's important to make sure the categories you group by are in the columns of df_enrichment_scores." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GgndTFdFBSdG" + }, + "outputs": [], + "source": [ + "for category in [\"boys_names\"]:\n", + " plot_top_k_feature_projections_by_token_and_category(\n", + " token_set_selected,\n", + " df_enrichment_scores,\n", + " category=category,\n", + " dec_projection_onto_W_U=dec_projection_onto_W_U,\n", + " model=model,\n", + " log_y=False,\n", + " histnorm=None,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AKP3u0D2BSdG" + }, + "source": [ + "# Appendix Results: Logit Weight distribution Statistics Accross All Layers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Vhht2iCQBSdG" + }, + "outputs": [], + "source": [ + "W_U_stats_df_dec_all_layers = get_all_stats_dfs(\n", + " gpt2_small_sparse_autoencoders, gpt2_small_sae_sparsities, model, cosine_sim=True\n", + ")\n", + "\n", + "display(W_U_stats_df_dec_all_layers.shape)\n", + "display(W_U_stats_df_dec_all_layers.head())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ok3DgORLBSdG" + }, + "outputs": [], + "source": [ + "# Let's plot the percentiles of the skewness and kurtosis by layer\n", + "tmp_df = W_U_stats_df_dec_all_layers.groupby(\"layer\")[\"skewness\"].describe(\n", + " percentiles=[0.01, 0.05, 0.10, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]\n", + ")\n", + "tmp_df = tmp_df[[\"1%\", \"5%\", \"10%\", \"25%\", \"50%\", \"75%\", \"90%\", \"95%\", \"99%\"]]\n", + "\n", + "fig = px.area(\n", + " tmp_df,\n", + " title=\"Skewness by Layer\",\n", + " width=800,\n", + " height=600,\n", + " color_discrete_sequence=px.colors.sequential.Turbo,\n", + ").show()\n", + "\n", + "\n", + "tmp_df = W_U_stats_df_dec_all_layers.groupby(\"layer\")[\"kurtosis\"].describe(\n", + " percentiles=[0.01, 0.05, 0.10, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]\n", + ")\n", + "tmp_df = tmp_df[[\"1%\", \"5%\", \"10%\", \"25%\", \"50%\", \"75%\", \"90%\", \"95%\", \"99%\"]]\n", + "\n", + "fig = px.area(\n", + " tmp_df,\n", + " title=\"Kurtosis by Layer\",\n", + " width=800,\n", + " height=600,\n", + " color_discrete_sequence=px.colors.sequential.Turbo,\n", + ")\n", + "\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UOkRZiunBSdG" + }, + "outputs": [], + "source": [ + "# let's make a pretty color scheme\n", + "from plotly.colors import n_colors\n", + "\n", + "colors = n_colors(\"rgb(5, 200, 200)\", \"rgb(200, 10, 10)\", 13, colortype=\"rgb\")\n", + "\n", + "# Make a box plot of the skewness by layer\n", + "fig = px.box(\n", + " W_U_stats_df_dec_all_layers,\n", + " x=\"layer\",\n", + " y=\"skewness\",\n", + " color=\"layer\",\n", + " color_discrete_sequence=colors,\n", + " height=600,\n", + " width=1200,\n", + " title=\"Skewness cos(W_U,W_dec) by Layer in GPT2 Small Residual Stream SAEs\",\n", + " labels={\"layer\": \"Layer\", \"skewnss\": \"Skewness\"},\n", + ")\n", + "fig.update_xaxes(showticklabels=True, dtick=1)\n", + "\n", + "# increase font size\n", + "fig.update_layout(font=dict(size=16))\n", + "fig.show()\n", + "\n", + "# Make a box plot of the skewness by layer\n", + "fig = px.box(\n", + " W_U_stats_df_dec_all_layers,\n", + " x=\"layer\",\n", + " y=\"kurtosis\",\n", + " color=\"layer\",\n", + " color_discrete_sequence=colors,\n", + " height=600,\n", + " width=1200,\n", + " log_y=True,\n", + " title=\"log kurtosis cos(W_U,W_dec) by Layer in GPT2 Small Residual Stream SAEs\",\n", + " labels={\"layer\": \"Layer\", \"kurtosis\": \"Log Kurtosis\"},\n", + ")\n", + "fig.update_xaxes(showticklabels=True, dtick=1)\n", + "\n", + "# increase font size\n", + "fig.update_layout(font=dict(size=16))\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "hYNYdY3wBSdG" + }, + "outputs": [], + "source": [ + "# scatter\n", + "fig = px.scatter(\n", + " W_U_stats_df_dec_all_layers[W_U_stats_df_dec_all_layers.log_feature_sparsity >= -9],\n", + " # W_U_stats_df_dec_all_layers[W_U_stats_df_dec_all_layers.layer == 8],\n", + " x=\"skewness\",\n", + " y=\"kurtosis\",\n", + " color=\"std\",\n", + " color_continuous_scale=\"Portland\",\n", + " hover_name=\"feature\",\n", + " # color_continuous_midpoint = 0,\n", + " # range_color = [-4,-1],\n", + " log_y=True,\n", + " height=800,\n", + " # width = 2000,\n", + " # facet_col=\"layer\",\n", + " # facet_col_wrap=5,\n", + " animation_frame=\"layer\",\n", + ")\n", + "fig.update_yaxes(matches=None)\n", + "fig.for_each_yaxis(lambda yaxis: yaxis.update(showticklabels=True))\n", + "\n", + "# decrease point size\n", + "fig.update_traces(marker=dict(size=5))\n", + "fig.show()\n", + "fig.write_html(\"skewness_kurtosis_scatter_all_layers.html\")" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "machine_shape": "hm", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/tutorials/pretokenizing_datasets.ipynb b/tutorials/pretokenizing_datasets.ipynb index 60fa649f..5483a873 100644 --- a/tutorials/pretokenizing_datasets.ipynb +++ b/tutorials/pretokenizing_datasets.ipynb @@ -52,19 +52,16 @@ "\n", "cfg = PretokenizeRunnerConfig(\n", " tokenizer_name=\"gpt2\",\n", - " dataset_path=\"NeelNanda/c4-10k\", # this is just a tiny test dataset\n", + " dataset_path=\"NeelNanda/c4-10k\", # this is just a tiny test dataset\n", " shuffle=True,\n", - " num_proc=4, # increase this number depending on how many CPUs you have\n", - "\n", + " num_proc=4, # increase this number depending on how many CPUs you have\n", " # tweak these settings depending on the model\n", " context_size=128,\n", " begin_batch_token=\"bos\",\n", " begin_sequence_token=None,\n", " sequence_separator_token=\"eos\",\n", - "\n", " # uncomment to upload to huggingface\n", " # hf_repo_id=\"your-username/c4-10k-tokenized-gpt2\"\n", - "\n", " # uncomment to save the dataset locally\n", " # save_path=\"./c4-10k-tokenized-gpt2\"\n", ")\n", @@ -91,7 +88,7 @@ "\n", "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n", "\n", - "tokenized_row = dataset['input_ids'][5]\n", + "tokenized_row = dataset[\"input_ids\"][5]\n", "\n", "print(f\"Row has {len(tokenized_row)} tokens\")\n", "print(f\"Decoded: {tokenizer.decode(tokenized_row)}\")" diff --git a/tutorials/training_a_gated_sae.ipynb b/tutorials/training_a_gated_sae.ipynb index ff319f9f..878e4a98 100644 --- a/tutorials/training_a_gated_sae.ipynb +++ b/tutorials/training_a_gated_sae.ipynb @@ -1,700 +1,702 @@ { - "cells": [ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "5O8tQblzOVHu" + }, + "source": [ + "# A Very Basic Gated SAE Training Run" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "shAFb9-lOVHu" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "LeRi_tw2dhae" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: sae-lens in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (3.3.0)\n", + "Requirement already satisfied: transformer-lens in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (1.19.0)\n", + "Requirement already satisfied: circuitsvis in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (1.43.2)\n", + "Requirement already satisfied: automated-interpretability<0.0.4,>=0.0.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.0.3)\n", + "Requirement already satisfied: babe<0.0.8,>=0.0.7 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.0.7)\n", + "Requirement already satisfied: datasets<3.0.0,>=2.17.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (2.19.2)\n", + "Requirement already satisfied: matplotlib<4.0.0,>=3.8.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (3.9.0)\n", + "Requirement already satisfied: matplotlib-inline<0.2.0,>=0.1.6 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.1.7)\n", + "Requirement already satisfied: nltk<4.0.0,>=3.8.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (3.8.1)\n", + "Requirement already satisfied: plotly<6.0.0,>=5.19.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (5.22.0)\n", + "Requirement already satisfied: plotly-express<0.5.0,>=0.4.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.4.1)\n", + "Requirement already satisfied: pytest-profiling<2.0.0,>=1.7.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (1.7.0)\n", + "Requirement already satisfied: python-dotenv<2.0.0,>=1.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (1.0.1)\n", + "Requirement already satisfied: pyyaml<7.0.0,>=6.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (6.0.1)\n", + "Requirement already satisfied: pyzmq==26.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (26.0.0)\n", + "Requirement already satisfied: sae-vis<0.3.0,>=0.2.18 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.2.18)\n", + "Requirement already satisfied: safetensors<0.5.0,>=0.4.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.4.3)\n", + "Requirement already satisfied: transformers<5.0.0,>=4.38.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (4.41.2)\n", + "Requirement already satisfied: typer<0.13.0,>=0.12.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.12.3)\n", + "Requirement already satisfied: accelerate>=0.23.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.31.0)\n", + "Requirement already satisfied: beartype<0.15.0,>=0.14.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.14.1)\n", + "Requirement already satisfied: better-abc<0.0.4,>=0.0.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.0.3)\n", + "Requirement already satisfied: einops>=0.6.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.7.0)\n", + "Requirement already satisfied: fancy-einsum>=0.0.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.0.3)\n", + "Requirement already satisfied: jaxtyping>=0.2.11 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.2.29)\n", + "Requirement already satisfied: numpy>=1.24 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (1.26.4)\n", + "Requirement already satisfied: pandas>=1.1.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (2.2.2)\n", + "Requirement already satisfied: rich>=12.6.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (13.7.1)\n", + "Requirement already satisfied: sentencepiece in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.2.0)\n", + "Requirement already satisfied: torch>=1.10 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (2.1.2)\n", + "Requirement already satisfied: tqdm>=4.64.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (4.66.4)\n", + "Requirement already satisfied: typing-extensions in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (4.12.2)\n", + "Requirement already satisfied: wandb>=0.13.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.17.1)\n", + "Requirement already satisfied: importlib-metadata>=5.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (7.1.0)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (2.18.1)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: triton==2.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (2.1.0)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->circuitsvis) (12.5.40)\n", + "Requirement already satisfied: filelock in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from triton==2.1.0->circuitsvis) (3.14.0)\n", + "Requirement already satisfied: packaging>=20.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from accelerate>=0.23.0->transformer-lens) (24.0)\n", + "Requirement already satisfied: psutil in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from accelerate>=0.23.0->transformer-lens) (5.9.0)\n", + "Requirement already satisfied: huggingface-hub in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from accelerate>=0.23.0->transformer-lens) (0.23.3)\n", + "Requirement already satisfied: blobfile<3.0.0,>=2.1.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.1.1)\n", + "Requirement already satisfied: boostedblob<0.16.0,>=0.15.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.15.3)\n", + "Requirement already satisfied: httpx<0.28.0,>=0.27.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.27.0)\n", + "Requirement already satisfied: orjson<4.0.0,>=3.10.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.10.4)\n", + "Requirement already satisfied: pytest<9.0.0,>=8.1.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (8.2.2)\n", + "Requirement already satisfied: scikit-learn<2.0.0,>=1.4.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.5.0)\n", + "Requirement already satisfied: tiktoken<0.7.0,>=0.6.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.6.0)\n", + "Requirement already satisfied: py2store in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from babe<0.0.8,>=0.0.7->sae-lens) (0.1.20)\n", + "Requirement already satisfied: graze in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from babe<0.0.8,>=0.0.7->sae-lens) (0.1.17)\n", + "Requirement already satisfied: pyarrow>=12.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (16.1.0)\n", + "Requirement already satisfied: pyarrow-hotfix in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (0.6)\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (0.3.8)\n", + "Requirement already satisfied: requests>=2.32.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (2.32.3)\n", + "Requirement already satisfied: xxhash in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (3.4.1)\n", + "Requirement already satisfied: multiprocess in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (0.70.16)\n", + "Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets<3.0.0,>=2.17.1->sae-lens) (2024.3.1)\n", + "Requirement already satisfied: aiohttp in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (3.9.5)\n", + "Requirement already satisfied: zipp>=0.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from importlib-metadata>=5.1.0->circuitsvis) (3.19.2)\n", + "Requirement already satisfied: typeguard==2.13.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from jaxtyping>=0.2.11->transformer-lens) (2.13.3)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (1.2.1)\n", + "Requirement already satisfied: cycler>=0.10 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (0.12.1)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (4.53.0)\n", + "Requirement already satisfied: kiwisolver>=1.3.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (1.4.5)\n", + "Requirement already satisfied: pillow>=8 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (10.3.0)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (3.1.2)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (2.9.0)\n", + "Requirement already satisfied: traitlets in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib-inline<0.2.0,>=0.1.6->sae-lens) (5.14.3)\n", + "Requirement already satisfied: click in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nltk<4.0.0,>=3.8.1->sae-lens) (8.1.7)\n", + "Requirement already satisfied: joblib in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nltk<4.0.0,>=3.8.1->sae-lens) (1.4.2)\n", + "Requirement already satisfied: regex>=2021.8.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nltk<4.0.0,>=3.8.1->sae-lens) (2024.5.15)\n", + "Requirement already satisfied: pytz>=2020.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pandas>=1.1.5->transformer-lens) (2024.1)\n", + "Requirement already satisfied: tzdata>=2022.7 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pandas>=1.1.5->transformer-lens) (2024.1)\n", + "Requirement already satisfied: tenacity>=6.2.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly<6.0.0,>=5.19.0->sae-lens) (8.3.0)\n", + "Requirement already satisfied: statsmodels>=0.9.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly-express<0.5.0,>=0.4.1->sae-lens) (0.14.2)\n", + "Requirement already satisfied: scipy>=0.18 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly-express<0.5.0,>=0.4.1->sae-lens) (1.13.1)\n", + "Requirement already satisfied: patsy>=0.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly-express<0.5.0,>=0.4.1->sae-lens) (0.5.6)\n", + "Requirement already satisfied: six in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest-profiling<2.0.0,>=1.7.0->sae-lens) (1.16.0)\n", + "Requirement already satisfied: gprof2dot in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest-profiling<2.0.0,>=1.7.0->sae-lens) (2024.6.6)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from rich>=12.6.0->transformer-lens) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from rich>=12.6.0->transformer-lens) (2.18.0)\n", + "Requirement already satisfied: dataclasses-json<0.7.0,>=0.6.4 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-vis<0.3.0,>=0.2.18->sae-lens) (0.6.7)\n", + "Requirement already satisfied: eindex-callum<0.2.0,>=0.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-vis<0.3.0,>=0.2.18->sae-lens) (0.1.1)\n", + "Requirement already satisfied: sympy in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from torch>=1.10->transformer-lens) (1.12.1)\n", + "Requirement already satisfied: networkx in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from torch>=1.10->transformer-lens) (3.3)\n", + "Requirement already satisfied: jinja2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from torch>=1.10->transformer-lens) (3.1.4)\n", + "Requirement already satisfied: tokenizers<0.20,>=0.19 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformers<5.0.0,>=4.38.1->sae-lens) (0.19.1)\n", + "Requirement already satisfied: shellingham>=1.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from typer<0.13.0,>=0.12.3->sae-lens) (1.5.4)\n", + "Requirement already satisfied: docker-pycreds>=0.4.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (0.4.0)\n", + "Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (3.1.43)\n", + "Requirement already satisfied: platformdirs in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (4.2.2)\n", + "Requirement already satisfied: protobuf!=4.21.0,<6,>=3.19.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (5.27.1)\n", + "Requirement already satisfied: sentry-sdk>=1.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (2.5.1)\n", + "Requirement already satisfied: setproctitle in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (1.3.3)\n", + "Requirement already satisfied: setuptools in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (70.0.0)\n", + "Requirement already satisfied: pycryptodomex~=3.8 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from blobfile<3.0.0,>=2.1.1->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.20.0)\n", + "Requirement already satisfied: urllib3<3,>=1.25.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from blobfile<3.0.0,>=2.1.1->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.2.1)\n", + "Requirement already satisfied: lxml~=4.9 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from blobfile<3.0.0,>=2.1.1->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (4.9.4)\n", + "Requirement already satisfied: uvloop>=0.16.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from boostedblob<0.16.0,>=0.15.3->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.19.0)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (23.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (1.9.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (4.0.3)\n", + "Requirement already satisfied: marshmallow<4.0.0,>=3.18.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from dataclasses-json<0.7.0,>=0.6.4->sae-vis<0.3.0,>=0.2.18->sae-lens) (3.21.3)\n", + "Requirement already satisfied: typing-inspect<1,>=0.4.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from dataclasses-json<0.7.0,>=0.6.4->sae-vis<0.3.0,>=0.2.18->sae-lens) (0.9.0)\n", + "Requirement already satisfied: gitdb<5,>=4.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from gitpython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens) (4.0.11)\n", + "Requirement already satisfied: anyio in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (4.4.0)\n", + "Requirement already satisfied: certifi in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2024.6.2)\n", + "Requirement already satisfied: httpcore==1.* in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.0.5)\n", + "Requirement already satisfied: idna in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.7)\n", + "Requirement already satisfied: sniffio in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.3.1)\n", + "Requirement already satisfied: h11<0.15,>=0.13 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpcore==1.*->httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.14.0)\n", + "Requirement already satisfied: mdurl~=0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=12.6.0->transformer-lens) (0.1.2)\n", + "Requirement already satisfied: iniconfig in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.0.0)\n", + "Requirement already satisfied: pluggy<2.0,>=1.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.5.0)\n", + "Requirement already satisfied: exceptiongroup>=1.0.0rc8 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.2.0)\n", + "Requirement already satisfied: tomli>=1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.0.1)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from requests>=2.32.1->datasets<3.0.0,>=2.17.1->sae-lens) (3.3.2)\n", + "Requirement already satisfied: threadpoolctl>=3.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from scikit-learn<2.0.0,>=1.4.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.5.0)\n", + "Requirement already satisfied: dol in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from graze->babe<0.0.8,>=0.0.7->sae-lens) (0.2.47)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from jinja2->torch>=1.10->transformer-lens) (2.1.5)\n", + "Requirement already satisfied: config2py in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from py2store->babe<0.0.8,>=0.0.7->sae-lens) (0.1.33)\n", + "Requirement already satisfied: importlib-resources in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from py2store->babe<0.0.8,>=0.0.7->sae-lens) (6.4.0)\n", + "Requirement already satisfied: mpmath<1.4.0,>=1.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sympy->torch>=1.10->transformer-lens) (1.3.0)\n", + "Requirement already satisfied: smmap<6,>=3.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens) (5.0.1)\n", + "Requirement already satisfied: mypy-extensions>=0.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from typing-inspect<1,>=0.4.0->dataclasses-json<0.7.0,>=0.6.4->sae-vis<0.3.0,>=0.2.18->sae-lens) (1.0.0)\n", + "Requirement already satisfied: i2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from config2py->py2store->babe<0.0.8,>=0.0.7->sae-lens) (0.1.17)\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "try:\n", + " # import google.colab # type: ignore\n", + " # from google.colab import output\n", + " %pip install sae-lens transformer-lens circuitsvis\n", + "except:\n", + " from IPython import get_ipython # type: ignore\n", + "\n", + " ipython = get_ipython()\n", + " assert ipython is not None\n", + " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", + " ipython.run_line_magic(\"autoreload\", \"2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "uy-b3CcSOVHu", + "outputId": "58ce28d0-f91f-436d-cf87-76bb26e2ecaf" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: cuda\n" + ] + } + ], + "source": [ + "import torch\n", + "import os\n", + "\n", + "from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner\n", + "\n", + "if torch.cuda.is_available():\n", + " device = \"cuda\"\n", + "elif torch.backends.mps.is_available():\n", + " device = \"mps\"\n", + "else:\n", + " device = \"cpu\"\n", + "\n", + "print(\"Using device:\", device)\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jCHtPycOOVHw" + }, + "source": [ + "## Training on MLP Out" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "oAsZCAdJOVHw" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Run name: 16384-L1-5-LR-5e-05-Tokens-1.229e+08\n", + "n_tokens_per_buffer (millions): 0.262144\n", + "Lower bound: n_contexts_per_buffer (millions): 0.001024\n", + "Total training steps: 30000\n", + "Total wandb updates: 1000\n", + "n_tokens_per_feature_sampling_window (millions): 1048.576\n", + "n_tokens_per_dead_feature_window (millions): 1048.576\n", + "We will reset the sparsity calculation 30 times.\n", + "Number tokens in sparsity calculation window: 4.10e+06\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n", + "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", + " return self.fget.__get__(instance, owner)()\n" + ] + }, { - "cell_type": "markdown", - "metadata": { - "id": "5O8tQblzOVHu" - }, - "source": [ - "# A Very Basic Gated SAE Training Run" + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded pretrained model tiny-stories-1L-21M into HookedTransformer\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mcurt-tigges\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.17.1" + ], + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "metadata": { - "id": "shAFb9-lOVHu" - }, - "source": [ - "## Setup" + "data": { + "text/html": [ + "Run data is saved locally in /home/curttigges/projects/SAELens/tutorials/wandb/run-20240611_143204-n7cy5v24" + ], + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "LeRi_tw2dhae" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: sae-lens in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (3.3.0)\n", - "Requirement already satisfied: transformer-lens in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (1.19.0)\n", - "Requirement already satisfied: circuitsvis in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (1.43.2)\n", - "Requirement already satisfied: automated-interpretability<0.0.4,>=0.0.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.0.3)\n", - "Requirement already satisfied: babe<0.0.8,>=0.0.7 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.0.7)\n", - "Requirement already satisfied: datasets<3.0.0,>=2.17.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (2.19.2)\n", - "Requirement already satisfied: matplotlib<4.0.0,>=3.8.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (3.9.0)\n", - "Requirement already satisfied: matplotlib-inline<0.2.0,>=0.1.6 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.1.7)\n", - "Requirement already satisfied: nltk<4.0.0,>=3.8.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (3.8.1)\n", - "Requirement already satisfied: plotly<6.0.0,>=5.19.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (5.22.0)\n", - "Requirement already satisfied: plotly-express<0.5.0,>=0.4.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.4.1)\n", - "Requirement already satisfied: pytest-profiling<2.0.0,>=1.7.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (1.7.0)\n", - "Requirement already satisfied: python-dotenv<2.0.0,>=1.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (1.0.1)\n", - "Requirement already satisfied: pyyaml<7.0.0,>=6.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (6.0.1)\n", - "Requirement already satisfied: pyzmq==26.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (26.0.0)\n", - "Requirement already satisfied: sae-vis<0.3.0,>=0.2.18 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.2.18)\n", - "Requirement already satisfied: safetensors<0.5.0,>=0.4.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.4.3)\n", - "Requirement already satisfied: transformers<5.0.0,>=4.38.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (4.41.2)\n", - "Requirement already satisfied: typer<0.13.0,>=0.12.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.12.3)\n", - "Requirement already satisfied: accelerate>=0.23.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.31.0)\n", - "Requirement already satisfied: beartype<0.15.0,>=0.14.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.14.1)\n", - "Requirement already satisfied: better-abc<0.0.4,>=0.0.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.0.3)\n", - "Requirement already satisfied: einops>=0.6.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.7.0)\n", - "Requirement already satisfied: fancy-einsum>=0.0.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.0.3)\n", - "Requirement already satisfied: jaxtyping>=0.2.11 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.2.29)\n", - "Requirement already satisfied: numpy>=1.24 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (1.26.4)\n", - "Requirement already satisfied: pandas>=1.1.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (2.2.2)\n", - "Requirement already satisfied: rich>=12.6.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (13.7.1)\n", - "Requirement already satisfied: sentencepiece in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.2.0)\n", - "Requirement already satisfied: torch>=1.10 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (2.1.2)\n", - "Requirement already satisfied: tqdm>=4.64.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (4.66.4)\n", - "Requirement already satisfied: typing-extensions in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (4.12.2)\n", - "Requirement already satisfied: wandb>=0.13.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.17.1)\n", - "Requirement already satisfied: importlib-metadata>=5.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (7.1.0)\n", - "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.3.1)\n", - "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", - "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", - "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", - "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (8.9.2.26)\n", - "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (11.0.2.54)\n", - "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (10.3.2.106)\n", - "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (11.4.5.107)\n", - "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.0.106)\n", - "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (2.18.1)\n", - "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", - "Requirement already satisfied: triton==2.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (2.1.0)\n", - "Requirement already satisfied: nvidia-nvjitlink-cu12 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->circuitsvis) (12.5.40)\n", - "Requirement already satisfied: filelock in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from triton==2.1.0->circuitsvis) (3.14.0)\n", - "Requirement already satisfied: packaging>=20.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from accelerate>=0.23.0->transformer-lens) (24.0)\n", - "Requirement already satisfied: psutil in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from accelerate>=0.23.0->transformer-lens) (5.9.0)\n", - "Requirement already satisfied: huggingface-hub in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from accelerate>=0.23.0->transformer-lens) (0.23.3)\n", - "Requirement already satisfied: blobfile<3.0.0,>=2.1.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.1.1)\n", - "Requirement already satisfied: boostedblob<0.16.0,>=0.15.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.15.3)\n", - "Requirement already satisfied: httpx<0.28.0,>=0.27.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.27.0)\n", - "Requirement already satisfied: orjson<4.0.0,>=3.10.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.10.4)\n", - "Requirement already satisfied: pytest<9.0.0,>=8.1.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (8.2.2)\n", - "Requirement already satisfied: scikit-learn<2.0.0,>=1.4.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.5.0)\n", - "Requirement already satisfied: tiktoken<0.7.0,>=0.6.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.6.0)\n", - "Requirement already satisfied: py2store in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from babe<0.0.8,>=0.0.7->sae-lens) (0.1.20)\n", - "Requirement already satisfied: graze in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from babe<0.0.8,>=0.0.7->sae-lens) (0.1.17)\n", - "Requirement already satisfied: pyarrow>=12.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (16.1.0)\n", - "Requirement already satisfied: pyarrow-hotfix in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (0.6)\n", - "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (0.3.8)\n", - "Requirement already satisfied: requests>=2.32.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (2.32.3)\n", - "Requirement already satisfied: xxhash in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (3.4.1)\n", - "Requirement already satisfied: multiprocess in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (0.70.16)\n", - "Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets<3.0.0,>=2.17.1->sae-lens) (2024.3.1)\n", - "Requirement already satisfied: aiohttp in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (3.9.5)\n", - "Requirement already satisfied: zipp>=0.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from importlib-metadata>=5.1.0->circuitsvis) (3.19.2)\n", - "Requirement already satisfied: typeguard==2.13.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from jaxtyping>=0.2.11->transformer-lens) (2.13.3)\n", - "Requirement already satisfied: contourpy>=1.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (1.2.1)\n", - "Requirement already satisfied: cycler>=0.10 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (0.12.1)\n", - "Requirement already satisfied: fonttools>=4.22.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (4.53.0)\n", - "Requirement already satisfied: kiwisolver>=1.3.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (1.4.5)\n", - "Requirement already satisfied: pillow>=8 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (10.3.0)\n", - "Requirement already satisfied: pyparsing>=2.3.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (3.1.2)\n", - "Requirement already satisfied: python-dateutil>=2.7 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (2.9.0)\n", - "Requirement already satisfied: traitlets in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib-inline<0.2.0,>=0.1.6->sae-lens) (5.14.3)\n", - "Requirement already satisfied: click in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nltk<4.0.0,>=3.8.1->sae-lens) (8.1.7)\n", - "Requirement already satisfied: joblib in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nltk<4.0.0,>=3.8.1->sae-lens) (1.4.2)\n", - "Requirement already satisfied: regex>=2021.8.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nltk<4.0.0,>=3.8.1->sae-lens) (2024.5.15)\n", - "Requirement already satisfied: pytz>=2020.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pandas>=1.1.5->transformer-lens) (2024.1)\n", - "Requirement already satisfied: tzdata>=2022.7 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pandas>=1.1.5->transformer-lens) (2024.1)\n", - "Requirement already satisfied: tenacity>=6.2.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly<6.0.0,>=5.19.0->sae-lens) (8.3.0)\n", - "Requirement already satisfied: statsmodels>=0.9.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly-express<0.5.0,>=0.4.1->sae-lens) (0.14.2)\n", - "Requirement already satisfied: scipy>=0.18 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly-express<0.5.0,>=0.4.1->sae-lens) (1.13.1)\n", - "Requirement already satisfied: patsy>=0.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly-express<0.5.0,>=0.4.1->sae-lens) (0.5.6)\n", - "Requirement already satisfied: six in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest-profiling<2.0.0,>=1.7.0->sae-lens) (1.16.0)\n", - "Requirement already satisfied: gprof2dot in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest-profiling<2.0.0,>=1.7.0->sae-lens) (2024.6.6)\n", - "Requirement already satisfied: markdown-it-py>=2.2.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from rich>=12.6.0->transformer-lens) (3.0.0)\n", - "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from rich>=12.6.0->transformer-lens) (2.18.0)\n", - "Requirement already satisfied: dataclasses-json<0.7.0,>=0.6.4 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-vis<0.3.0,>=0.2.18->sae-lens) (0.6.7)\n", - "Requirement already satisfied: eindex-callum<0.2.0,>=0.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-vis<0.3.0,>=0.2.18->sae-lens) (0.1.1)\n", - "Requirement already satisfied: sympy in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from torch>=1.10->transformer-lens) (1.12.1)\n", - "Requirement already satisfied: networkx in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from torch>=1.10->transformer-lens) (3.3)\n", - "Requirement already satisfied: jinja2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from torch>=1.10->transformer-lens) (3.1.4)\n", - "Requirement already satisfied: tokenizers<0.20,>=0.19 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformers<5.0.0,>=4.38.1->sae-lens) (0.19.1)\n", - "Requirement already satisfied: shellingham>=1.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from typer<0.13.0,>=0.12.3->sae-lens) (1.5.4)\n", - "Requirement already satisfied: docker-pycreds>=0.4.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (0.4.0)\n", - "Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (3.1.43)\n", - "Requirement already satisfied: platformdirs in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (4.2.2)\n", - "Requirement already satisfied: protobuf!=4.21.0,<6,>=3.19.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (5.27.1)\n", - "Requirement already satisfied: sentry-sdk>=1.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (2.5.1)\n", - "Requirement already satisfied: setproctitle in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (1.3.3)\n", - "Requirement already satisfied: setuptools in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (70.0.0)\n", - "Requirement already satisfied: pycryptodomex~=3.8 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from blobfile<3.0.0,>=2.1.1->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.20.0)\n", - "Requirement already satisfied: urllib3<3,>=1.25.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from blobfile<3.0.0,>=2.1.1->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.2.1)\n", - "Requirement already satisfied: lxml~=4.9 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from blobfile<3.0.0,>=2.1.1->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (4.9.4)\n", - "Requirement already satisfied: uvloop>=0.16.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from boostedblob<0.16.0,>=0.15.3->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.19.0)\n", - "Requirement already satisfied: aiosignal>=1.1.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (1.3.1)\n", - "Requirement already satisfied: attrs>=17.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (23.2.0)\n", - "Requirement already satisfied: frozenlist>=1.1.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (1.4.1)\n", - "Requirement already satisfied: multidict<7.0,>=4.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (6.0.5)\n", - "Requirement already satisfied: yarl<2.0,>=1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (1.9.4)\n", - "Requirement already satisfied: async-timeout<5.0,>=4.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (4.0.3)\n", - "Requirement already satisfied: marshmallow<4.0.0,>=3.18.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from dataclasses-json<0.7.0,>=0.6.4->sae-vis<0.3.0,>=0.2.18->sae-lens) (3.21.3)\n", - "Requirement already satisfied: typing-inspect<1,>=0.4.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from dataclasses-json<0.7.0,>=0.6.4->sae-vis<0.3.0,>=0.2.18->sae-lens) (0.9.0)\n", - "Requirement already satisfied: gitdb<5,>=4.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from gitpython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens) (4.0.11)\n", - "Requirement already satisfied: anyio in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (4.4.0)\n", - "Requirement already satisfied: certifi in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2024.6.2)\n", - "Requirement already satisfied: httpcore==1.* in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.0.5)\n", - "Requirement already satisfied: idna in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.7)\n", - "Requirement already satisfied: sniffio in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.3.1)\n", - "Requirement already satisfied: h11<0.15,>=0.13 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpcore==1.*->httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.14.0)\n", - "Requirement already satisfied: mdurl~=0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=12.6.0->transformer-lens) (0.1.2)\n", - "Requirement already satisfied: iniconfig in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.0.0)\n", - "Requirement already satisfied: pluggy<2.0,>=1.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.5.0)\n", - "Requirement already satisfied: exceptiongroup>=1.0.0rc8 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.2.0)\n", - "Requirement already satisfied: tomli>=1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.0.1)\n", - "Requirement already satisfied: charset-normalizer<4,>=2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from requests>=2.32.1->datasets<3.0.0,>=2.17.1->sae-lens) (3.3.2)\n", - "Requirement already satisfied: threadpoolctl>=3.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from scikit-learn<2.0.0,>=1.4.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.5.0)\n", - "Requirement already satisfied: dol in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from graze->babe<0.0.8,>=0.0.7->sae-lens) (0.2.47)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from jinja2->torch>=1.10->transformer-lens) (2.1.5)\n", - "Requirement already satisfied: config2py in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from py2store->babe<0.0.8,>=0.0.7->sae-lens) (0.1.33)\n", - "Requirement already satisfied: importlib-resources in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from py2store->babe<0.0.8,>=0.0.7->sae-lens) (6.4.0)\n", - "Requirement already satisfied: mpmath<1.4.0,>=1.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sympy->torch>=1.10->transformer-lens) (1.3.0)\n", - "Requirement already satisfied: smmap<6,>=3.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens) (5.0.1)\n", - "Requirement already satisfied: mypy-extensions>=0.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from typing-inspect<1,>=0.4.0->dataclasses-json<0.7.0,>=0.6.4->sae-vis<0.3.0,>=0.2.18->sae-lens) (1.0.0)\n", - "Requirement already satisfied: i2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from config2py->py2store->babe<0.0.8,>=0.0.7->sae-lens) (0.1.17)\n", - "Note: you may need to restart the kernel to use updated packages.\n" - ] - } + "data": { + "text/html": [ + "Syncing run 16384-L1-5-LR-5e-05-Tokens-1.229e+08 to Weights & Biases (docs)
" ], - "source": [ - "try:\n", - " #import google.colab # type: ignore\n", - " #from google.colab import output\n", - " %pip install sae-lens transformer-lens circuitsvis\n", - "except:\n", - " from IPython import get_ipython # type: ignore\n", - " ipython = get_ipython(); assert ipython is not None\n", - " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", - " ipython.run_line_magic(\"autoreload\", \"2\")" + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "uy-b3CcSOVHu", - "outputId": "58ce28d0-f91f-436d-cf87-76bb26e2ecaf" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using device: cuda\n" - ] - } + "data": { + "text/html": [ + " View project at https://wandb.ai/curt-tigges/sae_lens_tutorial" ], - "source": [ - "import torch\n", - "import os\n", - "\n", - "from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner\n", - "\n", - "if torch.cuda.is_available():\n", - " device = \"cuda\"\n", - "elif torch.backends.mps.is_available():\n", - " device = \"mps\"\n", - "else:\n", - " device = \"cpu\"\n", - "\n", - "print(\"Using device:\", device)\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"" + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "metadata": { - "id": "jCHtPycOOVHw" - }, - "source": [ - "## Training on MLP Out" + "data": { + "text/html": [ + " View run at https://wandb.ai/curt-tigges/sae_lens_tutorial/runs/n7cy5v24" + ], + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Estimating norm scaling factor: 100%|██████████| 1000/1000 [00:33<00:00, 30.13it/s]\n", + "5500| MSE Loss 208.944 | L1 167.607: 0%| | 225280/122880000 [08:05<71:26:53, 476.86it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "interrupted, saving progress\n", + "done saving\n" + ] + }, + { + "ename": "InterruptedException", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mInterruptedException\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[3], line 63\u001b[0m\n\u001b[1;32m 9\u001b[0m cfg \u001b[38;5;241m=\u001b[39m LanguageModelSAERunnerConfig(\n\u001b[1;32m 10\u001b[0m \u001b[38;5;66;03m# Data Generating Function (Model + Training Distribution)\u001b[39;00m\n\u001b[1;32m 11\u001b[0m variant\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbaseline\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;66;03m# we'll use the gated variant.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 60\u001b[0m dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 61\u001b[0m )\n\u001b[1;32m 62\u001b[0m \u001b[38;5;66;03m# look at the next cell to see some instruction for what to do while this is running.\u001b[39;00m\n\u001b[0;32m---> 63\u001b[0m sparse_autoencoder \u001b[38;5;241m=\u001b[39m \u001b[43mSAETrainingRunner\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcfg\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/projects/SAELens/sae_lens/sae_training_runner.py:87\u001b[0m, in \u001b[0;36mSAETrainingRunner.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 78\u001b[0m trainer \u001b[38;5;241m=\u001b[39m SAETrainer(\n\u001b[1;32m 79\u001b[0m model\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel,\n\u001b[1;32m 80\u001b[0m sae\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msae,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 83\u001b[0m cfg\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg,\n\u001b[1;32m 84\u001b[0m )\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compile_if_needed()\n\u001b[0;32m---> 87\u001b[0m sae \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_trainer_with_interruption_handling\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mlog_to_wandb:\n\u001b[1;32m 90\u001b[0m wandb\u001b[38;5;241m.\u001b[39mfinish()\n", + "File \u001b[0;32m~/projects/SAELens/sae_lens/sae_training_runner.py:130\u001b[0m, in \u001b[0;36mSAETrainingRunner.run_trainer_with_interruption_handling\u001b[0;34m(self, trainer)\u001b[0m\n\u001b[1;32m 127\u001b[0m signal\u001b[38;5;241m.\u001b[39msignal(signal\u001b[38;5;241m.\u001b[39mSIGTERM, interrupt_callback)\n\u001b[1;32m 129\u001b[0m \u001b[38;5;66;03m# train SAE\u001b[39;00m\n\u001b[0;32m--> 130\u001b[0m sae \u001b[38;5;241m=\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, InterruptedException):\n\u001b[1;32m 133\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minterrupted, saving progress\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/projects/SAELens/sae_lens/training/sae_trainer.py:162\u001b[0m, in \u001b[0;36mSAETrainer.fit\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 159\u001b[0m layer_acts \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mactivation_store\u001b[38;5;241m.\u001b[39mnext_batch()[:, \u001b[38;5;241m0\u001b[39m, :]\n\u001b[1;32m 160\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_training_tokens \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mtrain_batch_size_tokens\n\u001b[0;32m--> 162\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_train_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43msae\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msae\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msae_in\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlayer_acts\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mlog_to_wandb:\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_log_train_step(step_output)\n", + "File \u001b[0;32m~/projects/SAELens/sae_lens/training/sae_trainer.py:216\u001b[0m, in \u001b[0;36mSAETrainer._train_step\u001b[0;34m(self, sae, sae_in)\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[38;5;66;03m# for documentation on autocasting see:\u001b[39;00m\n\u001b[1;32m 213\u001b[0m \u001b[38;5;66;03m# https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html\u001b[39;00m\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mautocast_if_enabled:\n\u001b[0;32m--> 216\u001b[0m train_step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msae\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_forward_pass\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 217\u001b[0m \u001b[43m \u001b[49m\u001b[43msae_in\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msae_in\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 218\u001b[0m \u001b[43m \u001b[49m\u001b[43mdead_neuron_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdead_neurons\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 219\u001b[0m \u001b[43m \u001b[49m\u001b[43mcurrent_l1_coefficient\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_l1_coefficient\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 220\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 222\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m 223\u001b[0m did_fire \u001b[38;5;241m=\u001b[39m (train_step_output\u001b[38;5;241m.\u001b[39mfeature_acts \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mfloat()\u001b[38;5;241m.\u001b[39msum(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m\n", + "File \u001b[0;32m~/projects/SAELens/sae_lens/training/training_sae.py:303\u001b[0m, in \u001b[0;36mTrainingSAE.training_forward_pass\u001b[0;34m(self, sae_in, current_l1_coefficient, dead_neuron_mask)\u001b[0m\n\u001b[1;32m 295\u001b[0m l1_loss \u001b[38;5;241m=\u001b[39m (current_l1_coefficient \u001b[38;5;241m*\u001b[39m sparsity)\u001b[38;5;241m.\u001b[39mmean()\n\u001b[1;32m 296\u001b[0m loss \u001b[38;5;241m=\u001b[39m mse_loss \u001b[38;5;241m+\u001b[39m l1_loss \u001b[38;5;241m+\u001b[39m ghost_grad_loss\n\u001b[1;32m 298\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m TrainStepOutput(\n\u001b[1;32m 299\u001b[0m sae_in\u001b[38;5;241m=\u001b[39msae_in,\n\u001b[1;32m 300\u001b[0m sae_out\u001b[38;5;241m=\u001b[39msae_out,\n\u001b[1;32m 301\u001b[0m feature_acts\u001b[38;5;241m=\u001b[39mfeature_acts,\n\u001b[1;32m 302\u001b[0m loss\u001b[38;5;241m=\u001b[39mloss,\n\u001b[0;32m--> 303\u001b[0m mse_loss\u001b[38;5;241m=\u001b[39m\u001b[43mmse_loss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitem\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 304\u001b[0m l1_loss\u001b[38;5;241m=\u001b[39ml1_loss\u001b[38;5;241m.\u001b[39mitem(),\n\u001b[1;32m 305\u001b[0m ghost_grad_loss\u001b[38;5;241m=\u001b[39m(\n\u001b[1;32m 306\u001b[0m ghost_grad_loss\u001b[38;5;241m.\u001b[39mitem()\n\u001b[1;32m 307\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(ghost_grad_loss, torch\u001b[38;5;241m.\u001b[39mTensor)\n\u001b[1;32m 308\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m ghost_grad_loss\n\u001b[1;32m 309\u001b[0m ),\n\u001b[1;32m 310\u001b[0m )\n", + "File \u001b[0;32m~/projects/SAELens/sae_lens/sae_training_runner.py:25\u001b[0m, in \u001b[0;36minterrupt_callback\u001b[0;34m(sig_num, stack_frame)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minterrupt_callback\u001b[39m(sig_num: Any, stack_frame: Any):\n\u001b[0;32m---> 25\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m InterruptedException()\n", + "\u001b[0;31mInterruptedException\u001b[0m: " + ] + } + ], + "source": [ + "total_training_steps = 30_000 # probably we should do more\n", + "batch_size = 4096\n", + "total_training_tokens = total_training_steps * batch_size\n", + "\n", + "lr_warm_up_steps = 0\n", + "lr_decay_steps = total_training_steps // 5 # 20% of training\n", + "l1_warm_up_steps = total_training_steps // 20 # 5% of training\n", + "\n", + "cfg = LanguageModelSAERunnerConfig(\n", + " # Data Generating Function (Model + Training Distribution)\n", + " architecture=\"baseline\", # we'll use the gated variant.\n", + " model_name=\"tiny-stories-1L-21M\", # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)\n", + " hook_name=\"blocks.0.hook_mlp_out\", # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)\n", + " hook_layer=0, # Only one layer in the model.\n", + " d_in=1024, # the width of the mlp output.\n", + " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\", # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.\n", + " is_dataset_tokenized=True,\n", + " streaming=True, # we could pre-download the token dataset if it was small.\n", + " # SAE Parameters\n", + " mse_loss_normalization=None, # We won't normalize the mse loss,\n", + " expansion_factor=16, # the width of the SAE. Larger will result in better stats but slower training.\n", + " b_dec_init_method=\"zeros\", # The geometric median can be used to initialize the decoder weights.\n", + " apply_b_dec_to_input=False, # We won't apply the decoder weights to the input.\n", + " normalize_sae_decoder=False,\n", + " scale_sparsity_penalty_by_decoder_norm=True,\n", + " decoder_heuristic_init=True,\n", + " init_encoder_as_decoder_transpose=True,\n", + " normalize_activations=True,\n", + " # Training Parameters\n", + " lr=5e-5, # lower the better, we'll go fairly high to speed up the tutorial.\n", + " adam_beta1=0.9, # adam params (default, but once upon a time we experimented with these.)\n", + " adam_beta2=0.999,\n", + " lr_scheduler_name=\"constant\", # constant learning rate with warmup. Could be better schedules out there.\n", + " lr_warm_up_steps=lr_warm_up_steps, # this can help avoid too many dead features initially.\n", + " lr_decay_steps=lr_decay_steps, # this will help us avoid overfitting.\n", + " l1_coefficient=5, # will control how sparse the feature activations are\n", + " l1_warm_up_steps=l1_warm_up_steps, # this can help avoid too many dead features initially.\n", + " lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1)\n", + " train_batch_size_tokens=batch_size,\n", + " context_size=256, # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.\n", + " # Activation Store Parameters\n", + " n_batches_in_buffer=64, # controls how many activations we store / shuffle.\n", + " training_tokens=total_training_tokens, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.\n", + " store_batch_size_prompts=16,\n", + " # Resampling protocol\n", + " use_ghost_grads=False, # we don't use ghost grads anymore.\n", + " feature_sampling_window=1000, # this controls our reporting of feature sparsity stats\n", + " dead_feature_window=1000, # would effect resampling or ghost grads if we were using it.\n", + " dead_feature_threshold=1e-4, # would effect resampling or ghost grads if we were using it.\n", + " # WANDB\n", + " log_to_wandb=True, # always use wandb unless you are just testing code.\n", + " wandb_project=\"sae_lens_tutorial\",\n", + " wandb_log_frequency=30,\n", + " eval_every_n_wandb_logs=20,\n", + " # Misc\n", + " device=device,\n", + " seed=42,\n", + " n_checkpoints=0,\n", + " checkpoint_path=\"checkpoints\",\n", + " dtype=\"float32\",\n", + ")\n", + "# look at the next cell to see some instruction for what to do while this is running.\n", + "sparse_autoencoder = SAETrainingRunner(cfg).run()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Run name: 16384-L1-20-LR-5e-05-Tokens-1.229e+08\n", + "n_tokens_per_buffer (millions): 0.262144\n", + "Lower bound: n_contexts_per_buffer (millions): 0.001024\n", + "Total training steps: 30000\n", + "Total wandb updates: 1000\n", + "n_tokens_per_feature_sampling_window (millions): 1048.576\n", + "n_tokens_per_dead_feature_window (millions): 1048.576\n", + "We will reset the sparsity calculation 30 times.\n", + "Number tokens in sparsity calculation window: 4.10e+06\n" + ] }, { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "oAsZCAdJOVHw" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Run name: 16384-L1-5-LR-5e-05-Tokens-1.229e+08\n", - "n_tokens_per_buffer (millions): 0.262144\n", - "Lower bound: n_contexts_per_buffer (millions): 0.001024\n", - "Total training steps: 30000\n", - "Total wandb updates: 1000\n", - "n_tokens_per_feature_sampling_window (millions): 1048.576\n", - "n_tokens_per_dead_feature_window (millions): 1048.576\n", - "We will reset the sparsity calculation 30 times.\n", - "Number tokens in sparsity calculation window: 4.10e+06\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", - " warnings.warn(\n", - "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", - " return self.fget.__get__(instance, owner)()\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loaded pretrained model tiny-stories-1L-21M into HookedTransformer\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", - "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mcurt-tigges\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" - ] - }, - { - "data": { - "text/html": [ - "Tracking run with wandb version 0.17.1" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Run data is saved locally in /home/curttigges/projects/SAELens/tutorials/wandb/run-20240611_143204-n7cy5v24" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Syncing run 16384-L1-5-LR-5e-05-Tokens-1.229e+08 to Weights & Biases (docs)
" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - " View project at https://wandb.ai/curt-tigges/sae_lens_tutorial" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - " View run at https://wandb.ai/curt-tigges/sae_lens_tutorial/runs/n7cy5v24" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Estimating norm scaling factor: 100%|██████████| 1000/1000 [00:33<00:00, 30.13it/s]\n", - "5500| MSE Loss 208.944 | L1 167.607: 0%| | 225280/122880000 [08:05<71:26:53, 476.86it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "interrupted, saving progress\n", - "done saving\n" - ] - }, - { - "ename": "InterruptedException", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mInterruptedException\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[3], line 63\u001b[0m\n\u001b[1;32m 9\u001b[0m cfg \u001b[38;5;241m=\u001b[39m LanguageModelSAERunnerConfig(\n\u001b[1;32m 10\u001b[0m \u001b[38;5;66;03m# Data Generating Function (Model + Training Distribution)\u001b[39;00m\n\u001b[1;32m 11\u001b[0m variant\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbaseline\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;66;03m# we'll use the gated variant.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 60\u001b[0m dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 61\u001b[0m )\n\u001b[1;32m 62\u001b[0m \u001b[38;5;66;03m# look at the next cell to see some instruction for what to do while this is running.\u001b[39;00m\n\u001b[0;32m---> 63\u001b[0m sparse_autoencoder \u001b[38;5;241m=\u001b[39m \u001b[43mSAETrainingRunner\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcfg\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/projects/SAELens/sae_lens/sae_training_runner.py:87\u001b[0m, in \u001b[0;36mSAETrainingRunner.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 78\u001b[0m trainer \u001b[38;5;241m=\u001b[39m SAETrainer(\n\u001b[1;32m 79\u001b[0m model\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel,\n\u001b[1;32m 80\u001b[0m sae\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msae,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 83\u001b[0m cfg\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg,\n\u001b[1;32m 84\u001b[0m )\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compile_if_needed()\n\u001b[0;32m---> 87\u001b[0m sae \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_trainer_with_interruption_handling\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mlog_to_wandb:\n\u001b[1;32m 90\u001b[0m wandb\u001b[38;5;241m.\u001b[39mfinish()\n", - "File \u001b[0;32m~/projects/SAELens/sae_lens/sae_training_runner.py:130\u001b[0m, in \u001b[0;36mSAETrainingRunner.run_trainer_with_interruption_handling\u001b[0;34m(self, trainer)\u001b[0m\n\u001b[1;32m 127\u001b[0m signal\u001b[38;5;241m.\u001b[39msignal(signal\u001b[38;5;241m.\u001b[39mSIGTERM, interrupt_callback)\n\u001b[1;32m 129\u001b[0m \u001b[38;5;66;03m# train SAE\u001b[39;00m\n\u001b[0;32m--> 130\u001b[0m sae \u001b[38;5;241m=\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, InterruptedException):\n\u001b[1;32m 133\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minterrupted, saving progress\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m~/projects/SAELens/sae_lens/training/sae_trainer.py:162\u001b[0m, in \u001b[0;36mSAETrainer.fit\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 159\u001b[0m layer_acts \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mactivation_store\u001b[38;5;241m.\u001b[39mnext_batch()[:, \u001b[38;5;241m0\u001b[39m, :]\n\u001b[1;32m 160\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_training_tokens \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mtrain_batch_size_tokens\n\u001b[0;32m--> 162\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_train_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43msae\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msae\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msae_in\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlayer_acts\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mlog_to_wandb:\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_log_train_step(step_output)\n", - "File \u001b[0;32m~/projects/SAELens/sae_lens/training/sae_trainer.py:216\u001b[0m, in \u001b[0;36mSAETrainer._train_step\u001b[0;34m(self, sae, sae_in)\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[38;5;66;03m# for documentation on autocasting see:\u001b[39;00m\n\u001b[1;32m 213\u001b[0m \u001b[38;5;66;03m# https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html\u001b[39;00m\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mautocast_if_enabled:\n\u001b[0;32m--> 216\u001b[0m train_step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msae\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_forward_pass\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 217\u001b[0m \u001b[43m \u001b[49m\u001b[43msae_in\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msae_in\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 218\u001b[0m \u001b[43m \u001b[49m\u001b[43mdead_neuron_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdead_neurons\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 219\u001b[0m \u001b[43m \u001b[49m\u001b[43mcurrent_l1_coefficient\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_l1_coefficient\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 220\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 222\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m 223\u001b[0m did_fire \u001b[38;5;241m=\u001b[39m (train_step_output\u001b[38;5;241m.\u001b[39mfeature_acts \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mfloat()\u001b[38;5;241m.\u001b[39msum(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m\n", - "File \u001b[0;32m~/projects/SAELens/sae_lens/training/training_sae.py:303\u001b[0m, in \u001b[0;36mTrainingSAE.training_forward_pass\u001b[0;34m(self, sae_in, current_l1_coefficient, dead_neuron_mask)\u001b[0m\n\u001b[1;32m 295\u001b[0m l1_loss \u001b[38;5;241m=\u001b[39m (current_l1_coefficient \u001b[38;5;241m*\u001b[39m sparsity)\u001b[38;5;241m.\u001b[39mmean()\n\u001b[1;32m 296\u001b[0m loss \u001b[38;5;241m=\u001b[39m mse_loss \u001b[38;5;241m+\u001b[39m l1_loss \u001b[38;5;241m+\u001b[39m ghost_grad_loss\n\u001b[1;32m 298\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m TrainStepOutput(\n\u001b[1;32m 299\u001b[0m sae_in\u001b[38;5;241m=\u001b[39msae_in,\n\u001b[1;32m 300\u001b[0m sae_out\u001b[38;5;241m=\u001b[39msae_out,\n\u001b[1;32m 301\u001b[0m feature_acts\u001b[38;5;241m=\u001b[39mfeature_acts,\n\u001b[1;32m 302\u001b[0m loss\u001b[38;5;241m=\u001b[39mloss,\n\u001b[0;32m--> 303\u001b[0m mse_loss\u001b[38;5;241m=\u001b[39m\u001b[43mmse_loss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitem\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 304\u001b[0m l1_loss\u001b[38;5;241m=\u001b[39ml1_loss\u001b[38;5;241m.\u001b[39mitem(),\n\u001b[1;32m 305\u001b[0m ghost_grad_loss\u001b[38;5;241m=\u001b[39m(\n\u001b[1;32m 306\u001b[0m ghost_grad_loss\u001b[38;5;241m.\u001b[39mitem()\n\u001b[1;32m 307\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(ghost_grad_loss, torch\u001b[38;5;241m.\u001b[39mTensor)\n\u001b[1;32m 308\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m ghost_grad_loss\n\u001b[1;32m 309\u001b[0m ),\n\u001b[1;32m 310\u001b[0m )\n", - "File \u001b[0;32m~/projects/SAELens/sae_lens/sae_training_runner.py:25\u001b[0m, in \u001b[0;36minterrupt_callback\u001b[0;34m(sig_num, stack_frame)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minterrupt_callback\u001b[39m(sig_num: Any, stack_frame: Any):\n\u001b[0;32m---> 25\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m InterruptedException()\n", - "\u001b[0;31mInterruptedException\u001b[0m: " - ] - } + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n", + "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", + " return self.fget.__get__(instance, owner)()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded pretrained model tiny-stories-1L-21M into HookedTransformer\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: W&B API key is configured. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.17.1" ], - "source": [ - "total_training_steps = 30_000 # probably we should do more\n", - "batch_size = 4096\n", - "total_training_tokens = total_training_steps * batch_size\n", - "\n", - "lr_warm_up_steps = 0\n", - "lr_decay_steps = total_training_steps // 5 # 20% of training\n", - "l1_warm_up_steps = total_training_steps // 20 # 5% of training\n", - "\n", - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distribution)\n", - " architecture=\"baseline\", # we'll use the gated variant.\n", - " model_name=\"tiny-stories-1L-21M\", # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)\n", - " hook_name=\"blocks.0.hook_mlp_out\", # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)\n", - " hook_layer=0, # Only one layer in the model.\n", - " d_in=1024, # the width of the mlp output.\n", - " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\", # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.\n", - " is_dataset_tokenized=True,\n", - " streaming=True, # we could pre-download the token dataset if it was small.\n", - " # SAE Parameters\n", - " mse_loss_normalization=None, # We won't normalize the mse loss,\n", - " expansion_factor=16, # the width of the SAE. Larger will result in better stats but slower training.\n", - " b_dec_init_method=\"zeros\", # The geometric median can be used to initialize the decoder weights.\n", - " apply_b_dec_to_input=False, # We won't apply the decoder weights to the input.\n", - " normalize_sae_decoder=False,\n", - " scale_sparsity_penalty_by_decoder_norm=True,\n", - " decoder_heuristic_init=True,\n", - " init_encoder_as_decoder_transpose=True,\n", - " normalize_activations=True,\n", - " # Training Parameters\n", - " lr=5e-5, # lower the better, we'll go fairly high to speed up the tutorial.\n", - " adam_beta1=0.9, # adam params (default, but once upon a time we experimented with these.)\n", - " adam_beta2=0.999,\n", - " lr_scheduler_name=\"constant\", # constant learning rate with warmup. Could be better schedules out there.\n", - " lr_warm_up_steps=lr_warm_up_steps, # this can help avoid too many dead features initially.\n", - " lr_decay_steps=lr_decay_steps, # this will help us avoid overfitting.\n", - " l1_coefficient=5, # will control how sparse the feature activations are\n", - " l1_warm_up_steps=l1_warm_up_steps, # this can help avoid too many dead features initially.\n", - " lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1)\n", - " train_batch_size_tokens=batch_size,\n", - " context_size=256, # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=64, # controls how many activations we store / shuffle.\n", - " training_tokens=total_training_tokens, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.\n", - " store_batch_size_prompts=16,\n", - " # Resampling protocol\n", - " use_ghost_grads=False, # we don't use ghost grads anymore.\n", - " feature_sampling_window=1000, # this controls our reporting of feature sparsity stats\n", - " dead_feature_window=1000, # would effect resampling or ghost grads if we were using it.\n", - " dead_feature_threshold=1e-4, # would effect resampling or ghost grads if we were using it.\n", - " # WANDB\n", - " log_to_wandb=True, # always use wandb unless you are just testing code.\n", - " wandb_project=\"sae_lens_tutorial\",\n", - " wandb_log_frequency=30,\n", - " eval_every_n_wandb_logs=20,\n", - " # Misc\n", - " device=device,\n", - " seed=42,\n", - " n_checkpoints=0,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=\"float32\"\n", - ")\n", - "# look at the next cell to see some instruction for what to do while this is running.\n", - "sparse_autoencoder = SAETrainingRunner(cfg).run()" + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Run name: 16384-L1-20-LR-5e-05-Tokens-1.229e+08\n", - "n_tokens_per_buffer (millions): 0.262144\n", - "Lower bound: n_contexts_per_buffer (millions): 0.001024\n", - "Total training steps: 30000\n", - "Total wandb updates: 1000\n", - "n_tokens_per_feature_sampling_window (millions): 1048.576\n", - "n_tokens_per_dead_feature_window (millions): 1048.576\n", - "We will reset the sparsity calculation 30 times.\n", - "Number tokens in sparsity calculation window: 4.10e+06\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", - " warnings.warn(\n", - "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", - " return self.fget.__get__(instance, owner)()\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loaded pretrained model tiny-stories-1L-21M into HookedTransformer\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", - "\u001b[34m\u001b[1mwandb\u001b[0m: W&B API key is configured. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" - ] - }, - { - "data": { - "text/html": [ - "Tracking run with wandb version 0.17.1" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Run data is saved locally in /home/curttigges/projects/SAELens/tutorials/wandb/run-20240616_143959-ch6e0a5s" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Syncing run 16384-L1-20-LR-5e-05-Tokens-1.229e+08 to Weights & Biases (docs)
" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - " View project at https://wandb.ai/curt-tigges/gated_sae_testing" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - " View run at https://wandb.ai/curt-tigges/gated_sae_testing/runs/ch6e0a5s" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "30000| MSE Loss 143.062 | L1 0.000: 1%| | 1228800/122880000 [1:04:38<106:39:53, 316.81it/s]\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "

Run history:


details/current_l1_coefficient▁▂▂▃▃▄▄▅▅▆▆▇████████████████████████████
details/current_learning_rate████████████████████████████████▇▇▅▅▄▃▂▁
details/n_training_tokens▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/auxiliary_reconstruction_loss▁▃▃▄▄▅▅▅▆▆▆▇▇███████████████████████████
losses/ghost_grad_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss▁▂▃▄▄▅▅▅▆▆▆▆▆▆▇▆▆▆▆▆▆▇▆▆▇▆▆▆▇▆▇▆▆▆▆▇▆▇▆█
losses/overall_loss▁▃▄▄▅▆▆▆▇▇▇▇██████████████▇████▇▇█▇█▇█▇█
losses/sfn_sparsity_loss▂▃▅▆▆▇████▇▆▅▄▃▃▃▃▃▃▃▃▄▃▃▃▂▃▃▃▃▃▃▂▂▂▁▁▁▁
metrics/CE_loss_score██▇▇▆▆▆▆▆▆▆▆▆▅▆▆▅▆▅▃▅▁▆▅▆▆▅▅▃▆▅▅▆▅▆▆▅▅▃▄
metrics/ce_loss_with_ablation▅▃▁▆▄▃▅▆▅▄▄▁▅▃▄▃▃▄▁▃▄▆▄▄▄▃▆▄▃█▄▄▁▄▄▅█▄▄▃
metrics/ce_loss_with_sae▁▁▂▂▃▂▃▃▃▃▃▃▃▃▃▃▄▃▄▆▄█▃▄▃▃▄▄▆▃▃▄▃▄▃▃▄▄▆▅
metrics/ce_loss_without_sae▄▂▃█▃▁▃▃▄▅▃▄▂▃▄▃▄▃▃▃▃▃▄▁▂▂▆▃▄▅▃▄▃▅▇▅▃▃▃▂
metrics/explained_variance█▇▆▅▅▄▄▄▃▃▃▃▃▃▂▃▃▃▃▃▃▂▃▃▃▃▃▃▂▃▃▃▃▃▃▂▃▂▃▁
metrics/explained_variance_std▁▂▃▃▃▃▃▃▃▃▃▃▃▃▄▃▃▄▄▄▃▆▄▄▄▄▄▄▅▄▄▄▄▄▄▅▄▅▄█
metrics/l0█▅▅▂▃▅▆▄▅█▃▅▃▅▇▄▆▄▆▄▄▁▄▅▅▄▁▃▇▄▄▅▃▆▃▄▄▄▃▁
metrics/l2_norm▇▆▇▁▄▃▆▃▃▂▂▄▂▂▃▂▄▃▂▃▄▅▅▃▃▃▄▄▃▃▂▅▃▃▃▃▃▆▆█
metrics/l2_norm_in▃▃▇▂▄▂▅▅▅▅▄▆▃▃▁▃▄▅▃▂▅▂▅▄▃▄▃▅▂▅▄▃▃▄▁▄▂█▃▅
metrics/l2_ratio█▆▆▁▃▃▅▂▃▁▂▃▂▂▃▂▄▂▂▄▄▅▅▃▃▂▅▃▃▃▂▅▃▂▃▃▃▅▆█
metrics/mean_log10_feature_sparsity█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▅▄▄▂▁▁
sparsity/below_1e-5▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁▁▁▁▁▂▃▆██
sparsity/below_1e-6▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁▁▁▁▁▂▃▆██
sparsity/dead_features▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▁▁▁▁▁▁▁▁▁▁▃▃▄▇██
sparsity/mean_passes_since_fired▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▁▁▁▁▁▁▁▁▁▂▂▂▄▅▇█

Run summary:


details/current_l1_coefficient20
details/current_learning_rate0.0
details/n_training_tokens122880000
losses/auxiliary_reconstruction_loss227.78122
losses/ghost_grad_loss0.0
losses/l1_loss0.0
losses/mse_loss143.06226
losses/overall_loss434.46942
losses/sfn_sparsity_loss63.62593
metrics/CE_loss_score0.59248
metrics/ce_loss_with_ablation8.29373
metrics/ce_loss_with_sae4.50411
metrics/ce_loss_without_sae1.8969
metrics/explained_variance0.15973
metrics/explained_variance_std0.24142
metrics/l07705.52734
metrics/l2_norm14.99578
metrics/l2_norm_in17.58649
metrics/l2_ratio0.8463
metrics/mean_log10_feature_sparsity-0.74933
sparsity/below_1e-5681
sparsity/below_1e-6681
sparsity/dead_features681
sparsity/mean_passes_since_fired138.74988

" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - " View run 16384-L1-20-LR-5e-05-Tokens-1.229e+08 at: https://wandb.ai/curt-tigges/gated_sae_testing/runs/ch6e0a5s
View project at: https://wandb.ai/curt-tigges/gated_sae_testing
Synced 6 W&B file(s), 0 media file(s), 3 artifact file(s) and 0 other file(s)" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Find logs at: ./wandb/run-20240616_143959-ch6e0a5s/logs" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } + "data": { + "text/html": [ + "Run data is saved locally in /home/curttigges/projects/SAELens/tutorials/wandb/run-20240616_143959-ch6e0a5s" ], - "source": [ - "total_training_steps = 30_000 # probably we should do more\n", - "batch_size = 4096\n", - "total_training_tokens = total_training_steps * batch_size\n", - "\n", - "lr_warm_up_steps = 0\n", - "lr_decay_steps = total_training_steps // 5 # 20% of training\n", - "l1_warm_up_steps = 10_000 #total_training_steps // 20 # 5% of training\n", - "\n", - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distribution)\n", - " architecture=\"gated\", # we'll use the gated variant.\n", - " model_name=\"tiny-stories-1L-21M\", # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)\n", - " hook_name=\"blocks.0.hook_mlp_out\", # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)\n", - " hook_layer=0, # Only one layer in the model.\n", - " d_in=1024, # the width of the mlp output.\n", - " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\", # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.\n", - " is_dataset_tokenized=True,\n", - " streaming=True, # we could pre-download the token dataset if it was small.\n", - " # SAE Parameters\n", - " mse_loss_normalization=None, # We won't normalize the mse loss,\n", - " expansion_factor=16, # the width of the SAE. Larger will result in better stats but slower training.\n", - " b_dec_init_method=\"zeros\", # The geometric median can be used to initialize the decoder weights.\n", - " apply_b_dec_to_input=True, # We won't apply the decoder weights to the input.\n", - " normalize_sae_decoder=False,\n", - " scale_sparsity_penalty_by_decoder_norm=False,\n", - " decoder_heuristic_init=True,\n", - " init_encoder_as_decoder_transpose=True,\n", - " normalize_activations=False,\n", - " # Training Parameters\n", - " lr=5e-5, # lower the better, we'll go fairly high to speed up the tutorial.\n", - " adam_beta1=0.9, # adam params (default, but once upon a time we experimented with these.)\n", - " adam_beta2=0.999,\n", - " lr_scheduler_name=\"constant\", # constant learning rate with warmup. Could be better schedules out there.\n", - " lr_warm_up_steps=lr_warm_up_steps, # this can help avoid too many dead features initially.\n", - " lr_decay_steps=lr_decay_steps, # this will help us avoid overfitting.\n", - " l1_coefficient=20, # will control how sparse the feature activations are\n", - " l1_warm_up_steps=l1_warm_up_steps, # this can help avoid too many dead features initially.\n", - " lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1)\n", - " train_batch_size_tokens=batch_size,\n", - " context_size=256, # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=64, # controls how many activations we store / shuffle.\n", - " training_tokens=total_training_tokens, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.\n", - " store_batch_size_prompts=16,\n", - " # Resampling protocol\n", - " use_ghost_grads=False, # we don't use ghost grads anymore.\n", - " feature_sampling_window=1000, # this controls our reporting of feature sparsity stats\n", - " dead_feature_window=1000, # would effect resampling or ghost grads if we were using it.\n", - " dead_feature_threshold=1e-4, # would effect resampling or ghost grads if we were using it.\n", - " # WANDB\n", - " log_to_wandb=True, # always use wandb unless you are just testing code.\n", - " wandb_project=\"gated_sae_testing\",\n", - " wandb_log_frequency=30,\n", - " eval_every_n_wandb_logs=20,\n", - " # Misc\n", - " device=device,\n", - " seed=42,\n", - " n_checkpoints=0,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=\"float32\"\n", - ")\n", - "# look at the next cell to see some instruction for what to do while this is running.\n", - "sparse_autoencoder = SAETrainingRunner(cfg).run()" + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "T4", - "provenance": [] + "data": { + "text/html": [ + "Syncing run 16384-L1-20-LR-5e-05-Tokens-1.229e+08 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/curt-tigges/gated_sae_testing" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/curt-tigges/gated_sae_testing/runs/ch6e0a5s" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "30000| MSE Loss 143.062 | L1 0.000: 1%| | 1228800/122880000 [1:04:38<106:39:53, 316.81it/s]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "

Run history:


details/current_l1_coefficient▁▂▂▃▃▄▄▅▅▆▆▇████████████████████████████
details/current_learning_rate████████████████████████████████▇▇▅▅▄▃▂▁
details/n_training_tokens▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/auxiliary_reconstruction_loss▁▃▃▄▄▅▅▅▆▆▆▇▇███████████████████████████
losses/ghost_grad_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss▁▂▃▄▄▅▅▅▆▆▆▆▆▆▇▆▆▆▆▆▆▇▆▆▇▆▆▆▇▆▇▆▆▆▆▇▆▇▆█
losses/overall_loss▁▃▄▄▅▆▆▆▇▇▇▇██████████████▇████▇▇█▇█▇█▇█
losses/sfn_sparsity_loss▂▃▅▆▆▇████▇▆▅▄▃▃▃▃▃▃▃▃▄▃▃▃▂▃▃▃▃▃▃▂▂▂▁▁▁▁
metrics/CE_loss_score██▇▇▆▆▆▆▆▆▆▆▆▅▆▆▅▆▅▃▅▁▆▅▆▆▅▅▃▆▅▅▆▅▆▆▅▅▃▄
metrics/ce_loss_with_ablation▅▃▁▆▄▃▅▆▅▄▄▁▅▃▄▃▃▄▁▃▄▆▄▄▄▃▆▄▃█▄▄▁▄▄▅█▄▄▃
metrics/ce_loss_with_sae▁▁▂▂▃▂▃▃▃▃▃▃▃▃▃▃▄▃▄▆▄█▃▄▃▃▄▄▆▃▃▄▃▄▃▃▄▄▆▅
metrics/ce_loss_without_sae▄▂▃█▃▁▃▃▄▅▃▄▂▃▄▃▄▃▃▃▃▃▄▁▂▂▆▃▄▅▃▄▃▅▇▅▃▃▃▂
metrics/explained_variance█▇▆▅▅▄▄▄▃▃▃▃▃▃▂▃▃▃▃▃▃▂▃▃▃▃▃▃▂▃▃▃▃▃▃▂▃▂▃▁
metrics/explained_variance_std▁▂▃▃▃▃▃▃▃▃▃▃▃▃▄▃▃▄▄▄▃▆▄▄▄▄▄▄▅▄▄▄▄▄▄▅▄▅▄█
metrics/l0█▅▅▂▃▅▆▄▅█▃▅▃▅▇▄▆▄▆▄▄▁▄▅▅▄▁▃▇▄▄▅▃▆▃▄▄▄▃▁
metrics/l2_norm▇▆▇▁▄▃▆▃▃▂▂▄▂▂▃▂▄▃▂▃▄▅▅▃▃▃▄▄▃▃▂▅▃▃▃▃▃▆▆█
metrics/l2_norm_in▃▃▇▂▄▂▅▅▅▅▄▆▃▃▁▃▄▅▃▂▅▂▅▄▃▄▃▅▂▅▄▃▃▄▁▄▂█▃▅
metrics/l2_ratio█▆▆▁▃▃▅▂▃▁▂▃▂▂▃▂▄▂▂▄▄▅▅▃▃▂▅▃▃▃▂▅▃▂▃▃▃▅▆█
metrics/mean_log10_feature_sparsity█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▅▄▄▂▁▁
sparsity/below_1e-5▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁▁▁▁▁▂▃▆██
sparsity/below_1e-6▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁▁▁▁▁▂▃▆██
sparsity/dead_features▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▁▁▁▁▁▁▁▁▁▁▃▃▄▇██
sparsity/mean_passes_since_fired▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▁▁▁▁▁▁▁▁▁▂▂▂▄▅▇█

Run summary:


details/current_l1_coefficient20
details/current_learning_rate0.0
details/n_training_tokens122880000
losses/auxiliary_reconstruction_loss227.78122
losses/ghost_grad_loss0.0
losses/l1_loss0.0
losses/mse_loss143.06226
losses/overall_loss434.46942
losses/sfn_sparsity_loss63.62593
metrics/CE_loss_score0.59248
metrics/ce_loss_with_ablation8.29373
metrics/ce_loss_with_sae4.50411
metrics/ce_loss_without_sae1.8969
metrics/explained_variance0.15973
metrics/explained_variance_std0.24142
metrics/l07705.52734
metrics/l2_norm14.99578
metrics/l2_norm_in17.58649
metrics/l2_ratio0.8463
metrics/mean_log10_feature_sparsity-0.74933
sparsity/below_1e-5681
sparsity/below_1e-6681
sparsity/dead_features681
sparsity/mean_passes_since_fired138.74988

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" + { + "data": { + "text/html": [ + " View run 16384-L1-20-LR-5e-05-Tokens-1.229e+08 at: https://wandb.ai/curt-tigges/gated_sae_testing/runs/ch6e0a5s
View project at: https://wandb.ai/curt-tigges/gated_sae_testing
Synced 6 W&B file(s), 0 media file(s), 3 artifact file(s) and 0 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20240616_143959-ch6e0a5s/logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } + ], + "source": [ + "total_training_steps = 30_000 # probably we should do more\n", + "batch_size = 4096\n", + "total_training_tokens = total_training_steps * batch_size\n", + "\n", + "lr_warm_up_steps = 0\n", + "lr_decay_steps = total_training_steps // 5 # 20% of training\n", + "l1_warm_up_steps = 10_000 # total_training_steps // 20 # 5% of training\n", + "\n", + "cfg = LanguageModelSAERunnerConfig(\n", + " # Data Generating Function (Model + Training Distribution)\n", + " architecture=\"gated\", # we'll use the gated variant.\n", + " model_name=\"tiny-stories-1L-21M\", # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)\n", + " hook_name=\"blocks.0.hook_mlp_out\", # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)\n", + " hook_layer=0, # Only one layer in the model.\n", + " d_in=1024, # the width of the mlp output.\n", + " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\", # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.\n", + " is_dataset_tokenized=True,\n", + " streaming=True, # we could pre-download the token dataset if it was small.\n", + " # SAE Parameters\n", + " mse_loss_normalization=None, # We won't normalize the mse loss,\n", + " expansion_factor=16, # the width of the SAE. Larger will result in better stats but slower training.\n", + " b_dec_init_method=\"zeros\", # The geometric median can be used to initialize the decoder weights.\n", + " apply_b_dec_to_input=True, # We won't apply the decoder weights to the input.\n", + " normalize_sae_decoder=False,\n", + " scale_sparsity_penalty_by_decoder_norm=False,\n", + " decoder_heuristic_init=True,\n", + " init_encoder_as_decoder_transpose=True,\n", + " normalize_activations=False,\n", + " # Training Parameters\n", + " lr=5e-5, # lower the better, we'll go fairly high to speed up the tutorial.\n", + " adam_beta1=0.9, # adam params (default, but once upon a time we experimented with these.)\n", + " adam_beta2=0.999,\n", + " lr_scheduler_name=\"constant\", # constant learning rate with warmup. Could be better schedules out there.\n", + " lr_warm_up_steps=lr_warm_up_steps, # this can help avoid too many dead features initially.\n", + " lr_decay_steps=lr_decay_steps, # this will help us avoid overfitting.\n", + " l1_coefficient=20, # will control how sparse the feature activations are\n", + " l1_warm_up_steps=l1_warm_up_steps, # this can help avoid too many dead features initially.\n", + " lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1)\n", + " train_batch_size_tokens=batch_size,\n", + " context_size=256, # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.\n", + " # Activation Store Parameters\n", + " n_batches_in_buffer=64, # controls how many activations we store / shuffle.\n", + " training_tokens=total_training_tokens, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.\n", + " store_batch_size_prompts=16,\n", + " # Resampling protocol\n", + " use_ghost_grads=False, # we don't use ghost grads anymore.\n", + " feature_sampling_window=1000, # this controls our reporting of feature sparsity stats\n", + " dead_feature_window=1000, # would effect resampling or ghost grads if we were using it.\n", + " dead_feature_threshold=1e-4, # would effect resampling or ghost grads if we were using it.\n", + " # WANDB\n", + " log_to_wandb=True, # always use wandb unless you are just testing code.\n", + " wandb_project=\"gated_sae_testing\",\n", + " wandb_log_frequency=30,\n", + " eval_every_n_wandb_logs=20,\n", + " # Misc\n", + " device=device,\n", + " seed=42,\n", + " n_checkpoints=0,\n", + " checkpoint_path=\"checkpoints\",\n", + " dtype=\"float32\",\n", + ")\n", + "# look at the next cell to see some instruction for what to do while this is running.\n", + "sparse_autoencoder = SAETrainingRunner(cfg).run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" }, - "nbformat": 4, - "nbformat_minor": 0 + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/tutorials/training_a_sparse_autoencoder.ipynb b/tutorials/training_a_sparse_autoencoder.ipynb index 42441cf0..3008abe5 100644 --- a/tutorials/training_a_sparse_autoencoder.ipynb +++ b/tutorials/training_a_sparse_autoencoder.ipynb @@ -1,901 +1,903 @@ { - "cells": [ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "5O8tQblzOVHu" + }, + "source": [ + "# Training a basic SAE with SAELens\n", + "\n", + "This tutorial demonstrates training a simple, relatively small Sparse Autoencoder, specifically on the tiny-stories-1L-21M model. \n", + "\n", + "As the SAELens library is under active development, please open an issue if this tutorial is stale [here](https://github.com/jbloomAus/SAELens)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "shAFb9-lOVHu" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "LeRi_tw2dhae" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "5O8tQblzOVHu" - }, - "source": [ - "# Training a basic SAE with SAELens\n", - "\n", - "This tutorial demonstrates training a simple, relatively small Sparse Autoencoder, specifically on the tiny-stories-1L-21M model. \n", - "\n", - "As the SAELens library is under active development, please open an issue if this tutorial is stale [here](https://github.com/jbloomAus/SAELens)." - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: sae-lens in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (3.3.0)\n", + "Requirement already satisfied: transformer-lens in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (1.19.0)\n", + "Requirement already satisfied: circuitsvis in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (1.43.2)\n", + "Requirement already satisfied: automated-interpretability<0.0.4,>=0.0.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.0.3)\n", + "Requirement already satisfied: babe<0.0.8,>=0.0.7 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.0.7)\n", + "Requirement already satisfied: datasets<3.0.0,>=2.17.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (2.19.2)\n", + "Requirement already satisfied: matplotlib<4.0.0,>=3.8.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (3.9.0)\n", + "Requirement already satisfied: matplotlib-inline<0.2.0,>=0.1.6 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.1.7)\n", + "Requirement already satisfied: nltk<4.0.0,>=3.8.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (3.8.1)\n", + "Requirement already satisfied: plotly<6.0.0,>=5.19.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (5.22.0)\n", + "Requirement already satisfied: plotly-express<0.5.0,>=0.4.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.4.1)\n", + "Requirement already satisfied: pytest-profiling<2.0.0,>=1.7.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (1.7.0)\n", + "Requirement already satisfied: python-dotenv<2.0.0,>=1.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (1.0.1)\n", + "Requirement already satisfied: pyyaml<7.0.0,>=6.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (6.0.1)\n", + "Requirement already satisfied: pyzmq==26.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (26.0.0)\n", + "Requirement already satisfied: sae-vis<0.3.0,>=0.2.18 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.2.18)\n", + "Requirement already satisfied: safetensors<0.5.0,>=0.4.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.4.3)\n", + "Requirement already satisfied: transformers<5.0.0,>=4.38.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (4.41.2)\n", + "Requirement already satisfied: typer<0.13.0,>=0.12.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.12.3)\n", + "Requirement already satisfied: accelerate>=0.23.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.31.0)\n", + "Requirement already satisfied: beartype<0.15.0,>=0.14.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.14.1)\n", + "Requirement already satisfied: better-abc<0.0.4,>=0.0.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.0.3)\n", + "Requirement already satisfied: einops>=0.6.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.7.0)\n", + "Requirement already satisfied: fancy-einsum>=0.0.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.0.3)\n", + "Requirement already satisfied: jaxtyping>=0.2.11 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.2.29)\n", + "Requirement already satisfied: numpy>=1.24 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (1.26.4)\n", + "Requirement already satisfied: pandas>=1.1.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (2.2.2)\n", + "Requirement already satisfied: rich>=12.6.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (13.7.1)\n", + "Requirement already satisfied: sentencepiece in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.2.0)\n", + "Requirement already satisfied: torch>=1.10 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (2.1.2)\n", + "Requirement already satisfied: tqdm>=4.64.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (4.66.4)\n", + "Requirement already satisfied: typing-extensions in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (4.12.2)\n", + "Requirement already satisfied: wandb>=0.13.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.17.1)\n", + "Requirement already satisfied: importlib-metadata>=5.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (7.1.0)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (2.18.1)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: triton==2.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (2.1.0)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->circuitsvis) (12.5.40)\n", + "Requirement already satisfied: filelock in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from triton==2.1.0->circuitsvis) (3.14.0)\n", + "Requirement already satisfied: packaging>=20.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from accelerate>=0.23.0->transformer-lens) (24.0)\n", + "Requirement already satisfied: psutil in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from accelerate>=0.23.0->transformer-lens) (5.9.0)\n", + "Requirement already satisfied: huggingface-hub in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from accelerate>=0.23.0->transformer-lens) (0.23.3)\n", + "Requirement already satisfied: blobfile<3.0.0,>=2.1.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.1.1)\n", + "Requirement already satisfied: boostedblob<0.16.0,>=0.15.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.15.3)\n", + "Requirement already satisfied: httpx<0.28.0,>=0.27.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.27.0)\n", + "Requirement already satisfied: orjson<4.0.0,>=3.10.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.10.4)\n", + "Requirement already satisfied: pytest<9.0.0,>=8.1.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (8.2.2)\n", + "Requirement already satisfied: scikit-learn<2.0.0,>=1.4.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.5.0)\n", + "Requirement already satisfied: tiktoken<0.7.0,>=0.6.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.6.0)\n", + "Requirement already satisfied: py2store in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from babe<0.0.8,>=0.0.7->sae-lens) (0.1.20)\n", + "Requirement already satisfied: graze in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from babe<0.0.8,>=0.0.7->sae-lens) (0.1.17)\n", + "Requirement already satisfied: pyarrow>=12.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (16.1.0)\n", + "Requirement already satisfied: pyarrow-hotfix in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (0.6)\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (0.3.8)\n", + "Requirement already satisfied: requests>=2.32.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (2.32.3)\n", + "Requirement already satisfied: xxhash in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (3.4.1)\n", + "Requirement already satisfied: multiprocess in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (0.70.16)\n", + "Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets<3.0.0,>=2.17.1->sae-lens) (2024.3.1)\n", + "Requirement already satisfied: aiohttp in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (3.9.5)\n", + "Requirement already satisfied: zipp>=0.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from importlib-metadata>=5.1.0->circuitsvis) (3.19.2)\n", + "Requirement already satisfied: typeguard==2.13.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from jaxtyping>=0.2.11->transformer-lens) (2.13.3)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (1.2.1)\n", + "Requirement already satisfied: cycler>=0.10 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (0.12.1)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (4.53.0)\n", + "Requirement already satisfied: kiwisolver>=1.3.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (1.4.5)\n", + "Requirement already satisfied: pillow>=8 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (10.3.0)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (3.1.2)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (2.9.0)\n", + "Requirement already satisfied: traitlets in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib-inline<0.2.0,>=0.1.6->sae-lens) (5.14.3)\n", + "Requirement already satisfied: click in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nltk<4.0.0,>=3.8.1->sae-lens) (8.1.7)\n", + "Requirement already satisfied: joblib in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nltk<4.0.0,>=3.8.1->sae-lens) (1.4.2)\n", + "Requirement already satisfied: regex>=2021.8.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nltk<4.0.0,>=3.8.1->sae-lens) (2024.5.15)\n", + "Requirement already satisfied: pytz>=2020.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pandas>=1.1.5->transformer-lens) (2024.1)\n", + "Requirement already satisfied: tzdata>=2022.7 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pandas>=1.1.5->transformer-lens) (2024.1)\n", + "Requirement already satisfied: tenacity>=6.2.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly<6.0.0,>=5.19.0->sae-lens) (8.3.0)\n", + "Requirement already satisfied: statsmodels>=0.9.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly-express<0.5.0,>=0.4.1->sae-lens) (0.14.2)\n", + "Requirement already satisfied: scipy>=0.18 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly-express<0.5.0,>=0.4.1->sae-lens) (1.13.1)\n", + "Requirement already satisfied: patsy>=0.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly-express<0.5.0,>=0.4.1->sae-lens) (0.5.6)\n", + "Requirement already satisfied: six in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest-profiling<2.0.0,>=1.7.0->sae-lens) (1.16.0)\n", + "Requirement already satisfied: gprof2dot in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest-profiling<2.0.0,>=1.7.0->sae-lens) (2024.6.6)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from rich>=12.6.0->transformer-lens) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from rich>=12.6.0->transformer-lens) (2.18.0)\n", + "Requirement already satisfied: dataclasses-json<0.7.0,>=0.6.4 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-vis<0.3.0,>=0.2.18->sae-lens) (0.6.7)\n", + "Requirement already satisfied: eindex-callum<0.2.0,>=0.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-vis<0.3.0,>=0.2.18->sae-lens) (0.1.1)\n", + "Requirement already satisfied: sympy in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from torch>=1.10->transformer-lens) (1.12.1)\n", + "Requirement already satisfied: networkx in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from torch>=1.10->transformer-lens) (3.3)\n", + "Requirement already satisfied: jinja2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from torch>=1.10->transformer-lens) (3.1.4)\n", + "Requirement already satisfied: tokenizers<0.20,>=0.19 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformers<5.0.0,>=4.38.1->sae-lens) (0.19.1)\n", + "Requirement already satisfied: shellingham>=1.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from typer<0.13.0,>=0.12.3->sae-lens) (1.5.4)\n", + "Requirement already satisfied: docker-pycreds>=0.4.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (0.4.0)\n", + "Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (3.1.43)\n", + "Requirement already satisfied: platformdirs in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (4.2.2)\n", + "Requirement already satisfied: protobuf!=4.21.0,<6,>=3.19.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (5.27.1)\n", + "Requirement already satisfied: sentry-sdk>=1.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (2.5.1)\n", + "Requirement already satisfied: setproctitle in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (1.3.3)\n", + "Requirement already satisfied: setuptools in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (70.0.0)\n", + "Requirement already satisfied: pycryptodomex~=3.8 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from blobfile<3.0.0,>=2.1.1->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.20.0)\n", + "Requirement already satisfied: urllib3<3,>=1.25.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from blobfile<3.0.0,>=2.1.1->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.2.1)\n", + "Requirement already satisfied: lxml~=4.9 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from blobfile<3.0.0,>=2.1.1->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (4.9.4)\n", + "Requirement already satisfied: uvloop>=0.16.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from boostedblob<0.16.0,>=0.15.3->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.19.0)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (23.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (1.9.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (4.0.3)\n", + "Requirement already satisfied: marshmallow<4.0.0,>=3.18.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from dataclasses-json<0.7.0,>=0.6.4->sae-vis<0.3.0,>=0.2.18->sae-lens) (3.21.3)\n", + "Requirement already satisfied: typing-inspect<1,>=0.4.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from dataclasses-json<0.7.0,>=0.6.4->sae-vis<0.3.0,>=0.2.18->sae-lens) (0.9.0)\n", + "Requirement already satisfied: gitdb<5,>=4.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from gitpython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens) (4.0.11)\n", + "Requirement already satisfied: anyio in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (4.4.0)\n", + "Requirement already satisfied: certifi in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2024.6.2)\n", + "Requirement already satisfied: httpcore==1.* in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.0.5)\n", + "Requirement already satisfied: idna in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.7)\n", + "Requirement already satisfied: sniffio in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.3.1)\n", + "Requirement already satisfied: h11<0.15,>=0.13 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpcore==1.*->httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.14.0)\n", + "Requirement already satisfied: mdurl~=0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=12.6.0->transformer-lens) (0.1.2)\n", + "Requirement already satisfied: iniconfig in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.0.0)\n", + "Requirement already satisfied: pluggy<2.0,>=1.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.5.0)\n", + "Requirement already satisfied: exceptiongroup>=1.0.0rc8 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.2.0)\n", + "Requirement already satisfied: tomli>=1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.0.1)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from requests>=2.32.1->datasets<3.0.0,>=2.17.1->sae-lens) (3.3.2)\n", + "Requirement already satisfied: threadpoolctl>=3.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from scikit-learn<2.0.0,>=1.4.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.5.0)\n", + "Requirement already satisfied: dol in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from graze->babe<0.0.8,>=0.0.7->sae-lens) (0.2.47)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from jinja2->torch>=1.10->transformer-lens) (2.1.5)\n", + "Requirement already satisfied: config2py in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from py2store->babe<0.0.8,>=0.0.7->sae-lens) (0.1.33)\n", + "Requirement already satisfied: importlib-resources in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from py2store->babe<0.0.8,>=0.0.7->sae-lens) (6.4.0)\n", + "Requirement already satisfied: mpmath<1.4.0,>=1.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sympy->torch>=1.10->transformer-lens) (1.3.0)\n", + "Requirement already satisfied: smmap<6,>=3.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens) (5.0.1)\n", + "Requirement already satisfied: mypy-extensions>=0.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from typing-inspect<1,>=0.4.0->dataclasses-json<0.7.0,>=0.6.4->sae-vis<0.3.0,>=0.2.18->sae-lens) (1.0.0)\n", + "Requirement already satisfied: i2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from config2py->py2store->babe<0.0.8,>=0.0.7->sae-lens) (0.1.17)\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "try:\n", + " # import google.colab # type: ignore\n", + " # from google.colab import output\n", + " %pip install sae-lens transformer-lens circuitsvis\n", + "except:\n", + " from IPython import get_ipython # type: ignore\n", + "\n", + " ipython = get_ipython()\n", + " assert ipython is not None\n", + " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", + " ipython.run_line_magic(\"autoreload\", \"2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "uy-b3CcSOVHu", + "outputId": "58ce28d0-f91f-436d-cf87-76bb26e2ecaf" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "shAFb9-lOVHu" - }, - "source": [ - "## Setup" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] }, { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "LeRi_tw2dhae" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: sae-lens in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (3.3.0)\n", - "Requirement already satisfied: transformer-lens in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (1.19.0)\n", - "Requirement already satisfied: circuitsvis in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (1.43.2)\n", - "Requirement already satisfied: automated-interpretability<0.0.4,>=0.0.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.0.3)\n", - "Requirement already satisfied: babe<0.0.8,>=0.0.7 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.0.7)\n", - "Requirement already satisfied: datasets<3.0.0,>=2.17.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (2.19.2)\n", - "Requirement already satisfied: matplotlib<4.0.0,>=3.8.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (3.9.0)\n", - "Requirement already satisfied: matplotlib-inline<0.2.0,>=0.1.6 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.1.7)\n", - "Requirement already satisfied: nltk<4.0.0,>=3.8.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (3.8.1)\n", - "Requirement already satisfied: plotly<6.0.0,>=5.19.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (5.22.0)\n", - "Requirement already satisfied: plotly-express<0.5.0,>=0.4.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.4.1)\n", - "Requirement already satisfied: pytest-profiling<2.0.0,>=1.7.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (1.7.0)\n", - "Requirement already satisfied: python-dotenv<2.0.0,>=1.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (1.0.1)\n", - "Requirement already satisfied: pyyaml<7.0.0,>=6.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (6.0.1)\n", - "Requirement already satisfied: pyzmq==26.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (26.0.0)\n", - "Requirement already satisfied: sae-vis<0.3.0,>=0.2.18 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.2.18)\n", - "Requirement already satisfied: safetensors<0.5.0,>=0.4.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.4.3)\n", - "Requirement already satisfied: transformers<5.0.0,>=4.38.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (4.41.2)\n", - "Requirement already satisfied: typer<0.13.0,>=0.12.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.12.3)\n", - "Requirement already satisfied: accelerate>=0.23.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.31.0)\n", - "Requirement already satisfied: beartype<0.15.0,>=0.14.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.14.1)\n", - "Requirement already satisfied: better-abc<0.0.4,>=0.0.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.0.3)\n", - "Requirement already satisfied: einops>=0.6.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.7.0)\n", - "Requirement already satisfied: fancy-einsum>=0.0.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.0.3)\n", - "Requirement already satisfied: jaxtyping>=0.2.11 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.2.29)\n", - "Requirement already satisfied: numpy>=1.24 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (1.26.4)\n", - "Requirement already satisfied: pandas>=1.1.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (2.2.2)\n", - "Requirement already satisfied: rich>=12.6.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (13.7.1)\n", - "Requirement already satisfied: sentencepiece in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.2.0)\n", - "Requirement already satisfied: torch>=1.10 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (2.1.2)\n", - "Requirement already satisfied: tqdm>=4.64.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (4.66.4)\n", - "Requirement already satisfied: typing-extensions in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (4.12.2)\n", - "Requirement already satisfied: wandb>=0.13.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.17.1)\n", - "Requirement already satisfied: importlib-metadata>=5.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (7.1.0)\n", - "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.3.1)\n", - "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", - "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", - "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", - "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (8.9.2.26)\n", - "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (11.0.2.54)\n", - "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (10.3.2.106)\n", - "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (11.4.5.107)\n", - "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.0.106)\n", - "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (2.18.1)\n", - "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", - "Requirement already satisfied: triton==2.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (2.1.0)\n", - "Requirement already satisfied: nvidia-nvjitlink-cu12 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->circuitsvis) (12.5.40)\n", - "Requirement already satisfied: filelock in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from triton==2.1.0->circuitsvis) (3.14.0)\n", - "Requirement already satisfied: packaging>=20.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from accelerate>=0.23.0->transformer-lens) (24.0)\n", - "Requirement already satisfied: psutil in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from accelerate>=0.23.0->transformer-lens) (5.9.0)\n", - "Requirement already satisfied: huggingface-hub in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from accelerate>=0.23.0->transformer-lens) (0.23.3)\n", - "Requirement already satisfied: blobfile<3.0.0,>=2.1.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.1.1)\n", - "Requirement already satisfied: boostedblob<0.16.0,>=0.15.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.15.3)\n", - "Requirement already satisfied: httpx<0.28.0,>=0.27.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.27.0)\n", - "Requirement already satisfied: orjson<4.0.0,>=3.10.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.10.4)\n", - "Requirement already satisfied: pytest<9.0.0,>=8.1.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (8.2.2)\n", - "Requirement already satisfied: scikit-learn<2.0.0,>=1.4.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.5.0)\n", - "Requirement already satisfied: tiktoken<0.7.0,>=0.6.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.6.0)\n", - "Requirement already satisfied: py2store in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from babe<0.0.8,>=0.0.7->sae-lens) (0.1.20)\n", - "Requirement already satisfied: graze in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from babe<0.0.8,>=0.0.7->sae-lens) (0.1.17)\n", - "Requirement already satisfied: pyarrow>=12.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (16.1.0)\n", - "Requirement already satisfied: pyarrow-hotfix in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (0.6)\n", - "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (0.3.8)\n", - "Requirement already satisfied: requests>=2.32.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (2.32.3)\n", - "Requirement already satisfied: xxhash in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (3.4.1)\n", - "Requirement already satisfied: multiprocess in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (0.70.16)\n", - "Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets<3.0.0,>=2.17.1->sae-lens) (2024.3.1)\n", - "Requirement already satisfied: aiohttp in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (3.9.5)\n", - "Requirement already satisfied: zipp>=0.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from importlib-metadata>=5.1.0->circuitsvis) (3.19.2)\n", - "Requirement already satisfied: typeguard==2.13.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from jaxtyping>=0.2.11->transformer-lens) (2.13.3)\n", - "Requirement already satisfied: contourpy>=1.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (1.2.1)\n", - "Requirement already satisfied: cycler>=0.10 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (0.12.1)\n", - "Requirement already satisfied: fonttools>=4.22.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (4.53.0)\n", - "Requirement already satisfied: kiwisolver>=1.3.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (1.4.5)\n", - "Requirement already satisfied: pillow>=8 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (10.3.0)\n", - "Requirement already satisfied: pyparsing>=2.3.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (3.1.2)\n", - "Requirement already satisfied: python-dateutil>=2.7 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (2.9.0)\n", - "Requirement already satisfied: traitlets in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib-inline<0.2.0,>=0.1.6->sae-lens) (5.14.3)\n", - "Requirement already satisfied: click in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nltk<4.0.0,>=3.8.1->sae-lens) (8.1.7)\n", - "Requirement already satisfied: joblib in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nltk<4.0.0,>=3.8.1->sae-lens) (1.4.2)\n", - "Requirement already satisfied: regex>=2021.8.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nltk<4.0.0,>=3.8.1->sae-lens) (2024.5.15)\n", - "Requirement already satisfied: pytz>=2020.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pandas>=1.1.5->transformer-lens) (2024.1)\n", - "Requirement already satisfied: tzdata>=2022.7 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pandas>=1.1.5->transformer-lens) (2024.1)\n", - "Requirement already satisfied: tenacity>=6.2.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly<6.0.0,>=5.19.0->sae-lens) (8.3.0)\n", - "Requirement already satisfied: statsmodels>=0.9.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly-express<0.5.0,>=0.4.1->sae-lens) (0.14.2)\n", - "Requirement already satisfied: scipy>=0.18 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly-express<0.5.0,>=0.4.1->sae-lens) (1.13.1)\n", - "Requirement already satisfied: patsy>=0.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly-express<0.5.0,>=0.4.1->sae-lens) (0.5.6)\n", - "Requirement already satisfied: six in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest-profiling<2.0.0,>=1.7.0->sae-lens) (1.16.0)\n", - "Requirement already satisfied: gprof2dot in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest-profiling<2.0.0,>=1.7.0->sae-lens) (2024.6.6)\n", - "Requirement already satisfied: markdown-it-py>=2.2.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from rich>=12.6.0->transformer-lens) (3.0.0)\n", - "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from rich>=12.6.0->transformer-lens) (2.18.0)\n", - "Requirement already satisfied: dataclasses-json<0.7.0,>=0.6.4 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-vis<0.3.0,>=0.2.18->sae-lens) (0.6.7)\n", - "Requirement already satisfied: eindex-callum<0.2.0,>=0.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-vis<0.3.0,>=0.2.18->sae-lens) (0.1.1)\n", - "Requirement already satisfied: sympy in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from torch>=1.10->transformer-lens) (1.12.1)\n", - "Requirement already satisfied: networkx in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from torch>=1.10->transformer-lens) (3.3)\n", - "Requirement already satisfied: jinja2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from torch>=1.10->transformer-lens) (3.1.4)\n", - "Requirement already satisfied: tokenizers<0.20,>=0.19 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformers<5.0.0,>=4.38.1->sae-lens) (0.19.1)\n", - "Requirement already satisfied: shellingham>=1.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from typer<0.13.0,>=0.12.3->sae-lens) (1.5.4)\n", - "Requirement already satisfied: docker-pycreds>=0.4.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (0.4.0)\n", - "Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (3.1.43)\n", - "Requirement already satisfied: platformdirs in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (4.2.2)\n", - "Requirement already satisfied: protobuf!=4.21.0,<6,>=3.19.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (5.27.1)\n", - "Requirement already satisfied: sentry-sdk>=1.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (2.5.1)\n", - "Requirement already satisfied: setproctitle in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (1.3.3)\n", - "Requirement already satisfied: setuptools in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (70.0.0)\n", - "Requirement already satisfied: pycryptodomex~=3.8 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from blobfile<3.0.0,>=2.1.1->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.20.0)\n", - "Requirement already satisfied: urllib3<3,>=1.25.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from blobfile<3.0.0,>=2.1.1->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.2.1)\n", - "Requirement already satisfied: lxml~=4.9 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from blobfile<3.0.0,>=2.1.1->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (4.9.4)\n", - "Requirement already satisfied: uvloop>=0.16.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from boostedblob<0.16.0,>=0.15.3->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.19.0)\n", - "Requirement already satisfied: aiosignal>=1.1.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (1.3.1)\n", - "Requirement already satisfied: attrs>=17.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (23.2.0)\n", - "Requirement already satisfied: frozenlist>=1.1.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (1.4.1)\n", - "Requirement already satisfied: multidict<7.0,>=4.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (6.0.5)\n", - "Requirement already satisfied: yarl<2.0,>=1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (1.9.4)\n", - "Requirement already satisfied: async-timeout<5.0,>=4.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (4.0.3)\n", - "Requirement already satisfied: marshmallow<4.0.0,>=3.18.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from dataclasses-json<0.7.0,>=0.6.4->sae-vis<0.3.0,>=0.2.18->sae-lens) (3.21.3)\n", - "Requirement already satisfied: typing-inspect<1,>=0.4.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from dataclasses-json<0.7.0,>=0.6.4->sae-vis<0.3.0,>=0.2.18->sae-lens) (0.9.0)\n", - "Requirement already satisfied: gitdb<5,>=4.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from gitpython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens) (4.0.11)\n", - "Requirement already satisfied: anyio in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (4.4.0)\n", - "Requirement already satisfied: certifi in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2024.6.2)\n", - "Requirement already satisfied: httpcore==1.* in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.0.5)\n", - "Requirement already satisfied: idna in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.7)\n", - "Requirement already satisfied: sniffio in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.3.1)\n", - "Requirement already satisfied: h11<0.15,>=0.13 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpcore==1.*->httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.14.0)\n", - "Requirement already satisfied: mdurl~=0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=12.6.0->transformer-lens) (0.1.2)\n", - "Requirement already satisfied: iniconfig in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.0.0)\n", - "Requirement already satisfied: pluggy<2.0,>=1.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.5.0)\n", - "Requirement already satisfied: exceptiongroup>=1.0.0rc8 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.2.0)\n", - "Requirement already satisfied: tomli>=1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.0.1)\n", - "Requirement already satisfied: charset-normalizer<4,>=2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from requests>=2.32.1->datasets<3.0.0,>=2.17.1->sae-lens) (3.3.2)\n", - "Requirement already satisfied: threadpoolctl>=3.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from scikit-learn<2.0.0,>=1.4.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.5.0)\n", - "Requirement already satisfied: dol in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from graze->babe<0.0.8,>=0.0.7->sae-lens) (0.2.47)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from jinja2->torch>=1.10->transformer-lens) (2.1.5)\n", - "Requirement already satisfied: config2py in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from py2store->babe<0.0.8,>=0.0.7->sae-lens) (0.1.33)\n", - "Requirement already satisfied: importlib-resources in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from py2store->babe<0.0.8,>=0.0.7->sae-lens) (6.4.0)\n", - "Requirement already satisfied: mpmath<1.4.0,>=1.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sympy->torch>=1.10->transformer-lens) (1.3.0)\n", - "Requirement already satisfied: smmap<6,>=3.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens) (5.0.1)\n", - "Requirement already satisfied: mypy-extensions>=0.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from typing-inspect<1,>=0.4.0->dataclasses-json<0.7.0,>=0.6.4->sae-vis<0.3.0,>=0.2.18->sae-lens) (1.0.0)\n", - "Requirement already satisfied: i2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from config2py->py2store->babe<0.0.8,>=0.0.7->sae-lens) (0.1.17)\n", - "Note: you may need to restart the kernel to use updated packages.\n" - ] - } - ], - "source": [ - "try:\n", - " #import google.colab # type: ignore\n", - " #from google.colab import output\n", - " %pip install sae-lens transformer-lens circuitsvis\n", - "except:\n", - " from IPython import get_ipython # type: ignore\n", - " ipython = get_ipython(); assert ipython is not None\n", - " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", - " ipython.run_line_magic(\"autoreload\", \"2\")" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: cuda\n" + ] + } + ], + "source": [ + "import torch\n", + "import os\n", + "\n", + "from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner\n", + "\n", + "if torch.cuda.is_available():\n", + " device = \"cuda\"\n", + "elif torch.backends.mps.is_available():\n", + " device = \"mps\"\n", + "else:\n", + " device = \"cpu\"\n", + "\n", + "print(\"Using device:\", device)\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oe2nlqf-OVHv" + }, + "source": [ + "# Model Selection and Evaluation (Feel Free to Skip)\n", + "\n", + "We'll use the runner to train an SAE on a TinyStories Model. This is a very small model so we can train an SAE on it quite quickly. Before we get started, let's load in the model with `transformer_lens` and see what it can do.\n", + "\n", + "TransformerLens gives us 2 functions that are useful here (and circuits viz provides a third):\n", + "1. `transformer_lens.utils.test_prompt` will help us see when the model can infer one token.\n", + "2. `HookedTransformer.generate` will help us see what happens when we sample from the model.\n", + "3. `circuitsvis.logits.token_log_probs` will help us visualize the log probs of tokens at several positions in a prompt." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "hFz6JUMuOVHv" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n", + "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", + " return self.fget.__get__(instance, owner)()\n" + ] }, { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "uy-b3CcSOVHu", - "outputId": "58ce28d0-f91f-436d-cf87-76bb26e2ecaf" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using device: cuda\n" - ] - } - ], - "source": [ - "import torch\n", - "import os\n", - "\n", - "from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner\n", - "\n", - "if torch.cuda.is_available():\n", - " device = \"cuda\"\n", - "elif torch.backends.mps.is_available():\n", - " device = \"mps\"\n", - "else:\n", - " device = \"cpu\"\n", - "\n", - "print(\"Using device:\", device)\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"" + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded pretrained model tiny-stories-1L-21M into HookedTransformer\n" + ] + } + ], + "source": [ + "from transformer_lens import HookedTransformer\n", + "\n", + "model = HookedTransformer.from_pretrained(\n", + " \"tiny-stories-1L-21M\"\n", + ") # This will wrap huggingface models and has lots of nice utilities." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aUiXrjdUOVHv" + }, + "source": [ + "### Getting a vibe for a model using `model.generate`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZZfKT5aDOVHv" + }, + "source": [ + "Let's start by generating some stories using the model." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "G4ad4Zz1OVHv" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'Once upon a time, Bobby was hungry and needed something to do. He went to the subway but was far away.\\n\\nThe man wanted to get the hat, so the people wanted it. He found it hard to be a big, powerful bird. It wanted'" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "metadata": { - "id": "oe2nlqf-OVHv" - }, - "source": [ - "# Model Selection and Evaluation (Feel Free to Skip)\n", - "\n", - "We'll use the runner to train an SAE on a TinyStories Model. This is a very small model so we can train an SAE on it quite quickly. Before we get started, let's load in the model with `transformer_lens` and see what it can do.\n", - "\n", - "TransformerLens gives us 2 functions that are useful here (and circuits viz provides a third):\n", - "1. `transformer_lens.utils.test_prompt` will help us see when the model can infer one token.\n", - "2. `HookedTransformer.generate` will help us see what happens when we sample from the model.\n", - "3. `circuitsvis.logits.token_log_probs` will help us visualize the log probs of tokens at several positions in a prompt." + "data": { + "text/plain": [ + "'Once upon a time, there was a trunk. The trunk was very rich, and it was a very special trunk. All the animals came across the trunk, and was very colorful. They took turns to fill it up when they bumped into a dragon. \\n\\n'" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "hFz6JUMuOVHv" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", - " warnings.warn(\n", - "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", - " return self.fget.__get__(instance, owner)()\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loaded pretrained model tiny-stories-1L-21M into HookedTransformer\n" - ] - } - ], - "source": [ - "from transformer_lens import HookedTransformer\n", - "\n", - "model = HookedTransformer.from_pretrained(\n", - " \"tiny-stories-1L-21M\"\n", - ") # This will wrap huggingface models and has lots of nice utilities." + "data": { + "text/plain": [ + "'Once upon a time, there was a young man. He was three years old. He wanted to learn how to keep the match safe. So he kept checking it every day.\\n\\nOne day a 3 year old girl wanted to learn about the fun trunk of the'" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "metadata": { - "id": "aUiXrjdUOVHv" - }, - "source": [ - "### Getting a vibe for a model using `model.generate`" + "data": { + "text/plain": [ + "'Once upon a time, there was a little girl named Sally. She liked to play with her toys. One sunny day, Sally found a butterfly. She was so happy! She wanted to play with something new. So, she called her friend Tom.\\n\\nTom'" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "metadata": { - "id": "ZZfKT5aDOVHv" - }, - "source": [ - "Let's start by generating some stories using the model." + "data": { + "text/plain": [ + "'Once upon a time, there was a little girl named Lola. She really loved playing with her pet cat, Tom and show them his appreciation.\\n\\nOne day, Lola licked the couch closer and soon found herself in a magical land! As soon as'" ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# here we use generate to get 10 completeions with temperature 1. Feel free to play with the prompt to make it more interesting.\n", + "for i in range(5):\n", + " display(\n", + " model.generate(\n", + " \"Once upon a time\",\n", + " stop_at_eos=False, # avoids a bug on MPS\n", + " temperature=1,\n", + " verbose=False,\n", + " max_new_tokens=50,\n", + " )\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RDKr8o1xOVHv" + }, + "source": [ + "One thing we notice is that the model seems to be able to repeat the name of the main character very consistently. It can output a pronoun intead but in some stories will repeat the protagonists name. This seems like an interesting capability to analyse with SAEs. To better understand the models ability to remember the protagonists name, let's extract a prompt where the next character is determined and use the \"test_prompt\" utility from TransformerLens to check the ranking of the token for that name." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KsfJX-YpOVHv" + }, + "source": [ + "### Spot checking model abilities with `transformer_lens.utils.test_prompt`" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "TpmPoj7uOVHv" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ',', ' there', ' was', ' a', ' little', ' girl', ' named', ' Lily', '.', ' She', ' lived', ' in', ' a', ' big', ',', ' happy', ' little', ' girl', '.', ' On', ' her', ' big', ' adventure', ',']\n", + "Tokenized answer: [' Lily']\n" + ] }, { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "G4ad4Zz1OVHv" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "'Once upon a time, Bobby was hungry and needed something to do. He went to the subway but was far away.\\n\\nThe man wanted to get the hat, so the people wanted it. He found it hard to be a big, powerful bird. It wanted'" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "'Once upon a time, there was a trunk. The trunk was very rich, and it was a very special trunk. All the animals came across the trunk, and was very colorful. They took turns to fill it up when they bumped into a dragon. \\n\\n'" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "'Once upon a time, there was a young man. He was three years old. He wanted to learn how to keep the match safe. So he kept checking it every day.\\n\\nOne day a 3 year old girl wanted to learn about the fun trunk of the'" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "'Once upon a time, there was a little girl named Sally. She liked to play with her toys. One sunny day, Sally found a butterfly. She was so happy! She wanted to play with something new. So, she called her friend Tom.\\n\\nTom'" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "'Once upon a time, there was a little girl named Lola. She really loved playing with her pet cat, Tom and show them his appreciation.\\n\\nOne day, Lola licked the couch closer and soon found herself in a magical land! As soon as'" - ] - }, - "metadata": {}, - "output_type": "display_data" - } + "data": { + "text/html": [ + "
Performance on answer token:\n",
+       "Rank: 1        Logit: 18.81 Prob: 13.46% Token: | Lily|\n",
+       "
\n" ], - "source": [ - "# here we use generate to get 10 completeions with temperature 1. Feel free to play with the prompt to make it more interesting.\n", - "for i in range(5):\n", - " display(\n", - " model.generate(\n", - " \"Once upon a time\",\n", - " stop_at_eos=False, # avoids a bug on MPS\n", - " temperature=1,\n", - " verbose=False,\n", - " max_new_tokens=50,\n", - " )\n", - " )" + "text/plain": [ + "Performance on answer token:\n", + "\u001b[1mRank: \u001b[0m\u001b[1;36m1\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m18.81\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m13.46\u001b[0m\u001b[1m% Token: | Lily|\u001b[0m\n" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "metadata": { - "id": "RDKr8o1xOVHv" - }, - "source": [ - "One thing we notice is that the model seems to be able to repeat the name of the main character very consistently. It can output a pronoun intead but in some stories will repeat the protagonists name. This seems like an interesting capability to analyse with SAEs. To better understand the models ability to remember the protagonists name, let's extract a prompt where the next character is determined and use the \"test_prompt\" utility from TransformerLens to check the ranking of the token for that name." - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "Top 0th token. Logit: 20.48 Prob: 71.06% Token: | she|\n", + "Top 1th token. Logit: 18.81 Prob: 13.46% Token: | Lily|\n", + "Top 2th token. Logit: 17.35 Prob: 3.11% Token: | the|\n", + "Top 3th token. Logit: 17.26 Prob: 2.86% Token: | her|\n", + "Top 4th token. Logit: 16.74 Prob: 1.70% Token: | there|\n", + "Top 5th token. Logit: 16.43 Prob: 1.25% Token: | they|\n", + "Top 6th token. Logit: 15.80 Prob: 0.66% Token: | all|\n", + "Top 7th token. Logit: 15.64 Prob: 0.56% Token: | things|\n", + "Top 8th token. Logit: 15.28 Prob: 0.39% Token: | one|\n", + "Top 9th token. Logit: 15.24 Prob: 0.38% Token: | lived|\n" + ] }, { - "cell_type": "markdown", - "metadata": { - "id": "KsfJX-YpOVHv" - }, - "source": [ - "### Spot checking model abilities with `transformer_lens.utils.test_prompt`" + "data": { + "text/html": [ + "
Ranks of the answer tokens: [(' Lily', 1)]\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Lily'\u001b[0m, \u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n" ] - }, + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from transformer_lens.utils import test_prompt\n", + "\n", + "# Test the model with a prompt\n", + "test_prompt(\n", + " \"Once upon a time, there was a little girl named Lily. She lived in a big, happy little girl. On her big adventure,\",\n", + " \" Lily\",\n", + " model,\n", + " prepend_space_to_answer=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jGzOvReDOVHv" + }, + "source": [ + "In the output above, we see that the model assigns ~ 70% probability to \"she\" being the next token, and a 13% chance to \" Lily\" being the next token. Other names like Lucy or Anna are not highly ranked." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QH8YOZOzOVHv" + }, + "source": [ + "### Exploring Model Capabilities with Log Probs" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "50mqTBihOVHw" + }, + "source": [ + "Looking at token ranking for a single prompt is interesting, but a much higher through way to understand models is to look at token log probs for all tokens in text. We can use the `circuits_vis` package to get a nice visualization where we can see tokenization, and hover to get the top5 tokens by log probability. Darker tokens are tokens where the model assigned a higher probability to the actual next token." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "Tic0RCUpOVHw" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "TpmPoj7uOVHv" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ',', ' there', ' was', ' a', ' little', ' girl', ' named', ' Lily', '.', ' She', ' lived', ' in', ' a', ' big', ',', ' happy', ' little', ' girl', '.', ' On', ' her', ' big', ' adventure', ',']\n", - "Tokenized answer: [' Lily']\n" - ] - }, - { - "data": { - "text/html": [ - "
Performance on answer token:\n",
-              "Rank: 1        Logit: 18.81 Prob: 13.46% Token: | Lily|\n",
-              "
\n" - ], - "text/plain": [ - "Performance on answer token:\n", - "\u001b[1mRank: \u001b[0m\u001b[1;36m1\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m18.81\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m13.46\u001b[0m\u001b[1m% Token: | Lily|\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Top 0th token. Logit: 20.48 Prob: 71.06% Token: | she|\n", - "Top 1th token. Logit: 18.81 Prob: 13.46% Token: | Lily|\n", - "Top 2th token. Logit: 17.35 Prob: 3.11% Token: | the|\n", - "Top 3th token. Logit: 17.26 Prob: 2.86% Token: | her|\n", - "Top 4th token. Logit: 16.74 Prob: 1.70% Token: | there|\n", - "Top 5th token. Logit: 16.43 Prob: 1.25% Token: | they|\n", - "Top 6th token. Logit: 15.80 Prob: 0.66% Token: | all|\n", - "Top 7th token. Logit: 15.64 Prob: 0.56% Token: | things|\n", - "Top 8th token. Logit: 15.28 Prob: 0.39% Token: | one|\n", - "Top 9th token. Logit: 15.24 Prob: 0.38% Token: | lived|\n" - ] - }, - { - "data": { - "text/html": [ - "
Ranks of the answer tokens: [(' Lily', 1)]\n",
-              "
\n" - ], - "text/plain": [ - "\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Lily'\u001b[0m, \u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } + "data": { + "text/html": [ + "
\n", + " " ], - "source": [ - "from transformer_lens.utils import test_prompt\n", - "\n", - "# Test the model with a prompt\n", - "test_prompt(\n", - " \"Once upon a time, there was a little girl named Lily. She lived in a big, happy little girl. On her big adventure,\",\n", - " \" Lily\",\n", - " model,\n", - " prepend_space_to_answer=False,\n", - ")" + "text/plain": [ + "" ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import circuitsvis as cv # optional dep, install with pip install circuitsvis\n", + "\n", + "# Let's make a longer prompt and see the log probabilities of the tokens\n", + "example_prompt = \"\"\"Hi, how are you doing this? I'm really enjoying your posts\"\"\"\n", + "logits, cache = model.run_with_cache(example_prompt)\n", + "cv.logits.token_log_probs(\n", + " model.to_tokens(example_prompt),\n", + " model(example_prompt)[0].log_softmax(dim=-1),\n", + " model.to_string,\n", + ")\n", + "# hover on the output to see the result." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lhGIl3YbOVHw" + }, + "source": [ + "Let's combine `model.generate` and the token log probs visualization to see the log probs on text generated by the model. Note that we can play with the temperature and this should sample less likely trajectories according to the model. I've increased the maximum number of tokens in order to get a full story.\n", + "\n", + "Some things to explore:\n", + "- Which tokens does the model assign high probability to? Can you see how the model should know which word comes next?\n", + "- What happens if you increase / decrease the temperature?\n", + "- Do the rankings of tokens seem sensible to you? What about where the model doesn't assign a high probability to the token which came next?" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "Nikp2ASlOVHw" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 200/200 [00:01<00:00, 103.20it/s]\n" + ] }, { - "cell_type": "markdown", - "metadata": { - "id": "jGzOvReDOVHv" - }, - "source": [ - "In the output above, we see that the model assigns ~ 70% probability to \"she\" being the next token, and a 13% chance to \" Lily\" being the next token. Other names like Lucy or Anna are not highly ranked." + "data": { + "text/html": [ + "
\n", + " " + ], + "text/plain": [ + "" ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "example_prompt = model.generate(\n", + " \"Once upon a time\",\n", + " stop_at_eos=False, # avoids a bug on MPS\n", + " temperature=1,\n", + " verbose=True,\n", + " max_new_tokens=200,\n", + ")\n", + "logits, cache = model.run_with_cache(example_prompt)\n", + "cv.logits.token_log_probs(\n", + " model.to_tokens(example_prompt),\n", + " model(example_prompt)[0].log_softmax(dim=-1),\n", + " model.to_string,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "er3H1TDoOVHw" + }, + "source": [ + "# Training an SAE\n", + "\n", + "Now we're ready to train out SAE. We'll make a runner config, instantiate the runner and the rest is taken care of for us!\n", + "\n", + "During training, you use weights and biases to check key metrics which indicate how well we are able to optimize the variables we care about.\n", + "\n", + "To get a better sense of which variables to look at, you can read my (Joseph's) post [here](https://www.lesswrong.com/posts/f9EgfLSurAiqRJySD/open-source-sparse-autoencoders-for-all-residual-stream) and especially look at my weights and biases report [here](https://links-cdn.wandb.ai/wandb-public-images/links/jbloom/uue9i416.html).\n", + "\n", + "A few tips:\n", + "- Feel free to reorganize your wandb dashboard to put L0, CE_Loss_score, explained variance and other key metrics in one section at the top.\n", + "- Make a [run comparer](https://docs.wandb.ai/guides/app/features/panels/run-comparer) when tuning hyperparameters.\n", + "- You can download the resulting sparse autoencoder / sparsity estimate from wandb and upload them to huggingface if you want to share your SAE with other.\n", + " - cfg.json (training config)\n", + " - sae_weight.safetensors (model weights)\n", + " - sparsity.safetensors (sparsity estimate)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jCHtPycOOVHw" + }, + "source": [ + "## MLP Out\n", + "\n", + "I've tuned the hyperparameters below for a decent SAE which achieves 86% CE Loss recovered and an L0 of ~85, and runs in about 2 hours on an M3 Max. You can get an SAE that looks better faster if you only consider L0 and CE loss but it will likely have more dense features and more dead features. Here's a link to my output with two runs with two different L1's: https://wandb.ai/jbloom/sae_lens_tutorial ." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "id": "oAsZCAdJOVHw" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Run name: 16384-L1-5-LR-5e-05-Tokens-1.229e+08\n", + "n_tokens_per_buffer (millions): 0.262144\n", + "Lower bound: n_contexts_per_buffer (millions): 0.001024\n", + "Total training steps: 30000\n", + "Total wandb updates: 1000\n", + "n_tokens_per_feature_sampling_window (millions): 1048.576\n", + "n_tokens_per_dead_feature_window (millions): 1048.576\n", + "We will reset the sparsity calculation 30 times.\n", + "Number tokens in sparsity calculation window: 4.10e+06\n", + "Loaded pretrained model tiny-stories-1L-21M into HookedTransformer\n" + ] }, { - "cell_type": "markdown", - "metadata": { - "id": "QH8YOZOzOVHv" - }, - "source": [ - "### Exploring Model Capabilities with Log Probs" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "Downloading readme: 100%|██████████| 415/415 [00:00<00:00, 4.30MB/s]\n", + "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mcurt-tigges\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] }, { - "cell_type": "markdown", - "metadata": { - "id": "50mqTBihOVHw" - }, - "source": [ - "Looking at token ranking for a single prompt is interesting, but a much higher through way to understand models is to look at token log probs for all tokens in text. We can use the `circuits_vis` package to get a nice visualization where we can see tokenization, and hover to get the top5 tokens by log probability. Darker tokens are tokens where the model assigned a higher probability to the actual next token." + "data": { + "text/html": [ + "Tracking run with wandb version 0.17.1" + ], + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "id": "Tic0RCUpOVHw" - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } + "data": { + "text/html": [ + "Run data is saved locally in /home/curttigges/projects/SAELens/tutorials/wandb/run-20240610_114538-yr1gvjdc" ], - "source": [ - "import circuitsvis as cv # optional dep, install with pip install circuitsvis\n", - "\n", - "# Let's make a longer prompt and see the log probabilities of the tokens\n", - "example_prompt = \"\"\"Hi, how are you doing this? I'm really enjoying your posts\"\"\"\n", - "logits, cache = model.run_with_cache(example_prompt)\n", - "cv.logits.token_log_probs(\n", - " model.to_tokens(example_prompt),\n", - " model(example_prompt)[0].log_softmax(dim=-1),\n", - " model.to_string,\n", - ")\n", - "# hover on the output to see the result." + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "metadata": { - "id": "lhGIl3YbOVHw" - }, - "source": [ - "Let's combine `model.generate` and the token log probs visualization to see the log probs on text generated by the model. Note that we can play with the temperature and this should sample less likely trajectories according to the model. I've increased the maximum number of tokens in order to get a full story.\n", - "\n", - "Some things to explore:\n", - "- Which tokens does the model assign high probability to? Can you see how the model should know which word comes next?\n", - "- What happens if you increase / decrease the temperature?\n", - "- Do the rankings of tokens seem sensible to you? What about where the model doesn't assign a high probability to the token which came next?" + "data": { + "text/html": [ + "Syncing run 16384-L1-5-LR-5e-05-Tokens-1.229e+08 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "id": "Nikp2ASlOVHw" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 200/200 [00:01<00:00, 103.20it/s]\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } + "data": { + "text/html": [ + " View project at https://wandb.ai/curt-tigges/sae_lens_tutorial" ], - "source": [ - "example_prompt = model.generate(\n", - " \"Once upon a time\",\n", - " stop_at_eos=False, # avoids a bug on MPS\n", - " temperature=1,\n", - " verbose=True,\n", - " max_new_tokens=200,\n", - ")\n", - "logits, cache = model.run_with_cache(example_prompt)\n", - "cv.logits.token_log_probs(\n", - " model.to_tokens(example_prompt),\n", - " model(example_prompt)[0].log_softmax(dim=-1),\n", - " model.to_string,\n", - ")" + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "metadata": { - "id": "er3H1TDoOVHw" - }, - "source": [ - "# Training an SAE\n", - "\n", - "Now we're ready to train out SAE. We'll make a runner config, instantiate the runner and the rest is taken care of for us!\n", - "\n", - "During training, you use weights and biases to check key metrics which indicate how well we are able to optimize the variables we care about.\n", - "\n", - "To get a better sense of which variables to look at, you can read my (Joseph's) post [here](https://www.lesswrong.com/posts/f9EgfLSurAiqRJySD/open-source-sparse-autoencoders-for-all-residual-stream) and especially look at my weights and biases report [here](https://links-cdn.wandb.ai/wandb-public-images/links/jbloom/uue9i416.html).\n", - "\n", - "A few tips:\n", - "- Feel free to reorganize your wandb dashboard to put L0, CE_Loss_score, explained variance and other key metrics in one section at the top.\n", - "- Make a [run comparer](https://docs.wandb.ai/guides/app/features/panels/run-comparer) when tuning hyperparameters.\n", - "- You can download the resulting sparse autoencoder / sparsity estimate from wandb and upload them to huggingface if you want to share your SAE with other.\n", - " - cfg.json (training config)\n", - " - sae_weight.safetensors (model weights)\n", - " - sparsity.safetensors (sparsity estimate)" + "data": { + "text/html": [ + " View run at https://wandb.ai/curt-tigges/sae_lens_tutorial/runs/yr1gvjdc" + ], + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "metadata": { - "id": "jCHtPycOOVHw" - }, - "source": [ - "## MLP Out\n", - "\n", - "I've tuned the hyperparameters below for a decent SAE which achieves 86% CE Loss recovered and an L0 of ~85, and runs in about 2 hours on an M3 Max. You can get an SAE that looks better faster if you only consider L0 and CE loss but it will likely have more dense features and more dead features. Here's a link to my output with two runs with two different L1's: https://wandb.ai/jbloom/sae_lens_tutorial ." - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "Estimating norm scaling factor: 100%|██████████| 1000/1000 [00:33<00:00, 29.79it/s]\n", + "30000| MSE Loss 187.703 | L1 156.883: 1%| | 1228800/122880000 [42:59<70:56:28, 476.34it/s]\n" + ] }, { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "oAsZCAdJOVHw" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Run name: 16384-L1-5-LR-5e-05-Tokens-1.229e+08\n", - "n_tokens_per_buffer (millions): 0.262144\n", - "Lower bound: n_contexts_per_buffer (millions): 0.001024\n", - "Total training steps: 30000\n", - "Total wandb updates: 1000\n", - "n_tokens_per_feature_sampling_window (millions): 1048.576\n", - "n_tokens_per_dead_feature_window (millions): 1048.576\n", - "We will reset the sparsity calculation 30 times.\n", - "Number tokens in sparsity calculation window: 4.10e+06\n", - "Loaded pretrained model tiny-stories-1L-21M into HookedTransformer\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Downloading readme: 100%|██████████| 415/415 [00:00<00:00, 4.30MB/s]\n", - "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", - "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mcurt-tigges\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" - ] - }, - { - "data": { - "text/html": [ - "Tracking run with wandb version 0.17.1" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Run data is saved locally in /home/curttigges/projects/SAELens/tutorials/wandb/run-20240610_114538-yr1gvjdc" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Syncing run 16384-L1-5-LR-5e-05-Tokens-1.229e+08 to Weights & Biases (docs)
" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - " View project at https://wandb.ai/curt-tigges/sae_lens_tutorial" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - " View run at https://wandb.ai/curt-tigges/sae_lens_tutorial/runs/yr1gvjdc" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Estimating norm scaling factor: 100%|██████████| 1000/1000 [00:33<00:00, 29.79it/s]\n", - "30000| MSE Loss 187.703 | L1 156.883: 1%| | 1228800/122880000 [42:59<70:56:28, 476.34it/s]\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "

Run history:


details/current_l1_coefficient▁▅██████████████████████████████████████
details/current_learning_rate████████████████████████████████▇▇▅▅▄▃▂▁
details/n_training_tokens▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss▄▅█▇▆▆▅▄▄▄▄▃▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁
losses/overall_loss▁▅█▇▇▆▆▆▆▆▅▅▅▅▅▅▅▅▅▅▅▄▅▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄
metrics/CE_loss_score█▂▁▂▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▅▆▅▆▅▅▅▅▆▅▆▆▆▆▆▆▆▅
metrics/ce_loss_with_ablation█▆▇▃▄▄▃▆▆▅▃▅▅▅▅▃▄▅▅▆▄▅▃▅▃▄▅▁▅▅▆▅▆▄▆▄▆▄▅▄
metrics/ce_loss_with_sae▂▇█▆▅▅▅▅▄▄▅▆▄▄▃▄▄▄▁▃▃▃▁▄▄▂▃▂▄▂▁▃▄▃▄▂▄▂▃▄
metrics/ce_loss_without_sae▇▇▇▅▄▄▆▆▅▄▇█▆▆▄▅▆▇▁▄▄▅▁▇▆▄▅▄▆▄▂▅▇▆▇▄▇▄▅█
metrics/explained_variance▅▄▁▂▃▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▆▇▇▇▇▇▇▇▇█▇▇▇███████
metrics/explained_variance_std▁▅███▇▇▇▇▇▇▆▆▇▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▇▆▆▆▆▆▆▆▅▆
metrics/l0█▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l2_norm▆▅▁▃▃▂▂▃▁▅▃▃▅▆▂█▄▃▄▆▇▃▆▃▃▁▂▅▅▆▇▅▄▅▃██▃▃▃
metrics/l2_norm_in▅▆▅▆▆▄▅▅▃▆▄▄▆▇▄█▄▃▅▆▇▃▆▃▃▁▂▅▅▆▇▄▅▆▃▇█▄▄▃
metrics/l2_ratio▇▅▁▃▃▃▃▄▃▆▅▅▆▆▄▇▆▆▅▆▇▅▇▆▅▅▅▆▆▇▇▇▆▆▆██▆▆▆
metrics/mean_log10_feature_sparsity█▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
sparsity/below_1e-5▁▁▁▅▆███▆▆▆▆▆▆▆▆▅▅▅▅▅▅▅▅▅▅▅▅▅▅
sparsity/below_1e-6▁▁▁▁▁▃▆▆▆█▆█████▆▆▆▆▆▆▆▆▆▆▆▆▆▆
sparsity/dead_features▁▁▁▁▁▁▁▁▁▁▁▁▅▅▅▁▁▅█▁▁▁▅▅█▅█▅▅████▅▅▅████
sparsity/mean_passes_since_fired▂▁▁▁▁▁▁▁▂▂▁▁▂▂▂▂▂▂▄▂▂▂▃▃▄▄▄▃▃▃▄▅▆▄▄▅▆▇▇█

Run summary:


details/current_l1_coefficient5
details/current_learning_rate0.0
details/n_training_tokens122880000
losses/ghost_grad_loss0.0
losses/l1_loss31.37665
losses/mse_loss187.70346
losses/overall_loss344.5867
metrics/CE_loss_score0.90369
metrics/ce_loss_with_ablation8.30545
metrics/ce_loss_with_sae2.62867
metrics/ce_loss_without_sae2.02306
metrics/explained_variance0.66377
metrics/explained_variance_std0.13242
metrics/l0192.95703
metrics/l2_norm24.64933
metrics/l2_norm_in31.38967
metrics/l2_ratio0.77361
metrics/mean_log10_feature_sparsity-2.66941
sparsity/below_1e-52
sparsity/below_1e-62
sparsity/dead_features2
sparsity/mean_passes_since_fired0.86823

" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - " View run 16384-L1-5-LR-5e-05-Tokens-1.229e+08 at: https://wandb.ai/curt-tigges/sae_lens_tutorial/runs/yr1gvjdc
View project at: https://wandb.ai/curt-tigges/sae_lens_tutorial
Synced 6 W&B file(s), 0 media file(s), 3 artifact file(s) and 0 other file(s)" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Find logs at: ./wandb/run-20240610_114538-yr1gvjdc/logs" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } + "data": { + "text/html": [ + "\n", + "

Run history:


details/current_l1_coefficient▁▅██████████████████████████████████████
details/current_learning_rate████████████████████████████████▇▇▅▅▄▃▂▁
details/n_training_tokens▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss▄▅█▇▆▆▅▄▄▄▄▃▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁
losses/overall_loss▁▅█▇▇▆▆▆▆▆▅▅▅▅▅▅▅▅▅▅▅▄▅▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄
metrics/CE_loss_score█▂▁▂▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▅▆▅▆▅▅▅▅▆▅▆▆▆▆▆▆▆▅
metrics/ce_loss_with_ablation█▆▇▃▄▄▃▆▆▅▃▅▅▅▅▃▄▅▅▆▄▅▃▅▃▄▅▁▅▅▆▅▆▄▆▄▆▄▅▄
metrics/ce_loss_with_sae▂▇█▆▅▅▅▅▄▄▅▆▄▄▃▄▄▄▁▃▃▃▁▄▄▂▃▂▄▂▁▃▄▃▄▂▄▂▃▄
metrics/ce_loss_without_sae▇▇▇▅▄▄▆▆▅▄▇█▆▆▄▅▆▇▁▄▄▅▁▇▆▄▅▄▆▄▂▅▇▆▇▄▇▄▅█
metrics/explained_variance▅▄▁▂▃▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▆▇▇▇▇▇▇▇▇█▇▇▇███████
metrics/explained_variance_std▁▅███▇▇▇▇▇▇▆▆▇▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▇▆▆▆▆▆▆▆▅▆
metrics/l0█▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l2_norm▆▅▁▃▃▂▂▃▁▅▃▃▅▆▂█▄▃▄▆▇▃▆▃▃▁▂▅▅▆▇▅▄▅▃██▃▃▃
metrics/l2_norm_in▅▆▅▆▆▄▅▅▃▆▄▄▆▇▄█▄▃▅▆▇▃▆▃▃▁▂▅▅▆▇▄▅▆▃▇█▄▄▃
metrics/l2_ratio▇▅▁▃▃▃▃▄▃▆▅▅▆▆▄▇▆▆▅▆▇▅▇▆▅▅▅▆▆▇▇▇▆▆▆██▆▆▆
metrics/mean_log10_feature_sparsity█▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
sparsity/below_1e-5▁▁▁▅▆███▆▆▆▆▆▆▆▆▅▅▅▅▅▅▅▅▅▅▅▅▅▅
sparsity/below_1e-6▁▁▁▁▁▃▆▆▆█▆█████▆▆▆▆▆▆▆▆▆▆▆▆▆▆
sparsity/dead_features▁▁▁▁▁▁▁▁▁▁▁▁▅▅▅▁▁▅█▁▁▁▅▅█▅█▅▅████▅▅▅████
sparsity/mean_passes_since_fired▂▁▁▁▁▁▁▁▂▂▁▁▂▂▂▂▂▂▄▂▂▂▃▃▄▄▄▃▃▃▄▅▆▄▄▅▆▇▇█

Run summary:


details/current_l1_coefficient5
details/current_learning_rate0.0
details/n_training_tokens122880000
losses/ghost_grad_loss0.0
losses/l1_loss31.37665
losses/mse_loss187.70346
losses/overall_loss344.5867
metrics/CE_loss_score0.90369
metrics/ce_loss_with_ablation8.30545
metrics/ce_loss_with_sae2.62867
metrics/ce_loss_without_sae2.02306
metrics/explained_variance0.66377
metrics/explained_variance_std0.13242
metrics/l0192.95703
metrics/l2_norm24.64933
metrics/l2_norm_in31.38967
metrics/l2_ratio0.77361
metrics/mean_log10_feature_sparsity-2.66941
sparsity/below_1e-52
sparsity/below_1e-62
sparsity/dead_features2
sparsity/mean_passes_since_fired0.86823

" ], - "source": [ - "total_training_steps = 30_000 # probably we should do more\n", - "batch_size = 4096\n", - "total_training_tokens = total_training_steps * batch_size\n", - "\n", - "lr_warm_up_steps = 0\n", - "lr_decay_steps = total_training_steps // 5 # 20% of training\n", - "l1_warm_up_steps = total_training_steps // 20 # 5% of training\n", - "\n", - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"tiny-stories-1L-21M\", # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)\n", - " hook_name=\"blocks.0.hook_mlp_out\", # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)\n", - " hook_layer=0, # Only one layer in the model.\n", - " d_in=1024, # the width of the mlp output.\n", - " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\", # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.\n", - " is_dataset_tokenized=True,\n", - " streaming=True, # we could pre-download the token dataset if it was small.\n", - " # SAE Parameters\n", - " mse_loss_normalization=None, # We won't normalize the mse loss,\n", - " expansion_factor=16, # the width of the SAE. Larger will result in better stats but slower training.\n", - " b_dec_init_method=\"zeros\", # The geometric median can be used to initialize the decoder weights.\n", - " apply_b_dec_to_input=False, # We won't apply the decoder weights to the input.\n", - " normalize_sae_decoder=False,\n", - " scale_sparsity_penalty_by_decoder_norm=True,\n", - " decoder_heuristic_init=True,\n", - " init_encoder_as_decoder_transpose=True,\n", - " normalize_activations=\"expected_average_only_in\",\n", - " # Training Parameters\n", - " lr=5e-5, # lower the better, we'll go fairly high to speed up the tutorial.\n", - " adam_beta1=0.9, # adam params (default, but once upon a time we experimented with these.)\n", - " adam_beta2=0.999,\n", - " lr_scheduler_name=\"constant\", # constant learning rate with warmup. Could be better schedules out there.\n", - " lr_warm_up_steps=lr_warm_up_steps, # this can help avoid too many dead features initially.\n", - " lr_decay_steps=lr_decay_steps, # this will help us avoid overfitting.\n", - " l1_coefficient=5, # will control how sparse the feature activations are\n", - " l1_warm_up_steps=l1_warm_up_steps, # this can help avoid too many dead features initially.\n", - " lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1)\n", - " train_batch_size_tokens=batch_size,\n", - " context_size=512, # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=64, # controls how many activations we store / shuffle.\n", - " training_tokens=total_training_tokens, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.\n", - " store_batch_size_prompts=16,\n", - " # Resampling protocol\n", - " use_ghost_grads=False, # we don't use ghost grads anymore.\n", - " feature_sampling_window=1000, # this controls our reporting of feature sparsity stats\n", - " dead_feature_window=1000, # would effect resampling or ghost grads if we were using it.\n", - " dead_feature_threshold=1e-4, # would effect resampling or ghost grads if we were using it.\n", - " # WANDB\n", - " log_to_wandb=True, # always use wandb unless you are just testing code.\n", - " wandb_project=\"sae_lens_tutorial\",\n", - " wandb_log_frequency=30,\n", - " eval_every_n_wandb_logs=20,\n", - " # Misc\n", - " device=device,\n", - " seed=42,\n", - " n_checkpoints=0,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=\"float32\"\n", - ")\n", - "# look at the next cell to see some instruction for what to do while this is running.\n", - "sparse_autoencoder = SAETrainingRunner(cfg).run()" + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "markdown", - "metadata": { - "id": "khR_QkAJOVHw" - }, - "source": [ - "# TO DO: Understanding TinyStories-1L with our SAE\n", - "\n", - "I haven't had time yet to complete this section, but I'd love to see a PR where someones uses an SAE they trained in this tutorial to understand this model better." + "data": { + "text/html": [ + " View run 16384-L1-5-LR-5e-05-Tokens-1.229e+08 at: https://wandb.ai/curt-tigges/sae_lens_tutorial/runs/yr1gvjdc
View project at: https://wandb.ai/curt-tigges/sae_lens_tutorial
Synced 6 W&B file(s), 0 media file(s), 3 artifact file(s) and 0 other file(s)" + ], + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "b4sUumxZOVHw" - }, - "outputs": [], - "source": [ - "import pandas as pd\n", - "\n", - "# Let's start by getting the top 10 logits for each feature\n", - "projection_onto_unembed = sparse_autoencoder.W_dec @ model.W_U\n", - "\n", - "\n", - "# get the top 10 logits.\n", - "vals, inds = torch.topk(projection_onto_unembed, 10, dim=1)\n", - "\n", - "# get 10 random features\n", - "random_indices = torch.randint(0, projection_onto_unembed.shape[0], (10,))\n", - "\n", - "# Show the top 10 logits promoted by those features\n", - "top_10_logits_df = pd.DataFrame(\n", - " [model.to_str_tokens(i) for i in inds[random_indices]],\n", - " index=random_indices.tolist(),\n", - ").T\n", - "top_10_logits_df" + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20240610_114538-yr1gvjdc/logs" + ], + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "T4", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } + ], + "source": [ + "total_training_steps = 30_000 # probably we should do more\n", + "batch_size = 4096\n", + "total_training_tokens = total_training_steps * batch_size\n", + "\n", + "lr_warm_up_steps = 0\n", + "lr_decay_steps = total_training_steps // 5 # 20% of training\n", + "l1_warm_up_steps = total_training_steps // 20 # 5% of training\n", + "\n", + "cfg = LanguageModelSAERunnerConfig(\n", + " # Data Generating Function (Model + Training Distibuion)\n", + " model_name=\"tiny-stories-1L-21M\", # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)\n", + " hook_name=\"blocks.0.hook_mlp_out\", # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)\n", + " hook_layer=0, # Only one layer in the model.\n", + " d_in=1024, # the width of the mlp output.\n", + " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\", # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.\n", + " is_dataset_tokenized=True,\n", + " streaming=True, # we could pre-download the token dataset if it was small.\n", + " # SAE Parameters\n", + " mse_loss_normalization=None, # We won't normalize the mse loss,\n", + " expansion_factor=16, # the width of the SAE. Larger will result in better stats but slower training.\n", + " b_dec_init_method=\"zeros\", # The geometric median can be used to initialize the decoder weights.\n", + " apply_b_dec_to_input=False, # We won't apply the decoder weights to the input.\n", + " normalize_sae_decoder=False,\n", + " scale_sparsity_penalty_by_decoder_norm=True,\n", + " decoder_heuristic_init=True,\n", + " init_encoder_as_decoder_transpose=True,\n", + " normalize_activations=\"expected_average_only_in\",\n", + " # Training Parameters\n", + " lr=5e-5, # lower the better, we'll go fairly high to speed up the tutorial.\n", + " adam_beta1=0.9, # adam params (default, but once upon a time we experimented with these.)\n", + " adam_beta2=0.999,\n", + " lr_scheduler_name=\"constant\", # constant learning rate with warmup. Could be better schedules out there.\n", + " lr_warm_up_steps=lr_warm_up_steps, # this can help avoid too many dead features initially.\n", + " lr_decay_steps=lr_decay_steps, # this will help us avoid overfitting.\n", + " l1_coefficient=5, # will control how sparse the feature activations are\n", + " l1_warm_up_steps=l1_warm_up_steps, # this can help avoid too many dead features initially.\n", + " lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1)\n", + " train_batch_size_tokens=batch_size,\n", + " context_size=512, # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.\n", + " # Activation Store Parameters\n", + " n_batches_in_buffer=64, # controls how many activations we store / shuffle.\n", + " training_tokens=total_training_tokens, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.\n", + " store_batch_size_prompts=16,\n", + " # Resampling protocol\n", + " use_ghost_grads=False, # we don't use ghost grads anymore.\n", + " feature_sampling_window=1000, # this controls our reporting of feature sparsity stats\n", + " dead_feature_window=1000, # would effect resampling or ghost grads if we were using it.\n", + " dead_feature_threshold=1e-4, # would effect resampling or ghost grads if we were using it.\n", + " # WANDB\n", + " log_to_wandb=True, # always use wandb unless you are just testing code.\n", + " wandb_project=\"sae_lens_tutorial\",\n", + " wandb_log_frequency=30,\n", + " eval_every_n_wandb_logs=20,\n", + " # Misc\n", + " device=device,\n", + " seed=42,\n", + " n_checkpoints=0,\n", + " checkpoint_path=\"checkpoints\",\n", + " dtype=\"float32\",\n", + ")\n", + "# look at the next cell to see some instruction for what to do while this is running.\n", + "sparse_autoencoder = SAETrainingRunner(cfg).run()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "khR_QkAJOVHw" + }, + "source": [ + "# TO DO: Understanding TinyStories-1L with our SAE\n", + "\n", + "I haven't had time yet to complete this section, but I'd love to see a PR where someones uses an SAE they trained in this tutorial to understand this model better." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "b4sUumxZOVHw" + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "# Let's start by getting the top 10 logits for each feature\n", + "projection_onto_unembed = sparse_autoencoder.W_dec @ model.W_U\n", + "\n", + "\n", + "# get the top 10 logits.\n", + "vals, inds = torch.topk(projection_onto_unembed, 10, dim=1)\n", + "\n", + "# get 10 random features\n", + "random_indices = torch.randint(0, projection_onto_unembed.shape[0], (10,))\n", + "\n", + "# Show the top 10 logits promoted by those features\n", + "top_10_logits_df = pd.DataFrame(\n", + " [model.to_str_tokens(i) for i in inds[random_indices]],\n", + " index=random_indices.tolist(),\n", + ").T\n", + "top_10_logits_df" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" }, - "nbformat": 4, - "nbformat_minor": 0 + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/tutorials/tutorial_2_0.ipynb b/tutorials/tutorial_2_0.ipynb index fe25d85b..e080e79b 100644 --- a/tutorials/tutorial_2_0.ipynb +++ b/tutorials/tutorial_2_0.ipynb @@ -60,14 +60,17 @@ "outputs": [], "source": [ "try:\n", - " import google.colab # type: ignore\n", + " import google.colab # type: ignore\n", " from google.colab import output\n", + "\n", " COLAB = True\n", " %pip install sae-lens transformer-lens sae-dashboard\n", "except:\n", " COLAB = False\n", - " from IPython import get_ipython # type: ignore\n", - " ipython = get_ipython(); assert ipython is not None\n", + " from IPython import get_ipython # type: ignore\n", + "\n", + " ipython = get_ipython()\n", + " assert ipython is not None\n", " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", " ipython.run_line_magic(\"autoreload\", \"2\")\n", "\n", @@ -113,9 +116,19 @@ "from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory\n", "\n", "# TODO: Make this nicer.\n", - "df = pd.DataFrame.from_records({k:v.__dict__ for k,v in get_pretrained_saes_directory().items()}).T\n", - "df.drop(columns=[\"expected_var_explained\", \"expected_l0\", \"config_overrides\", \"conversion_func\"], inplace=True)\n", - "df # Each row is a \"release\" which has multiple SAEs which may have different configs / match different hook points in a model. " + "df = pd.DataFrame.from_records(\n", + " {k: v.__dict__ for k, v in get_pretrained_saes_directory().items()}\n", + ").T\n", + "df.drop(\n", + " columns=[\n", + " \"expected_var_explained\",\n", + " \"expected_l0\",\n", + " \"config_overrides\",\n", + " \"conversion_func\",\n", + " ],\n", + " inplace=True,\n", + ")\n", + "df # Each row is a \"release\" which has multiple SAEs which may have different configs / match different hook points in a model." ] }, { @@ -140,17 +153,21 @@ "source": [ "# show the contents of the saes_map column for a specific row\n", "print(\"SAEs in the GTP2 Small Resid Pre release\")\n", - "for k,v in df.loc[df.release == \"gpt2-small-res-jb\", \"saes_map\"].values[0].items():\n", + "for k, v in df.loc[df.release == \"gpt2-small-res-jb\", \"saes_map\"].values[0].items():\n", " print(f\"SAE id: {k} for hook point: {v}\")\n", "\n", - "print(\"-\"*50)\n", + "print(\"-\" * 50)\n", "print(\"SAEs in the feature splitting release\")\n", - "for k,v in df.loc[df.release == \"gpt2-small-res-jb-feature-splitting\", \"saes_map\"].values[0].items():\n", + "for k, v in (\n", + " df.loc[df.release == \"gpt2-small-res-jb-feature-splitting\", \"saes_map\"]\n", + " .values[0]\n", + " .items()\n", + "):\n", " print(f\"SAE id: {k} for hook point: {v}\")\n", - " \n", - "print(\"-\"*50)\n", + "\n", + "print(\"-\" * 50)\n", "print(\"SAEs in the Gemma base model release\")\n", - "for k,v in df.loc[df.release == \"gemma-2b-res-jb\", \"saes_map\"].values[0].items():\n", + "for k, v in df.loc[df.release == \"gemma-2b-res-jb\", \"saes_map\"].values[0].items():\n", " print(f\"SAE id: {k} for hook point: {v}\")" ] }, @@ -172,15 +189,15 @@ "# from transformer_lens import HookedTransformer\n", "from sae_lens import SAE, HookedSAETransformer\n", "\n", - "model = HookedSAETransformer.from_pretrained(\"gpt2-small\", device = device)\n", + "model = HookedSAETransformer.from_pretrained(\"gpt2-small\", device=device)\n", "\n", "# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)\n", "# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict\n", - "# We also return the feature sparsities which are stored in HF for convenience. \n", + "# We also return the feature sparsities which are stored in HF for convenience.\n", "sae, cfg_dict, sparsity = SAE.from_pretrained(\n", - " release = \"gpt2-small-res-jb\", # <- Release name \n", - " sae_id = \"blocks.7.hook_resid_pre\", # <- SAE id (not always a hook point!)\n", - " device = device\n", + " release=\"gpt2-small-res-jb\", # <- Release name\n", + " sae_id=\"blocks.7.hook_resid_pre\", # <- SAE id (not always a hook point!)\n", + " device=device,\n", ")" ] }, @@ -235,18 +252,18 @@ "metadata": {}, "outputs": [], "source": [ - "from datasets import load_dataset \n", + "from datasets import load_dataset\n", "from transformer_lens.utils import tokenize_and_concatenate\n", "\n", "dataset = load_dataset(\n", - " path = \"NeelNanda/pile-10k\",\n", + " path=\"NeelNanda/pile-10k\",\n", " split=\"train\",\n", " streaming=False,\n", ")\n", "\n", "token_dataset = tokenize_and_concatenate(\n", - " dataset= dataset,# type: ignore\n", - " tokenizer = model.tokenizer, # type: ignore\n", + " dataset=dataset, # type: ignore\n", + " tokenizer=model.tokenizer, # type: ignore\n", " streaming=True,\n", " max_length=sae.cfg.context_size,\n", " add_bos_token=sae.cfg.prepend_bos,\n", @@ -303,10 +320,14 @@ "\n", "html_template = \"https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300\"\n", "\n", - "def get_dashboard_html(sae_release = \"gpt2-small\", sae_id=\"7-res-jb\", feature_idx=0):\n", + "\n", + "def get_dashboard_html(sae_release=\"gpt2-small\", sae_id=\"7-res-jb\", feature_idx=0):\n", " return html_template.format(sae_release, sae_id, feature_idx)\n", "\n", - "html = get_dashboard_html(sae_release = \"gpt2-small\", sae_id=\"7-res-jb\", feature_idx=feature_idx)\n", + "\n", + "html = get_dashboard_html(\n", + " sae_release=\"gpt2-small\", sae_id=\"7-res-jb\", feature_idx=feature_idx\n", + ")\n", "IFrame(html, width=1200, height=600)" ] }, @@ -361,7 +382,9 @@ "# rename index to \"feature\"\n", "explanations_df.rename(columns={\"index\": \"feature\"}, inplace=True)\n", "# explanations_df[\"feature\"] = explanations_df[\"feature\"].astype(int)\n", - "explanations_df[\"description\"] = explanations_df[\"description\"].apply(lambda x: x.lower())\n", + "explanations_df[\"description\"] = explanations_df[\"description\"].apply(\n", + " lambda x: x.lower()\n", + ")\n", "explanations_df" ] }, @@ -389,7 +412,11 @@ "outputs": [], "source": [ "# Let's get the dashboard for this feature.\n", - "html = get_dashboard_html(sae_release = \"gpt2-small\", sae_id=\"7-res-jb\", feature_idx=bible_features.feature.values[0])\n", + "html = get_dashboard_html(\n", + " sae_release=\"gpt2-small\",\n", + " sae_id=\"7-res-jb\",\n", + " feature_idx=bible_features.feature.values[0],\n", + ")\n", "IFrame(html, width=1200, height=600)" ] }, @@ -440,7 +467,7 @@ "# This is because the SAE will be used to modify the forward pass, and if it doesn't reconstruct the activations well, the outputs may be effected.\n", "# Good SAEs have small error terms but it's something to be mindful of.\n", "\n", - "sae.use_error_term # If use error term is set to false, we will modify the forward pass by using the sae." + "sae.use_error_term # If use error term is set to false, we will modify the forward pass by using the sae." ] }, { @@ -459,7 +486,7 @@ "# hooked SAE Transformer will enable us to get the feature activations from the SAE\n", "_, cache = model.run_with_cache_with_saes(prompt, saes=[sae])\n", "\n", - "print([(k, v.shape) for k,v in cache.items() if \"sae\" in k])\n", + "print([(k, v.shape) for k, v in cache.items() if \"sae\" in k])\n", "\n", "# note there were 11 tokens in our prompt, the residual stream dimension is 768, and the number of SAE features is 768" ] @@ -481,16 +508,20 @@ "\n", "# hover over lines to see the Feature ID.\n", "px.line(\n", - " cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][0, -1, :].cpu().numpy(),\n", + " cache[\"blocks.7.hook_resid_pre.hook_sae_acts_post\"][0, -1, :].cpu().numpy(),\n", " title=\"Feature activations at the final token position\",\n", " labels={\"index\": \"Feature\", \"value\": \"Activation\"},\n", ").show()\n", "\n", "# let's print the top 5 features and how much they fired\n", - "vals, inds = torch.topk(cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][0, -1, :], 5)\n", + "vals, inds = torch.topk(\n", + " cache[\"blocks.7.hook_resid_pre.hook_sae_acts_post\"][0, -1, :], 5\n", + ")\n", "for val, ind in zip(vals, inds):\n", " print(f\"Feature {ind} fired {val:.2f}\")\n", - " html = get_dashboard_html(sae_release = \"gpt2-small\", sae_id=\"7-res-jb\", feature_idx=ind)\n", + " html = get_dashboard_html(\n", + " sae_release=\"gpt2-small\", sae_id=\"7-res-jb\", feature_idx=ind\n", + " )\n", " display(IFrame(html, width=1200, height=300))" ] }, @@ -533,16 +564,24 @@ "metadata": {}, "outputs": [], "source": [ - "prompt = [\"In the beginning, God created the heavens and the\", \"In the beginning, God created the cat and the\"]\n", + "prompt = [\n", + " \"In the beginning, God created the heavens and the\",\n", + " \"In the beginning, God created the cat and the\",\n", + "]\n", "_, cache = model.run_with_cache_with_saes(prompt, saes=[sae])\n", - "print([(k, v.shape) for k,v in cache.items() if \"sae\" in k])\n", + "print([(k, v.shape) for k, v in cache.items() if \"sae\" in k])\n", "\n", - "feature_activation_df = pd.DataFrame(cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][0, -1, :].cpu().numpy(),\n", - " index = [f\"feature_{i}\" for i in range(sae.cfg.d_sae)],\n", + "feature_activation_df = pd.DataFrame(\n", + " cache[\"blocks.7.hook_resid_pre.hook_sae_acts_post\"][0, -1, :].cpu().numpy(),\n", + " index=[f\"feature_{i}\" for i in range(sae.cfg.d_sae)],\n", ")\n", "feature_activation_df.columns = [\"heavens_and_the\"]\n", - "feature_activation_df[\"cat_and_the\"] = cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][1, -1, :].cpu().numpy()\n", - "feature_activation_df[\"diff\"]= feature_activation_df[\"heavens_and_the\"] - feature_activation_df[\"cat_and_the\"]\n", + "feature_activation_df[\"cat_and_the\"] = (\n", + " cache[\"blocks.7.hook_resid_pre.hook_sae_acts_post\"][1, -1, :].cpu().numpy()\n", + ")\n", + "feature_activation_df[\"diff\"] = (\n", + " feature_activation_df[\"heavens_and_the\"] - feature_activation_df[\"cat_and_the\"]\n", + ")\n", "\n", "fig = px.line(\n", " feature_activation_df,\n", @@ -570,11 +609,16 @@ "source": [ "# let's look at the biggest features in terms of absolute difference\n", "\n", - "diff = cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][1, -1, :].cpu() - cache['blocks.7.hook_resid_pre.hook_sae_acts_post'][0, -1, :].cpu()\n", + "diff = (\n", + " cache[\"blocks.7.hook_resid_pre.hook_sae_acts_post\"][1, -1, :].cpu()\n", + " - cache[\"blocks.7.hook_resid_pre.hook_sae_acts_post\"][0, -1, :].cpu()\n", + ")\n", "vals, inds = torch.topk(torch.abs(diff), 5)\n", "for val, ind in zip(vals, inds):\n", " print(f\"Feature {ind} had a difference of {val:.2f}\")\n", - " html = get_dashboard_html(sae_release = \"gpt2-small\", sae_id=\"7-res-jb\", feature_idx=ind)\n", + " html = get_dashboard_html(\n", + " sae_release=\"gpt2-small\", sae_id=\"7-res-jb\", feature_idx=ind\n", + " )\n", " display(IFrame(html, width=1200, height=300))" ] }, @@ -638,7 +682,7 @@ " train_batch_size_tokens=4096,\n", " n_batches_in_buffer=32,\n", " device=device,\n", - ")\n" + ")" ] }, { @@ -650,36 +694,43 @@ "def list_flatten(nested_list):\n", " return [x for y in nested_list for x in y]\n", "\n", + "\n", "# A very handy function Neel wrote to get context around a feature activation\n", - "def make_token_df(tokens, len_prefix=5, len_suffix=3, model = model):\n", + "def make_token_df(tokens, len_prefix=5, len_suffix=3, model=model):\n", " str_tokens = [model.to_str_tokens(t) for t in tokens]\n", - " unique_token = [[f\"{s}/{i}\" for i, s in enumerate(str_tok)] for str_tok in str_tokens]\n", - " \n", + " unique_token = [\n", + " [f\"{s}/{i}\" for i, s in enumerate(str_tok)] for str_tok in str_tokens\n", + " ]\n", + "\n", " context = []\n", " prompt = []\n", " pos = []\n", " label = []\n", " for b in range(tokens.shape[0]):\n", " for p in range(tokens.shape[1]):\n", - " prefix = \"\".join(str_tokens[b][max(0, p-len_prefix):p])\n", - " if p==tokens.shape[1]-1:\n", + " prefix = \"\".join(str_tokens[b][max(0, p - len_prefix) : p])\n", + " if p == tokens.shape[1] - 1:\n", " suffix = \"\"\n", " else:\n", - " suffix = \"\".join(str_tokens[b][p+1:min(tokens.shape[1]-1, p+1+len_suffix)])\n", + " suffix = \"\".join(\n", + " str_tokens[b][p + 1 : min(tokens.shape[1] - 1, p + 1 + len_suffix)]\n", + " )\n", " current = str_tokens[b][p]\n", " context.append(f\"{prefix}|{current}|{suffix}\")\n", " prompt.append(b)\n", " pos.append(p)\n", " label.append(f\"{b}/{p}\")\n", " # print(len(batch), len(pos), len(context), len(label))\n", - " return pd.DataFrame(dict(\n", - " str_tokens=list_flatten(str_tokens),\n", - " unique_token=list_flatten(unique_token),\n", - " context=context,\n", - " prompt=prompt,\n", - " pos=pos,\n", - " label=label,\n", - " ))" + " return pd.DataFrame(\n", + " dict(\n", + " str_tokens=list_flatten(str_tokens),\n", + " unique_token=list_flatten(unique_token),\n", + " context=context,\n", + " prompt=prompt,\n", + " pos=pos,\n", + " label=label,\n", + " )\n", + " )" ] }, { @@ -703,7 +754,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "# finding max activating examples is a bit harder. To do this we need to calculate feature activations for a large number of tokens\n", "feature_list = torch.randint(0, sae.cfg.d_sae, (100,))\n", "examples_found = 0\n", @@ -720,14 +770,16 @@ " tokens = activation_store.get_batch_tokens()\n", " tokens_df = make_token_df(tokens)\n", " tokens_df[\"batch\"] = i\n", - " \n", + "\n", " flat_tokens = tokens.flatten()\n", - " \n", - " _, cache = model.run_with_cache(tokens, stop_at_layer = sae.cfg.hook_layer + 1, names_filter = [sae.cfg.hook_name])\n", + "\n", + " _, cache = model.run_with_cache(\n", + " tokens, stop_at_layer=sae.cfg.hook_layer + 1, names_filter=[sae.cfg.hook_name]\n", + " )\n", " sae_in = cache[sae.cfg.hook_name]\n", " feature_acts = sae.encode(sae_in).squeeze()\n", "\n", - " feature_acts = feature_acts.flatten(0,1)\n", + " feature_acts = feature_acts.flatten(0, 1)\n", " fired_mask = (feature_acts[:, feature_list]).sum(dim=-1) > 0\n", " fired_tokens = model.to_str_tokens(flat_tokens[fired_mask])\n", " reconstruction = feature_acts[fired_mask][:, feature_list] @ sae.W_dec[feature_list]\n", @@ -737,12 +789,12 @@ " all_feature_acts.append(feature_acts[fired_mask][:, feature_list])\n", " all_fired_tokens.append(fired_tokens)\n", " all_reconstructions.append(reconstruction)\n", - " \n", + "\n", " examples_found += len(fired_tokens)\n", " # print(f\"Examples found: {examples_found}\")\n", " # update description\n", " pbar.set_description(f\"Examples found: {examples_found}\")\n", - " \n", + "\n", "# flatten the list of lists\n", "all_token_dfs = pd.concat(all_token_dfs)\n", "all_fired_tokens = list_flatten(all_fired_tokens)\n", @@ -770,7 +822,10 @@ "metadata": {}, "outputs": [], "source": [ - "feature_acts_df = pd.DataFrame(all_feature_acts.detach().cpu().numpy(), columns = [f\"feature_{i}\" for i in feature_list])\n", + "feature_acts_df = pd.DataFrame(\n", + " all_feature_acts.detach().cpu().numpy(),\n", + " columns=[f\"feature_{i}\" for i in feature_list],\n", + ")\n", "feature_acts_df.shape" ] }, @@ -783,15 +838,20 @@ "feature_idx = 0\n", "# get non-zero activations\n", "\n", - "all_positive_acts = all_feature_acts[all_feature_acts[:, feature_idx] > 0][:, feature_idx].detach()\n", - "prop_positive_activations = 100*len(all_positive_acts) / (total_batches*batch_size_tokens)\n", + "all_positive_acts = all_feature_acts[all_feature_acts[:, feature_idx] > 0][\n", + " :, feature_idx\n", + "].detach()\n", + "prop_positive_activations = (\n", + " 100 * len(all_positive_acts) / (total_batches * batch_size_tokens)\n", + ")\n", "\n", "px.histogram(\n", " all_positive_acts.cpu(),\n", " nbins=50,\n", " title=f\"Histogram of positive activations - {prop_positive_activations:.3f}% of activations were positive\",\n", " labels={\"value\": \"Activation\"},\n", - " width=800,)" + " width=800,\n", + ")" ] }, { @@ -800,8 +860,12 @@ "metadata": {}, "outputs": [], "source": [ - "top_10_activations = feature_acts_df.sort_values(f\"feature_{feature_list[0]}\", ascending=False).head(10)\n", - "all_token_dfs.iloc[top_10_activations.index] # TODO: double check this is working correctly" + "top_10_activations = feature_acts_df.sort_values(\n", + " f\"feature_{feature_list[0]}\", ascending=False\n", + ").head(10)\n", + "all_token_dfs.iloc[\n", + " top_10_activations.index\n", + "] # TODO: double check this is working correctly" ] }, { @@ -837,7 +901,9 @@ "_, top_k_tokens = torch.topk(projection_matrix[feature_list], top_k, dim=1)\n", "\n", "\n", - "feature_df = pd.DataFrame(top_k_tokens.cpu().numpy(), index = [f\"feature_{i}\" for i in feature_list]).T\n", + "feature_df = pd.DataFrame(\n", + " top_k_tokens.cpu().numpy(), index=[f\"feature_{i}\" for i in feature_list]\n", + ").T\n", "feature_df.index = [f\"token_{i}\" for i in range(top_k)]\n", "feature_df.applymap(lambda x: model.tokenizer.decode(x))" ] @@ -862,7 +928,11 @@ "metadata": {}, "outputs": [], "source": [ - "html = get_dashboard_html(sae_release = \"gpt2-small\", sae_id=f\"{sae.cfg.hook_layer}-res-jb\", feature_idx=feature_list[0])\n", + "html = get_dashboard_html(\n", + " sae_release=\"gpt2-small\",\n", + " sae_id=f\"{sae.cfg.hook_layer}-res-jb\",\n", + " feature_idx=feature_list[0],\n", + ")\n", "IFrame(html, width=1200, height=600)" ] }, @@ -888,8 +958,8 @@ "metadata": {}, "outputs": [], "source": [ - "# only valid for res-jb resid_pre 7. \n", - "# Josh Engel's emailed us these lists. \n", + "# only valid for res-jb resid_pre 7.\n", + "# Josh Engel's emailed us these lists.\n", "day_of_the_week_features = [2592, 4445, 4663, 4733, 6531, 8179, 9566, 20927, 24185]\n", "# months_of_the_year = [3977, 4140, 5993, 7299, 9104, 9401, 10449, 11196, 12661, 14715, 17068, 17528, 19589, 21033, 22043, 23304]\n", "# years_of_10th_century = [1052, 2753, 4427, 6382, 8314, 9576, 9606, 13551, 19734, 20349]\n", @@ -910,14 +980,16 @@ " tokens = activation_store.get_batch_tokens()\n", " tokens_df = make_token_df(tokens)\n", " tokens_df[\"batch\"] = i\n", - " \n", + "\n", " flat_tokens = tokens.flatten()\n", - " \n", - " _, cache = model.run_with_cache(tokens, stop_at_layer = sae.cfg.hook_layer + 1, names_filter = [sae.cfg.hook_name])\n", + "\n", + " _, cache = model.run_with_cache(\n", + " tokens, stop_at_layer=sae.cfg.hook_layer + 1, names_filter=[sae.cfg.hook_name]\n", + " )\n", " sae_in = cache[sae.cfg.hook_name]\n", " feature_acts = sae.encode(sae_in).squeeze()\n", "\n", - " feature_acts = feature_acts.flatten(0,1)\n", + " feature_acts = feature_acts.flatten(0, 1)\n", " fired_mask = (feature_acts[:, feature_list]).sum(dim=-1) > 0\n", " fired_tokens = model.to_str_tokens(flat_tokens[fired_mask])\n", " reconstruction = feature_acts[fired_mask][:, feature_list] @ sae.W_dec[feature_list]\n", @@ -927,12 +999,12 @@ " all_feature_acts.append(feature_acts[fired_mask][:, feature_list])\n", " all_fired_tokens.append(fired_tokens)\n", " all_reconstructions.append(reconstruction)\n", - " \n", + "\n", " examples_found += len(fired_tokens)\n", " # print(f\"Examples found: {examples_found}\")\n", " # update description\n", " pbar.set_description(f\"Examples found: {examples_found}\")\n", - " \n", + "\n", "# flatten the list of lists\n", "all_token_dfs = pd.concat(all_token_dfs)\n", "all_fired_tokens = list_flatten(all_fired_tokens)\n", @@ -955,7 +1027,7 @@ "source": [ "# do PCA on reconstructions\n", "from sklearn.decomposition import PCA\n", - "import plotly.express as px \n", + "import plotly.express as px\n", "\n", "pca = PCA(n_components=3)\n", "pca_embedding = pca.fit_transform(all_reconstructions.detach().cpu().numpy())\n", @@ -966,15 +1038,16 @@ "\n", "\n", "px.scatter(\n", - " pca_df, x=\"PC2\", y=\"PC3\",\n", + " pca_df,\n", + " x=\"PC2\",\n", + " y=\"PC3\",\n", " hover_data=[\"context\"],\n", " hover_name=\"tokens\",\n", - " height = 800,\n", - " width = 1200,\n", - " color = \"tokens\",\n", - " title = \"PCA Subspace Reconstructions\",\n", - ").show()\n", - "\n" + " height=800,\n", + " width=1200,\n", + " color=\"tokens\",\n", + " title=\"PCA Subspace Reconstructions\",\n", + ").show()" ] }, { @@ -1012,23 +1085,24 @@ "outputs": [], "source": [ "from tqdm import tqdm\n", - "from functools import partial \n", + "from functools import partial\n", + "\n", "\n", "def find_max_activation(model, sae, activation_store, feature_idx, num_batches=100):\n", - " '''\n", - " Find the maximum activation for a given feature index. This is useful for \n", + " \"\"\"\n", + " Find the maximum activation for a given feature index. This is useful for\n", " calibrating the right amount of the feature to add.\n", - " '''\n", + " \"\"\"\n", " max_activation = 0.0\n", "\n", " pbar = tqdm(range(num_batches))\n", " for _ in pbar:\n", " tokens = activation_store.get_batch_tokens()\n", - " \n", + "\n", " _, cache = model.run_with_cache(\n", - " tokens, \n", - " stop_at_layer=sae.cfg.hook_layer + 1, \n", - " names_filter=[sae.cfg.hook_name]\n", + " tokens,\n", + " stop_at_layer=sae.cfg.hook_layer + 1,\n", + " names_filter=[sae.cfg.hook_name],\n", " )\n", " sae_in = cache[sae.cfg.hook_name]\n", " feature_acts = sae.encode(sae_in).squeeze()\n", @@ -1036,27 +1110,39 @@ " feature_acts = feature_acts.flatten(0, 1)\n", " batch_max_activation = feature_acts[:, feature_idx].max().item()\n", " max_activation = max(max_activation, batch_max_activation)\n", - " \n", + "\n", " pbar.set_description(f\"Max activation: {max_activation:.4f}\")\n", "\n", " return max_activation\n", "\n", - "def steering(activations, hook, steering_strength=1.0, steering_vector=None, max_act=1.0):\n", + "\n", + "def steering(\n", + " activations, hook, steering_strength=1.0, steering_vector=None, max_act=1.0\n", + "):\n", " # Note if the feature fires anyway, we'd be adding to that here.\n", " return activations + max_act * steering_strength * steering_vector\n", "\n", - "def generate_with_steering(model, sae, prompt, steering_feature, max_act, steering_strength=1.0, max_new_tokens=95):\n", + "\n", + "def generate_with_steering(\n", + " model,\n", + " sae,\n", + " prompt,\n", + " steering_feature,\n", + " max_act,\n", + " steering_strength=1.0,\n", + " max_new_tokens=95,\n", + "):\n", " input_ids = model.to_tokens(prompt, prepend_bos=sae.cfg.prepend_bos)\n", - " \n", + "\n", " steering_vector = sae.W_dec[steering_feature].to(model.cfg.device)\n", - " \n", + "\n", " steering_hook = partial(\n", " steering,\n", " steering_vector=steering_vector,\n", " steering_strength=steering_strength,\n", - " max_act=max_act\n", + " max_act=max_act,\n", " )\n", - " \n", + "\n", " # standard transformerlens syntax for a hook context for generation\n", " with model.hooks(fwd_hooks=[(sae.cfg.hook_name, steering_hook)]):\n", " output = model.generate(\n", @@ -1064,12 +1150,13 @@ " max_new_tokens=max_new_tokens,\n", " temperature=0.7,\n", " top_p=0.9,\n", - " stop_at_eos = False if device == \"mps\" else True,\n", - " prepend_bos = sae.cfg.prepend_bos,\n", + " stop_at_eos=False if device == \"mps\" else True,\n", + " prepend_bos=sae.cfg.prepend_bos,\n", " )\n", - " \n", + "\n", " return model.tokenizer.decode(output[0])\n", "\n", + "\n", "# Choose a feature to steer\n", "steering_feature = steering_feature = 20115 # Choose a feature to steer towards\n", "\n", @@ -1083,16 +1170,18 @@ "prompt = \"Once upon a time\"\n", "normal_text = model.generate(\n", " prompt,\n", - " max_new_tokens=95, \n", - " stop_at_eos = False if device == \"mps\" else True,\n", - " prepend_bos = sae.cfg.prepend_bos,\n", + " max_new_tokens=95,\n", + " stop_at_eos=False if device == \"mps\" else True,\n", + " prepend_bos=sae.cfg.prepend_bos,\n", ")\n", "\n", "print(\"\\nNormal text (without steering):\")\n", "print(normal_text)\n", "\n", "# Generate text with steering\n", - "steered_text = generate_with_steering(model, sae, prompt, steering_feature, max_act, steering_strength=2.0)\n", + "steered_text = generate_with_steering(\n", + " model, sae, prompt, steering_feature, max_act, steering_strength=2.0\n", + ")\n", "print(\"Steered text:\")\n", "print(steered_text)" ] @@ -1103,11 +1192,12 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "# Experiment with different steering strengths\n", "print(\"\\nExperimenting with different steering strengths:\")\n", "for strength in [-4.0, -2.0, 0.5, 2.0, 4.0]:\n", - " steered_text = generate_with_steering(model, sae, prompt, steering_feature, max_act, steering_strength=strength)\n", + " steered_text = generate_with_steering(\n", + " model, sae, prompt, steering_feature, max_act, steering_strength=strength\n", + " )\n", " print(f\"\\nSteering strength {strength}:\")\n", " print(steered_text)" ] @@ -1132,24 +1222,19 @@ "\n", "payload = {\n", " # \"prompt\": \"A knight in shining\",\n", - " # \"prompt\": \"He had to fight back in self-\", \n", + " # \"prompt\": \"He had to fight back in self-\",\n", " \"prompt\": \"In the middle of the universe is the galactic\",\n", " # \"prompt\": \"Oh no. We're running on empty. Its time to fill up the car with\",\n", " # \"prompt\": \"Sure, I'm happy to pay. I don't have any cash on me but let me write you a\",\n", " \"modelId\": \"gpt2-small\",\n", " \"features\": [\n", - " {\n", - " \"modelId\": \"gpt2-small\",\n", - " \"layer\": \"7-res-jb\",\n", - " \"index\": 6770,\n", - " \"strength\": 8\n", - " }\n", + " {\"modelId\": \"gpt2-small\", \"layer\": \"7-res-jb\", \"index\": 6770, \"strength\": 8}\n", " ],\n", " \"temperature\": 0.2,\n", " \"n_tokens\": 2,\n", " \"freq_penalty\": 1,\n", " \"seed\": np.random.randint(100),\n", - " \"strength_multiplier\": 4\n", + " \"strength_multiplier\": 4,\n", "}\n", "headers = {\"Content-Type\": \"application/json\"}\n", "\n", @@ -1166,24 +1251,20 @@ "source": [ "import requests\n", "import numpy as np\n", + "\n", "url = \"https://www.neuronpedia.org/api/steer\"\n", "\n", "payload = {\n", - " \"prompt\": \"I wrote a letter to my girlfiend. It said \\\"\",\n", + " \"prompt\": 'I wrote a letter to my girlfiend. It said \"',\n", " \"modelId\": \"gpt2-small\",\n", " \"features\": [\n", - " {\n", - " \"modelId\": \"gpt2-small\",\n", - " \"layer\": \"7-res-jb\",\n", - " \"index\": 20115,\n", - " \"strength\": 4\n", - " }\n", + " {\"modelId\": \"gpt2-small\", \"layer\": \"7-res-jb\", \"index\": 20115, \"strength\": 4}\n", " ],\n", " \"temperature\": 0.7,\n", " \"n_tokens\": 120,\n", " \"freq_penalty\": 1,\n", " \"seed\": np.random.randint(100),\n", - " \"strength_multiplier\": 4\n", + " \"strength_multiplier\": 4,\n", "}\n", "headers = {\"Content-Type\": \"application/json\"}\n", "\n", @@ -1223,28 +1304,30 @@ "from transformer_lens.utils import test_prompt\n", "from functools import partial\n", "\n", + "\n", "def test_prompt_with_ablation(model, sae, prompt, answer, ablation_features):\n", - " \n", - " def ablate_feature_hook(feature_activations, hook, feature_ids, position = None):\n", - " \n", + "\n", + " def ablate_feature_hook(feature_activations, hook, feature_ids, position=None):\n", + "\n", " if position is None:\n", - " feature_activations[:,:,feature_ids] = 0\n", + " feature_activations[:, :, feature_ids] = 0\n", " else:\n", - " feature_activations[:,position,feature_ids] = 0\n", - " \n", + " feature_activations[:, position, feature_ids] = 0\n", + "\n", " return feature_activations\n", - " \n", - " ablation_hook = partial(ablate_feature_hook, feature_ids = ablation_features)\n", - " \n", + "\n", + " ablation_hook = partial(ablate_feature_hook, feature_ids=ablation_features)\n", + "\n", " model.add_sae(sae)\n", - " hook_point = sae.cfg.hook_name + '.hook_sae_acts_post'\n", + " hook_point = sae.cfg.hook_name + \".hook_sae_acts_post\"\n", " model.add_hook(hook_point, ablation_hook, \"fwd\")\n", - " \n", + "\n", " test_prompt(prompt, answer, model)\n", - " \n", + "\n", " model.reset_hooks()\n", " model.reset_saes()\n", "\n", + "\n", "# Example usage in a notebook:\n", "\n", "# Assume model and sae are already defined\n", @@ -1356,7 +1439,9 @@ "\n", " # this hook just track the SAE input, output, features, and error. If `track_grads=True`, it also ensures\n", " # that requires_grad is set to True and retain_grad is called for intermediate values.\n", - " def reconstruction_hook(sae_in: torch.Tensor, hook: HookPoint, hook_point: str): # noqa: ARG001\n", + " def reconstruction_hook(\n", + " sae_in: torch.Tensor, hook: HookPoint, hook_point: str\n", + " ): # noqa: ARG001\n", " sae = saes[hook_point]\n", " feature_acts = sae.encode(sae_in)\n", " sae_out = sae.decode(feature_acts)\n", @@ -1382,7 +1467,9 @@ " return (output_grads,)\n", "\n", " # this hook just records model activations, and ensures that intermediate activations have gradient tracking turned on if needed\n", - " def tracking_hook(hook_input: torch.Tensor, hook: HookPoint, hook_point: str): # noqa: ARG001\n", + " def tracking_hook(\n", + " hook_input: torch.Tensor, hook: HookPoint, hook_point: str\n", + " ): # noqa: ARG001\n", " model_activations[hook_point] = hook_input\n", " if track_grads:\n", " track_grad(hook_input)\n", @@ -1427,6 +1514,8 @@ "EPS = 1e-8\n", "\n", "torch.set_grad_enabled(True)\n", + "\n", + "\n", "@dataclass\n", "class AttributionGrads:\n", " metric: torch.Tensor\n", @@ -1557,24 +1646,31 @@ " sae_feature_grads=sae_feature_grads,\n", " sae_errors_attribution_proportion=sae_error_proportions,\n", " )\n", - " \n", - " \n", + "\n", + "\n", "# prompt = \" Tiger Woods plays the sport of\"\n", "# pos_token = model.tokenizer.encode(\" golf\")[0]\n", "prompt = \"In the beginning, God created the heavens and the\"\n", "pos_token = model.tokenizer.encode(\" earth\")\n", "neg_token = model.tokenizer.encode(\" sky\")\n", - "def metric_fn(logits: torch.tensor, pos_token: torch.tensor =pos_token, neg_token: torch.Tensor=neg_token) -> torch.Tensor:\n", - " return logits[0,-1,pos_token] - logits[0,-1,neg_token]\n", + "\n", + "\n", + "def metric_fn(\n", + " logits: torch.tensor,\n", + " pos_token: torch.tensor = pos_token,\n", + " neg_token: torch.Tensor = neg_token,\n", + ") -> torch.Tensor:\n", + " return logits[0, -1, pos_token] - logits[0, -1, neg_token]\n", + "\n", "\n", "feature_attribution_df = calculate_feature_attribution(\n", - " input = prompt,\n", - " model = model,\n", - " metric_fn = metric_fn,\n", + " input=prompt,\n", + " model=model,\n", + " metric_fn=metric_fn,\n", " include_saes={sae.cfg.hook_name: sae},\n", " include_error_term=True,\n", " return_logits=True,\n", - ")\n" + ")" ] }, { @@ -1584,6 +1680,7 @@ "outputs": [], "source": [ "from transformer_lens.utils import test_prompt\n", + "\n", "test_prompt(prompt, model.to_string(pos_token), model)" ] }, @@ -1596,8 +1693,14 @@ "tokens = model.to_str_tokens(prompt)\n", "unique_tokens = [f\"{i}/{t}\" for i, t in enumerate(tokens)]\n", "\n", - "px.bar(x = unique_tokens,\n", - " y = feature_attribution_df.sae_feature_attributions[sae.cfg.hook_name][0].sum(-1).detach().cpu().numpy())" + "px.bar(\n", + " x=unique_tokens,\n", + " y=feature_attribution_df.sae_feature_attributions[sae.cfg.hook_name][0]\n", + " .sum(-1)\n", + " .detach()\n", + " .cpu()\n", + " .numpy(),\n", + ")" ] }, { @@ -1611,13 +1714,18 @@ " Convert a sparse tensor to a long format pandas DataFrame.\n", " \"\"\"\n", " df = pd.DataFrame(sparse_tensor.detach().cpu().numpy())\n", - " df_long = df.melt(ignore_index=False, var_name='column', value_name='value')\n", + " df_long = df.melt(ignore_index=False, var_name=\"column\", value_name=\"value\")\n", " df_long.columns = [\"feature\", \"attribution\"]\n", - " df_long_nonzero = df_long[df_long['attribution'] != 0]\n", - " df_long_nonzero = df_long_nonzero.reset_index().rename(columns={'index': 'position'})\n", + " df_long_nonzero = df_long[df_long[\"attribution\"] != 0]\n", + " df_long_nonzero = df_long_nonzero.reset_index().rename(\n", + " columns={\"index\": \"position\"}\n", + " )\n", " return df_long_nonzero\n", "\n", - "df_long_nonzero = convert_sparse_feature_to_long_df(feature_attribution_df.sae_feature_attributions[sae.cfg.hook_name][0])\n", + "\n", + "df_long_nonzero = convert_sparse_feature_to_long_df(\n", + " feature_attribution_df.sae_feature_attributions[sae.cfg.hook_name][0]\n", + ")\n", "df_long_nonzero.sort_values(\"attribution\", ascending=False)" ] }, @@ -1627,9 +1735,20 @@ "metadata": {}, "outputs": [], "source": [ - "for i, v in df_long_nonzero.query(\"position==8\").groupby(\"feature\").attribution.sum().sort_values(ascending=False).head(5).items():\n", + "for i, v in (\n", + " df_long_nonzero.query(\"position==8\")\n", + " .groupby(\"feature\")\n", + " .attribution.sum()\n", + " .sort_values(ascending=False)\n", + " .head(5)\n", + " .items()\n", + "):\n", " print(f\"Feature {i} had a total attribution of {v:.2f}\")\n", - " html = get_dashboard_html(sae_release = \"gpt2-small\", sae_id=f\"{sae.cfg.hook_layer}-res-jb\", feature_idx=int(i))\n", + " html = get_dashboard_html(\n", + " sae_release=\"gpt2-small\",\n", + " sae_id=f\"{sae.cfg.hook_layer}-res-jb\",\n", + " feature_idx=int(i),\n", + " )\n", " display(IFrame(html, width=1200, height=300))" ] }, @@ -1639,9 +1758,19 @@ "metadata": {}, "outputs": [], "source": [ - "for i, v in df_long_nonzero.groupby(\"feature\").attribution.sum().sort_values(ascending=False).head(5).items():\n", + "for i, v in (\n", + " df_long_nonzero.groupby(\"feature\")\n", + " .attribution.sum()\n", + " .sort_values(ascending=False)\n", + " .head(5)\n", + " .items()\n", + "):\n", " print(f\"Feature {i} had a total attribution of {v:.2f}\")\n", - " html = get_dashboard_html(sae_release = \"gpt2-small\", sae_id=f\"{sae.cfg.hook_layer}-res-jb\", feature_idx=int(i))\n", + " html = get_dashboard_html(\n", + " sae_release=\"gpt2-small\",\n", + " sae_id=f\"{sae.cfg.hook_layer}-res-jb\",\n", + " feature_idx=int(i),\n", + " )\n", " display(IFrame(html, width=1200, height=300))" ] }, diff --git a/tutorials/uploading_saes_to_huggingface.ipynb b/tutorials/uploading_saes_to_huggingface.ipynb index 2af4e94b..c5c7448f 100644 --- a/tutorials/uploading_saes_to_huggingface.ipynb +++ b/tutorials/uploading_saes_to_huggingface.ipynb @@ -111,8 +111,8 @@ "layer_1_sae.save_model(layer_1_sae_path)\n", "\n", "saes_dict = {\n", - " \"blocks.0.hook_resid_pre\": layer_0_sae, # values can be an SAE object\n", - " \"blocks.1.hook_resid_pre\": layer_1_sae_path, # or a path to a saved SAE\n", + " \"blocks.0.hook_resid_pre\": layer_0_sae, # values can be an SAE object\n", + " \"blocks.1.hook_resid_pre\": layer_1_sae_path, # or a path to a saved SAE\n", "}\n", "\n", "upload_saes_to_huggingface(\n", diff --git a/tutorials/using_an_sae_as_a_steering_vector.ipynb b/tutorials/using_an_sae_as_a_steering_vector.ipynb index d7751813..b3a04603 100644 --- a/tutorials/using_an_sae_as_a_steering_vector.ipynb +++ b/tutorials/using_an_sae_as_a_steering_vector.ipynb @@ -1,2171 +1,2184 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "GoXn14ltnGh3" - }, - "source": [ - "# Using an SAE as a steering vector\n", - "\n", - "This notebook demonstrates how to use SAE lens to identify a feature on a pretrained model, and then construct a steering vector to affect the models output to various prompts. This notebook will also make use of Neuronpedia for identifying features of interest.\n", - "\n", - "The steps below include:\n", - "\n", - "\n", - "\n", - "* Installing relevant packages (Colab or locally)\n", - "* Load your SAE and the model it used\n", - "* Determining your feature of interest and its index\n", - "* Implementing your steering vector\n", - "\n", - "\n", - "\n" - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "GoXn14ltnGh3" + }, + "source": [ + "# Using an SAE as a steering vector\n", + "\n", + "This notebook demonstrates how to use SAE lens to identify a feature on a pretrained model, and then construct a steering vector to affect the models output to various prompts. This notebook will also make use of Neuronpedia for identifying features of interest.\n", + "\n", + "The steps below include:\n", + "\n", + "\n", + "\n", + "* Installing relevant packages (Colab or locally)\n", + "* Load your SAE and the model it used\n", + "* Determining your feature of interest and its index\n", + "* Implementing your steering vector\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gf3lJYPEXh0v" + }, + "source": [ + "## Setting up packages and notebook" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l9k5iGyOXtuN" + }, + "source": [ + "### Import and installs" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fapxk8MDrs6R" + }, + "source": [ + "#### Environment Setup\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, - { - "cell_type": "markdown", - "metadata": { - "id": "gf3lJYPEXh0v" - }, - "source": [ - "## Setting up packages and notebook" - ] + "collapsed": true, + "id": "0TwNmRkRUgR7", + "outputId": "ffeb827a-9af2-4b09-b8dd-78e0d594ddf6" + }, + "outputs": [], + "source": [ + "try:\n", + " # for google colab users\n", + " import google.colab # type: ignore\n", + " from google.colab import output\n", + "\n", + " COLAB = True\n", + " %pip install sae-lens transformer-lens\n", + "except:\n", + " # for local setup\n", + " COLAB = False\n", + " from IPython import get_ipython # type: ignore\n", + "\n", + " ipython = get_ipython()\n", + " assert ipython is not None\n", + " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", + " ipython.run_line_magic(\"autoreload\", \"2\")\n", + "\n", + "# Imports for displaying vis in Colab / notebook\n", + "import webbrowser\n", + "import http.server\n", + "import socketserver\n", + "import threading\n", + "\n", + "PORT = 8000\n", + "\n", + "# general imports\n", + "import os\n", + "import torch\n", + "from tqdm import tqdm\n", + "import plotly.express as px\n", + "\n", + "torch.set_grad_enabled(False);" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "NGgIu1ZVYDub" + }, + "outputs": [], + "source": [ + "def display_vis_inline(filename: str, height: int = 850):\n", + " \"\"\"\n", + " Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each\n", + " vis has a unique port without having to define a port within the function.\n", + " \"\"\"\n", + " if not (COLAB):\n", + " webbrowser.open(filename)\n", + "\n", + " else:\n", + " global PORT\n", + "\n", + " def serve(directory):\n", + " os.chdir(directory)\n", + "\n", + " # Create a handler for serving files\n", + " handler = http.server.SimpleHTTPRequestHandler\n", + "\n", + " # Create a socket server with the handler\n", + " with socketserver.TCPServer((\"\", PORT), handler) as httpd:\n", + " print(f\"Serving files from {directory} on port {PORT}\")\n", + " httpd.serve_forever()\n", + "\n", + " thread = threading.Thread(target=serve, args=(\"/content\",))\n", + " thread.start()\n", + "\n", + " output.serve_kernel_port_as_iframe(\n", + " PORT, path=f\"/{filename}\", height=height, cache_in_notebook=True\n", + " )\n", + "\n", + " PORT += 1" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CmaPYLpGrxbo" + }, + "source": [ + "#### General Installs and device setup" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "tdUm9rZKr1Qb", + "outputId": "9b73b762-1356-437b-8925-91c514093b43" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "l9k5iGyOXtuN" - }, - "source": [ - "### Import and installs" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "Device: mps\n" + ] + } + ], + "source": [ + "# package import\n", + "from torch import Tensor\n", + "from transformer_lens import utils\n", + "from functools import partial\n", + "from jaxtyping import Int, Float\n", + "\n", + "# device setup\n", + "if torch.backends.mps.is_available():\n", + " device = \"mps\"\n", + "else:\n", + " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "print(f\"Device: {device}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lsB0qORUaXiK" + }, + "source": [ + "### Load your model and SAE\n", + "\n", + "We're going to work with a pretrained GPT2-small model, and the RES-JB SAE set which is for the residual stream." + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "collapsed": true, + "id": "bCvNtm1OOhlR", + "outputId": "e6fd27ab-ee94-46ec-a07e-ee48c8f30da3" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "fapxk8MDrs6R" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8607cfc3f17548078c7b3ff7ebcca055", + "version_major": 2, + "version_minor": 0 }, - "source": [ - "#### Environment Setup\n" + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00\", \"J\", and \"edi\".\n", + "\n", + "Our feature activation indexes at sv_feature_acts[2] - for \"edi\" - are of most interest to us.\n", + "\n", + "Because we are using pretrained saes that have published feature maps, you can search on Neuronpedia for a feature of interest." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gFv4iBHFcOmE" + }, + "source": [ + "### Steps for Neuronpedia use\n", + "\n", + "Use the interface to search for a specific concept or item and determine which layer and at what index it is.\n", + "\n", + "1. Open the [Neuronpedia](https://www.neuronpedia.org/) homepage.\n", + "2. Using the \"Models\" dropdown, select your model. Here we are using GPT2-SM (GPT2-small).\n", + "3. The next page will have a search bar, which allows you to enter your index of interest. We're interested in the \"RES-JB\" SAE set, make sure to select it.\n", + "4. We found these indices in the previous step: [ 7650, 718, 22372]. Select them in the search to see the feature dashboard for each.\n", + "5. As we'll see, some of the indices may relate to features you don't care about.\n", + "\n", + "From using Neuronpedia, I have determined that my feature of interest is in layer 2, at index 7650: [here](https://www.neuronpedia.org/gpt2-small/2-res-jb/7650) is the feature." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KX0rXziniH9O" + }, + "source": [ + "### Note: 2nd Option - Starting with Neuronpedia\n", + "\n", + "Another option here is that you can start with Neuronpedia to identify features of interest. By using your prompt in the interface you can explore which features were involved and search across all the layers. This allows you to first determine your layer and index of interest in Neuronpedia before focusing them in your code. Start [here](https://www.neuronpedia.org/search) if you want to begin with search." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YACtNFzGcNua" + }, + "source": [ + "## Implement your steering vector and affect the output" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pO8hjg8j5bb-" + }, + "source": [ + "### Define values for your steering vector\n", + "To create our steering vector, we now need to get the decoder weights from our sparse autoencoder found at our index of interest.\n", + "\n", + "Then to use our steering vector, we want a prompt for text generation, as well as a scaling factor coefficent to apply with the steering vector\n", + "\n", + "We also set common sampling kwargs - temperature, top_p and freq_penalty" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": { + "id": "rgYEWGV0t0L2" + }, + "outputs": [], + "source": [ + "steering_vector = sae.W_dec[10200]\n", + "\n", + "example_prompt = \"What is the most iconic structure known to man?\"\n", + "coeff = 300\n", + "sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cexaoBR65lIa" + }, + "source": [ + "### Set up hook functions\n", + "\n", + "Finally, we need to create a hook that allows us to apply the steering vector when our model runs generate() on our defined prompt. We have also added a boolean value 'steering_on' that allows us to easily toggle the steering vector on and off for each prompt\n" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": { + "collapsed": true, + "id": "3kcVWeJoIAlC" + }, + "outputs": [], + "source": [ + "def steering_hook(resid_pre, hook):\n", + " if resid_pre.shape[1] == 1:\n", + " return\n", + "\n", + " position = sae_out.shape[1]\n", + " if steering_on:\n", + " # using our steering vector and applying the coefficient\n", + " resid_pre[:, : position - 1, :] += coeff * steering_vector\n", + "\n", + "\n", + "def hooked_generate(prompt_batch, fwd_hooks=[], seed=None, **kwargs):\n", + " if seed is not None:\n", + " torch.manual_seed(seed)\n", + "\n", + " with model.hooks(fwd_hooks=fwd_hooks):\n", + " tokenized = model.to_tokens(prompt_batch)\n", + " result = model.generate(\n", + " stop_at_eos=False, # avoids a bug on MPS\n", + " input=tokenized,\n", + " max_new_tokens=50,\n", + " do_sample=True,\n", + " **kwargs\n", + " )\n", + " return result" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": { + "id": "VcuRkX0yA2WH" + }, + "outputs": [], + "source": [ + "def run_generate(example_prompt):\n", + " model.reset_hooks()\n", + " editing_hooks = [(f\"blocks.{layer}.hook_resid_post\", steering_hook)]\n", + " res = hooked_generate(\n", + " [example_prompt] * 3, editing_hooks, seed=None, **sampling_kwargs\n", + " )\n", + "\n", + " # Print results, removing the ugly beginning of sequence token\n", + " res_str = model.to_string(res[:, 1:])\n", + " print((\"\\n\\n\" + \"-\" * 80 + \"\\n\\n\").join(res_str))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XYx--hIn61VQ" + }, + "source": [ + "### Generate text influenced by steering vector\n", + "\n", + "You may want to experiment with the scaling factor coefficient value that you set and see how it affects the generated output." + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 337, + "referenced_widgets": [ + "9f555c5ada38495eb4281cbb49169abe", + "79b59cbde9444bf892931d31afec7f2a", + "a157870318114d459a33d795850967ef", + "635162e10abc441797d4e5b74713bf44", + "720b4d010c364e3fbf72a53b267e8db9", + "d9c33fbfb3164cbbb7b9a4cd172d20ae", + "df53331cce124bd1ada5aa9e9a977015", + "229dad8e29f04c279c5603286e2c0643", + "83d947fc3338491ab4155b87c443884c", + "5e9700580d6b4ad0bfac34bf3b3919fc", + "a2c30462ef8d41fd9158f194a746d5a7" + ] }, + "id": "hN_YOzBE6lz8", + "outputId": "e263b8ff-86ce-439e-81e5-bbecb0d7e187" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": 40, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "collapsed": true, - "id": "bCvNtm1OOhlR", - "outputId": "e6fd27ab-ee94-46ec-a07e-ee48c8f30da3" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "634ddfad68cb49208e63733402859842", + "version_major": 2, + "version_minor": 0 }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "8607cfc3f17548078c7b3ff7ebcca055", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Loading checkpoint shards: 0%| | 0/2 [00:00\", \"J\", and \"edi\".\n", - "\n", - "Our feature activation indexes at sv_feature_acts[2] - for \"edi\" - are of most interest to us.\n", - "\n", - "Because we are using pretrained saes that have published feature maps, you can search on Neuronpedia for a feature of interest." - ] + "229dad8e29f04c279c5603286e2c0643": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } }, - { - "cell_type": "markdown", - "metadata": { - "id": "gFv4iBHFcOmE" - }, - "source": [ - "### Steps for Neuronpedia use\n", - "\n", - "Use the interface to search for a specific concept or item and determine which layer and at what index it is.\n", - "\n", - "1. Open the [Neuronpedia](https://www.neuronpedia.org/) homepage.\n", - "2. Using the \"Models\" dropdown, select your model. Here we are using GPT2-SM (GPT2-small).\n", - "3. The next page will have a search bar, which allows you to enter your index of interest. We're interested in the \"RES-JB\" SAE set, make sure to select it.\n", - "4. We found these indices in the previous step: [ 7650, 718, 22372]. Select them in the search to see the feature dashboard for each.\n", - "5. As we'll see, some of the indices may relate to features you don't care about.\n", - "\n", - "From using Neuronpedia, I have determined that my feature of interest is in layer 2, at index 7650: [here](https://www.neuronpedia.org/gpt2-small/2-res-jb/7650) is the feature." - ] + "25ebd285de2e49c483c3b22b5c8364c0": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } }, - { - "cell_type": "markdown", - "metadata": { - "id": "KX0rXziniH9O" - }, - "source": [ - "### Note: 2nd Option - Starting with Neuronpedia\n", - "\n", - "Another option here is that you can start with Neuronpedia to identify features of interest. By using your prompt in the interface you can explore which features were involved and search across all the layers. This allows you to first determine your layer and index of interest in Neuronpedia before focusing them in your code. Start [here](https://www.neuronpedia.org/search) if you want to begin with search." - ] + "359ef2b8a4ac4a9c9a91edc4a2dd1326": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } }, - { - "cell_type": "markdown", - "metadata": { - "id": "YACtNFzGcNua" - }, - "source": [ - "## Implement your steering vector and affect the output" - ] + "38341454dd6b4e9ca2fe5b85d2e371e1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } }, - { - "cell_type": "markdown", - "metadata": { - "id": "pO8hjg8j5bb-" - }, - "source": [ - "### Define values for your steering vector\n", - "To create our steering vector, we now need to get the decoder weights from our sparse autoencoder found at our index of interest.\n", - "\n", - "Then to use our steering vector, we want a prompt for text generation, as well as a scaling factor coefficent to apply with the steering vector\n", - "\n", - "We also set common sampling kwargs - temperature, top_p and freq_penalty" - ] + "3b74befc8d70471697ce6686ab4ac5c3": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } }, - { - "cell_type": "code", - "execution_count": 46, - "metadata": { - "id": "rgYEWGV0t0L2" - }, - "outputs": [], - "source": [ - "steering_vector = sae.W_dec[10200]\n", - "\n", - "example_prompt = \"What is the most iconic structure known to man?\"\n", - "coeff = 300\n", - "sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0)" - ] + "3d3584d1feec459287ffa24c4ef790c3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a30c82833f55441995744300c2ef538d", + "max": 50, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_4932983d4f1a4199b3d24c730c765a24", + "value": 50 + } }, - { - "cell_type": "markdown", - "metadata": { - "id": "cexaoBR65lIa" - }, - "source": [ - "### Set up hook functions\n", - "\n", - "Finally, we need to create a hook that allows us to apply the steering vector when our model runs generate() on our defined prompt. We have also added a boolean value 'steering_on' that allows us to easily toggle the steering vector on and off for each prompt\n" - ] + "3f5f9cad86e24dd489146215c3a208c9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } }, - { - "cell_type": "code", - "execution_count": 47, - "metadata": { - "collapsed": true, - "id": "3kcVWeJoIAlC" - }, - "outputs": [], - "source": [ - "def steering_hook(resid_pre, hook):\n", - " if resid_pre.shape[1] == 1:\n", - " return\n", - "\n", - " position = sae_out.shape[1]\n", - " if steering_on:\n", - " # using our steering vector and applying the coefficient\n", - " resid_pre[:, :position - 1, :] += coeff * steering_vector\n", - "\n", - "\n", - "def hooked_generate(prompt_batch, fwd_hooks=[], seed=None, **kwargs):\n", - " if seed is not None:\n", - " torch.manual_seed(seed)\n", - "\n", - " with model.hooks(fwd_hooks=fwd_hooks):\n", - " tokenized = model.to_tokens(prompt_batch)\n", - " result = model.generate(\n", - " stop_at_eos=False, # avoids a bug on MPS\n", - " input=tokenized,\n", - " max_new_tokens=50,\n", - " do_sample=True,\n", - " **kwargs)\n", - " return result\n" - ] + "3fdf0c5e62f24f30b02bcdc37b17c2e7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } }, - { - "cell_type": "code", - "execution_count": 48, - "metadata": { - "id": "VcuRkX0yA2WH" - }, - "outputs": [], - "source": [ - "def run_generate(example_prompt):\n", - " model.reset_hooks()\n", - " editing_hooks = [(f\"blocks.{layer}.hook_resid_post\", steering_hook)]\n", - " res = hooked_generate([example_prompt] * 3, editing_hooks, seed=None, **sampling_kwargs)\n", - "\n", - " # Print results, removing the ugly beginning of sequence token\n", - " res_str = model.to_string(res[:, 1:])\n", - " print((\"\\n\\n\" + \"-\" * 80 + \"\\n\\n\").join(res_str))" - ] + "4024c181581c485abd3181586afc2574": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3fdf0c5e62f24f30b02bcdc37b17c2e7", + "max": 50, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_07c0dd1a8de149408b981a8892f6e46d", + "value": 50 + } }, - { - "cell_type": "markdown", - "metadata": { - "id": "XYx--hIn61VQ" - }, - "source": [ - "### Generate text influenced by steering vector\n", - "\n", - "You may want to experiment with the scaling factor coefficient value that you set and see how it affects the generated output." - ] + "4932983d4f1a4199b3d24c730c765a24": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } }, - { - "cell_type": "code", - "execution_count": 49, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 337, - "referenced_widgets": [ - "9f555c5ada38495eb4281cbb49169abe", - "79b59cbde9444bf892931d31afec7f2a", - "a157870318114d459a33d795850967ef", - "635162e10abc441797d4e5b74713bf44", - "720b4d010c364e3fbf72a53b267e8db9", - "d9c33fbfb3164cbbb7b9a4cd172d20ae", - "df53331cce124bd1ada5aa9e9a977015", - "229dad8e29f04c279c5603286e2c0643", - "83d947fc3338491ab4155b87c443884c", - "5e9700580d6b4ad0bfac34bf3b3919fc", - "a2c30462ef8d41fd9158f194a746d5a7" - ] - }, - "id": "hN_YOzBE6lz8", - "outputId": "e263b8ff-86ce-439e-81e5-bbecb0d7e187" - }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "634ddfad68cb49208e63733402859842", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/50 [00:00