Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encode with slice #334

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 78 additions & 28 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ class SAEConfig:

@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig":

# rename dict:
rename_dict = { # old : new
"hook_point": "hook_name",
Expand Down Expand Up @@ -155,17 +154,8 @@ def __init__(
self.device = torch.device(cfg.device)
self.use_error_term = use_error_term

if self.cfg.architecture == "standard":
self.initialize_weights_basic()
self.encode = self.encode_standard
elif self.cfg.architecture == "gated":
self.initialize_weights_gated()
self.encode = self.encode_gated
elif self.cfg.architecture == "jumprelu":
self.initialize_weights_jumprelu()
self.encode = self.encode_jumprelu
else:
raise (ValueError)
if self.cfg.architecture not in ["standard", "gated", "jumprelu"]:
Copy link
Collaborator

@chanind chanind Oct 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like tests are failing because initialize_weights() have been moved into their own function, but that function is never called now. IMO this error can be raised in the new initialize_weights() method instead.

raise ValueError(f"Architecture {self.cfg.architecture} not supported")

# handle presence / absence of scaling factor.
if self.cfg.finetuning_scaling_factor:
Expand Down Expand Up @@ -196,7 +186,6 @@ def __init__(

# handle run time activation normalization if needed:
if self.cfg.normalize_activations == "constant_norm_rescale":

# we need to scale the norm of the input and store the scaling factor
def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
self.x_norm_coeff = (self.cfg.d_in**0.5) / x.norm(dim=-1, keepdim=True)
Expand All @@ -212,7 +201,6 @@ def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor: #
self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out

elif self.cfg.normalize_activations == "layer_norm":

# we need to scale the norm of the input and store the scaling factor
def run_time_activation_ln_in(
x: torch.Tensor, eps: float = 1e-5
Expand All @@ -236,8 +224,17 @@ def run_time_activation_ln_out(x: torch.Tensor, eps: float = 1e-5):

self.setup() # Required for `HookedRootModule`s

def initialize_weights_basic(self):
def initialize_weights(self):
if self.cfg.architecture == "standard":
self.initialize_weights_basic()
elif self.cfg.architecture == "gated":
self.initialize_weights_gated()
elif self.cfg.architecture == "jumprelu":
self.initialize_weights_jumprelu()
else:
raise (ValueError)

def initialize_weights_basic(self):
# no config changes encoder bias init for now.
self.b_enc = nn.Parameter(
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
Expand Down Expand Up @@ -491,9 +488,39 @@ def forward(

return self.hook_sae_output(sae_out)

def encode(
self, x: torch.Tensor, latents: torch.Tensor | None = None
) -> torch.Tensor:
"""
Calculate SAE latents from inputs. Includes optional `latents` argument to only calculate a subset. Note that
this won't make sense for topk SAEs, because we need to compute all hidden values to apply the topk masking.
"""
if self.cfg.activation_fn_str == "topk":
assert (
latents is None
), "Computing a slice of SAE hidden values doesn't make sense in topk SAEs."

return {
"standard": self.encode_standard,
"gated": self.encode_gated,
"jumprelu": self.encode_jumprelu,
}[self.cfg.architecture](x, latents)

def encode_gated(
self, x: Float[torch.Tensor, "... d_in"]
self,
x: Float[torch.Tensor, "... d_in"],
latents: torch.Tensor | None = None,
) -> Float[torch.Tensor, "... d_sae"]:
"""
Computes the latent values of the Sparse Autoencoder (SAE) using a gated architecture. The activation values are
computed as the product of the masking term & the post-activation function magnitude term:

1[(x - b_dec) @ W_gate + b_gate > 0] * activation_fn((x - b_dec) @ W_enc + b_enc)

The `latents` argument allows for the computation of a specific subset of the hidden values. If `latents` is not
provided, all latent values will be computed.
"""
latents_tensor = torch.arange(self.cfg.d_sae) if latents is None else latents
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous implementation you had, latents_slice = slice(None) if latents is None else torch.tensor(latents), seems better to me since this new version will create a new tensor of size d_sae on every SAE forward pass when not selecting specific latents. This would likely reduce performance for most users, which seems counter-productive since this PR is just meant to be a performance improvement if I understand the goal correctly. Wouldn't the old implementation have worked fine just adding torch.Tensor to the type of latents? e.g. latents: Iterable[int] | torch.Tensor | None = None


x = x.to(self.dtype)
x = self.reshape_fn_in(x)
Expand All @@ -502,12 +529,15 @@ def encode_gated(
sae_in = x - self.b_dec * self.cfg.apply_b_dec_to_input

# Gating path
gating_pre_activation = sae_in @ self.W_enc + self.b_gate
gating_pre_activation = (
sae_in @ self.W_enc[:, latents_tensor] + self.b_gate[latents_tensor]
)
active_features = (gating_pre_activation > 0).to(self.dtype)

# Magnitude path with weight sharing
magnitude_pre_activation = self.hook_sae_acts_pre(
sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
sae_in @ (self.W_enc[:, latents_tensor] * self.r_mag[latents_tensor].exp())
+ self.b_mag[latents_tensor]
)
feature_magnitudes = self.activation_fn(magnitude_pre_activation)

Expand All @@ -516,11 +546,20 @@ def encode_gated(
return feature_acts

def encode_jumprelu(
self, x: Float[torch.Tensor, "... d_in"]
self,
x: Float[torch.Tensor, "... d_in"],
latents: torch.Tensor | None = None,
) -> Float[torch.Tensor, "... d_sae"]:
"""
Calculate SAE features from inputs
Computes the latent values of the Sparse Autoencoder (SAE) using a gated architecture. The activation values are
computed as:

activation_fn((x - b_dec) @ W_enc + b_enc) * 1[(x - b_dec) @ W_enc + b_enc > threshold]

The `latents` argument allows for the computation of a specific subset of the hidden values. If `latents` is not
provided, all latent values will be computed.
"""
latents_tensor = torch.arange(self.cfg.d_sae) if latents is None else latents

# move x to correct dtype
x = x.to(self.dtype)
Expand All @@ -535,20 +574,32 @@ def encode_jumprelu(
sae_in = self.hook_sae_input(x - (self.b_dec * self.cfg.apply_b_dec_to_input))

# "... d_in, d_in d_sae -> ... d_sae",
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
hidden_pre = self.hook_sae_acts_pre(
sae_in @ self.W_enc[:, latents_tensor] + self.b_enc[latents_tensor]
)

feature_acts = self.hook_sae_acts_post(
self.activation_fn(hidden_pre) * (hidden_pre > self.threshold)
self.activation_fn(hidden_pre)
* (hidden_pre > self.threshold[latents_tensor])
)

return feature_acts

def encode_standard(
self, x: Float[torch.Tensor, "... d_in"]
self,
x: Float[torch.Tensor, "... d_in"],
latents: torch.Tensor | None = None,
) -> Float[torch.Tensor, "... d_sae"]:
"""
Calculate SAE features from inputs
Computes the latent values of the Sparse Autoencoder (SAE) using a gated architecture. The activation values are
computed as:

activation_fn((x - b_dec) @ W_enc + b_enc)

The `latents` argument allows for the computation of a specific subset of the hidden values. If `latents` is not
provided, all latent values will be computed.
"""
latents_tensor = torch.arange(self.cfg.d_sae) if latents is None else latents

x = x.to(self.dtype)
x = self.reshape_fn_in(x)
Expand All @@ -559,7 +610,9 @@ def encode_standard(
sae_in = x - (self.b_dec * self.cfg.apply_b_dec_to_input)

# "... d_in, d_in d_sae -> ... d_sae",
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
hidden_pre = self.hook_sae_acts_pre(
sae_in @ self.W_enc[:, latents_tensor] + self.b_enc[latents_tensor]
)
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))

return feature_acts
Expand Down Expand Up @@ -606,7 +659,6 @@ def fold_activation_norm_scaling_factor(
self.cfg.normalize_activations = "none"

def save_model(self, path: str, sparsity: Optional[torch.Tensor] = None):

if not os.path.exists(path):
os.mkdir(path)

Expand All @@ -627,7 +679,6 @@ def save_model(self, path: str, sparsity: Optional[torch.Tensor] = None):
def load_from_pretrained(
cls, path: str, device: str = "cpu", dtype: str | None = None
) -> "SAE":

# get the config
config_path = os.path.join(path, SAE_CFG_PATH)
with open(config_path, "r") as f:
Expand Down Expand Up @@ -752,7 +803,6 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "SAE":
return cls(SAEConfig.from_dict(config_dict))

def turn_on_forward_pass_hook_z_reshaping(self):

assert self.cfg.hook_name.endswith(
"_z"
), "This method should only be called for hook_z SAEs."
Expand Down
19 changes: 6 additions & 13 deletions sae_lens/training/training_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ class TrainStepOutput:

@dataclass(kw_only=True)
class TrainingSAEConfig(SAEConfig):

# Sparsity Loss Calculations
l1_coefficient: float
lp_norm: float
Expand All @@ -55,7 +54,6 @@ class TrainingSAEConfig(SAEConfig):
def from_sae_runner_config(
cls, cfg: LanguageModelSAERunnerConfig
) -> "TrainingSAEConfig":

return cls(
# base config
architecture=cfg.architecture,
Expand Down Expand Up @@ -168,7 +166,6 @@ class TrainingSAE(SAE):
device: torch.device

def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):

base_sae_cfg = SAEConfig.from_dict(cfg.get_base_sae_cfg_dict())
super().__init__(base_sae_cfg)
self.cfg = cfg # type: ignore
Expand Down Expand Up @@ -203,18 +200,21 @@ def check_cfg_compatibility(self):
assert self.use_error_term is False, "Gated SAEs do not support error terms"

def encode_standard(
self, x: Float[torch.Tensor, "... d_in"]
self, x: Float[torch.Tensor, "... d_in"], latents: torch.Tensor | None = None
) -> Float[torch.Tensor, "... d_sae"]:
"""
Calcuate SAE features from inputs
Calcuate SAE features from inputs. The `latents` argument is ignored (this is just so the type signature matches
the parent class, which uses this argument to compute only a subset of the SAE hidden values)
"""
assert (
latents is None
), "Function `encode_standard` in training should always return activations for all latents"
feature_acts, _ = self.encode_with_hidden_pre_fn(x)
return feature_acts

def encode_with_hidden_pre(
self, x: Float[torch.Tensor, "... d_in"]
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:

x = x.to(self.dtype)
x = self.reshape_fn_in(x) # type: ignore
x = self.hook_sae_input(x)
Expand All @@ -235,7 +235,6 @@ def encode_with_hidden_pre(
def encode_with_hidden_pre_gated(
self, x: Float[torch.Tensor, "... d_in"]
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:

x = x.to(self.dtype)
x = self.reshape_fn_in(x) # type: ignore
x = self.hook_sae_input(x)
Expand Down Expand Up @@ -267,7 +266,6 @@ def forward(
self,
x: Float[torch.Tensor, "... d_in"],
) -> Float[torch.Tensor, "... d_in"]:

feature_acts, _ = self.encode_with_hidden_pre_fn(x)
sae_out = self.decode(feature_acts)

Expand All @@ -279,7 +277,6 @@ def training_forward_pass(
current_l1_coefficient: float,
dead_neuron_mask: Optional[torch.Tensor] = None,
) -> TrainStepOutput:

# do a forward pass to get SAE out, but we also need the
# hidden pre.
feature_acts, _ = self.encode_with_hidden_pre_fn(sae_in)
Expand All @@ -291,7 +288,6 @@ def training_forward_pass(

# GHOST GRADS
if self.cfg.use_ghost_grads and self.training and dead_neuron_mask is not None:

# first half of second forward pass
_, hidden_pre = self.encode_with_hidden_pre_fn(sae_in)
ghost_grad_loss = self.calculate_ghost_grad_loss(
Expand Down Expand Up @@ -362,7 +358,6 @@ def calculate_ghost_grad_loss(
hidden_pre: torch.Tensor,
dead_neuron_mask: torch.Tensor,
) -> torch.Tensor:

# 1.
residual = x - sae_out
l2_norm_residual = torch.norm(residual, dim=-1)
Expand Down Expand Up @@ -394,7 +389,6 @@ def calculate_ghost_grad_loss(

@torch.no_grad()
def _get_mse_loss_fn(self) -> Any:

def standard_mse_loss_fn(
preds: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
Expand All @@ -421,7 +415,6 @@ def load_from_pretrained(
device: str = "cpu",
dtype: str | None = None,
) -> "TrainingSAE":

# get the config
config_path = os.path.join(path, SAE_CFG_PATH)
with open(config_path, "r") as f:
Expand Down
16 changes: 15 additions & 1 deletion tests/unit/training/test_sae_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,21 @@ def test_sae_init(cfg: LanguageModelSAERunnerConfig):
assert sae.b_dec.shape == (cfg.d_in,)


@pytest.mark.parametrize("architecture", ["standard", "gated", "jumprelu"])
def test_sae_encode_with_different_architectures(architecture: str) -> None:
cfg = build_sae_cfg(architecture=architecture)
sae = SAE.from_dict(cfg.get_base_sae_cfg_dict())
assert isinstance(cfg.d_sae, int)

activations = torch.randn(10, 4, cfg.d_in, device=cfg.device)
latents = torch.randint(low=0, high=cfg.d_sae, size=(10,))
feature_activations = sae.encode(activations)
feature_activations_slice = sae.encode(activations, latents=latents)
torch.testing.assert_close(
feature_activations[..., latents], feature_activations_slice
)


def test_sae_fold_w_dec_norm(cfg: LanguageModelSAERunnerConfig):
sae = SAE.from_dict(cfg.get_base_sae_cfg_dict())
sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here.
Expand Down Expand Up @@ -106,7 +121,6 @@ def test_sae_fold_w_dec_norm(cfg: LanguageModelSAERunnerConfig):


def test_sae_fold_norm_scaling_factor(cfg: LanguageModelSAERunnerConfig):

norm_scaling_factor = 3.0

sae = SAE.from_dict(cfg.get_base_sae_cfg_dict())
Expand Down
Loading