Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor4: Split fw/bw operators #560

Merged
merged 16 commits into from
Dec 9, 2022

Conversation

danthe3rd
Copy link
Contributor

@danthe3rd danthe3rd commented Dec 7, 2022

Stack from ghstack (oldest at bottom):

NOTE: Dropout is entirely disabled until we figure out how to implement it properly

BW perf on A100
[-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------]                                                                                                                                         
                                     |  refactor[flshattB]  |  refactor[cutlassB]  |  flash2[flshatt]  |  vanilla  |  cutlass[cutlass]  |  flash1[flshatt]
1 threads: -----------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |          449.4       |                      |         447.6     |   2260.4  |         621.0      |                 
      f16 B=384, M=197, H=1, K=80    |          433.8       |                      |         431.8     |   1916.2  |         589.4      |                 
      f16 B=384, M=197, H=1, K=64    |          234.5       |                      |         233.6     |   1810.0  |         406.2      |         232.8   
      f16 B=1024, M=197, H=1, K=88   |         1112.9       |                      |        1111.3     |   5945.9  |        1557.7      |                 
      f16 B=1024, M=197, H=1, K=80   |         1073.3       |                      |        1073.5     |   5022.5  |        1480.4      |                 
      f16 B=1024, M=197, H=1, K=64   |          579.9       |                      |         579.4     |   4736.3  |         930.1      |         578.3   
      f16 B=512, M=197, H=1, K=80    |          546.0       |                      |         545.1     |   2539.2  |         749.0      |                 
      f16 B=32, M=197, H=16, K=80    |          560.2       |                      |         558.7     |   2579.7  |         748.1      |                 
      f16 B=32, M=197, H=16, K=64    |          298.4       |                      |         297.8     |   2430.3  |         480.1      |         295.8   
      f16 B=32, M=197, H=16, K=128   |          652.5       |                      |         647.4     |   4494.1  |         896.4      |         684.3   
      f16 B=256, M=197, H=1, K=88    |          330.1       |                      |         328.8     |   1527.1  |         449.9      |                 
      f16 B=16, M=197, H=16, K=88    |          331.2       |                      |         329.1     |   1540.0  |         445.2      |                 
      f16 B=16, M=197, H=16, K=64    |          198.7       |                      |         166.9     |   1243.8  |         246.4      |         165.9   
      f16 B=16, M=197, H=16, K=128   |          371.7       |                      |         367.4     |   2266.8  |         503.6      |         386.4   
      f16 B=1, M=4096, H=160, K=128  |        42897.1       |                      |       42915.9     |  45961.0  |       56457.5      |       54740.1   
      f16 B=2, M=4096, H=160, K=128  |        84681.5       |                      |       84646.9     |           |       93553.0      |       84262.0   
      f16 B=1, M=8192, H=160, K=128  |       165741.2       |                      |      165657.6     |           |      222850.4      |      215961.0   
      f16 B=2, M=8192, H=160, K=128  |       332314.2       |                      |      332593.3     |           |      364220.8      |      330727.9   
      f16 B=1024, M=82, H=8, K=64    |         1647.0       |                      |        1643.6     |   3821.1  |        1886.9      |        1621.7   
      f16 B=150, M=256, H=16, K=64   |         1632.0       |                      |        1626.4     |   4556.0  |        2302.2      |        1626.6   
      f16 B=64, M=256, H=12, K=64    |          570.4       |                      |         566.8     |   1499.1  |         778.4      |         567.1   
      f16 B=1, M=4096, H=16, K=40    |         2153.5       |                      |        2154.7     |   4193.7  |       23526.2      |                 
      f16 B=1, M=16384, H=16, K=40   |        27859.4       |                      |       27850.3     |           |      389891.9      |                 
      f16 B=256, M=4096, H=16, K=64  |       491968.1       |                      |      491246.6     |           |      729981.6      |      440392.8   
      f16 B=16, M=128, H=16, K=16    |          195.8       |                      |         118.3     |    276.8  |         129.3      |         120.4   
      f16 B=16, M=128, H=16, K=32    |          197.7       |                      |         117.6     |    276.8  |         128.8      |         113.7   
      f16 B=16, M=128, H=16, K=64    |          196.0       |                      |         122.9     |    324.0  |         129.9      |         113.6   
      f16 B=16, M=128, H=16, K=128   |          196.4       |                      |         170.0     |    331.5  |         182.2      |         157.9   
      f16 B=16, M=128, H=16, K=256   |                      |          791.8       |                   |    545.0  |         774.4      |                 
      f16 B=16, M=512, H=16, K=16    |          378.0       |                      |         375.1     |   1203.0  |         623.4      |         325.4   
      f16 B=16, M=512, H=16, K=32    |          436.7       |                      |         433.8     |   1306.2  |         711.6      |         432.0   
      f16 B=16, M=512, H=16, K=64    |          684.6       |                      |         680.6     |   1543.4  |         917.7      |         702.3   
      f16 B=16, M=512, H=16, K=128   |         1565.8       |                      |        1560.9     |   1983.1  |        1656.5      |        1580.7   
      f16 B=16, M=512, H=16, K=256   |                      |         8678.8       |                   |   2901.2  |        8652.9      |                 
      f16 B=16, M=1024, H=16, K=16   |         1432.3       |                      |        1435.8     |   4262.8  |        2422.2      |        1245.2   
      f16 B=16, M=1024, H=16, K=32   |         1608.0       |                      |        1605.9     |   4489.6  |        2697.4      |        1617.9   
      f16 B=16, M=1024, H=16, K=64   |         2549.3       |                      |        2541.9     |   4992.6  |        3341.4      |        2370.4   
      f16 B=16, M=1024, H=16, K=128  |         5034.4       |                      |        5023.6     |   5955.5  |        5903.8      |        5634.1   
      f16 B=16, M=1024, H=16, K=256  |                      |        31948.6       |                   |   7891.3  |       31850.4      |                 
      f16 B=64, M=128, H=16, K=16    |          193.9       |                      |         159.7     |    439.8  |         158.3      |         145.0   
      f16 B=64, M=128, H=16, K=32    |          217.0       |                      |         215.3     |    545.3  |         209.4      |         211.2   
      f16 B=64, M=128, H=16, K=64    |          316.0       |                      |         313.6     |    767.0  |         328.6      |         310.2   
      f16 B=64, M=128, H=16, K=128   |          559.7       |                      |         555.9     |   1229.2  |         637.6      |         561.6   
      f16 B=64, M=128, H=16, K=256   |                      |         2851.8       |                   |   2126.3  |        2839.0      |                 
      f16 B=64, M=512, H=16, K=16    |         1372.4       |                      |        1370.9     |   4484.9  |        2313.9      |        1219.4   
      f16 B=64, M=512, H=16, K=32    |         1542.0       |                      |        1539.2     |   4978.8  |        2706.7      |        1547.4   
      f16 B=64, M=512, H=16, K=64    |         2335.4       |                      |        2330.0     |   5885.5  |        3464.0      |        2417.8   
      f16 B=64, M=512, H=16, K=128   |         5357.0       |                      |        5347.8     |   7711.8  |        5909.5      |        5442.7   
      f16 B=64, M=512, H=16, K=256   |                      |        31136.1       |                   |  11501.8  |       30488.7      |                 
      f16 B=64, M=1024, H=16, K=16   |         5226.1       |                      |        5228.4     |  16908.5  |        9092.8      |        4713.7   
      f16 B=64, M=1024, H=16, K=32   |         5717.3       |                      |        5723.1     |  17895.8  |       10439.0      |        5705.3   
      f16 B=64, M=1024, H=16, K=64   |         8672.2       |                      |        8664.4     |  19929.7  |       12937.8      |        8137.3   
      f16 B=64, M=1024, H=16, K=128  |        19104.3       |                      |       19080.9     |  23711.0  |       20963.4      |       19169.7   
      f16 B=64, M=1024, H=16, K=256  |                      |       115163.2       |                   |  32729.8  |      113715.6      |                 

Times are in microseconds (us).

[ghstack-poisoned]
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 7, 2022
danthe3rd pushed a commit that referenced this pull request Dec 7, 2022
ghstack-source-id: ecfdae2d0fcfd9e24e3535f832cc53f378a75f74
Pull Request resolved: #560
danthe3rd pushed a commit that referenced this pull request Dec 7, 2022
ghstack-source-id: f61b6c2a352efc612252db71c765b78d070ec7a3
Pull Request resolved: #560
danthe3rd pushed a commit that referenced this pull request Dec 7, 2022
ghstack-source-id: c152503f42202510a29b0e7ac08441c13da8466c
Pull Request resolved: #560
<details>
<summary>BW perf on A100</summary>

```
[-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------]                                                                                                                                         
                                     |  refactor[flshattB]  |  refactor[cutlassB]  |  flash2[flshatt]  |  vanilla  |  cutlass[cutlass]  |  flash1[flshatt]
1 threads: -----------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |          449.4       |                      |         447.6     |   2260.4  |         621.0      |                 
      f16 B=384, M=197, H=1, K=80    |          433.8       |                      |         431.8     |   1916.2  |         589.4      |                 
      f16 B=384, M=197, H=1, K=64    |          234.5       |                      |         233.6     |   1810.0  |         406.2      |         232.8   
      f16 B=1024, M=197, H=1, K=88   |         1112.9       |                      |        1111.3     |   5945.9  |        1557.7      |                 
      f16 B=1024, M=197, H=1, K=80   |         1073.3       |                      |        1073.5     |   5022.5  |        1480.4      |                 
      f16 B=1024, M=197, H=1, K=64   |          579.9       |                      |         579.4     |   4736.3  |         930.1      |         578.3   
      f16 B=512, M=197, H=1, K=80    |          546.0       |                      |         545.1     |   2539.2  |         749.0      |                 
      f16 B=32, M=197, H=16, K=80    |          560.2       |                      |         558.7     |   2579.7  |         748.1      |                 
      f16 B=32, M=197, H=16, K=64    |          298.4       |                      |         297.8     |   2430.3  |         480.1      |         295.8   
      f16 B=32, M=197, H=16, K=128   |          652.5       |                      |         647.4     |   4494.1  |         896.4      |         684.3   
      f16 B=256, M=197, H=1, K=88    |          330.1       |                      |         328.8     |   1527.1  |         449.9      |                 
      f16 B=16, M=197, H=16, K=88    |          331.2       |                      |         329.1     |   1540.0  |         445.2      |                 
      f16 B=16, M=197, H=16, K=64    |          198.7       |                      |         166.9     |   1243.8  |         246.4      |         165.9   
      f16 B=16, M=197, H=16, K=128   |          371.7       |                      |         367.4     |   2266.8  |         503.6      |         386.4   
      f16 B=1, M=4096, H=160, K=128  |        42897.1       |                      |       42915.9     |  45961.0  |       56457.5      |       54740.1   
      f16 B=2, M=4096, H=160, K=128  |        84681.5       |                      |       84646.9     |           |       93553.0      |       84262.0   
      f16 B=1, M=8192, H=160, K=128  |       165741.2       |                      |      165657.6     |           |      222850.4      |      215961.0   
      f16 B=2, M=8192, H=160, K=128  |       332314.2       |                      |      332593.3     |           |      364220.8      |      330727.9   
      f16 B=1024, M=82, H=8, K=64    |         1647.0       |                      |        1643.6     |   3821.1  |        1886.9      |        1621.7   
      f16 B=150, M=256, H=16, K=64   |         1632.0       |                      |        1626.4     |   4556.0  |        2302.2      |        1626.6   
      f16 B=64, M=256, H=12, K=64    |          570.4       |                      |         566.8     |   1499.1  |         778.4      |         567.1   
      f16 B=1, M=4096, H=16, K=40    |         2153.5       |                      |        2154.7     |   4193.7  |       23526.2      |                 
      f16 B=1, M=16384, H=16, K=40   |        27859.4       |                      |       27850.3     |           |      389891.9      |                 
      f16 B=256, M=4096, H=16, K=64  |       491968.1       |                      |      491246.6     |           |      729981.6      |      440392.8   
      f16 B=16, M=128, H=16, K=16    |          195.8       |                      |         118.3     |    276.8  |         129.3      |         120.4   
      f16 B=16, M=128, H=16, K=32    |          197.7       |                      |         117.6     |    276.8  |         128.8      |         113.7   
      f16 B=16, M=128, H=16, K=64    |          196.0       |                      |         122.9     |    324.0  |         129.9      |         113.6   
      f16 B=16, M=128, H=16, K=128   |          196.4       |                      |         170.0     |    331.5  |         182.2      |         157.9   
      f16 B=16, M=128, H=16, K=256   |                      |          791.8       |                   |    545.0  |         774.4      |                 
      f16 B=16, M=512, H=16, K=16    |          378.0       |                      |         375.1     |   1203.0  |         623.4      |         325.4   
      f16 B=16, M=512, H=16, K=32    |          436.7       |                      |         433.8     |   1306.2  |         711.6      |         432.0   
      f16 B=16, M=512, H=16, K=64    |          684.6       |                      |         680.6     |   1543.4  |         917.7      |         702.3   
      f16 B=16, M=512, H=16, K=128   |         1565.8       |                      |        1560.9     |   1983.1  |        1656.5      |        1580.7   
      f16 B=16, M=512, H=16, K=256   |                      |         8678.8       |                   |   2901.2  |        8652.9      |                 
      f16 B=16, M=1024, H=16, K=16   |         1432.3       |                      |        1435.8     |   4262.8  |        2422.2      |        1245.2   
      f16 B=16, M=1024, H=16, K=32   |         1608.0       |                      |        1605.9     |   4489.6  |        2697.4      |        1617.9   
      f16 B=16, M=1024, H=16, K=64   |         2549.3       |                      |        2541.9     |   4992.6  |        3341.4      |        2370.4   
      f16 B=16, M=1024, H=16, K=128  |         5034.4       |                      |        5023.6     |   5955.5  |        5903.8      |        5634.1   
      f16 B=16, M=1024, H=16, K=256  |                      |        31948.6       |                   |   7891.3  |       31850.4      |                 
      f16 B=64, M=128, H=16, K=16    |          193.9       |                      |         159.7     |    439.8  |         158.3      |         145.0   
      f16 B=64, M=128, H=16, K=32    |          217.0       |                      |         215.3     |    545.3  |         209.4      |         211.2   
      f16 B=64, M=128, H=16, K=64    |          316.0       |                      |         313.6     |    767.0  |         328.6      |         310.2   
      f16 B=64, M=128, H=16, K=128   |          559.7       |                      |         555.9     |   1229.2  |         637.6      |         561.6   
      f16 B=64, M=128, H=16, K=256   |                      |         2851.8       |                   |   2126.3  |        2839.0      |                 
      f16 B=64, M=512, H=16, K=16    |         1372.4       |                      |        1370.9     |   4484.9  |        2313.9      |        1219.4   
      f16 B=64, M=512, H=16, K=32    |         1542.0       |                      |        1539.2     |   4978.8  |        2706.7      |        1547.4   
      f16 B=64, M=512, H=16, K=64    |         2335.4       |                      |        2330.0     |   5885.5  |        3464.0      |        2417.8   
      f16 B=64, M=512, H=16, K=128   |         5357.0       |                      |        5347.8     |   7711.8  |        5909.5      |        5442.7   
      f16 B=64, M=512, H=16, K=256   |                      |        31136.1       |                   |  11501.8  |       30488.7      |                 
      f16 B=64, M=1024, H=16, K=16   |         5226.1       |                      |        5228.4     |  16908.5  |        9092.8      |        4713.7   
      f16 B=64, M=1024, H=16, K=32   |         5717.3       |                      |        5723.1     |  17895.8  |       10439.0      |        5705.3   
      f16 B=64, M=1024, H=16, K=64   |         8672.2       |                      |        8664.4     |  19929.7  |       12937.8      |        8137.3   
      f16 B=64, M=1024, H=16, K=128  |        19104.3       |                      |       19080.9     |  23711.0  |       20963.4      |       19169.7   
      f16 B=64, M=1024, H=16, K=256  |                      |       115163.2       |                   |  32729.8  |      113715.6      |                 

Times are in microseconds (us).
```
</details>

[ghstack-poisoned]
danthe3rd pushed a commit that referenced this pull request Dec 7, 2022
ghstack-source-id: f4446e42dc1645e7cafb81934b7a16dbbb03a57e
Pull Request resolved: #560
danthe3rd added 6 commits December 7, 2022 19:28
<details>
<summary>BW perf on A100</summary>

```
[-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------]                                                                                                                                         
                                     |  refactor[flshattB]  |  refactor[cutlassB]  |  flash2[flshatt]  |  vanilla  |  cutlass[cutlass]  |  flash1[flshatt]
1 threads: -----------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |          449.4       |                      |         447.6     |   2260.4  |         621.0      |                 
      f16 B=384, M=197, H=1, K=80    |          433.8       |                      |         431.8     |   1916.2  |         589.4      |                 
      f16 B=384, M=197, H=1, K=64    |          234.5       |                      |         233.6     |   1810.0  |         406.2      |         232.8   
      f16 B=1024, M=197, H=1, K=88   |         1112.9       |                      |        1111.3     |   5945.9  |        1557.7      |                 
      f16 B=1024, M=197, H=1, K=80   |         1073.3       |                      |        1073.5     |   5022.5  |        1480.4      |                 
      f16 B=1024, M=197, H=1, K=64   |          579.9       |                      |         579.4     |   4736.3  |         930.1      |         578.3   
      f16 B=512, M=197, H=1, K=80    |          546.0       |                      |         545.1     |   2539.2  |         749.0      |                 
      f16 B=32, M=197, H=16, K=80    |          560.2       |                      |         558.7     |   2579.7  |         748.1      |                 
      f16 B=32, M=197, H=16, K=64    |          298.4       |                      |         297.8     |   2430.3  |         480.1      |         295.8   
      f16 B=32, M=197, H=16, K=128   |          652.5       |                      |         647.4     |   4494.1  |         896.4      |         684.3   
      f16 B=256, M=197, H=1, K=88    |          330.1       |                      |         328.8     |   1527.1  |         449.9      |                 
      f16 B=16, M=197, H=16, K=88    |          331.2       |                      |         329.1     |   1540.0  |         445.2      |                 
      f16 B=16, M=197, H=16, K=64    |          198.7       |                      |         166.9     |   1243.8  |         246.4      |         165.9   
      f16 B=16, M=197, H=16, K=128   |          371.7       |                      |         367.4     |   2266.8  |         503.6      |         386.4   
      f16 B=1, M=4096, H=160, K=128  |        42897.1       |                      |       42915.9     |  45961.0  |       56457.5      |       54740.1   
      f16 B=2, M=4096, H=160, K=128  |        84681.5       |                      |       84646.9     |           |       93553.0      |       84262.0   
      f16 B=1, M=8192, H=160, K=128  |       165741.2       |                      |      165657.6     |           |      222850.4      |      215961.0   
      f16 B=2, M=8192, H=160, K=128  |       332314.2       |                      |      332593.3     |           |      364220.8      |      330727.9   
      f16 B=1024, M=82, H=8, K=64    |         1647.0       |                      |        1643.6     |   3821.1  |        1886.9      |        1621.7   
      f16 B=150, M=256, H=16, K=64   |         1632.0       |                      |        1626.4     |   4556.0  |        2302.2      |        1626.6   
      f16 B=64, M=256, H=12, K=64    |          570.4       |                      |         566.8     |   1499.1  |         778.4      |         567.1   
      f16 B=1, M=4096, H=16, K=40    |         2153.5       |                      |        2154.7     |   4193.7  |       23526.2      |                 
      f16 B=1, M=16384, H=16, K=40   |        27859.4       |                      |       27850.3     |           |      389891.9      |                 
      f16 B=256, M=4096, H=16, K=64  |       491968.1       |                      |      491246.6     |           |      729981.6      |      440392.8   
      f16 B=16, M=128, H=16, K=16    |          195.8       |                      |         118.3     |    276.8  |         129.3      |         120.4   
      f16 B=16, M=128, H=16, K=32    |          197.7       |                      |         117.6     |    276.8  |         128.8      |         113.7   
      f16 B=16, M=128, H=16, K=64    |          196.0       |                      |         122.9     |    324.0  |         129.9      |         113.6   
      f16 B=16, M=128, H=16, K=128   |          196.4       |                      |         170.0     |    331.5  |         182.2      |         157.9   
      f16 B=16, M=128, H=16, K=256   |                      |          791.8       |                   |    545.0  |         774.4      |                 
      f16 B=16, M=512, H=16, K=16    |          378.0       |                      |         375.1     |   1203.0  |         623.4      |         325.4   
      f16 B=16, M=512, H=16, K=32    |          436.7       |                      |         433.8     |   1306.2  |         711.6      |         432.0   
      f16 B=16, M=512, H=16, K=64    |          684.6       |                      |         680.6     |   1543.4  |         917.7      |         702.3   
      f16 B=16, M=512, H=16, K=128   |         1565.8       |                      |        1560.9     |   1983.1  |        1656.5      |        1580.7   
      f16 B=16, M=512, H=16, K=256   |                      |         8678.8       |                   |   2901.2  |        8652.9      |                 
      f16 B=16, M=1024, H=16, K=16   |         1432.3       |                      |        1435.8     |   4262.8  |        2422.2      |        1245.2   
      f16 B=16, M=1024, H=16, K=32   |         1608.0       |                      |        1605.9     |   4489.6  |        2697.4      |        1617.9   
      f16 B=16, M=1024, H=16, K=64   |         2549.3       |                      |        2541.9     |   4992.6  |        3341.4      |        2370.4   
      f16 B=16, M=1024, H=16, K=128  |         5034.4       |                      |        5023.6     |   5955.5  |        5903.8      |        5634.1   
      f16 B=16, M=1024, H=16, K=256  |                      |        31948.6       |                   |   7891.3  |       31850.4      |                 
      f16 B=64, M=128, H=16, K=16    |          193.9       |                      |         159.7     |    439.8  |         158.3      |         145.0   
      f16 B=64, M=128, H=16, K=32    |          217.0       |                      |         215.3     |    545.3  |         209.4      |         211.2   
      f16 B=64, M=128, H=16, K=64    |          316.0       |                      |         313.6     |    767.0  |         328.6      |         310.2   
      f16 B=64, M=128, H=16, K=128   |          559.7       |                      |         555.9     |   1229.2  |         637.6      |         561.6   
      f16 B=64, M=128, H=16, K=256   |                      |         2851.8       |                   |   2126.3  |        2839.0      |                 
      f16 B=64, M=512, H=16, K=16    |         1372.4       |                      |        1370.9     |   4484.9  |        2313.9      |        1219.4   
      f16 B=64, M=512, H=16, K=32    |         1542.0       |                      |        1539.2     |   4978.8  |        2706.7      |        1547.4   
      f16 B=64, M=512, H=16, K=64    |         2335.4       |                      |        2330.0     |   5885.5  |        3464.0      |        2417.8   
      f16 B=64, M=512, H=16, K=128   |         5357.0       |                      |        5347.8     |   7711.8  |        5909.5      |        5442.7   
      f16 B=64, M=512, H=16, K=256   |                      |        31136.1       |                   |  11501.8  |       30488.7      |                 
      f16 B=64, M=1024, H=16, K=16   |         5226.1       |                      |        5228.4     |  16908.5  |        9092.8      |        4713.7   
      f16 B=64, M=1024, H=16, K=32   |         5717.3       |                      |        5723.1     |  17895.8  |       10439.0      |        5705.3   
      f16 B=64, M=1024, H=16, K=64   |         8672.2       |                      |        8664.4     |  19929.7  |       12937.8      |        8137.3   
      f16 B=64, M=1024, H=16, K=128  |        19104.3       |                      |       19080.9     |  23711.0  |       20963.4      |       19169.7   
      f16 B=64, M=1024, H=16, K=256  |                      |       115163.2       |                   |  32729.8  |      113715.6      |                 

Times are in microseconds (us).
```
</details>

[ghstack-poisoned]
…ors"


<details>
<summary>BW perf on A100</summary>

```
[-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------]                                                                                                                                         
                                     |  refactor[flshattB]  |  refactor[cutlassB]  |  flash2[flshatt]  |  vanilla  |  cutlass[cutlass]  |  flash1[flshatt]
1 threads: -----------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |          449.4       |                      |         447.6     |   2260.4  |         621.0      |                 
      f16 B=384, M=197, H=1, K=80    |          433.8       |                      |         431.8     |   1916.2  |         589.4      |                 
      f16 B=384, M=197, H=1, K=64    |          234.5       |                      |         233.6     |   1810.0  |         406.2      |         232.8   
      f16 B=1024, M=197, H=1, K=88   |         1112.9       |                      |        1111.3     |   5945.9  |        1557.7      |                 
      f16 B=1024, M=197, H=1, K=80   |         1073.3       |                      |        1073.5     |   5022.5  |        1480.4      |                 
      f16 B=1024, M=197, H=1, K=64   |          579.9       |                      |         579.4     |   4736.3  |         930.1      |         578.3   
      f16 B=512, M=197, H=1, K=80    |          546.0       |                      |         545.1     |   2539.2  |         749.0      |                 
      f16 B=32, M=197, H=16, K=80    |          560.2       |                      |         558.7     |   2579.7  |         748.1      |                 
      f16 B=32, M=197, H=16, K=64    |          298.4       |                      |         297.8     |   2430.3  |         480.1      |         295.8   
      f16 B=32, M=197, H=16, K=128   |          652.5       |                      |         647.4     |   4494.1  |         896.4      |         684.3   
      f16 B=256, M=197, H=1, K=88    |          330.1       |                      |         328.8     |   1527.1  |         449.9      |                 
      f16 B=16, M=197, H=16, K=88    |          331.2       |                      |         329.1     |   1540.0  |         445.2      |                 
      f16 B=16, M=197, H=16, K=64    |          198.7       |                      |         166.9     |   1243.8  |         246.4      |         165.9   
      f16 B=16, M=197, H=16, K=128   |          371.7       |                      |         367.4     |   2266.8  |         503.6      |         386.4   
      f16 B=1, M=4096, H=160, K=128  |        42897.1       |                      |       42915.9     |  45961.0  |       56457.5      |       54740.1   
      f16 B=2, M=4096, H=160, K=128  |        84681.5       |                      |       84646.9     |           |       93553.0      |       84262.0   
      f16 B=1, M=8192, H=160, K=128  |       165741.2       |                      |      165657.6     |           |      222850.4      |      215961.0   
      f16 B=2, M=8192, H=160, K=128  |       332314.2       |                      |      332593.3     |           |      364220.8      |      330727.9   
      f16 B=1024, M=82, H=8, K=64    |         1647.0       |                      |        1643.6     |   3821.1  |        1886.9      |        1621.7   
      f16 B=150, M=256, H=16, K=64   |         1632.0       |                      |        1626.4     |   4556.0  |        2302.2      |        1626.6   
      f16 B=64, M=256, H=12, K=64    |          570.4       |                      |         566.8     |   1499.1  |         778.4      |         567.1   
      f16 B=1, M=4096, H=16, K=40    |         2153.5       |                      |        2154.7     |   4193.7  |       23526.2      |                 
      f16 B=1, M=16384, H=16, K=40   |        27859.4       |                      |       27850.3     |           |      389891.9      |                 
      f16 B=256, M=4096, H=16, K=64  |       491968.1       |                      |      491246.6     |           |      729981.6      |      440392.8   
      f16 B=16, M=128, H=16, K=16    |          195.8       |                      |         118.3     |    276.8  |         129.3      |         120.4   
      f16 B=16, M=128, H=16, K=32    |          197.7       |                      |         117.6     |    276.8  |         128.8      |         113.7   
      f16 B=16, M=128, H=16, K=64    |          196.0       |                      |         122.9     |    324.0  |         129.9      |         113.6   
      f16 B=16, M=128, H=16, K=128   |          196.4       |                      |         170.0     |    331.5  |         182.2      |         157.9   
      f16 B=16, M=128, H=16, K=256   |                      |          791.8       |                   |    545.0  |         774.4      |                 
      f16 B=16, M=512, H=16, K=16    |          378.0       |                      |         375.1     |   1203.0  |         623.4      |         325.4   
      f16 B=16, M=512, H=16, K=32    |          436.7       |                      |         433.8     |   1306.2  |         711.6      |         432.0   
      f16 B=16, M=512, H=16, K=64    |          684.6       |                      |         680.6     |   1543.4  |         917.7      |         702.3   
      f16 B=16, M=512, H=16, K=128   |         1565.8       |                      |        1560.9     |   1983.1  |        1656.5      |        1580.7   
      f16 B=16, M=512, H=16, K=256   |                      |         8678.8       |                   |   2901.2  |        8652.9      |                 
      f16 B=16, M=1024, H=16, K=16   |         1432.3       |                      |        1435.8     |   4262.8  |        2422.2      |        1245.2   
      f16 B=16, M=1024, H=16, K=32   |         1608.0       |                      |        1605.9     |   4489.6  |        2697.4      |        1617.9   
      f16 B=16, M=1024, H=16, K=64   |         2549.3       |                      |        2541.9     |   4992.6  |        3341.4      |        2370.4   
      f16 B=16, M=1024, H=16, K=128  |         5034.4       |                      |        5023.6     |   5955.5  |        5903.8      |        5634.1   
      f16 B=16, M=1024, H=16, K=256  |                      |        31948.6       |                   |   7891.3  |       31850.4      |                 
      f16 B=64, M=128, H=16, K=16    |          193.9       |                      |         159.7     |    439.8  |         158.3      |         145.0   
      f16 B=64, M=128, H=16, K=32    |          217.0       |                      |         215.3     |    545.3  |         209.4      |         211.2   
      f16 B=64, M=128, H=16, K=64    |          316.0       |                      |         313.6     |    767.0  |         328.6      |         310.2   
      f16 B=64, M=128, H=16, K=128   |          559.7       |                      |         555.9     |   1229.2  |         637.6      |         561.6   
      f16 B=64, M=128, H=16, K=256   |                      |         2851.8       |                   |   2126.3  |        2839.0      |                 
      f16 B=64, M=512, H=16, K=16    |         1372.4       |                      |        1370.9     |   4484.9  |        2313.9      |        1219.4   
      f16 B=64, M=512, H=16, K=32    |         1542.0       |                      |        1539.2     |   4978.8  |        2706.7      |        1547.4   
      f16 B=64, M=512, H=16, K=64    |         2335.4       |                      |        2330.0     |   5885.5  |        3464.0      |        2417.8   
      f16 B=64, M=512, H=16, K=128   |         5357.0       |                      |        5347.8     |   7711.8  |        5909.5      |        5442.7   
      f16 B=64, M=512, H=16, K=256   |                      |        31136.1       |                   |  11501.8  |       30488.7      |                 
      f16 B=64, M=1024, H=16, K=16   |         5226.1       |                      |        5228.4     |  16908.5  |        9092.8      |        4713.7   
      f16 B=64, M=1024, H=16, K=32   |         5717.3       |                      |        5723.1     |  17895.8  |       10439.0      |        5705.3   
      f16 B=64, M=1024, H=16, K=64   |         8672.2       |                      |        8664.4     |  19929.7  |       12937.8      |        8137.3   
      f16 B=64, M=1024, H=16, K=128  |        19104.3       |                      |       19080.9     |  23711.0  |       20963.4      |       19169.7   
      f16 B=64, M=1024, H=16, K=256  |                      |       115163.2       |                   |  32729.8  |      113715.6      |                 

Times are in microseconds (us).
```
</details>

[ghstack-poisoned]
<details>
<summary>BW perf on A100</summary>

```
[-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------]                                                                                                                                         
                                     |  refactor[flshattB]  |  refactor[cutlassB]  |  flash2[flshatt]  |  vanilla  |  cutlass[cutlass]  |  flash1[flshatt]
1 threads: -----------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |          449.4       |                      |         447.6     |   2260.4  |         621.0      |                 
      f16 B=384, M=197, H=1, K=80    |          433.8       |                      |         431.8     |   1916.2  |         589.4      |                 
      f16 B=384, M=197, H=1, K=64    |          234.5       |                      |         233.6     |   1810.0  |         406.2      |         232.8   
      f16 B=1024, M=197, H=1, K=88   |         1112.9       |                      |        1111.3     |   5945.9  |        1557.7      |                 
      f16 B=1024, M=197, H=1, K=80   |         1073.3       |                      |        1073.5     |   5022.5  |        1480.4      |                 
      f16 B=1024, M=197, H=1, K=64   |          579.9       |                      |         579.4     |   4736.3  |         930.1      |         578.3   
      f16 B=512, M=197, H=1, K=80    |          546.0       |                      |         545.1     |   2539.2  |         749.0      |                 
      f16 B=32, M=197, H=16, K=80    |          560.2       |                      |         558.7     |   2579.7  |         748.1      |                 
      f16 B=32, M=197, H=16, K=64    |          298.4       |                      |         297.8     |   2430.3  |         480.1      |         295.8   
      f16 B=32, M=197, H=16, K=128   |          652.5       |                      |         647.4     |   4494.1  |         896.4      |         684.3   
      f16 B=256, M=197, H=1, K=88    |          330.1       |                      |         328.8     |   1527.1  |         449.9      |                 
      f16 B=16, M=197, H=16, K=88    |          331.2       |                      |         329.1     |   1540.0  |         445.2      |                 
      f16 B=16, M=197, H=16, K=64    |          198.7       |                      |         166.9     |   1243.8  |         246.4      |         165.9   
      f16 B=16, M=197, H=16, K=128   |          371.7       |                      |         367.4     |   2266.8  |         503.6      |         386.4   
      f16 B=1, M=4096, H=160, K=128  |        42897.1       |                      |       42915.9     |  45961.0  |       56457.5      |       54740.1   
      f16 B=2, M=4096, H=160, K=128  |        84681.5       |                      |       84646.9     |           |       93553.0      |       84262.0   
      f16 B=1, M=8192, H=160, K=128  |       165741.2       |                      |      165657.6     |           |      222850.4      |      215961.0   
      f16 B=2, M=8192, H=160, K=128  |       332314.2       |                      |      332593.3     |           |      364220.8      |      330727.9   
      f16 B=1024, M=82, H=8, K=64    |         1647.0       |                      |        1643.6     |   3821.1  |        1886.9      |        1621.7   
      f16 B=150, M=256, H=16, K=64   |         1632.0       |                      |        1626.4     |   4556.0  |        2302.2      |        1626.6   
      f16 B=64, M=256, H=12, K=64    |          570.4       |                      |         566.8     |   1499.1  |         778.4      |         567.1   
      f16 B=1, M=4096, H=16, K=40    |         2153.5       |                      |        2154.7     |   4193.7  |       23526.2      |                 
      f16 B=1, M=16384, H=16, K=40   |        27859.4       |                      |       27850.3     |           |      389891.9      |                 
      f16 B=256, M=4096, H=16, K=64  |       491968.1       |                      |      491246.6     |           |      729981.6      |      440392.8   
      f16 B=16, M=128, H=16, K=16    |          195.8       |                      |         118.3     |    276.8  |         129.3      |         120.4   
      f16 B=16, M=128, H=16, K=32    |          197.7       |                      |         117.6     |    276.8  |         128.8      |         113.7   
      f16 B=16, M=128, H=16, K=64    |          196.0       |                      |         122.9     |    324.0  |         129.9      |         113.6   
      f16 B=16, M=128, H=16, K=128   |          196.4       |                      |         170.0     |    331.5  |         182.2      |         157.9   
      f16 B=16, M=128, H=16, K=256   |                      |          791.8       |                   |    545.0  |         774.4      |                 
      f16 B=16, M=512, H=16, K=16    |          378.0       |                      |         375.1     |   1203.0  |         623.4      |         325.4   
      f16 B=16, M=512, H=16, K=32    |          436.7       |                      |         433.8     |   1306.2  |         711.6      |         432.0   
      f16 B=16, M=512, H=16, K=64    |          684.6       |                      |         680.6     |   1543.4  |         917.7      |         702.3   
      f16 B=16, M=512, H=16, K=128   |         1565.8       |                      |        1560.9     |   1983.1  |        1656.5      |        1580.7   
      f16 B=16, M=512, H=16, K=256   |                      |         8678.8       |                   |   2901.2  |        8652.9      |                 
      f16 B=16, M=1024, H=16, K=16   |         1432.3       |                      |        1435.8     |   4262.8  |        2422.2      |        1245.2   
      f16 B=16, M=1024, H=16, K=32   |         1608.0       |                      |        1605.9     |   4489.6  |        2697.4      |        1617.9   
      f16 B=16, M=1024, H=16, K=64   |         2549.3       |                      |        2541.9     |   4992.6  |        3341.4      |        2370.4   
      f16 B=16, M=1024, H=16, K=128  |         5034.4       |                      |        5023.6     |   5955.5  |        5903.8      |        5634.1   
      f16 B=16, M=1024, H=16, K=256  |                      |        31948.6       |                   |   7891.3  |       31850.4      |                 
      f16 B=64, M=128, H=16, K=16    |          193.9       |                      |         159.7     |    439.8  |         158.3      |         145.0   
      f16 B=64, M=128, H=16, K=32    |          217.0       |                      |         215.3     |    545.3  |         209.4      |         211.2   
      f16 B=64, M=128, H=16, K=64    |          316.0       |                      |         313.6     |    767.0  |         328.6      |         310.2   
      f16 B=64, M=128, H=16, K=128   |          559.7       |                      |         555.9     |   1229.2  |         637.6      |         561.6   
      f16 B=64, M=128, H=16, K=256   |                      |         2851.8       |                   |   2126.3  |        2839.0      |                 
      f16 B=64, M=512, H=16, K=16    |         1372.4       |                      |        1370.9     |   4484.9  |        2313.9      |        1219.4   
      f16 B=64, M=512, H=16, K=32    |         1542.0       |                      |        1539.2     |   4978.8  |        2706.7      |        1547.4   
      f16 B=64, M=512, H=16, K=64    |         2335.4       |                      |        2330.0     |   5885.5  |        3464.0      |        2417.8   
      f16 B=64, M=512, H=16, K=128   |         5357.0       |                      |        5347.8     |   7711.8  |        5909.5      |        5442.7   
      f16 B=64, M=512, H=16, K=256   |                      |        31136.1       |                   |  11501.8  |       30488.7      |                 
      f16 B=64, M=1024, H=16, K=16   |         5226.1       |                      |        5228.4     |  16908.5  |        9092.8      |        4713.7   
      f16 B=64, M=1024, H=16, K=32   |         5717.3       |                      |        5723.1     |  17895.8  |       10439.0      |        5705.3   
      f16 B=64, M=1024, H=16, K=64   |         8672.2       |                      |        8664.4     |  19929.7  |       12937.8      |        8137.3   
      f16 B=64, M=1024, H=16, K=128  |        19104.3       |                      |       19080.9     |  23711.0  |       20963.4      |       19169.7   
      f16 B=64, M=1024, H=16, K=256  |                      |       115163.2       |                   |  32729.8  |      113715.6      |                 

Times are in microseconds (us).
```
</details>

[ghstack-poisoned]
<details>
<summary>BW perf on A100</summary>

```
[-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------]                                                                                                                                         
                                     |  refactor[flshattB]  |  refactor[cutlassB]  |  flash2[flshatt]  |  vanilla  |  cutlass[cutlass]  |  flash1[flshatt]
1 threads: -----------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |          449.4       |                      |         447.6     |   2260.4  |         621.0      |                 
      f16 B=384, M=197, H=1, K=80    |          433.8       |                      |         431.8     |   1916.2  |         589.4      |                 
      f16 B=384, M=197, H=1, K=64    |          234.5       |                      |         233.6     |   1810.0  |         406.2      |         232.8   
      f16 B=1024, M=197, H=1, K=88   |         1112.9       |                      |        1111.3     |   5945.9  |        1557.7      |                 
      f16 B=1024, M=197, H=1, K=80   |         1073.3       |                      |        1073.5     |   5022.5  |        1480.4      |                 
      f16 B=1024, M=197, H=1, K=64   |          579.9       |                      |         579.4     |   4736.3  |         930.1      |         578.3   
      f16 B=512, M=197, H=1, K=80    |          546.0       |                      |         545.1     |   2539.2  |         749.0      |                 
      f16 B=32, M=197, H=16, K=80    |          560.2       |                      |         558.7     |   2579.7  |         748.1      |                 
      f16 B=32, M=197, H=16, K=64    |          298.4       |                      |         297.8     |   2430.3  |         480.1      |         295.8   
      f16 B=32, M=197, H=16, K=128   |          652.5       |                      |         647.4     |   4494.1  |         896.4      |         684.3   
      f16 B=256, M=197, H=1, K=88    |          330.1       |                      |         328.8     |   1527.1  |         449.9      |                 
      f16 B=16, M=197, H=16, K=88    |          331.2       |                      |         329.1     |   1540.0  |         445.2      |                 
      f16 B=16, M=197, H=16, K=64    |          198.7       |                      |         166.9     |   1243.8  |         246.4      |         165.9   
      f16 B=16, M=197, H=16, K=128   |          371.7       |                      |         367.4     |   2266.8  |         503.6      |         386.4   
      f16 B=1, M=4096, H=160, K=128  |        42897.1       |                      |       42915.9     |  45961.0  |       56457.5      |       54740.1   
      f16 B=2, M=4096, H=160, K=128  |        84681.5       |                      |       84646.9     |           |       93553.0      |       84262.0   
      f16 B=1, M=8192, H=160, K=128  |       165741.2       |                      |      165657.6     |           |      222850.4      |      215961.0   
      f16 B=2, M=8192, H=160, K=128  |       332314.2       |                      |      332593.3     |           |      364220.8      |      330727.9   
      f16 B=1024, M=82, H=8, K=64    |         1647.0       |                      |        1643.6     |   3821.1  |        1886.9      |        1621.7   
      f16 B=150, M=256, H=16, K=64   |         1632.0       |                      |        1626.4     |   4556.0  |        2302.2      |        1626.6   
      f16 B=64, M=256, H=12, K=64    |          570.4       |                      |         566.8     |   1499.1  |         778.4      |         567.1   
      f16 B=1, M=4096, H=16, K=40    |         2153.5       |                      |        2154.7     |   4193.7  |       23526.2      |                 
      f16 B=1, M=16384, H=16, K=40   |        27859.4       |                      |       27850.3     |           |      389891.9      |                 
      f16 B=256, M=4096, H=16, K=64  |       491968.1       |                      |      491246.6     |           |      729981.6      |      440392.8   
      f16 B=16, M=128, H=16, K=16    |          195.8       |                      |         118.3     |    276.8  |         129.3      |         120.4   
      f16 B=16, M=128, H=16, K=32    |          197.7       |                      |         117.6     |    276.8  |         128.8      |         113.7   
      f16 B=16, M=128, H=16, K=64    |          196.0       |                      |         122.9     |    324.0  |         129.9      |         113.6   
      f16 B=16, M=128, H=16, K=128   |          196.4       |                      |         170.0     |    331.5  |         182.2      |         157.9   
      f16 B=16, M=128, H=16, K=256   |                      |          791.8       |                   |    545.0  |         774.4      |                 
      f16 B=16, M=512, H=16, K=16    |          378.0       |                      |         375.1     |   1203.0  |         623.4      |         325.4   
      f16 B=16, M=512, H=16, K=32    |          436.7       |                      |         433.8     |   1306.2  |         711.6      |         432.0   
      f16 B=16, M=512, H=16, K=64    |          684.6       |                      |         680.6     |   1543.4  |         917.7      |         702.3   
      f16 B=16, M=512, H=16, K=128   |         1565.8       |                      |        1560.9     |   1983.1  |        1656.5      |        1580.7   
      f16 B=16, M=512, H=16, K=256   |                      |         8678.8       |                   |   2901.2  |        8652.9      |                 
      f16 B=16, M=1024, H=16, K=16   |         1432.3       |                      |        1435.8     |   4262.8  |        2422.2      |        1245.2   
      f16 B=16, M=1024, H=16, K=32   |         1608.0       |                      |        1605.9     |   4489.6  |        2697.4      |        1617.9   
      f16 B=16, M=1024, H=16, K=64   |         2549.3       |                      |        2541.9     |   4992.6  |        3341.4      |        2370.4   
      f16 B=16, M=1024, H=16, K=128  |         5034.4       |                      |        5023.6     |   5955.5  |        5903.8      |        5634.1   
      f16 B=16, M=1024, H=16, K=256  |                      |        31948.6       |                   |   7891.3  |       31850.4      |                 
      f16 B=64, M=128, H=16, K=16    |          193.9       |                      |         159.7     |    439.8  |         158.3      |         145.0   
      f16 B=64, M=128, H=16, K=32    |          217.0       |                      |         215.3     |    545.3  |         209.4      |         211.2   
      f16 B=64, M=128, H=16, K=64    |          316.0       |                      |         313.6     |    767.0  |         328.6      |         310.2   
      f16 B=64, M=128, H=16, K=128   |          559.7       |                      |         555.9     |   1229.2  |         637.6      |         561.6   
      f16 B=64, M=128, H=16, K=256   |                      |         2851.8       |                   |   2126.3  |        2839.0      |                 
      f16 B=64, M=512, H=16, K=16    |         1372.4       |                      |        1370.9     |   4484.9  |        2313.9      |        1219.4   
      f16 B=64, M=512, H=16, K=32    |         1542.0       |                      |        1539.2     |   4978.8  |        2706.7      |        1547.4   
      f16 B=64, M=512, H=16, K=64    |         2335.4       |                      |        2330.0     |   5885.5  |        3464.0      |        2417.8   
      f16 B=64, M=512, H=16, K=128   |         5357.0       |                      |        5347.8     |   7711.8  |        5909.5      |        5442.7   
      f16 B=64, M=512, H=16, K=256   |                      |        31136.1       |                   |  11501.8  |       30488.7      |                 
      f16 B=64, M=1024, H=16, K=16   |         5226.1       |                      |        5228.4     |  16908.5  |        9092.8      |        4713.7   
      f16 B=64, M=1024, H=16, K=32   |         5717.3       |                      |        5723.1     |  17895.8  |       10439.0      |        5705.3   
      f16 B=64, M=1024, H=16, K=64   |         8672.2       |                      |        8664.4     |  19929.7  |       12937.8      |        8137.3   
      f16 B=64, M=1024, H=16, K=128  |        19104.3       |                      |       19080.9     |  23711.0  |       20963.4      |       19169.7   
      f16 B=64, M=1024, H=16, K=256  |                      |       115163.2       |                   |  32729.8  |      113715.6      |                 

Times are in microseconds (us).
```
</details>

[ghstack-poisoned]
<details>
<summary>BW perf on A100</summary>

```
[-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------]                                                                                                                                         
                                     |  refactor[flshattB]  |  refactor[cutlassB]  |  flash2[flshatt]  |  vanilla  |  cutlass[cutlass]  |  flash1[flshatt]
1 threads: -----------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |          449.4       |                      |         447.6     |   2260.4  |         621.0      |                 
      f16 B=384, M=197, H=1, K=80    |          433.8       |                      |         431.8     |   1916.2  |         589.4      |                 
      f16 B=384, M=197, H=1, K=64    |          234.5       |                      |         233.6     |   1810.0  |         406.2      |         232.8   
      f16 B=1024, M=197, H=1, K=88   |         1112.9       |                      |        1111.3     |   5945.9  |        1557.7      |                 
      f16 B=1024, M=197, H=1, K=80   |         1073.3       |                      |        1073.5     |   5022.5  |        1480.4      |                 
      f16 B=1024, M=197, H=1, K=64   |          579.9       |                      |         579.4     |   4736.3  |         930.1      |         578.3   
      f16 B=512, M=197, H=1, K=80    |          546.0       |                      |         545.1     |   2539.2  |         749.0      |                 
      f16 B=32, M=197, H=16, K=80    |          560.2       |                      |         558.7     |   2579.7  |         748.1      |                 
      f16 B=32, M=197, H=16, K=64    |          298.4       |                      |         297.8     |   2430.3  |         480.1      |         295.8   
      f16 B=32, M=197, H=16, K=128   |          652.5       |                      |         647.4     |   4494.1  |         896.4      |         684.3   
      f16 B=256, M=197, H=1, K=88    |          330.1       |                      |         328.8     |   1527.1  |         449.9      |                 
      f16 B=16, M=197, H=16, K=88    |          331.2       |                      |         329.1     |   1540.0  |         445.2      |                 
      f16 B=16, M=197, H=16, K=64    |          198.7       |                      |         166.9     |   1243.8  |         246.4      |         165.9   
      f16 B=16, M=197, H=16, K=128   |          371.7       |                      |         367.4     |   2266.8  |         503.6      |         386.4   
      f16 B=1, M=4096, H=160, K=128  |        42897.1       |                      |       42915.9     |  45961.0  |       56457.5      |       54740.1   
      f16 B=2, M=4096, H=160, K=128  |        84681.5       |                      |       84646.9     |           |       93553.0      |       84262.0   
      f16 B=1, M=8192, H=160, K=128  |       165741.2       |                      |      165657.6     |           |      222850.4      |      215961.0   
      f16 B=2, M=8192, H=160, K=128  |       332314.2       |                      |      332593.3     |           |      364220.8      |      330727.9   
      f16 B=1024, M=82, H=8, K=64    |         1647.0       |                      |        1643.6     |   3821.1  |        1886.9      |        1621.7   
      f16 B=150, M=256, H=16, K=64   |         1632.0       |                      |        1626.4     |   4556.0  |        2302.2      |        1626.6   
      f16 B=64, M=256, H=12, K=64    |          570.4       |                      |         566.8     |   1499.1  |         778.4      |         567.1   
      f16 B=1, M=4096, H=16, K=40    |         2153.5       |                      |        2154.7     |   4193.7  |       23526.2      |                 
      f16 B=1, M=16384, H=16, K=40   |        27859.4       |                      |       27850.3     |           |      389891.9      |                 
      f16 B=256, M=4096, H=16, K=64  |       491968.1       |                      |      491246.6     |           |      729981.6      |      440392.8   
      f16 B=16, M=128, H=16, K=16    |          195.8       |                      |         118.3     |    276.8  |         129.3      |         120.4   
      f16 B=16, M=128, H=16, K=32    |          197.7       |                      |         117.6     |    276.8  |         128.8      |         113.7   
      f16 B=16, M=128, H=16, K=64    |          196.0       |                      |         122.9     |    324.0  |         129.9      |         113.6   
      f16 B=16, M=128, H=16, K=128   |          196.4       |                      |         170.0     |    331.5  |         182.2      |         157.9   
      f16 B=16, M=128, H=16, K=256   |                      |          791.8       |                   |    545.0  |         774.4      |                 
      f16 B=16, M=512, H=16, K=16    |          378.0       |                      |         375.1     |   1203.0  |         623.4      |         325.4   
      f16 B=16, M=512, H=16, K=32    |          436.7       |                      |         433.8     |   1306.2  |         711.6      |         432.0   
      f16 B=16, M=512, H=16, K=64    |          684.6       |                      |         680.6     |   1543.4  |         917.7      |         702.3   
      f16 B=16, M=512, H=16, K=128   |         1565.8       |                      |        1560.9     |   1983.1  |        1656.5      |        1580.7   
      f16 B=16, M=512, H=16, K=256   |                      |         8678.8       |                   |   2901.2  |        8652.9      |                 
      f16 B=16, M=1024, H=16, K=16   |         1432.3       |                      |        1435.8     |   4262.8  |        2422.2      |        1245.2   
      f16 B=16, M=1024, H=16, K=32   |         1608.0       |                      |        1605.9     |   4489.6  |        2697.4      |        1617.9   
      f16 B=16, M=1024, H=16, K=64   |         2549.3       |                      |        2541.9     |   4992.6  |        3341.4      |        2370.4   
      f16 B=16, M=1024, H=16, K=128  |         5034.4       |                      |        5023.6     |   5955.5  |        5903.8      |        5634.1   
      f16 B=16, M=1024, H=16, K=256  |                      |        31948.6       |                   |   7891.3  |       31850.4      |                 
      f16 B=64, M=128, H=16, K=16    |          193.9       |                      |         159.7     |    439.8  |         158.3      |         145.0   
      f16 B=64, M=128, H=16, K=32    |          217.0       |                      |         215.3     |    545.3  |         209.4      |         211.2   
      f16 B=64, M=128, H=16, K=64    |          316.0       |                      |         313.6     |    767.0  |         328.6      |         310.2   
      f16 B=64, M=128, H=16, K=128   |          559.7       |                      |         555.9     |   1229.2  |         637.6      |         561.6   
      f16 B=64, M=128, H=16, K=256   |                      |         2851.8       |                   |   2126.3  |        2839.0      |                 
      f16 B=64, M=512, H=16, K=16    |         1372.4       |                      |        1370.9     |   4484.9  |        2313.9      |        1219.4   
      f16 B=64, M=512, H=16, K=32    |         1542.0       |                      |        1539.2     |   4978.8  |        2706.7      |        1547.4   
      f16 B=64, M=512, H=16, K=64    |         2335.4       |                      |        2330.0     |   5885.5  |        3464.0      |        2417.8   
      f16 B=64, M=512, H=16, K=128   |         5357.0       |                      |        5347.8     |   7711.8  |        5909.5      |        5442.7   
      f16 B=64, M=512, H=16, K=256   |                      |        31136.1       |                   |  11501.8  |       30488.7      |                 
      f16 B=64, M=1024, H=16, K=16   |         5226.1       |                      |        5228.4     |  16908.5  |        9092.8      |        4713.7   
      f16 B=64, M=1024, H=16, K=32   |         5717.3       |                      |        5723.1     |  17895.8  |       10439.0      |        5705.3   
      f16 B=64, M=1024, H=16, K=64   |         8672.2       |                      |        8664.4     |  19929.7  |       12937.8      |        8137.3   
      f16 B=64, M=1024, H=16, K=128  |        19104.3       |                      |       19080.9     |  23711.0  |       20963.4      |       19169.7   
      f16 B=64, M=1024, H=16, K=256  |                      |       115163.2       |                   |  32729.8  |      113715.6      |                 

Times are in microseconds (us).
```
</details>

[ghstack-poisoned]
<details>
<summary>BW perf on A100</summary>

```
[-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------]                                                                                                                                         
                                     |  refactor[flshattB]  |  refactor[cutlassB]  |  flash2[flshatt]  |  vanilla  |  cutlass[cutlass]  |  flash1[flshatt]
1 threads: -----------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |          449.4       |                      |         447.6     |   2260.4  |         621.0      |                 
      f16 B=384, M=197, H=1, K=80    |          433.8       |                      |         431.8     |   1916.2  |         589.4      |                 
      f16 B=384, M=197, H=1, K=64    |          234.5       |                      |         233.6     |   1810.0  |         406.2      |         232.8   
      f16 B=1024, M=197, H=1, K=88   |         1112.9       |                      |        1111.3     |   5945.9  |        1557.7      |                 
      f16 B=1024, M=197, H=1, K=80   |         1073.3       |                      |        1073.5     |   5022.5  |        1480.4      |                 
      f16 B=1024, M=197, H=1, K=64   |          579.9       |                      |         579.4     |   4736.3  |         930.1      |         578.3   
      f16 B=512, M=197, H=1, K=80    |          546.0       |                      |         545.1     |   2539.2  |         749.0      |                 
      f16 B=32, M=197, H=16, K=80    |          560.2       |                      |         558.7     |   2579.7  |         748.1      |                 
      f16 B=32, M=197, H=16, K=64    |          298.4       |                      |         297.8     |   2430.3  |         480.1      |         295.8   
      f16 B=32, M=197, H=16, K=128   |          652.5       |                      |         647.4     |   4494.1  |         896.4      |         684.3   
      f16 B=256, M=197, H=1, K=88    |          330.1       |                      |         328.8     |   1527.1  |         449.9      |                 
      f16 B=16, M=197, H=16, K=88    |          331.2       |                      |         329.1     |   1540.0  |         445.2      |                 
      f16 B=16, M=197, H=16, K=64    |          198.7       |                      |         166.9     |   1243.8  |         246.4      |         165.9   
      f16 B=16, M=197, H=16, K=128   |          371.7       |                      |         367.4     |   2266.8  |         503.6      |         386.4   
      f16 B=1, M=4096, H=160, K=128  |        42897.1       |                      |       42915.9     |  45961.0  |       56457.5      |       54740.1   
      f16 B=2, M=4096, H=160, K=128  |        84681.5       |                      |       84646.9     |           |       93553.0      |       84262.0   
      f16 B=1, M=8192, H=160, K=128  |       165741.2       |                      |      165657.6     |           |      222850.4      |      215961.0   
      f16 B=2, M=8192, H=160, K=128  |       332314.2       |                      |      332593.3     |           |      364220.8      |      330727.9   
      f16 B=1024, M=82, H=8, K=64    |         1647.0       |                      |        1643.6     |   3821.1  |        1886.9      |        1621.7   
      f16 B=150, M=256, H=16, K=64   |         1632.0       |                      |        1626.4     |   4556.0  |        2302.2      |        1626.6   
      f16 B=64, M=256, H=12, K=64    |          570.4       |                      |         566.8     |   1499.1  |         778.4      |         567.1   
      f16 B=1, M=4096, H=16, K=40    |         2153.5       |                      |        2154.7     |   4193.7  |       23526.2      |                 
      f16 B=1, M=16384, H=16, K=40   |        27859.4       |                      |       27850.3     |           |      389891.9      |                 
      f16 B=256, M=4096, H=16, K=64  |       491968.1       |                      |      491246.6     |           |      729981.6      |      440392.8   
      f16 B=16, M=128, H=16, K=16    |          195.8       |                      |         118.3     |    276.8  |         129.3      |         120.4   
      f16 B=16, M=128, H=16, K=32    |          197.7       |                      |         117.6     |    276.8  |         128.8      |         113.7   
      f16 B=16, M=128, H=16, K=64    |          196.0       |                      |         122.9     |    324.0  |         129.9      |         113.6   
      f16 B=16, M=128, H=16, K=128   |          196.4       |                      |         170.0     |    331.5  |         182.2      |         157.9   
      f16 B=16, M=128, H=16, K=256   |                      |          791.8       |                   |    545.0  |         774.4      |                 
      f16 B=16, M=512, H=16, K=16    |          378.0       |                      |         375.1     |   1203.0  |         623.4      |         325.4   
      f16 B=16, M=512, H=16, K=32    |          436.7       |                      |         433.8     |   1306.2  |         711.6      |         432.0   
      f16 B=16, M=512, H=16, K=64    |          684.6       |                      |         680.6     |   1543.4  |         917.7      |         702.3   
      f16 B=16, M=512, H=16, K=128   |         1565.8       |                      |        1560.9     |   1983.1  |        1656.5      |        1580.7   
      f16 B=16, M=512, H=16, K=256   |                      |         8678.8       |                   |   2901.2  |        8652.9      |                 
      f16 B=16, M=1024, H=16, K=16   |         1432.3       |                      |        1435.8     |   4262.8  |        2422.2      |        1245.2   
      f16 B=16, M=1024, H=16, K=32   |         1608.0       |                      |        1605.9     |   4489.6  |        2697.4      |        1617.9   
      f16 B=16, M=1024, H=16, K=64   |         2549.3       |                      |        2541.9     |   4992.6  |        3341.4      |        2370.4   
      f16 B=16, M=1024, H=16, K=128  |         5034.4       |                      |        5023.6     |   5955.5  |        5903.8      |        5634.1   
      f16 B=16, M=1024, H=16, K=256  |                      |        31948.6       |                   |   7891.3  |       31850.4      |                 
      f16 B=64, M=128, H=16, K=16    |          193.9       |                      |         159.7     |    439.8  |         158.3      |         145.0   
      f16 B=64, M=128, H=16, K=32    |          217.0       |                      |         215.3     |    545.3  |         209.4      |         211.2   
      f16 B=64, M=128, H=16, K=64    |          316.0       |                      |         313.6     |    767.0  |         328.6      |         310.2   
      f16 B=64, M=128, H=16, K=128   |          559.7       |                      |         555.9     |   1229.2  |         637.6      |         561.6   
      f16 B=64, M=128, H=16, K=256   |                      |         2851.8       |                   |   2126.3  |        2839.0      |                 
      f16 B=64, M=512, H=16, K=16    |         1372.4       |                      |        1370.9     |   4484.9  |        2313.9      |        1219.4   
      f16 B=64, M=512, H=16, K=32    |         1542.0       |                      |        1539.2     |   4978.8  |        2706.7      |        1547.4   
      f16 B=64, M=512, H=16, K=64    |         2335.4       |                      |        2330.0     |   5885.5  |        3464.0      |        2417.8   
      f16 B=64, M=512, H=16, K=128   |         5357.0       |                      |        5347.8     |   7711.8  |        5909.5      |        5442.7   
      f16 B=64, M=512, H=16, K=256   |                      |        31136.1       |                   |  11501.8  |       30488.7      |                 
      f16 B=64, M=1024, H=16, K=16   |         5226.1       |                      |        5228.4     |  16908.5  |        9092.8      |        4713.7   
      f16 B=64, M=1024, H=16, K=32   |         5717.3       |                      |        5723.1     |  17895.8  |       10439.0      |        5705.3   
      f16 B=64, M=1024, H=16, K=64   |         8672.2       |                      |        8664.4     |  19929.7  |       12937.8      |        8137.3   
      f16 B=64, M=1024, H=16, K=128  |        19104.3       |                      |       19080.9     |  23711.0  |       20963.4      |       19169.7   
      f16 B=64, M=1024, H=16, K=256  |                      |       115163.2       |                   |  32729.8  |      113715.6      |                 

Times are in microseconds (us).
```
</details>

[ghstack-poisoned]

NOTE: Dropout is entirely disabled until we figure out how to implement it properly
<details>
<summary>BW perf on A100</summary>

```
[-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------]                                                                                                                                         
                                     |  refactor[flshattB]  |  refactor[cutlassB]  |  flash2[flshatt]  |  vanilla  |  cutlass[cutlass]  |  flash1[flshatt]
1 threads: -----------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |          449.4       |                      |         447.6     |   2260.4  |         621.0      |                 
      f16 B=384, M=197, H=1, K=80    |          433.8       |                      |         431.8     |   1916.2  |         589.4      |                 
      f16 B=384, M=197, H=1, K=64    |          234.5       |                      |         233.6     |   1810.0  |         406.2      |         232.8   
      f16 B=1024, M=197, H=1, K=88   |         1112.9       |                      |        1111.3     |   5945.9  |        1557.7      |                 
      f16 B=1024, M=197, H=1, K=80   |         1073.3       |                      |        1073.5     |   5022.5  |        1480.4      |                 
      f16 B=1024, M=197, H=1, K=64   |          579.9       |                      |         579.4     |   4736.3  |         930.1      |         578.3   
      f16 B=512, M=197, H=1, K=80    |          546.0       |                      |         545.1     |   2539.2  |         749.0      |                 
      f16 B=32, M=197, H=16, K=80    |          560.2       |                      |         558.7     |   2579.7  |         748.1      |                 
      f16 B=32, M=197, H=16, K=64    |          298.4       |                      |         297.8     |   2430.3  |         480.1      |         295.8   
      f16 B=32, M=197, H=16, K=128   |          652.5       |                      |         647.4     |   4494.1  |         896.4      |         684.3   
      f16 B=256, M=197, H=1, K=88    |          330.1       |                      |         328.8     |   1527.1  |         449.9      |                 
      f16 B=16, M=197, H=16, K=88    |          331.2       |                      |         329.1     |   1540.0  |         445.2      |                 
      f16 B=16, M=197, H=16, K=64    |          198.7       |                      |         166.9     |   1243.8  |         246.4      |         165.9   
      f16 B=16, M=197, H=16, K=128   |          371.7       |                      |         367.4     |   2266.8  |         503.6      |         386.4   
      f16 B=1, M=4096, H=160, K=128  |        42897.1       |                      |       42915.9     |  45961.0  |       56457.5      |       54740.1   
      f16 B=2, M=4096, H=160, K=128  |        84681.5       |                      |       84646.9     |           |       93553.0      |       84262.0   
      f16 B=1, M=8192, H=160, K=128  |       165741.2       |                      |      165657.6     |           |      222850.4      |      215961.0   
      f16 B=2, M=8192, H=160, K=128  |       332314.2       |                      |      332593.3     |           |      364220.8      |      330727.9   
      f16 B=1024, M=82, H=8, K=64    |         1647.0       |                      |        1643.6     |   3821.1  |        1886.9      |        1621.7   
      f16 B=150, M=256, H=16, K=64   |         1632.0       |                      |        1626.4     |   4556.0  |        2302.2      |        1626.6   
      f16 B=64, M=256, H=12, K=64    |          570.4       |                      |         566.8     |   1499.1  |         778.4      |         567.1   
      f16 B=1, M=4096, H=16, K=40    |         2153.5       |                      |        2154.7     |   4193.7  |       23526.2      |                 
      f16 B=1, M=16384, H=16, K=40   |        27859.4       |                      |       27850.3     |           |      389891.9      |                 
      f16 B=256, M=4096, H=16, K=64  |       491968.1       |                      |      491246.6     |           |      729981.6      |      440392.8   
      f16 B=16, M=128, H=16, K=16    |          195.8       |                      |         118.3     |    276.8  |         129.3      |         120.4   
      f16 B=16, M=128, H=16, K=32    |          197.7       |                      |         117.6     |    276.8  |         128.8      |         113.7   
      f16 B=16, M=128, H=16, K=64    |          196.0       |                      |         122.9     |    324.0  |         129.9      |         113.6   
      f16 B=16, M=128, H=16, K=128   |          196.4       |                      |         170.0     |    331.5  |         182.2      |         157.9   
      f16 B=16, M=128, H=16, K=256   |                      |          791.8       |                   |    545.0  |         774.4      |                 
      f16 B=16, M=512, H=16, K=16    |          378.0       |                      |         375.1     |   1203.0  |         623.4      |         325.4   
      f16 B=16, M=512, H=16, K=32    |          436.7       |                      |         433.8     |   1306.2  |         711.6      |         432.0   
      f16 B=16, M=512, H=16, K=64    |          684.6       |                      |         680.6     |   1543.4  |         917.7      |         702.3   
      f16 B=16, M=512, H=16, K=128   |         1565.8       |                      |        1560.9     |   1983.1  |        1656.5      |        1580.7   
      f16 B=16, M=512, H=16, K=256   |                      |         8678.8       |                   |   2901.2  |        8652.9      |                 
      f16 B=16, M=1024, H=16, K=16   |         1432.3       |                      |        1435.8     |   4262.8  |        2422.2      |        1245.2   
      f16 B=16, M=1024, H=16, K=32   |         1608.0       |                      |        1605.9     |   4489.6  |        2697.4      |        1617.9   
      f16 B=16, M=1024, H=16, K=64   |         2549.3       |                      |        2541.9     |   4992.6  |        3341.4      |        2370.4   
      f16 B=16, M=1024, H=16, K=128  |         5034.4       |                      |        5023.6     |   5955.5  |        5903.8      |        5634.1   
      f16 B=16, M=1024, H=16, K=256  |                      |        31948.6       |                   |   7891.3  |       31850.4      |                 
      f16 B=64, M=128, H=16, K=16    |          193.9       |                      |         159.7     |    439.8  |         158.3      |         145.0   
      f16 B=64, M=128, H=16, K=32    |          217.0       |                      |         215.3     |    545.3  |         209.4      |         211.2   
      f16 B=64, M=128, H=16, K=64    |          316.0       |                      |         313.6     |    767.0  |         328.6      |         310.2   
      f16 B=64, M=128, H=16, K=128   |          559.7       |                      |         555.9     |   1229.2  |         637.6      |         561.6   
      f16 B=64, M=128, H=16, K=256   |                      |         2851.8       |                   |   2126.3  |        2839.0      |                 
      f16 B=64, M=512, H=16, K=16    |         1372.4       |                      |        1370.9     |   4484.9  |        2313.9      |        1219.4   
      f16 B=64, M=512, H=16, K=32    |         1542.0       |                      |        1539.2     |   4978.8  |        2706.7      |        1547.4   
      f16 B=64, M=512, H=16, K=64    |         2335.4       |                      |        2330.0     |   5885.5  |        3464.0      |        2417.8   
      f16 B=64, M=512, H=16, K=128   |         5357.0       |                      |        5347.8     |   7711.8  |        5909.5      |        5442.7   
      f16 B=64, M=512, H=16, K=256   |                      |        31136.1       |                   |  11501.8  |       30488.7      |                 
      f16 B=64, M=1024, H=16, K=16   |         5226.1       |                      |        5228.4     |  16908.5  |        9092.8      |        4713.7   
      f16 B=64, M=1024, H=16, K=32   |         5717.3       |                      |        5723.1     |  17895.8  |       10439.0      |        5705.3   
      f16 B=64, M=1024, H=16, K=64   |         8672.2       |                      |        8664.4     |  19929.7  |       12937.8      |        8137.3   
      f16 B=64, M=1024, H=16, K=128  |        19104.3       |                      |       19080.9     |  23711.0  |       20963.4      |       19169.7   
      f16 B=64, M=1024, H=16, K=256  |                      |       115163.2       |                   |  32729.8  |      113715.6      |                 

Times are in microseconds (us).
```
</details>

[ghstack-poisoned]
danthe3rd added 4 commits December 8, 2022 10:28

NOTE: Dropout is entirely disabled until we figure out how to implement it properly
<details>
<summary>BW perf on A100</summary>

```
[-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------]                                                                                                                                         
                                     |  refactor[flshattB]  |  refactor[cutlassB]  |  flash2[flshatt]  |  vanilla  |  cutlass[cutlass]  |  flash1[flshatt]
1 threads: -----------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |          449.4       |                      |         447.6     |   2260.4  |         621.0      |                 
      f16 B=384, M=197, H=1, K=80    |          433.8       |                      |         431.8     |   1916.2  |         589.4      |                 
      f16 B=384, M=197, H=1, K=64    |          234.5       |                      |         233.6     |   1810.0  |         406.2      |         232.8   
      f16 B=1024, M=197, H=1, K=88   |         1112.9       |                      |        1111.3     |   5945.9  |        1557.7      |                 
      f16 B=1024, M=197, H=1, K=80   |         1073.3       |                      |        1073.5     |   5022.5  |        1480.4      |                 
      f16 B=1024, M=197, H=1, K=64   |          579.9       |                      |         579.4     |   4736.3  |         930.1      |         578.3   
      f16 B=512, M=197, H=1, K=80    |          546.0       |                      |         545.1     |   2539.2  |         749.0      |                 
      f16 B=32, M=197, H=16, K=80    |          560.2       |                      |         558.7     |   2579.7  |         748.1      |                 
      f16 B=32, M=197, H=16, K=64    |          298.4       |                      |         297.8     |   2430.3  |         480.1      |         295.8   
      f16 B=32, M=197, H=16, K=128   |          652.5       |                      |         647.4     |   4494.1  |         896.4      |         684.3   
      f16 B=256, M=197, H=1, K=88    |          330.1       |                      |         328.8     |   1527.1  |         449.9      |                 
      f16 B=16, M=197, H=16, K=88    |          331.2       |                      |         329.1     |   1540.0  |         445.2      |                 
      f16 B=16, M=197, H=16, K=64    |          198.7       |                      |         166.9     |   1243.8  |         246.4      |         165.9   
      f16 B=16, M=197, H=16, K=128   |          371.7       |                      |         367.4     |   2266.8  |         503.6      |         386.4   
      f16 B=1, M=4096, H=160, K=128  |        42897.1       |                      |       42915.9     |  45961.0  |       56457.5      |       54740.1   
      f16 B=2, M=4096, H=160, K=128  |        84681.5       |                      |       84646.9     |           |       93553.0      |       84262.0   
      f16 B=1, M=8192, H=160, K=128  |       165741.2       |                      |      165657.6     |           |      222850.4      |      215961.0   
      f16 B=2, M=8192, H=160, K=128  |       332314.2       |                      |      332593.3     |           |      364220.8      |      330727.9   
      f16 B=1024, M=82, H=8, K=64    |         1647.0       |                      |        1643.6     |   3821.1  |        1886.9      |        1621.7   
      f16 B=150, M=256, H=16, K=64   |         1632.0       |                      |        1626.4     |   4556.0  |        2302.2      |        1626.6   
      f16 B=64, M=256, H=12, K=64    |          570.4       |                      |         566.8     |   1499.1  |         778.4      |         567.1   
      f16 B=1, M=4096, H=16, K=40    |         2153.5       |                      |        2154.7     |   4193.7  |       23526.2      |                 
      f16 B=1, M=16384, H=16, K=40   |        27859.4       |                      |       27850.3     |           |      389891.9      |                 
      f16 B=256, M=4096, H=16, K=64  |       491968.1       |                      |      491246.6     |           |      729981.6      |      440392.8   
      f16 B=16, M=128, H=16, K=16    |          195.8       |                      |         118.3     |    276.8  |         129.3      |         120.4   
      f16 B=16, M=128, H=16, K=32    |          197.7       |                      |         117.6     |    276.8  |         128.8      |         113.7   
      f16 B=16, M=128, H=16, K=64    |          196.0       |                      |         122.9     |    324.0  |         129.9      |         113.6   
      f16 B=16, M=128, H=16, K=128   |          196.4       |                      |         170.0     |    331.5  |         182.2      |         157.9   
      f16 B=16, M=128, H=16, K=256   |                      |          791.8       |                   |    545.0  |         774.4      |                 
      f16 B=16, M=512, H=16, K=16    |          378.0       |                      |         375.1     |   1203.0  |         623.4      |         325.4   
      f16 B=16, M=512, H=16, K=32    |          436.7       |                      |         433.8     |   1306.2  |         711.6      |         432.0   
      f16 B=16, M=512, H=16, K=64    |          684.6       |                      |         680.6     |   1543.4  |         917.7      |         702.3   
      f16 B=16, M=512, H=16, K=128   |         1565.8       |                      |        1560.9     |   1983.1  |        1656.5      |        1580.7   
      f16 B=16, M=512, H=16, K=256   |                      |         8678.8       |                   |   2901.2  |        8652.9      |                 
      f16 B=16, M=1024, H=16, K=16   |         1432.3       |                      |        1435.8     |   4262.8  |        2422.2      |        1245.2   
      f16 B=16, M=1024, H=16, K=32   |         1608.0       |                      |        1605.9     |   4489.6  |        2697.4      |        1617.9   
      f16 B=16, M=1024, H=16, K=64   |         2549.3       |                      |        2541.9     |   4992.6  |        3341.4      |        2370.4   
      f16 B=16, M=1024, H=16, K=128  |         5034.4       |                      |        5023.6     |   5955.5  |        5903.8      |        5634.1   
      f16 B=16, M=1024, H=16, K=256  |                      |        31948.6       |                   |   7891.3  |       31850.4      |                 
      f16 B=64, M=128, H=16, K=16    |          193.9       |                      |         159.7     |    439.8  |         158.3      |         145.0   
      f16 B=64, M=128, H=16, K=32    |          217.0       |                      |         215.3     |    545.3  |         209.4      |         211.2   
      f16 B=64, M=128, H=16, K=64    |          316.0       |                      |         313.6     |    767.0  |         328.6      |         310.2   
      f16 B=64, M=128, H=16, K=128   |          559.7       |                      |         555.9     |   1229.2  |         637.6      |         561.6   
      f16 B=64, M=128, H=16, K=256   |                      |         2851.8       |                   |   2126.3  |        2839.0      |                 
      f16 B=64, M=512, H=16, K=16    |         1372.4       |                      |        1370.9     |   4484.9  |        2313.9      |        1219.4   
      f16 B=64, M=512, H=16, K=32    |         1542.0       |                      |        1539.2     |   4978.8  |        2706.7      |        1547.4   
      f16 B=64, M=512, H=16, K=64    |         2335.4       |                      |        2330.0     |   5885.5  |        3464.0      |        2417.8   
      f16 B=64, M=512, H=16, K=128   |         5357.0       |                      |        5347.8     |   7711.8  |        5909.5      |        5442.7   
      f16 B=64, M=512, H=16, K=256   |                      |        31136.1       |                   |  11501.8  |       30488.7      |                 
      f16 B=64, M=1024, H=16, K=16   |         5226.1       |                      |        5228.4     |  16908.5  |        9092.8      |        4713.7   
      f16 B=64, M=1024, H=16, K=32   |         5717.3       |                      |        5723.1     |  17895.8  |       10439.0      |        5705.3   
      f16 B=64, M=1024, H=16, K=64   |         8672.2       |                      |        8664.4     |  19929.7  |       12937.8      |        8137.3   
      f16 B=64, M=1024, H=16, K=128  |        19104.3       |                      |       19080.9     |  23711.0  |       20963.4      |       19169.7   
      f16 B=64, M=1024, H=16, K=256  |                      |       115163.2       |                   |  32729.8  |      113715.6      |                 

Times are in microseconds (us).
```
</details>

[ghstack-poisoned]

NOTE: Dropout is entirely disabled until we figure out how to implement it properly
<details>
<summary>BW perf on A100</summary>

```
[-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------]                                                                                                                                         
                                     |  refactor[flshattB]  |  refactor[cutlassB]  |  flash2[flshatt]  |  vanilla  |  cutlass[cutlass]  |  flash1[flshatt]
1 threads: -----------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |          449.4       |                      |         447.6     |   2260.4  |         621.0      |                 
      f16 B=384, M=197, H=1, K=80    |          433.8       |                      |         431.8     |   1916.2  |         589.4      |                 
      f16 B=384, M=197, H=1, K=64    |          234.5       |                      |         233.6     |   1810.0  |         406.2      |         232.8   
      f16 B=1024, M=197, H=1, K=88   |         1112.9       |                      |        1111.3     |   5945.9  |        1557.7      |                 
      f16 B=1024, M=197, H=1, K=80   |         1073.3       |                      |        1073.5     |   5022.5  |        1480.4      |                 
      f16 B=1024, M=197, H=1, K=64   |          579.9       |                      |         579.4     |   4736.3  |         930.1      |         578.3   
      f16 B=512, M=197, H=1, K=80    |          546.0       |                      |         545.1     |   2539.2  |         749.0      |                 
      f16 B=32, M=197, H=16, K=80    |          560.2       |                      |         558.7     |   2579.7  |         748.1      |                 
      f16 B=32, M=197, H=16, K=64    |          298.4       |                      |         297.8     |   2430.3  |         480.1      |         295.8   
      f16 B=32, M=197, H=16, K=128   |          652.5       |                      |         647.4     |   4494.1  |         896.4      |         684.3   
      f16 B=256, M=197, H=1, K=88    |          330.1       |                      |         328.8     |   1527.1  |         449.9      |                 
      f16 B=16, M=197, H=16, K=88    |          331.2       |                      |         329.1     |   1540.0  |         445.2      |                 
      f16 B=16, M=197, H=16, K=64    |          198.7       |                      |         166.9     |   1243.8  |         246.4      |         165.9   
      f16 B=16, M=197, H=16, K=128   |          371.7       |                      |         367.4     |   2266.8  |         503.6      |         386.4   
      f16 B=1, M=4096, H=160, K=128  |        42897.1       |                      |       42915.9     |  45961.0  |       56457.5      |       54740.1   
      f16 B=2, M=4096, H=160, K=128  |        84681.5       |                      |       84646.9     |           |       93553.0      |       84262.0   
      f16 B=1, M=8192, H=160, K=128  |       165741.2       |                      |      165657.6     |           |      222850.4      |      215961.0   
      f16 B=2, M=8192, H=160, K=128  |       332314.2       |                      |      332593.3     |           |      364220.8      |      330727.9   
      f16 B=1024, M=82, H=8, K=64    |         1647.0       |                      |        1643.6     |   3821.1  |        1886.9      |        1621.7   
      f16 B=150, M=256, H=16, K=64   |         1632.0       |                      |        1626.4     |   4556.0  |        2302.2      |        1626.6   
      f16 B=64, M=256, H=12, K=64    |          570.4       |                      |         566.8     |   1499.1  |         778.4      |         567.1   
      f16 B=1, M=4096, H=16, K=40    |         2153.5       |                      |        2154.7     |   4193.7  |       23526.2      |                 
      f16 B=1, M=16384, H=16, K=40   |        27859.4       |                      |       27850.3     |           |      389891.9      |                 
      f16 B=256, M=4096, H=16, K=64  |       491968.1       |                      |      491246.6     |           |      729981.6      |      440392.8   
      f16 B=16, M=128, H=16, K=16    |          195.8       |                      |         118.3     |    276.8  |         129.3      |         120.4   
      f16 B=16, M=128, H=16, K=32    |          197.7       |                      |         117.6     |    276.8  |         128.8      |         113.7   
      f16 B=16, M=128, H=16, K=64    |          196.0       |                      |         122.9     |    324.0  |         129.9      |         113.6   
      f16 B=16, M=128, H=16, K=128   |          196.4       |                      |         170.0     |    331.5  |         182.2      |         157.9   
      f16 B=16, M=128, H=16, K=256   |                      |          791.8       |                   |    545.0  |         774.4      |                 
      f16 B=16, M=512, H=16, K=16    |          378.0       |                      |         375.1     |   1203.0  |         623.4      |         325.4   
      f16 B=16, M=512, H=16, K=32    |          436.7       |                      |         433.8     |   1306.2  |         711.6      |         432.0   
      f16 B=16, M=512, H=16, K=64    |          684.6       |                      |         680.6     |   1543.4  |         917.7      |         702.3   
      f16 B=16, M=512, H=16, K=128   |         1565.8       |                      |        1560.9     |   1983.1  |        1656.5      |        1580.7   
      f16 B=16, M=512, H=16, K=256   |                      |         8678.8       |                   |   2901.2  |        8652.9      |                 
      f16 B=16, M=1024, H=16, K=16   |         1432.3       |                      |        1435.8     |   4262.8  |        2422.2      |        1245.2   
      f16 B=16, M=1024, H=16, K=32   |         1608.0       |                      |        1605.9     |   4489.6  |        2697.4      |        1617.9   
      f16 B=16, M=1024, H=16, K=64   |         2549.3       |                      |        2541.9     |   4992.6  |        3341.4      |        2370.4   
      f16 B=16, M=1024, H=16, K=128  |         5034.4       |                      |        5023.6     |   5955.5  |        5903.8      |        5634.1   
      f16 B=16, M=1024, H=16, K=256  |                      |        31948.6       |                   |   7891.3  |       31850.4      |                 
      f16 B=64, M=128, H=16, K=16    |          193.9       |                      |         159.7     |    439.8  |         158.3      |         145.0   
      f16 B=64, M=128, H=16, K=32    |          217.0       |                      |         215.3     |    545.3  |         209.4      |         211.2   
      f16 B=64, M=128, H=16, K=64    |          316.0       |                      |         313.6     |    767.0  |         328.6      |         310.2   
      f16 B=64, M=128, H=16, K=128   |          559.7       |                      |         555.9     |   1229.2  |         637.6      |         561.6   
      f16 B=64, M=128, H=16, K=256   |                      |         2851.8       |                   |   2126.3  |        2839.0      |                 
      f16 B=64, M=512, H=16, K=16    |         1372.4       |                      |        1370.9     |   4484.9  |        2313.9      |        1219.4   
      f16 B=64, M=512, H=16, K=32    |         1542.0       |                      |        1539.2     |   4978.8  |        2706.7      |        1547.4   
      f16 B=64, M=512, H=16, K=64    |         2335.4       |                      |        2330.0     |   5885.5  |        3464.0      |        2417.8   
      f16 B=64, M=512, H=16, K=128   |         5357.0       |                      |        5347.8     |   7711.8  |        5909.5      |        5442.7   
      f16 B=64, M=512, H=16, K=256   |                      |        31136.1       |                   |  11501.8  |       30488.7      |                 
      f16 B=64, M=1024, H=16, K=16   |         5226.1       |                      |        5228.4     |  16908.5  |        9092.8      |        4713.7   
      f16 B=64, M=1024, H=16, K=32   |         5717.3       |                      |        5723.1     |  17895.8  |       10439.0      |        5705.3   
      f16 B=64, M=1024, H=16, K=64   |         8672.2       |                      |        8664.4     |  19929.7  |       12937.8      |        8137.3   
      f16 B=64, M=1024, H=16, K=128  |        19104.3       |                      |       19080.9     |  23711.0  |       20963.4      |       19169.7   
      f16 B=64, M=1024, H=16, K=256  |                      |       115163.2       |                   |  32729.8  |      113715.6      |                 

Times are in microseconds (us).
```
</details>

[ghstack-poisoned]

NOTE: Dropout is entirely disabled until we figure out how to implement it properly
<details>
<summary>BW perf on A100</summary>

```
[-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------]                                                                                                                                         
                                     |  refactor[flshattB]  |  refactor[cutlassB]  |  flash2[flshatt]  |  vanilla  |  cutlass[cutlass]  |  flash1[flshatt]
1 threads: -----------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |          449.4       |                      |         447.6     |   2260.4  |         621.0      |                 
      f16 B=384, M=197, H=1, K=80    |          433.8       |                      |         431.8     |   1916.2  |         589.4      |                 
      f16 B=384, M=197, H=1, K=64    |          234.5       |                      |         233.6     |   1810.0  |         406.2      |         232.8   
      f16 B=1024, M=197, H=1, K=88   |         1112.9       |                      |        1111.3     |   5945.9  |        1557.7      |                 
      f16 B=1024, M=197, H=1, K=80   |         1073.3       |                      |        1073.5     |   5022.5  |        1480.4      |                 
      f16 B=1024, M=197, H=1, K=64   |          579.9       |                      |         579.4     |   4736.3  |         930.1      |         578.3   
      f16 B=512, M=197, H=1, K=80    |          546.0       |                      |         545.1     |   2539.2  |         749.0      |                 
      f16 B=32, M=197, H=16, K=80    |          560.2       |                      |         558.7     |   2579.7  |         748.1      |                 
      f16 B=32, M=197, H=16, K=64    |          298.4       |                      |         297.8     |   2430.3  |         480.1      |         295.8   
      f16 B=32, M=197, H=16, K=128   |          652.5       |                      |         647.4     |   4494.1  |         896.4      |         684.3   
      f16 B=256, M=197, H=1, K=88    |          330.1       |                      |         328.8     |   1527.1  |         449.9      |                 
      f16 B=16, M=197, H=16, K=88    |          331.2       |                      |         329.1     |   1540.0  |         445.2      |                 
      f16 B=16, M=197, H=16, K=64    |          198.7       |                      |         166.9     |   1243.8  |         246.4      |         165.9   
      f16 B=16, M=197, H=16, K=128   |          371.7       |                      |         367.4     |   2266.8  |         503.6      |         386.4   
      f16 B=1, M=4096, H=160, K=128  |        42897.1       |                      |       42915.9     |  45961.0  |       56457.5      |       54740.1   
      f16 B=2, M=4096, H=160, K=128  |        84681.5       |                      |       84646.9     |           |       93553.0      |       84262.0   
      f16 B=1, M=8192, H=160, K=128  |       165741.2       |                      |      165657.6     |           |      222850.4      |      215961.0   
      f16 B=2, M=8192, H=160, K=128  |       332314.2       |                      |      332593.3     |           |      364220.8      |      330727.9   
      f16 B=1024, M=82, H=8, K=64    |         1647.0       |                      |        1643.6     |   3821.1  |        1886.9      |        1621.7   
      f16 B=150, M=256, H=16, K=64   |         1632.0       |                      |        1626.4     |   4556.0  |        2302.2      |        1626.6   
      f16 B=64, M=256, H=12, K=64    |          570.4       |                      |         566.8     |   1499.1  |         778.4      |         567.1   
      f16 B=1, M=4096, H=16, K=40    |         2153.5       |                      |        2154.7     |   4193.7  |       23526.2      |                 
      f16 B=1, M=16384, H=16, K=40   |        27859.4       |                      |       27850.3     |           |      389891.9      |                 
      f16 B=256, M=4096, H=16, K=64  |       491968.1       |                      |      491246.6     |           |      729981.6      |      440392.8   
      f16 B=16, M=128, H=16, K=16    |          195.8       |                      |         118.3     |    276.8  |         129.3      |         120.4   
      f16 B=16, M=128, H=16, K=32    |          197.7       |                      |         117.6     |    276.8  |         128.8      |         113.7   
      f16 B=16, M=128, H=16, K=64    |          196.0       |                      |         122.9     |    324.0  |         129.9      |         113.6   
      f16 B=16, M=128, H=16, K=128   |          196.4       |                      |         170.0     |    331.5  |         182.2      |         157.9   
      f16 B=16, M=128, H=16, K=256   |                      |          791.8       |                   |    545.0  |         774.4      |                 
      f16 B=16, M=512, H=16, K=16    |          378.0       |                      |         375.1     |   1203.0  |         623.4      |         325.4   
      f16 B=16, M=512, H=16, K=32    |          436.7       |                      |         433.8     |   1306.2  |         711.6      |         432.0   
      f16 B=16, M=512, H=16, K=64    |          684.6       |                      |         680.6     |   1543.4  |         917.7      |         702.3   
      f16 B=16, M=512, H=16, K=128   |         1565.8       |                      |        1560.9     |   1983.1  |        1656.5      |        1580.7   
      f16 B=16, M=512, H=16, K=256   |                      |         8678.8       |                   |   2901.2  |        8652.9      |                 
      f16 B=16, M=1024, H=16, K=16   |         1432.3       |                      |        1435.8     |   4262.8  |        2422.2      |        1245.2   
      f16 B=16, M=1024, H=16, K=32   |         1608.0       |                      |        1605.9     |   4489.6  |        2697.4      |        1617.9   
      f16 B=16, M=1024, H=16, K=64   |         2549.3       |                      |        2541.9     |   4992.6  |        3341.4      |        2370.4   
      f16 B=16, M=1024, H=16, K=128  |         5034.4       |                      |        5023.6     |   5955.5  |        5903.8      |        5634.1   
      f16 B=16, M=1024, H=16, K=256  |                      |        31948.6       |                   |   7891.3  |       31850.4      |                 
      f16 B=64, M=128, H=16, K=16    |          193.9       |                      |         159.7     |    439.8  |         158.3      |         145.0   
      f16 B=64, M=128, H=16, K=32    |          217.0       |                      |         215.3     |    545.3  |         209.4      |         211.2   
      f16 B=64, M=128, H=16, K=64    |          316.0       |                      |         313.6     |    767.0  |         328.6      |         310.2   
      f16 B=64, M=128, H=16, K=128   |          559.7       |                      |         555.9     |   1229.2  |         637.6      |         561.6   
      f16 B=64, M=128, H=16, K=256   |                      |         2851.8       |                   |   2126.3  |        2839.0      |                 
      f16 B=64, M=512, H=16, K=16    |         1372.4       |                      |        1370.9     |   4484.9  |        2313.9      |        1219.4   
      f16 B=64, M=512, H=16, K=32    |         1542.0       |                      |        1539.2     |   4978.8  |        2706.7      |        1547.4   
      f16 B=64, M=512, H=16, K=64    |         2335.4       |                      |        2330.0     |   5885.5  |        3464.0      |        2417.8   
      f16 B=64, M=512, H=16, K=128   |         5357.0       |                      |        5347.8     |   7711.8  |        5909.5      |        5442.7   
      f16 B=64, M=512, H=16, K=256   |                      |        31136.1       |                   |  11501.8  |       30488.7      |                 
      f16 B=64, M=1024, H=16, K=16   |         5226.1       |                      |        5228.4     |  16908.5  |        9092.8      |        4713.7   
      f16 B=64, M=1024, H=16, K=32   |         5717.3       |                      |        5723.1     |  17895.8  |       10439.0      |        5705.3   
      f16 B=64, M=1024, H=16, K=64   |         8672.2       |                      |        8664.4     |  19929.7  |       12937.8      |        8137.3   
      f16 B=64, M=1024, H=16, K=128  |        19104.3       |                      |       19080.9     |  23711.0  |       20963.4      |       19169.7   
      f16 B=64, M=1024, H=16, K=256  |                      |       115163.2       |                   |  32729.8  |      113715.6      |                 

Times are in microseconds (us).
```
</details>

[ghstack-poisoned]

NOTE: Dropout is entirely disabled until we figure out how to implement it properly
<details>
<summary>BW perf on A100</summary>

```
[-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------]                                                                                                                                         
                                     |  refactor[flshattB]  |  refactor[cutlassB]  |  flash2[flshatt]  |  vanilla  |  cutlass[cutlass]  |  flash1[flshatt]
1 threads: -----------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |          449.4       |                      |         447.6     |   2260.4  |         621.0      |                 
      f16 B=384, M=197, H=1, K=80    |          433.8       |                      |         431.8     |   1916.2  |         589.4      |                 
      f16 B=384, M=197, H=1, K=64    |          234.5       |                      |         233.6     |   1810.0  |         406.2      |         232.8   
      f16 B=1024, M=197, H=1, K=88   |         1112.9       |                      |        1111.3     |   5945.9  |        1557.7      |                 
      f16 B=1024, M=197, H=1, K=80   |         1073.3       |                      |        1073.5     |   5022.5  |        1480.4      |                 
      f16 B=1024, M=197, H=1, K=64   |          579.9       |                      |         579.4     |   4736.3  |         930.1      |         578.3   
      f16 B=512, M=197, H=1, K=80    |          546.0       |                      |         545.1     |   2539.2  |         749.0      |                 
      f16 B=32, M=197, H=16, K=80    |          560.2       |                      |         558.7     |   2579.7  |         748.1      |                 
      f16 B=32, M=197, H=16, K=64    |          298.4       |                      |         297.8     |   2430.3  |         480.1      |         295.8   
      f16 B=32, M=197, H=16, K=128   |          652.5       |                      |         647.4     |   4494.1  |         896.4      |         684.3   
      f16 B=256, M=197, H=1, K=88    |          330.1       |                      |         328.8     |   1527.1  |         449.9      |                 
      f16 B=16, M=197, H=16, K=88    |          331.2       |                      |         329.1     |   1540.0  |         445.2      |                 
      f16 B=16, M=197, H=16, K=64    |          198.7       |                      |         166.9     |   1243.8  |         246.4      |         165.9   
      f16 B=16, M=197, H=16, K=128   |          371.7       |                      |         367.4     |   2266.8  |         503.6      |         386.4   
      f16 B=1, M=4096, H=160, K=128  |        42897.1       |                      |       42915.9     |  45961.0  |       56457.5      |       54740.1   
      f16 B=2, M=4096, H=160, K=128  |        84681.5       |                      |       84646.9     |           |       93553.0      |       84262.0   
      f16 B=1, M=8192, H=160, K=128  |       165741.2       |                      |      165657.6     |           |      222850.4      |      215961.0   
      f16 B=2, M=8192, H=160, K=128  |       332314.2       |                      |      332593.3     |           |      364220.8      |      330727.9   
      f16 B=1024, M=82, H=8, K=64    |         1647.0       |                      |        1643.6     |   3821.1  |        1886.9      |        1621.7   
      f16 B=150, M=256, H=16, K=64   |         1632.0       |                      |        1626.4     |   4556.0  |        2302.2      |        1626.6   
      f16 B=64, M=256, H=12, K=64    |          570.4       |                      |         566.8     |   1499.1  |         778.4      |         567.1   
      f16 B=1, M=4096, H=16, K=40    |         2153.5       |                      |        2154.7     |   4193.7  |       23526.2      |                 
      f16 B=1, M=16384, H=16, K=40   |        27859.4       |                      |       27850.3     |           |      389891.9      |                 
      f16 B=256, M=4096, H=16, K=64  |       491968.1       |                      |      491246.6     |           |      729981.6      |      440392.8   
      f16 B=16, M=128, H=16, K=16    |          195.8       |                      |         118.3     |    276.8  |         129.3      |         120.4   
      f16 B=16, M=128, H=16, K=32    |          197.7       |                      |         117.6     |    276.8  |         128.8      |         113.7   
      f16 B=16, M=128, H=16, K=64    |          196.0       |                      |         122.9     |    324.0  |         129.9      |         113.6   
      f16 B=16, M=128, H=16, K=128   |          196.4       |                      |         170.0     |    331.5  |         182.2      |         157.9   
      f16 B=16, M=128, H=16, K=256   |                      |          791.8       |                   |    545.0  |         774.4      |                 
      f16 B=16, M=512, H=16, K=16    |          378.0       |                      |         375.1     |   1203.0  |         623.4      |         325.4   
      f16 B=16, M=512, H=16, K=32    |          436.7       |                      |         433.8     |   1306.2  |         711.6      |         432.0   
      f16 B=16, M=512, H=16, K=64    |          684.6       |                      |         680.6     |   1543.4  |         917.7      |         702.3   
      f16 B=16, M=512, H=16, K=128   |         1565.8       |                      |        1560.9     |   1983.1  |        1656.5      |        1580.7   
      f16 B=16, M=512, H=16, K=256   |                      |         8678.8       |                   |   2901.2  |        8652.9      |                 
      f16 B=16, M=1024, H=16, K=16   |         1432.3       |                      |        1435.8     |   4262.8  |        2422.2      |        1245.2   
      f16 B=16, M=1024, H=16, K=32   |         1608.0       |                      |        1605.9     |   4489.6  |        2697.4      |        1617.9   
      f16 B=16, M=1024, H=16, K=64   |         2549.3       |                      |        2541.9     |   4992.6  |        3341.4      |        2370.4   
      f16 B=16, M=1024, H=16, K=128  |         5034.4       |                      |        5023.6     |   5955.5  |        5903.8      |        5634.1   
      f16 B=16, M=1024, H=16, K=256  |                      |        31948.6       |                   |   7891.3  |       31850.4      |                 
      f16 B=64, M=128, H=16, K=16    |          193.9       |                      |         159.7     |    439.8  |         158.3      |         145.0   
      f16 B=64, M=128, H=16, K=32    |          217.0       |                      |         215.3     |    545.3  |         209.4      |         211.2   
      f16 B=64, M=128, H=16, K=64    |          316.0       |                      |         313.6     |    767.0  |         328.6      |         310.2   
      f16 B=64, M=128, H=16, K=128   |          559.7       |                      |         555.9     |   1229.2  |         637.6      |         561.6   
      f16 B=64, M=128, H=16, K=256   |                      |         2851.8       |                   |   2126.3  |        2839.0      |                 
      f16 B=64, M=512, H=16, K=16    |         1372.4       |                      |        1370.9     |   4484.9  |        2313.9      |        1219.4   
      f16 B=64, M=512, H=16, K=32    |         1542.0       |                      |        1539.2     |   4978.8  |        2706.7      |        1547.4   
      f16 B=64, M=512, H=16, K=64    |         2335.4       |                      |        2330.0     |   5885.5  |        3464.0      |        2417.8   
      f16 B=64, M=512, H=16, K=128   |         5357.0       |                      |        5347.8     |   7711.8  |        5909.5      |        5442.7   
      f16 B=64, M=512, H=16, K=256   |                      |        31136.1       |                   |  11501.8  |       30488.7      |                 
      f16 B=64, M=1024, H=16, K=16   |         5226.1       |                      |        5228.4     |  16908.5  |        9092.8      |        4713.7   
      f16 B=64, M=1024, H=16, K=32   |         5717.3       |                      |        5723.1     |  17895.8  |       10439.0      |        5705.3   
      f16 B=64, M=1024, H=16, K=64   |         8672.2       |                      |        8664.4     |  19929.7  |       12937.8      |        8137.3   
      f16 B=64, M=1024, H=16, K=128  |        19104.3       |                      |       19080.9     |  23711.0  |       20963.4      |       19169.7   
      f16 B=64, M=1024, H=16, K=256  |                      |       115163.2       |                   |  32729.8  |      113715.6      |                 

Times are in microseconds (us).
```
</details>

[ghstack-poisoned]
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting some early comments in (the PR was still changing while I added the first comments), so that I don't forget

Comment on lines +46 to +53
# Saving attn_bias is a bit complicated, as the
# torch part should go in `save_for_backward`
if isinstance(inp.attn_bias, torch.Tensor):
attn_bias_tensor = inp.attn_bias
attn_bias_ctx = None
else:
attn_bias_tensor = None
attn_bias_ctx = inp.attn_bias
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably make the CausalTensor actually inherit from a torch.Tensor in the near future

xformers/ops/fmha/__init__.py Show resolved Hide resolved
return attn_bias_ctx
return attn_bias_tensor

@classmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably also make it @once_differentiable for correctness. But I don't know how it plays with classmethod

Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

return (gradients.dq, gradients.dk, gradients.dv)


def _memory_efficient_attention(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Those _ functions seems like they could be folded into the original memory_efficient_attention function, avoiding the need for this private function?


NOTE: Dropout is entirely disabled until we figure out how to implement it properly
<details>
<summary>BW perf on A100</summary>

```
[-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------]                                                                                                                                         
                                     |  refactor[flshattB]  |  refactor[cutlassB]  |  flash2[flshatt]  |  vanilla  |  cutlass[cutlass]  |  flash1[flshatt]
1 threads: -----------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |          449.4       |                      |         447.6     |   2260.4  |         621.0      |                 
      f16 B=384, M=197, H=1, K=80    |          433.8       |                      |         431.8     |   1916.2  |         589.4      |                 
      f16 B=384, M=197, H=1, K=64    |          234.5       |                      |         233.6     |   1810.0  |         406.2      |         232.8   
      f16 B=1024, M=197, H=1, K=88   |         1112.9       |                      |        1111.3     |   5945.9  |        1557.7      |                 
      f16 B=1024, M=197, H=1, K=80   |         1073.3       |                      |        1073.5     |   5022.5  |        1480.4      |                 
      f16 B=1024, M=197, H=1, K=64   |          579.9       |                      |         579.4     |   4736.3  |         930.1      |         578.3   
      f16 B=512, M=197, H=1, K=80    |          546.0       |                      |         545.1     |   2539.2  |         749.0      |                 
      f16 B=32, M=197, H=16, K=80    |          560.2       |                      |         558.7     |   2579.7  |         748.1      |                 
      f16 B=32, M=197, H=16, K=64    |          298.4       |                      |         297.8     |   2430.3  |         480.1      |         295.8   
      f16 B=32, M=197, H=16, K=128   |          652.5       |                      |         647.4     |   4494.1  |         896.4      |         684.3   
      f16 B=256, M=197, H=1, K=88    |          330.1       |                      |         328.8     |   1527.1  |         449.9      |                 
      f16 B=16, M=197, H=16, K=88    |          331.2       |                      |         329.1     |   1540.0  |         445.2      |                 
      f16 B=16, M=197, H=16, K=64    |          198.7       |                      |         166.9     |   1243.8  |         246.4      |         165.9   
      f16 B=16, M=197, H=16, K=128   |          371.7       |                      |         367.4     |   2266.8  |         503.6      |         386.4   
      f16 B=1, M=4096, H=160, K=128  |        42897.1       |                      |       42915.9     |  45961.0  |       56457.5      |       54740.1   
      f16 B=2, M=4096, H=160, K=128  |        84681.5       |                      |       84646.9     |           |       93553.0      |       84262.0   
      f16 B=1, M=8192, H=160, K=128  |       165741.2       |                      |      165657.6     |           |      222850.4      |      215961.0   
      f16 B=2, M=8192, H=160, K=128  |       332314.2       |                      |      332593.3     |           |      364220.8      |      330727.9   
      f16 B=1024, M=82, H=8, K=64    |         1647.0       |                      |        1643.6     |   3821.1  |        1886.9      |        1621.7   
      f16 B=150, M=256, H=16, K=64   |         1632.0       |                      |        1626.4     |   4556.0  |        2302.2      |        1626.6   
      f16 B=64, M=256, H=12, K=64    |          570.4       |                      |         566.8     |   1499.1  |         778.4      |         567.1   
      f16 B=1, M=4096, H=16, K=40    |         2153.5       |                      |        2154.7     |   4193.7  |       23526.2      |                 
      f16 B=1, M=16384, H=16, K=40   |        27859.4       |                      |       27850.3     |           |      389891.9      |                 
      f16 B=256, M=4096, H=16, K=64  |       491968.1       |                      |      491246.6     |           |      729981.6      |      440392.8   
      f16 B=16, M=128, H=16, K=16    |          195.8       |                      |         118.3     |    276.8  |         129.3      |         120.4   
      f16 B=16, M=128, H=16, K=32    |          197.7       |                      |         117.6     |    276.8  |         128.8      |         113.7   
      f16 B=16, M=128, H=16, K=64    |          196.0       |                      |         122.9     |    324.0  |         129.9      |         113.6   
      f16 B=16, M=128, H=16, K=128   |          196.4       |                      |         170.0     |    331.5  |         182.2      |         157.9   
      f16 B=16, M=128, H=16, K=256   |                      |          791.8       |                   |    545.0  |         774.4      |                 
      f16 B=16, M=512, H=16, K=16    |          378.0       |                      |         375.1     |   1203.0  |         623.4      |         325.4   
      f16 B=16, M=512, H=16, K=32    |          436.7       |                      |         433.8     |   1306.2  |         711.6      |         432.0   
      f16 B=16, M=512, H=16, K=64    |          684.6       |                      |         680.6     |   1543.4  |         917.7      |         702.3   
      f16 B=16, M=512, H=16, K=128   |         1565.8       |                      |        1560.9     |   1983.1  |        1656.5      |        1580.7   
      f16 B=16, M=512, H=16, K=256   |                      |         8678.8       |                   |   2901.2  |        8652.9      |                 
      f16 B=16, M=1024, H=16, K=16   |         1432.3       |                      |        1435.8     |   4262.8  |        2422.2      |        1245.2   
      f16 B=16, M=1024, H=16, K=32   |         1608.0       |                      |        1605.9     |   4489.6  |        2697.4      |        1617.9   
      f16 B=16, M=1024, H=16, K=64   |         2549.3       |                      |        2541.9     |   4992.6  |        3341.4      |        2370.4   
      f16 B=16, M=1024, H=16, K=128  |         5034.4       |                      |        5023.6     |   5955.5  |        5903.8      |        5634.1   
      f16 B=16, M=1024, H=16, K=256  |                      |        31948.6       |                   |   7891.3  |       31850.4      |                 
      f16 B=64, M=128, H=16, K=16    |          193.9       |                      |         159.7     |    439.8  |         158.3      |         145.0   
      f16 B=64, M=128, H=16, K=32    |          217.0       |                      |         215.3     |    545.3  |         209.4      |         211.2   
      f16 B=64, M=128, H=16, K=64    |          316.0       |                      |         313.6     |    767.0  |         328.6      |         310.2   
      f16 B=64, M=128, H=16, K=128   |          559.7       |                      |         555.9     |   1229.2  |         637.6      |         561.6   
      f16 B=64, M=128, H=16, K=256   |                      |         2851.8       |                   |   2126.3  |        2839.0      |                 
      f16 B=64, M=512, H=16, K=16    |         1372.4       |                      |        1370.9     |   4484.9  |        2313.9      |        1219.4   
      f16 B=64, M=512, H=16, K=32    |         1542.0       |                      |        1539.2     |   4978.8  |        2706.7      |        1547.4   
      f16 B=64, M=512, H=16, K=64    |         2335.4       |                      |        2330.0     |   5885.5  |        3464.0      |        2417.8   
      f16 B=64, M=512, H=16, K=128   |         5357.0       |                      |        5347.8     |   7711.8  |        5909.5      |        5442.7   
      f16 B=64, M=512, H=16, K=256   |                      |        31136.1       |                   |  11501.8  |       30488.7      |                 
      f16 B=64, M=1024, H=16, K=16   |         5226.1       |                      |        5228.4     |  16908.5  |        9092.8      |        4713.7   
      f16 B=64, M=1024, H=16, K=32   |         5717.3       |                      |        5723.1     |  17895.8  |       10439.0      |        5705.3   
      f16 B=64, M=1024, H=16, K=64   |         8672.2       |                      |        8664.4     |  19929.7  |       12937.8      |        8137.3   
      f16 B=64, M=1024, H=16, K=128  |        19104.3       |                      |       19080.9     |  23711.0  |       20963.4      |       19169.7   
      f16 B=64, M=1024, H=16, K=256  |                      |       115163.2       |                   |  32729.8  |      113715.6      |                 

Times are in microseconds (us).
```
</details>

[ghstack-poisoned]
@danthe3rd danthe3rd merged commit 3b82aec into gh/danthe3rd/59/base Dec 9, 2022
danthe3rd pushed a commit that referenced this pull request Dec 9, 2022
ghstack-source-id: 548ab26edd3160663dd9bdf08e4be3f1f1ab809c
Pull Request resolved: #560
@danthe3rd danthe3rd deleted the gh/danthe3rd/59/head branch December 9, 2022 15:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants