From 8a7f2299a5cff3fd17e89c1aad7b30a7b7deb3d9 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 31 Jul 2024 18:34:19 -0700 Subject: [PATCH] little cleanup --- README.md | 20 +++++++++++++++++--- examples/benchmark.py | 33 ++++++++++++++++++--------------- 2 files changed, 35 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 343b750..bfd4d21 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ pip install . ## Usage -Here's a quick example of how to use the FlexAttention API with a custom attention mechanism: +Here's a quick example of how to use the FlexAttention API with a causal_mask: ```python from torch.nn.attention.flex_attention import flex_attention, create_block_mask @@ -40,11 +40,25 @@ block_mask: BlockMask = create_block_mask(causal_mask, 1, 1, Q_LEN, KV_LEN) # Use FlexAttention with a causal mask modification output = flex_attention(query, key, value, block_mask=causal_mask) ``` +## 📁 Structure -## 📚 Examples +Attention Gym is organized for easy exploration of attention mechanisms: -Check out the `examples/` directory for more detailed examples of different attention mechanisms and how to implement them using FlexAttention. +### 🔍 Key Locations +- `attn_gym.masks`: Examples creating `BlockMasks` +- `attn_gym.mods`: Examples creating `score_mods` +- `examples/`: Detailed implementations using FlexAttention +### 🏃‍♂️ Running Examples +Files are both importable and runnable. To explore: + +1. Run files directly: + ```Shell + python attn_gym/masks/document_mask.py + ``` +2. Most files generate visualizations when run. + +Check out the `examples` directory for end-to-end examples of using FlexAttention in real-world scenarios. ## Note Attention Gym is under active development, and we do not currently offer any backward compatibility guarantees. APIs and functionalities may change between versions. We recommend pinning to a specific version in your projects and carefully reviewing changes when upgrading. diff --git a/examples/benchmark.py b/examples/benchmark.py index 4a84542..4d5db0a 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -39,7 +39,7 @@ data_type = torch.float16 -# The kernels will utilize block sparisty to increase performance +# The kernels will utilize block sparsity to increase performance print(f"Using the default sparsity block size: {_DEFAULT_SPARSE_BLOCK_SIZE}") @@ -181,18 +181,7 @@ def test_mask( print(f"\nBlock Mask:\n{block_mask}") -def main(): - test_mask(mask_mod=causal_mask) - # Correctness check here is simple and only works with mask_fns and not actual score_mods - test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True) - - sliding_window_mask = generate_sliding_window(window_size=1024) - test_mask(mask_mod=sliding_window_mask) - - prefix_lm_mask = generate_prefix_lm_mask(prefix_length=1024) - test_mask(mask_mod=prefix_lm_mask) - - # Document masking +def run_document_masking(max_seq_len: int, num_docs: int): import random random.seed(0) @@ -209,12 +198,26 @@ def generate_random_lengths(total_length, num_documents): return lengths - max_seq_len, n_docs = 32768, 12 - lengths = generate_random_lengths(max_seq_len, n_docs) + lengths = generate_random_lengths(max_seq_len, num_docs) offsets = length_to_offsets(lengths, "cuda") document_causal_mask = generate_doc_mask_mod(causal_mask, offsets) test_mask(mask_mod=document_causal_mask, S=32768) +def main(): + test_mask(mask_mod=causal_mask) + # Correctness check here is simple and only works with mask_fns and not actual score_mods + test_mask(score_mod=generate_alibi_bias(16), skip_correctness=True) + + sliding_window_mask = generate_sliding_window(window_size=1024) + test_mask(mask_mod=sliding_window_mask) + + prefix_lm_mask = generate_prefix_lm_mask(prefix_length=1024) + test_mask(mask_mod=prefix_lm_mask) + + # Document masking + run_document_masking(max_seq_len=32768, num_docs=12) + + if __name__ == "__main__": main()