-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor attn layers #240
refactor attn layers #240
Conversation
c468655
to
aa912dd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work Vitaliy this looks clean and ready for caching
Running a 125m (with clip 6) using flash, triton, torch in wandb proj: attn_refactor. MFU is about where it should be. Training curves look identical. wandb proj: attn_refactor also has a 7b run with triton to verify MFU is where it should be (using 16 GPUs) (its at > 43% MFU; checks out) |
3b845b2
to
e8707f6
Compare
For inference, ideally we should have export test for onnx/torchscript but can be added in a followup PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks awesome @vchiley ! I added a few comments but they're minor.
Looks great. Much cleaner now.
^ can be done later once we move to 2.0. |
This PR refactors all attn_impl to use one cls which calls diff attn fns (flash, triton, torch) internally.
This means the state dict of one attn_impl is guaranteed to be the same as the other (this is tested now).
Note: this PR uses unpacked attn variants (ie q, k, v are chunked out from qkv), since kv_caching will concatenated tokens to kv and therefore q will be diff seq len from k and v.
Test also now compare all variants to each other with clip, qkln, and alibi (when possible). This is like 55 really quick tests.
@dskhudia lmk if this works for the inference work
@dskhudia also noted that we'll want to move to torch.nn.functional.scaled_dot_product_attention. maybe now, maybe later.
@samhavens re ift. Note: the flash path will unpad the input and do the attn calculation unpadded; the other impl should do calc on the entire pad input but it should be only up to the longest seq in the batch (not the max seq len of the model)