Skip to content

Commit

Permalink
tweaking probing formula following experimentation
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Jul 22, 2024
1 parent ac0b16b commit 320db8c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 24 deletions.
47 changes: 23 additions & 24 deletions sae_spelling/probing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from math import exp, log
from typing import Callable

import torch
Expand All @@ -7,8 +8,6 @@

from sae_spelling.util import DEFAULT_DEVICE

EPS = 1e-8


class LinearProbe(nn.Module):
"""
Expand Down Expand Up @@ -40,10 +39,10 @@ def train_multi_probe(
batch_size: int = 256,
num_epochs: int = 100,
lr: float = 0.01,
weight_decay: float = 1e-7,
end_lr: float = 1e-5,
weight_decay: float = 1e-6,
show_progress: bool = True,
verbose: bool = False,
early_stopping: bool = True,
device: torch.device = DEFAULT_DEVICE,
) -> LinearProbe:
"""
Expand Down Expand Up @@ -75,10 +74,10 @@ def train_multi_probe(
loss_fn=nn.BCEWithLogitsLoss(pos_weight=_calc_pos_weights(y_train)),
num_epochs=num_epochs,
lr=lr,
end_lr=end_lr,
weight_decay=weight_decay,
show_progress=show_progress,
verbose=verbose,
early_stopping=early_stopping,
)

return probe
Expand All @@ -90,10 +89,10 @@ def train_binary_probe(
batch_size: int = 256,
num_epochs: int = 100,
lr: float = 0.01,
weight_decay: float = 1e-7,
end_lr: float = 1e-5,
weight_decay: float = 1e-6,
show_progress: bool = True,
verbose: bool = False,
early_stopping: bool = True,
device: torch.device = DEFAULT_DEVICE,
) -> LinearProbe:
"""
Expand All @@ -116,33 +115,36 @@ def train_binary_probe(
batch_size=batch_size,
num_epochs=num_epochs,
lr=lr,
end_lr=end_lr,
weight_decay=weight_decay,
show_progress=show_progress,
verbose=verbose,
device=device,
early_stopping=early_stopping,
)


def _get_exponential_decay_scheduler(
optimizer: optim.Optimizer, start_lr: float, end_lr: float, num_steps: int
) -> optim.lr_scheduler.ExponentialLR:
gamma = exp(log(end_lr / start_lr) / num_steps)
return optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)


def _run_probe_training(
probe: LinearProbe,
loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
num_epochs: int,
lr: float,
end_lr: float,
weight_decay: float,
show_progress: bool,
verbose: bool,
early_stopping: bool,
) -> None:
probe.train()
optimizer = optim.Adam(probe.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="min",
factor=0.1,
patience=3,
eps=EPS,
scheduler = _get_exponential_decay_scheduler(
optimizer, start_lr=lr, end_lr=end_lr, num_steps=num_epochs
)
pbar = tqdm(total=len(loader), disable=not show_progress)
for epoch in range(num_epochs):
Expand All @@ -153,22 +155,19 @@ def _run_probe_training(
loss = loss_fn(logits, batch_labels)
loss.backward()
optimizer.step()
epoch_sum_loss += loss.item()
batch_loss = loss.item()
epoch_sum_loss += batch_loss
pbar.set_description(
f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.8f}"
f"Epoch {epoch + 1}/{num_epochs}, Loss: {batch_loss:.8f}"
)
pbar.update()
pbar.reset()
epoch_mean_loss = epoch_sum_loss / len(loader)
last_lr = scheduler.get_last_lr()
if last_lr[0] <= 2 * EPS and early_stopping:
if verbose:
print("Early stopping")
break
if verbose:
epoch_mean_loss = epoch_sum_loss / len(loader)
last_lr = scheduler.get_last_lr()
print(
f"epoch {epoch} sum loss: {epoch_sum_loss:.8f}, mean loss: {epoch_mean_loss:.8f} lr: {last_lr}"
)
scheduler.step(epoch_mean_loss)
scheduler.step()
pbar.close()
probe.eval()
14 changes: 14 additions & 0 deletions tests/test_probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,25 @@

from sae_spelling.probing import (
_calc_pos_weights,
_get_exponential_decay_scheduler,
train_binary_probe,
train_multi_probe,
)


def test_get_exponential_decay_scheduler_decays_from_lr_to_end_lr_over_num_epochs():
optim = torch.optim.Adam([torch.zeros(1)], lr=0.01)
scheduler = _get_exponential_decay_scheduler(
optim, start_lr=0.01, end_lr=1e-5, num_steps=100
)
lrs = []
for _ in range(100):
lrs.append(scheduler.get_last_lr()[0])
scheduler.step()
assert lrs[0] == pytest.approx(0.01, abs=1e-6)
assert lrs[-1] == pytest.approx(1e-5, abs=1e-6)


def test_calc_pos_weights_returns_1_for_equal_weights():
y_train = torch.cat([torch.ones(5), torch.zeros(5)]).unsqueeze(1)
pos_weights = _calc_pos_weights(y_train)
Expand Down

0 comments on commit 320db8c

Please sign in to comment.