Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fctls_bflsh: New OP that combines cutlass's fw + flash's bw #469

Merged
merged 2 commits into from
Oct 7, 2022

Conversation

danthe3rd
Copy link
Contributor

@danthe3rd danthe3rd commented Oct 7, 2022

Stack from ghstack (oldest at bottom):

PERFORMANCE

A100 bw
[--------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------]
                                     |  flash[flshatt]  |  vanilla  |  fwbw[fctls_bflsh]  |  48_chunk3_31735f9[cutlass]
1 threads: ------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=64    |        232.7     |   1813.9  |          240.3      |              391.7         
      f16 B=1024, M=197, H=1, K=64   |        577.0     |   4746.9  |          582.8      |              876.9         
      f16 B=32, M=197, H=16, K=64    |        296.2     |   2434.6  |          303.2      |              459.9         
      f16 B=32, M=197, H=16, K=128   |        682.8     |   4504.9  |          688.6      |              792.5         
      f16 B=16, M=197, H=16, K=64    |        164.9     |   1246.6  |          172.4      |              235.4         
      f16 B=16, M=197, H=16, K=128   |        385.9     |   2272.5  |          394.1      |              455.4         
      f16 B=1, M=4096, H=160, K=128  |      54810.6     |  45967.1  |        54876.8      |            62454.4         
      f16 B=2, M=4096, H=160, K=128  |      84422.1     |           |        84371.5      |            98791.3         
      f16 B=1, M=8192, H=160, K=128  |     216095.0     |           |       216170.6      |           248498.9         
      f16 B=2, M=8192, H=160, K=128  |     330754.4     |           |       331201.2      |           389207.8         
      f16 B=1024, M=82, H=8, K=64    |       1621.7     |   3820.0  |         1625.6      |             1872.4         
      f16 B=150, M=256, H=16, K=64   |       1625.7     |   4551.9  |         1629.7      |             2126.4         
      f16 B=64, M=256, H=12, K=64    |        567.7     |   1493.7  |          569.7      |              741.2         
      f16 B=256, M=4096, H=16, K=64  |     441302.3     |           |       441526.8      |           597391.6         
      f16 B=16, M=128, H=16, K=16    |        114.4     |    266.0  |          148.1      |               93.1         
      f16 B=16, M=128, H=16, K=32    |        112.6     |    269.7  |          243.3      |              127.9         
      f16 B=16, M=128, H=16, K=64    |        113.6     |    267.8  |          149.5      |              131.4         
      f16 B=16, M=128, H=16, K=128   |        158.6     |    298.0  |          160.3      |              175.6         
      f16 B=16, M=512, H=16, K=16    |        323.5     |   1203.4  |          325.9      |              558.2         
      f16 B=16, M=512, H=16, K=32    |        435.2     |   1305.2  |          436.8      |              653.5         
      f16 B=16, M=512, H=16, K=64    |        703.1     |   1543.5  |          706.4      |              848.8         
      f16 B=16, M=512, H=16, K=128   |       1586.5     |   1982.6  |         1588.1      |             1735.4         
      f16 B=16, M=1024, H=16, K=16   |       1252.2     |   4273.8  |         1251.6      |             2236.4         
      f16 B=16, M=1024, H=16, K=32   |       1621.6     |   4494.4  |         1623.8      |             2430.8         
      f16 B=16, M=1024, H=16, K=64   |       2376.6     |   5007.3  |         2381.7      |             3007.2         
      f16 B=16, M=1024, H=16, K=128  |       5647.1     |   5956.1  |         5650.4      |             6296.2         
      f16 B=64, M=128, H=16, K=16    |        145.6     |    439.2  |          148.6      |              165.5         
      f16 B=64, M=128, H=16, K=32    |        212.1     |    544.4  |          214.4      |              210.4         
      f16 B=64, M=128, H=16, K=64    |        310.1     |    767.3  |          312.5      |              330.4         
      f16 B=64, M=128, H=16, K=128   |        562.3     |   1226.6  |          564.7      |              605.5         
      f16 B=64, M=512, H=16, K=16    |       1202.3     |   4481.6  |         1203.0      |             2004.7         
      f16 B=64, M=512, H=16, K=32    |       1543.5     |   4966.9  |         1548.2      |             2379.3         
      f16 B=64, M=512, H=16, K=64    |       2421.7     |   5886.7  |         2424.0      |             3129.6         
      f16 B=64, M=512, H=16, K=128   |       5446.6     |   7713.8  |         5452.7      |             6054.1         
      f16 B=64, M=1024, H=16, K=16   |       4725.5     |  16880.0  |         4719.8      |             7929.4         
      f16 B=64, M=1024, H=16, K=32   |       5716.0     |  17869.0  |         5722.7      |             8876.8         
      f16 B=64, M=1024, H=16, K=64   |       8155.3     |  19924.2  |         8163.5      |            11198.7         
      f16 B=64, M=1024, H=16, K=128  |      19213.5     |  23735.2  |        19228.8      |            21618.9  

Times are in microseconds (us).

danthe3rd pushed a commit that referenced this pull request Oct 7, 2022
ghstack-source-id: 80b6cd182490d70feaeae7a72a6c8b2682e927ea
Pull Request resolved: #469
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 7, 2022
@danthe3rd danthe3rd requested review from fmassa and dianaml0 October 7, 2022 10:09
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks!

I think it might be good to think about refactoring how our OPs are setup so that the forward and backward are independent -- but that could be looked in the future.

Also, can you share some benchmark numbers?

…p on "fctls_bflsh: New OP that combines cutlass's fw + flash's bw"


**PERFORMANCE**

<details>
<summary>A100 bw</summary>

```
[--------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------]
                                     |  flash[flshatt]  |  vanilla  |  fwbw[fctls_bflsh]  |  48_chunk3_31735f9[cutlass]
1 threads: ------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=64    |        232.7     |   1813.9  |          240.3      |              391.7         
      f16 B=1024, M=197, H=1, K=64   |        577.0     |   4746.9  |          582.8      |              876.9         
      f16 B=32, M=197, H=16, K=64    |        296.2     |   2434.6  |          303.2      |              459.9         
      f16 B=32, M=197, H=16, K=128   |        682.8     |   4504.9  |          688.6      |              792.5         
      f16 B=16, M=197, H=16, K=64    |        164.9     |   1246.6  |          172.4      |              235.4         
      f16 B=16, M=197, H=16, K=128   |        385.9     |   2272.5  |          394.1      |              455.4         
      f16 B=1, M=4096, H=160, K=128  |      54810.6     |  45967.1  |        54876.8      |            62454.4         
      f16 B=2, M=4096, H=160, K=128  |      84422.1     |           |        84371.5      |            98791.3         
      f16 B=1, M=8192, H=160, K=128  |     216095.0     |           |       216170.6      |           248498.9         
      f16 B=2, M=8192, H=160, K=128  |     330754.4     |           |       331201.2      |           389207.8         
      f16 B=1024, M=82, H=8, K=64    |       1621.7     |   3820.0  |         1625.6      |             1872.4         
      f16 B=150, M=256, H=16, K=64   |       1625.7     |   4551.9  |         1629.7      |             2126.4         
      f16 B=64, M=256, H=12, K=64    |        567.7     |   1493.7  |          569.7      |              741.2         
      f16 B=256, M=4096, H=16, K=64  |     441302.3     |           |       441526.8      |           597391.6         
      f16 B=16, M=128, H=16, K=16    |        114.4     |    266.0  |          148.1      |               93.1         
      f16 B=16, M=128, H=16, K=32    |        112.6     |    269.7  |          243.3      |              127.9         
      f16 B=16, M=128, H=16, K=64    |        113.6     |    267.8  |          149.5      |              131.4         
      f16 B=16, M=128, H=16, K=128   |        158.6     |    298.0  |          160.3      |              175.6         
      f16 B=16, M=512, H=16, K=16    |        323.5     |   1203.4  |          325.9      |              558.2         
      f16 B=16, M=512, H=16, K=32    |        435.2     |   1305.2  |          436.8      |              653.5         
      f16 B=16, M=512, H=16, K=64    |        703.1     |   1543.5  |          706.4      |              848.8         
      f16 B=16, M=512, H=16, K=128   |       1586.5     |   1982.6  |         1588.1      |             1735.4         
      f16 B=16, M=1024, H=16, K=16   |       1252.2     |   4273.8  |         1251.6      |             2236.4         
      f16 B=16, M=1024, H=16, K=32   |       1621.6     |   4494.4  |         1623.8      |             2430.8         
      f16 B=16, M=1024, H=16, K=64   |       2376.6     |   5007.3  |         2381.7      |             3007.2         
      f16 B=16, M=1024, H=16, K=128  |       5647.1     |   5956.1  |         5650.4      |             6296.2         
      f16 B=64, M=128, H=16, K=16    |        145.6     |    439.2  |          148.6      |              165.5         
      f16 B=64, M=128, H=16, K=32    |        212.1     |    544.4  |          214.4      |              210.4         
      f16 B=64, M=128, H=16, K=64    |        310.1     |    767.3  |          312.5      |              330.4         
      f16 B=64, M=128, H=16, K=128   |        562.3     |   1226.6  |          564.7      |              605.5         
      f16 B=64, M=512, H=16, K=16    |       1202.3     |   4481.6  |         1203.0      |             2004.7         
      f16 B=64, M=512, H=16, K=32    |       1543.5     |   4966.9  |         1548.2      |             2379.3         
      f16 B=64, M=512, H=16, K=64    |       2421.7     |   5886.7  |         2424.0      |             3129.6         
      f16 B=64, M=512, H=16, K=128   |       5446.6     |   7713.8  |         5452.7      |             6054.1         
      f16 B=64, M=1024, H=16, K=16   |       4725.5     |  16880.0  |         4719.8      |             7929.4         
      f16 B=64, M=1024, H=16, K=32   |       5716.0     |  17869.0  |         5722.7      |             8876.8         
      f16 B=64, M=1024, H=16, K=64   |       8155.3     |  19924.2  |         8163.5      |            11198.7         
      f16 B=64, M=1024, H=16, K=128  |      19213.5     |  23735.2  |        19228.8      |            21618.9  

Times are in microseconds (us).
```
</details>

[ghstack-poisoned]
danthe3rd pushed a commit that referenced this pull request Oct 7, 2022
ghstack-source-id: 8c2c5e6c4aa8c7cdf9bc972b5434e6c0fed58c6f
Pull Request resolved: #469
@danthe3rd danthe3rd merged commit 6f069c9 into gh/danthe3rd/49/base Oct 7, 2022
danthe3rd pushed a commit that referenced this pull request Oct 7, 2022
ghstack-source-id: 8c2c5e6c4aa8c7cdf9bc972b5434e6c0fed58c6f
Pull Request resolved: #469
@danthe3rd danthe3rd deleted the gh/danthe3rd/49/head branch October 7, 2022 11:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants