Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memmap shuffling #216

Open
wants to merge 24 commits into
base: main
Choose a base branch
from

Conversation

Lewington-pitsos
Copy link
Contributor

@Lewington-pitsos Lewington-pitsos commented Jul 7, 2024

Description

A fork of the improved_io branch https://github.com/jbloomAus/SAELens/tree/improved_io

The overall intention is to load cached activations in the ActivationStore via memmap and add code to shuffle those buffers on disk using the CacheActivationsRunner. At this stage the shuffling is probably not sufficient at this stage and hasn't been altered since improved_io.

This feature probably isn't complete until we have a script which generates and shuffles buffers while also training at the same time and having the training code make use of those buffers.

  • fix test_cache_activations_runner_saving (a memmap was being initialized as float32 and then read as float16 causing the loaded memmap to seem twice as large as expected
  • fix all existing typing errors across ActivationsStore and CacheActivationsRunner
  • make ActivationsStore.get_batch use the new memmap strategy instead of the old dataloader
  • get all prior unit tests passing after this change
  • bring in line with main
  • add a new test to cover the new get_batch functionality
  • do a speed test

NOTE: with the new next_batch functionality we EITHER assume some other process is creating buffers in the cache for us OR generate activations on the fly without caching ever. At some point we will want the second option to build a cache as it goes, but IMO supporting this will be very cumbersome and should be left for future work.

Type of change

Please delete options that are not relevant.

  • New feature (non-breaking change which adds functionality)

You have tested formatting, typing and unit tests (acceptance tests not currently in use)

  • I have run make check-ci to check format and linting. (you can run make format to format code if needed.)

Performance Check.

If you have implemented a training change, please indicate precisely how performance changes with respect to the following metrics:

  • L0
  • CE Loss
  • MSE Loss
  • Feature Dashboard Interpretability

Please links to wandb dashboards with a control and test group.

@jbloomAus
Copy link
Owner

Some notes:

  • We've lost the 2x activation store getting shuffled in which may degrade results for SAE's trained with activations generated in real time. Maybe not worth retaining as it's messy code.
  • We need to pin the memory and move it to the SAE device asynchronously (eg: as here). If we are loading straight from disk then the notion of the activation store having a device makes less sense and we should let the caller move it.
  • We should eliminate shuffling entirely in the cache activations runner and replace it with a specific shuffling script (per Anthropic's blog post).
  • Since we're clearly going for performance improvements, it might be nice to do ab test with the old code for a short run where we pre-cache activations and then train an SAE.

@Lewington-pitsos
Copy link
Contributor Author

Lewington-pitsos commented Jul 8, 2024

Profiling has been completed on this branch: https://github.com/Lewington-pitsos/SAELens/tree/memmap_profiling in the memmap_profiling directory

it contains a generate script which creates activations and a profile script which loads 1000 batches of these activations. torch profiler.profile was used in each case on a cloud-based Nvidia A40 GPU.

The overall outcome is that the old script consumes around 26 CPU seconds and 3 GPU seconds,
the new version with pinning consumes 3 CPU seconds and 11 GPU milliseconds
and the new version without pinning consumes at 300 CPU milliseconds and 19 GPU milliseconds

This means that with no pinning we can achieve the following speeds:

Time taken to get 1000 batches: 0.39 seconds
Batches Per Second: 2538.15 batches per second
Samples Per Second: 649,766.1574397656 samples per second
Tokens Per Second: 665,360,545.21832 tokens per second

More thorough profiling would be required to give us a clearer picture on the tradeoff between pinning and not pinning.

memmap-pinning.txt - run with memmap and pinning as implemented on 4e8847b.
memmap-no-pin.txt- run with the condition above the pinning line always set to false, so instead of the pinning code we simply ran tensor_batch = tensor_batch.to(self.device)
safetensors-9adba61b03b.txt the same generate and profile ran from the main branch as of 9adba61b03b

Note that the profile script appears to hang when testing the main branch on9adba61b03b for some reason pertaining to the torch profiling, but the profiling will complete after 20 minutes or so.

@Lewington-pitsos
Copy link
Contributor Author

I'm still not 100% sure what would be required to make this PR merge-able by the way

@jbloomAus
Copy link
Owner

Huh, pinning slowed it down? I feel like somethings wrong here. Will read up and play with the code a bunch.

@Lewington-pitsos
Copy link
Contributor Author

Well, it led to 10x the CPU time but 1/2 the GPU time so possibly? I wasn't sure if a thorough comparison of pinning vs non-pinning was in scope for this PR

@jbloomAus
Copy link
Owner

I think pinning maybe only helps when it's facilitating asynchronous work (ie: we're also doing SAE training) which you weren't doing here? So maybe expected? I'm inclined to say sorting this out is out of scope for this PR and probably I'll merge once I've checked the training an SAE with this kind of shuffling as opposed to our last isn't worse (for non-cached activation training).

@Lewington-pitsos
Copy link
Contributor Author

I actually had no idea what pinning even was before this PR so I feel underqualified to make a final judgement on it in the context of a system designed to do bleeding edge machine learning training XD

@jbloomAus
Copy link
Owner

jbloomAus commented Jul 8, 2024

I actually had no idea what pinning even was before this PR so I feel underqualified to make a final judgement on it in the context of a system designed to do bleeding edge machine learning training XD

I'm fairly happy with this PR but want to spend a bit of time on it. Sorry for the delay. Writing the shuffling utility will keep us moving if you've got time. Sorry for the delay here.

@Lewington-pitsos
Copy link
Contributor Author

no stress fam, I'm working on a bit of side-research at the moment to do with a possible metric for measuring SAE quality, I'll ping this board if I have time to work on shuffling.

What I would basically do is make a shuffler which achieves approximately the same level of randomness as we had prior (but on disk)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants