Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor attn layers #240

Merged
merged 18 commits into from
Mar 18, 2023
Merged

refactor attn layers #240

merged 18 commits into from
Mar 18, 2023

Conversation

vchiley
Copy link
Contributor

@vchiley vchiley commented Mar 16, 2023

This PR refactors all attn_impl to use one cls which calls diff attn fns (flash, triton, torch) internally.
This means the state dict of one attn_impl is guaranteed to be the same as the other (this is tested now).

Note: this PR uses unpacked attn variants (ie q, k, v are chunked out from qkv), since kv_caching will concatenated tokens to kv and therefore q will be diff seq len from k and v.

Test also now compare all variants to each other with clip, qkln, and alibi (when possible). This is like 55 really quick tests.

@dskhudia lmk if this works for the inference work

@dskhudia also noted that we'll want to move to torch.nn.functional.scaled_dot_product_attention. maybe now, maybe later.

@samhavens re ift. Note: the flash path will unpad the input and do the attn calculation unpadded; the other impl should do calc on the entire pad input but it should be only up to the longest seq in the batch (not the max seq len of the model)

@vchiley vchiley self-assigned this Mar 16, 2023
@vchiley vchiley force-pushed the attn_refactor branch 2 times, most recently from c468655 to aa912dd Compare March 16, 2023 06:59
@vchiley vchiley requested a review from alextrott16 March 16, 2023 16:29
@vchiley vchiley marked this pull request as draft March 17, 2023 02:20
@vchiley vchiley marked this pull request as ready for review March 17, 2023 06:56
@vchiley vchiley requested a review from samhavens March 17, 2023 16:00
Copy link
Contributor

@bmosaicml bmosaicml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work Vitaliy this looks clean and ready for caching

@vchiley
Copy link
Contributor Author

vchiley commented Mar 17, 2023

Running a 125m (with clip 6) using flash, triton, torch in wandb proj: attn_refactor. MFU is about where it should be. Training curves look identical.

wandb proj: attn_refactor also has a 7b run with triton to verify MFU is where it should be (using 16 GPUs) (its at > 43% MFU; checks out)
Also adding 7B with flash to see diff (not training to convergence, just for a 100ish steps). Flash version gets 41% MFU.
Note 7B run uses activation_checkpointing_reentrant: false and limit_all_gathers: true.

@vchiley vchiley force-pushed the attn_refactor branch 2 times, most recently from 3b845b2 to e8707f6 Compare March 17, 2023 21:38
@dskhudia
Copy link
Contributor

For inference, ideally we should have export test for onnx/torchscript but can be added in a followup PR.

Copy link
Contributor

@abhi-mosaic abhi-mosaic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks awesome @vchiley ! I added a few comments but they're minor.

examples/llm/src/models/layers/attention.py Outdated Show resolved Hide resolved
examples/llm/src/models/layers/attention.py Outdated Show resolved Hide resolved
examples/llm/src/models/layers/attention.py Show resolved Hide resolved
examples/llm/src/models/layers/attention.py Outdated Show resolved Hide resolved
examples/llm/src/models/layers/attention.py Outdated Show resolved Hide resolved
@dskhudia
Copy link
Contributor

Looks great. Much cleaner now.

also noted that we'll want to move to torch.nn.functional.scaled_dot_product_attention.

^ can be done later once we move to 2.0.

@vchiley vchiley merged commit 2b09481 into mosaicml:main Mar 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants