Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Added four blocksparsity layouts #320

Conversation

igormolybogFB
Copy link
Contributor

@igormolybogFB igormolybogFB commented Jun 1, 2022

Four sparsity layouts from DeepSpeed are now available for blocksparse attention on xFormer:
Fixed
BSLongformer
BigBird
Variable
sparsity_configs.py (https://fburl.com/code/s2n7x8gs) contains flexible objects with many parameters. The default parameters can be invoked through the quick_ functionality in attention_patterns.py (https://fburl.com/code/hya0t9e7) the produced layouts can be turned into a pattern with layout_to_pattern function (https://fburl.com/code/1qmsntyj)
Detailed notes for the task: https://docs.google.com/document/d/1cBlZeccvphI-d5avLkgKwZ4ScXQM1q6igqvlyFDXdDc/edit?usp=sharing

What does this PR do?

Fixes T102859240.

Before submitting

  • [ x ] Did you have fun?
    • Make sure you had fun coding 🙃
  • [ x ] Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • [ x ] N/A
  • Did you make sure to update the docs?
    • [ x ] N/A
  • Did you write any new necessary tests?
    • [ x ] N/A
  • Did you update the changelog? (if needed)
    • [ x ] N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Four sparsity layouts from DeepSpeed are now available for blocksparse attention on xFormer:
Fixed
BSLongformer
BigBird
Variable
sparsity_configs.py (https://fburl.com/code/s2n7x8gs) contains flexible objects with many parameters. The default parameters can be invoked through the quick_ functionality in attention_patterns.py (https://fburl.com/code/hya0t9e7) the produced layouts can be turned into a pattern with layout_to_pattern function (https://fburl.com/code/1qmsntyj)
Detailed notes for the task: https://docs.google.com/document/d/1cBlZeccvphI-d5avLkgKwZ4ScXQM1q6igqvlyFDXdDc/edit?usp=sharing
@facebook-github-bot
Copy link
Contributor

Hi @igormolybogFB!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks!

@blefaudeux
Copy link
Contributor

hey @igormolybogFB, thanks for the PR ! @suchenzang I remember that you were asking for that at some point ?

@blefaudeux
Copy link
Contributor

blefaudeux commented Jun 1, 2022

@igormolybogFB if you can, I would recommend setting up pre-commit, it should be explained in CONTRIBUTING, it helps a lot with all the formatting/linting. Also, would you mind adding this to the changelog ? Thank you !

@blefaudeux blefaudeux changed the title added four blocksparsity layouts [feat] Added four blocksparsity layouts Jun 1, 2022
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 1, 2022
@facebook-github-bot
Copy link
Contributor

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
Arguments:
num_heads: required: an integer determining number of attention heads of the layer.
block: optional: an integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, `Block X Block`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: block could be renamed block_size to stay consistent with the rest of the repo

"""
def __init__(self, num_heads, block=16, different_layout_per_head=False):
"""Initialize the Sparsity Pattern Config.
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should be deleted probably since we don't have that in this repo?

linting errors fixed for std-blocksparsity-attention-layouts
Added copyright
fixed block -> block_size
removed DeepSpeed TODO
@codecov-commenter
Copy link

codecov-commenter commented Jun 1, 2022

Codecov Report

Merging #320 (878ea72) into main (9e27c8f) will increase coverage by 0.22%.
The diff coverage is 96.40%.

@@            Coverage Diff             @@
##             main     #320      +/-   ##
==========================================
+ Coverage   93.04%   93.26%   +0.22%     
==========================================
  Files          66       67       +1     
  Lines        3564     3814     +250     
==========================================
+ Hits         3316     3557     +241     
- Misses        248      257       +9     
Flag Coverage Δ
Python 93.26% <96.40%> (+0.22%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
xformers/components/attention/sparsity_config.py 96.17% <96.17%> (ø)
...formers/components/attention/attention_patterns.py 89.20% <100.00%> (+1.30%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 9e27c8f...878ea72. Read the comment docs.

@blefaudeux
Copy link
Contributor

@igormolybogFB perfect for the lint fixes ! Could you add this PR to the changelog (CHANGELOG.md) and add these patterns to the unit tests, to fix the code coverage ? Thanks !

@kashif
Copy link
Contributor

kashif commented Jun 2, 2022

@igormolybogFB can you kindly also remove the .DS_.. files?

Copy link
Contributor

@blefaudeux blefaudeux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me, thanks a lot @igormolybogFB ! Super complete PR, very clean in my view, thanks for all the tests in particular

@blefaudeux blefaudeux merged commit 6bc77d3 into facebookresearch:main Jun 3, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants