-
Notifications
You must be signed in to change notification settings - Fork 633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Added four blocksparsity layouts #320
[feat] Added four blocksparsity layouts #320
Conversation
Four sparsity layouts from DeepSpeed are now available for blocksparse attention on xFormer: Fixed BSLongformer BigBird Variable sparsity_configs.py (https://fburl.com/code/s2n7x8gs) contains flexible objects with many parameters. The default parameters can be invoked through the quick_ functionality in attention_patterns.py (https://fburl.com/code/hya0t9e7) the produced layouts can be turned into a pattern with layout_to_pattern function (https://fburl.com/code/1qmsntyj) Detailed notes for the task: https://docs.google.com/document/d/1cBlZeccvphI-d5avLkgKwZ4ScXQM1q6igqvlyFDXdDc/edit?usp=sharing
Hi @igormolybogFB! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks! |
hey @igormolybogFB, thanks for the PR ! @suchenzang I remember that you were asking for that at some point ? |
@igormolybogFB if you can, I would recommend setting up pre-commit, it should be explained in CONTRIBUTING, it helps a lot with all the formatting/linting. Also, would you mind adding this to the changelog ? Thank you ! |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial | ||
Arguments: | ||
num_heads: required: an integer determining number of attention heads of the layer. | ||
block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: block could be renamed block_size to stay consistent with the rest of the repo
""" | ||
def __init__(self, num_heads, block=16, different_layout_per_head=False): | ||
"""Initialize the Sparsity Pattern Config. | ||
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: should be deleted probably since we don't have that in this repo?
linting errors fixed for std-blocksparsity-attention-layouts
Added copyright fixed block -> block_size removed DeepSpeed TODO
Codecov Report
@@ Coverage Diff @@
## main #320 +/- ##
==========================================
+ Coverage 93.04% 93.26% +0.22%
==========================================
Files 66 67 +1
Lines 3564 3814 +250
==========================================
+ Hits 3316 3557 +241
- Misses 248 257 +9
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
@igormolybogFB perfect for the lint fixes ! Could you add this PR to the changelog (CHANGELOG.md) and add these patterns to the unit tests, to fix the code coverage ? Thanks ! |
@igormolybogFB can you kindly also remove the .DS_.. files? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me, thanks a lot @igormolybogFB ! Super complete PR, very clean in my view, thanks for all the tests in particular
Four sparsity layouts from DeepSpeed are now available for blocksparse attention on xFormer:
Fixed
BSLongformer
BigBird
Variable
sparsity_configs.py (https://fburl.com/code/s2n7x8gs) contains flexible objects with many parameters. The default parameters can be invoked through the quick_ functionality in attention_patterns.py (https://fburl.com/code/hya0t9e7) the produced layouts can be turned into a pattern with layout_to_pattern function (https://fburl.com/code/1qmsntyj)
Detailed notes for the task: https://docs.google.com/document/d/1cBlZeccvphI-d5avLkgKwZ4ScXQM1q6igqvlyFDXdDc/edit?usp=sharing
What does this PR do?
Fixes T102859240.
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.