From bb8301f3af3ce93c960e0d58028e6a8f0b4fe45d Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Tue, 16 Apr 2024 01:12:16 +0200 Subject: [PATCH 01/42] revised demo testing to check all demos --- makefile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/makefile b/makefile index b786aa209..643ad1a25 100644 --- a/makefile +++ b/makefile @@ -18,8 +18,7 @@ docstring-test: poetry run pytest transformer_lens/ notebook-test: - poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Main_Demo.ipynb - poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Exploratory_Analysis_Demo.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/ test: make unit-test From 4fba5587229174db7a584164a3d6eacb527d6ed1 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 24 Apr 2024 01:25:26 +0200 Subject: [PATCH 02/42] separated demos --- makefile | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/makefile b/makefile index 643ad1a25..2baefc24c 100644 --- a/makefile +++ b/makefile @@ -18,7 +18,14 @@ docstring-test: poetry run pytest transformer_lens/ notebook-test: - poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/ + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Main_Demo.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Exploratory_Analysis_Demo.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/BERT.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Grokking_Demo.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Head_Detector_Demo.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/No_Position_Experiment.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Othello_GPT.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Activation_Patching_in_TL_Demo.ipynb test: make unit-test From e73f290ffc87c202c985f8b8acd3c2d901346926 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 24 Apr 2024 03:08:07 +0200 Subject: [PATCH 03/42] changed demo test order --- makefile | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/makefile b/makefile index 2baefc24c..42a0a22a2 100644 --- a/makefile +++ b/makefile @@ -18,14 +18,24 @@ docstring-test: poetry run pytest transformer_lens/ notebook-test: - poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Main_Demo.ipynb poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Exploratory_Analysis_Demo.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Main_Demo.ipynb + + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Activation_Patching_in_TL_Demo.ipynb + + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Attribution_Patching_Demo.ipynb poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/BERT.ipynb poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Grokking_Demo.ipynb poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Head_Detector_Demo.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Interactive_Neuroscope.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/LLaMA.ipynb poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/No_Position_Experiment.ipynb poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Othello_GPT.ipynb - poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Activation_Patching_in_TL_Demo.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Qwen.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Santa_Coder.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Stable_Lm.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/SVD_Interpreter_Demo.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Tracr_to_Transformer_Lens_Demo.ipynb test: make unit-test From 99ba5db8fec627ca76c48c6d78dd241d6b3f30ce Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 24 Apr 2024 03:18:05 +0200 Subject: [PATCH 04/42] rearranged test order --- makefile | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/makefile b/makefile index 42a0a22a2..32a11f52f 100644 --- a/makefile +++ b/makefile @@ -18,13 +18,10 @@ docstring-test: poetry run pytest transformer_lens/ notebook-test: + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Attribution_Patching_Demo.ipynb poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Exploratory_Analysis_Demo.ipynb - poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Main_Demo.ipynb - - poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Activation_Patching_in_TL_Demo.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Main_Demo.ipynb - poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Attribution_Patching_Demo.ipynb - poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/BERT.ipynb poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Grokking_Demo.ipynb poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Head_Detector_Demo.ipynb poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Interactive_Neuroscope.ipynb @@ -37,6 +34,12 @@ notebook-test: poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/SVD_Interpreter_Demo.ipynb poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Tracr_to_Transformer_Lens_Demo.ipynb + # Contains failing cells + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/BERT.ipynb + + # Causes CI to hang + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Activation_Patching_in_TL_Demo.ipynb + test: make unit-test make acceptance-test From 37dca9a823e67068dc706d853334fc62db1aef09 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Wed, 24 Apr 2024 03:24:43 +0200 Subject: [PATCH 05/42] updated attribution patching to run differnt code in github --- demos/Attribution_Patching_Demo.ipynb | 2 +- demos/Main_Demo.ipynb | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/demos/Attribution_Patching_Demo.ipynb b/demos/Attribution_Patching_Demo.ipynb index cef67eb8b..8d8796629 100644 --- a/demos/Attribution_Patching_Demo.ipynb +++ b/demos/Attribution_Patching_Demo.ipynb @@ -1 +1 @@ -{"cells":[{"cell_type":"markdown","metadata":{},"source":["\n"," \"Open\n",""]},{"cell_type":"markdown","metadata":{},"source":[" # Attribution Patching Demo\n"," **Read [the accompanying blog post here](https://neelnanda.io/attribution-patching) for more context**\n"," This is an interim research report, giving a whirlwind tour of some unpublished work I did at Anthropic (credit to the then team - Chris Olah, Catherine Olsson, Nelson Elhage and Tristan Hume for help, support, and mentorship!)\n","\n"," The goal of this work is run activation patching at an industrial scale, by using gradient based attribution to approximate the technique - allow an arbitrary number of patches to be made on two forwards and a single backward pass\n","\n"," I have had less time than hoped to flesh out this investigation, but am writing up a rough investigation and comparison to standard activation patching on a few tasks to give a sense of the potential of this approach, and where it works vs falls down."]},{"cell_type":"markdown","metadata":{},"source":[" To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.\n","\n"," **Tips for reading this Colab:**\n"," * You can run all this code for yourself!\n"," * The graphs are interactive!\n"," * Use the table of contents pane in the sidebar to navigate\n"," * Collapse irrelevant sections with the dropdown arrows\n"," * Search the page using the search in the sidebar, not CTRL+F"]},{"cell_type":"markdown","metadata":{},"source":[" ## Setup (Ignore)"]},{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Running as a Jupyter notebook - intended for development only!\n"]}],"source":["# Janky code to do different setup when run in a Colab notebook vs VSCode\n","DEBUG_MODE = False\n","try:\n"," import google.colab\n"," IN_COLAB = True\n"," print(\"Running as a Colab notebook\")\n"," %pip install transformer_lens==1.1.1\n"," %pip install torchtyping\n"," # Install my janky personal plotting utils\n"," %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n"," # Install another version of node that makes PySvelte work way faster\n"," !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n"," %pip install git+https://github.com/neelnanda-io/PySvelte.git\n"," # Needed for PySvelte to work, v3 came out and broke things...\n"," %pip install typeguard==2.13.3\n","except:\n"," IN_COLAB = False\n"," print(\"Running as a Jupyter notebook - intended for development only!\")\n"," from IPython import get_ipython\n","\n"," ipython = get_ipython()\n"," # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n"," ipython.magic(\"load_ext autoreload\")\n"," ipython.magic(\"autoreload 2\")"]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[],"source":["# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n","import plotly.io as pio\n","\n","if IN_COLAB or not DEBUG_MODE:\n"," # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.\n"," pio.renderers.default = \"colab\"\n","else:\n"," pio.renderers.default = \"notebook_connected\""]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[],"source":["# Import stuff\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","import numpy as np\n","import einops\n","from fancy_einsum import einsum\n","import tqdm.notebook as tqdm\n","import random\n","from pathlib import Path\n","import plotly.express as px\n","from torch.utils.data import DataLoader\n","\n","from torchtyping import TensorType as TT\n","from typing import List, Union, Optional, Callable\n","from functools import partial\n","import copy\n","import itertools\n","import json\n","\n","from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n","import dataclasses\n","import datasets\n","from IPython.display import HTML, Markdown"]},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[],"source":["import pysvelte\n","\n","import transformer_lens\n","import transformer_lens.utils as utils\n","from transformer_lens.hook_points import (\n"," HookedRootModule,\n"," HookPoint,\n",") # Hooking utilities\n","from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache"]},{"cell_type":"markdown","metadata":{},"source":[" Plotting helper functions from a janky personal library of plotting utils. The library is not documented and I recommend against trying to read it, just use your preferred plotting library if you want to do anything non-obvious:"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[],"source":["from neel_plotly import line, imshow, scatter"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[],"source":["import transformer_lens.patching as patching"]},{"cell_type":"markdown","metadata":{},"source":[" ## IOI Patching Setup\n"," This just copies the relevant set up from Exploratory Analysis Demo, and isn't very important."]},{"cell_type":"code","execution_count":8,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-small into HookedTransformer\n"]}],"source":["model = HookedTransformer.from_pretrained(\"gpt2-small\")\n","model.set_use_attn_result(True)"]},{"cell_type":"code","execution_count":9,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean string 0 <|endoftext|>When John and Mary went to the shops, John gave the bag to\n","Corrupted string 0 <|endoftext|>When John and Mary went to the shops, Mary gave the bag to\n","Answer token indices tensor([[ 5335, 1757],\n"," [ 1757, 5335],\n"," [ 4186, 3700],\n"," [ 3700, 4186],\n"," [ 6035, 15686],\n"," [15686, 6035],\n"," [ 5780, 14235],\n"," [14235, 5780]], device='cuda:0')\n"]}],"source":["prompts = ['When John and Mary went to the shops, John gave the bag to', 'When John and Mary went to the shops, Mary gave the bag to', 'When Tom and James went to the park, James gave the ball to', 'When Tom and James went to the park, Tom gave the ball to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'After Martin and Amy went to the park, Amy gave a drink to', 'After Martin and Amy went to the park, Martin gave a drink to']\n","answers = [(' Mary', ' John'), (' John', ' Mary'), (' Tom', ' James'), (' James', ' Tom'), (' Dan', ' Sid'), (' Sid', ' Dan'), (' Martin', ' Amy'), (' Amy', ' Martin')]\n","\n","clean_tokens = model.to_tokens(prompts)\n","# Swap each adjacent pair, with a hacky list comprehension\n","corrupted_tokens = clean_tokens[\n"," [(i+1 if i%2==0 else i-1) for i in range(len(clean_tokens)) ]\n"," ]\n","print(\"Clean string 0\", model.to_string(clean_tokens[0]))\n","print(\"Corrupted string 0\", model.to_string(corrupted_tokens[0]))\n","\n","answer_token_indices = torch.tensor([[model.to_single_token(answers[i][j]) for j in range(2)] for i in range(len(answers))], device=model.cfg.device)\n","print(\"Answer token indices\", answer_token_indices)"]},{"cell_type":"code","execution_count":10,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 3.5519\n","Corrupted logit diff: -3.5519\n"]}],"source":["def get_logit_diff(logits, answer_token_indices=answer_token_indices):\n"," if len(logits.shape)==3:\n"," # Get final logits only\n"," logits = logits[:, -1, :]\n"," correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))\n"," incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))\n"," return (correct_logits - incorrect_logits).mean()\n","\n","clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n","corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)\n","\n","clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()\n","print(f\"Clean logit diff: {clean_logit_diff:.4f}\")\n","\n","corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()\n","print(f\"Corrupted logit diff: {corrupted_logit_diff:.4f}\")"]},{"cell_type":"code","execution_count":11,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Baseline is 1: 1.0000\n","Corrupted Baseline is 0: 0.0000\n"]}],"source":["CLEAN_BASELINE = clean_logit_diff\n","CORRUPTED_BASELINE = corrupted_logit_diff\n","def ioi_metric(logits, answer_token_indices=answer_token_indices):\n"," return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE - CORRUPTED_BASELINE)\n","\n","print(f\"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}\")\n","print(f\"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Patching\n"," In the following cells, we define attribution patching and use it in various ways on the model."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["Metric = Callable[[TT[\"batch_and_pos_dims\", \"d_model\"]], float]"]},{"cell_type":"code","execution_count":13,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Value: 1.0\n","Clean Activations Cached: 220\n","Clean Gradients Cached: 220\n","Corrupted Value: 0.0\n","Corrupted Activations Cached: 220\n","Corrupted Gradients Cached: 220\n"]}],"source":["filter_not_qkv_input = lambda name: \"_input\" not in name\n","def get_cache_fwd_and_bwd(model, tokens, metric):\n"," model.reset_hooks()\n"," cache = {}\n"," def forward_cache_hook(act, hook):\n"," cache[hook.name] = act.detach()\n"," model.add_hook(filter_not_qkv_input, forward_cache_hook, \"fwd\")\n","\n"," grad_cache = {}\n"," def backward_cache_hook(act, hook):\n"," grad_cache[hook.name] = act.detach()\n"," model.add_hook(filter_not_qkv_input, backward_cache_hook, \"bwd\")\n","\n"," value = metric(model(tokens))\n"," value.backward()\n"," model.reset_hooks()\n"," return value.item(), ActivationCache(cache, model), ActivationCache(grad_cache, model)\n","\n","clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(model, clean_tokens, ioi_metric)\n","print(\"Clean Value:\", clean_value)\n","print(\"Clean Activations Cached:\", len(clean_cache))\n","print(\"Clean Gradients Cached:\", len(clean_grad_cache))\n","corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(model, corrupted_tokens, ioi_metric)\n","print(\"Corrupted Value:\", corrupted_value)\n","print(\"Corrupted Activations Cached:\", len(corrupted_cache))\n","print(\"Corrupted Gradients Cached:\", len(corrupted_grad_cache))"]},{"cell_type":"markdown","metadata":{},"source":[" ### Attention Attribution\n"," The easiest thing to start with is to not even engage with the corrupted tokens/patching, but to look at the attribution of the attention patterns - that is, the linear approximation to what happens if you set each element of the attention pattern to zero. This, as it turns out, is a good proxy to what is going on with each head!\n"," Note that this is *not* the same as what we will later do with patching. In particular, this does not set up a careful counterfactual! It's a good tool for what's generally going on in this problem, but does not control for eg stuff that systematically boosts John > Mary in general, stuff that says \"I should activate the IOI circuit\", etc. Though using logit diff as our metric *does*\n"," Each element of the batch is independent and the metric is an average logit diff, so we can analyse each batch element independently here. We'll look at the first one, and then at the average across the whole batch (note - 4 prompts have indirect object before subject, 4 prompts have it the other way round, making the average pattern harder to interpret - I plot it over the first sequence of tokens as a mildly misleading reference).\n"," We can compare it to the interpretability in the wild diagram, and basically instantly recover most of the circuit!"]},{"cell_type":"code","execution_count":14,"metadata":{},"outputs":[],"source":["def create_attention_attr(clean_cache, clean_grad_cache) -> TT[\"batch\", \"layer\", \"head_index\", \"dest\", \"src\"]:\n"," attention_stack = torch.stack([clean_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0)\n"," attention_grad_stack = torch.stack([clean_grad_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0)\n"," attention_attr = attention_grad_stack * attention_stack\n"," attention_attr = einops.rearrange(attention_attr, \"layer batch head_index dest src -> batch layer head_index dest src\")\n"," return attention_attr\n","\n","attention_attr = create_attention_attr(clean_cache, clean_grad_cache)"]},{"cell_type":"code","execution_count":15,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["['L0H0', 'L0H1', 'L0H2', 'L0H3', 'L0H4']\n","['L0H0+', 'L0H0-', 'L0H1+', 'L0H1-', 'L0H2+']\n","['L0H0Q', 'L0H0K', 'L0H0V', 'L0H1Q', 'L0H1K']\n"]}],"source":["HEAD_NAMES = [f\"L{l}H{h}\" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]\n","HEAD_NAMES_SIGNED = [f\"{name}{sign}\" for name in HEAD_NAMES for sign in [\"+\", \"-\"]]\n","HEAD_NAMES_QKV = [f\"{name}{act_name}\" for name in HEAD_NAMES for act_name in [\"Q\", \"K\", \"V\"]]\n","print(HEAD_NAMES[:5])\n","print(HEAD_NAMES_SIGNED[:5])\n","print(HEAD_NAMES_QKV[:5])"]},{"cell_type":"markdown","metadata":{},"source":[" An extremely janky way to plot the attention attribution patterns. We scale them to be in [-1, 1], split each head into a positive and negative part (so all of it is in [0, 1]), and then plot the top 20 head-halves (a head can appear twice!) by the max value of the attribution pattern."]},{"cell_type":"code","execution_count":16,"metadata":{},"outputs":[{"data":{"text/markdown":["### Attention Attribution for first sequence"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["### Summed Attention Attribution for all sequences"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\n"]}],"source":["def plot_attention_attr(attention_attr, tokens, top_k=20, index=0, title=\"\"):\n"," if len(tokens.shape)==2:\n"," tokens = tokens[index]\n"," if len(attention_attr.shape)==5:\n"," attention_attr = attention_attr[index]\n"," attention_attr_pos = attention_attr.clamp(min=-1e-5)\n"," attention_attr_neg = - attention_attr.clamp(max=1e-5)\n"," attention_attr_signed = torch.stack([attention_attr_pos, attention_attr_neg], dim=0)\n"," attention_attr_signed = einops.rearrange(attention_attr_signed, \"sign layer head_index dest src -> (layer head_index sign) dest src\")\n"," attention_attr_signed = attention_attr_signed / attention_attr_signed.max()\n"," attention_attr_indices = attention_attr_signed.max(-1).values.max(-1).values.argsort(descending=True)\n"," # print(attention_attr_indices.shape)\n"," # print(attention_attr_indices)\n"," attention_attr_signed = attention_attr_signed[attention_attr_indices, :, :]\n"," head_labels = [HEAD_NAMES_SIGNED[i.item()] for i in attention_attr_indices]\n","\n"," if title: display(Markdown(\"### \"+title))\n"," display(pysvelte.AttentionMulti(tokens=model.to_str_tokens(tokens), attention=attention_attr_signed.permute(1, 2, 0)[:, :, :top_k], head_labels=head_labels[:top_k]))\n","\n","plot_attention_attr(attention_attr, clean_tokens, index=0, title=\"Attention Attribution for first sequence\")\n","\n","plot_attention_attr(attention_attr.sum(0), clean_tokens[0], title=\"Summed Attention Attribution for all sequences\")\n","print(\"Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Attribution Patching\n"," In the following sections, I will implement various kinds of attribution patching, and then compare them to the activation patching patterns (activation patching code copied from [Exploratory Analysis Demo](https://neelnanda.io/exploratory-analysis-demo))\n"," ### Residual Stream Patching\n","
Note: We add up across both d_model and batch (Explanation).\n"," We add up along d_model because we're taking the dot product - the derivative *is* the linear map that locally linearly approximates the metric, and so we take the dot product of our change vector with the derivative vector. Equivalent, we look at the effect of changing each coordinate independently, and then combine them by adding it up - it's linear, so this totally works.\n"," We add up across batch because we're taking the average of the metric, so each individual batch element provides `1/batch_size` of the overall effect. Because each batch element is independent of the others and no information moves between activations for different inputs, the batched version is equivalent to doing attribution patching separately for each input, and then averaging - in this second version the metric per input is *not* divided by batch_size because we don't average.
"]},{"cell_type":"code","execution_count":17,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_residual(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," ) -> TT[\"component\", \"pos\"]:\n"," clean_residual, residual_labels = clean_cache.accumulated_resid(-1, incl_mid=True, return_labels=True)\n"," corrupted_residual = corrupted_cache.accumulated_resid(-1, incl_mid=True, return_labels=False)\n"," corrupted_grad_residual = corrupted_grad_cache.accumulated_resid(-1, incl_mid=True, return_labels=False)\n"," residual_attr = einops.reduce(\n"," corrupted_grad_residual * (clean_residual - corrupted_residual),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\"\n"," )\n"," return residual_attr, residual_labels\n","\n","residual_attr, residual_labels = attr_patch_residual(clean_cache, corrupted_cache, corrupted_grad_cache)\n","imshow(residual_attr, y=residual_labels, yaxis=\"Component\", xaxis=\"Position\", title=\"Residual Attribution Patching\")\n","\n","# ### Layer Output Patching"]},{"cell_type":"code","execution_count":18,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_layer_out(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," ) -> TT[\"component\", \"pos\"]:\n"," clean_layer_out, labels = clean_cache.decompose_resid(-1, return_labels=True)\n"," corrupted_layer_out = corrupted_cache.decompose_resid(-1, return_labels=False)\n"," corrupted_grad_layer_out = corrupted_grad_cache.decompose_resid(-1, return_labels=False)\n"," layer_out_attr = einops.reduce(\n"," corrupted_grad_layer_out * (clean_layer_out - corrupted_layer_out),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\"\n"," )\n"," return layer_out_attr, labels\n","\n","layer_out_attr, layer_out_labels = attr_patch_layer_out(clean_cache, corrupted_cache, corrupted_grad_cache)\n","imshow(layer_out_attr, y=layer_out_labels, yaxis=\"Component\", xaxis=\"Position\", title=\"Layer Output Attribution Patching\")"]},{"cell_type":"code","execution_count":19,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_head_out(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," ) -> TT[\"component\", \"pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_out = clean_cache.stack_head_results(-1, return_labels=False)\n"," corrupted_head_out = corrupted_cache.stack_head_results(-1, return_labels=False)\n"," corrupted_grad_head_out = corrupted_grad_cache.stack_head_results(-1, return_labels=False)\n"," head_out_attr = einops.reduce(\n"," corrupted_grad_head_out * (clean_head_out - corrupted_head_out),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\"\n"," )\n"," return head_out_attr, labels\n","\n","head_out_attr, head_out_labels = attr_patch_head_out(clean_cache, corrupted_cache, corrupted_grad_cache)\n","imshow(head_out_attr, y=head_out_labels, yaxis=\"Component\", xaxis=\"Position\", title=\"Head Output Attribution Patching\")\n","sum_head_out_attr = einops.reduce(head_out_attr, \"(layer head) pos -> layer head\", \"sum\", layer=model.cfg.n_layers, head=model.cfg.n_heads)\n","imshow(sum_head_out_attr, yaxis=\"Layer\", xaxis=\"Head Index\", title=\"Head Output Attribution Patching Sum Over Pos\")"]},{"cell_type":"markdown","metadata":{},"source":[" ### Head Activation Patching\n"," Intuitively, a head has three inputs, keys, queries and values. We can patch each of these individually to get a sense for where the important part of each head's input comes from!\n"," As a sanity check, we also do this for the mixed value. The result is a linear map of this (`z @ W_O == result`), so this is the same as patching the output of the head.\n"," We plot both the patch for each head over each position, and summed over position (it tends to be pretty sparse, so the latter is the same)"]},{"cell_type":"code","execution_count":20,"metadata":{},"outputs":[{"data":{"text/markdown":["#### Key Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Query Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Value Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Mixed Value Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["from typing_extensions import Literal\n","def stack_head_vector_from_cache(\n"," cache, \n"," activation_name: Literal[\"q\", \"k\", \"v\", \"z\"]\n"," ) -> TT[\"layer_and_head_index\", \"batch\", \"pos\", \"d_head\"]:\n"," \"\"\"Stacks the head vectors from the cache from a specific activation (key, query, value or mixed_value (z)) into a single tensor.\"\"\"\n"," stacked_head_vectors = torch.stack([cache[activation_name, l] for l in range(model.cfg.n_layers)], dim=0)\n"," stacked_head_vectors = einops.rearrange(\n"," stacked_head_vectors,\n"," \"layer batch pos head_index d_head -> (layer head_index) batch pos d_head\"\n"," )\n"," return stacked_head_vectors\n","\n","def attr_patch_head_vector(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," activation_name: Literal[\"q\", \"k\", \"v\", \"z\"],\n"," ) -> TT[\"component\", \"pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_vector = stack_head_vector_from_cache(clean_cache, activation_name)\n"," corrupted_head_vector = stack_head_vector_from_cache(corrupted_cache, activation_name)\n"," corrupted_grad_head_vector = stack_head_vector_from_cache(corrupted_grad_cache, activation_name)\n"," head_vector_attr = einops.reduce(\n"," corrupted_grad_head_vector * (clean_head_vector - corrupted_head_vector),\n"," \"component batch pos d_head -> component pos\",\n"," \"sum\"\n"," )\n"," return head_vector_attr, labels\n","\n","head_vector_attr_dict = {}\n","for activation_name, activation_name_full in [(\"k\", \"Key\"), (\"q\", \"Query\"), (\"v\", \"Value\"), (\"z\", \"Mixed Value\")]:\n"," display(Markdown(f\"#### {activation_name_full} Head Vector Attribution Patching\"))\n"," head_vector_attr_dict[activation_name], head_vector_labels = attr_patch_head_vector(clean_cache, corrupted_cache, corrupted_grad_cache, activation_name)\n"," imshow(head_vector_attr_dict[activation_name], y=head_vector_labels, yaxis=\"Component\", xaxis=\"Position\", title=f\"{activation_name_full} Attribution Patching\")\n"," sum_head_vector_attr = einops.reduce(head_vector_attr_dict[activation_name], \"(layer head) pos -> layer head\", \"sum\", layer=model.cfg.n_layers, head=model.cfg.n_heads)\n"," imshow(sum_head_vector_attr, yaxis=\"Layer\", xaxis=\"Head Index\", title=f\"{activation_name_full} Attribution Patching Sum Over Pos\")"]},{"cell_type":"code","execution_count":21,"metadata":{},"outputs":[{"data":{"text/markdown":["### Head Pattern Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"}],"source":["from typing_extensions import Literal\n","def stack_head_pattern_from_cache(\n"," cache, \n"," ) -> TT[\"layer_and_head_index\", \"batch\", \"dest_pos\", \"src_pos\"]:\n"," \"\"\"Stacks the head patterns from the cache into a single tensor.\"\"\"\n"," stacked_head_pattern = torch.stack([cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0)\n"," stacked_head_pattern = einops.rearrange(\n"," stacked_head_pattern,\n"," \"layer batch head_index dest_pos src_pos -> (layer head_index) batch dest_pos src_pos\"\n"," )\n"," return stacked_head_pattern\n","\n","def attr_patch_head_pattern(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache,\n"," ) -> TT[\"component\", \"dest_pos\", \"src_pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_pattern = stack_head_pattern_from_cache(clean_cache)\n"," corrupted_head_pattern = stack_head_pattern_from_cache(corrupted_cache)\n"," corrupted_grad_head_pattern = stack_head_pattern_from_cache(corrupted_grad_cache)\n"," head_pattern_attr = einops.reduce(\n"," corrupted_grad_head_pattern * (clean_head_pattern - corrupted_head_pattern),\n"," \"component batch dest_pos src_pos -> component dest_pos src_pos\",\n"," \"sum\"\n"," )\n"," return head_pattern_attr, labels\n","\n","head_pattern_attr, labels = attr_patch_head_pattern(clean_cache, corrupted_cache, corrupted_grad_cache)\n","\n","plot_attention_attr(einops.rearrange(head_pattern_attr, \"(layer head) dest src -> layer head dest src\", layer=model.cfg.n_layers, head=model.cfg.n_heads), clean_tokens, index=0, title=\"Head Pattern Attribution Patching\")"]},{"cell_type":"code","execution_count":22,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_head_vector_grad_input_from_grad_cache(\n"," grad_cache: ActivationCache, \n"," activation_name: Literal[\"q\", \"k\", \"v\"],\n"," layer: int\n"," ) -> TT[\"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," vector_grad = grad_cache[activation_name, layer]\n"," ln_scales = grad_cache[\"scale\", layer, \"ln1\"]\n"," attn_layer_object = model.blocks[layer].attn\n"," if activation_name == \"q\":\n"," W = attn_layer_object.W_Q\n"," elif activation_name == \"k\":\n"," W = attn_layer_object.W_K\n"," elif activation_name == \"v\":\n"," W = attn_layer_object.W_V\n"," else:\n"," raise ValueError(\"Invalid activation name\")\n","\n"," return einsum(\"batch pos head_index d_head, batch pos, head_index d_model d_head -> batch pos head_index d_model\", vector_grad, ln_scales.squeeze(-1), W)\n","\n","def get_stacked_head_vector_grad_input(grad_cache, activation_name: Literal[\"q\", \"k\", \"v\"]) -> TT[\"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack([get_head_vector_grad_input_from_grad_cache(grad_cache, activation_name, l) for l in range(model.cfg.n_layers)], dim=0)\n","\n","def get_full_vector_grad_input(grad_cache) -> TT[\"qkv\", \"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack([get_stacked_head_vector_grad_input(grad_cache, activation_name) for activation_name in ['q', 'k', 'v']], dim=0)\n","\n","def attr_patch_head_path(\n"," clean_cache: ActivationCache, \n"," corrupted_cache: ActivationCache, \n"," corrupted_grad_cache: ActivationCache\n"," ) -> TT[\"qkv\", \"dest_component\", \"src_component\", \"pos\"]:\n"," \"\"\"\n"," Computes the attribution patch along the path between each pair of heads.\n","\n"," Sets this to zero for the path from any late head to any early head\n","\n"," \"\"\"\n"," start_labels = HEAD_NAMES\n"," end_labels = HEAD_NAMES_QKV\n"," full_vector_grad_input = get_full_vector_grad_input(corrupted_grad_cache)\n"," clean_head_result_stack = clean_cache.stack_head_results(-1)\n"," corrupted_head_result_stack = corrupted_cache.stack_head_results(-1)\n"," diff_head_result = einops.rearrange(\n"," clean_head_result_stack - corrupted_head_result_stack,\n"," \"(layer head_index) batch pos d_model -> layer batch pos head_index d_model\",\n"," layer = model.cfg.n_layers,\n"," head_index = model.cfg.n_heads,\n"," )\n"," path_attr = einsum(\n"," \"qkv layer_end batch pos head_end d_model, layer_start batch pos head_start d_model -> qkv layer_end head_end layer_start head_start pos\", \n"," full_vector_grad_input, \n"," diff_head_result)\n"," correct_layer_order_mask = (\n"," torch.arange(model.cfg.n_layers)[None, :, None, None, None, None] > \n"," torch.arange(model.cfg.n_layers)[None, None, None, :, None, None]).to(path_attr.device)\n"," zero = torch.zeros(1, device=path_attr.device)\n"," path_attr = torch.where(correct_layer_order_mask, path_attr, zero)\n","\n"," path_attr = einops.rearrange(\n"," path_attr,\n"," \"qkv layer_end head_end layer_start head_start pos -> (layer_end head_end qkv) (layer_start head_start) pos\",\n"," )\n"," return path_attr, end_labels, start_labels\n","\n","head_path_attr, end_labels, start_labels = attr_patch_head_path(clean_cache, corrupted_cache, corrupted_grad_cache)\n","imshow(head_path_attr.sum(-1), y=end_labels, yaxis=\"Path End (Head Input)\", x=start_labels, xaxis=\"Path Start (Head Output)\", title=\"Head Path Attribution Patching\")"]},{"cell_type":"markdown","metadata":{},"source":[" This is hard to parse. Here's an experiment with filtering for the most important heads and showing their paths."]},{"cell_type":"code","execution_count":23,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["head_out_values, head_out_indices = head_out_attr.sum(-1).abs().sort(descending=True)\n","line(head_out_values)\n","top_head_indices = head_out_indices[:22].sort().values\n","top_end_indices = []\n","top_end_labels = []\n","top_start_indices = []\n","top_start_labels = []\n","for i in top_head_indices:\n"," i = i.item()\n"," top_start_indices.append(i)\n"," top_start_labels.append(start_labels[i])\n"," for j in range(3):\n"," top_end_indices.append(3*i+j)\n"," top_end_labels.append(end_labels[3*i+j])\n","\n","imshow(head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1), y=top_end_labels, yaxis=\"Path End (Head Input)\", x=top_start_labels, xaxis=\"Path Start (Head Output)\", title=\"Head Path Attribution Patching (Filtered for Top Heads)\")"]},{"cell_type":"code","execution_count":24,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["for j, composition_type in enumerate([\"Query\", \"Key\", \"Value\"]):\n"," imshow(head_path_attr[top_end_indices, :][:, top_start_indices][j::3].sum(-1), y=top_end_labels[j::3], yaxis=\"Path End (Head Input)\", x=top_start_labels, xaxis=\"Path Start (Head Output)\", title=f\"Head Path to {composition_type} Attribution Patching (Filtered for Top Heads)\")"]},{"cell_type":"code","execution_count":25,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["top_head_path_attr = einops.rearrange(head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1), \"(head_end qkv) head_start -> qkv head_end head_start\", qkv=3)\n","imshow(top_head_path_attr, y=[i[:-1] for i in top_end_labels[::3]], yaxis=\"Path End (Head Input)\", x=top_start_labels, xaxis=\"Path Start (Head Output)\", title=f\"Head Path Attribution Patching (Filtered for Top Heads)\", facet_col=0, facet_labels=[\"Query\", \"Key\", \"Value\"])"]},{"cell_type":"markdown","metadata":{},"source":[" Let's now dive into 3 interesting heads: L5H5 (induction head), L8H6 (S-Inhibition Head), L9H9 (Name Mover) and look at their input and output paths (note - Q input means )"]},{"cell_type":"code","execution_count":26,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["interesting_heads = [5 * model.cfg.n_heads + 5, 8 * model.cfg.n_heads + 6, 9 * model.cfg.n_heads + 9]\n","interesting_head_labels = [HEAD_NAMES[i] for i in interesting_heads]\n","for head_index, label in zip(interesting_heads, interesting_head_labels):\n"," in_paths = head_path_attr[3*head_index:3*head_index+3].sum(-1)\n"," out_paths = head_path_attr[:, head_index].sum(-1)\n"," out_paths = einops.rearrange(out_paths, \"(layer_head qkv) -> qkv layer_head\", qkv=3)\n"," all_paths = torch.cat([in_paths, out_paths], dim=0)\n"," all_paths = einops.rearrange(all_paths, \"path_type (layer head) -> path_type layer head\", layer=model.cfg.n_layers, head=model.cfg.n_heads)\n"," imshow(all_paths, facet_col=0, facet_labels=[\"Query (In)\", \"Key (In)\", \"Value (In)\", \"Query (Out)\", \"Key (Out)\", \"Value (Out)\"], title=f\"Input and Output Paths for head {label}\", yaxis=\"Layer\", xaxis=\"Head\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Validating Attribution vs Activation Patching\n"," Let's now compare attribution and activation patching. Generally it's a decent approximation! The main place it fails is MLP0 and the residual stream\n"," My fuzzy intuition is that attribution patching works badly for \"big\" things which are poorly modelled as linear approximations, and works well for \"small\" things which are more like incremental changes. Anything involving replacing the embedding is a \"big\" thing, which includes residual streams, and in GPT-2 small MLP0 seems to be used as an \"extended embedding\" (where later layers use MLP0's output instead of the token embedding), so I also count it as big.\n"," See more discussion in the accompanying blog post!\n"]},{"cell_type":"markdown","metadata":{},"source":[" First do some refactoring to make attribution patching more generic. We make an attribution cache, which is an ActivationCache where each element is (clean_act - corrupted_act) * corrupted_grad, so that it's the per-element attribution for each activation. Thanks to linearity, we just compute things by adding stuff up along the relevant dimensions!"]},{"cell_type":"code","execution_count":27,"metadata":{},"outputs":[],"source":["attribution_cache_dict = {}\n","for key in corrupted_grad_cache.cache_dict.keys():\n"," attribution_cache_dict[key] = corrupted_grad_cache.cache_dict[key] * (clean_cache.cache_dict[key] - corrupted_cache.cache_dict[key])\n","attr_cache = ActivationCache(attribution_cache_dict, model)"]},{"cell_type":"markdown","metadata":{},"source":[" By block: For each head we patch the starting residual stream, attention output + MLP output"]},{"cell_type":"code","execution_count":28,"metadata":{},"outputs":[],"source":["str_tokens = model.to_str_tokens(clean_tokens[0])\n","context_length = len(str_tokens)"]},{"cell_type":"code","execution_count":29,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"95a5290e11b64b6a95ef5dd37d027c7a","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/180 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_block_act_patch_result = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(every_block_act_patch_result, facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Activation Patching Per Block\", xaxis=\"Position\", yaxis=\"Layer\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])"]},{"cell_type":"code","execution_count":30,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_block_every(attr_cache):\n"," resid_pre_attr = einops.reduce(\n"," attr_cache.stack_activation(\"resid_pre\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," attn_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"attn_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," mlp_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"mlp_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n","\n"," every_block_attr_patch_result = torch.stack([resid_pre_attr, attn_out_attr, mlp_out_attr], dim=0)\n"," return every_block_attr_patch_result\n","every_block_attr_patch_result = get_attr_patch_block_every(attr_cache)\n","imshow(every_block_attr_patch_result, facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Attribution Patching Per Block\", xaxis=\"Position\", yaxis=\"Layer\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])"]},{"cell_type":"code","execution_count":31,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(y=every_block_attr_patch_result.reshape(3, -1), x=every_block_act_patch_result.reshape(3, -1), facet_col=0, facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], title=\"Attribution vs Activation Patching Per Block\", xaxis=\"Activation Patch\", yaxis=\"Attribution Patch\", hover=[f\"Layer {l}, Position {p}, |{str_tokens[p]}|\" for l in range(model.cfg.n_layers) for p in range(context_length)], color=einops.repeat(torch.arange(model.cfg.n_layers), \"layer -> (layer pos)\", pos=context_length), color_continuous_scale=\"Portland\")"]},{"cell_type":"markdown","metadata":{},"source":[" By head: For each head we patch the output, query, key, value or pattern. We do all positions at once so it's not super slow."]},{"cell_type":"code","execution_count":32,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"18b2e6b0985b40cd8c0cd1a16ba62975","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)\n","imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=1, zmin=-1)"]},{"cell_type":"code","execution_count":33,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_all_pos_every(attr_cache):\n"," head_out_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_q_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_k_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_v_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack([head_out_all_pos_attr, head_q_all_pos_attr, head_k_all_pos_attr, head_v_all_pos_attr, head_pattern_all_pos_attr])\n"," \n","every_head_all_pos_attr_patch_result = get_attr_patch_attn_head_all_pos_every(attr_cache)\n","imshow(every_head_all_pos_attr_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution Patching Per Head (All Pos)\", xaxis=\"Head\", yaxis=\"Layer\", zmax=1, zmin=-1)"]},{"cell_type":"code","execution_count":34,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(y=every_head_all_pos_attr_patch_result.reshape(5, -1), x=every_head_all_pos_act_patch_result.reshape(5, -1), facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution vs Activation Patching Per Head (All Pos)\", xaxis=\"Activation Patch\", yaxis=\"Attribution Patch\", include_diag=True, hover=head_out_labels, color=einops.repeat(torch.arange(model.cfg.n_layers), \"layer -> (layer head)\", head=model.cfg.n_heads), color_continuous_scale=\"Portland\")"]},{"cell_type":"markdown","metadata":{},"source":[" We see pretty good results in general, but significant errors for heads L5H5 on query and moderate errors for head L10H7 on query and key, and moderate errors for head L11H10 on key. But each of these is fine for pattern and output. My guess is that the problem is that these have pretty saturated attention on a single token, and the linear approximation is thus not great on the attention calculation here, but I'm not sure. When we plot the attention patterns, we do see this!\n"," Note that the axis labels are for the *first* prompt's tokens, but each facet is a different prompt, so this is somewhat inaccurate. In particular, every odd facet has indirect object and subject in the opposite order (IO first). But otherwise everything lines up between the prompts"]},{"cell_type":"code","execution_count":35,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["graph_tok_labels = [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]\n","imshow(clean_cache[\"pattern\", 5][:, 5], x= graph_tok_labels, y=graph_tok_labels, facet_col=0, title=\"Attention for Head L5H5\", facet_name=\"Prompt\")\n","imshow(clean_cache[\"pattern\", 10][:, 7], x= graph_tok_labels, y=graph_tok_labels, facet_col=0, title=\"Attention for Head L10H7\", facet_name=\"Prompt\")\n","imshow(clean_cache[\"pattern\", 11][:, 10], x= graph_tok_labels, y=graph_tok_labels, facet_col=0, title=\"Attention for Head L11H10\", facet_name=\"Prompt\")\n","\n","\n","# [markdown]"]},{"cell_type":"code","execution_count":36,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"06f39489001845849fbc7446a07066f4","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/2160 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_by_pos_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)\n","every_head_by_pos_act_patch_result = einops.rearrange(every_head_by_pos_act_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n","imshow(every_head_by_pos_act_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Activation Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))], y=head_out_labels)"]},{"cell_type":"code","execution_count":37,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_by_pos_every(attr_cache):\n"," head_out_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_q_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_k_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_v_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer dest_pos head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack([head_out_by_pos_attr, head_q_by_pos_attr, head_k_by_pos_attr, head_v_by_pos_attr, head_pattern_by_pos_attr])\n","every_head_by_pos_attr_patch_result = get_attr_patch_attn_head_by_pos_every(attr_cache)\n","every_head_by_pos_attr_patch_result = einops.rearrange(every_head_by_pos_attr_patch_result, \"act_type layer pos head -> act_type (layer head) pos\")\n","imshow(every_head_by_pos_attr_patch_result, facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution Patching Per Head (By Pos)\", xaxis=\"Position\", yaxis=\"Layer & Head\", zmax=1, zmin=-1, x= [f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))], y=head_out_labels)"]},{"cell_type":"code","execution_count":38,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(y=every_head_by_pos_attr_patch_result.reshape(5, -1), x=every_head_by_pos_act_patch_result.reshape(5, -1), facet_col=0, facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"], title=\"Attribution vs Activation Patching Per Head (by Pos)\", xaxis=\"Activation Patch\", yaxis=\"Attribution Patch\", include_diag=True, hover=[f\"{label} {tok}\" for label in head_out_labels for tok in graph_tok_labels], color=einops.repeat(torch.arange(model.cfg.n_layers), \"layer -> (layer head pos)\", head=model.cfg.n_heads, pos = 15), color_continuous_scale=\"Portland\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Factual Knowledge Patching Example\n"," Incomplete, but maybe of interest!\n"," Note that I have better results with the corrupted prompt as having random words rather than Colosseum."]},{"cell_type":"code","execution_count":39,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-xl into HookedTransformer\n","Tokenized prompt: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Paris']\n"]},{"data":{"text/html":["
Performance on answer token:\n","Rank: 0        Logit: 20.73 Prob: 95.80% Token: | Paris|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.73\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m95.80\u001b[0m\u001b[1m% Token: | Paris|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.73 Prob: 95.80% Token: | Paris|\n","Top 1th token. Logit: 16.49 Prob: 1.39% Token: | E|\n","Top 2th token. Logit: 14.69 Prob: 0.23% Token: | the|\n","Top 3th token. Logit: 14.58 Prob: 0.21% Token: | É|\n","Top 4th token. Logit: 14.44 Prob: 0.18% Token: | France|\n","Top 5th token. Logit: 14.36 Prob: 0.16% Token: | Mont|\n","Top 6th token. Logit: 13.77 Prob: 0.09% Token: | Le|\n","Top 7th token. Logit: 13.66 Prob: 0.08% Token: | Ang|\n","Top 8th token. Logit: 13.43 Prob: 0.06% Token: | V|\n","Top 9th token. Logit: 13.42 Prob: 0.06% Token: | Stras|\n"]},{"data":{"text/html":["
Ranks of the answer tokens: [(' Paris', 0)]\n","
\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Paris'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Tokenized prompt: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Rome']\n"]},{"data":{"text/html":["
Performance on answer token:\n","Rank: 0        Logit: 20.02 Prob: 83.70% Token: | Rome|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.02\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m83.70\u001b[0m\u001b[1m% Token: | Rome|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.02 Prob: 83.70% Token: | Rome|\n","Top 1th token. Logit: 17.03 Prob: 4.23% Token: | Naples|\n","Top 2th token. Logit: 16.85 Prob: 3.51% Token: | Pompe|\n","Top 3th token. Logit: 16.14 Prob: 1.73% Token: | Ver|\n","Top 4th token. Logit: 15.87 Prob: 1.32% Token: | Florence|\n","Top 5th token. Logit: 14.77 Prob: 0.44% Token: | Roma|\n","Top 6th token. Logit: 14.68 Prob: 0.40% Token: | Milan|\n","Top 7th token. Logit: 14.66 Prob: 0.39% Token: | ancient|\n","Top 8th token. Logit: 14.37 Prob: 0.29% Token: | Pal|\n","Top 9th token. Logit: 14.30 Prob: 0.27% Token: | Constantinople|\n"]},{"data":{"text/html":["
Ranks of the answer tokens: [(' Rome', 0)]\n","
\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Rome'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"}],"source":["gpt2_xl = HookedTransformer.from_pretrained(\"gpt2-xl\")\n","clean_prompt = \"The Eiffel Tower is located in the city of\"\n","clean_answer = \" Paris\"\n","# corrupted_prompt = \"The red brown fox jumps is located in the city of\"\n","corrupted_prompt = \"The Colosseum is located in the city of\"\n","corrupted_answer = \" Rome\"\n","utils.test_prompt(clean_prompt, clean_answer, gpt2_xl)\n","utils.test_prompt(corrupted_prompt, corrupted_answer, gpt2_xl)"]},{"cell_type":"code","execution_count":40,"metadata":{},"outputs":[],"source":["clean_answer_index = gpt2_xl.to_single_token(clean_answer)\n","corrupted_answer_index = gpt2_xl.to_single_token(corrupted_answer)\n","def factual_logit_diff(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return logits[0, -1, clean_answer_index] - logits[0, -1, corrupted_answer_index]"]},{"cell_type":"code","execution_count":41,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 10.634519577026367\n","Corrupted logit diff: -8.988396644592285\n","Clean Metric: tensor(1., device='cuda:0', grad_fn=)\n","Corrupted Metric: tensor(0., device='cuda:0', grad_fn=)\n"]}],"source":["clean_logits, clean_cache = gpt2_xl.run_with_cache(clean_prompt)\n","CLEAN_LOGIT_DIFF_FACTUAL = factual_logit_diff(clean_logits).item()\n","corrupted_logits, _ = gpt2_xl.run_with_cache(corrupted_prompt)\n","CORRUPTED_LOGIT_DIFF_FACTUAL = factual_logit_diff(corrupted_logits).item()\n","\n","def factual_metric(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return (factual_logit_diff(logits) - CORRUPTED_LOGIT_DIFF_FACTUAL) / (CLEAN_LOGIT_DIFF_FACTUAL - CORRUPTED_LOGIT_DIFF_FACTUAL)\n","print(\"Clean logit diff:\", CLEAN_LOGIT_DIFF_FACTUAL)\n","print(\"Corrupted logit diff:\", CORRUPTED_LOGIT_DIFF_FACTUAL)\n","print(\"Clean Metric:\", factual_metric(clean_logits))\n","print(\"Corrupted Metric:\", factual_metric(corrupted_logits))"]},{"cell_type":"code","execution_count":42,"metadata":{},"outputs":[],"source":["# corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(gpt2_xl, corrupted_prompt, factual_metric)"]},{"cell_type":"code","execution_count":43,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Corrupted: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n"]}],"source":["clean_tokens = gpt2_xl.to_tokens(clean_prompt)\n","clean_str_tokens = gpt2_xl.to_str_tokens(clean_prompt)\n","corrupted_tokens = gpt2_xl.to_tokens(corrupted_prompt)\n","corrupted_str_tokens = gpt2_xl.to_str_tokens(corrupted_prompt)\n","print(\"Clean:\", clean_str_tokens)\n","print(\"Corrupted:\", corrupted_str_tokens)"]},{"cell_type":"code","execution_count":44,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"b767eef7a3cd49b9b3cb6e5301463f08","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/48 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def act_patch_residual(clean_cache, corrupted_tokens, model: HookedTransformer, metric):\n"," if len(corrupted_tokens.shape)==2:\n"," corrupted_tokens = corrupted_tokens[0]\n"," residual_patches = torch.zeros((model.cfg.n_layers, len(corrupted_tokens)), device=model.cfg.device)\n"," def residual_hook(resid_pre, hook, layer, pos):\n"," resid_pre[:, pos, :] = clean_cache[\"resid_pre\", layer][:, pos, :]\n"," return resid_pre\n"," for layer in tqdm.tqdm(range(model.cfg.n_layers)):\n"," for pos in range(len(corrupted_tokens)):\n"," patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[(f\"blocks.{layer}.hook_resid_pre\", partial(residual_hook, layer=layer, pos=pos))])\n"," residual_patches[layer, pos] = metric(patched_logits).item()\n"," return residual_patches\n","\n","residual_act_patch = act_patch_residual(clean_cache, corrupted_tokens, gpt2_xl, factual_metric)\n","\n","imshow(residual_act_patch, title=\"Factual Recall Patching (Residual)\", xaxis=\"Position\", yaxis=\"Layer\", x=clean_str_tokens)"]}],"metadata":{"kernelspec":{"display_name":"base","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.7.13"},"orig_nbformat":4,"vscode":{"interpreter":{"hash":"d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"}}},"nbformat":4,"nbformat_minor":2} +{"cells":[{"cell_type":"markdown","metadata":{},"source":["\n"," \"Open\n",""]},{"cell_type":"markdown","metadata":{},"source":[" # Attribution Patching Demo\n"," **Read [the accompanying blog post here](https://neelnanda.io/attribution-patching) for more context**\n"," This is an interim research report, giving a whirlwind tour of some unpublished work I did at Anthropic (credit to the then team - Chris Olah, Catherine Olsson, Nelson Elhage and Tristan Hume for help, support, and mentorship!)\n","\n"," The goal of this work is run activation patching at an industrial scale, by using gradient based attribution to approximate the technique - allow an arbitrary number of patches to be made on two forwards and a single backward pass\n","\n"," I have had less time than hoped to flesh out this investigation, but am writing up a rough investigation and comparison to standard activation patching on a few tasks to give a sense of the potential of this approach, and where it works vs falls down."]},{"cell_type":"markdown","metadata":{},"source":[" To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.\n","\n"," **Tips for reading this Colab:**\n"," * You can run all this code for yourself!\n"," * The graphs are interactive!\n"," * Use the table of contents pane in the sidebar to navigate\n"," * Collapse irrelevant sections with the dropdown arrows\n"," * Search the page using the search in the sidebar, not CTRL+F"]},{"cell_type":"markdown","metadata":{},"source":[" ## Setup (Ignore)"]},{"cell_type":"code","execution_count":1,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Running as a Jupyter notebook - intended for development only!\n"]},{"name":"stderr","output_type":"stream","text":["/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_25358/2480103146.py:24: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n"," ipython.magic(\"load_ext autoreload\")\n","/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_25358/2480103146.py:25: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n"," ipython.magic(\"autoreload 2\")\n"]}],"source":["# Janky code to do different setup when run in a Colab notebook vs VSCode\n","import os\n","\n","DEBUG_MODE = False\n","IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n","try:\n"," import google.colab\n","\n"," IN_COLAB = True\n"," print(\"Running as a Colab notebook\")\n","except:\n"," IN_COLAB = False\n"," print(\"Running as a Jupyter notebook - intended for development only!\")\n"," from IPython import get_ipython\n","\n"," ipython = get_ipython()\n"," # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n"," ipython.magic(\"load_ext autoreload\")\n"," ipython.magic(\"autoreload 2\")\n","\n","if IN_COLAB or IN_GITHUB:\n"," %pip install transformer_lens\n"," %pip install torchtyping\n"," # Install my janky personal plotting utils\n"," %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n"," # Install another version of node that makes PySvelte work way faster\n"," %pip install circuitsvis\n"," # Needed for PySvelte to work, v3 came out and broke things...\n"," %pip install typeguard==2.13.3"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n","import plotly.io as pio\n","\n","if IN_COLAB or not DEBUG_MODE:\n"," # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.\n"," pio.renderers.default = \"colab\"\n","else:\n"," pio.renderers.default = \"notebook_connected\""]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[{"ename":"ModuleNotFoundError","evalue":"No module named 'torchtyping'","output_type":"error","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)","Cell \u001b[0;32mIn[3], line 15\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mplotly\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mexpress\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpx\u001b[39;00m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m DataLoader\n\u001b[0;32m---> 15\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorchtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m TensorType \u001b[38;5;28;01mas\u001b[39;00m TT\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m List, Union, Optional, Callable\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mfunctools\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m partial\n","\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torchtyping'"]}],"source":["# Import stuff\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","import numpy as np\n","import einops\n","from fancy_einsum import einsum\n","import tqdm.notebook as tqdm\n","import random\n","from pathlib import Path\n","import plotly.express as px\n","from torch.utils.data import DataLoader\n","\n","from torchtyping import TensorType as TT\n","from typing import List, Union, Optional, Callable\n","from functools import partial\n","import copy\n","import itertools\n","import json\n","\n","from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n","import dataclasses\n","import datasets\n","from IPython.display import HTML, Markdown"]},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[],"source":["import transformer_lens\n","import transformer_lens.utils as utils\n","from transformer_lens.hook_points import (\n"," HookedRootModule,\n"," HookPoint,\n",") # Hooking utilities\n","from transformer_lens import (\n"," HookedTransformer,\n"," HookedTransformerConfig,\n"," FactoredMatrix,\n"," ActivationCache,\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" Plotting helper functions from a janky personal library of plotting utils. The library is not documented and I recommend against trying to read it, just use your preferred plotting library if you want to do anything non-obvious:"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[],"source":["from neel_plotly import line, imshow, scatter"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[],"source":["import transformer_lens.patching as patching"]},{"cell_type":"markdown","metadata":{},"source":[" ## IOI Patching Setup\n"," This just copies the relevant set up from Exploratory Analysis Demo, and isn't very important."]},{"cell_type":"code","execution_count":8,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-small into HookedTransformer\n"]}],"source":["model = HookedTransformer.from_pretrained(\"gpt2-small\")\n","model.set_use_attn_result(True)"]},{"cell_type":"code","execution_count":9,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean string 0 <|endoftext|>When John and Mary went to the shops, John gave the bag to\n","Corrupted string 0 <|endoftext|>When John and Mary went to the shops, Mary gave the bag to\n","Answer token indices tensor([[ 5335, 1757],\n"," [ 1757, 5335],\n"," [ 4186, 3700],\n"," [ 3700, 4186],\n"," [ 6035, 15686],\n"," [15686, 6035],\n"," [ 5780, 14235],\n"," [14235, 5780]], device='cuda:0')\n"]}],"source":["prompts = [\n"," \"When John and Mary went to the shops, John gave the bag to\",\n"," \"When John and Mary went to the shops, Mary gave the bag to\",\n"," \"When Tom and James went to the park, James gave the ball to\",\n"," \"When Tom and James went to the park, Tom gave the ball to\",\n"," \"When Dan and Sid went to the shops, Sid gave an apple to\",\n"," \"When Dan and Sid went to the shops, Dan gave an apple to\",\n"," \"After Martin and Amy went to the park, Amy gave a drink to\",\n"," \"After Martin and Amy went to the park, Martin gave a drink to\",\n","]\n","answers = [\n"," (\" Mary\", \" John\"),\n"," (\" John\", \" Mary\"),\n"," (\" Tom\", \" James\"),\n"," (\" James\", \" Tom\"),\n"," (\" Dan\", \" Sid\"),\n"," (\" Sid\", \" Dan\"),\n"," (\" Martin\", \" Amy\"),\n"," (\" Amy\", \" Martin\"),\n","]\n","\n","clean_tokens = model.to_tokens(prompts)\n","# Swap each adjacent pair, with a hacky list comprehension\n","corrupted_tokens = clean_tokens[\n"," [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]\n","]\n","print(\"Clean string 0\", model.to_string(clean_tokens[0]))\n","print(\"Corrupted string 0\", model.to_string(corrupted_tokens[0]))\n","\n","answer_token_indices = torch.tensor(\n"," [\n"," [model.to_single_token(answers[i][j]) for j in range(2)]\n"," for i in range(len(answers))\n"," ],\n"," device=model.cfg.device,\n",")\n","print(\"Answer token indices\", answer_token_indices)"]},{"cell_type":"code","execution_count":10,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 3.5519\n","Corrupted logit diff: -3.5519\n"]}],"source":["def get_logit_diff(logits, answer_token_indices=answer_token_indices):\n"," if len(logits.shape) == 3:\n"," # Get final logits only\n"," logits = logits[:, -1, :]\n"," correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))\n"," incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))\n"," return (correct_logits - incorrect_logits).mean()\n","\n","\n","clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n","corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)\n","\n","clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()\n","print(f\"Clean logit diff: {clean_logit_diff:.4f}\")\n","\n","corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()\n","print(f\"Corrupted logit diff: {corrupted_logit_diff:.4f}\")"]},{"cell_type":"code","execution_count":11,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Baseline is 1: 1.0000\n","Corrupted Baseline is 0: 0.0000\n"]}],"source":["CLEAN_BASELINE = clean_logit_diff\n","CORRUPTED_BASELINE = corrupted_logit_diff\n","\n","\n","def ioi_metric(logits, answer_token_indices=answer_token_indices):\n"," return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (\n"," CLEAN_BASELINE - CORRUPTED_BASELINE\n"," )\n","\n","\n","print(f\"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}\")\n","print(f\"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Patching\n"," In the following cells, we define attribution patching and use it in various ways on the model."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["Metric = Callable[[TT[\"batch_and_pos_dims\", \"d_model\"]], float]"]},{"cell_type":"code","execution_count":13,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Value: 1.0\n","Clean Activations Cached: 220\n","Clean Gradients Cached: 220\n","Corrupted Value: 0.0\n","Corrupted Activations Cached: 220\n","Corrupted Gradients Cached: 220\n"]}],"source":["filter_not_qkv_input = lambda name: \"_input\" not in name\n","\n","\n","def get_cache_fwd_and_bwd(model, tokens, metric):\n"," model.reset_hooks()\n"," cache = {}\n","\n"," def forward_cache_hook(act, hook):\n"," cache[hook.name] = act.detach()\n","\n"," model.add_hook(filter_not_qkv_input, forward_cache_hook, \"fwd\")\n","\n"," grad_cache = {}\n","\n"," def backward_cache_hook(act, hook):\n"," grad_cache[hook.name] = act.detach()\n","\n"," model.add_hook(filter_not_qkv_input, backward_cache_hook, \"bwd\")\n","\n"," value = metric(model(tokens))\n"," value.backward()\n"," model.reset_hooks()\n"," return (\n"," value.item(),\n"," ActivationCache(cache, model),\n"," ActivationCache(grad_cache, model),\n"," )\n","\n","\n","clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(\n"," model, clean_tokens, ioi_metric\n",")\n","print(\"Clean Value:\", clean_value)\n","print(\"Clean Activations Cached:\", len(clean_cache))\n","print(\"Clean Gradients Cached:\", len(clean_grad_cache))\n","corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(\n"," model, corrupted_tokens, ioi_metric\n",")\n","print(\"Corrupted Value:\", corrupted_value)\n","print(\"Corrupted Activations Cached:\", len(corrupted_cache))\n","print(\"Corrupted Gradients Cached:\", len(corrupted_grad_cache))"]},{"cell_type":"markdown","metadata":{},"source":[" ### Attention Attribution\n"," The easiest thing to start with is to not even engage with the corrupted tokens/patching, but to look at the attribution of the attention patterns - that is, the linear approximation to what happens if you set each element of the attention pattern to zero. This, as it turns out, is a good proxy to what is going on with each head!\n"," Note that this is *not* the same as what we will later do with patching. In particular, this does not set up a careful counterfactual! It's a good tool for what's generally going on in this problem, but does not control for eg stuff that systematically boosts John > Mary in general, stuff that says \"I should activate the IOI circuit\", etc. Though using logit diff as our metric *does*\n"," Each element of the batch is independent and the metric is an average logit diff, so we can analyse each batch element independently here. We'll look at the first one, and then at the average across the whole batch (note - 4 prompts have indirect object before subject, 4 prompts have it the other way round, making the average pattern harder to interpret - I plot it over the first sequence of tokens as a mildly misleading reference).\n"," We can compare it to the interpretability in the wild diagram, and basically instantly recover most of the circuit!"]},{"cell_type":"code","execution_count":14,"metadata":{},"outputs":[],"source":["def create_attention_attr(\n"," clean_cache, clean_grad_cache\n",") -> TT[\"batch\", \"layer\", \"head_index\", \"dest\", \"src\"]:\n"," attention_stack = torch.stack(\n"," [clean_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0\n"," )\n"," attention_grad_stack = torch.stack(\n"," [clean_grad_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0\n"," )\n"," attention_attr = attention_grad_stack * attention_stack\n"," attention_attr = einops.rearrange(\n"," attention_attr,\n"," \"layer batch head_index dest src -> batch layer head_index dest src\",\n"," )\n"," return attention_attr\n","\n","\n","attention_attr = create_attention_attr(clean_cache, clean_grad_cache)"]},{"cell_type":"code","execution_count":15,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["['L0H0', 'L0H1', 'L0H2', 'L0H3', 'L0H4']\n","['L0H0+', 'L0H0-', 'L0H1+', 'L0H1-', 'L0H2+']\n","['L0H0Q', 'L0H0K', 'L0H0V', 'L0H1Q', 'L0H1K']\n"]}],"source":["HEAD_NAMES = [\n"," f\"L{l}H{h}\" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)\n","]\n","HEAD_NAMES_SIGNED = [f\"{name}{sign}\" for name in HEAD_NAMES for sign in [\"+\", \"-\"]]\n","HEAD_NAMES_QKV = [\n"," f\"{name}{act_name}\" for name in HEAD_NAMES for act_name in [\"Q\", \"K\", \"V\"]\n","]\n","print(HEAD_NAMES[:5])\n","print(HEAD_NAMES_SIGNED[:5])\n","print(HEAD_NAMES_QKV[:5])"]},{"cell_type":"markdown","metadata":{},"source":[" An extremely janky way to plot the attention attribution patterns. We scale them to be in [-1, 1], split each head into a positive and negative part (so all of it is in [0, 1]), and then plot the top 20 head-halves (a head can appear twice!) by the max value of the attribution pattern."]},{"cell_type":"code","execution_count":16,"metadata":{},"outputs":[{"data":{"text/markdown":["### Attention Attribution for first sequence"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["### Summed Attention Attribution for all sequences"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\n"]}],"source":["def plot_attention_attr(attention_attr, tokens, top_k=20, index=0, title=\"\"):\n"," if len(tokens.shape) == 2:\n"," tokens = tokens[index]\n"," if len(attention_attr.shape) == 5:\n"," attention_attr = attention_attr[index]\n"," attention_attr_pos = attention_attr.clamp(min=-1e-5)\n"," attention_attr_neg = -attention_attr.clamp(max=1e-5)\n"," attention_attr_signed = torch.stack([attention_attr_pos, attention_attr_neg], dim=0)\n"," attention_attr_signed = einops.rearrange(\n"," attention_attr_signed,\n"," \"sign layer head_index dest src -> (layer head_index sign) dest src\",\n"," )\n"," attention_attr_signed = attention_attr_signed / attention_attr_signed.max()\n"," attention_attr_indices = (\n"," attention_attr_signed.max(-1).values.max(-1).values.argsort(descending=True)\n"," )\n"," # print(attention_attr_indices.shape)\n"," # print(attention_attr_indices)\n"," attention_attr_signed = attention_attr_signed[attention_attr_indices, :, :]\n"," head_labels = [HEAD_NAMES_SIGNED[i.item()] for i in attention_attr_indices]\n","\n"," if title:\n"," display(Markdown(\"### \" + title))\n"," display(\n"," pysvelte.AttentionMulti(\n"," tokens=model.to_str_tokens(tokens),\n"," attention=attention_attr_signed.permute(1, 2, 0)[:, :, :top_k],\n"," head_labels=head_labels[:top_k],\n"," )\n"," )\n","\n","\n","plot_attention_attr(\n"," attention_attr,\n"," clean_tokens,\n"," index=0,\n"," title=\"Attention Attribution for first sequence\",\n",")\n","\n","plot_attention_attr(\n"," attention_attr.sum(0),\n"," clean_tokens[0],\n"," title=\"Summed Attention Attribution for all sequences\",\n",")\n","print(\n"," \"Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\"\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Attribution Patching\n"," In the following sections, I will implement various kinds of attribution patching, and then compare them to the activation patching patterns (activation patching code copied from [Exploratory Analysis Demo](https://neelnanda.io/exploratory-analysis-demo))\n"," ### Residual Stream Patching\n","
Note: We add up across both d_model and batch (Explanation).\n"," We add up along d_model because we're taking the dot product - the derivative *is* the linear map that locally linearly approximates the metric, and so we take the dot product of our change vector with the derivative vector. Equivalent, we look at the effect of changing each coordinate independently, and then combine them by adding it up - it's linear, so this totally works.\n"," We add up across batch because we're taking the average of the metric, so each individual batch element provides `1/batch_size` of the overall effect. Because each batch element is independent of the others and no information moves between activations for different inputs, the batched version is equivalent to doing attribution patching separately for each input, and then averaging - in this second version the metric per input is *not* divided by batch_size because we don't average.
"]},{"cell_type":"code","execution_count":17,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_residual(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"component\", \"pos\"]:\n"," clean_residual, residual_labels = clean_cache.accumulated_resid(\n"," -1, incl_mid=True, return_labels=True\n"," )\n"," corrupted_residual = corrupted_cache.accumulated_resid(\n"," -1, incl_mid=True, return_labels=False\n"," )\n"," corrupted_grad_residual = corrupted_grad_cache.accumulated_resid(\n"," -1, incl_mid=True, return_labels=False\n"," )\n"," residual_attr = einops.reduce(\n"," corrupted_grad_residual * (clean_residual - corrupted_residual),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\",\n"," )\n"," return residual_attr, residual_labels\n","\n","\n","residual_attr, residual_labels = attr_patch_residual(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","imshow(\n"," residual_attr,\n"," y=residual_labels,\n"," yaxis=\"Component\",\n"," xaxis=\"Position\",\n"," title=\"Residual Attribution Patching\",\n",")\n","\n","# ### Layer Output Patching"]},{"cell_type":"code","execution_count":18,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_layer_out(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"component\", \"pos\"]:\n"," clean_layer_out, labels = clean_cache.decompose_resid(-1, return_labels=True)\n"," corrupted_layer_out = corrupted_cache.decompose_resid(-1, return_labels=False)\n"," corrupted_grad_layer_out = corrupted_grad_cache.decompose_resid(\n"," -1, return_labels=False\n"," )\n"," layer_out_attr = einops.reduce(\n"," corrupted_grad_layer_out * (clean_layer_out - corrupted_layer_out),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\",\n"," )\n"," return layer_out_attr, labels\n","\n","\n","layer_out_attr, layer_out_labels = attr_patch_layer_out(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","imshow(\n"," layer_out_attr,\n"," y=layer_out_labels,\n"," yaxis=\"Component\",\n"," xaxis=\"Position\",\n"," title=\"Layer Output Attribution Patching\",\n",")"]},{"cell_type":"code","execution_count":19,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_head_out(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"component\", \"pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_out = clean_cache.stack_head_results(-1, return_labels=False)\n"," corrupted_head_out = corrupted_cache.stack_head_results(-1, return_labels=False)\n"," corrupted_grad_head_out = corrupted_grad_cache.stack_head_results(\n"," -1, return_labels=False\n"," )\n"," head_out_attr = einops.reduce(\n"," corrupted_grad_head_out * (clean_head_out - corrupted_head_out),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\",\n"," )\n"," return head_out_attr, labels\n","\n","\n","head_out_attr, head_out_labels = attr_patch_head_out(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","imshow(\n"," head_out_attr,\n"," y=head_out_labels,\n"," yaxis=\"Component\",\n"," xaxis=\"Position\",\n"," title=\"Head Output Attribution Patching\",\n",")\n","sum_head_out_attr = einops.reduce(\n"," head_out_attr,\n"," \"(layer head) pos -> layer head\",\n"," \"sum\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n",")\n","imshow(\n"," sum_head_out_attr,\n"," yaxis=\"Layer\",\n"," xaxis=\"Head Index\",\n"," title=\"Head Output Attribution Patching Sum Over Pos\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" ### Head Activation Patching\n"," Intuitively, a head has three inputs, keys, queries and values. We can patch each of these individually to get a sense for where the important part of each head's input comes from!\n"," As a sanity check, we also do this for the mixed value. The result is a linear map of this (`z @ W_O == result`), so this is the same as patching the output of the head.\n"," We plot both the patch for each head over each position, and summed over position (it tends to be pretty sparse, so the latter is the same)"]},{"cell_type":"code","execution_count":20,"metadata":{},"outputs":[{"data":{"text/markdown":["#### Key Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Query Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Value Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Mixed Value Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["from typing_extensions import Literal\n","\n","\n","def stack_head_vector_from_cache(\n"," cache, activation_name: Literal[\"q\", \"k\", \"v\", \"z\"]\n",") -> TT[\"layer_and_head_index\", \"batch\", \"pos\", \"d_head\"]:\n"," \"\"\"Stacks the head vectors from the cache from a specific activation (key, query, value or mixed_value (z)) into a single tensor.\"\"\"\n"," stacked_head_vectors = torch.stack(\n"," [cache[activation_name, l] for l in range(model.cfg.n_layers)], dim=0\n"," )\n"," stacked_head_vectors = einops.rearrange(\n"," stacked_head_vectors,\n"," \"layer batch pos head_index d_head -> (layer head_index) batch pos d_head\",\n"," )\n"," return stacked_head_vectors\n","\n","\n","def attr_patch_head_vector(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n"," activation_name: Literal[\"q\", \"k\", \"v\", \"z\"],\n",") -> TT[\"component\", \"pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_vector = stack_head_vector_from_cache(clean_cache, activation_name)\n"," corrupted_head_vector = stack_head_vector_from_cache(\n"," corrupted_cache, activation_name\n"," )\n"," corrupted_grad_head_vector = stack_head_vector_from_cache(\n"," corrupted_grad_cache, activation_name\n"," )\n"," head_vector_attr = einops.reduce(\n"," corrupted_grad_head_vector * (clean_head_vector - corrupted_head_vector),\n"," \"component batch pos d_head -> component pos\",\n"," \"sum\",\n"," )\n"," return head_vector_attr, labels\n","\n","\n","head_vector_attr_dict = {}\n","for activation_name, activation_name_full in [\n"," (\"k\", \"Key\"),\n"," (\"q\", \"Query\"),\n"," (\"v\", \"Value\"),\n"," (\"z\", \"Mixed Value\"),\n","]:\n"," display(Markdown(f\"#### {activation_name_full} Head Vector Attribution Patching\"))\n"," head_vector_attr_dict[activation_name], head_vector_labels = attr_patch_head_vector(\n"," clean_cache, corrupted_cache, corrupted_grad_cache, activation_name\n"," )\n"," imshow(\n"," head_vector_attr_dict[activation_name],\n"," y=head_vector_labels,\n"," yaxis=\"Component\",\n"," xaxis=\"Position\",\n"," title=f\"{activation_name_full} Attribution Patching\",\n"," )\n"," sum_head_vector_attr = einops.reduce(\n"," head_vector_attr_dict[activation_name],\n"," \"(layer head) pos -> layer head\",\n"," \"sum\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n"," )\n"," imshow(\n"," sum_head_vector_attr,\n"," yaxis=\"Layer\",\n"," xaxis=\"Head Index\",\n"," title=f\"{activation_name_full} Attribution Patching Sum Over Pos\",\n"," )"]},{"cell_type":"code","execution_count":21,"metadata":{},"outputs":[{"data":{"text/markdown":["### Head Pattern Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"}],"source":["from typing_extensions import Literal\n","\n","\n","def stack_head_pattern_from_cache(\n"," cache,\n",") -> TT[\"layer_and_head_index\", \"batch\", \"dest_pos\", \"src_pos\"]:\n"," \"\"\"Stacks the head patterns from the cache into a single tensor.\"\"\"\n"," stacked_head_pattern = torch.stack(\n"," [cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0\n"," )\n"," stacked_head_pattern = einops.rearrange(\n"," stacked_head_pattern,\n"," \"layer batch head_index dest_pos src_pos -> (layer head_index) batch dest_pos src_pos\",\n"," )\n"," return stacked_head_pattern\n","\n","\n","def attr_patch_head_pattern(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"component\", \"dest_pos\", \"src_pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_pattern = stack_head_pattern_from_cache(clean_cache)\n"," corrupted_head_pattern = stack_head_pattern_from_cache(corrupted_cache)\n"," corrupted_grad_head_pattern = stack_head_pattern_from_cache(corrupted_grad_cache)\n"," head_pattern_attr = einops.reduce(\n"," corrupted_grad_head_pattern * (clean_head_pattern - corrupted_head_pattern),\n"," \"component batch dest_pos src_pos -> component dest_pos src_pos\",\n"," \"sum\",\n"," )\n"," return head_pattern_attr, labels\n","\n","\n","head_pattern_attr, labels = attr_patch_head_pattern(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","\n","plot_attention_attr(\n"," einops.rearrange(\n"," head_pattern_attr,\n"," \"(layer head) dest src -> layer head dest src\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n"," ),\n"," clean_tokens,\n"," index=0,\n"," title=\"Head Pattern Attribution Patching\",\n",")"]},{"cell_type":"code","execution_count":22,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_head_vector_grad_input_from_grad_cache(\n"," grad_cache: ActivationCache, activation_name: Literal[\"q\", \"k\", \"v\"], layer: int\n",") -> TT[\"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," vector_grad = grad_cache[activation_name, layer]\n"," ln_scales = grad_cache[\"scale\", layer, \"ln1\"]\n"," attn_layer_object = model.blocks[layer].attn\n"," if activation_name == \"q\":\n"," W = attn_layer_object.W_Q\n"," elif activation_name == \"k\":\n"," W = attn_layer_object.W_K\n"," elif activation_name == \"v\":\n"," W = attn_layer_object.W_V\n"," else:\n"," raise ValueError(\"Invalid activation name\")\n","\n"," return einsum(\n"," \"batch pos head_index d_head, batch pos, head_index d_model d_head -> batch pos head_index d_model\",\n"," vector_grad,\n"," ln_scales.squeeze(-1),\n"," W,\n"," )\n","\n","\n","def get_stacked_head_vector_grad_input(\n"," grad_cache, activation_name: Literal[\"q\", \"k\", \"v\"]\n",") -> TT[\"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack(\n"," [\n"," get_head_vector_grad_input_from_grad_cache(grad_cache, activation_name, l)\n"," for l in range(model.cfg.n_layers)\n"," ],\n"," dim=0,\n"," )\n","\n","\n","def get_full_vector_grad_input(\n"," grad_cache,\n",") -> TT[\"qkv\", \"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack(\n"," [\n"," get_stacked_head_vector_grad_input(grad_cache, activation_name)\n"," for activation_name in [\"q\", \"k\", \"v\"]\n"," ],\n"," dim=0,\n"," )\n","\n","\n","def attr_patch_head_path(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"qkv\", \"dest_component\", \"src_component\", \"pos\"]:\n"," \"\"\"\n"," Computes the attribution patch along the path between each pair of heads.\n","\n"," Sets this to zero for the path from any late head to any early head\n","\n"," \"\"\"\n"," start_labels = HEAD_NAMES\n"," end_labels = HEAD_NAMES_QKV\n"," full_vector_grad_input = get_full_vector_grad_input(corrupted_grad_cache)\n"," clean_head_result_stack = clean_cache.stack_head_results(-1)\n"," corrupted_head_result_stack = corrupted_cache.stack_head_results(-1)\n"," diff_head_result = einops.rearrange(\n"," clean_head_result_stack - corrupted_head_result_stack,\n"," \"(layer head_index) batch pos d_model -> layer batch pos head_index d_model\",\n"," layer=model.cfg.n_layers,\n"," head_index=model.cfg.n_heads,\n"," )\n"," path_attr = einsum(\n"," \"qkv layer_end batch pos head_end d_model, layer_start batch pos head_start d_model -> qkv layer_end head_end layer_start head_start pos\",\n"," full_vector_grad_input,\n"," diff_head_result,\n"," )\n"," correct_layer_order_mask = (\n"," torch.arange(model.cfg.n_layers)[None, :, None, None, None, None]\n"," > torch.arange(model.cfg.n_layers)[None, None, None, :, None, None]\n"," ).to(path_attr.device)\n"," zero = torch.zeros(1, device=path_attr.device)\n"," path_attr = torch.where(correct_layer_order_mask, path_attr, zero)\n","\n"," path_attr = einops.rearrange(\n"," path_attr,\n"," \"qkv layer_end head_end layer_start head_start pos -> (layer_end head_end qkv) (layer_start head_start) pos\",\n"," )\n"," return path_attr, end_labels, start_labels\n","\n","\n","head_path_attr, end_labels, start_labels = attr_patch_head_path(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","imshow(\n"," head_path_attr.sum(-1),\n"," y=end_labels,\n"," yaxis=\"Path End (Head Input)\",\n"," x=start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=\"Head Path Attribution Patching\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" This is hard to parse. Here's an experiment with filtering for the most important heads and showing their paths."]},{"cell_type":"code","execution_count":23,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["head_out_values, head_out_indices = head_out_attr.sum(-1).abs().sort(descending=True)\n","line(head_out_values)\n","top_head_indices = head_out_indices[:22].sort().values\n","top_end_indices = []\n","top_end_labels = []\n","top_start_indices = []\n","top_start_labels = []\n","for i in top_head_indices:\n"," i = i.item()\n"," top_start_indices.append(i)\n"," top_start_labels.append(start_labels[i])\n"," for j in range(3):\n"," top_end_indices.append(3 * i + j)\n"," top_end_labels.append(end_labels[3 * i + j])\n","\n","imshow(\n"," head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1),\n"," y=top_end_labels,\n"," yaxis=\"Path End (Head Input)\",\n"," x=top_start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=\"Head Path Attribution Patching (Filtered for Top Heads)\",\n",")"]},{"cell_type":"code","execution_count":24,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["for j, composition_type in enumerate([\"Query\", \"Key\", \"Value\"]):\n"," imshow(\n"," head_path_attr[top_end_indices, :][:, top_start_indices][j::3].sum(-1),\n"," y=top_end_labels[j::3],\n"," yaxis=\"Path End (Head Input)\",\n"," x=top_start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=f\"Head Path to {composition_type} Attribution Patching (Filtered for Top Heads)\",\n"," )"]},{"cell_type":"code","execution_count":25,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["top_head_path_attr = einops.rearrange(\n"," head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1),\n"," \"(head_end qkv) head_start -> qkv head_end head_start\",\n"," qkv=3,\n",")\n","imshow(\n"," top_head_path_attr,\n"," y=[i[:-1] for i in top_end_labels[::3]],\n"," yaxis=\"Path End (Head Input)\",\n"," x=top_start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=f\"Head Path Attribution Patching (Filtered for Top Heads)\",\n"," facet_col=0,\n"," facet_labels=[\"Query\", \"Key\", \"Value\"],\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" Let's now dive into 3 interesting heads: L5H5 (induction head), L8H6 (S-Inhibition Head), L9H9 (Name Mover) and look at their input and output paths (note - Q input means )"]},{"cell_type":"code","execution_count":26,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["interesting_heads = [\n"," 5 * model.cfg.n_heads + 5,\n"," 8 * model.cfg.n_heads + 6,\n"," 9 * model.cfg.n_heads + 9,\n","]\n","interesting_head_labels = [HEAD_NAMES[i] for i in interesting_heads]\n","for head_index, label in zip(interesting_heads, interesting_head_labels):\n"," in_paths = head_path_attr[3 * head_index : 3 * head_index + 3].sum(-1)\n"," out_paths = head_path_attr[:, head_index].sum(-1)\n"," out_paths = einops.rearrange(out_paths, \"(layer_head qkv) -> qkv layer_head\", qkv=3)\n"," all_paths = torch.cat([in_paths, out_paths], dim=0)\n"," all_paths = einops.rearrange(\n"," all_paths,\n"," \"path_type (layer head) -> path_type layer head\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n"," )\n"," imshow(\n"," all_paths,\n"," facet_col=0,\n"," facet_labels=[\n"," \"Query (In)\",\n"," \"Key (In)\",\n"," \"Value (In)\",\n"," \"Query (Out)\",\n"," \"Key (Out)\",\n"," \"Value (Out)\",\n"," ],\n"," title=f\"Input and Output Paths for head {label}\",\n"," yaxis=\"Layer\",\n"," xaxis=\"Head\",\n"," )"]},{"cell_type":"markdown","metadata":{},"source":[" ## Validating Attribution vs Activation Patching\n"," Let's now compare attribution and activation patching. Generally it's a decent approximation! The main place it fails is MLP0 and the residual stream\n"," My fuzzy intuition is that attribution patching works badly for \"big\" things which are poorly modelled as linear approximations, and works well for \"small\" things which are more like incremental changes. Anything involving replacing the embedding is a \"big\" thing, which includes residual streams, and in GPT-2 small MLP0 seems to be used as an \"extended embedding\" (where later layers use MLP0's output instead of the token embedding), so I also count it as big.\n"," See more discussion in the accompanying blog post!\n"]},{"cell_type":"markdown","metadata":{},"source":[" First do some refactoring to make attribution patching more generic. We make an attribution cache, which is an ActivationCache where each element is (clean_act - corrupted_act) * corrupted_grad, so that it's the per-element attribution for each activation. Thanks to linearity, we just compute things by adding stuff up along the relevant dimensions!"]},{"cell_type":"code","execution_count":27,"metadata":{},"outputs":[],"source":["attribution_cache_dict = {}\n","for key in corrupted_grad_cache.cache_dict.keys():\n"," attribution_cache_dict[key] = corrupted_grad_cache.cache_dict[key] * (\n"," clean_cache.cache_dict[key] - corrupted_cache.cache_dict[key]\n"," )\n","attr_cache = ActivationCache(attribution_cache_dict, model)"]},{"cell_type":"markdown","metadata":{},"source":[" By block: For each head we patch the starting residual stream, attention output + MLP output"]},{"cell_type":"code","execution_count":28,"metadata":{},"outputs":[],"source":["str_tokens = model.to_str_tokens(clean_tokens[0])\n","context_length = len(str_tokens)"]},{"cell_type":"code","execution_count":29,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"95a5290e11b64b6a95ef5dd37d027c7a","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/180 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_block_act_patch_result = patching.get_act_patch_block_every(\n"," model, corrupted_tokens, clean_cache, ioi_metric\n",")\n","imshow(\n"," every_block_act_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n"," title=\"Activation Patching Per Block\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n",")"]},{"cell_type":"code","execution_count":30,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_block_every(attr_cache):\n"," resid_pre_attr = einops.reduce(\n"," attr_cache.stack_activation(\"resid_pre\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," attn_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"attn_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," mlp_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"mlp_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n","\n"," every_block_attr_patch_result = torch.stack(\n"," [resid_pre_attr, attn_out_attr, mlp_out_attr], dim=0\n"," )\n"," return every_block_attr_patch_result\n","\n","\n","every_block_attr_patch_result = get_attr_patch_block_every(attr_cache)\n","imshow(\n"," every_block_attr_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n"," title=\"Attribution Patching Per Block\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n",")"]},{"cell_type":"code","execution_count":31,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(\n"," y=every_block_attr_patch_result.reshape(3, -1),\n"," x=every_block_act_patch_result.reshape(3, -1),\n"," facet_col=0,\n"," facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n"," title=\"Attribution vs Activation Patching Per Block\",\n"," xaxis=\"Activation Patch\",\n"," yaxis=\"Attribution Patch\",\n"," hover=[\n"," f\"Layer {l}, Position {p}, |{str_tokens[p]}|\"\n"," for l in range(model.cfg.n_layers)\n"," for p in range(context_length)\n"," ],\n"," color=einops.repeat(\n"," torch.arange(model.cfg.n_layers), \"layer -> (layer pos)\", pos=context_length\n"," ),\n"," color_continuous_scale=\"Portland\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" By head: For each head we patch the output, query, key, value or pattern. We do all positions at once so it's not super slow."]},{"cell_type":"code","execution_count":32,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"18b2e6b0985b40cd8c0cd1a16ba62975","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(\n"," model, corrupted_tokens, clean_cache, ioi_metric\n",")\n","imshow(\n"," every_head_all_pos_act_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Activation Patching Per Head (All Pos)\",\n"," xaxis=\"Head\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n",")"]},{"cell_type":"code","execution_count":33,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_all_pos_every(attr_cache):\n"," head_out_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_q_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_k_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_v_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack(\n"," [\n"," head_out_all_pos_attr,\n"," head_q_all_pos_attr,\n"," head_k_all_pos_attr,\n"," head_v_all_pos_attr,\n"," head_pattern_all_pos_attr,\n"," ]\n"," )\n","\n","\n","every_head_all_pos_attr_patch_result = get_attr_patch_attn_head_all_pos_every(\n"," attr_cache\n",")\n","imshow(\n"," every_head_all_pos_attr_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution Patching Per Head (All Pos)\",\n"," xaxis=\"Head\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n",")"]},{"cell_type":"code","execution_count":34,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(\n"," y=every_head_all_pos_attr_patch_result.reshape(5, -1),\n"," x=every_head_all_pos_act_patch_result.reshape(5, -1),\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution vs Activation Patching Per Head (All Pos)\",\n"," xaxis=\"Activation Patch\",\n"," yaxis=\"Attribution Patch\",\n"," include_diag=True,\n"," hover=head_out_labels,\n"," color=einops.repeat(\n"," torch.arange(model.cfg.n_layers),\n"," \"layer -> (layer head)\",\n"," head=model.cfg.n_heads,\n"," ),\n"," color_continuous_scale=\"Portland\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" We see pretty good results in general, but significant errors for heads L5H5 on query and moderate errors for head L10H7 on query and key, and moderate errors for head L11H10 on key. But each of these is fine for pattern and output. My guess is that the problem is that these have pretty saturated attention on a single token, and the linear approximation is thus not great on the attention calculation here, but I'm not sure. When we plot the attention patterns, we do see this!\n"," Note that the axis labels are for the *first* prompt's tokens, but each facet is a different prompt, so this is somewhat inaccurate. In particular, every odd facet has indirect object and subject in the opposite order (IO first). But otherwise everything lines up between the prompts"]},{"cell_type":"code","execution_count":35,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["graph_tok_labels = [\n"," f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))\n","]\n","imshow(\n"," clean_cache[\"pattern\", 5][:, 5],\n"," x=graph_tok_labels,\n"," y=graph_tok_labels,\n"," facet_col=0,\n"," title=\"Attention for Head L5H5\",\n"," facet_name=\"Prompt\",\n",")\n","imshow(\n"," clean_cache[\"pattern\", 10][:, 7],\n"," x=graph_tok_labels,\n"," y=graph_tok_labels,\n"," facet_col=0,\n"," title=\"Attention for Head L10H7\",\n"," facet_name=\"Prompt\",\n",")\n","imshow(\n"," clean_cache[\"pattern\", 11][:, 10],\n"," x=graph_tok_labels,\n"," y=graph_tok_labels,\n"," facet_col=0,\n"," title=\"Attention for Head L11H10\",\n"," facet_name=\"Prompt\",\n",")\n","\n","\n","# [markdown]"]},{"cell_type":"code","execution_count":36,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"06f39489001845849fbc7446a07066f4","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/2160 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_by_pos_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(\n"," model, corrupted_tokens, clean_cache, ioi_metric\n",")\n","every_head_by_pos_act_patch_result = einops.rearrange(\n"," every_head_by_pos_act_patch_result,\n"," \"act_type layer pos head -> act_type (layer head) pos\",\n",")\n","imshow(\n"," every_head_by_pos_act_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Activation Patching Per Head (By Pos)\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer & Head\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n"," y=head_out_labels,\n",")"]},{"cell_type":"code","execution_count":37,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_by_pos_every(attr_cache):\n"," head_out_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_q_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_k_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_v_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer dest_pos head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack(\n"," [\n"," head_out_by_pos_attr,\n"," head_q_by_pos_attr,\n"," head_k_by_pos_attr,\n"," head_v_by_pos_attr,\n"," head_pattern_by_pos_attr,\n"," ]\n"," )\n","\n","\n","every_head_by_pos_attr_patch_result = get_attr_patch_attn_head_by_pos_every(attr_cache)\n","every_head_by_pos_attr_patch_result = einops.rearrange(\n"," every_head_by_pos_attr_patch_result,\n"," \"act_type layer pos head -> act_type (layer head) pos\",\n",")\n","imshow(\n"," every_head_by_pos_attr_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution Patching Per Head (By Pos)\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer & Head\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n"," y=head_out_labels,\n",")"]},{"cell_type":"code","execution_count":38,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(\n"," y=every_head_by_pos_attr_patch_result.reshape(5, -1),\n"," x=every_head_by_pos_act_patch_result.reshape(5, -1),\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution vs Activation Patching Per Head (by Pos)\",\n"," xaxis=\"Activation Patch\",\n"," yaxis=\"Attribution Patch\",\n"," include_diag=True,\n"," hover=[f\"{label} {tok}\" for label in head_out_labels for tok in graph_tok_labels],\n"," color=einops.repeat(\n"," torch.arange(model.cfg.n_layers),\n"," \"layer -> (layer head pos)\",\n"," head=model.cfg.n_heads,\n"," pos=15,\n"," ),\n"," color_continuous_scale=\"Portland\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Factual Knowledge Patching Example\n"," Incomplete, but maybe of interest!\n"," Note that I have better results with the corrupted prompt as having random words rather than Colosseum."]},{"cell_type":"code","execution_count":39,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-xl into HookedTransformer\n","Tokenized prompt: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Paris']\n"]},{"data":{"text/html":["
Performance on answer token:\n","Rank: 0        Logit: 20.73 Prob: 95.80% Token: | Paris|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.73\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m95.80\u001b[0m\u001b[1m% Token: | Paris|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.73 Prob: 95.80% Token: | Paris|\n","Top 1th token. Logit: 16.49 Prob: 1.39% Token: | E|\n","Top 2th token. Logit: 14.69 Prob: 0.23% Token: | the|\n","Top 3th token. Logit: 14.58 Prob: 0.21% Token: | É|\n","Top 4th token. Logit: 14.44 Prob: 0.18% Token: | France|\n","Top 5th token. Logit: 14.36 Prob: 0.16% Token: | Mont|\n","Top 6th token. Logit: 13.77 Prob: 0.09% Token: | Le|\n","Top 7th token. Logit: 13.66 Prob: 0.08% Token: | Ang|\n","Top 8th token. Logit: 13.43 Prob: 0.06% Token: | V|\n","Top 9th token. Logit: 13.42 Prob: 0.06% Token: | Stras|\n"]},{"data":{"text/html":["
Ranks of the answer tokens: [(' Paris', 0)]\n","
\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Paris'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Tokenized prompt: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Rome']\n"]},{"data":{"text/html":["
Performance on answer token:\n","Rank: 0        Logit: 20.02 Prob: 83.70% Token: | Rome|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.02\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m83.70\u001b[0m\u001b[1m% Token: | Rome|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.02 Prob: 83.70% Token: | Rome|\n","Top 1th token. Logit: 17.03 Prob: 4.23% Token: | Naples|\n","Top 2th token. Logit: 16.85 Prob: 3.51% Token: | Pompe|\n","Top 3th token. Logit: 16.14 Prob: 1.73% Token: | Ver|\n","Top 4th token. Logit: 15.87 Prob: 1.32% Token: | Florence|\n","Top 5th token. Logit: 14.77 Prob: 0.44% Token: | Roma|\n","Top 6th token. Logit: 14.68 Prob: 0.40% Token: | Milan|\n","Top 7th token. Logit: 14.66 Prob: 0.39% Token: | ancient|\n","Top 8th token. Logit: 14.37 Prob: 0.29% Token: | Pal|\n","Top 9th token. Logit: 14.30 Prob: 0.27% Token: | Constantinople|\n"]},{"data":{"text/html":["
Ranks of the answer tokens: [(' Rome', 0)]\n","
\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Rome'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"}],"source":["gpt2_xl = HookedTransformer.from_pretrained(\"gpt2-xl\")\n","clean_prompt = \"The Eiffel Tower is located in the city of\"\n","clean_answer = \" Paris\"\n","# corrupted_prompt = \"The red brown fox jumps is located in the city of\"\n","corrupted_prompt = \"The Colosseum is located in the city of\"\n","corrupted_answer = \" Rome\"\n","utils.test_prompt(clean_prompt, clean_answer, gpt2_xl)\n","utils.test_prompt(corrupted_prompt, corrupted_answer, gpt2_xl)"]},{"cell_type":"code","execution_count":40,"metadata":{},"outputs":[],"source":["clean_answer_index = gpt2_xl.to_single_token(clean_answer)\n","corrupted_answer_index = gpt2_xl.to_single_token(corrupted_answer)\n","\n","\n","def factual_logit_diff(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return logits[0, -1, clean_answer_index] - logits[0, -1, corrupted_answer_index]"]},{"cell_type":"code","execution_count":41,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 10.634519577026367\n","Corrupted logit diff: -8.988396644592285\n","Clean Metric: tensor(1., device='cuda:0', grad_fn=)\n","Corrupted Metric: tensor(0., device='cuda:0', grad_fn=)\n"]}],"source":["clean_logits, clean_cache = gpt2_xl.run_with_cache(clean_prompt)\n","CLEAN_LOGIT_DIFF_FACTUAL = factual_logit_diff(clean_logits).item()\n","corrupted_logits, _ = gpt2_xl.run_with_cache(corrupted_prompt)\n","CORRUPTED_LOGIT_DIFF_FACTUAL = factual_logit_diff(corrupted_logits).item()\n","\n","\n","def factual_metric(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return (factual_logit_diff(logits) - CORRUPTED_LOGIT_DIFF_FACTUAL) / (\n"," CLEAN_LOGIT_DIFF_FACTUAL - CORRUPTED_LOGIT_DIFF_FACTUAL\n"," )\n","\n","\n","print(\"Clean logit diff:\", CLEAN_LOGIT_DIFF_FACTUAL)\n","print(\"Corrupted logit diff:\", CORRUPTED_LOGIT_DIFF_FACTUAL)\n","print(\"Clean Metric:\", factual_metric(clean_logits))\n","print(\"Corrupted Metric:\", factual_metric(corrupted_logits))"]},{"cell_type":"code","execution_count":42,"metadata":{},"outputs":[],"source":["# corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(gpt2_xl, corrupted_prompt, factual_metric)"]},{"cell_type":"code","execution_count":43,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Corrupted: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n"]}],"source":["clean_tokens = gpt2_xl.to_tokens(clean_prompt)\n","clean_str_tokens = gpt2_xl.to_str_tokens(clean_prompt)\n","corrupted_tokens = gpt2_xl.to_tokens(corrupted_prompt)\n","corrupted_str_tokens = gpt2_xl.to_str_tokens(corrupted_prompt)\n","print(\"Clean:\", clean_str_tokens)\n","print(\"Corrupted:\", corrupted_str_tokens)"]},{"cell_type":"code","execution_count":44,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"b767eef7a3cd49b9b3cb6e5301463f08","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/48 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def act_patch_residual(clean_cache, corrupted_tokens, model: HookedTransformer, metric):\n"," if len(corrupted_tokens.shape) == 2:\n"," corrupted_tokens = corrupted_tokens[0]\n"," residual_patches = torch.zeros(\n"," (model.cfg.n_layers, len(corrupted_tokens)), device=model.cfg.device\n"," )\n","\n"," def residual_hook(resid_pre, hook, layer, pos):\n"," resid_pre[:, pos, :] = clean_cache[\"resid_pre\", layer][:, pos, :]\n"," return resid_pre\n","\n"," for layer in tqdm.tqdm(range(model.cfg.n_layers)):\n"," for pos in range(len(corrupted_tokens)):\n"," patched_logits = model.run_with_hooks(\n"," corrupted_tokens,\n"," fwd_hooks=[\n"," (\n"," f\"blocks.{layer}.hook_resid_pre\",\n"," partial(residual_hook, layer=layer, pos=pos),\n"," )\n"," ],\n"," )\n"," residual_patches[layer, pos] = metric(patched_logits).item()\n"," return residual_patches\n","\n","\n","residual_act_patch = act_patch_residual(\n"," clean_cache, corrupted_tokens, gpt2_xl, factual_metric\n",")\n","\n","imshow(\n"," residual_act_patch,\n"," title=\"Factual Recall Patching (Residual)\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer\",\n"," x=clean_str_tokens,\n",")"]}],"metadata":{"kernelspec":{"display_name":"base","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.11.8"},"orig_nbformat":4,"vscode":{"interpreter":{"hash":"d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"}}},"nbformat":4,"nbformat_minor":2} diff --git a/demos/Main_Demo.ipynb b/demos/Main_Demo.ipynb index c871a6bd8..b2f89b695 100644 --- a/demos/Main_Demo.ipynb +++ b/demos/Main_Demo.ipynb @@ -74,8 +74,7 @@ " ip.extension_manager.load('autoreload')\n", " %autoreload 2\n", " \n", - "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", - "IN_GITHUB = True\n" + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n" ] }, { From 9397815b371340de9ffd8fe3bd6faf9087630e4b Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Thu, 25 Apr 2024 01:15:09 +0200 Subject: [PATCH 06/42] rearranged tests --- makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/makefile b/makefile index 32a11f52f..92fa6c559 100644 --- a/makefile +++ b/makefile @@ -18,11 +18,10 @@ docstring-test: poetry run pytest transformer_lens/ notebook-test: - poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Attribution_Patching_Demo.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Grokking_Demo.ipynb poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Exploratory_Analysis_Demo.ipynb poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Main_Demo.ipynb - poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Grokking_Demo.ipynb poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Head_Detector_Demo.ipynb poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Interactive_Neuroscope.ipynb poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/LLaMA.ipynb @@ -39,6 +38,7 @@ notebook-test: # Causes CI to hang poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Activation_Patching_in_TL_Demo.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Attribution_Patching_Demo.ipynb test: make unit-test From 430684e1dd3dbf9c591f1f1e85c99e9b6201b5d0 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Thu, 25 Apr 2024 01:29:00 +0200 Subject: [PATCH 07/42] updated header --- demos/Grokking_Demo.ipynb | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/demos/Grokking_Demo.ipynb b/demos/Grokking_Demo.ipynb index 7e3792095..87733d159 100644 --- a/demos/Grokking_Demo.ipynb +++ b/demos/Grokking_Demo.ipynb @@ -53,13 +53,14 @@ ], "source": [ "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "import os\n", + "\n", "DEVELOPMENT_MODE = True\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", "try:\n", " import google.colab\n", " IN_COLAB = True\n", " print(\"Running as a Colab notebook\")\n", - " %pip install transformer-lens\n", - " %pip install circuitsvis\n", " \n", " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", " # # Install another version of node that makes PySvelte work way faster\n", @@ -73,7 +74,11 @@ " ipython = get_ipython()\n", " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", " ipython.magic(\"load_ext autoreload\")\n", - " ipython.magic(\"autoreload 2\")" + " ipython.magic(\"autoreload 2\")\n", + " \n", + "if IN_COLAB or IN_GITHUB:\n", + " %pip install transformer-lens\n", + " %pip install circuitsvis" ] }, { From 316cb45e6df72eda400ef28e8b4c27fe523bb7e4 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Sat, 27 Apr 2024 00:03:40 +0200 Subject: [PATCH 08/42] updated grokking demo --- demos/Grokking_Demo.ipynb | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/demos/Grokking_Demo.ipynb b/demos/Grokking_Demo.ipynb index 87733d159..473d7ca82 100644 --- a/demos/Grokking_Demo.ipynb +++ b/demos/Grokking_Demo.ipynb @@ -77,7 +77,7 @@ " ipython.magic(\"autoreload 2\")\n", " \n", "if IN_COLAB or IN_GITHUB:\n", - " %pip install transformer-lens\n", + " %pip install transformer_lens\n", " %pip install circuitsvis" ] }, @@ -159,7 +159,10 @@ " HookedRootModule,\n", " HookPoint,\n", ") # Hooking utilities\n", - "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache" + "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache\n", + "\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" ] }, { @@ -286,7 +289,7 @@ } ], "source": [ - "dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).cuda()\n", + "dataset = torch.stack([a_vector, b_vector, equals_vector], dim=1).to(device)\n", "print(dataset[:5])\n", "print(dataset.shape)" ] @@ -391,7 +394,7 @@ " d_vocab_out=p,\n", " n_ctx=3,\n", " init_weights=True,\n", - " device=\"cuda\",\n", + " device=device,\n", " seed = 999,\n", ")" ] @@ -1650,7 +1653,7 @@ " fourier_basis_names.append(f\"Sin {freq}\")\n", " fourier_basis.append(torch.cos(torch.arange(p)*2 * torch.pi * freq / p))\n", " fourier_basis_names.append(f\"Cos {freq}\")\n", - "fourier_basis = torch.stack(fourier_basis, dim=0).cuda()\n", + "fourier_basis = torch.stack(fourier_basis, dim=0).to(device)\n", "fourier_basis = fourier_basis/fourier_basis.norm(dim=-1, keepdim=True)\n", "imshow(fourier_basis, xaxis=\"Input\", yaxis=\"Component\", y=fourier_basis_names)" ] @@ -2399,7 +2402,7 @@ } ], "source": [ - "neuron_freq_norm = torch.zeros(p//2, model.cfg.d_mlp).cuda()\n", + "neuron_freq_norm = torch.zeros(p//2, model.cfg.d_mlp).to(device)\n", "for freq in range(0, p//2):\n", " for x in [0, 2*(freq+1) - 1, 2*(freq+1)]:\n", " for y in [0, 2*(freq+1) - 1, 2*(freq+1)]:\n", @@ -2998,7 +3001,7 @@ " a = torch.arange(p)[:, None, None]\n", " b = torch.arange(p)[None, :, None]\n", " c = torch.arange(p)[None, None, :]\n", - " cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).cuda()\n", + " cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).to(device)\n", " cube_predicted_logits /= cube_predicted_logits.norm()\n", " coses[freq] = cube_predicted_logits" ] @@ -3129,7 +3132,7 @@ " a = torch.arange(p)[:, None, None]\n", " b = torch.arange(p)[None, :, None]\n", " c = torch.arange(p)[None, None, :]\n", - " cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).cuda()\n", + " cube_predicted_logits = torch.cos(freq * 2 * torch.pi / p * (a + b - c)).to(device)\n", " cube_predicted_logits /= cube_predicted_logits.norm()\n", " cos_cube.append(cube_predicted_logits)\n", "cos_cube = torch.stack(cos_cube, dim=0)\n", @@ -3491,11 +3494,11 @@ "a = torch.arange(p)[:, None]\n", "b = torch.arange(p)[None, :]\n", "for freq in key_freqs:\n", - " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)\n", " cos_apb_vec /= cos_apb_vec.norm()\n", " cos_apb_vec = einops.rearrange(cos_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec\n", - " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)\n", " sin_apb_vec /= sin_apb_vec.norm()\n", " sin_apb_vec = einops.rearrange(sin_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec\n", @@ -3560,11 +3563,11 @@ " a = torch.arange(p)[:, None]\n", " b = torch.arange(p)[None, :]\n", " for freq in key_freqs:\n", - " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)\n", " cos_apb_vec /= cos_apb_vec.norm()\n", " cos_apb_vec = einops.rearrange(cos_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec\n", - " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)\n", " sin_apb_vec /= sin_apb_vec.norm()\n", " sin_apb_vec = einops.rearrange(sin_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec\n", @@ -3723,11 +3726,11 @@ "a = torch.arange(p)[:, None]\n", "b = torch.arange(p)[None, :]\n", "for freq in key_freqs:\n", - " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)\n", " cos_apb_vec /= cos_apb_vec.norm()\n", " cos_apb_vec = einops.rearrange(cos_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec\n", - " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)\n", " sin_apb_vec /= sin_apb_vec.norm()\n", " sin_apb_vec = einops.rearrange(sin_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec\n", @@ -3770,11 +3773,11 @@ " a = torch.arange(p)[:, None]\n", " b = torch.arange(p)[None, :]\n", " for freq in key_freqs:\n", - " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " cos_apb_vec = torch.cos(freq * 2 * torch.pi / p * (a + b)).to(device)\n", " cos_apb_vec /= cos_apb_vec.norm()\n", " cos_apb_vec = einops.rearrange(cos_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * cos_apb_vec).sum(dim=0) * cos_apb_vec\n", - " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).cuda()\n", + " sin_apb_vec = torch.sin(freq * 2 * torch.pi / p * (a + b)).to(device)\n", " sin_apb_vec /= sin_apb_vec.norm()\n", " sin_apb_vec = einops.rearrange(sin_apb_vec, \"a b -> (a b) 1\")\n", " approx_neuron_acts += (neuron_acts * sin_apb_vec).sum(dim=0) * sin_apb_vec\n", From 81a27f730b1516f213a5c15c1d5919856e2129cc Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Sat, 27 Apr 2024 00:27:42 +0200 Subject: [PATCH 09/42] updated bert for testing --- demos/BERT.ipynb | 10 +++++++--- makefile | 4 ++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/demos/BERT.ipynb b/demos/BERT.ipynb index 581a6365d..20b9c0ccc 100644 --- a/demos/BERT.ipynb +++ b/demos/BERT.ipynb @@ -41,14 +41,14 @@ } ], "source": [ + "import os\n", "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", "DEVELOPMENT_MODE = False\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", "try:\n", " import google.colab\n", " IN_COLAB = True\n", " print(\"Running as a Colab notebook\")\n", - " %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n", - " %pip install circuitsvis\n", " \n", " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", " # # Install another version of node that makes PySvelte work way faster\n", @@ -62,7 +62,11 @@ " ipython = get_ipython()\n", " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", " ipython.magic(\"load_ext autoreload\")\n", - " ipython.magic(\"autoreload 2\")" + " ipython.magic(\"autoreload 2\")\n", + "\n", + "if IN_COLAB or IN_GITHUB:\n", + " %pip install transformer_lens\n", + " %pip install circuitsvis" ] }, { diff --git a/makefile b/makefile index 92fa6c559..a28c07f54 100644 --- a/makefile +++ b/makefile @@ -18,7 +18,7 @@ docstring-test: poetry run pytest transformer_lens/ notebook-test: - poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Grokking_Demo.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/BERT.ipynb poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Exploratory_Analysis_Demo.ipynb poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Main_Demo.ipynb @@ -34,11 +34,11 @@ notebook-test: poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Tracr_to_Transformer_Lens_Demo.ipynb # Contains failing cells - poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/BERT.ipynb # Causes CI to hang poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Activation_Patching_in_TL_Demo.ipynb poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Attribution_Patching_Demo.ipynb + poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Grokking_Demo.ipynb test: make unit-test From 7d64be0c042d6738eed6ce80fab3a79d1ed77e3e Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Sat, 27 Apr 2024 00:36:50 +0200 Subject: [PATCH 10/42] updated bert demo --- demos/BERT.ipynb | 209 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 184 insertions(+), 25 deletions(-) diff --git a/demos/BERT.ipynb b/demos/BERT.ipynb index 20b9c0ccc..37b5a78a5 100644 --- a/demos/BERT.ipynb +++ b/demos/BERT.ipynb @@ -29,33 +29,122 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Running as a Jupyter notebook - intended for development only!\n" + "Requirement already satisfied: transformer_lens in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (0.0.0)\n", + "Requirement already satisfied: accelerate>=0.23.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.29.1)\n", + "Requirement already satisfied: beartype<0.15.0,>=0.14.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.14.1)\n", + "Requirement already satisfied: better-abc<0.0.4,>=0.0.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.0.3)\n", + "Requirement already satisfied: datasets>=2.7.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (2.18.0)\n", + "Requirement already satisfied: einops>=0.6.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.7.0)\n", + "Requirement already satisfied: fancy-einsum>=0.0.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.0.3)\n", + "Requirement already satisfied: jaxtyping>=0.2.11 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.2.19)\n", + "Requirement already satisfied: numpy>=1.24 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (1.26.4)\n", + "Requirement already satisfied: pandas>=1.1.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (2.0.3)\n", + "Requirement already satisfied: rich>=12.6.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (13.7.1)\n", + "Requirement already satisfied: sentencepiece in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.2.0)\n", + "Requirement already satisfied: torch!=2.0,!=2.1.0,>=1.10 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (2.1.2)\n", + "Requirement already satisfied: tqdm>=4.64.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (4.66.2)\n", + "Requirement already satisfied: transformers>=4.37.2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (4.39.3)\n", + "Requirement already satisfied: typing-extensions in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (4.11.0)\n", + "Requirement already satisfied: wandb>=0.13.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.16.6)\n", + "Requirement already satisfied: packaging>=20.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (24.0)\n", + "Requirement already satisfied: psutil in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (5.9.8)\n", + "Requirement already satisfied: pyyaml in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (6.0.1)\n", + "Requirement already satisfied: huggingface-hub in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (0.22.2)\n", + "Requirement already satisfied: safetensors>=0.3.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (0.4.2)\n", + "Requirement already satisfied: filelock in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (3.13.3)\n", + "Requirement already satisfied: pyarrow>=12.0.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (15.0.2)\n", + "Requirement already satisfied: pyarrow-hotfix in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (0.6)\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (0.3.8)\n", + "Requirement already satisfied: requests>=2.19.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (2.31.0)\n", + "Requirement already satisfied: xxhash in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (3.4.1)\n", + "Requirement already satisfied: multiprocess in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (0.70.16)\n", + "Requirement already satisfied: fsspec<=2024.2.0,>=2023.1.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from fsspec[http]<=2024.2.0,>=2023.1.0->datasets>=2.7.1->transformer_lens) (2024.2.0)\n", + "Requirement already satisfied: aiohttp in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (3.9.3)\n", + "Requirement already satisfied: typeguard>=2.13.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from jaxtyping>=0.2.11->transformer_lens) (4.2.1)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from pandas>=1.1.5->transformer_lens) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from pandas>=1.1.5->transformer_lens) (2024.1)\n", + "Requirement already satisfied: tzdata>=2022.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from pandas>=1.1.5->transformer_lens) (2024.1)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from rich>=12.6.0->transformer_lens) (2.2.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from rich>=12.6.0->transformer_lens) (2.17.2)\n", + "Requirement already satisfied: sympy in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (1.12)\n", + "Requirement already satisfied: networkx in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (3.1)\n", + "Requirement already satisfied: jinja2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (3.1.3)\n", + "Requirement already satisfied: regex!=2019.12.17 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformers>=4.37.2->transformer_lens) (2023.12.25)\n", + "Requirement already satisfied: tokenizers<0.19,>=0.14 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformers>=4.37.2->transformer_lens) (0.15.2)\n", + "Requirement already satisfied: Click!=8.0.0,>=7.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (8.1.7)\n", + "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (3.1.43)\n", + "Requirement already satisfied: sentry-sdk>=1.0.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (1.44.1)\n", + "Requirement already satisfied: docker-pycreds>=0.4.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (0.4.0)\n", + "Requirement already satisfied: setproctitle in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (1.3.3)\n", + "Requirement already satisfied: setuptools in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (69.2.0)\n", + "Requirement already satisfied: appdirs>=1.4.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (1.4.4)\n", + "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (4.25.3)\n", + "Requirement already satisfied: six>=1.4.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from docker-pycreds>=0.4.0->wandb>=0.13.5->transformer_lens) (1.16.0)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (23.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.9.4)\n", + "Requirement already satisfied: gitdb<5,>=4.0.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens) (4.0.11)\n", + "Requirement already satisfied: mdurl~=0.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from markdown-it-py>=2.2.0->rich>=12.6.0->transformer_lens) (0.1.2)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (3.6)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (2.2.1)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (2024.2.2)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from jinja2->torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (2.1.5)\n", + "Requirement already satisfied: mpmath>=0.19 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from sympy->torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (1.3.0)\n", + "Requirement already satisfied: smmap<6,>=3.0.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens) (5.0.1)\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", + "Note: you may need to restart the kernel to use updated packages.\n", + "Requirement already satisfied: circuitsvis in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (1.43.2)\n", + "Requirement already satisfied: importlib-metadata>=5.1.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from circuitsvis) (7.1.0)\n", + "Requirement already satisfied: numpy>=1.24 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from circuitsvis) (1.26.4)\n", + "Requirement already satisfied: torch>=1.10 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from circuitsvis) (2.1.2)\n", + "Requirement already satisfied: zipp>=0.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from importlib-metadata>=5.1.0->circuitsvis) (3.18.1)\n", + "Requirement already satisfied: filelock in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (3.13.3)\n", + "Requirement already satisfied: typing-extensions in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (4.11.0)\n", + "Requirement already satisfied: sympy in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (1.12)\n", + "Requirement already satisfied: networkx in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (3.1)\n", + "Requirement already satisfied: jinja2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (3.1.3)\n", + "Requirement already satisfied: fsspec in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (2024.2.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from jinja2->torch>=1.10->circuitsvis) (2.1.5)\n", + "Requirement already satisfied: mpmath>=0.19 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from sympy->torch>=1.10->circuitsvis) (1.3.0)\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", + "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "import os\n", + "\n", "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", "DEVELOPMENT_MODE = False\n", "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", "try:\n", " import google.colab\n", + "\n", " IN_COLAB = True\n", " print(\"Running as a Colab notebook\")\n", - " \n", + "\n", " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", " # # Install another version of node that makes PySvelte work way faster\n", " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", "except:\n", " IN_COLAB = False\n", + "\n", + "if not IN_GITHUB and not IN_COLAB:\n", " print(\"Running as a Jupyter notebook - intended for development only!\")\n", " from IPython import get_ipython\n", "\n", @@ -71,7 +160,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -85,6 +174,7 @@ "source": [ "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n", "import plotly.io as pio\n", + "\n", "if IN_COLAB or not DEVELOPMENT_MODE:\n", " pio.renderers.default = \"colab\"\n", "else:\n", @@ -94,40 +184,41 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import circuitsvis as cv\n", + "\n", "# Testing that the library works\n", "cv.examples.hello(\"Neel\")" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -141,16 +232,16 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -171,22 +262,90 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "WARNING:root:HookedEncoder is still in beta. Please be aware that model preprocessing (e.g. LayerNorm folding) is not yet supported and backward compatibility is not guaranteed.\n" + "WARNING:root:Support for BERT in TransformerLens is currently experimental, until such a time when it has feature parity with HookedTransformer and has been tested on real research tasks. Until then, backward compatibility is not guaranteed. Please see the docs for information on the limitations of the current implementation.\n", + "If using BERT for interpretability research, keep in mind that BERT has some significant architectural differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning that the last LayerNorm in a block cannot be folded.\n" ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Moving model to device: cpu\n", - "Loaded pretrained model bert-base-cased into HookedTransformer\n" + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "30ca8295e43142db88991fcd24d522d5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "model.safetensors: 0%| | 0.00/436M [00:00 1\u001b[0m bert \u001b[38;5;241m=\u001b[39m \u001b[43mHookedEncoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mbert-base-cased\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m AutoTokenizer\u001b[38;5;241m.\u001b[39mfrom_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbert-base-cased\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/transformer_lens/HookedEncoder.py:251\u001b[0m, in \u001b[0;36mHookedEncoder.from_pretrained\u001b[0;34m(cls, model_name, checkpoint_index, checkpoint_value, hf_model, device, tokenizer, move_to_device, dtype, **from_pretrained_kwargs)\u001b[0m\n\u001b[1;32m 236\u001b[0m cfg \u001b[38;5;241m=\u001b[39m loading\u001b[38;5;241m.\u001b[39mget_pretrained_model_config(\n\u001b[1;32m 237\u001b[0m official_model_name,\n\u001b[1;32m 238\u001b[0m checkpoint_index\u001b[38;5;241m=\u001b[39mcheckpoint_index,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 244\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mfrom_pretrained_kwargs,\n\u001b[1;32m 245\u001b[0m )\n\u001b[1;32m 247\u001b[0m state_dict \u001b[38;5;241m=\u001b[39m loading\u001b[38;5;241m.\u001b[39mget_pretrained_state_dict(\n\u001b[1;32m 248\u001b[0m official_model_name, cfg, hf_model, dtype\u001b[38;5;241m=\u001b[39mdtype, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mfrom_pretrained_kwargs\n\u001b[1;32m 249\u001b[0m )\n\u001b[0;32m--> 251\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcfg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokenizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmove_to_device\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 253\u001b[0m model\u001b[38;5;241m.\u001b[39mload_state_dict(state_dict, strict\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 255\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m move_to_device:\n", + "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/transformer_lens/HookedEncoder.py:57\u001b[0m, in \u001b[0;36mHookedEncoder.__init__\u001b[0;34m(self, cfg, tokenizer, move_to_device, **kwargs)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mtokenizer_name \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 56\u001b[0m huggingface_token \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39menviron\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHF_TOKEN\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m---> 57\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtokenizer \u001b[38;5;241m=\u001b[39m \u001b[43mAutoTokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 58\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcfg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtokenizer_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 59\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhuggingface_token\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 60\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 61\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 62\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtokenizer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/transformers/models/auto/tokenization_auto.py:855\u001b[0m, in \u001b[0;36mAutoTokenizer.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, *inputs, **kwargs)\u001b[0m\n\u001b[1;32m 853\u001b[0m tokenizer_class_py, tokenizer_class_fast \u001b[38;5;241m=\u001b[39m TOKENIZER_MAPPING[\u001b[38;5;28mtype\u001b[39m(config)]\n\u001b[1;32m 854\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m tokenizer_class_fast \u001b[38;5;129;01mand\u001b[39;00m (use_fast \u001b[38;5;129;01mor\u001b[39;00m tokenizer_class_py \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m--> 855\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtokenizer_class_fast\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 856\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 857\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m tokenizer_class_py \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:2044\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, trust_remote_code, *init_inputs, **kwargs)\u001b[0m\n\u001b[1;32m 2042\u001b[0m resolved_vocab_files[file_id] \u001b[38;5;241m=\u001b[39m download_url(file_path, proxies\u001b[38;5;241m=\u001b[39mproxies)\n\u001b[1;32m 2043\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 2044\u001b[0m resolved_vocab_files[file_id] \u001b[38;5;241m=\u001b[39m \u001b[43mcached_file\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2045\u001b[0m \u001b[43m \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2046\u001b[0m \u001b[43m \u001b[49m\u001b[43mfile_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2047\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2048\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2049\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2050\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2051\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2052\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2053\u001b[0m \u001b[43m \u001b[49m\u001b[43muser_agent\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muser_agent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2054\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2055\u001b[0m \u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msubfolder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2056\u001b[0m \u001b[43m \u001b[49m\u001b[43m_raise_exceptions_for_gated_repo\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 2057\u001b[0m \u001b[43m \u001b[49m\u001b[43m_raise_exceptions_for_missing_entries\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 2058\u001b[0m \u001b[43m \u001b[49m\u001b[43m_raise_exceptions_for_connection_errors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 2059\u001b[0m \u001b[43m \u001b[49m\u001b[43m_commit_hash\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcommit_hash\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2060\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2061\u001b[0m commit_hash \u001b[38;5;241m=\u001b[39m extract_commit_hash(resolved_vocab_files[file_id], commit_hash)\n\u001b[1;32m 2063\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(unresolved_files) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n", + "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/transformers/utils/hub.py:398\u001b[0m, in \u001b[0;36mcached_file\u001b[0;34m(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_gated_repo, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)\u001b[0m\n\u001b[1;32m 395\u001b[0m user_agent \u001b[38;5;241m=\u001b[39m http_user_agent(user_agent)\n\u001b[1;32m 396\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 397\u001b[0m \u001b[38;5;66;03m# Load from URL or cache if already cached\u001b[39;00m\n\u001b[0;32m--> 398\u001b[0m resolved_file \u001b[38;5;241m=\u001b[39m \u001b[43mhf_hub_download\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 399\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath_or_repo_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 400\u001b[0m \u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 401\u001b[0m \u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43msubfolder\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 402\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 403\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 404\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 405\u001b[0m \u001b[43m \u001b[49m\u001b[43muser_agent\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muser_agent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 406\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 407\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 408\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 409\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 410\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 411\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 412\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m GatedRepoError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 413\u001b[0m resolved_file \u001b[38;5;241m=\u001b[39m _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)\n", + "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py:119\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.._inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_use_auth_token:\n\u001b[1;32m 117\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[38;5;241m=\u001b[39mfn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, has_token\u001b[38;5;241m=\u001b[39mhas_token, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[0;32m--> 119\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:1492\u001b[0m, in \u001b[0;36mhf_hub_download\u001b[0;34m(repo_id, filename, subfolder, repo_type, revision, library_name, library_version, cache_dir, local_dir, local_dir_use_symlinks, user_agent, force_download, force_filename, proxies, etag_timeout, resume_download, token, local_files_only, headers, legacy_cache_layout, endpoint)\u001b[0m\n\u001b[1;32m 1489\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m local_dir \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1490\u001b[0m _check_disk_space(expected_size, local_dir)\n\u001b[0;32m-> 1492\u001b[0m \u001b[43mhttp_get\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1493\u001b[0m \u001b[43m \u001b[49m\u001b[43murl_to_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1494\u001b[0m \u001b[43m \u001b[49m\u001b[43mtemp_file\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1495\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1496\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1497\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1498\u001b[0m \u001b[43m \u001b[49m\u001b[43mexpected_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mexpected_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1499\u001b[0m \u001b[43m \u001b[49m\u001b[43mdisplayed_filename\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1500\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m local_dir \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1503\u001b[0m logger\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mStoring \u001b[39m\u001b[38;5;132;01m{\u001b[39;00murl\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m in cache at \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mblob_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:456\u001b[0m, in \u001b[0;36mhttp_get\u001b[0;34m(url, temp_file, proxies, resume_size, headers, expected_size, displayed_filename, _nb_retries, _tqdm_bar)\u001b[0m\n\u001b[1;32m 453\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m resume_size \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 454\u001b[0m headers[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRange\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbytes=\u001b[39m\u001b[38;5;132;01m%d\u001b[39;00m\u001b[38;5;124m-\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m (resume_size,)\n\u001b[0;32m--> 456\u001b[0m r \u001b[38;5;241m=\u001b[39m \u001b[43m_request_wrapper\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 457\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mGET\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstream\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mHF_HUB_DOWNLOAD_TIMEOUT\u001b[49m\n\u001b[1;32m 458\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 459\u001b[0m hf_raise_for_status(r)\n\u001b[1;32m 460\u001b[0m content_length \u001b[38;5;241m=\u001b[39m r\u001b[38;5;241m.\u001b[39mheaders\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mContent-Length\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:392\u001b[0m, in \u001b[0;36m_request_wrapper\u001b[0;34m(method, url, follow_relative_redirects, **params)\u001b[0m\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m response\n\u001b[1;32m 391\u001b[0m \u001b[38;5;66;03m# Perform request and return if status_code is not in the retry list.\u001b[39;00m\n\u001b[0;32m--> 392\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[43mget_session\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 393\u001b[0m hf_raise_for_status(response)\n\u001b[1;32m 394\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m response\n", + "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/requests/sessions.py:589\u001b[0m, in \u001b[0;36mSession.request\u001b[0;34m(self, method, url, params, data, headers, cookies, files, auth, timeout, allow_redirects, proxies, hooks, stream, verify, cert, json)\u001b[0m\n\u001b[1;32m 584\u001b[0m send_kwargs \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 585\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtimeout\u001b[39m\u001b[38;5;124m\"\u001b[39m: timeout,\n\u001b[1;32m 586\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mallow_redirects\u001b[39m\u001b[38;5;124m\"\u001b[39m: allow_redirects,\n\u001b[1;32m 587\u001b[0m }\n\u001b[1;32m 588\u001b[0m send_kwargs\u001b[38;5;241m.\u001b[39mupdate(settings)\n\u001b[0;32m--> 589\u001b[0m resp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprep\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43msend_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 591\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m resp\n", + "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/requests/sessions.py:703\u001b[0m, in \u001b[0;36mSession.send\u001b[0;34m(self, request, **kwargs)\u001b[0m\n\u001b[1;32m 700\u001b[0m start \u001b[38;5;241m=\u001b[39m preferred_clock()\n\u001b[1;32m 702\u001b[0m \u001b[38;5;66;03m# Send the request\u001b[39;00m\n\u001b[0;32m--> 703\u001b[0m r \u001b[38;5;241m=\u001b[39m \u001b[43madapter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 705\u001b[0m \u001b[38;5;66;03m# Total elapsed time of the request (approximately)\u001b[39;00m\n\u001b[1;32m 706\u001b[0m elapsed \u001b[38;5;241m=\u001b[39m preferred_clock() \u001b[38;5;241m-\u001b[39m start\n", + "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/huggingface_hub/utils/_http.py:68\u001b[0m, in \u001b[0;36mUniqueRequestIdAdapter.send\u001b[0;34m(self, request, *args, **kwargs)\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Catch any RequestException to append request id to the error message for debugging.\"\"\"\u001b[39;00m\n\u001b[1;32m 67\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 68\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 69\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m requests\u001b[38;5;241m.\u001b[39mRequestException \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 70\u001b[0m request_id \u001b[38;5;241m=\u001b[39m request\u001b[38;5;241m.\u001b[39mheaders\u001b[38;5;241m.\u001b[39mget(X_AMZN_TRACE_ID)\n", + "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/requests/adapters.py:486\u001b[0m, in \u001b[0;36mHTTPAdapter.send\u001b[0;34m(self, request, stream, timeout, verify, cert, proxies)\u001b[0m\n\u001b[1;32m 483\u001b[0m timeout \u001b[38;5;241m=\u001b[39m TimeoutSauce(connect\u001b[38;5;241m=\u001b[39mtimeout, read\u001b[38;5;241m=\u001b[39mtimeout)\n\u001b[1;32m 485\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 486\u001b[0m resp \u001b[38;5;241m=\u001b[39m \u001b[43mconn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43murlopen\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 487\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 488\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 489\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 490\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 491\u001b[0m \u001b[43m \u001b[49m\u001b[43mredirect\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 492\u001b[0m \u001b[43m \u001b[49m\u001b[43massert_same_host\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 493\u001b[0m \u001b[43m \u001b[49m\u001b[43mpreload_content\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 494\u001b[0m \u001b[43m \u001b[49m\u001b[43mdecode_content\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 495\u001b[0m \u001b[43m \u001b[49m\u001b[43mretries\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_retries\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 496\u001b[0m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 497\u001b[0m \u001b[43m \u001b[49m\u001b[43mchunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchunked\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 498\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 500\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (ProtocolError, \u001b[38;5;167;01mOSError\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[1;32m 501\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mConnectionError\u001b[39;00m(err, request\u001b[38;5;241m=\u001b[39mrequest)\n", + "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/urllib3/connectionpool.py:793\u001b[0m, in \u001b[0;36mHTTPConnectionPool.urlopen\u001b[0;34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, preload_content, decode_content, **response_kw)\u001b[0m\n\u001b[1;32m 790\u001b[0m response_conn \u001b[38;5;241m=\u001b[39m conn \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m release_conn \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 792\u001b[0m \u001b[38;5;66;03m# Make the request on the HTTPConnection object\u001b[39;00m\n\u001b[0;32m--> 793\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_make_request\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 794\u001b[0m \u001b[43m \u001b[49m\u001b[43mconn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 795\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 796\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 797\u001b[0m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout_obj\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 798\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 799\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 800\u001b[0m \u001b[43m \u001b[49m\u001b[43mchunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchunked\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 801\u001b[0m \u001b[43m \u001b[49m\u001b[43mretries\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mretries\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 802\u001b[0m \u001b[43m \u001b[49m\u001b[43mresponse_conn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresponse_conn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 803\u001b[0m \u001b[43m \u001b[49m\u001b[43mpreload_content\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpreload_content\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 804\u001b[0m \u001b[43m \u001b[49m\u001b[43mdecode_content\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecode_content\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 805\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mresponse_kw\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 806\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 808\u001b[0m \u001b[38;5;66;03m# Everything went great!\u001b[39;00m\n\u001b[1;32m 809\u001b[0m clean_exit \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n", + "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/urllib3/connectionpool.py:537\u001b[0m, in \u001b[0;36mHTTPConnectionPool._make_request\u001b[0;34m(self, conn, method, url, body, headers, retries, timeout, chunked, response_conn, preload_content, decode_content, enforce_content_length)\u001b[0m\n\u001b[1;32m 535\u001b[0m \u001b[38;5;66;03m# Receive the response from the server\u001b[39;00m\n\u001b[1;32m 536\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 537\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[43mconn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgetresponse\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 538\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (BaseSSLError, \u001b[38;5;167;01mOSError\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 539\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_raise_timeout(err\u001b[38;5;241m=\u001b[39me, url\u001b[38;5;241m=\u001b[39murl, timeout_value\u001b[38;5;241m=\u001b[39mread_timeout)\n", + "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/urllib3/connection.py:466\u001b[0m, in \u001b[0;36mHTTPConnection.getresponse\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 463\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mresponse\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m HTTPResponse\n\u001b[1;32m 465\u001b[0m \u001b[38;5;66;03m# Get the response from http.client.HTTPConnection\u001b[39;00m\n\u001b[0;32m--> 466\u001b[0m httplib_response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgetresponse\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 468\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 469\u001b[0m assert_header_parsing(httplib_response\u001b[38;5;241m.\u001b[39mmsg)\n", + "File \u001b[0;32m/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/http/client.py:1390\u001b[0m, in \u001b[0;36mHTTPConnection.getresponse\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1388\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1389\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1390\u001b[0m \u001b[43mresponse\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbegin\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1391\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mConnectionError\u001b[39;00m:\n\u001b[1;32m 1392\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclose()\n", + "File \u001b[0;32m/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/http/client.py:325\u001b[0m, in \u001b[0;36mHTTPResponse.begin\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 323\u001b[0m \u001b[38;5;66;03m# read until we get a non-100 response\u001b[39;00m\n\u001b[1;32m 324\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m--> 325\u001b[0m version, status, reason \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_read_status\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 326\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m status \u001b[38;5;241m!=\u001b[39m CONTINUE:\n\u001b[1;32m 327\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n", + "File \u001b[0;32m/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/http/client.py:286\u001b[0m, in \u001b[0;36mHTTPResponse._read_status\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_read_status\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m--> 286\u001b[0m line \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mstr\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfp\u001b[38;5;241m.\u001b[39mreadline(_MAXLINE \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124miso-8859-1\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 287\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(line) \u001b[38;5;241m>\u001b[39m _MAXLINE:\n\u001b[1;32m 288\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m LineTooLong(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstatus line\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/socket.py:706\u001b[0m, in \u001b[0;36mSocketIO.readinto\u001b[0;34m(self, b)\u001b[0m\n\u001b[1;32m 704\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[1;32m 705\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 706\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sock\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrecv_into\u001b[49m\u001b[43m(\u001b[49m\u001b[43mb\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 707\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m timeout:\n\u001b[1;32m 708\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_timeout_occurred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n", + "File \u001b[0;32m/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/ssl.py:1314\u001b[0m, in \u001b[0;36mSSLSocket.recv_into\u001b[0;34m(self, buffer, nbytes, flags)\u001b[0m\n\u001b[1;32m 1310\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m flags \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 1311\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 1312\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnon-zero flags not allowed in calls to recv_into() on \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m\n\u001b[1;32m 1313\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m)\n\u001b[0;32m-> 1314\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnbytes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbuffer\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1315\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1316\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39mrecv_into(buffer, nbytes, flags)\n", + "File \u001b[0;32m/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/ssl.py:1166\u001b[0m, in \u001b[0;36mSSLSocket.read\u001b[0;34m(self, len, buffer)\u001b[0m\n\u001b[1;32m 1164\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1165\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m buffer \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sslobj\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbuffer\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1167\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1168\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sslobj\u001b[38;5;241m.\u001b[39mread(\u001b[38;5;28mlen\u001b[39m)\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], @@ -205,7 +364,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -217,7 +376,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -234,7 +393,7 @@ "prediction = tokenizer.decode(logprobs.argmax(dim=-1).item())\n", "\n", "print(f\"Prompt: {prompt}\")\n", - "print(f\"Prediction: \\\"{prediction}\\\"\")" + "print(f'Prediction: \"{prediction}\"')" ] }, { @@ -262,7 +421,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.11.8" }, "orig_nbformat": 4 }, From 027e2bc899a629dc6e71cd0c2f59f3df5957f02d Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Sat, 27 Apr 2024 00:37:06 +0200 Subject: [PATCH 11/42] ran cells --- demos/BERT.ipynb | 73 +++++++----------------------------------------- 1 file changed, 10 insertions(+), 63 deletions(-) diff --git a/demos/BERT.ipynb b/demos/BERT.ipynb index 37b5a78a5..09b8e90f8 100644 --- a/demos/BERT.ipynb +++ b/demos/BERT.ipynb @@ -262,7 +262,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -276,76 +276,23 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "30ca8295e43142db88991fcd24d522d5", + "model_id": "407745de57a144f09c2c0b7345c31edc", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "model.safetensors: 0%| | 0.00/436M [00:00 1\u001b[0m bert \u001b[38;5;241m=\u001b[39m \u001b[43mHookedEncoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mbert-base-cased\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m AutoTokenizer\u001b[38;5;241m.\u001b[39mfrom_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbert-base-cased\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/transformer_lens/HookedEncoder.py:251\u001b[0m, in \u001b[0;36mHookedEncoder.from_pretrained\u001b[0;34m(cls, model_name, checkpoint_index, checkpoint_value, hf_model, device, tokenizer, move_to_device, dtype, **from_pretrained_kwargs)\u001b[0m\n\u001b[1;32m 236\u001b[0m cfg \u001b[38;5;241m=\u001b[39m loading\u001b[38;5;241m.\u001b[39mget_pretrained_model_config(\n\u001b[1;32m 237\u001b[0m official_model_name,\n\u001b[1;32m 238\u001b[0m checkpoint_index\u001b[38;5;241m=\u001b[39mcheckpoint_index,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 244\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mfrom_pretrained_kwargs,\n\u001b[1;32m 245\u001b[0m )\n\u001b[1;32m 247\u001b[0m state_dict \u001b[38;5;241m=\u001b[39m loading\u001b[38;5;241m.\u001b[39mget_pretrained_state_dict(\n\u001b[1;32m 248\u001b[0m official_model_name, cfg, hf_model, dtype\u001b[38;5;241m=\u001b[39mdtype, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mfrom_pretrained_kwargs\n\u001b[1;32m 249\u001b[0m )\n\u001b[0;32m--> 251\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcfg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokenizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmove_to_device\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 253\u001b[0m model\u001b[38;5;241m.\u001b[39mload_state_dict(state_dict, strict\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 255\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m move_to_device:\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/transformer_lens/HookedEncoder.py:57\u001b[0m, in \u001b[0;36mHookedEncoder.__init__\u001b[0;34m(self, cfg, tokenizer, move_to_device, **kwargs)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mtokenizer_name \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 56\u001b[0m huggingface_token \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39menviron\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHF_TOKEN\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m---> 57\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtokenizer \u001b[38;5;241m=\u001b[39m \u001b[43mAutoTokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 58\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcfg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtokenizer_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 59\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhuggingface_token\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 60\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 61\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 62\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtokenizer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/transformers/models/auto/tokenization_auto.py:855\u001b[0m, in \u001b[0;36mAutoTokenizer.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, *inputs, **kwargs)\u001b[0m\n\u001b[1;32m 853\u001b[0m tokenizer_class_py, tokenizer_class_fast \u001b[38;5;241m=\u001b[39m TOKENIZER_MAPPING[\u001b[38;5;28mtype\u001b[39m(config)]\n\u001b[1;32m 854\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m tokenizer_class_fast \u001b[38;5;129;01mand\u001b[39;00m (use_fast \u001b[38;5;129;01mor\u001b[39;00m tokenizer_class_py \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m--> 855\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtokenizer_class_fast\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 856\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 857\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m tokenizer_class_py \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:2044\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, trust_remote_code, *init_inputs, **kwargs)\u001b[0m\n\u001b[1;32m 2042\u001b[0m resolved_vocab_files[file_id] \u001b[38;5;241m=\u001b[39m download_url(file_path, proxies\u001b[38;5;241m=\u001b[39mproxies)\n\u001b[1;32m 2043\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 2044\u001b[0m resolved_vocab_files[file_id] \u001b[38;5;241m=\u001b[39m \u001b[43mcached_file\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2045\u001b[0m \u001b[43m \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2046\u001b[0m \u001b[43m \u001b[49m\u001b[43mfile_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2047\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2048\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2049\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2050\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2051\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2052\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2053\u001b[0m \u001b[43m \u001b[49m\u001b[43muser_agent\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muser_agent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2054\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2055\u001b[0m \u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msubfolder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2056\u001b[0m \u001b[43m \u001b[49m\u001b[43m_raise_exceptions_for_gated_repo\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 2057\u001b[0m \u001b[43m \u001b[49m\u001b[43m_raise_exceptions_for_missing_entries\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 2058\u001b[0m \u001b[43m \u001b[49m\u001b[43m_raise_exceptions_for_connection_errors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 2059\u001b[0m \u001b[43m \u001b[49m\u001b[43m_commit_hash\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcommit_hash\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2060\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2061\u001b[0m commit_hash \u001b[38;5;241m=\u001b[39m extract_commit_hash(resolved_vocab_files[file_id], commit_hash)\n\u001b[1;32m 2063\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(unresolved_files) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/transformers/utils/hub.py:398\u001b[0m, in \u001b[0;36mcached_file\u001b[0;34m(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_gated_repo, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)\u001b[0m\n\u001b[1;32m 395\u001b[0m user_agent \u001b[38;5;241m=\u001b[39m http_user_agent(user_agent)\n\u001b[1;32m 396\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 397\u001b[0m \u001b[38;5;66;03m# Load from URL or cache if already cached\u001b[39;00m\n\u001b[0;32m--> 398\u001b[0m resolved_file \u001b[38;5;241m=\u001b[39m \u001b[43mhf_hub_download\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 399\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath_or_repo_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 400\u001b[0m \u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 401\u001b[0m \u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43msubfolder\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 402\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 403\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 404\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 405\u001b[0m \u001b[43m \u001b[49m\u001b[43muser_agent\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muser_agent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 406\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 407\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 408\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 409\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 410\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 411\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 412\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m GatedRepoError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 413\u001b[0m resolved_file \u001b[38;5;241m=\u001b[39m _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py:119\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.._inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_use_auth_token:\n\u001b[1;32m 117\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[38;5;241m=\u001b[39mfn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, has_token\u001b[38;5;241m=\u001b[39mhas_token, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[0;32m--> 119\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:1492\u001b[0m, in \u001b[0;36mhf_hub_download\u001b[0;34m(repo_id, filename, subfolder, repo_type, revision, library_name, library_version, cache_dir, local_dir, local_dir_use_symlinks, user_agent, force_download, force_filename, proxies, etag_timeout, resume_download, token, local_files_only, headers, legacy_cache_layout, endpoint)\u001b[0m\n\u001b[1;32m 1489\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m local_dir \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1490\u001b[0m _check_disk_space(expected_size, local_dir)\n\u001b[0;32m-> 1492\u001b[0m \u001b[43mhttp_get\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1493\u001b[0m \u001b[43m \u001b[49m\u001b[43murl_to_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1494\u001b[0m \u001b[43m \u001b[49m\u001b[43mtemp_file\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1495\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1496\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1497\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1498\u001b[0m \u001b[43m \u001b[49m\u001b[43mexpected_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mexpected_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1499\u001b[0m \u001b[43m \u001b[49m\u001b[43mdisplayed_filename\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1500\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m local_dir \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1503\u001b[0m logger\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mStoring \u001b[39m\u001b[38;5;132;01m{\u001b[39;00murl\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m in cache at \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mblob_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:456\u001b[0m, in \u001b[0;36mhttp_get\u001b[0;34m(url, temp_file, proxies, resume_size, headers, expected_size, displayed_filename, _nb_retries, _tqdm_bar)\u001b[0m\n\u001b[1;32m 453\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m resume_size \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 454\u001b[0m headers[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRange\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbytes=\u001b[39m\u001b[38;5;132;01m%d\u001b[39;00m\u001b[38;5;124m-\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m (resume_size,)\n\u001b[0;32m--> 456\u001b[0m r \u001b[38;5;241m=\u001b[39m \u001b[43m_request_wrapper\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 457\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mGET\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstream\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mHF_HUB_DOWNLOAD_TIMEOUT\u001b[49m\n\u001b[1;32m 458\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 459\u001b[0m hf_raise_for_status(r)\n\u001b[1;32m 460\u001b[0m content_length \u001b[38;5;241m=\u001b[39m r\u001b[38;5;241m.\u001b[39mheaders\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mContent-Length\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:392\u001b[0m, in \u001b[0;36m_request_wrapper\u001b[0;34m(method, url, follow_relative_redirects, **params)\u001b[0m\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m response\n\u001b[1;32m 391\u001b[0m \u001b[38;5;66;03m# Perform request and return if status_code is not in the retry list.\u001b[39;00m\n\u001b[0;32m--> 392\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[43mget_session\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 393\u001b[0m hf_raise_for_status(response)\n\u001b[1;32m 394\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m response\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/requests/sessions.py:589\u001b[0m, in \u001b[0;36mSession.request\u001b[0;34m(self, method, url, params, data, headers, cookies, files, auth, timeout, allow_redirects, proxies, hooks, stream, verify, cert, json)\u001b[0m\n\u001b[1;32m 584\u001b[0m send_kwargs \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 585\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtimeout\u001b[39m\u001b[38;5;124m\"\u001b[39m: timeout,\n\u001b[1;32m 586\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mallow_redirects\u001b[39m\u001b[38;5;124m\"\u001b[39m: allow_redirects,\n\u001b[1;32m 587\u001b[0m }\n\u001b[1;32m 588\u001b[0m send_kwargs\u001b[38;5;241m.\u001b[39mupdate(settings)\n\u001b[0;32m--> 589\u001b[0m resp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprep\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43msend_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 591\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m resp\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/requests/sessions.py:703\u001b[0m, in \u001b[0;36mSession.send\u001b[0;34m(self, request, **kwargs)\u001b[0m\n\u001b[1;32m 700\u001b[0m start \u001b[38;5;241m=\u001b[39m preferred_clock()\n\u001b[1;32m 702\u001b[0m \u001b[38;5;66;03m# Send the request\u001b[39;00m\n\u001b[0;32m--> 703\u001b[0m r \u001b[38;5;241m=\u001b[39m \u001b[43madapter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 705\u001b[0m \u001b[38;5;66;03m# Total elapsed time of the request (approximately)\u001b[39;00m\n\u001b[1;32m 706\u001b[0m elapsed \u001b[38;5;241m=\u001b[39m preferred_clock() \u001b[38;5;241m-\u001b[39m start\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/huggingface_hub/utils/_http.py:68\u001b[0m, in \u001b[0;36mUniqueRequestIdAdapter.send\u001b[0;34m(self, request, *args, **kwargs)\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Catch any RequestException to append request id to the error message for debugging.\"\"\"\u001b[39;00m\n\u001b[1;32m 67\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 68\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 69\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m requests\u001b[38;5;241m.\u001b[39mRequestException \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 70\u001b[0m request_id \u001b[38;5;241m=\u001b[39m request\u001b[38;5;241m.\u001b[39mheaders\u001b[38;5;241m.\u001b[39mget(X_AMZN_TRACE_ID)\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/requests/adapters.py:486\u001b[0m, in \u001b[0;36mHTTPAdapter.send\u001b[0;34m(self, request, stream, timeout, verify, cert, proxies)\u001b[0m\n\u001b[1;32m 483\u001b[0m timeout \u001b[38;5;241m=\u001b[39m TimeoutSauce(connect\u001b[38;5;241m=\u001b[39mtimeout, read\u001b[38;5;241m=\u001b[39mtimeout)\n\u001b[1;32m 485\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 486\u001b[0m resp \u001b[38;5;241m=\u001b[39m \u001b[43mconn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43murlopen\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 487\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 488\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 489\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 490\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrequest\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 491\u001b[0m \u001b[43m \u001b[49m\u001b[43mredirect\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 492\u001b[0m \u001b[43m \u001b[49m\u001b[43massert_same_host\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 493\u001b[0m \u001b[43m \u001b[49m\u001b[43mpreload_content\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 494\u001b[0m \u001b[43m \u001b[49m\u001b[43mdecode_content\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 495\u001b[0m \u001b[43m \u001b[49m\u001b[43mretries\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_retries\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 496\u001b[0m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 497\u001b[0m \u001b[43m \u001b[49m\u001b[43mchunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchunked\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 498\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 500\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (ProtocolError, \u001b[38;5;167;01mOSError\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[1;32m 501\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mConnectionError\u001b[39;00m(err, request\u001b[38;5;241m=\u001b[39mrequest)\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/urllib3/connectionpool.py:793\u001b[0m, in \u001b[0;36mHTTPConnectionPool.urlopen\u001b[0;34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, preload_content, decode_content, **response_kw)\u001b[0m\n\u001b[1;32m 790\u001b[0m response_conn \u001b[38;5;241m=\u001b[39m conn \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m release_conn \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 792\u001b[0m \u001b[38;5;66;03m# Make the request on the HTTPConnection object\u001b[39;00m\n\u001b[0;32m--> 793\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_make_request\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 794\u001b[0m \u001b[43m \u001b[49m\u001b[43mconn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 795\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 796\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 797\u001b[0m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout_obj\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 798\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 799\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 800\u001b[0m \u001b[43m \u001b[49m\u001b[43mchunked\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchunked\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 801\u001b[0m \u001b[43m \u001b[49m\u001b[43mretries\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mretries\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 802\u001b[0m \u001b[43m \u001b[49m\u001b[43mresponse_conn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresponse_conn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 803\u001b[0m \u001b[43m \u001b[49m\u001b[43mpreload_content\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpreload_content\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 804\u001b[0m \u001b[43m \u001b[49m\u001b[43mdecode_content\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecode_content\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 805\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mresponse_kw\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 806\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 808\u001b[0m \u001b[38;5;66;03m# Everything went great!\u001b[39;00m\n\u001b[1;32m 809\u001b[0m clean_exit \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/urllib3/connectionpool.py:537\u001b[0m, in \u001b[0;36mHTTPConnectionPool._make_request\u001b[0;34m(self, conn, method, url, body, headers, retries, timeout, chunked, response_conn, preload_content, decode_content, enforce_content_length)\u001b[0m\n\u001b[1;32m 535\u001b[0m \u001b[38;5;66;03m# Receive the response from the server\u001b[39;00m\n\u001b[1;32m 536\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 537\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[43mconn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgetresponse\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 538\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (BaseSSLError, \u001b[38;5;167;01mOSError\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 539\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_raise_timeout(err\u001b[38;5;241m=\u001b[39me, url\u001b[38;5;241m=\u001b[39murl, timeout_value\u001b[38;5;241m=\u001b[39mread_timeout)\n", - "File \u001b[0;32m~/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages/urllib3/connection.py:466\u001b[0m, in \u001b[0;36mHTTPConnection.getresponse\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 463\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mresponse\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m HTTPResponse\n\u001b[1;32m 465\u001b[0m \u001b[38;5;66;03m# Get the response from http.client.HTTPConnection\u001b[39;00m\n\u001b[0;32m--> 466\u001b[0m httplib_response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgetresponse\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 468\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 469\u001b[0m assert_header_parsing(httplib_response\u001b[38;5;241m.\u001b[39mmsg)\n", - "File \u001b[0;32m/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/http/client.py:1390\u001b[0m, in \u001b[0;36mHTTPConnection.getresponse\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1388\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1389\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1390\u001b[0m \u001b[43mresponse\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbegin\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1391\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mConnectionError\u001b[39;00m:\n\u001b[1;32m 1392\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclose()\n", - "File \u001b[0;32m/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/http/client.py:325\u001b[0m, in \u001b[0;36mHTTPResponse.begin\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 323\u001b[0m \u001b[38;5;66;03m# read until we get a non-100 response\u001b[39;00m\n\u001b[1;32m 324\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m--> 325\u001b[0m version, status, reason \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_read_status\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 326\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m status \u001b[38;5;241m!=\u001b[39m CONTINUE:\n\u001b[1;32m 327\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n", - "File \u001b[0;32m/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/http/client.py:286\u001b[0m, in \u001b[0;36mHTTPResponse._read_status\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_read_status\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m--> 286\u001b[0m line \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mstr\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfp\u001b[38;5;241m.\u001b[39mreadline(_MAXLINE \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124miso-8859-1\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 287\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(line) \u001b[38;5;241m>\u001b[39m _MAXLINE:\n\u001b[1;32m 288\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m LineTooLong(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstatus line\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/socket.py:706\u001b[0m, in \u001b[0;36mSocketIO.readinto\u001b[0;34m(self, b)\u001b[0m\n\u001b[1;32m 704\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[1;32m 705\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 706\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sock\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrecv_into\u001b[49m\u001b[43m(\u001b[49m\u001b[43mb\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 707\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m timeout:\n\u001b[1;32m 708\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_timeout_occurred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n", - "File \u001b[0;32m/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/ssl.py:1314\u001b[0m, in \u001b[0;36mSSLSocket.recv_into\u001b[0;34m(self, buffer, nbytes, flags)\u001b[0m\n\u001b[1;32m 1310\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m flags \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 1311\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 1312\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnon-zero flags not allowed in calls to recv_into() on \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m\n\u001b[1;32m 1313\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m)\n\u001b[0;32m-> 1314\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnbytes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbuffer\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1315\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1316\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39mrecv_into(buffer, nbytes, flags)\n", - "File \u001b[0;32m/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/ssl.py:1166\u001b[0m, in \u001b[0;36mSSLSocket.read\u001b[0;34m(self, len, buffer)\u001b[0m\n\u001b[1;32m 1164\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1165\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m buffer \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sslobj\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbuffer\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1167\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1168\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sslobj\u001b[38;5;241m.\u001b[39mread(\u001b[38;5;28mlen\u001b[39m)\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + "name": "stdout", + "output_type": "stream", + "text": [ + "Moving model to device: mps\n", + "Loaded pretrained model bert-base-cased into HookedTransformer\n" ] } ], @@ -364,7 +311,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -376,7 +323,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [ { From f06e2ec0d454306fe80e38b685bdc3915e13c252 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Sat, 27 Apr 2024 00:50:33 +0200 Subject: [PATCH 12/42] removed github check --- demos/BERT.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demos/BERT.ipynb b/demos/BERT.ipynb index 09b8e90f8..f619c7cc5 100644 --- a/demos/BERT.ipynb +++ b/demos/BERT.ipynb @@ -153,7 +153,7 @@ " ipython.magic(\"load_ext autoreload\")\n", " ipython.magic(\"autoreload 2\")\n", "\n", - "if IN_COLAB or IN_GITHUB:\n", + "if IN_COLAB:\n", " %pip install transformer_lens\n", " %pip install circuitsvis" ] From 1f0197b34ff305c5c5c7e16bb44b9fc5b4c23c15 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Sat, 27 Apr 2024 00:59:35 +0200 Subject: [PATCH 13/42] removed cells to skip --- demos/BERT.ipynb | 99 ++---------------------------------------------- 1 file changed, 3 insertions(+), 96 deletions(-) diff --git a/demos/BERT.ipynb b/demos/BERT.ipynb index f619c7cc5..976cd5b6d 100644 --- a/demos/BERT.ipynb +++ b/demos/BERT.ipynb @@ -29,102 +29,9 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: transformer_lens in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (0.0.0)\n", - "Requirement already satisfied: accelerate>=0.23.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.29.1)\n", - "Requirement already satisfied: beartype<0.15.0,>=0.14.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.14.1)\n", - "Requirement already satisfied: better-abc<0.0.4,>=0.0.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.0.3)\n", - "Requirement already satisfied: datasets>=2.7.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (2.18.0)\n", - "Requirement already satisfied: einops>=0.6.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.7.0)\n", - "Requirement already satisfied: fancy-einsum>=0.0.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.0.3)\n", - "Requirement already satisfied: jaxtyping>=0.2.11 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.2.19)\n", - "Requirement already satisfied: numpy>=1.24 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (1.26.4)\n", - "Requirement already satisfied: pandas>=1.1.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (2.0.3)\n", - "Requirement already satisfied: rich>=12.6.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (13.7.1)\n", - "Requirement already satisfied: sentencepiece in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.2.0)\n", - "Requirement already satisfied: torch!=2.0,!=2.1.0,>=1.10 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (2.1.2)\n", - "Requirement already satisfied: tqdm>=4.64.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (4.66.2)\n", - "Requirement already satisfied: transformers>=4.37.2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (4.39.3)\n", - "Requirement already satisfied: typing-extensions in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (4.11.0)\n", - "Requirement already satisfied: wandb>=0.13.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.16.6)\n", - "Requirement already satisfied: packaging>=20.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (24.0)\n", - "Requirement already satisfied: psutil in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (5.9.8)\n", - "Requirement already satisfied: pyyaml in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (6.0.1)\n", - "Requirement already satisfied: huggingface-hub in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (0.22.2)\n", - "Requirement already satisfied: safetensors>=0.3.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (0.4.2)\n", - "Requirement already satisfied: filelock in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (3.13.3)\n", - "Requirement already satisfied: pyarrow>=12.0.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (15.0.2)\n", - "Requirement already satisfied: pyarrow-hotfix in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (0.6)\n", - "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (0.3.8)\n", - "Requirement already satisfied: requests>=2.19.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (2.31.0)\n", - "Requirement already satisfied: xxhash in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (3.4.1)\n", - "Requirement already satisfied: multiprocess in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (0.70.16)\n", - "Requirement already satisfied: fsspec<=2024.2.0,>=2023.1.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from fsspec[http]<=2024.2.0,>=2023.1.0->datasets>=2.7.1->transformer_lens) (2024.2.0)\n", - "Requirement already satisfied: aiohttp in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (3.9.3)\n", - "Requirement already satisfied: typeguard>=2.13.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from jaxtyping>=0.2.11->transformer_lens) (4.2.1)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from pandas>=1.1.5->transformer_lens) (2.9.0.post0)\n", - "Requirement already satisfied: pytz>=2020.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from pandas>=1.1.5->transformer_lens) (2024.1)\n", - "Requirement already satisfied: tzdata>=2022.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from pandas>=1.1.5->transformer_lens) (2024.1)\n", - "Requirement already satisfied: markdown-it-py>=2.2.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from rich>=12.6.0->transformer_lens) (2.2.0)\n", - "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from rich>=12.6.0->transformer_lens) (2.17.2)\n", - "Requirement already satisfied: sympy in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (1.12)\n", - "Requirement already satisfied: networkx in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (3.1)\n", - "Requirement already satisfied: jinja2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (3.1.3)\n", - "Requirement already satisfied: regex!=2019.12.17 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformers>=4.37.2->transformer_lens) (2023.12.25)\n", - "Requirement already satisfied: tokenizers<0.19,>=0.14 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformers>=4.37.2->transformer_lens) (0.15.2)\n", - "Requirement already satisfied: Click!=8.0.0,>=7.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (8.1.7)\n", - "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (3.1.43)\n", - "Requirement already satisfied: sentry-sdk>=1.0.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (1.44.1)\n", - "Requirement already satisfied: docker-pycreds>=0.4.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (0.4.0)\n", - "Requirement already satisfied: setproctitle in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (1.3.3)\n", - "Requirement already satisfied: setuptools in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (69.2.0)\n", - "Requirement already satisfied: appdirs>=1.4.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (1.4.4)\n", - "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (4.25.3)\n", - "Requirement already satisfied: six>=1.4.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from docker-pycreds>=0.4.0->wandb>=0.13.5->transformer_lens) (1.16.0)\n", - "Requirement already satisfied: aiosignal>=1.1.2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.3.1)\n", - "Requirement already satisfied: attrs>=17.3.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (23.2.0)\n", - "Requirement already satisfied: frozenlist>=1.1.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.4.1)\n", - "Requirement already satisfied: multidict<7.0,>=4.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (6.0.5)\n", - "Requirement already satisfied: yarl<2.0,>=1.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.9.4)\n", - "Requirement already satisfied: gitdb<5,>=4.0.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens) (4.0.11)\n", - "Requirement already satisfied: mdurl~=0.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from markdown-it-py>=2.2.0->rich>=12.6.0->transformer_lens) (0.1.2)\n", - "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (3.3.2)\n", - "Requirement already satisfied: idna<4,>=2.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (3.6)\n", - "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (2.2.1)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (2024.2.2)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from jinja2->torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (2.1.5)\n", - "Requirement already satisfied: mpmath>=0.19 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from sympy->torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (1.3.0)\n", - "Requirement already satisfied: smmap<6,>=3.0.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens) (5.0.1)\n", - "\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", - "Note: you may need to restart the kernel to use updated packages.\n", - "Requirement already satisfied: circuitsvis in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (1.43.2)\n", - "Requirement already satisfied: importlib-metadata>=5.1.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from circuitsvis) (7.1.0)\n", - "Requirement already satisfied: numpy>=1.24 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from circuitsvis) (1.26.4)\n", - "Requirement already satisfied: torch>=1.10 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from circuitsvis) (2.1.2)\n", - "Requirement already satisfied: zipp>=0.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from importlib-metadata>=5.1.0->circuitsvis) (3.18.1)\n", - "Requirement already satisfied: filelock in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (3.13.3)\n", - "Requirement already satisfied: typing-extensions in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (4.11.0)\n", - "Requirement already satisfied: sympy in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (1.12)\n", - "Requirement already satisfied: networkx in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (3.1)\n", - "Requirement already satisfied: jinja2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (3.1.3)\n", - "Requirement already satisfied: fsspec in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (2024.2.0)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from jinja2->torch>=1.10->circuitsvis) (2.1.5)\n", - "Requirement already satisfied: mpmath>=0.19 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from sympy->torch>=1.10->circuitsvis) (1.3.0)\n", - "\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", - "Note: you may need to restart the kernel to use updated packages.\n" - ] - } - ], + "outputs": [], "source": [ "import os\n", "\n", @@ -291,7 +198,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Moving model to device: mps\n", + "Moving model to device: cpu\n", "Loaded pretrained model bert-base-cased into HookedTransformer\n" ] } From b1d54162e9fd5d457c9003b88a8d33813c1e18c7 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Sat, 27 Apr 2024 01:09:19 +0200 Subject: [PATCH 14/42] ignored output of loading cells --- demos/BERT.ipynb | 35 ++--------------------------------- 1 file changed, 2 insertions(+), 33 deletions(-) diff --git a/demos/BERT.ipynb b/demos/BERT.ipynb index 976cd5b6d..9d01678a1 100644 --- a/demos/BERT.ipynb +++ b/demos/BERT.ipynb @@ -169,40 +169,9 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:root:Support for BERT in TransformerLens is currently experimental, until such a time when it has feature parity with HookedTransformer and has been tested on real research tasks. Until then, backward compatibility is not guaranteed. Please see the docs for information on the limitations of the current implementation.\n", - "If using BERT for interpretability research, keep in mind that BERT has some significant architectural differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning that the last LayerNorm in a block cannot be folded.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "407745de57a144f09c2c0b7345c31edc", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "tokenizer.json: 0%| | 0.00/436k [00:00 Date: Fri, 3 May 2024 00:20:23 +0200 Subject: [PATCH 15/42] changed notebook tests to run on separate jobs --- .github/workflows/checks.yml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 819cb6539..e8142bcd1 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -121,6 +121,12 @@ jobs: notebook-checks: name: Notebook Checks runs-on: ubuntu-latest + strategy: + matrix: + notebook: + - "BERT" + - "Exploratory_Analysis_Demo" + - "Main_Demo" steps: - uses: actions/checkout@v3 - name: Install Poetry @@ -146,7 +152,7 @@ jobs: version: 1.0 - name: Check Notebook Output Consistency # Note: currently only checks notebooks we have specifically setup for this - run: make notebook-test + run: poetry run pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/${{ matrix.notebook }}.ipynb build-docs: From 5b63c463a78b97e35a6791504af5d033c79f8d03 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Fri, 3 May 2024 00:27:29 +0200 Subject: [PATCH 16/42] added all notebooks to CI --- .github/workflows/checks.yml | 10 ++++++++++ ...merDemo.ipynb => Hooked_SAE_Transformer_Demo.ipynb} | 0 2 files changed, 10 insertions(+) rename demos/{HookedSAETransformerDemo.ipynb => Hooked_SAE_Transformer_Demo.ipynb} (100%) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index e8142bcd1..dc5517c36 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -124,9 +124,19 @@ jobs: strategy: matrix: notebook: + - "Activation_Patching_in_TL_Demo" + - "Attribution_Patching_Demo" - "BERT" - "Exploratory_Analysis_Demo" + - "Grokking_Demo" + - "Head_Detector_Demo" + - "Hooked_SAE_Transformer_Demo" + - "Interactive_Neuroscope" + - "LLaMA" + - "LLaMA2_GPU_Quantized" - "Main_Demo" + - "No_Position_Experiment" + - "Othello_GPT" steps: - uses: actions/checkout@v3 - name: Install Poetry diff --git a/demos/HookedSAETransformerDemo.ipynb b/demos/Hooked_SAE_Transformer_Demo.ipynb similarity index 100% rename from demos/HookedSAETransformerDemo.ipynb rename to demos/Hooked_SAE_Transformer_Demo.ipynb From efcdfa890fb0c98dd8d4ee2c794bae7dc596fcc2 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Fri, 3 May 2024 00:32:05 +0200 Subject: [PATCH 17/42] renamed file temporarily --- demos/{LLaMA2_GPU_quantized.ipynb => LLaMA2.ipynb} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename demos/{LLaMA2_GPU_quantized.ipynb => LLaMA2.ipynb} (100%) diff --git a/demos/LLaMA2_GPU_quantized.ipynb b/demos/LLaMA2.ipynb similarity index 100% rename from demos/LLaMA2_GPU_quantized.ipynb rename to demos/LLaMA2.ipynb From 35f74c265f33ec4898ad8fb677c0f775a7bc4b85 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Fri, 3 May 2024 00:32:41 +0200 Subject: [PATCH 18/42] fixed file case --- demos/{LLaMA2.ipynb => LLaMA2_GPU_Quantized.ipynb} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename demos/{LLaMA2.ipynb => LLaMA2_GPU_Quantized.ipynb} (100%) diff --git a/demos/LLaMA2.ipynb b/demos/LLaMA2_GPU_Quantized.ipynb similarity index 100% rename from demos/LLaMA2.ipynb rename to demos/LLaMA2_GPU_Quantized.ipynb From eaba4b77d008b6482e5a57f69a811f2d58740b16 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Fri, 3 May 2024 00:48:21 +0200 Subject: [PATCH 19/42] reorganized setup --- demos/Interactive_Neuroscope.ipynb | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/demos/Interactive_Neuroscope.ipynb b/demos/Interactive_Neuroscope.ipynb index 843de35c8..1ab92567f 100644 --- a/demos/Interactive_Neuroscope.ipynb +++ b/demos/Interactive_Neuroscope.ipynb @@ -48,14 +48,16 @@ } ], "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", "import os\n", "\n", + "DEVELOPMENT_MODE = True\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", "try:\n", " import google.colab\n", - "\n", " IN_COLAB = True\n", " print(\"Running as a Colab notebook\")\n", - "\n", "except:\n", " IN_COLAB = False\n", " print(\"Running as a Jupyter notebook - intended for development only!\")\n", @@ -64,20 +66,11 @@ " ipython = get_ipython()\n", " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", " ipython.magic(\"load_ext autoreload\")\n", - " ipython.magic(\"autoreload 2\")" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", + " ipython.magic(\"autoreload 2\")\n", "\n", - "if IN_COLAB:\n", - " os.system(\"pip install git+https://github.com/neelnanda-io/TransformerLens.git\")\n", - " os.system(\"pip install gradio\")" + "if IN_COLAB or IN_GITHUB:\n", + " %pip install transformer_lens\n", + " %pip install gradio" ] }, { From 9f6e4dfdc0a68beecfae5c50c8140ff8f6c52d30 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Fri, 3 May 2024 01:10:01 +0200 Subject: [PATCH 20/42] updated head detector demo --- demos/Head_Detector_Demo.ipynb | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/demos/Head_Detector_Demo.ipynb b/demos/Head_Detector_Demo.ipynb index 3b4c341b5..cf19d95bb 100644 --- a/demos/Head_Detector_Demo.ipynb +++ b/demos/Head_Detector_Demo.ipynb @@ -295,21 +295,16 @@ } ], "source": [ + "# NBVAL_IGNORE_OUTPUT\n", "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", - "DEBUG_MODE = False\n", + "import os\n", + "\n", + "DEVELOPMENT_MODE = True\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", "try:\n", " import google.colab\n", " IN_COLAB = True\n", " print(\"Running as a Colab notebook\")\n", - " %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n", - " # Install Neel's personal plotting utils\n", - " %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n", - " # Install another version of node that makes PySvelte work way faster\n", - " !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", - " %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", - " # Needed for PySvelte to work, v3 came out and broke things...\n", - " %pip install typeguard==2.13.3\n", - " %pip install typing-extensions\n", "except:\n", " IN_COLAB = False\n", " print(\"Running as a Jupyter notebook - intended for development only!\")\n", @@ -318,7 +313,18 @@ " ipython = get_ipython()\n", " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", " ipython.magic(\"load_ext autoreload\")\n", - " ipython.magic(\"autoreload 2\")" + " ipython.magic(\"autoreload 2\")\n", + "\n", + "if IN_COLAB or IN_GITHUB:\n", + " %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n", + " # Install Neel's personal plotting utils\n", + " %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n", + " # Install another version of node that makes PySvelte work way faster\n", + " !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", + " %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", + " # Needed for PySvelte to work, v3 came out and broke things...\n", + " %pip install typeguard==2.13.3\n", + " %pip install typing-extensions" ] }, { From a4e4c90f33824b87b99f73a0f967a13dd3680821 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Fri, 3 May 2024 01:16:25 +0200 Subject: [PATCH 21/42] reran othello --- demos/Othello_GPT.ipynb | 5317 +++------------------------------------ 1 file changed, 387 insertions(+), 4930 deletions(-) diff --git a/demos/Othello_GPT.ipynb b/demos/Othello_GPT.ipynb index 1b4400bc7..7f11a494d 100644 --- a/demos/Othello_GPT.ipynb +++ b/demos/Othello_GPT.ipynb @@ -49,28 +49,158 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Running as a Jupyter notebook - intended for development only!\n" + "Running as a Jupyter notebook - intended for development only!\n", + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_72266/3797162913.py:25: DeprecationWarning:\n", + "\n", + "`magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + "\n", + "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_72266/3797162913.py:26: DeprecationWarning:\n", + "\n", + "`magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: transformer_lens in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (0.0.0)\n", + "Requirement already satisfied: accelerate>=0.23.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.29.1)\n", + "Requirement already satisfied: beartype<0.15.0,>=0.14.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.14.1)\n", + "Requirement already satisfied: better-abc<0.0.4,>=0.0.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.0.3)\n", + "Requirement already satisfied: datasets>=2.7.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (2.18.0)\n", + "Requirement already satisfied: einops>=0.6.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.7.0)\n", + "Requirement already satisfied: fancy-einsum>=0.0.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.0.3)\n", + "Requirement already satisfied: jaxtyping>=0.2.11 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.2.19)\n", + "Requirement already satisfied: numpy>=1.24 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (1.26.4)\n", + "Requirement already satisfied: pandas>=1.1.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (2.0.3)\n", + "Requirement already satisfied: rich>=12.6.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (13.7.1)\n", + "Requirement already satisfied: sentencepiece in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.2.0)\n", + "Requirement already satisfied: torch!=2.0,!=2.1.0,>=1.10 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (2.1.2)\n", + "Requirement already satisfied: tqdm>=4.64.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (4.66.2)\n", + "Requirement already satisfied: transformers>=4.37.2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (4.39.3)\n", + "Requirement already satisfied: typing-extensions in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (4.11.0)\n", + "Requirement already satisfied: wandb>=0.13.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.16.6)\n", + "Requirement already satisfied: packaging>=20.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (24.0)\n", + "Requirement already satisfied: psutil in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (5.9.8)\n", + "Requirement already satisfied: pyyaml in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (6.0.1)\n", + "Requirement already satisfied: huggingface-hub in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (0.22.2)\n", + "Requirement already satisfied: safetensors>=0.3.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (0.4.2)\n", + "Requirement already satisfied: filelock in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (3.13.3)\n", + "Requirement already satisfied: pyarrow>=12.0.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (15.0.2)\n", + "Requirement already satisfied: pyarrow-hotfix in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (0.6)\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (0.3.8)\n", + "Requirement already satisfied: requests>=2.19.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (2.31.0)\n", + "Requirement already satisfied: xxhash in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (3.4.1)\n", + "Requirement already satisfied: multiprocess in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (0.70.16)\n", + "Requirement already satisfied: fsspec<=2024.2.0,>=2023.1.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from fsspec[http]<=2024.2.0,>=2023.1.0->datasets>=2.7.1->transformer_lens) (2024.2.0)\n", + "Requirement already satisfied: aiohttp in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (3.9.3)\n", + "Requirement already satisfied: typeguard>=2.13.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from jaxtyping>=0.2.11->transformer_lens) (4.2.1)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from pandas>=1.1.5->transformer_lens) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from pandas>=1.1.5->transformer_lens) (2024.1)\n", + "Requirement already satisfied: tzdata>=2022.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from pandas>=1.1.5->transformer_lens) (2024.1)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from rich>=12.6.0->transformer_lens) (2.2.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from rich>=12.6.0->transformer_lens) (2.17.2)\n", + "Requirement already satisfied: sympy in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (1.12)\n", + "Requirement already satisfied: networkx in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (3.1)\n", + "Requirement already satisfied: jinja2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (3.1.3)\n", + "Requirement already satisfied: regex!=2019.12.17 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformers>=4.37.2->transformer_lens) (2023.12.25)\n", + "Requirement already satisfied: tokenizers<0.19,>=0.14 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformers>=4.37.2->transformer_lens) (0.15.2)\n", + "Requirement already satisfied: Click!=8.0.0,>=7.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (8.1.7)\n", + "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (3.1.43)\n", + "Requirement already satisfied: sentry-sdk>=1.0.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (1.44.1)\n", + "Requirement already satisfied: docker-pycreds>=0.4.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (0.4.0)\n", + "Requirement already satisfied: setproctitle in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (1.3.3)\n", + "Requirement already satisfied: setuptools in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (69.2.0)\n", + "Requirement already satisfied: appdirs>=1.4.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (1.4.4)\n", + "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (4.25.3)\n", + "Requirement already satisfied: six>=1.4.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from docker-pycreds>=0.4.0->wandb>=0.13.5->transformer_lens) (1.16.0)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (23.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.9.4)\n", + "Requirement already satisfied: gitdb<5,>=4.0.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens) (4.0.11)\n", + "Requirement already satisfied: mdurl~=0.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from markdown-it-py>=2.2.0->rich>=12.6.0->transformer_lens) (0.1.2)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (3.6)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (2.2.1)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (2024.2.2)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from jinja2->torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (2.1.5)\n", + "Requirement already satisfied: mpmath>=0.19 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from sympy->torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (1.3.0)\n", + "Requirement already satisfied: smmap<6,>=3.0.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens) (5.0.1)\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", + "Note: you may need to restart the kernel to use updated packages.\n", + "Requirement already satisfied: circuitsvis in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (1.43.2)\n", + "Requirement already satisfied: importlib-metadata>=5.1.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from circuitsvis) (7.1.0)\n", + "Requirement already satisfied: numpy>=1.24 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from circuitsvis) (1.26.4)\n", + "Requirement already satisfied: torch>=1.10 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from circuitsvis) (2.1.2)\n", + "Requirement already satisfied: zipp>=0.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from importlib-metadata>=5.1.0->circuitsvis) (3.18.1)\n", + "Requirement already satisfied: filelock in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (3.13.3)\n", + "Requirement already satisfied: typing-extensions in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (4.11.0)\n", + "Requirement already satisfied: sympy in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (1.12)\n", + "Requirement already satisfied: networkx in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (3.1)\n", + "Requirement already satisfied: jinja2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (3.1.3)\n", + "Requirement already satisfied: fsspec in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (2024.2.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from jinja2->torch>=1.10->circuitsvis) (2.1.5)\n", + "Requirement already satisfied: mpmath>=0.19 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from sympy->torch>=1.10->circuitsvis) (1.3.0)\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", + "Note: you may need to restart the kernel to use updated packages.\n", + "Collecting torchtyping\n", + " Using cached torchtyping-0.1.4-py3-none-any.whl.metadata (9.2 kB)\n", + "Requirement already satisfied: torch>=1.7.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torchtyping) (2.1.2)\n", + "Requirement already satisfied: typeguard>=2.11.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torchtyping) (4.2.1)\n", + "Requirement already satisfied: filelock in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.7.0->torchtyping) (3.13.3)\n", + "Requirement already satisfied: typing-extensions in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.7.0->torchtyping) (4.11.0)\n", + "Requirement already satisfied: sympy in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.7.0->torchtyping) (1.12)\n", + "Requirement already satisfied: networkx in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.7.0->torchtyping) (3.1)\n", + "Requirement already satisfied: jinja2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.7.0->torchtyping) (3.1.3)\n", + "Requirement already satisfied: fsspec in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.7.0->torchtyping) (2024.2.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from jinja2->torch>=1.7.0->torchtyping) (2.1.5)\n", + "Requirement already satisfied: mpmath>=0.19 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from sympy->torch>=1.7.0->torchtyping) (1.3.0)\n", + "Using cached torchtyping-0.1.4-py3-none-any.whl (17 kB)\n", + "Installing collected packages: torchtyping\n", + "Successfully installed torchtyping-0.1.4\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", + "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "import os\n", + "\n", "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", "DEVELOPMENT_MODE = False\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", + "IN_GITHUB = True\n", "try:\n", " import google.colab\n", + "\n", " IN_COLAB = True\n", " print(\"Running as a Colab notebook\")\n", - " %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n", - " %pip install circuitsvis\n", - " %pip install torchtyping\n", - " \n", + "\n", " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", " # # Install another version of node that makes PySvelte work way faster\n", " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", @@ -83,7 +213,12 @@ " ipython = get_ipython()\n", " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", " ipython.magic(\"load_ext autoreload\")\n", - " ipython.magic(\"autoreload 2\")" + " ipython.magic(\"autoreload 2\")\n", + "\n", + "if IN_COLAB or IN_GITHUB:\n", + " %pip install transformer_lens\n", + " %pip install circuitsvis\n", + " %pip install torchtyping" ] }, { @@ -102,6 +237,7 @@ "source": [ "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n", "import plotly.io as pio\n", + "\n", "if IN_COLAB or not DEVELOPMENT_MODE:\n", " pio.renderers.default = \"colab\"\n", "else:\n", @@ -117,4882 +253,18 @@ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -5002,13 +274,14 @@ ], "source": [ "import circuitsvis as cv\n", + "\n", "# Testing that the library works\n", "cv.examples.hello(\"Neel\")" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -5040,7 +313,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -5050,7 +323,12 @@ " HookedRootModule,\n", " HookPoint,\n", ") # Hooking utilities\n", - "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache" + "from transformer_lens import (\n", + " HookedTransformer,\n", + " HookedTransformerConfig,\n", + " FactoredMatrix,\n", + " ActivationCache,\n", + ")" ] }, { @@ -5062,16 +340,16 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 6, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -5089,20 +367,32 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def imshow(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n", - " px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale=\"RdBu\", labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n", + " px.imshow(\n", + " utils.to_numpy(tensor),\n", + " color_continuous_midpoint=0.0,\n", + " color_continuous_scale=\"RdBu\",\n", + " labels={\"x\": xaxis, \"y\": yaxis},\n", + " **kwargs,\n", + " ).show(renderer)\n", + "\n", "\n", "def line(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n", - " px.line(utils.to_numpy(tensor), labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n", + " px.line(utils.to_numpy(tensor), labels={\"x\": xaxis, \"y\": yaxis}, **kwargs).show(\n", + " renderer\n", + " )\n", + "\n", "\n", "def scatter(x, y, xaxis=\"\", yaxis=\"\", caxis=\"\", renderer=None, **kwargs):\n", " x = utils.to_numpy(x)\n", " y = utils.to_numpy(y)\n", - " px.scatter(y=y, x=x, labels={\"x\":xaxis, \"y\":yaxis, \"color\":caxis}, **kwargs).show(renderer)" + " px.scatter(\n", + " y=y, x=x, labels={\"x\": xaxis, \"y\": yaxis, \"color\": caxis}, **kwargs\n", + " ).show(renderer)" ] }, { @@ -5115,7 +405,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -5124,33 +414,60 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import transformer_lens.utils as utils\n", + "\n", "cfg = HookedTransformerConfig(\n", - " n_layers = 8,\n", - " d_model = 512,\n", - " d_head = 64,\n", - " n_heads = 8,\n", - " d_mlp = 2048,\n", - " d_vocab = 61,\n", - " n_ctx = 59,\n", + " n_layers=8,\n", + " d_model=512,\n", + " d_head=64,\n", + " n_heads=8,\n", + " d_mlp=2048,\n", + " d_vocab=61,\n", + " n_ctx=59,\n", " act_fn=\"gelu\",\n", - " normalization_type=\"LNPre\"\n", + " normalization_type=\"LNPre\",\n", ")\n", "model = HookedTransformer(cfg)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "566376c7d21e4645958b59f758d2a971", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "./synthetic_model.pth: 0%| | 0.00/101M [00:00" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "\n", - "sd = utils.download_file_from_hf(\"NeelNanda/Othello-GPT-Transformer-Lens\", \"synthetic_model.pth\")\n", + "sd = utils.download_file_from_hf(\n", + " \"NeelNanda/Othello-GPT-Transformer-Lens\", \"synthetic_model.pth\"\n", + ")\n", "# champion_ship_sd = utils.download_file_from_hf(\"NeelNanda/Othello-GPT-Transformer-Lens\", \"championship_model.pth\")\n", "model.load_state_dict(sd)" ] @@ -5165,7 +482,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -5185,25 +502,39 @@ " out_sd[f\"blocks.{layer}.ln2.b\"] = in_sd[f\"blocks.{layer}.ln2.bias\"]\n", "\n", " out_sd[f\"blocks.{layer}.attn.W_Q\"] = einops.rearrange(\n", - " in_sd[f\"blocks.{layer}.attn.query.weight\"], \"(head d_head) d_model -> head d_model d_head\", head=n_heads\n", + " in_sd[f\"blocks.{layer}.attn.query.weight\"],\n", + " \"(head d_head) d_model -> head d_model d_head\",\n", + " head=n_heads,\n", " )\n", " out_sd[f\"blocks.{layer}.attn.b_Q\"] = einops.rearrange(\n", - " in_sd[f\"blocks.{layer}.attn.query.bias\"], \"(head d_head) -> head d_head\", head=n_heads\n", + " in_sd[f\"blocks.{layer}.attn.query.bias\"],\n", + " \"(head d_head) -> head d_head\",\n", + " head=n_heads,\n", " )\n", " out_sd[f\"blocks.{layer}.attn.W_K\"] = einops.rearrange(\n", - " in_sd[f\"blocks.{layer}.attn.key.weight\"], \"(head d_head) d_model -> head d_model d_head\", head=n_heads\n", + " in_sd[f\"blocks.{layer}.attn.key.weight\"],\n", + " \"(head d_head) d_model -> head d_model d_head\",\n", + " head=n_heads,\n", " )\n", " out_sd[f\"blocks.{layer}.attn.b_K\"] = einops.rearrange(\n", - " in_sd[f\"blocks.{layer}.attn.key.bias\"], \"(head d_head) -> head d_head\", head=n_heads\n", + " in_sd[f\"blocks.{layer}.attn.key.bias\"],\n", + " \"(head d_head) -> head d_head\",\n", + " head=n_heads,\n", " )\n", " out_sd[f\"blocks.{layer}.attn.W_V\"] = einops.rearrange(\n", - " in_sd[f\"blocks.{layer}.attn.value.weight\"], \"(head d_head) d_model -> head d_model d_head\", head=n_heads\n", + " in_sd[f\"blocks.{layer}.attn.value.weight\"],\n", + " \"(head d_head) d_model -> head d_model d_head\",\n", + " head=n_heads,\n", " )\n", " out_sd[f\"blocks.{layer}.attn.b_V\"] = einops.rearrange(\n", - " in_sd[f\"blocks.{layer}.attn.value.bias\"], \"(head d_head) -> head d_head\", head=n_heads\n", + " in_sd[f\"blocks.{layer}.attn.value.bias\"],\n", + " \"(head d_head) -> head d_head\",\n", + " head=n_heads,\n", " )\n", " out_sd[f\"blocks.{layer}.attn.W_O\"] = einops.rearrange(\n", - " in_sd[f\"blocks.{layer}.attn.proj.weight\"], \"d_model (head d_head) -> head d_head d_model\", head=n_heads\n", + " in_sd[f\"blocks.{layer}.attn.proj.weight\"],\n", + " \"d_model (head d_head) -> head d_head d_model\",\n", + " head=n_heads,\n", " )\n", " out_sd[f\"blocks.{layer}.attn.b_O\"] = in_sd[f\"blocks.{layer}.attn.proj.bias\"]\n", "\n", @@ -5211,31 +542,32 @@ " out_sd[f\"blocks.{layer}.mlp.W_in\"] = in_sd[f\"blocks.{layer}.mlp.0.weight\"].T\n", " out_sd[f\"blocks.{layer}.mlp.b_out\"] = in_sd[f\"blocks.{layer}.mlp.2.bias\"]\n", " out_sd[f\"blocks.{layer}.mlp.W_out\"] = in_sd[f\"blocks.{layer}.mlp.2.weight\"].T\n", - " \n", + "\n", " return out_sd\n", "\n", - "if LOAD_AND_CONVERT_CHECKPOINT:\n", "\n", + "if LOAD_AND_CONVERT_CHECKPOINT:\n", " synthetic_checkpoint = torch.load(\"/workspace/othello_world/gpt_synthetic.ckpt\")\n", " for name, param in synthetic_checkpoint.items():\n", " if name.startswith(\"blocks.0\") or not name.startswith(\"blocks\"):\n", " print(name, param.shape)\n", "\n", " cfg = HookedTransformerConfig(\n", - " n_layers = 8,\n", - " d_model = 512,\n", - " d_head = 64,\n", - " n_heads = 8,\n", - " d_mlp = 2048,\n", - " d_vocab = 61,\n", - " n_ctx = 59,\n", + " n_layers=8,\n", + " d_model=512,\n", + " d_head=64,\n", + " n_heads=8,\n", + " d_mlp=2048,\n", + " d_vocab=61,\n", + " n_ctx=59,\n", " act_fn=\"gelu\",\n", - " normalization_type=\"LNPre\"\n", + " normalization_type=\"LNPre\",\n", " )\n", " model = HookedTransformer(cfg)\n", "\n", - "\n", - " model.load_and_process_state_dict(convert_to_transformer_lens_format(synthetic_checkpoint))\n" + " model.load_and_process_state_dict(\n", + " convert_to_transformer_lens_format(synthetic_checkpoint)\n", + " )" ] }, { @@ -5248,7 +580,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -5257,22 +589,147 @@ "tensor([[21, 41, 40, 34, 40, 41, 3, 11, 21, 43, 40, 21, 28, 50, 33, 50, 33, 5,\n", " 33, 5, 52, 46, 14, 46, 14, 47, 38, 57, 36, 50, 38, 15, 28, 26, 28, 59,\n", " 50, 28, 14, 28, 28, 28, 28, 45, 28, 35, 15, 14, 30, 59, 49, 59, 15, 15,\n", - " 14, 15, 8, 7, 8]], device='cuda:0')" + " 14, 15, 8, 7, 8]])" ] }, - "execution_count": 36, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# An example input\n", - "sample_input = torch.tensor([[20, 19, 18, 10, 2, 1, 27, 3, 41, 42, 34, 12, 4, 40, 11, 29, 43, 13, 48, 56, 33, 39, 22, 44, 24, 5, 46, 6, 32, 36, 51, 58, 52, 60, 21, 53, 26, 31, 37, 9, 25, 38, 23, 50, 45, 17, 47, 28, 35, 30, 54, 16, 59, 49, 57, 14, 15, 55, 7]])\n", + "sample_input = torch.tensor(\n", + " [\n", + " [\n", + " 20,\n", + " 19,\n", + " 18,\n", + " 10,\n", + " 2,\n", + " 1,\n", + " 27,\n", + " 3,\n", + " 41,\n", + " 42,\n", + " 34,\n", + " 12,\n", + " 4,\n", + " 40,\n", + " 11,\n", + " 29,\n", + " 43,\n", + " 13,\n", + " 48,\n", + " 56,\n", + " 33,\n", + " 39,\n", + " 22,\n", + " 44,\n", + " 24,\n", + " 5,\n", + " 46,\n", + " 6,\n", + " 32,\n", + " 36,\n", + " 51,\n", + " 58,\n", + " 52,\n", + " 60,\n", + " 21,\n", + " 53,\n", + " 26,\n", + " 31,\n", + " 37,\n", + " 9,\n", + " 25,\n", + " 38,\n", + " 23,\n", + " 50,\n", + " 45,\n", + " 17,\n", + " 47,\n", + " 28,\n", + " 35,\n", + " 30,\n", + " 54,\n", + " 16,\n", + " 59,\n", + " 49,\n", + " 57,\n", + " 14,\n", + " 15,\n", + " 55,\n", + " 7,\n", + " ]\n", + " ]\n", + ")\n", "# The argmax of the output (ie the most likely next move from each position)\n", - "sample_output = torch.tensor([[21, 41, 40, 34, 40, 41, 3, 11, 21, 43, 40, 21, 28, 50, 33, 50, 33, 5,\n", - " 33, 5, 52, 46, 14, 46, 14, 47, 38, 57, 36, 50, 38, 15, 28, 26, 28, 59,\n", - " 50, 28, 14, 28, 28, 28, 28, 45, 28, 35, 15, 14, 30, 59, 49, 59, 15, 15,\n", - " 14, 15, 8, 7, 8]])\n", + "sample_output = torch.tensor(\n", + " [\n", + " [\n", + " 21,\n", + " 41,\n", + " 40,\n", + " 34,\n", + " 40,\n", + " 41,\n", + " 3,\n", + " 11,\n", + " 21,\n", + " 43,\n", + " 40,\n", + " 21,\n", + " 28,\n", + " 50,\n", + " 33,\n", + " 50,\n", + " 33,\n", + " 5,\n", + " 33,\n", + " 5,\n", + " 52,\n", + " 46,\n", + " 14,\n", + " 46,\n", + " 14,\n", + " 47,\n", + " 38,\n", + " 57,\n", + " 36,\n", + " 50,\n", + " 38,\n", + " 15,\n", + " 28,\n", + " 26,\n", + " 28,\n", + " 59,\n", + " 50,\n", + " 28,\n", + " 14,\n", + " 28,\n", + " 28,\n", + " 28,\n", + " 28,\n", + " 45,\n", + " 28,\n", + " 35,\n", + " 15,\n", + " 14,\n", + " 30,\n", + " 59,\n", + " 49,\n", + " 59,\n", + " 15,\n", + " 15,\n", + " 14,\n", + " 15,\n", + " 8,\n", + " 7,\n", + " 8,\n", + " ]\n", + " ]\n", + ")\n", "model(sample_input).argmax(dim=-1)" ] } @@ -5293,7 +750,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.13" + "version": "3.11.8" }, "orig_nbformat": 4, "vscode": { From cda93e77f9e9a4ff75259547376e65298d532ea5 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Fri, 3 May 2024 01:30:31 +0200 Subject: [PATCH 22/42] updated installation section --- demos/LLaMA2_GPU_Quantized.ipynb | 46 +++++++++----------------------- 1 file changed, 13 insertions(+), 33 deletions(-) diff --git a/demos/LLaMA2_GPU_Quantized.ipynb b/demos/LLaMA2_GPU_Quantized.ipynb index 58631a21e..c6d739baf 100644 --- a/demos/LLaMA2_GPU_Quantized.ipynb +++ b/demos/LLaMA2_GPU_Quantized.ipynb @@ -18,30 +18,6 @@ "## Setup (skip)" ] }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "HssVtL08CUsP", - "outputId": "5ad91c32-95e8-4970-99ec-242f9e2ebab2" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (0.1.99)\n" - ] - } - ], - "source": [ - "%pip install transformers>=4.31.0 # Llama requires transformers>=4.31.0 and transformers in turn requires Python 3.8\n", - "%pip install sentencepiece # Llama tokenizer requires sentencepiece" - ] - }, { "cell_type": "code", "execution_count": 2, @@ -173,21 +149,18 @@ } ], "source": [ + "# NBVAL_IGNORE_OUTPUT\n", "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "import os\n", + "\n", "DEVELOPMENT_MODE = False\n", "IN_VSCODE = False\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", + "\n", "try:\n", " import google.colab\n", " IN_COLAB = True\n", " print(\"Running as a Colab notebook\")\n", - " # %pip install git+https://github.com/neelnanda-io/TransformerLens.git``\n", - " %pip install git+https://github.com/coolvision/TransformerLens.git@llama_4bit_v2``\n", - " %pip install circuitsvis\n", - "\n", - " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", - " # # Install another version of node that makes PySvelte work way faster\n", - " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", - " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", "except:\n", " IN_COLAB = False\n", " print(\"Running as a Jupyter notebook - intended for development only!\")\n", @@ -196,7 +169,14 @@ " ipython = get_ipython()\n", " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", " ipython.magic(\"load_ext autoreload\")\n", - " ipython.magic(\"autoreload 2\")" + " ipython.magic(\"autoreload 2\")\n", + " \n", + "%pip install transformers>=4.31.0 # Llama requires transformers>=4.31.0 and transformers in turn requires Python 3.8\n", + "%pip install sentencepiece # Llama tokenizer requires sentencepiece\n", + " \n", + "if IN_GITHUB or IN_COLAB:\n", + " %pip install transformer_lens\n", + " %pip install circuitsvis" ] }, { From 5ae2377b9b6854d6031d58a653f95de830bef8c4 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Fri, 3 May 2024 02:11:31 +0200 Subject: [PATCH 23/42] updated no position install to install deps in github --- demos/No_Position_Experiment.ipynb | 19 +++++++++++++++++-- demos/Othello_GPT.ipynb | 1 - 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/demos/No_Position_Experiment.ipynb b/demos/No_Position_Experiment.ipynb index 98b2ddf2a..e52d98765 100644 --- a/demos/No_Position_Experiment.ipynb +++ b/demos/No_Position_Experiment.ipynb @@ -39,14 +39,29 @@ "metadata": {}, "outputs": [], "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "import os\n", + "\n", + "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "DEVELOPMENT_MODE = False\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", "try:\n", " import google.colab\n", "\n", " IN_COLAB = True\n", - " !pip install einops\n", - " %pip install transformer_lens\n", + " print(\"Running as a Colab notebook\")\n", "except:\n", " IN_COLAB = False\n", + " print(\"Running as a Jupyter notebook - intended for development only!\")\n", + " \n", + "if IN_COLAB or IN_GITHUB:\n", + " %pip install einops\n", + " %pip install transformer_lens\n", + "\n", + " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", + " # # Install another version of node that makes PySvelte work way faster\n", + " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", + " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", "\n", "from transformer_lens import HookedTransformer, HookedTransformerConfig\n", "import torch\n", diff --git a/demos/Othello_GPT.ipynb b/demos/Othello_GPT.ipynb index 7f11a494d..6c93eb57d 100644 --- a/demos/Othello_GPT.ipynb +++ b/demos/Othello_GPT.ipynb @@ -194,7 +194,6 @@ "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", "DEVELOPMENT_MODE = False\n", "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", - "IN_GITHUB = True\n", "try:\n", " import google.colab\n", "\n", From 6f5e77c3d88e6df8c41bc39980474888bd18edaf Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Fri, 3 May 2024 02:36:03 +0200 Subject: [PATCH 24/42] updated output of beginning areas --- demos/No_Position_Experiment.ipynb | 99 ++++++++++++++++++++++++++---- 1 file changed, 87 insertions(+), 12 deletions(-) diff --git a/demos/No_Position_Experiment.ipynb b/demos/No_Position_Experiment.ipynb index e52d98765..00ec0af94 100644 --- a/demos/No_Position_Experiment.ipynb +++ b/demos/No_Position_Experiment.ipynb @@ -37,7 +37,89 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running as a Jupyter notebook - intended for development only!\n", + "Requirement already satisfied: einops in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (0.7.0)\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", + "Note: you may need to restart the kernel to use updated packages.\n", + "Requirement already satisfied: transformer_lens in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (0.0.0)\n", + "Requirement already satisfied: accelerate>=0.23.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.29.1)\n", + "Requirement already satisfied: beartype<0.15.0,>=0.14.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.14.1)\n", + "Requirement already satisfied: better-abc<0.0.4,>=0.0.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.0.3)\n", + "Requirement already satisfied: datasets>=2.7.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (2.18.0)\n", + "Requirement already satisfied: einops>=0.6.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.7.0)\n", + "Requirement already satisfied: fancy-einsum>=0.0.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.0.3)\n", + "Requirement already satisfied: jaxtyping>=0.2.11 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.2.19)\n", + "Requirement already satisfied: numpy>=1.24 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (1.26.4)\n", + "Requirement already satisfied: pandas>=1.1.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (2.0.3)\n", + "Requirement already satisfied: rich>=12.6.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (13.7.1)\n", + "Requirement already satisfied: sentencepiece in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.2.0)\n", + "Requirement already satisfied: torch!=2.0,!=2.1.0,>=1.10 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (2.1.2)\n", + "Requirement already satisfied: tqdm>=4.64.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (4.66.2)\n", + "Requirement already satisfied: transformers>=4.37.2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (4.39.3)\n", + "Requirement already satisfied: typing-extensions in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (4.11.0)\n", + "Requirement already satisfied: wandb>=0.13.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.16.6)\n", + "Requirement already satisfied: packaging>=20.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (24.0)\n", + "Requirement already satisfied: psutil in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (5.9.8)\n", + "Requirement already satisfied: pyyaml in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (6.0.1)\n", + "Requirement already satisfied: huggingface-hub in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (0.22.2)\n", + "Requirement already satisfied: safetensors>=0.3.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (0.4.2)\n", + "Requirement already satisfied: filelock in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (3.13.3)\n", + "Requirement already satisfied: pyarrow>=12.0.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (15.0.2)\n", + "Requirement already satisfied: pyarrow-hotfix in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (0.6)\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (0.3.8)\n", + "Requirement already satisfied: requests>=2.19.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (2.31.0)\n", + "Requirement already satisfied: xxhash in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (3.4.1)\n", + "Requirement already satisfied: multiprocess in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (0.70.16)\n", + "Requirement already satisfied: fsspec<=2024.2.0,>=2023.1.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from fsspec[http]<=2024.2.0,>=2023.1.0->datasets>=2.7.1->transformer_lens) (2024.2.0)\n", + "Requirement already satisfied: aiohttp in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (3.9.3)\n", + "Requirement already satisfied: typeguard>=2.13.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from jaxtyping>=0.2.11->transformer_lens) (4.2.1)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from pandas>=1.1.5->transformer_lens) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from pandas>=1.1.5->transformer_lens) (2024.1)\n", + "Requirement already satisfied: tzdata>=2022.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from pandas>=1.1.5->transformer_lens) (2024.1)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from rich>=12.6.0->transformer_lens) (2.2.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from rich>=12.6.0->transformer_lens) (2.17.2)\n", + "Requirement already satisfied: sympy in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (1.12)\n", + "Requirement already satisfied: networkx in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (3.1)\n", + "Requirement already satisfied: jinja2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (3.1.3)\n", + "Requirement already satisfied: regex!=2019.12.17 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformers>=4.37.2->transformer_lens) (2023.12.25)\n", + "Requirement already satisfied: tokenizers<0.19,>=0.14 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformers>=4.37.2->transformer_lens) (0.15.2)\n", + "Requirement already satisfied: Click!=8.0.0,>=7.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (8.1.7)\n", + "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (3.1.43)\n", + "Requirement already satisfied: sentry-sdk>=1.0.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (1.44.1)\n", + "Requirement already satisfied: docker-pycreds>=0.4.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (0.4.0)\n", + "Requirement already satisfied: setproctitle in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (1.3.3)\n", + "Requirement already satisfied: setuptools in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (69.2.0)\n", + "Requirement already satisfied: appdirs>=1.4.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (1.4.4)\n", + "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (4.25.3)\n", + "Requirement already satisfied: six>=1.4.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from docker-pycreds>=0.4.0->wandb>=0.13.5->transformer_lens) (1.16.0)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (23.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.9.4)\n", + "Requirement already satisfied: gitdb<5,>=4.0.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens) (4.0.11)\n", + "Requirement already satisfied: mdurl~=0.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from markdown-it-py>=2.2.0->rich>=12.6.0->transformer_lens) (0.1.2)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (3.6)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (2.2.1)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (2024.2.2)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from jinja2->torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (2.1.5)\n", + "Requirement already satisfied: mpmath>=0.19 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from sympy->torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (1.3.0)\n", + "Requirement already satisfied: smmap<6,>=3.0.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens) (5.0.1)\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], "source": [ "# NBVAL_IGNORE_OUTPUT\n", "import os\n", @@ -45,6 +127,7 @@ "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", "DEVELOPMENT_MODE = False\n", "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", + "IN_GITHUB = True\n", "try:\n", " import google.colab\n", "\n", @@ -53,7 +136,7 @@ "except:\n", " IN_COLAB = False\n", " print(\"Running as a Jupyter notebook - intended for development only!\")\n", - " \n", + "\n", "if IN_COLAB or IN_GITHUB:\n", " %pip install einops\n", " %pip install transformer_lens\n", @@ -134,15 +217,7 @@ "cell_type": "code", "execution_count": 3, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Moving model to device: cuda\n" - ] - } - ], + "outputs": [], "source": [ "cfg = HookedTransformerConfig(\n", " n_layers=2,\n", @@ -1405,7 +1480,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.13" + "version": "3.11.8" }, "vscode": { "interpreter": { From ef93d5da1f5e51630313b639f128dddd20e61b44 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Fri, 3 May 2024 02:46:00 +0200 Subject: [PATCH 25/42] updated starting block for llama --- demos/LLaMA.ipynb | 61 +++++++++++++---------------------------------- 1 file changed, 16 insertions(+), 45 deletions(-) diff --git a/demos/LLaMA.ipynb b/demos/LLaMA.ipynb index 9e9f428e6..35eb81d76 100644 --- a/demos/LLaMA.ipynb +++ b/demos/LLaMA.ipynb @@ -25,26 +25,6 @@ "## Setup (skip)" ] }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Note: you may need to restart the kernel to use updated packages.\n", - "Requirement already satisfied: sentencepiece in /root/TransformerLens/.venv/lib/python3.10/site-packages (0.1.99)\n", - "Note: you may need to restart the kernel to use updated packages.\n" - ] - } - ], - "source": [ - "%pip install transformers>=4.31.0 # Llama requires transformers>=4.31.0 and transformers in turn requires Python 3.8\n", - "%pip install sentencepiece # Llama tokenizer requires sentencepiece" - ] - }, { "cell_type": "code", "execution_count": 2, @@ -69,19 +49,18 @@ } ], "source": [ + "# NBVAL_IGNORE_OUTPUT\n", "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "import os\n", + "\n", "DEVELOPMENT_MODE = False\n", + "IN_VSCODE = False\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", + "\n", "try:\n", " import google.colab\n", " IN_COLAB = True\n", " print(\"Running as a Colab notebook\")\n", - " %pip install git+https://github.com/neelnanda-io/TransformerLens.git``\n", - " %pip install circuitsvis\n", - " \n", - " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", - " # # Install another version of node that makes PySvelte work way faster\n", - " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", - " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", "except:\n", " IN_COLAB = False\n", " print(\"Running as a Jupyter notebook - intended for development only!\")\n", @@ -90,23 +69,15 @@ " ipython = get_ipython()\n", " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", " ipython.magic(\"load_ext autoreload\")\n", - " ipython.magic(\"autoreload 2\")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using renderer: colab\n" - ] - } - ], - "source": [ + " ipython.magic(\"autoreload 2\")\n", + " \n", + "%pip install transformers>=4.31.0 # Llama requires transformers>=4.31.0 and transformers in turn requires Python 3.8\n", + "%pip install sentencepiece # Llama tokenizer requires sentencepiece\n", + "\n", + "if IN_COLAB or IN_GITHUB:\n", + " %pip install git+https://github.com/neelnanda-io/TransformerLens.git``\n", + " %pip install circuitsvis\n", + " \n", "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n", "import plotly.io as pio\n", "if IN_COLAB or not DEVELOPMENT_MODE:\n", @@ -470,7 +441,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.11.8" }, "orig_nbformat": 4, "vscode": { From bef4d99909a7a7a8178fba4df5003c5cd2efe76b Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Fri, 3 May 2024 02:53:51 +0200 Subject: [PATCH 26/42] regenerated no position experiment --- demos/No_Position_Experiment.ipynb | 436 +++++++++++------------------ 1 file changed, 163 insertions(+), 273 deletions(-) diff --git a/demos/No_Position_Experiment.ipynb b/demos/No_Position_Experiment.ipynb index 00ec0af94..831a543f0 100644 --- a/demos/No_Position_Experiment.ipynb +++ b/demos/No_Position_Experiment.ipynb @@ -35,88 +35,14 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Running as a Jupyter notebook - intended for development only!\n", - "Requirement already satisfied: einops in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (0.7.0)\n", - "\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", - "Note: you may need to restart the kernel to use updated packages.\n", - "Requirement already satisfied: transformer_lens in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (0.0.0)\n", - "Requirement already satisfied: accelerate>=0.23.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.29.1)\n", - "Requirement already satisfied: beartype<0.15.0,>=0.14.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.14.1)\n", - "Requirement already satisfied: better-abc<0.0.4,>=0.0.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.0.3)\n", - "Requirement already satisfied: datasets>=2.7.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (2.18.0)\n", - "Requirement already satisfied: einops>=0.6.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.7.0)\n", - "Requirement already satisfied: fancy-einsum>=0.0.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.0.3)\n", - "Requirement already satisfied: jaxtyping>=0.2.11 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.2.19)\n", - "Requirement already satisfied: numpy>=1.24 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (1.26.4)\n", - "Requirement already satisfied: pandas>=1.1.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (2.0.3)\n", - "Requirement already satisfied: rich>=12.6.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (13.7.1)\n", - "Requirement already satisfied: sentencepiece in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.2.0)\n", - "Requirement already satisfied: torch!=2.0,!=2.1.0,>=1.10 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (2.1.2)\n", - "Requirement already satisfied: tqdm>=4.64.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (4.66.2)\n", - "Requirement already satisfied: transformers>=4.37.2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (4.39.3)\n", - "Requirement already satisfied: typing-extensions in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (4.11.0)\n", - "Requirement already satisfied: wandb>=0.13.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.16.6)\n", - "Requirement already satisfied: packaging>=20.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (24.0)\n", - "Requirement already satisfied: psutil in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (5.9.8)\n", - "Requirement already satisfied: pyyaml in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (6.0.1)\n", - "Requirement already satisfied: huggingface-hub in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (0.22.2)\n", - "Requirement already satisfied: safetensors>=0.3.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (0.4.2)\n", - "Requirement already satisfied: filelock in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (3.13.3)\n", - "Requirement already satisfied: pyarrow>=12.0.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (15.0.2)\n", - "Requirement already satisfied: pyarrow-hotfix in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (0.6)\n", - "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (0.3.8)\n", - "Requirement already satisfied: requests>=2.19.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (2.31.0)\n", - "Requirement already satisfied: xxhash in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (3.4.1)\n", - "Requirement already satisfied: multiprocess in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (0.70.16)\n", - "Requirement already satisfied: fsspec<=2024.2.0,>=2023.1.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from fsspec[http]<=2024.2.0,>=2023.1.0->datasets>=2.7.1->transformer_lens) (2024.2.0)\n", - "Requirement already satisfied: aiohttp in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (3.9.3)\n", - "Requirement already satisfied: typeguard>=2.13.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from jaxtyping>=0.2.11->transformer_lens) (4.2.1)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from pandas>=1.1.5->transformer_lens) (2.9.0.post0)\n", - "Requirement already satisfied: pytz>=2020.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from pandas>=1.1.5->transformer_lens) (2024.1)\n", - "Requirement already satisfied: tzdata>=2022.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from pandas>=1.1.5->transformer_lens) (2024.1)\n", - "Requirement already satisfied: markdown-it-py>=2.2.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from rich>=12.6.0->transformer_lens) (2.2.0)\n", - "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from rich>=12.6.0->transformer_lens) (2.17.2)\n", - "Requirement already satisfied: sympy in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (1.12)\n", - "Requirement already satisfied: networkx in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (3.1)\n", - "Requirement already satisfied: jinja2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (3.1.3)\n", - "Requirement already satisfied: regex!=2019.12.17 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformers>=4.37.2->transformer_lens) (2023.12.25)\n", - "Requirement already satisfied: tokenizers<0.19,>=0.14 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformers>=4.37.2->transformer_lens) (0.15.2)\n", - "Requirement already satisfied: Click!=8.0.0,>=7.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (8.1.7)\n", - "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (3.1.43)\n", - "Requirement already satisfied: sentry-sdk>=1.0.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (1.44.1)\n", - "Requirement already satisfied: docker-pycreds>=0.4.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (0.4.0)\n", - "Requirement already satisfied: setproctitle in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (1.3.3)\n", - "Requirement already satisfied: setuptools in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (69.2.0)\n", - "Requirement already satisfied: appdirs>=1.4.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (1.4.4)\n", - "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (4.25.3)\n", - "Requirement already satisfied: six>=1.4.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from docker-pycreds>=0.4.0->wandb>=0.13.5->transformer_lens) (1.16.0)\n", - "Requirement already satisfied: aiosignal>=1.1.2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.3.1)\n", - "Requirement already satisfied: attrs>=17.3.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (23.2.0)\n", - "Requirement already satisfied: frozenlist>=1.1.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.4.1)\n", - "Requirement already satisfied: multidict<7.0,>=4.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (6.0.5)\n", - "Requirement already satisfied: yarl<2.0,>=1.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.9.4)\n", - "Requirement already satisfied: gitdb<5,>=4.0.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens) (4.0.11)\n", - "Requirement already satisfied: mdurl~=0.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from markdown-it-py>=2.2.0->rich>=12.6.0->transformer_lens) (0.1.2)\n", - "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (3.3.2)\n", - "Requirement already satisfied: idna<4,>=2.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (3.6)\n", - "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (2.2.1)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (2024.2.2)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from jinja2->torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (2.1.5)\n", - "Requirement already satisfied: mpmath>=0.19 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from sympy->torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (1.3.0)\n", - "Requirement already satisfied: smmap<6,>=3.0.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens) (5.0.1)\n", - "\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", - "Note: you may need to restart the kernel to use updated packages.\n" + "Running as a Jupyter notebook - intended for development only!\n" ] } ], @@ -127,7 +53,6 @@ "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", "DEVELOPMENT_MODE = False\n", "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", - "IN_GITHUB = True\n", "try:\n", " import google.colab\n", "\n", @@ -139,7 +64,7 @@ "\n", "if IN_COLAB or IN_GITHUB:\n", " %pip install einops\n", - " %pip install transformer_lens\n", + " %pip install transformer_lens@v1.15.0\n", "\n", " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", " # # Install another version of node that makes PySvelte work way faster\n", @@ -155,7 +80,9 @@ "pio.renderers.default = \"colab\"\n", "import tqdm.auto as tqdm\n", "import einops\n", - "from transformer_lens.utils import to_numpy" + "from transformer_lens.utils import to_numpy\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" ] }, { @@ -167,7 +94,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 34, "metadata": {}, "outputs": [], "source": [ @@ -215,7 +142,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ @@ -229,13 +156,14 @@ " n_ctx=50,\n", " act_fn=\"relu\",\n", " normalization_type=\"LN\",\n", + " device=device,\n", ")\n", "model = HookedTransformer(cfg)" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 36, "metadata": {}, "outputs": [], "source": [ @@ -249,7 +177,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 37, "metadata": {}, "outputs": [ { @@ -262,7 +190,7 @@ " (pos_embed): PosEmbed()\n", " (hook_pos_embed): HookPoint()\n", " (blocks): ModuleList(\n", - " (0): TransformerBlock(\n", + " (0-1): 2 x TransformerBlock(\n", " (ln1): LayerNorm(\n", " (hook_scale): HookPoint()\n", " (hook_normalized): HookPoint()\n", @@ -277,41 +205,18 @@ " (hook_v): HookPoint()\n", " (hook_z): HookPoint()\n", " (hook_attn_scores): HookPoint()\n", - " (hook_attn): HookPoint()\n", - " (hook_result): HookPoint()\n", - " )\n", - " (mlp): MLP(\n", - " (hook_pre): HookPoint()\n", - " (hook_post): HookPoint()\n", - " )\n", - " (hook_attn_out): HookPoint()\n", - " (hook_mlp_out): HookPoint()\n", - " (hook_resid_pre): HookPoint()\n", - " (hook_resid_mid): HookPoint()\n", - " (hook_resid_post): HookPoint()\n", - " )\n", - " (1): TransformerBlock(\n", - " (ln1): LayerNorm(\n", - " (hook_scale): HookPoint()\n", - " (hook_normalized): HookPoint()\n", - " )\n", - " (ln2): LayerNorm(\n", - " (hook_scale): HookPoint()\n", - " (hook_normalized): HookPoint()\n", - " )\n", - " (attn): Attention(\n", - " (hook_k): HookPoint()\n", - " (hook_q): HookPoint()\n", - " (hook_v): HookPoint()\n", - " (hook_z): HookPoint()\n", - " (hook_attn_scores): HookPoint()\n", - " (hook_attn): HookPoint()\n", + " (hook_pattern): HookPoint()\n", " (hook_result): HookPoint()\n", " )\n", " (mlp): MLP(\n", " (hook_pre): HookPoint()\n", " (hook_post): HookPoint()\n", " )\n", + " (hook_attn_in): HookPoint()\n", + " (hook_q_input): HookPoint()\n", + " (hook_k_input): HookPoint()\n", + " (hook_v_input): HookPoint()\n", + " (hook_mlp_in): HookPoint()\n", " (hook_attn_out): HookPoint()\n", " (hook_mlp_out): HookPoint()\n", " (hook_resid_pre): HookPoint()\n", @@ -341,7 +246,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 38, "metadata": {}, "outputs": [ { @@ -375,7 +280,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 39, "metadata": {}, "outputs": [], "source": [ @@ -394,7 +299,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 40, "metadata": {}, "outputs": [ { @@ -427,7 +332,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 41, "metadata": {}, "outputs": [], "source": [ @@ -452,13 +357,13 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "24bee9140cc842458ddb27db29714435", + "model_id": "122c183908104b04a600bfe4aca9f009", "version_major": 2, "version_minor": 0 }, @@ -473,46 +378,46 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 0: 6.0046892166137695\n", - "Epoch 100: 5.636989593505859\n", - "Epoch 200: 5.43870210647583\n", - "Epoch 300: 5.084463119506836\n", - "Epoch 400: 4.105402946472168\n", - "Epoch 500: 2.667759418487549\n", - "Epoch 600: 1.4100089073181152\n", - "Epoch 700: 0.6534866094589233\n", - "Epoch 800: 0.26310086250305176\n", - "Epoch 900: 0.094235360622406\n", - "Epoch 1000: 0.07662342488765717\n", - "Epoch 1100: 0.05123501643538475\n", - "Epoch 1200: 0.0633467361330986\n", - "Epoch 1300: 0.0698024183511734\n", - "Epoch 1400: 0.03592035919427872\n", - "Epoch 1500: 0.06732264906167984\n", - "Epoch 1600: 0.028138982132077217\n", - "Epoch 1700: 0.02272624894976616\n", - "Epoch 1800: 0.02585722878575325\n", - "Epoch 1900: 0.04599686339497566\n", - "Epoch 2000: 0.21788650751113892\n", - "Epoch 2100: 0.052709151059389114\n", - "Epoch 2200: 0.025653734803199768\n", - "Epoch 2300: 0.03516862168908119\n", - "Epoch 2400: 0.017889760434627533\n", - "Epoch 2500: 0.013999780640006065\n", - "Epoch 2600: 0.036015357822179794\n", - "Epoch 2700: 0.021333860233426094\n", - "Epoch 2800: 0.07593370974063873\n", - "Epoch 2900: 0.01114147063344717\n", - "Epoch 3000: 0.007803339511156082\n", - "Epoch 3100: 0.008570970967411995\n", - "Epoch 3200: 0.007602860685437918\n", - "Epoch 3300: 0.007690392434597015\n", - "Epoch 3400: 0.007755259983241558\n", - "Epoch 3500: 0.008618775755167007\n", - "Epoch 3600: 0.003434383077546954\n", - "Epoch 3700: 0.010702412575483322\n", - "Epoch 3800: 0.005929325707256794\n", - "Epoch 3900: 0.0032337282318621874\n" + "Epoch 0: 6.039131164550781\n", + "Epoch 100: 5.773892879486084\n", + "Epoch 200: 5.573237895965576\n", + "Epoch 300: 5.444890022277832\n", + "Epoch 400: 5.3126444816589355\n", + "Epoch 500: 5.152464389801025\n", + "Epoch 600: 4.953516483306885\n", + "Epoch 700: 4.677230358123779\n", + "Epoch 800: 4.353099822998047\n", + "Epoch 900: 3.9406914710998535\n", + "Epoch 1000: 3.4933784008026123\n", + "Epoch 1100: 3.07138991355896\n", + "Epoch 1200: 2.6529295444488525\n", + "Epoch 1300: 2.2651336193084717\n", + "Epoch 1400: 1.9132359027862549\n", + "Epoch 1500: 1.576438307762146\n", + "Epoch 1600: 1.2859177589416504\n", + "Epoch 1700: 1.0253156423568726\n", + "Epoch 1800: 0.8068246841430664\n", + "Epoch 1900: 0.6299871802330017\n", + "Epoch 2000: 0.47548314929008484\n", + "Epoch 2100: 0.3611340820789337\n", + "Epoch 2200: 0.2577555775642395\n", + "Epoch 2300: 0.19410978257656097\n", + "Epoch 2400: 0.14035893976688385\n", + "Epoch 2500: 0.10599333792924881\n", + "Epoch 2600: 0.07851045578718185\n", + "Epoch 2700: 0.055136531591415405\n", + "Epoch 2800: 0.041809480637311935\n", + "Epoch 2900: 0.0317872129380703\n", + "Epoch 3000: 0.025179561227560043\n", + "Epoch 3100: 0.017474526539444923\n", + "Epoch 3200: 0.016747349873185158\n", + "Epoch 3300: 0.011254986748099327\n", + "Epoch 3400: 0.010616127401590347\n", + "Epoch 3500: 0.0072231595404446125\n", + "Epoch 3600: 0.006508614867925644\n", + "Epoch 3700: 0.0050914837047457695\n", + "Epoch 3800: 0.005069360602647066\n", + "Epoch 3900: 0.00448464322835207\n" ] }, { @@ -522,9 +427,9 @@ "\n", "\n", "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "