Skip to content

Commit

Permalink
chore: adds black-jupyter to dependencies (#318)
Browse files Browse the repository at this point in the history
* adds black-jupyter to dependencies

* formats notebooks
  • Loading branch information
anthonyduong9 authored Oct 7, 2024
1 parent 3ff9c85 commit ed5d791
Show file tree
Hide file tree
Showing 14 changed files with 5,725 additions and 5,396 deletions.
36 changes: 21 additions & 15 deletions check_open_ai_sae_metrics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
"outputs": [],
"source": [
"from sae_lens.toolkit.pretrained_saes import load_sparsity\n",
"import plotly.express as px \n",
"import plotly.express as px\n",
"\n",
"path = \"open_ai_sae_weights_resid_post_attn_reformatted/v5_32k_layer_0\"\n",
"\n",
"\n",
"sparsity = load_sparsity(path)#[\"sparsity\"]\n",
"sparsity = load_sparsity(path) # [\"sparsity\"]\n",
"\n",
"\n",
"px.histogram(sparsity.cpu(), nbins=100).write_html(\"sparsity_histogram.html\")"
Expand Down Expand Up @@ -4226,57 +4226,63 @@
}
],
"source": [
"import pandas as pd \n",
"import pandas as pd\n",
"\n",
"# get all json files in all subfolders of the mother path\n",
"import os\n",
"import json\n",
"from IPython.display import display\n",
"import imgkit\n",
"\n",
"\n",
"def get_all_json_files(mother_path):\n",
" json_files = []\n",
" \n",
"\n",
" for root, dirs, files in os.walk(mother_path):\n",
" for file in files:\n",
" if file.endswith(\"metrics.json\"):\n",
" json_files.append(os.path.join(root, file))\n",
" return json_files\n",
"\n",
"def get_benchmark_stats_csv(mother_path = \"open_ai_sae_weights_resid_post_attn_reformatted\"):\n",
"\n",
"def get_benchmark_stats_csv(\n",
" mother_path=\"open_ai_sae_weights_resid_post_attn_reformatted\",\n",
"):\n",
" json_files = get_all_json_files(mother_path)\n",
" eval_metrics = {}\n",
"\n",
" for file in json_files:\n",
" with open(file, \"r\") as f:\n",
" data = json.load(f)\n",
" eval_metrics[file] = data\n",
" \n",
" \n",
"\n",
" df = pd.DataFrame(eval_metrics).T\n",
" df[\"filepath\"]=df.index\n",
" df[\"filepath\"] = df.index\n",
" df.head()\n",
" pattern = r\".*/v(\\d+)_(\\d+)k_layer_(\\d+)/metrics\\.json\"\n",
"\n",
" df[['version', 'd_sae', 'layer']] = df.filepath.str.extract(pattern)\n",
" df[[\"version\", \"d_sae\", \"layer\"]] = df.filepath.str.extract(pattern)\n",
" # move these columns to the start\n",
" cols = df.columns.tolist()\n",
" cols = cols[-3:] + cols[:-3]\n",
" df = df[cols]\n",
" df[\"layer\"] = df[\"layer\"].astype(int)\n",
" \n",
"\n",
" # remove \"metrics\" prefix from the columns\n",
" df.columns = [i.replace(\"metrics/\", \"\") for i in df.columns]\n",
" df.sort_values(by=[\"version\", \"d_sae\", \"layer\"], inplace=True)\n",
" df.to_csv(os.path.join(mother_path, \"benchmark_stats.csv\"))\n",
" df.style.background_gradient(cmap='viridis', axis=0).to_html(os.path.join(mother_path, \"benchmark_stats.html\"))\n",
" \n",
" df.style.background_gradient(cmap=\"viridis\", axis=0).to_html(\n",
" os.path.join(mother_path, \"benchmark_stats.html\")\n",
" )\n",
"\n",
" # read the html\n",
" with open(os.path.join(mother_path, \"benchmark_stats.html\"), \"r\") as f:\n",
" html = f.read()\n",
" imgkit.from_string(html, os.path.join(mother_path, \"benchmark_stats.png\"))\n",
" \n",
" return df.style.background_gradient(cmap='viridis', axis=0)\n",
" \n",
"\n",
" return df.style.background_gradient(cmap=\"viridis\", axis=0)\n",
"\n",
"\n",
"# list all paths that start with OAI in the current fold\n",
"paths = [i for i in os.listdir(\".\") if i.startswith(\"OAI\")]\n",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ typing-extensions = "^4.10.0"


[tool.poetry.group.dev.dependencies]
black = "24.4.0"
black = { version = "24.4.0", extras = ["jupyter"] }
pytest = "^8.0.2"
pytest-cov = "^4.1.0"
pre-commit = "^3.6.2"
Expand Down
128 changes: 68 additions & 60 deletions scripts/joseph_curt_pairing_gemma_scope_saes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"\n"
]
"source": []
},
{
"cell_type": "code",
Expand All @@ -27,57 +25,60 @@
}
],
"source": [
"import re \n",
"import re\n",
"import pandas as pd\n",
"\n",
"\n",
"from huggingface_hub import HfApi\n",
"import os\n",
"\n",
"\n",
"def list_repo_files(repo_id):\n",
" api = HfApi()\n",
" repo_files = api.list_repo_files(repo_id)\n",
" return repo_files\n",
"\n",
"\n",
"files = list_repo_files(repo_id)\n",
"\n",
"# print(f\"Files in the repository '{repo_id}':\")\n",
"# for file in files:\n",
"# print(file)\n",
"\n",
"\n",
"def get_details_from_file_path(file_path):\n",
" '''\n",
" \"\"\"\n",
" eg: layer_11/width_16k/average_l0_79\n",
" \n",
"\n",
" layer = 11\n",
" width = 16k\n",
" l0_or_canonical = \"79\"\n",
" \n",
"\n",
" or if layer_11/width_16k/canonical\n",
" \n",
"\n",
" layer = 11\n",
" width = 16k\n",
" l0_or_canonical = \"canonical\"\n",
" \n",
"\n",
" or if layer_11/width_1m/average_l0_79\n",
" \n",
"\n",
" layer = 11\n",
" width = 1m\n",
" l0_or_canonical = \"79\"\n",
" '''\n",
" \n",
" layer = re.search(r'layer_(\\d+)', file_path).group(1)\n",
" width = re.search(r'width_(\\d+[k|m])', file_path).group(1)\n",
" l0 = re.search(r'average_l0_(\\d+)', file_path)\n",
" \"\"\"\n",
"\n",
" layer = re.search(r\"layer_(\\d+)\", file_path).group(1)\n",
" width = re.search(r\"width_(\\d+[k|m])\", file_path).group(1)\n",
" l0 = re.search(r\"average_l0_(\\d+)\", file_path)\n",
" if l0:\n",
" l0 = l0.group(1)\n",
" else:\n",
" l0 = re.search(r'(canonical)', file_path).group(1)\n",
" \n",
" \n",
" l0 = re.search(r\"(canonical)\", file_path).group(1)\n",
"\n",
" return layer, width, l0\n",
"\n",
"# # test it \n",
"\n",
"# # test it\n",
"# file_path = 'layer_11/width_16k/average_l0_79'\n",
"# layer, width, l0 = get_details_from_file_path(file_path)\n",
"# print(f\"layer: {layer}, width: {width}, l0: {l0}\")\n",
Expand All @@ -93,28 +94,28 @@
"# print(f\"layer: {layer}, width: {width}, l0: {l0}\")\n",
"\n",
"\n",
"\n",
"def generate_entries(repo_id):\n",
" entries = []\n",
" files = list_repo_files(repo_id)\n",
" for file in files:\n",
" if 'params.npz' in file:\n",
" if \"params.npz\" in file:\n",
" entry = {}\n",
" # print(file)\n",
" layer, width, l0 = get_details_from_file_path(file)\n",
" folder_path = os.path.dirname(file)\n",
" entry['repo_id'] = repo_id\n",
" entry['id'] = folder_path\n",
" entry['path'] = folder_path\n",
" entry['l0'] = l0\n",
" entry['layer'] = layer\n",
" entry['width'] = width\n",
" \n",
" entry[\"repo_id\"] = repo_id\n",
" entry[\"id\"] = folder_path\n",
" entry[\"path\"] = folder_path\n",
" entry[\"l0\"] = l0\n",
" entry[\"layer\"] = layer\n",
" entry[\"width\"] = width\n",
"\n",
" entries.append(entry)\n",
" return entries\n",
"\n",
"def df_to_yaml(df, file_path, canonical = False):\n",
" '''\n",
"\n",
"def df_to_yaml(df, file_path, canonical=False):\n",
" \"\"\"\n",
" EXAMPLE STRUCTURE:\n",
"\n",
" gemma-scope-2b-pt-res:\n",
Expand All @@ -126,11 +127,13 @@
" path: layer_11/width_16k/average_l0_79\n",
" l0: 79.0\n",
"\n",
" '''\n",
" repo_id = df.iloc[0]['repo_id']\n",
" release_id = repo_id.split('/')[1] + '-canonical' if canonical else repo_id.split('/')[1]\n",
" with open(file_path, 'w') as f:\n",
" \n",
" \"\"\"\n",
" repo_id = df.iloc[0][\"repo_id\"]\n",
" release_id = (\n",
" repo_id.split(\"/\")[1] + \"-canonical\" if canonical else repo_id.split(\"/\")[1]\n",
" )\n",
" with open(file_path, \"w\") as f:\n",
"\n",
" f.write(f\"{release_id}:\\n\")\n",
" f.write(f\" repo_id: {repo_id}\\n\")\n",
" f.write(f\" model: gemma-2-2b\\n\")\n",
Expand All @@ -139,7 +142,7 @@
" for index, row in df.iterrows():\n",
" f.write(f\" - id: {row['id']}\\n\")\n",
" f.write(f\" path: {row['path']}\\n\")\n",
" if row['l0'] != 'canonical':\n",
" if row[\"l0\"] != \"canonical\":\n",
" f.write(f\" l0: {row['l0']}\\n\")\n",
" # f.write(f\" l0: {row['l0']}\\n\")\n",
" # f.write(f\" layer: {row['layer']}\\n\")\n",
Expand All @@ -154,29 +157,34 @@
" \"google/gemma-scope-9b-pt-res\",\n",
" \"google/gemma-scope-9b-pt-mlp\",\n",
" \"google/gemma-scope-9b-pt-att\",\n",
" \"google/gemma-scope-27b-pt-res\"\n",
" \"google/gemma-scope-27b-pt-res\",\n",
"]\n",
"\n",
"for repo_id in repo_ids:\n",
"\n",
" entries = generate_entries(repo_id)\n",
"\n",
" df = pd.DataFrame(entries)\n",
" df[\"layer\"]= pd.to_numeric(df[\"layer\"])\n",
" df.sort_values(by=['width', 'layer', 'l0'], inplace=True)\n",
" df[\"layer\"] = pd.to_numeric(df[\"layer\"])\n",
" df.sort_values(by=[\"width\", \"layer\", \"l0\"], inplace=True)\n",
" df.head(30)\n",
"\n",
" canonical_only_df = df[df['l0'] == 'canonical']\n",
" non_canonical_df = df[df['l0'] != 'canonical']\n",
" canonical_only_df = df[df[\"l0\"] == \"canonical\"]\n",
" non_canonical_df = df[df[\"l0\"] != \"canonical\"]\n",
"\n",
" df_to_yaml(non_canonical_df, f'{repo_id.split(\"/\")[1]}_not_canonical.yaml', canonical=False)\n",
" df_to_yaml(\n",
" non_canonical_df, f'{repo_id.split(\"/\")[1]}_not_canonical.yaml', canonical=False\n",
" )\n",
" if canonical_only_df.shape[0] == 0:\n",
" print(f\"No canonical entries found in {repo_id.split('/')[1]}\")\n",
" continue\n",
" else:\n",
" df_to_yaml(canonical_only_df, f'{repo_id.split(\"/\")[1]}_canonical_only.yaml', canonical=True) \n",
" df_to_yaml(\n",
" canonical_only_df,\n",
" f'{repo_id.split(\"/\")[1]}_canonical_only.yaml',\n",
" canonical=True,\n",
" )\n",
"\n",
" \n",
" # !cat canonical_only.yaml"
]
},
Expand Down Expand Up @@ -229,36 +237,36 @@
"\n",
"\n",
"# Path to the YAML file\n",
"yaml_file = 'pretrained_saes.yaml'\n",
"yaml_file = \"pretrained_saes.yaml\"\n",
"\n",
"# Initialize yamel.yaml\n",
"yaml = YAML()\n",
"yaml.preserve_quotes = True\n",
"yaml.indent(mapping=2, sequence=4, offset=2)\n",
"\n",
"# Read the existing YAML file\n",
"with open(yaml_file, 'r') as file:\n",
"with open(yaml_file, \"r\") as file:\n",
" data = yaml.load(file)\n",
"\n",
"# Generate new entries\n",
"new_entries = generate_entries(local_dir)\n",
"\n",
"# Create a CommentedMap for gemmascope-2b-pt-res\n",
"gemmascope_data = CommentedMap()\n",
"gemmascope_data['repo_id'] = \"gg-hf/gemmascope-2b-pt-res\"\n",
"gemmascope_data['model'] = \"gemma-2-2b\"\n",
"gemmascope_data['conversion_func'] = \"gemma_2\"\n",
"gemmascope_data['saes'] = new_entries\n",
"gemmascope_data[\"repo_id\"] = \"gg-hf/gemmascope-2b-pt-res\"\n",
"gemmascope_data[\"model\"] = \"gemma-2-2b\"\n",
"gemmascope_data[\"conversion_func\"] = \"gemma_2\"\n",
"gemmascope_data[\"saes\"] = new_entries\n",
"\n",
"# Remove the existing gemmascope-2b-pt-res entry if it exists\n",
"if 'SAE_LOOKUP' in data and 'gemmascope-2b-pt-res' in data['SAE_LOOKUP']:\n",
" del data['SAE_LOOKUP']['gemmascope-2b-pt-res']\n",
"if \"SAE_LOOKUP\" in data and \"gemmascope-2b-pt-res\" in data[\"SAE_LOOKUP\"]:\n",
" del data[\"SAE_LOOKUP\"][\"gemmascope-2b-pt-res\"]\n",
"\n",
"# Add gemmascope-2b-pt-res at the end\n",
"data['SAE_LOOKUP']['gemmascope-2b-pt-res'] = gemmascope_data\n",
"data[\"SAE_LOOKUP\"][\"gemmascope-2b-pt-res\"] = gemmascope_data\n",
"\n",
"# Write the updated YAML file\n",
"with open(yaml_file, 'w') as file:\n",
"with open(yaml_file, \"w\") as file:\n",
" yaml.dump(data, file)\n",
"\n",
"print(f\"YAML file updated: {yaml_file}\")"
Expand Down Expand Up @@ -293,19 +301,19 @@
}
],
"source": [
"from sae_lens import HookedSAETransformer, SAE \n",
"from sae_lens import HookedSAETransformer, SAE\n",
"\n",
"device = \"cuda\"\n",
"\n",
"model = HookedSAETransformer.from_pretrained(\"gemma-2-2b\", device = device)\n",
"model = HookedSAETransformer.from_pretrained(\"gemma-2-2b\", device=device)\n",
"\n",
"# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)\n",
"# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict\n",
"# We also return the feature sparsities which are stored in HF for convenience. \n",
"# We also return the feature sparsities which are stored in HF for convenience.\n",
"sae, cfg_dict, sparsity = SAE.from_pretrained(\n",
" release = \"gemma-scope-9b-pt-mlp\", # <- Release name \n",
" sae_id = \"layer_2/width_131k/average_l0_12\", # <- SAE id (not always a hook point!)\n",
" device = device\n",
" release=\"gemma-scope-9b-pt-mlp\", # <- Release name\n",
" sae_id=\"layer_2/width_131k/average_l0_12\", # <- SAE id (not always a hook point!)\n",
" device=device,\n",
")"
]
},
Expand Down
Loading

0 comments on commit ed5d791

Please sign in to comment.