Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memmap shuffling #216

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
ee78d91
call dataloader next consistent with async loading
jbloomAus Jul 2, 2024
eda7db7
start using np.memmap (but with none of the actual advantages
jbloomAus Jul 2, 2024
bff803f
various changes, iterating toward np.memmap
jbloomAus Jul 6, 2024
c486f05
deduplicate small mathematical operations
Lewington-pitsos Jul 6, 2024
112960a
named parameters for shuffle idxs
Lewington-pitsos Jul 6, 2024
e01238d
add diagnostic prints, fix typing in activations_store
Lewington-pitsos Jul 6, 2024
545a33c
add diagnostic prints, fix typing in activations_store
Lewington-pitsos Jul 6, 2024
7eec1c2
rename shuffling methods
Lewington-pitsos Jul 6, 2024
a720ad0
add dataset override for cache activations runner, update test to use…
Lewington-pitsos Jul 7, 2024
18f4216
replicate error using activationstore alone
Lewington-pitsos Jul 7, 2024
c053910
replicate error using activationstore alone
Lewington-pitsos Jul 7, 2024
a76e4ed
fix float32 vs float16 memmap double size issue
Lewington-pitsos Jul 7, 2024
4b7fbdb
fix typing
Lewington-pitsos Jul 7, 2024
bf3e28d
skip test_load_cached_activations
Lewington-pitsos Jul 7, 2024
7a74cf6
get all unit test passing with new next_batch functionality
Lewington-pitsos Jul 7, 2024
9ca4aba
Merge branch 'main' into memmap_shuffling
Lewington-pitsos Jul 7, 2024
43ccbb8
merge with main
Lewington-pitsos Jul 7, 2024
d82b17c
format
Lewington-pitsos Jul 7, 2024
a955a7f
add small test of next_batch functionality
Lewington-pitsos Jul 7, 2024
ac45f4c
format again
Lewington-pitsos Jul 7, 2024
c80de83
map bfloat16 to np.float32
Lewington-pitsos Jul 7, 2024
fc147d5
reformat
Jul 7, 2024
4e8847b
add memory pinning
Jul 7, 2024
fd0d499
reformat
Lewington-pitsos Jul 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 96 additions & 66 deletions sae_lens/cache_activations_runner.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
import math
import os
from typing import Tuple

import numpy as np
import torch
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
from tqdm import tqdm

from sae_lens.config import DTYPE_MAP, CacheActivationsRunnerConfig
from sae_lens.load_model import load_model
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.activations_store import FILE_EXTENSION, ActivationsStore


class CacheActivationsRunner:

def __init__(self, cfg: CacheActivationsRunnerConfig):
def __init__(
self,
cfg: CacheActivationsRunnerConfig,
override_dataset: (
DatasetDict | Dataset | IterableDatasetDict | IterableDataset | None
) = None,
):
self.cfg = cfg
self.model = load_model(
model_class_name=cfg.model_class_name,
Expand All @@ -23,9 +30,10 @@ def __init__(self, cfg: CacheActivationsRunnerConfig):
self.activations_store = ActivationsStore.from_config(
self.model,
cfg,
override_dataset=override_dataset,
)

self.file_extension = "safetensors"
self.file_extension = FILE_EXTENSION

def __str__(self):
"""
Expand All @@ -40,24 +48,31 @@ def __str__(self):
if isinstance(self.cfg.dtype, torch.dtype)
else DTYPE_MAP[self.cfg.dtype].itemsize
)
tokens_in_buffer = (
self.cfg.n_batches_in_buffer
* self.cfg.store_batch_size_prompts
* self.cfg.context_size
)
total_training_tokens = self.cfg.training_tokens
total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9

return (
f"Activation Cache Runner:\n"
f"Total training tokens: {total_training_tokens}\n"
f"Number of buffers: {math.ceil(total_training_tokens / tokens_in_buffer)}\n"
f"Tokens per buffer: {tokens_in_buffer}\n"
f"Number of buffers: {self.n_buffers}\n"
f"Tokens per buffer: {self.tokens_in_buffer}\n"
f"Disk space required: {total_disk_space_gb:.2f} GB\n"
f"Configuration:\n"
f"{self.cfg}"
)

@property
def tokens_in_buffer(self):
return (
self.cfg.n_batches_in_buffer
* self.cfg.store_batch_size_prompts
* self.cfg.context_size
)

@property
def n_buffers(self):
return math.ceil(self.cfg.training_tokens / self.tokens_in_buffer)

@torch.no_grad()
def run(self):

Expand All @@ -73,94 +88,109 @@ def run(self):
else:
os.makedirs(new_cached_activations_path)

print(f"Started caching {self.cfg.training_tokens} activations")
tokens_per_buffer = (
self.cfg.store_batch_size_prompts
* self.cfg.context_size
* self.cfg.n_batches_in_buffer
)

n_buffers = math.ceil(self.cfg.training_tokens / tokens_per_buffer)

for i in tqdm(range(n_buffers), desc="Caching activations"):
for i in tqdm(range(self.n_buffers), desc="Caching activations"):
try:
buffer = self.activations_store.get_buffer(self.cfg.n_batches_in_buffer)

self.activations_store.save_buffer(
buffer, f"{new_cached_activations_path}/{i}.safetensors"
)
buffer = self.activations_store.get_buffer()
buffer_path = f"{new_cached_activations_path}/{i}.{self.file_extension}"
self.activations_store.save_buffer(buffer, buffer_path)

del buffer

if i % self.cfg.shuffle_every_n_buffers == 0 and i > 0:
if i > 0 and i % self.cfg.shuffle_every_n_buffers == 0:
# Shuffle the buffers on disk

# Do random pairwise shuffling between the last shuffle_every_n_buffers buffers
for _ in range(self.cfg.n_shuffles_with_last_section):
self.shuffle_activations_pairwise(
self.shuffle_two_random_buffers(
new_cached_activations_path,
buffer_idx_range=(i - self.cfg.shuffle_every_n_buffers, i),
start_idx=i - self.cfg.shuffle_every_n_buffers,
end_idx=i,
)

# Do more random pairwise shuffling between all the buffers
for _ in range(self.cfg.n_shuffles_in_entire_dir):
self.shuffle_activations_pairwise(
new_cached_activations_path,
buffer_idx_range=(0, i),
self.shuffle_two_random_buffers(
new_cached_activations_path, start_idx=0, end_idx=i
)
except StopIteration:
print(
f"Warning: Ran out of samples while filling the buffer at batch {i} before reaching {n_buffers} batches. No more caching will occur."
f"Warning: Ran out of samples while filling the buffer at batch {i} before reaching {self.n_buffers} batches. No more caching will occur."
)
break

# More final shuffling (mostly in case we didn't end on an i divisible by shuffle_every_n_buffers)
if n_buffers > 1:
if self.n_buffers > 1:
for _ in tqdm(range(self.cfg.n_shuffles_final), desc="Final shuffling"):
self.shuffle_activations_pairwise(
self.shuffle_two_random_buffers(
new_cached_activations_path,
buffer_idx_range=(0, n_buffers),
start_idx=0,
end_idx=self.n_buffers,
)

@torch.no_grad()
def shuffle_activations_pairwise(
self, datapath: str, buffer_idx_range: Tuple[int, int]
):
def shuffle_two_random_buffers(self, datapath: str, start_idx: int, end_idx: int):
"""
Shuffles two buffers on disk.
Shuffles two randomly selected buffers on disk.
"""
assert (
buffer_idx_range[0] < buffer_idx_range[1] - 1
start_idx < end_idx - 1
), "buffer_idx_range[0] must be smaller than buffer_idx_range[1] by at least 1"

buffer_idx1 = torch.randint(
buffer_idx_range[0], buffer_idx_range[1], (1,)
).item()
buffer_idx2 = torch.randint(
buffer_idx_range[0], buffer_idx_range[1], (1,)
).item()
buffer_idx1 = int(torch.randint(start_idx, end_idx, (1,)).item())
buffer_idx2 = int(torch.randint(start_idx, end_idx, (1,)).item())
while buffer_idx1 == buffer_idx2: # Make sure they're not the same
buffer_idx2 = torch.randint(
buffer_idx_range[0], buffer_idx_range[1], (1,)
).item()
buffer_idx2 = int(torch.randint(start_idx, end_idx, (1,)).item())

buffer1 = self.activations_store.load_buffer(
f"{datapath}/{buffer_idx1}.{self.file_extension}"
)
buffer2 = self.activations_store.load_buffer(
f"{datapath}/{buffer_idx2}.{self.file_extension}"
self.shuffle_two_buffers(datapath, buffer_idx1, buffer_idx2)

@torch.no_grad()
def shuffle_two_buffers(self, datapath: str, buffer_idx1: int, buffer_idx2: int):
path1 = f"{datapath}/{buffer_idx1}.{self.file_extension}"
path2 = f"{datapath}/{buffer_idx2}.{self.file_extension}"

buffer1 = self.activations_store.load_buffer(path1)
buffer2 = self.activations_store.load_buffer(path2)

# Get total size and create a joint buffer
total_size = buffer1.shape[0] + buffer2.shape[0]
joint_buffer = np.memmap(
f"{datapath}/temp_joint_buffer",
dtype=buffer1.dtype,
mode="w+",
shape=(total_size,) + buffer1.shape[1:],
)
joint_buffer = torch.cat([buffer1, buffer2])

# Shuffle them
joint_buffer = joint_buffer[torch.randperm(joint_buffer.shape[0])]
shuffled_buffer1 = joint_buffer[: buffer1.shape[0]]
shuffled_buffer2 = joint_buffer[buffer1.shape[0] :]
# Copy data to joint buffer
joint_buffer[: buffer1.shape[0]] = buffer1
joint_buffer[buffer1.shape[0] :] = buffer2

# Save them back
self.activations_store.save_buffer(
shuffled_buffer1, f"{datapath}/{buffer_idx1}.{self.file_extension}"
# Generate random permutation
permutation = np.random.permutation(total_size)

# Create shuffled buffers
shuffled_buffer1 = np.memmap(
f"{datapath}/temp_shuffled_1",
dtype=buffer1.dtype,
mode="w+",
shape=buffer1.shape,
)
self.activations_store.save_buffer(
shuffled_buffer2, f"{datapath}/{buffer_idx2}.{self.file_extension}"
shuffled_buffer2 = np.memmap(
f"{datapath}/temp_shuffled_2",
dtype=buffer2.dtype,
mode="w+",
shape=buffer2.shape,
)

# Apply permutation
shuffled_buffer1[:] = joint_buffer[permutation[: buffer1.shape[0]]]
shuffled_buffer2[:] = joint_buffer[permutation[buffer1.shape[0] :]]

# Save shuffled buffers back to original files
self.activations_store.save_buffer(shuffled_buffer1, path1)
self.activations_store.save_buffer(shuffled_buffer2, path2)

# Clean up temporary files
import os

os.remove(f"{datapath}/temp_joint_buffer")
os.remove(f"{datapath}/temp_shuffled_1")
os.remove(f"{datapath}/temp_shuffled_2")
4 changes: 2 additions & 2 deletions sae_lens/sae_training_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,15 @@ def _init_sae_group_b_decs(
"""

if self.cfg.b_dec_init_method == "geometric_median":
layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :]
layer_acts = torch.tensor(self.activations_store.storage_buffer)[:, 0, :]
# get geometric median of the activations if we're using those.
median = compute_geometric_median(
layer_acts,
maxiter=100,
).median
self.sae.initialize_b_dec_with_precalculated(median) # type: ignore
elif self.cfg.b_dec_init_method == "mean":
layer_acts = self.activations_store.storage_buffer.detach().cpu()[:, 0, :]
layer_acts = torch.tensor(self.activations_store.storage_buffer)[:, 0, :]
self.sae.initialize_b_dec_with_mean(layer_acts) # type: ignore

def save_checkpoint(
Expand Down
Loading
Loading