Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Triton Flash Attention #479

Merged
merged 31 commits into from
Dec 6, 2022
Merged

Add Triton Flash Attention #479

merged 31 commits into from
Dec 6, 2022

Conversation

dianaml0
Copy link
Contributor

@dianaml0 dianaml0 commented Oct 11, 2022

What does this PR do?

Adds Triton Flash Attention

Performance Compared to Vanilla
[--------- attention (attn_bias=<class 'NoneType'>) --------]
                                     |  optimized  |  vanilla
1 threads: --------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |     1260.4  |    372.9
      b16 B=384, M=197, H=1, K=88    |     1269.9  |    375.3
      f16 B=384, M=197, H=1, K=80    |      146.7  |    344.5
      b16 B=384, M=197, H=1, K=80    |      149.2  |    346.6
      f16 B=384, M=197, H=1, K=64    |       92.0  |    293.2
      b16 B=384, M=197, H=1, K=64    |       94.2  |    295.1
      f16 B=1024, M=197, H=1, K=88   |     3242.3  |    938.8
      b16 B=1024, M=197, H=1, K=88   |     3264.9  |    945.1
      f16 B=1024, M=197, H=1, K=80   |      348.7  |    864.5
      b16 B=1024, M=197, H=1, K=80   |      354.3  |    871.5
      f16 B=1024, M=197, H=1, K=64   |      213.7  |    729.9
      b16 B=1024, M=197, H=1, K=64   |      222.1  |    735.5
      f16 B=512, M=197, H=1, K=80    |      185.4  |    448.4
      b16 B=512, M=197, H=1, K=80    |      188.5  |    451.6
      f16 B=32, M=197, H=16, K=80    |      193.8  |    547.9
      b16 B=32, M=197, H=16, K=80    |      193.7  |    550.6
      f16 B=32, M=197, H=16, K=64    |      114.4  |    467.1
      b16 B=32, M=197, H=16, K=64    |      120.8  |    469.2
      f16 B=32, M=197, H=16, K=128   |      193.8  |    717.4
      b16 B=32, M=197, H=16, K=128   |      195.1  |    720.3
      f16 B=256, M=197, H=1, K=88    |      868.4  |    262.1
      b16 B=256, M=197, H=1, K=88    |      869.9  |    262.2
      f16 B=16, M=197, H=16, K=88    |      869.8  |    317.7
      b16 B=16, M=197, H=16, K=88    |      872.4  |    319.0
      f16 B=16, M=197, H=16, K=64    |       88.6  |    257.7
      b16 B=16, M=197, H=16, K=64    |       89.6  |    259.8
      f16 B=16, M=197, H=16, K=128   |      101.5  |    385.8
      b16 B=16, M=197, H=16, K=128   |      102.2  |    387.1
      f16 B=1, M=4096, H=160, K=128  |     9807.3  |  20343.8
      b16 B=1, M=4096, H=160, K=128  |    10034.8  |  21406.5
      f16 B=2, M=4096, H=160, K=128  |    19298.5  |  42595.1
      b16 B=2, M=4096, H=160, K=128  |    19841.7  |  44399.3
      f16 B=1, M=8192, H=160, K=128  |    37266.9  |  88426.6
      b16 B=1, M=8192, H=160, K=128  |    38570.7  |  87103.1
      f16 B=2, M=8192, H=160, K=128  |    74084.0  |
      b16 B=2, M=8192, H=160, K=128  |    76617.8  |
      f16 B=1024, M=82, H=8, K=64    |      461.1  |   1767.3
      b16 B=1024, M=82, H=8, K=64    |      479.9  |   1862.5
      f16 B=150, M=256, H=16, K=64   |      382.5  |   1725.6
      b16 B=150, M=256, H=16, K=64   |      420.9  |   1760.7
      f16 B=64, M=256, H=12, K=64    |      131.7  |    587.0
      b16 B=64, M=256, H=12, K=64    |      145.7  |    597.4
      f16 B=1, M=4096, H=16, K=40    |    30705.0  |   1937.8
      b16 B=1, M=4096, H=16, K=40    |    30832.3  |   1990.1
      f16 B=1, M=16384, H=16, K=40   |   423123.6  |  28845.9
      b16 B=1, M=16384, H=16, K=40   |   420611.6  |  30167.0
      f16 B=256, M=4096, H=16, K=64  |   118530.4  |
      b16 B=256, M=4096, H=16, K=64  |   132843.9  |
      f16 B=8, M=2048, H=20, K=128   |     2692.1  |   5704.5
      b16 B=8, M=2048, H=20, K=128   |     2743.9  |   6044.7
      f16 B=16, M=128, H=16, K=16    |       90.2  |    135.5
      b16 B=16, M=128, H=16, K=16    |       87.0  |    136.7
      f16 B=16, M=128, H=16, K=32    |       88.4  |    137.5
      b16 B=16, M=128, H=16, K=32    |       91.0  |    138.1
      f16 B=16, M=128, H=16, K=64    |       89.1  |    137.3
      b16 B=16, M=128, H=16, K=64    |       90.3  |    137.7
      f16 B=16, M=128, H=16, K=128   |       91.6  |    139.6
      b16 B=16, M=128, H=16, K=128   |       90.6  |    140.7
      f16 B=16, M=512, H=16, K=16    |       89.1  |    461.4
      b16 B=16, M=512, H=16, K=16    |       99.3  |    553.7
      f16 B=16, M=512, H=16, K=32    |      105.6  |    512.1
      b16 B=16, M=512, H=16, K=32    |      123.2  |    594.4
      f16 B=16, M=512, H=16, K=64    |      152.1  |    595.9
      b16 B=16, M=512, H=16, K=64    |      169.4  |    613.0
      f16 B=16, M=512, H=16, K=128   |      319.0  |    786.8
      b16 B=16, M=512, H=16, K=128   |      328.0  |    804.9
      f16 B=16, M=1024, H=16, K=16   |      292.1  |   1642.8
      b16 B=16, M=1024, H=16, K=16   |      385.7  |   2025.9
      f16 B=16, M=1024, H=16, K=32   |      369.5  |   1732.0
      b16 B=16, M=1024, H=16, K=32   |      428.2  |   2127.2
      f16 B=16, M=1024, H=16, K=64   |      518.4  |   2033.7
      b16 B=16, M=1024, H=16, K=64   |      579.9  |   2064.6
      f16 B=16, M=1024, H=16, K=128  |     1097.7  |   2410.3
      b16 B=16, M=1024, H=16, K=128  |     1135.3  |   2490.3
      f16 B=64, M=128, H=16, K=16    |       87.5  |    183.0
      b16 B=64, M=128, H=16, K=16    |       89.6  |    184.9
      f16 B=64, M=128, H=16, K=32    |       89.2  |    224.6
      b16 B=64, M=128, H=16, K=32    |       88.2  |    225.5
      f16 B=64, M=128, H=16, K=64    |       87.5  |    321.5
      b16 B=64, M=128, H=16, K=64    |       87.2  |    322.9
      f16 B=64, M=128, H=16, K=128   |      127.0  |    484.7
      b16 B=64, M=128, H=16, K=128   |      129.4  |    486.1
      f16 B=64, M=512, H=16, K=16    |      307.5  |   1727.1
      b16 B=64, M=512, H=16, K=16    |      403.5  |   2094.0
      f16 B=64, M=512, H=16, K=32    |      387.2  |   1888.1
      b16 B=64, M=512, H=16, K=32    |      450.6  |   2250.3
      f16 B=64, M=512, H=16, K=64    |      557.9  |   2233.9
      b16 B=64, M=512, H=16, K=64    |      620.4  |   2294.9
      f16 B=64, M=512, H=16, K=128   |     1205.0  |   3008.3
      b16 B=64, M=512, H=16, K=128   |     1235.5  |   3072.8
      f16 B=64, M=1024, H=16, K=16   |     1101.8  |   6422.7
      b16 B=64, M=1024, H=16, K=16   |     1473.2  |   8024.6
      f16 B=64, M=1024, H=16, K=32   |     1409.8  |   6744.7
      b16 B=64, M=1024, H=16, K=32   |     1667.6  |   8387.5
      f16 B=64, M=1024, H=16, K=64   |     2022.1  |   7966.1
      b16 B=64, M=1024, H=16, K=64   |     2265.4  |   8127.2
      f16 B=64, M=1024, H=16, K=128  |     4282.5  |   9524.9
      b16 B=64, M=1024, H=16, K=128  |     4413.0  |   9842.0

Times are in microseconds (us).
[ attention (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
                                     |  optimized  |  vanilla
1 threads: ---------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |     3701.7  |     447.0
      b16 B=384, M=197, H=1, K=88    |     1178.3  |     452.6
      f16 B=384, M=197, H=1, K=80    |      120.2  |     421.6
      b16 B=384, M=197, H=1, K=80    |      122.5  |     427.5
      f16 B=384, M=197, H=1, K=64    |       87.3  |     370.6
      b16 B=384, M=197, H=1, K=64    |       90.1  |     376.4
      f16 B=1024, M=197, H=1, K=88   |     9586.4  |    1124.4
      b16 B=1024, M=197, H=1, K=88   |     2981.6  |    1139.9
      f16 B=1024, M=197, H=1, K=80   |      294.0  |    1059.9
      b16 B=1024, M=197, H=1, K=80   |      300.3  |    1075.3
      f16 B=1024, M=197, H=1, K=64   |      187.5  |     925.9
      b16 B=1024, M=197, H=1, K=64   |      197.2  |     940.5
      f16 B=512, M=197, H=1, K=80    |      153.5  |     549.2
      b16 B=512, M=197, H=1, K=80    |      156.6  |     557.4
      f16 B=32, M=197, H=16, K=80    |      157.4  |     647.5
      b16 B=32, M=197, H=16, K=80    |      160.6  |     654.6
      f16 B=32, M=197, H=16, K=64    |      103.8  |     565.6
      b16 B=32, M=197, H=16, K=64    |      109.2  |     573.3
      f16 B=32, M=197, H=16, K=128   |      157.7  |     810.9
      b16 B=32, M=197, H=16, K=128   |      159.7  |     818.4
      f16 B=256, M=197, H=1, K=88    |     2551.1  |     310.8
      b16 B=256, M=197, H=1, K=88    |      813.3  |     314.9
      f16 B=16, M=197, H=16, K=88    |     2554.5  |     367.9
      b16 B=16, M=197, H=16, K=88    |      806.5  |     372.0
      f16 B=16, M=197, H=16, K=64    |       90.3  |     307.1
      b16 B=16, M=197, H=16, K=64    |       89.6  |     310.9
      f16 B=16, M=197, H=16, K=128   |       87.8  |     435.4
      b16 B=16, M=197, H=16, K=128   |       91.6  |     438.6
      f16 B=1, M=4096, H=160, K=128  |     5176.5  |   37308.5
      b16 B=1, M=4096, H=160, K=128  |     5330.5  |   37818.1
      f16 B=2, M=4096, H=160, K=128  |    10225.6  |   76322.3
      b16 B=2, M=4096, H=160, K=128  |    10547.6  |   77302.1
      f16 B=1, M=8192, H=160, K=128  |    19289.3  |  152433.6
      b16 B=1, M=8192, H=160, K=128  |    19922.2  |  148602.0
      f16 B=2, M=8192, H=160, K=128  |    38313.9  |
      b16 B=2, M=8192, H=160, K=128  |    39610.6  |
      f16 B=1024, M=82, H=8, K=64    |      488.4  |    1979.9
      b16 B=1024, M=82, H=8, K=64    |      515.8  |    2085.1
      f16 B=150, M=256, H=16, K=64   |      333.7  |    2401.0
      b16 B=150, M=256, H=16, K=64   |      349.7  |    2446.2
      f16 B=64, M=256, H=12, K=64    |      118.8  |     805.3
      b16 B=64, M=256, H=12, K=64    |      124.8  |     819.5
      f16 B=1, M=4096, H=16, K=40    |    15004.4  |    3428.4
      b16 B=1, M=4096, H=16, K=40    |    14890.7  |    3465.7
      f16 B=1, M=16384, H=16, K=40   |   221390.1  |   55383.6
      b16 B=1, M=16384, H=16, K=40   |   218101.2  |   55427.0
      f16 B=256, M=4096, H=16, K=64  |    67757.8  |
      b16 B=256, M=4096, H=16, K=64  |    74047.4  |
      f16 B=8, M=2048, H=20, K=128   |     1487.8  |    9352.9
      b16 B=8, M=2048, H=20, K=128   |     1534.3  |    9539.5
      f16 B=16, M=128, H=16, K=16    |       92.2  |     146.0
      b16 B=16, M=128, H=16, K=16    |       87.0  |     143.8
      f16 B=16, M=128, H=16, K=32    |       87.2  |     144.5
      b16 B=16, M=128, H=16, K=32    |       89.3  |     145.7
      f16 B=16, M=128, H=16, K=64    |       90.6  |     144.4
      b16 B=16, M=128, H=16, K=64    |       90.3  |     141.9
      f16 B=16, M=128, H=16, K=128   |       87.2  |     160.0
      b16 B=16, M=128, H=16, K=128   |       87.2  |     161.7
      f16 B=16, M=512, H=16, K=16    |       90.4  |     715.1
      b16 B=16, M=512, H=16, K=16    |       91.0  |     792.1
      f16 B=16, M=512, H=16, K=32    |       88.5  |     766.6
      b16 B=16, M=512, H=16, K=32    |       98.1  |     832.2
      f16 B=16, M=512, H=16, K=64    |      120.3  |     881.1
      b16 B=16, M=512, H=16, K=64    |      129.0  |     910.1
      f16 B=16, M=512, H=16, K=128   |      225.5  |    1066.3
      b16 B=16, M=512, H=16, K=128   |      231.9  |    1095.5
      f16 B=16, M=1024, H=16, K=16   |      196.5  |    2658.0
      b16 B=16, M=1024, H=16, K=16   |      253.2  |    3042.2
      f16 B=16, M=1024, H=16, K=32   |      244.3  |    2749.6
      b16 B=16, M=1024, H=16, K=32   |      291.8  |    3121.5
      f16 B=16, M=1024, H=16, K=64   |      355.3  |    2963.7
      b16 B=16, M=1024, H=16, K=64   |      384.7  |    3287.5
      f16 B=16, M=1024, H=16, K=128  |      683.0  |    3534.2
      b16 B=16, M=1024, H=16, K=128  |      705.2  |    3711.2
      f16 B=64, M=128, H=16, K=16    |       87.6  |     245.8
      b16 B=64, M=128, H=16, K=16    |       89.9  |     251.7
      f16 B=64, M=128, H=16, K=32    |       88.4  |     294.2
      b16 B=64, M=128, H=16, K=32    |       90.7  |     298.7
      f16 B=64, M=128, H=16, K=64    |       92.0  |     391.1
      b16 B=64, M=128, H=16, K=64    |       88.6  |     395.9
      f16 B=64, M=128, H=16, K=128   |      131.0  |     557.0
      b16 B=64, M=128, H=16, K=128   |      133.2  |     562.1
      f16 B=64, M=512, H=16, K=16    |      227.6  |    2704.8
      b16 B=64, M=512, H=16, K=16    |      289.3  |    3019.3
      f16 B=64, M=512, H=16, K=32    |      285.1  |    2891.7
      b16 B=64, M=512, H=16, K=32    |      336.8  |    3163.4
      f16 B=64, M=512, H=16, K=64    |      414.3  |    3351.7
      b16 B=64, M=512, H=16, K=64    |      444.2  |    3475.0
      f16 B=64, M=512, H=16, K=128   |      829.5  |    4097.1
      b16 B=64, M=512, H=16, K=128   |      854.6  |    4208.8
      f16 B=64, M=1024, H=16, K=16   |      722.1  |   10479.3
      b16 B=64, M=1024, H=16, K=16   |      931.6  |   12030.5
      f16 B=64, M=1024, H=16, K=32   |      900.9  |   10856.5
      b16 B=64, M=1024, H=16, K=32   |     1075.2  |   12361.2
      f16 B=64, M=1024, H=16, K=64   |     1313.1  |   11684.0
      b16 B=64, M=1024, H=16, K=64   |     1424.4  |   13002.0
      f16 B=64, M=1024, H=16, K=128  |     2610.0  |   13969.2
      b16 B=64, M=1024, H=16, K=128  |     2697.7  |   14687.4

Times are in microseconds (us).
[---- attention backward (attn_bias=<class 'NoneType'>) ----]
                                     |  optimized  |  vanilla
1 threads: --------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |     3702.8  |    820.9
      b16 B=384, M=197, H=1, K=88    |     3701.9  |    823.4
      f16 B=384, M=197, H=1, K=80    |      761.9  |    768.3
      b16 B=384, M=197, H=1, K=80    |      762.5  |    769.8
      f16 B=384, M=197, H=1, K=64    |      592.9  |    651.8
      b16 B=384, M=197, H=1, K=64    |      372.7  |    652.9
      f16 B=1024, M=197, H=1, K=88   |     9107.4  |   2103.4
      b16 B=1024, M=197, H=1, K=88   |     9115.3  |   2105.4
      f16 B=1024, M=197, H=1, K=80   |     1882.7  |   1958.4
      b16 B=1024, M=197, H=1, K=80   |     1891.6  |   1958.3
      f16 B=1024, M=197, H=1, K=64   |      916.6  |   1649.8
      b16 B=1024, M=197, H=1, K=64   |      924.4  |   1649.1
      f16 B=512, M=197, H=1, K=80    |      987.1  |    994.2
      b16 B=512, M=197, H=1, K=80    |      990.8  |    997.0
      f16 B=32, M=197, H=16, K=80    |     1037.3  |   1034.2
      b16 B=32, M=197, H=16, K=80    |     1040.1  |   1035.8
      f16 B=32, M=197, H=16, K=64    |      486.2  |    885.3
      b16 B=32, M=197, H=16, K=64    |      488.7  |    887.6
      f16 B=32, M=197, H=16, K=128   |     1165.0  |   1343.0
      b16 B=32, M=197, H=16, K=128   |     1168.3  |   1345.8
      f16 B=256, M=197, H=1, K=88    |     2360.9  |    576.3
      b16 B=256, M=197, H=1, K=88    |     2361.7  |    576.9
      f16 B=16, M=197, H=16, K=88    |     2377.4  |    599.7
      b16 B=16, M=197, H=16, K=88    |     2376.8  |    600.9
      f16 B=16, M=197, H=16, K=64    |      303.0  |    488.1
      b16 B=16, M=197, H=16, K=64    |      330.5  |    488.6
      f16 B=16, M=197, H=16, K=128   |      617.9  |    713.8
      b16 B=16, M=197, H=16, K=128   |      619.7  |    716.1
      f16 B=1, M=4096, H=160, K=128  |    41860.6  |  38927.7
      b16 B=1, M=4096, H=160, K=128  |    41971.4  |  39833.9
      f16 B=2, M=4096, H=160, K=128  |    83623.8  |  78708.1
      b16 B=2, M=4096, H=160, K=128  |    84009.4  |  80592.3
      f16 B=1, M=8192, H=160, K=128  |   160725.6  |
      b16 B=1, M=8192, H=160, K=128  |   161001.2  |
      f16 B=2, M=8192, H=160, K=128  |   321049.6  |
      b16 B=2, M=8192, H=160, K=128  |   321800.6  |
      f16 B=1024, M=82, H=8, K=64    |     2608.7  |   3609.7
      b16 B=1024, M=82, H=8, K=64    |     2525.2  |   3783.9
      f16 B=150, M=256, H=16, K=64   |     2552.5  |   3809.6
      b16 B=150, M=256, H=16, K=64   |     2464.5  |   3838.1
      f16 B=64, M=256, H=12, K=64    |      827.8  |   1257.5
      b16 B=64, M=256, H=12, K=64    |      835.1  |   1269.2
      f16 B=1, M=4096, H=16, K=40    |    43085.7  |   3556.9
      b16 B=1, M=4096, H=16, K=40    |    42898.9  |   3595.8
      f16 B=1, M=16384, H=16, K=40   |   664283.3  |  53870.2
      b16 B=1, M=16384, H=16, K=40   |   664702.6  |  54308.0
      f16 B=256, M=4096, H=16, K=64  |   503458.9  |
      b16 B=256, M=4096, H=16, K=64  |   503461.7  |
      f16 B=8, M=2048, H=20, K=128   |    11442.9  |  10536.5
      b16 B=8, M=2048, H=20, K=128   |    11362.9  |  10790.8
      f16 B=16, M=128, H=16, K=16    |      349.5  |    311.2
      b16 B=16, M=128, H=16, K=16    |      303.8  |    314.0
      f16 B=16, M=128, H=16, K=32    |      316.6  |    340.6
      b16 B=16, M=128, H=16, K=32    |      328.7  |    337.9
      f16 B=16, M=128, H=16, K=64    |      318.3  |    333.9
      b16 B=16, M=128, H=16, K=64    |      332.7  |    335.4
      f16 B=16, M=128, H=16, K=128   |      452.5  |    331.2
      b16 B=16, M=128, H=16, K=128   |      329.1  |    335.9
      f16 B=16, M=512, H=16, K=16    |      527.1  |    982.9
      b16 B=16, M=512, H=16, K=16    |      322.8  |   1078.9
      f16 B=16, M=512, H=16, K=32    |      680.9  |   1090.2
      b16 B=16, M=512, H=16, K=32    |      487.4  |   1179.3
      f16 B=16, M=512, H=16, K=64    |     1016.7  |   1276.3
      b16 B=16, M=512, H=16, K=64    |      845.2  |   1295.1
      f16 B=16, M=512, H=16, K=128   |     1891.7  |   1712.6
      b16 B=16, M=512, H=16, K=128   |     1766.5  |   1744.0
      f16 B=16, M=1024, H=16, K=16   |     1257.4  |   3532.6
      b16 B=16, M=1024, H=16, K=16   |     1066.4  |   3953.2
      f16 B=16, M=1024, H=16, K=32   |     1770.3  |   3731.3
      b16 B=16, M=1024, H=16, K=32   |     1615.7  |   4158.6
      f16 B=16, M=1024, H=16, K=64   |     2686.2  |   4281.1
      b16 B=16, M=1024, H=16, K=64   |     2545.2  |   4348.9
      f16 B=16, M=1024, H=16, K=128  |     5407.5  |   5109.9
      b16 B=16, M=1024, H=16, K=128  |     5312.2  |   5248.8
      f16 B=64, M=128, H=16, K=16    |      304.3  |    365.2
      b16 B=64, M=128, H=16, K=16    |      325.1  |    372.8
      f16 B=64, M=128, H=16, K=32    |      304.3  |    465.8
      b16 B=64, M=128, H=16, K=32    |      301.5  |    472.4
      f16 B=64, M=128, H=16, K=64    |      464.3  |    679.8
      b16 B=64, M=128, H=16, K=64    |      469.2  |    681.4
      f16 B=64, M=128, H=16, K=128   |      980.1  |   1076.0
      b16 B=64, M=128, H=16, K=128   |      987.5  |   1079.2
      f16 B=64, M=512, H=16, K=16    |     1109.7  |   3705.7
      b16 B=64, M=512, H=16, K=16    |     1111.9  |   4078.7
      f16 B=64, M=512, H=16, K=32    |     1678.8  |   4135.2
      b16 B=64, M=512, H=16, K=32    |     1684.5  |   4504.2
      f16 B=64, M=512, H=16, K=64    |     2982.7  |   4906.5
      b16 B=64, M=512, H=16, K=64    |     3001.9  |   4992.0
      f16 B=64, M=512, H=16, K=128   |     6641.2  |   6644.6
      b16 B=64, M=512, H=16, K=128   |     6697.3  |   6756.1
      f16 B=64, M=1024, H=16, K=16   |     3610.9  |  13882.0
      b16 B=64, M=1024, H=16, K=16   |     3614.8  |  15555.3
      f16 B=64, M=1024, H=16, K=32   |     6138.0  |  14767.7
      b16 B=64, M=1024, H=16, K=32   |     6147.9  |  16450.9
      f16 B=64, M=1024, H=16, K=64   |     9813.8  |  16866.6
      b16 B=64, M=1024, H=16, K=64   |     9848.6  |  17160.6
      f16 B=64, M=1024, H=16, K=128  |    20556.5  |  20371.7
      b16 B=64, M=1024, H=16, K=128  |    20674.3  |  20916.7

Times are in microseconds (us).
[ attention backward (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
                                     |  optimized  |  vanilla
1 threads: --------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |     9480.6  |    822.5
      b16 B=384, M=197, H=1, K=88    |     9355.1  |    823.8
      f16 B=384, M=197, H=1, K=80    |      669.6  |    768.6
      b16 B=384, M=197, H=1, K=80    |      674.6  |    770.2
      f16 B=384, M=197, H=1, K=64    |      542.3  |    651.9
      b16 B=384, M=197, H=1, K=64    |      351.1  |    653.8
      f16 B=1024, M=197, H=1, K=88   |    23565.1  |   2105.1
      b16 B=1024, M=197, H=1, K=88   |    23556.3  |   2105.4
      f16 B=1024, M=197, H=1, K=80   |     1646.1  |   1961.7
      b16 B=1024, M=197, H=1, K=80   |     1651.9  |   1958.3
      f16 B=1024, M=197, H=1, K=64   |      861.8  |   1649.0
      b16 B=1024, M=197, H=1, K=64   |      869.0  |   1648.1
      f16 B=512, M=197, H=1, K=80    |      863.6  |    995.4
      b16 B=512, M=197, H=1, K=80    |      866.2  |    997.1
      f16 B=32, M=197, H=16, K=80    |      867.8  |   1034.4
      b16 B=32, M=197, H=16, K=80    |      869.5  |   1036.2
      f16 B=32, M=197, H=16, K=64    |      456.3  |    883.7
      b16 B=32, M=197, H=16, K=64    |      458.7  |    885.6
      f16 B=32, M=197, H=16, K=128   |     1042.0  |   1340.5
      b16 B=32, M=197, H=16, K=128   |     1045.1  |   1343.6
      f16 B=256, M=197, H=1, K=88    |     6342.1  |    577.0
      b16 B=256, M=197, H=1, K=88    |     6344.3  |    578.3
      f16 B=16, M=197, H=16, K=88    |     6335.6  |    600.3
      b16 B=16, M=197, H=16, K=88    |     6341.5  |    601.6
      f16 B=16, M=197, H=16, K=64    |      330.3  |    488.4
      b16 B=16, M=197, H=16, K=64    |      325.5  |    490.0
      f16 B=16, M=197, H=16, K=128   |      557.4  |    715.2
      b16 B=16, M=197, H=16, K=128   |      557.8  |    717.4
      f16 B=1, M=4096, H=160, K=128  |    26152.2  |  38914.3
      b16 B=1, M=4096, H=160, K=128  |    26166.4  |  39924.4
      f16 B=2, M=4096, H=160, K=128  |    51711.2  |  78726.9
      b16 B=2, M=4096, H=160, K=128  |    51925.3  |  80711.7
      f16 B=1, M=8192, H=160, K=128  |    92696.3  |
      b16 B=1, M=8192, H=160, K=128  |    92960.2  |
      f16 B=2, M=8192, H=160, K=128  |   184624.7  |
      b16 B=2, M=8192, H=160, K=128  |   185330.2  |
      f16 B=1024, M=82, H=8, K=64    |     2724.8  |   3608.8
      b16 B=1024, M=82, H=8, K=64    |     2642.1  |   3784.4
      f16 B=150, M=256, H=16, K=64   |     2379.5  |   3803.7
      b16 B=150, M=256, H=16, K=64   |     2287.7  |   3832.2
      f16 B=64, M=256, H=12, K=64    |      774.1  |   1255.6
      b16 B=64, M=256, H=12, K=64    |      781.0  |   1270.5
      f16 B=1, M=4096, H=16, K=40    |     6799.9  |   3563.7
      b16 B=1, M=4096, H=16, K=40    |     6624.0  |   3598.3
      f16 B=1, M=16384, H=16, K=40   |    94594.6  |  53902.5
      b16 B=1, M=16384, H=16, K=40   |    94692.4  |  54354.3
      f16 B=256, M=4096, H=16, K=64  |   274333.4  |
      b16 B=256, M=4096, H=16, K=64  |   274964.5  |
      f16 B=8, M=2048, H=20, K=128   |     7844.4  |  10545.5
      b16 B=8, M=2048, H=20, K=128   |     7764.3  |  10795.4
      f16 B=16, M=128, H=16, K=16    |      334.9  |    293.7
      b16 B=16, M=128, H=16, K=16    |      326.3  |    314.7
      f16 B=16, M=128, H=16, K=32    |      354.5  |    319.1
      b16 B=16, M=128, H=16, K=32    |      304.6  |    291.4
      f16 B=16, M=128, H=16, K=64    |      350.7  |    312.6
      b16 B=16, M=128, H=16, K=64    |      331.6  |    301.8
      f16 B=16, M=128, H=16, K=128   |      486.3  |    313.5
      b16 B=16, M=128, H=16, K=128   |      326.8  |    314.4
      f16 B=16, M=512, H=16, K=16    |      456.9  |    986.5
      b16 B=16, M=512, H=16, K=16    |      306.0  |   1080.0
      f16 B=16, M=512, H=16, K=32    |      566.6  |   1091.4
      b16 B=16, M=512, H=16, K=32    |      393.9  |   1179.2
      f16 B=16, M=512, H=16, K=64    |      841.5  |   1276.5
      b16 B=16, M=512, H=16, K=64    |      669.4  |   1294.9
      f16 B=16, M=512, H=16, K=128   |     1590.8  |   1713.6
      b16 B=16, M=512, H=16, K=128   |     1486.2  |   1740.0
      f16 B=16, M=1024, H=16, K=16   |      858.6  |   3529.7
      b16 B=16, M=1024, H=16, K=16   |      686.6  |   3951.6
      f16 B=16, M=1024, H=16, K=32   |     1233.9  |   3733.7
      b16 B=16, M=1024, H=16, K=32   |     1059.3  |   4154.4
      f16 B=16, M=1024, H=16, K=64   |     1875.2  |   4281.8
      b16 B=16, M=1024, H=16, K=64   |     1754.2  |   4347.6
      f16 B=16, M=1024, H=16, K=128  |     4062.9  |   5105.1
      b16 B=16, M=1024, H=16, K=128  |     3983.0  |   5247.9
      f16 B=64, M=128, H=16, K=16    |      304.4  |    368.6
      b16 B=64, M=128, H=16, K=16    |      301.9  |    376.4
      f16 B=64, M=128, H=16, K=32    |      303.9  |    466.2
      b16 B=64, M=128, H=16, K=32    |      327.3  |    471.8
      f16 B=64, M=128, H=16, K=64    |      475.8  |    682.3
      b16 B=64, M=128, H=16, K=64    |      479.3  |    679.5
      f16 B=64, M=128, H=16, K=128   |     1041.7  |   1076.2
      b16 B=64, M=128, H=16, K=128   |     1047.8  |   1077.4
      f16 B=64, M=512, H=16, K=16    |      912.1  |   3704.3
      b16 B=64, M=512, H=16, K=16    |      917.2  |   4082.8
      f16 B=64, M=512, H=16, K=32    |     1458.2  |   4133.9
      b16 B=64, M=512, H=16, K=32    |     1461.2  |   4501.9
      f16 B=64, M=512, H=16, K=64    |     2466.8  |   4903.1
      b16 B=64, M=512, H=16, K=64    |     2492.7  |   4993.7
      f16 B=64, M=512, H=16, K=128   |     5545.0  |   6642.9
      b16 B=64, M=512, H=16, K=128   |     5595.6  |   6754.2
      f16 B=64, M=1024, H=16, K=16   |     2608.6  |  13864.4
      b16 B=64, M=1024, H=16, K=16   |     2610.5  |  15560.3
      f16 B=64, M=1024, H=16, K=32   |     4025.8  |  14755.8
      b16 B=64, M=1024, H=16, K=32   |     4032.1  |  16463.7
      f16 B=64, M=1024, H=16, K=64   |     6658.9  |  16880.7
      b16 B=64, M=1024, H=16, K=64   |     6722.0  |  17150.6
      f16 B=64, M=1024, H=16, K=128  |    15345.5  |  20371.0
      b16 B=64, M=1024, H=16, K=128  |    15438.5  |  20920.1

Times are in microseconds (us).
[ attention backward (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
                                     |  optimized  |  fctls_bflsh
1 threads: ------------------------------------------------------
      f16 B=384, M=197, H=1, K=64    |      542.1  |       241.0
      b16 B=384, M=197, H=1, K=64    |      344.0  |       236.7
      f16 B=1024, M=197, H=1, K=64   |      852.4  |       539.2
      b16 B=1024, M=197, H=1, K=64   |      857.6  |       544.6
      f16 B=32, M=197, H=16, K=64    |      449.0  |       280.6
      b16 B=32, M=197, H=16, K=64    |      451.6  |       283.2
      f16 B=32, M=197, H=16, K=128   |     1161.4  |       517.8
      b16 B=32, M=197, H=16, K=128   |     1030.0  |       518.8
      f16 B=16, M=197, H=16, K=64    |      312.1  |       214.0
      b16 B=16, M=197, H=16, K=64    |      309.8  |       213.7
      f16 B=16, M=197, H=16, K=128   |      547.4  |       296.6
      b16 B=16, M=197, H=16, K=128   |      548.1  |       298.1
      f16 B=1, M=4096, H=160, K=128  |    25831.9  |     31365.0
      b16 B=1, M=4096, H=160, K=128  |    25811.2  |     31382.5
      f16 B=2, M=4096, H=160, K=128  |    51084.6  |     48413.3
      b16 B=2, M=4096, H=160, K=128  |    51265.2  |     48371.8
      f16 B=1, M=8192, H=160, K=128  |    91356.3  |    121976.0
      b16 B=1, M=8192, H=160, K=128  |    91520.0  |    122016.8
      f16 B=2, M=8192, H=160, K=128  |   181984.4  |    187070.0
      b16 B=2, M=8192, H=160, K=128  |   182646.4  |    187302.9
      f16 B=1024, M=82, H=8, K=64    |     2703.5  |      1502.2
      b16 B=1024, M=82, H=8, K=64    |     2615.9  |      1511.7
      f16 B=150, M=256, H=16, K=64   |     2380.2  |      1505.7
      b16 B=150, M=256, H=16, K=64   |     2264.4  |      1521.6
      f16 B=64, M=256, H=12, K=64    |      768.4  |       526.4
      b16 B=64, M=256, H=12, K=64    |      774.2  |       530.1
      f16 B=8, M=2048, H=20, K=128   |     7798.2  |      8299.4
      b16 B=8, M=2048, H=20, K=128   |     7678.0  |      8323.0
      f16 B=16, M=128, H=16, K=16    |      322.9  |       212.5
      b16 B=16, M=128, H=16, K=16    |      311.7  |       192.5
      f16 B=16, M=128, H=16, K=32    |      359.6  |       212.4
      b16 B=16, M=128, H=16, K=32    |      333.6  |       194.5
      f16 B=16, M=128, H=16, K=64    |      330.3  |       196.0
      b16 B=16, M=128, H=16, K=64    |      314.8  |       196.9
      f16 B=16, M=128, H=16, K=128   |      489.5  |       215.7
      b16 B=16, M=128, H=16, K=128   |      343.9  |       215.5
      f16 B=16, M=512, H=16, K=16    |      463.5  |       261.1
      b16 B=16, M=512, H=16, K=16    |      318.1  |       263.3
      f16 B=16, M=512, H=16, K=32    |      573.5  |       337.3
      b16 B=16, M=512, H=16, K=32    |      391.6  |       339.6
      f16 B=16, M=512, H=16, K=64    |      829.6  |       517.0
      b16 B=16, M=512, H=16, K=64    |      666.4  |       521.2
      f16 B=16, M=512, H=16, K=128   |     1608.1  |      1034.4
      b16 B=16, M=512, H=16, K=128   |     1472.1  |      1032.4
      f16 B=16, M=1024, H=16, K=16   |      866.3  |       792.3
      b16 B=16, M=1024, H=16, K=16   |      684.4  |       793.2
      f16 B=16, M=1024, H=16, K=32   |     1269.0  |      1022.8
      b16 B=16, M=1024, H=16, K=32   |     1056.2  |      1023.4
      f16 B=16, M=1024, H=16, K=64   |     1885.2  |      1563.7
      b16 B=16, M=1024, H=16, K=64   |     1749.4  |      1564.6
      f16 B=16, M=1024, H=16, K=128  |     4052.5  |      3428.0
      b16 B=16, M=1024, H=16, K=128  |     3929.8  |      3432.8
      f16 B=64, M=128, H=16, K=16    |      313.7  |       218.0
      b16 B=64, M=128, H=16, K=16    |      336.9  |       217.6
      f16 B=64, M=128, H=16, K=32    |      317.4  |       212.5
      b16 B=64, M=128, H=16, K=32    |      318.1  |       208.0
      f16 B=64, M=128, H=16, K=64    |      470.3  |       284.8
      b16 B=64, M=128, H=16, K=64    |      474.6  |       286.9
      f16 B=64, M=128, H=16, K=128   |     1024.8  |       504.8
      b16 B=64, M=128, H=16, K=128   |     1030.2  |       507.4
      f16 B=64, M=512, H=16, K=16    |      909.5  |       924.7
      b16 B=64, M=512, H=16, K=16    |      913.6  |       930.5
      f16 B=64, M=512, H=16, K=32    |     1454.8  |      1176.4
      b16 B=64, M=512, H=16, K=32    |     1459.1  |      1181.8
      f16 B=64, M=512, H=16, K=64    |     2460.3  |      1752.1
      b16 B=64, M=512, H=16, K=64    |     2485.5  |      1773.0
      f16 B=64, M=512, H=16, K=128   |     5503.4  |      3564.6
      b16 B=64, M=512, H=16, K=128   |     5557.4  |      3592.5
      f16 B=64, M=1024, H=16, K=16   |     2599.1  |      2866.5
      b16 B=64, M=1024, H=16, K=16   |     2605.9  |      2868.9
      f16 B=64, M=1024, H=16, K=32   |     4017.3  |      3577.7
      b16 B=64, M=1024, H=16, K=32   |     4022.4  |      3592.0
      f16 B=64, M=1024, H=16, K=64   |     6648.5  |      5349.7
      b16 B=64, M=1024, H=16, K=64   |     6716.9  |      5374.7
      f16 B=64, M=1024, H=16, K=128  |    15206.5  |     11777.5
      b16 B=64, M=1024, H=16, K=128  |    15313.3  |     11814.0

Times are in microseconds (us).
Performance Compared to MemoryEfficientAttentionCutlassFwdFlashBwOp

[----------- attention (attn_bias=<class 'NoneType'>) ----------]
                                     |  optimized  |  fctls_bflsh
1 threads: ------------------------------------------------------
      f16 B=384, M=197, H=1, K=64    |       89.9  |        88.9
      b16 B=384, M=197, H=1, K=64    |       94.2  |        88.8
      f16 B=1024, M=197, H=1, K=64   |      213.5  |       215.1
      b16 B=1024, M=197, H=1, K=64   |      221.9  |       215.1
      f16 B=32, M=197, H=16, K=64    |      114.0  |       113.8
      b16 B=32, M=197, H=16, K=64    |      120.0  |       113.8
      f16 B=32, M=197, H=16, K=128   |      193.3  |       168.6
      b16 B=32, M=197, H=16, K=128   |      194.6  |       166.9
      f16 B=16, M=197, H=16, K=64    |       85.2  |        60.7
      b16 B=16, M=197, H=16, K=64    |       87.2  |        60.7
      f16 B=16, M=197, H=16, K=128   |      101.6  |        88.6
      b16 B=16, M=197, H=16, K=128   |      102.0  |        88.0
      f16 B=1, M=4096, H=160, K=128  |     9816.7  |     15435.3
      b16 B=1, M=4096, H=160, K=128  |    10046.1  |     15280.9
      f16 B=2, M=4096, H=160, K=128  |    19357.5  |     30760.0
      b16 B=2, M=4096, H=160, K=128  |    19869.8  |     30546.3
      f16 B=1, M=8192, H=160, K=128  |    37248.5  |     61367.7
      b16 B=1, M=8192, H=160, K=128  |    38471.5  |     61015.5
      f16 B=2, M=8192, H=160, K=128  |    73817.0  |    123123.0
      b16 B=2, M=8192, H=160, K=128  |    76761.6  |    121918.5
      f16 B=1024, M=82, H=8, K=64    |      459.6  |       455.0
      b16 B=1024, M=82, H=8, K=64    |      477.1  |       455.1
      f16 B=150, M=256, H=16, K=64   |      383.4  |       521.7
      b16 B=150, M=256, H=16, K=64   |      419.8  |       518.6
      f16 B=64, M=256, H=12, K=64    |      131.4  |       176.2
      b16 B=64, M=256, H=12, K=64    |      145.1  |       173.2
      f16 B=256, M=4096, H=16, K=64  |   118450.0  |    191728.6
      b16 B=256, M=4096, H=16, K=64  |   131960.6  |    190391.4
      f16 B=8, M=2048, H=20, K=128   |     2694.7  |      3380.7
      b16 B=8, M=2048, H=20, K=128   |     2738.6  |      3332.9
      f16 B=16, M=128, H=16, K=16    |       90.9  |        31.2
      b16 B=16, M=128, H=16, K=16    |       89.4  |        30.2
      f16 B=16, M=128, H=16, K=32    |       87.6  |        30.7
      b16 B=16, M=128, H=16, K=32    |       90.0  |        30.7
      f16 B=16, M=128, H=16, K=64    |       88.7  |        30.2
      b16 B=16, M=128, H=16, K=64    |       89.4  |        30.2
      f16 B=16, M=128, H=16, K=128   |       87.2  |        35.8
      b16 B=16, M=128, H=16, K=128   |       89.3  |        35.8
      f16 B=16, M=512, H=16, K=16    |       89.2  |       180.0
      b16 B=16, M=512, H=16, K=16    |       99.6  |       179.9
      f16 B=16, M=512, H=16, K=32    |      105.3  |       186.2
      b16 B=16, M=512, H=16, K=32    |      123.0  |       186.1
      f16 B=16, M=512, H=16, K=64    |      152.1  |       213.7
      b16 B=16, M=512, H=16, K=64    |      169.3  |       213.7
      f16 B=16, M=512, H=16, K=128   |      318.6  |       367.5
      b16 B=16, M=512, H=16, K=128   |      327.5  |       364.1
      f16 B=16, M=1024, H=16, K=16   |      291.8  |       686.6
      b16 B=16, M=1024, H=16, K=16   |      384.4  |       689.4
      f16 B=16, M=1024, H=16, K=32   |      369.7  |       693.3
      b16 B=16, M=1024, H=16, K=32   |      427.3  |       695.3
      f16 B=16, M=1024, H=16, K=64   |      518.2  |       802.4
      b16 B=16, M=1024, H=16, K=64   |      579.3  |       794.1
      f16 B=16, M=1024, H=16, K=128  |     1098.6  |      1383.8
      b16 B=16, M=1024, H=16, K=128  |     1135.7  |      1371.4
      f16 B=64, M=128, H=16, K=16    |       85.6  |        53.8
      b16 B=64, M=128, H=16, K=16    |       85.7  |        53.8
      f16 B=64, M=128, H=16, K=32    |       89.5  |        58.3
      b16 B=64, M=128, H=16, K=32    |       87.5  |        58.4
      f16 B=64, M=128, H=16, K=64    |       86.1  |        69.5
      b16 B=64, M=128, H=16, K=64    |       85.6  |        69.3
      f16 B=64, M=128, H=16, K=128   |      127.2  |       121.0
      b16 B=64, M=128, H=16, K=128   |      129.7  |       119.1
      f16 B=64, M=512, H=16, K=16    |      307.1  |       703.0
      b16 B=64, M=512, H=16, K=16    |      403.4  |       703.5
      f16 B=64, M=512, H=16, K=32    |      387.1  |       712.1
      b16 B=64, M=512, H=16, K=32    |      449.8  |       711.7
      f16 B=64, M=512, H=16, K=64    |      558.4  |       821.8
      b16 B=64, M=512, H=16, K=64    |      619.9  |       818.0
      f16 B=64, M=512, H=16, K=128   |     1200.5  |      1451.8
      b16 B=64, M=512, H=16, K=128   |     1235.0  |      1434.0
      f16 B=64, M=1024, H=16, K=16   |     1101.1  |      2692.5
      b16 B=64, M=1024, H=16, K=16   |     1472.9  |      2691.6
      f16 B=64, M=1024, H=16, K=32   |     1409.8  |      2720.0
      b16 B=64, M=1024, H=16, K=32   |     1669.5  |      2715.8
      f16 B=64, M=1024, H=16, K=64   |     2025.4  |      3136.1
      b16 B=64, M=1024, H=16, K=64   |     2263.5  |      3105.2
      f16 B=64, M=1024, H=16, K=128  |     4279.5  |      5500.2
      b16 B=64, M=1024, H=16, K=128  |     4413.0  |      5406.7

Times are in microseconds (us).
[ attention (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
                                     |  optimized  |  fctls_bflsh
1 threads: ------------------------------------------------------
      f16 B=384, M=197, H=1, K=64    |      89.1   |       66.3
      b16 B=384, M=197, H=1, K=64    |      87.5   |       66.3
      f16 B=1024, M=197, H=1, K=64   |     187.0   |      153.3
      b16 B=1024, M=197, H=1, K=64   |     196.9   |      153.2
      f16 B=32, M=197, H=16, K=64    |     103.6   |       84.7
      b16 B=32, M=197, H=16, K=64    |     109.5   |       84.8
      f16 B=32, M=197, H=16, K=128   |     157.6   |      126.1
      b16 B=32, M=197, H=16, K=128   |     159.3   |      124.3
      f16 B=16, M=197, H=16, K=64    |      88.5   |       46.3
      b16 B=16, M=197, H=16, K=64    |      86.2   |       46.3
      f16 B=16, M=197, H=16, K=128   |      86.1   |       68.3
      b16 B=16, M=197, H=16, K=128   |      90.7   |       68.3
      f16 B=1, M=4096, H=160, K=128  |    5167.6   |     7921.9
      b16 B=1, M=4096, H=160, K=128  |    5335.3   |     7870.1
      f16 B=2, M=4096, H=160, K=128  |   10226.7   |    15736.6
      b16 B=2, M=4096, H=160, K=128  |   10540.9   |    15639.6
      f16 B=1, M=8192, H=160, K=128  |   19262.0   |    31230.1
      b16 B=1, M=8192, H=160, K=128  |   19908.8   |    30905.6
      f16 B=2, M=8192, H=160, K=128  |   38306.6   |    62229.2
      b16 B=2, M=8192, H=160, K=128  |   39533.3   |    61578.2
      f16 B=1024, M=82, H=8, K=64    |     486.7   |      375.9
      b16 B=1024, M=82, H=8, K=64    |     514.7   |      376.0
      f16 B=150, M=256, H=16, K=64   |     333.8   |      363.1
      b16 B=150, M=256, H=16, K=64   |     349.4   |      359.8
      f16 B=64, M=256, H=12, K=64    |     118.3   |      124.9
      b16 B=64, M=256, H=12, K=64    |     124.5   |      123.9
      f16 B=256, M=4096, H=16, K=64  |   67841.8   |    99189.2
      b16 B=256, M=4096, H=16, K=64  |   73773.6   |    97542.4
      f16 B=8, M=2048, H=20, K=128   |    1487.1   |     1851.1
      b16 B=8, M=2048, H=20, K=128   |    1533.2   |     1818.1
      f16 B=16, M=128, H=16, K=16    |      87.7   |       30.3
      b16 B=16, M=128, H=16, K=16    |      85.4   |       30.6
      f16 B=16, M=128, H=16, K=32    |      86.8   |       30.7
      b16 B=16, M=128, H=16, K=32    |      89.3   |       30.6
      f16 B=16, M=128, H=16, K=64    |      85.6   |       30.8
      b16 B=16, M=128, H=16, K=64    |      88.8   |       30.5
      f16 B=16, M=128, H=16, K=128   |      86.6   |       34.3
      b16 B=16, M=128, H=16, K=128   |      87.6   |       34.4
      f16 B=16, M=512, H=16, K=16    |      86.6   |      119.7
      b16 B=16, M=512, H=16, K=16    |      86.1   |      119.7
      f16 B=16, M=512, H=16, K=32    |      86.2   |      124.0
      b16 B=16, M=512, H=16, K=32    |      98.0   |      124.0
      f16 B=16, M=512, H=16, K=64    |     120.6   |      142.7
      b16 B=16, M=512, H=16, K=64    |     129.1   |      142.8
      f16 B=16, M=512, H=16, K=128   |     225.3   |      241.0
      b16 B=16, M=512, H=16, K=128   |     231.9   |      238.4
      f16 B=16, M=1024, H=16, K=16   |     196.4   |      399.8
      b16 B=16, M=1024, H=16, K=16   |     252.5   |      399.4
      f16 B=16, M=1024, H=16, K=32   |     244.2   |      405.7
      b16 B=16, M=1024, H=16, K=32   |     291.7   |      405.3
      f16 B=16, M=1024, H=16, K=64   |     356.4   |      464.3
      b16 B=16, M=1024, H=16, K=64   |     385.0   |      463.6
      f16 B=16, M=1024, H=16, K=128  |     683.6   |      805.3
      b16 B=16, M=1024, H=16, K=128  |     704.4   |      794.6
      f16 B=64, M=128, H=16, K=16    |      85.7   |       47.1
      b16 B=64, M=128, H=16, K=16    |      86.2   |       47.1
      f16 B=64, M=128, H=16, K=32    |      89.4   |       50.2
      b16 B=64, M=128, H=16, K=32    |      88.1   |       50.2
      f16 B=64, M=128, H=16, K=64    |      85.8   |       61.1
      b16 B=64, M=128, H=16, K=64    |      85.6   |       61.1
      f16 B=64, M=128, H=16, K=128   |     131.1   |      109.6
      b16 B=64, M=128, H=16, K=128   |     133.4   |      108.7
      f16 B=64, M=512, H=16, K=16    |     227.3   |      428.1
      b16 B=64, M=512, H=16, K=16    |     289.1   |      428.1
      f16 B=64, M=512, H=16, K=32    |     285.0   |      434.4
      b16 B=64, M=512, H=16, K=32    |     336.7   |      434.3
      f16 B=64, M=512, H=16, K=64    |     415.5   |      507.7
      b16 B=64, M=512, H=16, K=64    |     444.6   |      502.7
      f16 B=64, M=512, H=16, K=128   |     829.4   |      919.7
      b16 B=64, M=512, H=16, K=128   |     854.1   |      910.8
      f16 B=64, M=1024, H=16, K=16   |     722.3   |     1497.8
      b16 B=64, M=1024, H=16, K=16   |     931.8   |     1495.7
      f16 B=64, M=1024, H=16, K=32   |     900.6   |     1512.9
      b16 B=64, M=1024, H=16, K=32   |    1075.4   |     1513.4
      f16 B=64, M=1024, H=16, K=64   |    1316.9   |     1764.7
      b16 B=64, M=1024, H=16, K=64   |    1424.1   |     1739.0
      f16 B=64, M=1024, H=16, K=128  |    2607.3   |     3139.7
      b16 B=64, M=1024, H=16, K=128  |    2692.0   |     3106.5

Times are in microseconds (us).
[------ attention backward (attn_bias=<class 'NoneType'>) ------]
                                     |  optimized  |  fctls_bflsh
1 threads: ------------------------------------------------------
      f16 B=384, M=197, H=1, K=64    |      545.1  |       236.0
      b16 B=384, M=197, H=1, K=64    |      366.9  |       238.5
      f16 B=1024, M=197, H=1, K=64   |      908.5  |       532.5
      b16 B=1024, M=197, H=1, K=64   |      915.2  |       531.5
      f16 B=32, M=197, H=16, K=64    |      477.8  |       276.7
      b16 B=32, M=197, H=16, K=64    |      481.0  |       277.0
      f16 B=32, M=197, H=16, K=128   |     1307.9  |       634.5
      b16 B=32, M=197, H=16, K=128   |     1156.2  |       635.5
      f16 B=16, M=197, H=16, K=64    |      333.1  |       213.3
      b16 B=16, M=197, H=16, K=64    |      310.0  |       212.3
      f16 B=16, M=197, H=16, K=128   |      597.0  |       365.3
      b16 B=16, M=197, H=16, K=128   |      598.4  |       366.0
      f16 B=1, M=4096, H=160, K=128  |    41711.0  |     54006.5
      b16 B=1, M=4096, H=160, K=128  |    41674.5  |     53975.9
      f16 B=2, M=4096, H=160, K=128  |    82813.6  |     82572.8
      b16 B=2, M=4096, H=160, K=128  |    83121.0  |     82540.1
      f16 B=1, M=8192, H=160, K=128  |   158913.4  |    213065.0
      b16 B=1, M=8192, H=160, K=128  |   159351.2  |    213233.8
      f16 B=2, M=8192, H=160, K=128  |   318080.6  |    324743.8
      b16 B=2, M=8192, H=160, K=128  |   318049.8  |    324995.7
      f16 B=1024, M=82, H=8, K=64    |     2628.2  |      1480.0
      b16 B=1024, M=82, H=8, K=64    |     2514.7  |      1502.1
      f16 B=150, M=256, H=16, K=64   |     2529.2  |      1485.6
      b16 B=150, M=256, H=16, K=64   |     2434.0  |      1485.3
      f16 B=64, M=256, H=12, K=64    |      824.4  |       520.0
      b16 B=64, M=256, H=12, K=64    |      831.4  |       517.1
      f16 B=8, M=2048, H=20, K=128   |    11325.1  |     13894.0
      b16 B=8, M=2048, H=20, K=128   |    11260.2  |     13895.4
      f16 B=16, M=128, H=16, K=16    |      323.3  |       193.0
      b16 B=16, M=128, H=16, K=16    |      339.6  |       213.9
      f16 B=16, M=128, H=16, K=32    |      368.4  |       215.0
      b16 B=16, M=128, H=16, K=32    |      333.9  |       210.6
      f16 B=16, M=128, H=16, K=64    |      328.8  |       196.2
      b16 B=16, M=128, H=16, K=64    |      314.8  |       197.2
      f16 B=16, M=128, H=16, K=128   |      496.4  |       222.7
      b16 B=16, M=128, H=16, K=128   |      334.9  |       215.7
      f16 B=16, M=512, H=16, K=16    |      536.8  |       320.3
      b16 B=16, M=512, H=16, K=16    |      340.9  |       323.6
      f16 B=16, M=512, H=16, K=32    |      660.7  |       422.7
      b16 B=16, M=512, H=16, K=32    |      479.8  |       424.9
      f16 B=16, M=512, H=16, K=64    |      985.6  |       670.5
      b16 B=16, M=512, H=16, K=64    |      824.3  |       672.4
      f16 B=16, M=512, H=16, K=128   |     1864.8  |      1504.7
      b16 B=16, M=512, H=16, K=128   |     1751.5  |      1508.3
      f16 B=16, M=1024, H=16, K=16   |     1255.1  |      1238.7
      b16 B=16, M=1024, H=16, K=16   |     1050.9  |      1241.7
      f16 B=16, M=1024, H=16, K=32   |     1798.1  |      1588.3
      b16 B=16, M=1024, H=16, K=32   |     1610.2  |      1596.3
      f16 B=16, M=1024, H=16, K=64   |     2691.8  |      2304.6
      b16 B=16, M=1024, H=16, K=64   |     2548.6  |      2310.8
      f16 B=16, M=1024, H=16, K=128  |     5376.4  |      5464.7
      b16 B=16, M=1024, H=16, K=128  |     5268.8  |      5473.3
      f16 B=64, M=128, H=16, K=16    |      328.3  |       194.6
      b16 B=64, M=128, H=16, K=16    |      320.7  |       194.9
      f16 B=64, M=128, H=16, K=32    |      317.8  |       221.3
      b16 B=64, M=128, H=16, K=32    |      315.1  |       199.9
      f16 B=64, M=128, H=16, K=64    |      461.0  |       281.6
      b16 B=64, M=128, H=16, K=64    |      465.7  |       284.7
      f16 B=64, M=128, H=16, K=128   |      967.1  |       491.4
      b16 B=64, M=128, H=16, K=128   |      974.3  |       496.1
      f16 B=64, M=512, H=16, K=16    |     1095.1  |      1180.2
      b16 B=64, M=512, H=16, K=16    |     1099.8  |      1188.7
      f16 B=64, M=512, H=16, K=32    |     1657.9  |      1487.9
      b16 B=64, M=512, H=16, K=32    |     1665.4  |      1497.0
      f16 B=64, M=512, H=16, K=64    |     2927.7  |      2291.3
      b16 B=64, M=512, H=16, K=64    |     2949.2  |      2302.8
      f16 B=64, M=512, H=16, K=128   |     6573.7  |      5139.4
      b16 B=64, M=512, H=16, K=128   |     6632.5  |      5169.0
      f16 B=64, M=1024, H=16, K=16   |     3570.0  |      4673.2
      b16 B=64, M=1024, H=16, K=16   |     3582.5  |      4671.1
      f16 B=64, M=1024, H=16, K=32   |     6134.0  |      5590.1
      b16 B=64, M=1024, H=16, K=32   |     6145.1  |      5607.8
      f16 B=64, M=1024, H=16, K=64   |     9832.3  |      7863.1
      b16 B=64, M=1024, H=16, K=64   |     9859.6  |      7890.6
      f16 B=64, M=1024, H=16, K=128  |    20411.4  |     18492.4
      b16 B=64, M=1024, H=16, K=128  |    20521.8  |     18545.9

Times are in microseconds (us).
[ attention backward (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
                                     |  optimized  |  fctls_bflsh
1 threads: ------------------------------------------------------
      f16 B=384, M=197, H=1, K=64    |      542.1  |       241.0
      b16 B=384, M=197, H=1, K=64    |      344.0  |       236.7
      f16 B=1024, M=197, H=1, K=64   |      852.4  |       539.2
      b16 B=1024, M=197, H=1, K=64   |      857.6  |       544.6
      f16 B=32, M=197, H=16, K=64    |      449.0  |       280.6
      b16 B=32, M=197, H=16, K=64    |      451.6  |       283.2
      f16 B=32, M=197, H=16, K=128   |     1161.4  |       517.8
      b16 B=32, M=197, H=16, K=128   |     1030.0  |       518.8
      f16 B=16, M=197, H=16, K=64    |      312.1  |       214.0
      b16 B=16, M=197, H=16, K=64    |      309.8  |       213.7
      f16 B=16, M=197, H=16, K=128   |      547.4  |       296.6
      b16 B=16, M=197, H=16, K=128   |      548.1  |       298.1
      f16 B=1, M=4096, H=160, K=128  |    25831.9  |     31365.0
      b16 B=1, M=4096, H=160, K=128  |    25811.2  |     31382.5
      f16 B=2, M=4096, H=160, K=128  |    51084.6  |     48413.3
      b16 B=2, M=4096, H=160, K=128  |    51265.2  |     48371.8
      f16 B=1, M=8192, H=160, K=128  |    91356.3  |    121976.0
      b16 B=1, M=8192, H=160, K=128  |    91520.0  |    122016.8
      f16 B=2, M=8192, H=160, K=128  |   181984.4  |    187070.0
      b16 B=2, M=8192, H=160, K=128  |   182646.4  |    187302.9
      f16 B=1024, M=82, H=8, K=64    |     2703.5  |      1502.2
      b16 B=1024, M=82, H=8, K=64    |     2615.9  |      1511.7
      f16 B=150, M=256, H=16, K=64   |     2380.2  |      1505.7
      b16 B=150, M=256, H=16, K=64   |     2264.4  |      1521.6
      f16 B=64, M=256, H=12, K=64    |      768.4  |       526.4
      b16 B=64, M=256, H=12, K=64    |      774.2  |       530.1
      f16 B=8, M=2048, H=20, K=128   |     7798.2  |      8299.4
      b16 B=8, M=2048, H=20, K=128   |     7678.0  |      8323.0
      f16 B=16, M=128, H=16, K=16    |      322.9  |       212.5
      b16 B=16, M=128, H=16, K=16    |      311.7  |       192.5
      f16 B=16, M=128, H=16, K=32    |      359.6  |       212.4
      b16 B=16, M=128, H=16, K=32    |      333.6  |       194.5
      f16 B=16, M=128, H=16, K=64    |      330.3  |       196.0
      b16 B=16, M=128, H=16, K=64    |      314.8  |       196.9
      f16 B=16, M=128, H=16, K=128   |      489.5  |       215.7
      b16 B=16, M=128, H=16, K=128   |      343.9  |       215.5
      f16 B=16, M=512, H=16, K=16    |      463.5  |       261.1
      b16 B=16, M=512, H=16, K=16    |      318.1  |       263.3
      f16 B=16, M=512, H=16, K=32    |      573.5  |       337.3
      b16 B=16, M=512, H=16, K=32    |      391.6  |       339.6
      f16 B=16, M=512, H=16, K=64    |      829.6  |       517.0
      b16 B=16, M=512, H=16, K=64    |      666.4  |       521.2
      f16 B=16, M=512, H=16, K=128   |     1608.1  |      1034.4
      b16 B=16, M=512, H=16, K=128   |     1472.1  |      1032.4
      f16 B=16, M=1024, H=16, K=16   |      866.3  |       792.3
      b16 B=16, M=1024, H=16, K=16   |      684.4  |       793.2
      f16 B=16, M=1024, H=16, K=32   |     1269.0  |      1022.8
      b16 B=16, M=1024, H=16, K=32   |     1056.2  |      1023.4
      f16 B=16, M=1024, H=16, K=64   |     1885.2  |      1563.7
      b16 B=16, M=1024, H=16, K=64   |     1749.4  |      1564.6
      f16 B=16, M=1024, H=16, K=128  |     4052.5  |      3428.0
      b16 B=16, M=1024, H=16, K=128  |     3929.8  |      3432.8
      f16 B=64, M=128, H=16, K=16    |      313.7  |       218.0
      b16 B=64, M=128, H=16, K=16    |      336.9  |       217.6
      f16 B=64, M=128, H=16, K=32    |      317.4  |       212.5
      b16 B=64, M=128, H=16, K=32    |      318.1  |       208.0
      f16 B=64, M=128, H=16, K=64    |      470.3  |       284.8
      b16 B=64, M=128, H=16, K=64    |      474.6  |       286.9
      f16 B=64, M=128, H=16, K=128   |     1024.8  |       504.8
      b16 B=64, M=128, H=16, K=128   |     1030.2  |       507.4
      f16 B=64, M=512, H=16, K=16    |      909.5  |       924.7
      b16 B=64, M=512, H=16, K=16    |      913.6  |       930.5
      f16 B=64, M=512, H=16, K=32    |     1454.8  |      1176.4
      b16 B=64, M=512, H=16, K=32    |     1459.1  |      1181.8
      f16 B=64, M=512, H=16, K=64    |     2460.3  |      1752.1
      b16 B=64, M=512, H=16, K=64    |     2485.5  |      1773.0
      f16 B=64, M=512, H=16, K=128   |     5503.4  |      3564.6
      b16 B=64, M=512, H=16, K=128   |     5557.4  |      3592.5
      f16 B=64, M=1024, H=16, K=16   |     2599.1  |      2866.5
      b16 B=64, M=1024, H=16, K=16   |     2605.9  |      2868.9
      f16 B=64, M=1024, H=16, K=32   |     4017.3  |      3577.7
      b16 B=64, M=1024, H=16, K=32   |     4022.4  |      3592.0
      f16 B=64, M=1024, H=16, K=64   |     6648.5  |      5349.7
      b16 B=64, M=1024, H=16, K=64   |     6716.9  |      5374.7
      f16 B=64, M=1024, H=16, K=128  |    15206.5  |     11777.5
      b16 B=64, M=1024, H=16, K=128  |    15313.3  |     11814.0

Times are in microseconds (us).

TODO:

  • get bf16 working
  • get non-causal working
  • benchmarking
  • specify when op can be used
  • flash bwd, triton fwd
  • packed
  • different block sizes for N and M
  • add Triton Autotune

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 11, 2022
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR Diana!

Can you add some tests to this backend? You basically only need to add the TritonFlashAttentionOp in

xformers.ops.MemoryEfficientAttentionCutlassFwdFlashBwOp,

Also, can you run the benchmarks in https://github.com/facebookresearch/xformers/blob/main/xformers/benchmarks/benchmark_mem_eff_attention.py with this backend so that we can compare to the existing ones?

xformers/ops.py Outdated Show resolved Hide resolved
@dianaml0 dianaml0 force-pushed the triton_flash branch 2 times, most recently from 6ec9fc0 to 6280055 Compare October 13, 2022 14:05
@blefaudeux
Copy link
Contributor

blefaudeux commented Oct 13, 2022

hey @dianaml0 this is great but you would need a modern triton, right ? I've had a branch up on that for a while, some of the API changed and we need to adapt a lot of the layers. Raising that just in case you bumped into that

@dianaml0
Copy link
Contributor Author

hey @dianaml0 this is great but you would need a modern triton, right ? I've had a branch up on that for a while, some of the API changed and we need to adapt a lot of the layers. Raising that just in case you bumped into that

Thanks @blefaudeux, that's a good point. I have the updated Triton locally and still facing some errors but still that's needed. Do you think you'll push the changes you have or should I look into it?

@blefaudeux
Copy link
Contributor

hey @dianaml0 this is great but you would need a modern triton, right ? I've had a branch up on that for a while, some of the API changed and we need to adapt a lot of the layers. Raising that just in case you bumped into that

Thanks @blefaudeux, that's a good point. I have the updated Triton locally and still facing some errors but still that's needed. Do you think you'll push the changes you have or should I look into it?

#483 should help, it should accept any modern triton pip package !

Copy link
Contributor

@danthe3rd danthe3rd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! The results really look amazing!
Would be curious to see if we can use triton's fwd with flash's bwd for instance to get the best of both worlds (not in this PR)

tests/test_triton_flashattention.py Outdated Show resolved Hide resolved
xformers/ops.py Outdated Show resolved Hide resolved
@dianaml0
Copy link
Contributor Author

dianaml0 commented Nov 15, 2022

Updated Numbers

Performance Compared to Vanilla FWD
[--------- attention (attn_bias=<class 'NoneType'>) --------]
                                     |  optimized  |  vanilla
1 threads: --------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |     1258.8  |    372.9
      b16 B=384, M=197, H=1, K=88    |     1270.5  |    375.4
      f16 B=384, M=197, H=1, K=80    |      146.8  |    344.5
      b16 B=384, M=197, H=1, K=80    |      149.5  |    346.8
      f16 B=384, M=197, H=1, K=64    |       90.1  |    293.4
      b16 B=384, M=197, H=1, K=64    |       94.3  |    295.4
      f16 B=1024, M=197, H=1, K=88   |     3241.4  |    938.6
      b16 B=1024, M=197, H=1, K=88   |     3263.3  |    945.0
      f16 B=1024, M=197, H=1, K=80   |      349.2  |    865.1
      b16 B=1024, M=197, H=1, K=80   |      354.9  |    871.3
      f16 B=1024, M=197, H=1, K=64   |      213.8  |    729.5
      b16 B=1024, M=197, H=1, K=64   |      222.1  |    735.7
      f16 B=512, M=197, H=1, K=80    |      185.6  |    448.6
      b16 B=512, M=197, H=1, K=80    |      188.8  |    451.6
      f16 B=32, M=197, H=16, K=80    |      193.8  |    548.3
      b16 B=32, M=197, H=16, K=80    |      193.9  |    551.0
      f16 B=32, M=197, H=16, K=64    |      114.4  |    466.7
      b16 B=32, M=197, H=16, K=64    |      120.5  |    469.2
      f16 B=32, M=197, H=16, K=128   |      193.9  |    717.6
      b16 B=32, M=197, H=16, K=128   |      195.0  |    720.7
      f16 B=256, M=197, H=1, K=88    |      868.3  |    260.9
      b16 B=256, M=197, H=1, K=88    |      870.0  |    262.3
      f16 B=16, M=197, H=16, K=88    |      869.9  |    317.9
      b16 B=16, M=197, H=16, K=88    |      870.8  |    319.3
      f16 B=16, M=197, H=16, K=64    |       89.0  |    257.0
      b16 B=16, M=197, H=16, K=64    |       88.2  |    258.2
      f16 B=16, M=197, H=16, K=128   |      101.6  |    385.5
      b16 B=16, M=197, H=16, K=128   |      102.2  |    387.0
      f16 B=1, M=4096, H=160, K=128  |     9824.7  |  20462.2
      b16 B=1, M=4096, H=160, K=128  |    10063.9  |  21439.0
      f16 B=2, M=4096, H=160, K=128  |    19368.3  |  42648.8
      b16 B=2, M=4096, H=160, K=128  |    19877.8  |  44515.5
      f16 B=1, M=8192, H=160, K=128  |    37335.5  |  88470.3
      b16 B=1, M=8192, H=160, K=128  |    38634.4  |  87261.8
      f16 B=2, M=8192, H=160, K=128  |    74047.2  |
      b16 B=2, M=8192, H=160, K=128  |    76836.0  |
      f16 B=1024, M=82, H=8, K=64    |      460.9  |   1769.1
      b16 B=1024, M=82, H=8, K=64    |      480.2  |   1863.2
      f16 B=150, M=256, H=16, K=64   |      383.1  |   1725.0
      b16 B=150, M=256, H=16, K=64   |      421.2  |   1760.5
      f16 B=64, M=256, H=12, K=64    |      131.6  |    587.0
      b16 B=64, M=256, H=12, K=64    |      145.8  |    597.5
      f16 B=1, M=4096, H=16, K=40    |    30691.3  |   1943.9
      b16 B=1, M=4096, H=16, K=40    |    30831.1  |   1998.1
      f16 B=1, M=16384, H=16, K=40   |   422956.0  |  28855.4
      b16 B=1, M=16384, H=16, K=40   |   420924.0  |  30204.7
      f16 B=256, M=4096, H=16, K=64  |   118507.9  |
      b16 B=256, M=4096, H=16, K=64  |   132804.6  |
      f16 B=8, M=2048, H=20, K=128   |     2697.1  |   5700.7
      b16 B=8, M=2048, H=20, K=128   |     2744.9  |   6041.7
      f16 B=16, M=128, H=16, K=16    |       87.3  |    139.1
      b16 B=16, M=128, H=16, K=16    |       87.1  |    137.9
      f16 B=16, M=128, H=16, K=32    |       89.6  |    139.4
      b16 B=16, M=128, H=16, K=32    |       89.2  |    138.3
      f16 B=16, M=128, H=16, K=64    |       86.7  |    137.1
      b16 B=16, M=128, H=16, K=64    |       87.0  |    136.9
      f16 B=16, M=128, H=16, K=128   |       88.7  |    139.6
      b16 B=16, M=128, H=16, K=128   |       86.6  |    140.6
      f16 B=16, M=512, H=16, K=16    |       87.6  |    461.5
      b16 B=16, M=512, H=16, K=16    |       99.5  |    553.9
      f16 B=16, M=512, H=16, K=32    |      105.4  |    512.3
      b16 B=16, M=512, H=16, K=32    |      123.2  |    594.6
      f16 B=16, M=512, H=16, K=64    |      152.1  |    596.1
      b16 B=16, M=512, H=16, K=64    |      169.4  |    616.0
      f16 B=16, M=512, H=16, K=128   |      318.6  |    786.9
      b16 B=16, M=512, H=16, K=128   |      327.6  |    804.4
      f16 B=16, M=1024, H=16, K=16   |      292.0  |   1642.8
      b16 B=16, M=1024, H=16, K=16   |      384.8  |   2028.9
      f16 B=16, M=1024, H=16, K=32   |      369.4  |   1731.9
      b16 B=16, M=1024, H=16, K=32   |      428.1  |   2126.7
      f16 B=16, M=1024, H=16, K=64   |      518.5  |   2034.7
      b16 B=16, M=1024, H=16, K=64   |      579.7  |   2077.2
      f16 B=16, M=1024, H=16, K=128  |     1097.8  |   2421.0
      b16 B=16, M=1024, H=16, K=128  |     1134.9  |   2489.6
      f16 B=64, M=128, H=16, K=16    |       89.3  |    183.7
      b16 B=64, M=128, H=16, K=16    |       87.1  |    185.1
      f16 B=64, M=128, H=16, K=32    |       87.5  |    224.8
      b16 B=64, M=128, H=16, K=32    |       88.1  |    225.5
      f16 B=64, M=128, H=16, K=64    |       86.7  |    321.5
      b16 B=64, M=128, H=16, K=64    |       88.1  |    323.1
      f16 B=64, M=128, H=16, K=128   |      127.1  |    484.8
      b16 B=64, M=128, H=16, K=128   |      129.5  |    486.1
      f16 B=64, M=512, H=16, K=16    |      307.5  |   1727.2
      b16 B=64, M=512, H=16, K=16    |      403.6  |   2094.0
      f16 B=64, M=512, H=16, K=32    |      387.3  |   1893.5
      b16 B=64, M=512, H=16, K=32    |      450.5  |   2250.3
      f16 B=64, M=512, H=16, K=64    |      558.2  |   2234.4
      b16 B=64, M=512, H=16, K=64    |      620.5  |   2294.3
      f16 B=64, M=512, H=16, K=128   |     1201.0  |   3016.3
      b16 B=64, M=512, H=16, K=128   |     1235.3  |   3069.4
      f16 B=64, M=1024, H=16, K=16   |     1101.7  |   6428.2
      b16 B=64, M=1024, H=16, K=16   |     1473.2  |   8038.7
      f16 B=64, M=1024, H=16, K=32   |     1411.1  |   6770.6
      b16 B=64, M=1024, H=16, K=32   |     1667.5  |   8383.2
      f16 B=64, M=1024, H=16, K=64   |     2033.4  |   7978.2
      b16 B=64, M=1024, H=16, K=64   |     2265.9  |   8160.9
      f16 B=64, M=1024, H=16, K=128  |     4281.6  |   9521.1
      b16 B=64, M=1024, H=16, K=128  |     4413.3  |   9886.2

Times are in microseconds (us).
[ attention (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
                                     |  optimized  |  vanilla
1 threads: ---------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |     3697.6  |     446.9
      b16 B=384, M=197, H=1, K=88    |     1178.0  |     452.8
      f16 B=384, M=197, H=1, K=80    |      120.3  |     421.8
      b16 B=384, M=197, H=1, K=80    |      122.6  |     427.6
      f16 B=384, M=197, H=1, K=64    |       87.6  |     370.9
      b16 B=384, M=197, H=1, K=64    |       90.4  |     376.4
      f16 B=1024, M=197, H=1, K=88   |     9588.0  |    1124.2
      b16 B=1024, M=197, H=1, K=88   |     2983.1  |    1140.5
      f16 B=1024, M=197, H=1, K=80   |      294.1  |    1059.7
      b16 B=1024, M=197, H=1, K=80   |      300.2  |    1075.1
      f16 B=1024, M=197, H=1, K=64   |      187.5  |     925.9
      b16 B=1024, M=197, H=1, K=64   |      197.2  |     940.5
      f16 B=512, M=197, H=1, K=80    |      153.7  |     549.6
      b16 B=512, M=197, H=1, K=80    |      156.5  |     557.4
      f16 B=32, M=197, H=16, K=80    |      157.5  |     647.7
      b16 B=32, M=197, H=16, K=80    |      160.8  |     655.5
      f16 B=32, M=197, H=16, K=64    |      103.8  |     565.7
      b16 B=32, M=197, H=16, K=64    |      109.3  |     573.3
      f16 B=32, M=197, H=16, K=128   |      157.6  |     811.7
      b16 B=32, M=197, H=16, K=128   |      159.5  |     819.3
      f16 B=256, M=197, H=1, K=88    |     2552.6  |     311.0
      b16 B=256, M=197, H=1, K=88    |      812.5  |     315.1
      f16 B=16, M=197, H=16, K=88    |     2555.6  |     368.1
      b16 B=16, M=197, H=16, K=88    |      805.8  |     372.4
      f16 B=16, M=197, H=16, K=64    |       87.2  |     307.2
      b16 B=16, M=197, H=16, K=64    |       88.9  |     311.0
      f16 B=16, M=197, H=16, K=128   |       87.7  |     435.1
      b16 B=16, M=197, H=16, K=128   |       87.1  |     438.6
      f16 B=1, M=4096, H=160, K=128  |     5174.9  |   37307.0
      b16 B=1, M=4096, H=160, K=128  |     5338.7  |   37824.6
      f16 B=2, M=4096, H=160, K=128  |    10220.4  |   76294.3
      b16 B=2, M=4096, H=160, K=128  |    10545.0  |   77265.7
      f16 B=1, M=8192, H=160, K=128  |    19334.4  |  152522.4
      b16 B=1, M=8192, H=160, K=128  |    19974.0  |  148646.9
      f16 B=2, M=8192, H=160, K=128  |    38377.2  |
      b16 B=2, M=8192, H=160, K=128  |    39641.7  |
      f16 B=1024, M=82, H=8, K=64    |      488.3  |    1981.4
      b16 B=1024, M=82, H=8, K=64    |      515.8  |    2083.8
      f16 B=150, M=256, H=16, K=64   |      335.1  |    2402.9
      b16 B=150, M=256, H=16, K=64   |      350.9  |    2445.6
      f16 B=64, M=256, H=12, K=64    |      118.9  |     805.2
      b16 B=64, M=256, H=12, K=64    |      124.6  |     819.3
      f16 B=1, M=4096, H=16, K=40    |    14939.0  |    3432.0
      b16 B=1, M=4096, H=16, K=40    |    14915.2  |    3468.0
      f16 B=1, M=16384, H=16, K=40   |   222056.7  |   55382.5
      b16 B=1, M=16384, H=16, K=40   |   218015.9  |   55427.4
      f16 B=256, M=4096, H=16, K=64  |    67733.2  |
      b16 B=256, M=4096, H=16, K=64  |    74031.3  |
      f16 B=8, M=2048, H=20, K=128   |     1488.1  |    9349.8
      b16 B=8, M=2048, H=20, K=128   |     1534.1  |    9538.5
      f16 B=16, M=128, H=16, K=16    |       86.9  |     147.4
      b16 B=16, M=128, H=16, K=16    |       87.9  |     143.6
      f16 B=16, M=128, H=16, K=32    |       86.6  |     144.3
      b16 B=16, M=128, H=16, K=32    |       89.8  |     143.9
      f16 B=16, M=128, H=16, K=64    |       86.4  |     143.1
      b16 B=16, M=128, H=16, K=64    |       86.5  |     141.4
      f16 B=16, M=128, H=16, K=128   |       89.7  |     160.0
      b16 B=16, M=128, H=16, K=128   |       88.3  |     161.9
      f16 B=16, M=512, H=16, K=16    |       90.7  |     715.3
      b16 B=16, M=512, H=16, K=16    |       87.6  |     792.0
      f16 B=16, M=512, H=16, K=32    |       88.5  |     766.8
      b16 B=16, M=512, H=16, K=32    |       98.1  |     832.3
      f16 B=16, M=512, H=16, K=64    |      120.3  |     881.4
      b16 B=16, M=512, H=16, K=64    |      129.1  |     910.7
      f16 B=16, M=512, H=16, K=128   |      225.4  |    1066.2
      b16 B=16, M=512, H=16, K=128   |      232.0  |    1095.8
      f16 B=16, M=1024, H=16, K=16   |      196.6  |    2658.4
      b16 B=16, M=1024, H=16, K=16   |      253.0  |    3042.3
      f16 B=16, M=1024, H=16, K=32   |      244.2  |    2749.6
      b16 B=16, M=1024, H=16, K=32   |      291.8  |    3123.0
      f16 B=16, M=1024, H=16, K=64   |      355.3  |    2962.7
      b16 B=16, M=1024, H=16, K=64   |      384.7  |    3287.6
      f16 B=16, M=1024, H=16, K=128  |      683.0  |    3533.6
      b16 B=16, M=1024, H=16, K=128  |      705.0  |    3714.0
      f16 B=64, M=128, H=16, K=16    |       87.6  |     246.0
      b16 B=64, M=128, H=16, K=16    |       87.2  |     251.8
      f16 B=64, M=128, H=16, K=32    |       87.1  |     294.3
      b16 B=64, M=128, H=16, K=32    |       87.5  |     298.8
      f16 B=64, M=128, H=16, K=64    |       87.8  |     391.4
      b16 B=64, M=128, H=16, K=64    |       86.8  |     396.0
      f16 B=64, M=128, H=16, K=128   |      130.9  |     557.1
      b16 B=64, M=128, H=16, K=128   |      133.1  |     562.3
      f16 B=64, M=512, H=16, K=16    |      227.5  |    2704.1
      b16 B=64, M=512, H=16, K=16    |      288.9  |    3019.5
      f16 B=64, M=512, H=16, K=32    |      285.1  |    2892.2
      b16 B=64, M=512, H=16, K=32    |      336.9  |    3163.7
      f16 B=64, M=512, H=16, K=64    |      414.2  |    3352.4
      b16 B=64, M=512, H=16, K=64    |      444.3  |    3470.6
      f16 B=64, M=512, H=16, K=128   |      828.5  |    4093.5
      b16 B=64, M=512, H=16, K=128   |      854.7  |    4208.4
      f16 B=64, M=1024, H=16, K=16   |      721.9  |   10475.9
      b16 B=64, M=1024, H=16, K=16   |      932.5  |   12026.3
      f16 B=64, M=1024, H=16, K=32   |      900.6  |   10855.0
      b16 B=64, M=1024, H=16, K=32   |     1075.5  |   12354.0
      f16 B=64, M=1024, H=16, K=64   |     1314.2  |   11673.0
      b16 B=64, M=1024, H=16, K=64   |     1423.5  |   13000.3
      f16 B=64, M=1024, H=16, K=128  |     2608.3  |   13961.5
      b16 B=64, M=1024, H=16, K=128  |     2693.0  |   14694.5

Times are in microseconds (us).
Performance Compared to Vanilla BWD
[---- attention backward (attn_bias=<class 'NoneType'>) ----]
                                     |  optimized  |  vanilla
1 threads: --------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |     3679.0  |    820.0
      b16 B=384, M=197, H=1, K=88    |     3560.9  |    823.2
      f16 B=384, M=197, H=1, K=80    |      762.1  |    767.0
      b16 B=384, M=197, H=1, K=80    |      762.6  |    768.6
      f16 B=384, M=197, H=1, K=64    |      543.5  |    651.8
      b16 B=384, M=197, H=1, K=64    |      372.8  |    652.8
      f16 B=1024, M=197, H=1, K=88   |     9127.6  |   2103.8
      b16 B=1024, M=197, H=1, K=88   |     9123.4  |   2104.3
      f16 B=1024, M=197, H=1, K=80   |     1881.4  |   1957.8
      b16 B=1024, M=197, H=1, K=80   |     1889.3  |   1957.7
      f16 B=1024, M=197, H=1, K=64   |      916.7  |   1648.4
      b16 B=1024, M=197, H=1, K=64   |      924.1  |   1649.6
      f16 B=512, M=197, H=1, K=80    |      987.0  |    993.8
      b16 B=512, M=197, H=1, K=80    |      991.6  |    995.8
      f16 B=32, M=197, H=16, K=80    |     1038.9  |   1032.6
      b16 B=32, M=197, H=16, K=80    |     1040.7  |   1035.2
      f16 B=32, M=197, H=16, K=64    |      485.6  |    886.1
      b16 B=32, M=197, H=16, K=64    |      487.9  |    887.0
      f16 B=32, M=197, H=16, K=128   |     1165.6  |   1341.5
      b16 B=32, M=197, H=16, K=128   |     1168.8  |   1344.7
      f16 B=256, M=197, H=1, K=88    |     2364.1  |    576.1
      b16 B=256, M=197, H=1, K=88    |     2363.5  |    577.3
      f16 B=16, M=197, H=16, K=88    |     2378.9  |    599.2
      b16 B=16, M=197, H=16, K=88    |     2377.1  |    600.3
      f16 B=16, M=197, H=16, K=64    |      355.0  |    486.9
      b16 B=16, M=197, H=16, K=64    |      309.7  |    488.6
      f16 B=16, M=197, H=16, K=128   |      617.3  |    713.9
      b16 B=16, M=197, H=16, K=128   |      620.4  |    715.8
      f16 B=1, M=4096, H=160, K=128  |    41943.7  |  38971.0
      b16 B=1, M=4096, H=160, K=128  |    42033.0  |  39938.5
      f16 B=2, M=4096, H=160, K=128  |    83588.1  |  79017.0
      b16 B=2, M=4096, H=160, K=128  |    83792.2  |  80856.7
      f16 B=1, M=8192, H=160, K=128  |   160571.8  |
      b16 B=1, M=8192, H=160, K=128  |   160843.2  |
      f16 B=2, M=8192, H=160, K=128  |   320831.4  |
      b16 B=2, M=8192, H=160, K=128  |   322116.0  |
      f16 B=1024, M=82, H=8, K=64    |     2629.7  |   3609.1
      b16 B=1024, M=82, H=8, K=64    |     2524.3  |   3782.4
      f16 B=150, M=256, H=16, K=64   |     2577.6  |   3790.4
      b16 B=150, M=256, H=16, K=64   |     2458.9  |   3837.4
      f16 B=64, M=256, H=12, K=64    |      827.1  |   1254.5
      b16 B=64, M=256, H=12, K=64    |      835.7  |   1267.6
      f16 B=1, M=4096, H=16, K=40    |    43086.4  |   3562.3
      b16 B=1, M=4096, H=16, K=40    |    43079.7  |   3588.6
      f16 B=1, M=16384, H=16, K=40   |   664389.3  |  53939.0
      b16 B=1, M=16384, H=16, K=40   |   663918.7  |  54291.3
      f16 B=256, M=4096, H=16, K=64  |   504278.4  |
      b16 B=256, M=4096, H=16, K=64  |   503469.4  |
      f16 B=8, M=2048, H=20, K=128   |    11424.6  |  10535.0
      b16 B=8, M=2048, H=20, K=128   |    11365.6  |  10784.8
      f16 B=16, M=128, H=16, K=16    |      351.2  |    348.4
      b16 B=16, M=128, H=16, K=16    |      329.6  |    346.4
      f16 B=16, M=128, H=16, K=32    |      351.3  |    343.9
      b16 B=16, M=128, H=16, K=32    |      321.3  |    319.8
      f16 B=16, M=128, H=16, K=64    |      329.9  |    316.2
      b16 B=16, M=128, H=16, K=64    |      329.8  |    340.4
      f16 B=16, M=128, H=16, K=128   |      453.7  |    338.1
      b16 B=16, M=128, H=16, K=128   |      332.6  |    337.6
      f16 B=16, M=512, H=16, K=16    |      526.4  |    985.4
      b16 B=16, M=512, H=16, K=16    |      322.8  |   1078.1
      f16 B=16, M=512, H=16, K=32    |      683.4  |   1089.3
      b16 B=16, M=512, H=16, K=32    |      486.9  |   1178.8
      f16 B=16, M=512, H=16, K=64    |     1020.1  |   1276.7
      b16 B=16, M=512, H=16, K=64    |      845.3  |   1296.3
      f16 B=16, M=512, H=16, K=128   |     1900.1  |   1715.2
      b16 B=16, M=512, H=16, K=128   |     1766.8  |   1742.1
      f16 B=16, M=1024, H=16, K=16   |     1260.6  |   3525.5
      b16 B=16, M=1024, H=16, K=16   |     1067.4  |   3953.2
      f16 B=16, M=1024, H=16, K=32   |     1794.2  |   3728.4
      b16 B=16, M=1024, H=16, K=32   |     1615.1  |   4153.2
      f16 B=16, M=1024, H=16, K=64   |     2685.4  |   4278.5
      b16 B=16, M=1024, H=16, K=64   |     2545.9  |   4349.5
      f16 B=16, M=1024, H=16, K=128  |     5420.6  |   5108.0
      b16 B=16, M=1024, H=16, K=128  |     5317.9  |   5240.9
      f16 B=64, M=128, H=16, K=16    |      318.0  |    365.1
      b16 B=64, M=128, H=16, K=16    |      333.2  |    373.0
      f16 B=64, M=128, H=16, K=32    |      332.6  |    466.7
      b16 B=64, M=128, H=16, K=32    |      334.3  |    470.8
      f16 B=64, M=128, H=16, K=64    |      464.2  |    680.0
      b16 B=64, M=128, H=16, K=64    |      469.7  |    681.6
      f16 B=64, M=128, H=16, K=128   |      980.2  |   1076.1
      b16 B=64, M=128, H=16, K=128   |      986.8  |   1078.6
      f16 B=64, M=512, H=16, K=16    |     1109.1  |   3706.3
      b16 B=64, M=512, H=16, K=16    |     1112.0  |   4077.9
      f16 B=64, M=512, H=16, K=32    |     1679.8  |   4137.0
      b16 B=64, M=512, H=16, K=32    |     1684.1  |   4507.5
      f16 B=64, M=512, H=16, K=64    |     2988.8  |   4903.0
      b16 B=64, M=512, H=16, K=64    |     3001.1  |   4991.4
      f16 B=64, M=512, H=16, K=128   |     6639.6  |   6643.4
      b16 B=64, M=512, H=16, K=128   |     6698.5  |   6749.8
      f16 B=64, M=1024, H=16, K=16   |     3618.2  |  13888.1
      b16 B=64, M=1024, H=16, K=16   |     3625.5  |  15564.5
      f16 B=64, M=1024, H=16, K=32   |     6140.2  |  14750.0
      b16 B=64, M=1024, H=16, K=32   |     6156.8  |  16452.7
      f16 B=64, M=1024, H=16, K=64   |     9818.7  |  16864.4
      b16 B=64, M=1024, H=16, K=64   |     9843.5  |  17158.3
      f16 B=64, M=1024, H=16, K=128  |    20554.6  |  20403.1
      b16 B=64, M=1024, H=16, K=128  |    20676.6  |  20899.8

Times are in microseconds (us).
[ attention backward (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
                                     |  optimized  |  vanilla
1 threads: --------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |     9482.2  |    821.1
      b16 B=384, M=197, H=1, K=88    |     9334.4  |    823.2
      f16 B=384, M=197, H=1, K=80    |      669.6  |    768.4
      b16 B=384, M=197, H=1, K=80    |      672.2  |    769.8
      f16 B=384, M=197, H=1, K=64    |      544.9  |    651.9
      b16 B=384, M=197, H=1, K=64    |      351.0  |    653.7
      f16 B=1024, M=197, H=1, K=88   |    23569.2  |   2103.4
      b16 B=1024, M=197, H=1, K=88   |    23571.4  |   2105.2
      f16 B=1024, M=197, H=1, K=80   |     1644.3  |   1957.5
      b16 B=1024, M=197, H=1, K=80   |     1649.9  |   1957.4
      f16 B=1024, M=197, H=1, K=64   |      862.7  |   1648.0
      b16 B=1024, M=197, H=1, K=64   |      868.7  |   1649.4
      f16 B=512, M=197, H=1, K=80    |      863.2  |    993.9
      b16 B=512, M=197, H=1, K=80    |      866.3  |    996.2
      f16 B=32, M=197, H=16, K=80    |      867.2  |   1035.0
      b16 B=32, M=197, H=16, K=80    |      869.5  |   1035.5
      f16 B=32, M=197, H=16, K=64    |      456.5  |    884.3
      b16 B=32, M=197, H=16, K=64    |      458.6  |    885.4
      f16 B=32, M=197, H=16, K=128   |     1042.2  |   1339.5
      b16 B=32, M=197, H=16, K=128   |     1046.8  |   1343.0
      f16 B=256, M=197, H=1, K=88    |     6343.3  |    575.9
      b16 B=256, M=197, H=1, K=88    |     6346.2  |    578.0
      f16 B=16, M=197, H=16, K=88    |     6339.5  |    600.1
      b16 B=16, M=197, H=16, K=88    |     6338.8  |    601.4
      f16 B=16, M=197, H=16, K=64    |      323.2  |    488.7
      b16 B=16, M=197, H=16, K=64    |      309.0  |    491.1
      f16 B=16, M=197, H=16, K=128   |      555.8  |    715.2
      b16 B=16, M=197, H=16, K=128   |      557.6  |    717.0
      f16 B=1, M=4096, H=160, K=128  |    26198.1  |  38857.2
      b16 B=1, M=4096, H=160, K=128  |    26165.0  |  39906.6
      f16 B=2, M=4096, H=160, K=128  |    51727.5  |  78346.7
      b16 B=2, M=4096, H=160, K=128  |    51935.9  |  80796.5
      f16 B=1, M=8192, H=160, K=128  |    92706.3  |
      b16 B=1, M=8192, H=160, K=128  |    92957.2  |
      f16 B=2, M=8192, H=160, K=128  |   184853.7  |
      b16 B=2, M=8192, H=160, K=128  |   185265.1  |
      f16 B=1024, M=82, H=8, K=64    |     2746.8  |   3611.4
      b16 B=1024, M=82, H=8, K=64    |     2641.2  |   3784.6
      f16 B=150, M=256, H=16, K=64   |     2381.3  |   3824.6
      b16 B=150, M=256, H=16, K=64   |     2291.5  |   3863.9
      f16 B=64, M=256, H=12, K=64    |      776.2  |   1254.8
      b16 B=64, M=256, H=12, K=64    |      780.0  |   1269.1
      f16 B=1, M=4096, H=16, K=40    |     6826.9  |   3563.2
      b16 B=1, M=4096, H=16, K=40    |     6622.5  |   3596.4
      f16 B=1, M=16384, H=16, K=40   |    94566.7  |  53887.6
      b16 B=1, M=16384, H=16, K=40   |    94648.8  |  54316.2
      f16 B=256, M=4096, H=16, K=64  |   274458.4  |
      b16 B=256, M=4096, H=16, K=64  |   275062.8  |
      f16 B=8, M=2048, H=20, K=128   |     7871.9  |  10544.8
      b16 B=8, M=2048, H=20, K=128   |     7770.2  |  10794.5
      f16 B=16, M=128, H=16, K=16    |      348.7  |    323.2
      b16 B=16, M=128, H=16, K=16    |      326.5  |    325.8
      f16 B=16, M=128, H=16, K=32    |      350.7  |    323.9
      b16 B=16, M=128, H=16, K=32    |      305.9  |    300.9
      f16 B=16, M=128, H=16, K=64    |      361.3  |    328.4
      b16 B=16, M=128, H=16, K=64    |      309.5  |    296.7
      f16 B=16, M=128, H=16, K=128   |      487.7  |    321.4
      b16 B=16, M=128, H=16, K=128   |      329.5  |    305.0
      f16 B=16, M=512, H=16, K=16    |      435.9  |    986.1
      b16 B=16, M=512, H=16, K=16    |      333.1  |   1079.5
      f16 B=16, M=512, H=16, K=32    |      589.3  |   1090.4
      b16 B=16, M=512, H=16, K=32    |      394.0  |   1178.8
      f16 B=16, M=512, H=16, K=64    |      846.2  |   1280.3
      b16 B=16, M=512, H=16, K=64    |      670.0  |   1298.3
      f16 B=16, M=512, H=16, K=128   |     1616.0  |   1712.2
      b16 B=16, M=512, H=16, K=128   |     1485.7  |   1740.0
      f16 B=16, M=1024, H=16, K=16   |      881.3  |   3527.4
      b16 B=16, M=1024, H=16, K=16   |      686.9  |   3948.4
      f16 B=16, M=1024, H=16, K=32   |     1241.7  |   3729.6
      b16 B=16, M=1024, H=16, K=32   |     1059.4  |   4152.4
      f16 B=16, M=1024, H=16, K=64   |     1858.0  |   4284.3
      b16 B=16, M=1024, H=16, K=64   |     1752.9  |   4348.1
      f16 B=16, M=1024, H=16, K=128  |     4083.5  |   5106.9
      b16 B=16, M=1024, H=16, K=128  |     3969.3  |   5246.5
      f16 B=64, M=128, H=16, K=16    |      306.6  |    368.1
      b16 B=64, M=128, H=16, K=16    |      333.1  |    376.4
      f16 B=64, M=128, H=16, K=32    |      327.0  |    466.3
      b16 B=64, M=128, H=16, K=32    |      329.4  |    472.3
      f16 B=64, M=128, H=16, K=64    |      475.4  |    680.6
      b16 B=64, M=128, H=16, K=64    |      479.8  |    680.9
      f16 B=64, M=128, H=16, K=128   |     1041.1  |   1071.8
      b16 B=64, M=128, H=16, K=128   |     1046.3  |   1077.4
      f16 B=64, M=512, H=16, K=16    |      914.1  |   3704.1
      b16 B=64, M=512, H=16, K=16    |      916.3  |   4083.0
      f16 B=64, M=512, H=16, K=32    |     1458.5  |   4134.4
      b16 B=64, M=512, H=16, K=32    |     1461.1  |   4502.1
      f16 B=64, M=512, H=16, K=64    |     2464.9  |   4900.6
      b16 B=64, M=512, H=16, K=64    |     2490.6  |   4991.0
      f16 B=64, M=512, H=16, K=128   |     5549.3  |   6642.1
      b16 B=64, M=512, H=16, K=128   |     5594.7  |   6753.4
      f16 B=64, M=1024, H=16, K=16   |     2605.9  |  13864.0
      b16 B=64, M=1024, H=16, K=16   |     2613.2  |  15564.9
      f16 B=64, M=1024, H=16, K=32   |     4026.7  |  14759.8
      b16 B=64, M=1024, H=16, K=32   |     4029.4  |  16449.0
      f16 B=64, M=1024, H=16, K=64   |     6661.4  |  16899.4
      b16 B=64, M=1024, H=16, K=64   |     6721.7  |  17166.3
      f16 B=64, M=1024, H=16, K=128  |    15343.3  |  20357.1
      b16 B=64, M=1024, H=16, K=128  |    15446.7  |  20917.8

Times are in microseconds (us).

@danthe3rd danthe3rd mentioned this pull request Nov 16, 2022
@dianaml0 dianaml0 force-pushed the triton_flash branch 2 times, most recently from 3c5496f to d8c610c Compare November 16, 2022 19:04
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR Diana!

I have some comments. Also, if there were no changes to the triton implementation from the one present in flashattention, would it make sense to just use it directly, instead of copying its code?

We already have triton compiled and installed (for the CUDA dependency), so we could directly call into its Python API. What do you think?

tests/test_triton_flashattention.py Outdated Show resolved Hide resolved
Comment on lines +726 to +766
def supports(cls, d: "AttentionOpDispatch") -> bool:
if not has_triton_flashattention:
return False
device_capability = torch.cuda.get_device_capability(d.device)
is_sm80 = device_capability[0] >= 8
if not is_sm80:
return False
return super(TritonFlashAttentionOp, cls).supports(d)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the triton implementation supports all values of K, is that right? Can you make the test_mem_eff_attention run on this operator so that we double-check that this all works as expected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it supports all values of K until 128. I ran the tests and they're all passing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, so it also supports K which is not a multiple of 8? cc @danthe3rd

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I ran the tests locally and it passes for all the values of K less than 128, even those which are not a multiple of 8

xformers/triton/flash_attention.py Outdated Show resolved Hide resolved
xformers/triton/k_flash_attn.py Outdated Show resolved Hide resolved
xformers/triton/utils.py Outdated Show resolved Hide resolved
xformers/triton/utils.py Outdated Show resolved Hide resolved
@dianaml0
Copy link
Contributor Author

Thanks a lot for the reviews @fmassa and @danthe3rd ! I've made some changes and added an op for Triton fwd with Flash bwd

@dianaml0
Copy link
Contributor Author

dianaml0 commented Nov 29, 2022

Forwards for Triton fwd and Flash bwd:
[--------- attention (attn_bias=<class 'NoneType'>) --------]
                                     |  optimized  |   eager
1 threads: --------------------------------------------------
      f16 B=384, M=197, H=1, K=64    |       91.7  |    292.3
      b16 B=384, M=197, H=1, K=64    |       94.2  |    295.4
      f16 B=1024, M=197, H=1, K=64   |      213.1  |    727.8
      b16 B=1024, M=197, H=1, K=64   |      221.4  |    734.1
      f16 B=32, M=197, H=16, K=64    |      114.2  |    466.1
      b16 B=32, M=197, H=16, K=64    |      120.6  |    468.4
      f16 B=32, M=197, H=16, K=128   |      194.0  |    716.8
      b16 B=32, M=197, H=16, K=128   |      195.4  |    719.7
      f16 B=16, M=197, H=16, K=64    |       91.5  |    256.3
      b16 B=16, M=197, H=16, K=64    |       90.9  |    256.9
      f16 B=16, M=197, H=16, K=128   |      102.4  |    384.7
      b16 B=16, M=197, H=16, K=128   |      104.6  |    386.0
      f16 B=1, M=4096, H=160, K=128  |     9785.7  |  20529.6
      b16 B=1, M=4096, H=160, K=128  |    10020.5  |  21512.5
      f16 B=2, M=4096, H=160, K=128  |    19305.3  |  42686.7
      b16 B=2, M=4096, H=160, K=128  |    19817.6  |  44503.4
      f16 B=1, M=8192, H=160, K=128  |    37278.0  |  88499.1
      b16 B=1, M=8192, H=160, K=128  |    38494.5  |  87043.0
      f16 B=2, M=8192, H=160, K=128  |    73829.8  |
      b16 B=2, M=8192, H=160, K=128  |    76805.0  |
      f16 B=1024, M=82, H=8, K=64    |      460.8  |   1769.1
      b16 B=1024, M=82, H=8, K=64    |      478.7  |   1864.2
      f16 B=150, M=256, H=16, K=64   |      383.5  |   1727.0
      b16 B=150, M=256, H=16, K=64   |      420.7  |   1761.2
      f16 B=64, M=256, H=12, K=64    |      131.7  |    588.1
      b16 B=64, M=256, H=12, K=64    |      145.6  |    598.2
      f16 B=256, M=4096, H=16, K=64  |   118090.6  |
      b16 B=256, M=4096, H=16, K=64  |   132264.8  |
      f16 B=8, M=2048, H=20, K=128   |     2692.1  |   5746.4
      b16 B=8, M=2048, H=20, K=128   |     2743.3  |   6049.4
      f16 B=16, M=128, H=16, K=16    |       90.0  |    145.5
      b16 B=16, M=128, H=16, K=16    |       92.8  |    143.1
      f16 B=16, M=128, H=16, K=32    |       93.7  |    147.5
      b16 B=16, M=128, H=16, K=32    |       90.5  |    145.4
      f16 B=16, M=128, H=16, K=64    |       90.4  |    146.0
      b16 B=16, M=128, H=16, K=64    |       92.5  |    142.1
      f16 B=16, M=128, H=16, K=128   |       91.0  |    142.3
      b16 B=16, M=128, H=16, K=128   |       90.2  |    143.1
      f16 B=16, M=512, H=16, K=16    |       94.1  |    462.3
      b16 B=16, M=512, H=16, K=16    |      100.9  |    553.7
      f16 B=16, M=512, H=16, K=32    |      105.4  |    513.1
      b16 B=16, M=512, H=16, K=32    |      122.9  |    595.7
      f16 B=16, M=512, H=16, K=64    |      152.1  |    596.8
      b16 B=16, M=512, H=16, K=64    |      169.2  |    613.5
      f16 B=16, M=512, H=16, K=128   |      319.0  |    788.5
      b16 B=16, M=512, H=16, K=128   |      327.9  |    805.1
      f16 B=16, M=1024, H=16, K=16   |      291.8  |   1644.3
      b16 B=16, M=1024, H=16, K=16   |      384.1  |   2026.1
      f16 B=16, M=1024, H=16, K=32   |      369.0  |   1734.6
      b16 B=16, M=1024, H=16, K=32   |      426.7  |   2128.1
      f16 B=16, M=1024, H=16, K=64   |      521.1  |   2035.3
      b16 B=16, M=1024, H=16, K=64   |      578.1  |   2079.1
      f16 B=16, M=1024, H=16, K=128  |     1093.8  |   2420.7
      b16 B=16, M=1024, H=16, K=128  |     1131.1  |   2491.4
      f16 B=64, M=128, H=16, K=16    |       90.9  |    182.8
      b16 B=64, M=128, H=16, K=16    |       90.3  |    184.6
      f16 B=64, M=128, H=16, K=32    |       90.6  |    225.4
      b16 B=64, M=128, H=16, K=32    |       90.3  |    226.7
      f16 B=64, M=128, H=16, K=64    |       90.9  |    324.6
      b16 B=64, M=128, H=16, K=64    |       89.9  |    326.3
      f16 B=64, M=128, H=16, K=128   |      128.4  |    485.0
      b16 B=64, M=128, H=16, K=128   |      130.8  |    486.5
      f16 B=64, M=512, H=16, K=16    |      305.7  |   1729.1
      b16 B=64, M=512, H=16, K=16    |      401.8  |   2094.5
      f16 B=64, M=512, H=16, K=32    |      386.0  |   1899.7
      b16 B=64, M=512, H=16, K=32    |      448.9  |   2252.9
      f16 B=64, M=512, H=16, K=64    |      561.1  |   2246.8
      b16 B=64, M=512, H=16, K=64    |      618.4  |   2302.4
      f16 B=64, M=512, H=16, K=128   |     1195.5  |   3013.4
      b16 B=64, M=512, H=16, K=128   |     1228.0  |   3069.2
      f16 B=64, M=1024, H=16, K=16   |     1097.8  |   6432.0
      b16 B=64, M=1024, H=16, K=16   |     1468.3  |   8046.7
      f16 B=64, M=1024, H=16, K=32   |     1405.7  |   6777.6
      b16 B=64, M=1024, H=16, K=32   |     1662.4  |   8426.7
      f16 B=64, M=1024, H=16, K=64   |     2035.0  |   7999.1
      b16 B=64, M=1024, H=16, K=64   |     2255.6  |   8174.9
      f16 B=64, M=1024, H=16, K=128  |     4269.3  |   9546.4
      b16 B=64, M=1024, H=16, K=128  |     4404.0  |   9904.9

Times are in microseconds (us).

[ attention (attn_bias=<class 'xformers.ops.memory_efficient_attention.LowerTriangularMask'>) ]
                                     |  optimized  |   eager
1 threads: ---------------------------------------------------
      f16 B=384, M=197, H=1, K=64    |      92.4   |     369.8
      b16 B=384, M=197, H=1, K=64    |      94.6   |     376.0
      f16 B=1024, M=197, H=1, K=64   |     186.6   |     923.4
      b16 B=1024, M=197, H=1, K=64   |     196.6   |     938.1
      f16 B=32, M=197, H=16, K=64    |     104.4   |     564.2
      b16 B=32, M=197, H=16, K=64    |     108.8   |     572.2
      f16 B=32, M=197, H=16, K=128   |     157.9   |     812.1
      b16 B=32, M=197, H=16, K=128   |     159.5   |     818.4
      f16 B=16, M=197, H=16, K=64    |      93.2   |     306.1
      b16 B=16, M=197, H=16, K=64    |      90.6   |     310.1
      f16 B=16, M=197, H=16, K=128   |      92.2   |     433.2
      b16 B=16, M=197, H=16, K=128   |      91.6   |     437.6
      f16 B=1, M=4096, H=160, K=128  |    5166.1   |   37295.6
      b16 B=1, M=4096, H=160, K=128  |    5326.0   |   37828.6
      f16 B=2, M=4096, H=160, K=128  |   10191.6   |   76206.0
      b16 B=2, M=4096, H=160, K=128  |   10507.9   |   77209.3
      f16 B=1, M=8192, H=160, K=128  |   19285.9   |  152334.7
      b16 B=1, M=8192, H=160, K=128  |   19893.0   |  148896.7
      f16 B=2, M=8192, H=160, K=128  |   38335.1   |
      b16 B=2, M=8192, H=160, K=128  |   39490.4   |
      f16 B=1024, M=82, H=8, K=64    |     486.3   |    1981.2
      b16 B=1024, M=82, H=8, K=64    |     514.4   |    2086.7
      f16 B=150, M=256, H=16, K=64   |     334.3   |    2402.0
      b16 B=150, M=256, H=16, K=64   |     350.1   |    2447.2
      f16 B=64, M=256, H=12, K=64    |     118.7   |     806.5
      b16 B=64, M=256, H=12, K=64    |     124.9   |     820.5
      f16 B=256, M=4096, H=16, K=64  |   67692.5   |
      b16 B=256, M=4096, H=16, K=64  |   73779.3   |
      f16 B=8, M=2048, H=20, K=128   |    1487.2   |    9354.3
      b16 B=8, M=2048, H=20, K=128   |    1529.4   |    9535.1
      f16 B=16, M=128, H=16, K=16    |      94.6   |     153.5
      b16 B=16, M=128, H=16, K=16    |      93.1   |     150.6
      f16 B=16, M=128, H=16, K=32    |      91.0   |     154.5
      b16 B=16, M=128, H=16, K=32    |      94.2   |     154.0
      f16 B=16, M=128, H=16, K=64    |      93.7   |     149.8
      b16 B=16, M=128, H=16, K=64    |      93.2   |     152.3
      f16 B=16, M=128, H=16, K=128   |      92.0   |     160.4
      b16 B=16, M=128, H=16, K=128   |      91.1   |     162.3
      f16 B=16, M=512, H=16, K=16    |      93.0   |     715.4
      b16 B=16, M=512, H=16, K=16    |      93.9   |     792.7
      f16 B=16, M=512, H=16, K=32    |      91.0   |     767.7
      b16 B=16, M=512, H=16, K=32    |      98.1   |     833.7
      f16 B=16, M=512, H=16, K=64    |     120.8   |     882.7
      b16 B=16, M=512, H=16, K=64    |     129.8   |     911.8
      f16 B=16, M=512, H=16, K=128   |     225.5   |    1067.1
      b16 B=16, M=512, H=16, K=128   |     231.7   |    1096.1
      f16 B=16, M=1024, H=16, K=16   |     195.9   |    2658.6
      b16 B=16, M=1024, H=16, K=16   |     252.9   |    3043.7
      f16 B=16, M=1024, H=16, K=32   |     243.7   |    2752.9
      b16 B=16, M=1024, H=16, K=32   |     291.1   |    3121.6
      f16 B=16, M=1024, H=16, K=64   |     355.5   |    2966.1
      b16 B=16, M=1024, H=16, K=64   |     384.2   |    3287.9
      f16 B=16, M=1024, H=16, K=128  |     682.0   |    3533.0
      b16 B=16, M=1024, H=16, K=128  |     703.3   |    3713.4
      f16 B=64, M=128, H=16, K=16    |      90.4   |     245.3
      b16 B=64, M=128, H=16, K=16    |      90.3   |     251.1
      f16 B=64, M=128, H=16, K=32    |      91.1   |     294.9
      b16 B=64, M=128, H=16, K=32    |      91.8   |     299.4
      f16 B=64, M=128, H=16, K=64    |      90.1   |     393.0
      b16 B=64, M=128, H=16, K=64    |      92.7   |     397.9
      f16 B=64, M=128, H=16, K=128   |     132.3   |     557.3
      b16 B=64, M=128, H=16, K=128   |     134.1   |     562.5
      f16 B=64, M=512, H=16, K=16    |     227.0   |    2705.2
      b16 B=64, M=512, H=16, K=16    |     288.3   |    3017.0
      f16 B=64, M=512, H=16, K=32    |     284.5   |    2893.9
      b16 B=64, M=512, H=16, K=32    |     336.8   |    3167.5
      f16 B=64, M=512, H=16, K=64    |     418.1   |    3338.8
      b16 B=64, M=512, H=16, K=64    |     444.2   |    3460.2
      f16 B=64, M=512, H=16, K=128   |     830.2   |    4091.8
      b16 B=64, M=512, H=16, K=128   |     855.8   |    4207.5
      f16 B=64, M=1024, H=16, K=16   |     720.1   |   10487.1
      b16 B=64, M=1024, H=16, K=16   |     929.6   |   12038.4
      f16 B=64, M=1024, H=16, K=32   |     898.2   |   10852.4
      b16 B=64, M=1024, H=16, K=32   |    1072.3   |   12350.3
      f16 B=64, M=1024, H=16, K=64   |    1325.3   |   11682.5
      b16 B=64, M=1024, H=16, K=64   |    1421.4   |   13010.3
      f16 B=64, M=1024, H=16, K=128  |    2604.2   |   13959.2
      b16 B=64, M=1024, H=16, K=128  |    2683.7   |   14677.4

Times are in microseconds (us).
Backwards for Triton fwd and Flash bwd:
[---- attention backward (attn_bias=<class 'NoneType'>) ----]
                                     |  optimized  |  vanilla
1 threads: --------------------------------------------------
      f16 B=384, M=197, H=1, K=64    |      242.0  |    650.9
      b16 B=384, M=197, H=1, K=64    |      221.9  |    652.5
      f16 B=1024, M=197, H=1, K=64   |      534.5  |   1647.7
      b16 B=1024, M=197, H=1, K=64   |      532.9  |   1648.0
      f16 B=32, M=197, H=16, K=64    |      279.1  |    882.9
      b16 B=32, M=197, H=16, K=64    |      278.5  |    883.9
      f16 B=32, M=197, H=16, K=128   |      638.0  |   1344.0
      b16 B=32, M=197, H=16, K=128   |      640.2  |   1346.5
      f16 B=16, M=197, H=16, K=64    |      235.9  |    484.8
      b16 B=16, M=197, H=16, K=64    |      245.0  |    485.8
      f16 B=16, M=197, H=16, K=128   |      367.5  |    711.5
      b16 B=16, M=197, H=16, K=128   |      368.5  |    713.9
      f16 B=1, M=4096, H=160, K=128  |    53888.1  |  38893.2
      b16 B=1, M=4096, H=160, K=128  |    53989.6  |  39876.1
      f16 B=2, M=4096, H=160, K=128  |    82534.0  |  78629.0
      b16 B=2, M=4096, H=160, K=128  |    82718.4  |  80537.6
      f16 B=1, M=8192, H=160, K=128  |   213238.1  |
      b16 B=1, M=8192, H=160, K=128  |   213335.0  |
      f16 B=2, M=8192, H=160, K=128  |   324974.6  |
      b16 B=2, M=8192, H=160, K=128  |   325014.0  |
      f16 B=1024, M=82, H=8, K=64    |     1494.2  |   3611.5
      b16 B=1024, M=82, H=8, K=64    |     1515.2  |   3787.3
      f16 B=150, M=256, H=16, K=64   |     1493.8  |   3800.9
      b16 B=150, M=256, H=16, K=64   |     1491.3  |   3843.4
      f16 B=64, M=256, H=12, K=64    |      523.3  |   1256.9
      b16 B=64, M=256, H=12, K=64    |      520.9  |   1271.3
      f16 B=256, M=4096, H=16, K=64  |   430986.6  |
      b16 B=256, M=4096, H=16, K=64  |   430693.7  |
      f16 B=8, M=2048, H=20, K=128   |    13945.0  |  10546.8
      b16 B=8, M=2048, H=20, K=128   |    13939.3  |  10798.0
      f16 B=16, M=128, H=16, K=16    |      196.4  |    327.9
      b16 B=16, M=128, H=16, K=16    |      197.0  |    353.1
      f16 B=16, M=128, H=16, K=32    |      215.2  |    350.5
      b16 B=16, M=128, H=16, K=32    |      197.6  |    325.6
      f16 B=16, M=128, H=16, K=64    |      195.9  |    346.2
      b16 B=16, M=128, H=16, K=64    |      216.0  |    357.0
      f16 B=16, M=128, H=16, K=128   |      197.1  |    322.8
      b16 B=16, M=128, H=16, K=128   |      218.6  |    320.6
      f16 B=16, M=512, H=16, K=16    |      321.2  |    984.5
      b16 B=16, M=512, H=16, K=16    |      323.6  |   1077.2
      f16 B=16, M=512, H=16, K=32    |      423.1  |   1089.6
      b16 B=16, M=512, H=16, K=32    |      425.6  |   1178.5
      f16 B=16, M=512, H=16, K=64    |      671.8  |   1285.5
      b16 B=16, M=512, H=16, K=64    |      673.3  |   1306.2
      f16 B=16, M=512, H=16, K=128   |     1512.0  |   1718.0
      b16 B=16, M=512, H=16, K=128   |     1514.4  |   1748.0
      f16 B=16, M=1024, H=16, K=16   |     1237.7  |   3532.8
      b16 B=16, M=1024, H=16, K=16   |     1240.7  |   3957.6
      f16 B=16, M=1024, H=16, K=32   |     1593.6  |   3733.2
      b16 B=16, M=1024, H=16, K=32   |     1594.9  |   4156.9
      f16 B=16, M=1024, H=16, K=64   |     2309.4  |   4278.0
      b16 B=16, M=1024, H=16, K=64   |     2311.7  |   4350.1
      f16 B=16, M=1024, H=16, K=128  |     5478.4  |   5110.4
      b16 B=16, M=1024, H=16, K=128  |     5498.3  |   5251.9
      f16 B=64, M=128, H=16, K=16    |      197.3  |    364.8
      b16 B=64, M=128, H=16, K=16    |      215.9  |    372.4
      f16 B=64, M=128, H=16, K=32    |      200.8  |    464.8
      b16 B=64, M=128, H=16, K=32    |      201.4  |    470.3
      f16 B=64, M=128, H=16, K=64    |      283.5  |    680.8
      b16 B=64, M=128, H=16, K=64    |      286.8  |    682.9
      f16 B=64, M=128, H=16, K=128   |      496.9  |   1076.1
      b16 B=64, M=128, H=16, K=128   |      501.1  |   1078.8
      f16 B=64, M=512, H=16, K=16    |     1181.5  |   3715.9
      b16 B=64, M=512, H=16, K=16    |     1187.2  |   4087.0
      f16 B=64, M=512, H=16, K=32    |     1497.3  |   4141.0
      b16 B=64, M=512, H=16, K=32    |     1505.0  |   4508.7
      f16 B=64, M=512, H=16, K=64    |     2296.0  |   4917.1
      b16 B=64, M=512, H=16, K=64    |     2309.8  |   5008.3
      f16 B=64, M=512, H=16, K=128   |     5157.6  |   6644.4
      b16 B=64, M=512, H=16, K=128   |     5178.1  |   6751.8
      f16 B=64, M=1024, H=16, K=16   |     4668.2  |  13869.2
      b16 B=64, M=1024, H=16, K=16   |     4671.8  |  15561.4
      f16 B=64, M=1024, H=16, K=32   |     5595.4  |  14761.1
      b16 B=64, M=1024, H=16, K=32   |     5609.4  |  16450.7
      f16 B=64, M=1024, H=16, K=64   |     7872.6  |  16891.0
      b16 B=64, M=1024, H=16, K=64   |     7899.7  |  17185.1
      f16 B=64, M=1024, H=16, K=128  |    18538.4  |  20376.2
      b16 B=64, M=1024, H=16, K=128  |    18576.9  |  20931.0

Times are in microseconds (us).

[ attention backward (attn_bias=<class 'xformers.ops.memory_efficient_attention.LowerTriangularMask'>) ]
                                     |  optimized  |  vanilla
1 threads: --------------------------------------------------
      f16 B=384, M=197, H=1, K=64    |      225.1  |    652.1
      b16 B=384, M=197, H=1, K=64    |      227.0  |    652.6
      f16 B=1024, M=197, H=1, K=64   |      542.6  |   1647.2
      b16 B=1024, M=197, H=1, K=64   |      546.3  |   1647.3
      f16 B=32, M=197, H=16, K=64    |      282.4  |    885.0
      b16 B=32, M=197, H=16, K=64    |      285.3  |    886.9
      f16 B=32, M=197, H=16, K=128   |      521.3  |   1345.3
      b16 B=32, M=197, H=16, K=128   |      521.9  |   1346.8
      f16 B=16, M=197, H=16, K=64    |      243.5  |    485.8
      b16 B=16, M=197, H=16, K=64    |      215.6  |    487.3
      f16 B=16, M=197, H=16, K=128   |      298.9  |    714.0
      b16 B=16, M=197, H=16, K=128   |      299.4  |    715.9
      f16 B=1, M=4096, H=160, K=128  |    31532.4  |  38844.2
      b16 B=1, M=4096, H=160, K=128  |    31579.6  |  39971.8
      f16 B=2, M=4096, H=160, K=128  |    48359.2  |  78614.4
      b16 B=2, M=4096, H=160, K=128  |    48379.2  |  80423.5
      f16 B=1, M=8192, H=160, K=128  |   121826.9  |
      b16 B=1, M=8192, H=160, K=128  |   121908.8  |
      f16 B=2, M=8192, H=160, K=128  |   186970.4  |
      b16 B=2, M=8192, H=160, K=128  |   186947.6  |
      f16 B=1024, M=82, H=8, K=64    |     1513.1  |   3612.9
      b16 B=1024, M=82, H=8, K=64    |     1524.6  |   3789.6
      f16 B=150, M=256, H=16, K=64   |     1513.9  |   3821.8
      b16 B=150, M=256, H=16, K=64   |     1530.0  |   3863.2
      f16 B=64, M=256, H=12, K=64    |      529.9  |   1257.3
      b16 B=64, M=256, H=12, K=64    |      533.3  |   1272.5
      f16 B=256, M=4096, H=16, K=64  |   241098.0  |
      b16 B=256, M=4096, H=16, K=64  |   241300.4  |
      f16 B=8, M=2048, H=20, K=128   |     8338.3  |  10544.2
      b16 B=8, M=2048, H=20, K=128   |     8347.7  |  10792.9
      f16 B=16, M=128, H=16, K=16    |      197.3  |    328.7
      b16 B=16, M=128, H=16, K=16    |      215.8  |    329.7
      f16 B=16, M=128, H=16, K=32    |      196.4  |    308.2
      b16 B=16, M=128, H=16, K=32    |      214.5  |    332.4
      f16 B=16, M=128, H=16, K=64    |      221.3  |    302.9
      b16 B=16, M=128, H=16, K=64    |      195.7  |    335.6
      f16 B=16, M=128, H=16, K=128   |      216.6  |    324.5
      b16 B=16, M=128, H=16, K=128   |      217.9  |    321.4
      f16 B=16, M=512, H=16, K=16    |      261.4  |    986.3
      b16 B=16, M=512, H=16, K=16    |      263.7  |   1078.2
      f16 B=16, M=512, H=16, K=32    |      338.6  |   1089.9
      b16 B=16, M=512, H=16, K=32    |      341.3  |   1178.4
      f16 B=16, M=512, H=16, K=64    |      519.7  |   1286.0
      b16 B=16, M=512, H=16, K=64    |      523.9  |   1306.3
      f16 B=16, M=512, H=16, K=128   |     1036.2  |   1717.5
      b16 B=16, M=512, H=16, K=128   |     1039.4  |   1744.6
      f16 B=16, M=1024, H=16, K=16   |      790.6  |   3531.1
      b16 B=16, M=1024, H=16, K=16   |      793.3  |   3954.9
      f16 B=16, M=1024, H=16, K=32   |     1023.9  |   3736.3
      b16 B=16, M=1024, H=16, K=32   |     1023.6  |   4155.7
      f16 B=16, M=1024, H=16, K=64   |     1565.6  |   4279.1
      b16 B=16, M=1024, H=16, K=64   |     1572.5  |   4350.4
      f16 B=16, M=1024, H=16, K=128  |     3437.1  |   5117.6
      b16 B=16, M=1024, H=16, K=128  |     3449.8  |   5254.7
      f16 B=64, M=128, H=16, K=16    |      199.3  |    367.9
      b16 B=64, M=128, H=16, K=16    |      196.3  |    376.1
      f16 B=64, M=128, H=16, K=32    |      206.7  |    465.6
      b16 B=64, M=128, H=16, K=32    |      214.1  |    471.7
      f16 B=64, M=128, H=16, K=64    |      286.8  |    683.6
      b16 B=64, M=128, H=16, K=64    |      289.0  |    685.6
      f16 B=64, M=128, H=16, K=128   |      512.0  |   1075.4
      b16 B=64, M=128, H=16, K=128   |      514.3  |   1078.1
      f16 B=64, M=512, H=16, K=16    |      924.1  |   3713.1
      b16 B=64, M=512, H=16, K=16    |      929.9  |   4090.5
      f16 B=64, M=512, H=16, K=32    |     1180.3  |   4138.9
      b16 B=64, M=512, H=16, K=32    |     1186.6  |   4505.3
      f16 B=64, M=512, H=16, K=64    |     1761.6  |   4914.9
      b16 B=64, M=512, H=16, K=64    |     1780.9  |   5004.2
      f16 B=64, M=512, H=16, K=128   |     3586.7  |   6644.1
      b16 B=64, M=512, H=16, K=128   |     3601.9  |   6750.4
      f16 B=64, M=1024, H=16, K=16   |     2859.7  |  13877.5
      b16 B=64, M=1024, H=16, K=16   |     2863.2  |  15563.7
      f16 B=64, M=1024, H=16, K=32   |     3587.1  |  14759.8
      b16 B=64, M=1024, H=16, K=32   |     3587.2  |  16462.9
      f16 B=64, M=1024, H=16, K=64   |     5363.7  |  16889.4
      b16 B=64, M=1024, H=16, K=64   |     5384.1  |  17213.8
      f16 B=64, M=1024, H=16, K=128  |    11821.0  |  20391.7
      b16 B=64, M=1024, H=16, K=128  |    11861.3  |  20907.0

Times are in microseconds (us).

Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM, thanks a lot Diana!

I've added a few more comments that can be addressed in the future, but we can get started with this for now!

Comment on lines 32 to 34

# Dependency for triton flash attn
flash-attn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this is needed given that we have flash-attn as a submodule, which is setup in our build system?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we would need to change our build scripts to also install flash-attn. Otherwise we might have this override our version of _C_flashattn that we compile

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, all tests in CI from Triton seem to have been skipped. Worth looking into this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they're skipping because I'm requiring sm80 for the Triton implementation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added flash-attention as a dependency in setup.py, is that what you had in mind?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they're skipping because I'm requiring sm80 for the Triton implementation

Indeed we don't have sm80 on the CI. But it should work with Sm75 (we have it in the CI)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added sm75 in #556, will check if it works!

Comment on lines +726 to +766
def supports(cls, d: "AttentionOpDispatch") -> bool:
if not has_triton_flashattention:
return False
device_capability = torch.cuda.get_device_capability(d.device)
is_sm80 = device_capability[0] >= 8
if not is_sm80:
return False
return super(TritonFlashAttentionOp, cls).supports(d)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, so it also supports K which is not a multiple of 8? cc @danthe3rd

xformers/ops/memory_efficient_attention.py Outdated Show resolved Hide resolved
@@ -80,6 +80,8 @@ def generate_test_shapes_B_Mq_Mkv_H_K_Kv(op):
xformers.ops.MemoryEfficientAttentionCutlassOp,
xformers.ops.MemoryEfficientAttentionFlashAttentionOp,
xformers.ops.MemoryEfficientAttentionCutlassFwdFlashBwOp,
xformers.ops.TritonFlashAttentionOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the future, it might be worth adding a test for checking about race conditions. This has been illustrated in this comment in flashattention, and I know that @danthe3rd has had issues with race conditions in the past so might be good to extend our tests to cover this case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They already have testing for race condition in the HazyResarch repo here but maybe makes sense to add it if we want the same testing for other implementations as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check that a bit by running with very large batches, to ensure the GPU is saturated (so missing __syncthreads() cause wrong results). We could also have something similar to what flash is doing tho

Comment on lines 955 to 956
MemoryEfficientAttentionTritonFwdFlashBwOp,
TritonFlashAttentionOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we change the order of priority in here so that TritonFwdFlashBwdOp is dispatched more often?
Maybe writing another _is_triton_faster_than_..?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! I added a method for now but haven't filled it out. It seems like triton is faster than cutlass for most except for the following:

B=64, M=128, H=16, K=128
B=64, M=128, H=16, K=64
B=64, M=128, H=16, K=16
B=16, M=128, H=16, K=128
B=16, M=128, H=16, K=64 
B=16, M=128, H=16, K=32
B=16, M=128, H=16, K=16
B=1024, M=82, H=8, K=64
B=16, M=197, H=16, K=128
B=16, M=197, H=16, K=64
B=32, M=197, H=16, K=128
B=32, M=197, H=16, K=64
B=384, M=197, H=1, K=64

@danthe3rd do you know for which cases in general we should expect cutlass to be faster?

Copy link
Contributor

@danthe3rd danthe3rd Dec 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know - we could be conservative for now (eg keep the old behavior) and enable it one-by-one if we see opportunities.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I'll post in the xFormers group so people know how to try it out if they want

@dianaml0
Copy link
Contributor Author

dianaml0 commented Dec 6, 2022

@fmassa Thanks a lot for the helpful comments and for taking another pass! I've updated with related changes. Okay to merge for now?

@codecov-commenter
Copy link

codecov-commenter commented Dec 6, 2022

Codecov Report

Base: 89.79% // Head: 89.06% // Decreases project coverage by -0.73% ⚠️

Coverage data is based on head (5724663) compared to base (71205ec).
Patch coverage: 48.86% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #479      +/-   ##
==========================================
- Coverage   89.79%   89.06%   -0.74%     
==========================================
  Files          80       80              
  Lines        4839     4927      +88     
==========================================
+ Hits         4345     4388      +43     
- Misses        494      539      +45     
Flag Coverage Δ
Python 89.06% <48.86%> (-0.74%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
xformers/info.py 0.00% <ø> (ø)
xformers/ops/__init__.py 82.35% <ø> (ø)
xformers/ops/memory_efficient_attention.py 78.05% <48.86%> (-6.66%) ⬇️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

Copy link
Contributor

@danthe3rd danthe3rd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot Diana! This looks great. A few nits that can be addressed later - let's get this merged :)

del flatten_diff
assert torch.allclose(out, ref, rtol=rtol, atol=atol), (
f"{msg}: "
f"out={out.flatten()[max_pos]} and ref={ref.flatten()[max_pos]} (diff={max_diff} > 0)"
f"/ atol={atol}, rtol={rtol}"
f"/ total failing elements: {num_different}, percentage={percentage}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

MemoryEfficientAttentionCutlassOp,
TritonFlashAttentionOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we should include the Triton backward into the priority list, unless we are confident it works fine without correctness issues (Tri wasn't really sure about that).
Since it seems properly tested in the tests, I guess we should be good to go

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I removed it in #556 for now

Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's get this merged.

Thanks Diana!

@@ -267,6 +267,7 @@ def run(self):
version=version,
install_requires=fetch_requirements(),
packages=setuptools.find_packages(exclude=("tests", "tests.*")),
dependency_links=["file:///./third_party/flash-attention#egg=flash-attention"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danthe3rd let's test this afterwards. I think we will need to change the way we depend on flashattention in our code

@danthe3rd danthe3rd merged commit f2f3424 into main Dec 6, 2022
@danthe3rd danthe3rd deleted the triton_flash branch December 6, 2022 15:01
@dianaml0 dianaml0 mentioned this pull request Dec 6, 2022
10 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants