Skip to content

Commit

Permalink
Add layernorm things in analysis_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonGross committed Oct 16, 2023
1 parent 9c96100 commit deeddd8
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions training/analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,19 @@ def make_local_tqdm(tqdm):
else:
return tqdm

# %%
@torch.no_grad()
def layernorm_noscale(x: torch.Tensor) -> torch.Tensor:
return x - x.mean(axis=-1, keepdim=True)

# %%
@torch.no_grad()
def layernorm_scales(x: torch.Tensor, eps: float = 1e-5, recip: bool = True) -> torch.Tensor:
x = layernorm_noscale(x)
scale = (x.pow(2).mean(axis=-1, keepdim=True) + eps).sqrt()
if recip: scale = 1 / scale
return scale

# %%

def display_size_direction_stats(size_direction: torch.Tensor, QK: torch.Tensor, U: torch.Tensor, Vh: torch.Tensor, S: torch.Tensor,
Expand Down Expand Up @@ -1065,13 +1078,12 @@ def find_size_direction(model: HookedTransformer, plot_heatmaps=False, renderer=



# from train_max_of_2 import get_model
# from tqdm.auto import tqdm


# if __name__ == '__main__':
# from train_max_of_2 import get_model
# from tqdm.auto import tqdm

# TRAIN_IF_NECESSARY = False
# model = get_model(train_if_necessary=TRAIN_IF_NECESSARY)

# find_size_direction(model, plot_heatmaps=True, renderer='png')
# find_size_direction(model, plot_heatmaps=True)#, renderer='png')
# %%

0 comments on commit deeddd8

Please sign in to comment.