-
Notifications
You must be signed in to change notification settings - Fork 633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor4: Split fw/bw operators #560
Conversation
[ghstack-poisoned]
ghstack-source-id: ecfdae2d0fcfd9e24e3535f832cc53f378a75f74 Pull Request resolved: #560
[ghstack-poisoned]
ghstack-source-id: f61b6c2a352efc612252db71c765b78d070ec7a3 Pull Request resolved: #560
[ghstack-poisoned]
ghstack-source-id: c152503f42202510a29b0e7ac08441c13da8466c Pull Request resolved: #560
<details> <summary>BW perf on A100</summary> ``` [-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------] | refactor[flshattB] | refactor[cutlassB] | flash2[flshatt] | vanilla | cutlass[cutlass] | flash1[flshatt] 1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 449.4 | | 447.6 | 2260.4 | 621.0 | f16 B=384, M=197, H=1, K=80 | 433.8 | | 431.8 | 1916.2 | 589.4 | f16 B=384, M=197, H=1, K=64 | 234.5 | | 233.6 | 1810.0 | 406.2 | 232.8 f16 B=1024, M=197, H=1, K=88 | 1112.9 | | 1111.3 | 5945.9 | 1557.7 | f16 B=1024, M=197, H=1, K=80 | 1073.3 | | 1073.5 | 5022.5 | 1480.4 | f16 B=1024, M=197, H=1, K=64 | 579.9 | | 579.4 | 4736.3 | 930.1 | 578.3 f16 B=512, M=197, H=1, K=80 | 546.0 | | 545.1 | 2539.2 | 749.0 | f16 B=32, M=197, H=16, K=80 | 560.2 | | 558.7 | 2579.7 | 748.1 | f16 B=32, M=197, H=16, K=64 | 298.4 | | 297.8 | 2430.3 | 480.1 | 295.8 f16 B=32, M=197, H=16, K=128 | 652.5 | | 647.4 | 4494.1 | 896.4 | 684.3 f16 B=256, M=197, H=1, K=88 | 330.1 | | 328.8 | 1527.1 | 449.9 | f16 B=16, M=197, H=16, K=88 | 331.2 | | 329.1 | 1540.0 | 445.2 | f16 B=16, M=197, H=16, K=64 | 198.7 | | 166.9 | 1243.8 | 246.4 | 165.9 f16 B=16, M=197, H=16, K=128 | 371.7 | | 367.4 | 2266.8 | 503.6 | 386.4 f16 B=1, M=4096, H=160, K=128 | 42897.1 | | 42915.9 | 45961.0 | 56457.5 | 54740.1 f16 B=2, M=4096, H=160, K=128 | 84681.5 | | 84646.9 | | 93553.0 | 84262.0 f16 B=1, M=8192, H=160, K=128 | 165741.2 | | 165657.6 | | 222850.4 | 215961.0 f16 B=2, M=8192, H=160, K=128 | 332314.2 | | 332593.3 | | 364220.8 | 330727.9 f16 B=1024, M=82, H=8, K=64 | 1647.0 | | 1643.6 | 3821.1 | 1886.9 | 1621.7 f16 B=150, M=256, H=16, K=64 | 1632.0 | | 1626.4 | 4556.0 | 2302.2 | 1626.6 f16 B=64, M=256, H=12, K=64 | 570.4 | | 566.8 | 1499.1 | 778.4 | 567.1 f16 B=1, M=4096, H=16, K=40 | 2153.5 | | 2154.7 | 4193.7 | 23526.2 | f16 B=1, M=16384, H=16, K=40 | 27859.4 | | 27850.3 | | 389891.9 | f16 B=256, M=4096, H=16, K=64 | 491968.1 | | 491246.6 | | 729981.6 | 440392.8 f16 B=16, M=128, H=16, K=16 | 195.8 | | 118.3 | 276.8 | 129.3 | 120.4 f16 B=16, M=128, H=16, K=32 | 197.7 | | 117.6 | 276.8 | 128.8 | 113.7 f16 B=16, M=128, H=16, K=64 | 196.0 | | 122.9 | 324.0 | 129.9 | 113.6 f16 B=16, M=128, H=16, K=128 | 196.4 | | 170.0 | 331.5 | 182.2 | 157.9 f16 B=16, M=128, H=16, K=256 | | 791.8 | | 545.0 | 774.4 | f16 B=16, M=512, H=16, K=16 | 378.0 | | 375.1 | 1203.0 | 623.4 | 325.4 f16 B=16, M=512, H=16, K=32 | 436.7 | | 433.8 | 1306.2 | 711.6 | 432.0 f16 B=16, M=512, H=16, K=64 | 684.6 | | 680.6 | 1543.4 | 917.7 | 702.3 f16 B=16, M=512, H=16, K=128 | 1565.8 | | 1560.9 | 1983.1 | 1656.5 | 1580.7 f16 B=16, M=512, H=16, K=256 | | 8678.8 | | 2901.2 | 8652.9 | f16 B=16, M=1024, H=16, K=16 | 1432.3 | | 1435.8 | 4262.8 | 2422.2 | 1245.2 f16 B=16, M=1024, H=16, K=32 | 1608.0 | | 1605.9 | 4489.6 | 2697.4 | 1617.9 f16 B=16, M=1024, H=16, K=64 | 2549.3 | | 2541.9 | 4992.6 | 3341.4 | 2370.4 f16 B=16, M=1024, H=16, K=128 | 5034.4 | | 5023.6 | 5955.5 | 5903.8 | 5634.1 f16 B=16, M=1024, H=16, K=256 | | 31948.6 | | 7891.3 | 31850.4 | f16 B=64, M=128, H=16, K=16 | 193.9 | | 159.7 | 439.8 | 158.3 | 145.0 f16 B=64, M=128, H=16, K=32 | 217.0 | | 215.3 | 545.3 | 209.4 | 211.2 f16 B=64, M=128, H=16, K=64 | 316.0 | | 313.6 | 767.0 | 328.6 | 310.2 f16 B=64, M=128, H=16, K=128 | 559.7 | | 555.9 | 1229.2 | 637.6 | 561.6 f16 B=64, M=128, H=16, K=256 | | 2851.8 | | 2126.3 | 2839.0 | f16 B=64, M=512, H=16, K=16 | 1372.4 | | 1370.9 | 4484.9 | 2313.9 | 1219.4 f16 B=64, M=512, H=16, K=32 | 1542.0 | | 1539.2 | 4978.8 | 2706.7 | 1547.4 f16 B=64, M=512, H=16, K=64 | 2335.4 | | 2330.0 | 5885.5 | 3464.0 | 2417.8 f16 B=64, M=512, H=16, K=128 | 5357.0 | | 5347.8 | 7711.8 | 5909.5 | 5442.7 f16 B=64, M=512, H=16, K=256 | | 31136.1 | | 11501.8 | 30488.7 | f16 B=64, M=1024, H=16, K=16 | 5226.1 | | 5228.4 | 16908.5 | 9092.8 | 4713.7 f16 B=64, M=1024, H=16, K=32 | 5717.3 | | 5723.1 | 17895.8 | 10439.0 | 5705.3 f16 B=64, M=1024, H=16, K=64 | 8672.2 | | 8664.4 | 19929.7 | 12937.8 | 8137.3 f16 B=64, M=1024, H=16, K=128 | 19104.3 | | 19080.9 | 23711.0 | 20963.4 | 19169.7 f16 B=64, M=1024, H=16, K=256 | | 115163.2 | | 32729.8 | 113715.6 | Times are in microseconds (us). ``` </details> [ghstack-poisoned]
ghstack-source-id: f4446e42dc1645e7cafb81934b7a16dbbb03a57e Pull Request resolved: #560
<details> <summary>BW perf on A100</summary> ``` [-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------] | refactor[flshattB] | refactor[cutlassB] | flash2[flshatt] | vanilla | cutlass[cutlass] | flash1[flshatt] 1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 449.4 | | 447.6 | 2260.4 | 621.0 | f16 B=384, M=197, H=1, K=80 | 433.8 | | 431.8 | 1916.2 | 589.4 | f16 B=384, M=197, H=1, K=64 | 234.5 | | 233.6 | 1810.0 | 406.2 | 232.8 f16 B=1024, M=197, H=1, K=88 | 1112.9 | | 1111.3 | 5945.9 | 1557.7 | f16 B=1024, M=197, H=1, K=80 | 1073.3 | | 1073.5 | 5022.5 | 1480.4 | f16 B=1024, M=197, H=1, K=64 | 579.9 | | 579.4 | 4736.3 | 930.1 | 578.3 f16 B=512, M=197, H=1, K=80 | 546.0 | | 545.1 | 2539.2 | 749.0 | f16 B=32, M=197, H=16, K=80 | 560.2 | | 558.7 | 2579.7 | 748.1 | f16 B=32, M=197, H=16, K=64 | 298.4 | | 297.8 | 2430.3 | 480.1 | 295.8 f16 B=32, M=197, H=16, K=128 | 652.5 | | 647.4 | 4494.1 | 896.4 | 684.3 f16 B=256, M=197, H=1, K=88 | 330.1 | | 328.8 | 1527.1 | 449.9 | f16 B=16, M=197, H=16, K=88 | 331.2 | | 329.1 | 1540.0 | 445.2 | f16 B=16, M=197, H=16, K=64 | 198.7 | | 166.9 | 1243.8 | 246.4 | 165.9 f16 B=16, M=197, H=16, K=128 | 371.7 | | 367.4 | 2266.8 | 503.6 | 386.4 f16 B=1, M=4096, H=160, K=128 | 42897.1 | | 42915.9 | 45961.0 | 56457.5 | 54740.1 f16 B=2, M=4096, H=160, K=128 | 84681.5 | | 84646.9 | | 93553.0 | 84262.0 f16 B=1, M=8192, H=160, K=128 | 165741.2 | | 165657.6 | | 222850.4 | 215961.0 f16 B=2, M=8192, H=160, K=128 | 332314.2 | | 332593.3 | | 364220.8 | 330727.9 f16 B=1024, M=82, H=8, K=64 | 1647.0 | | 1643.6 | 3821.1 | 1886.9 | 1621.7 f16 B=150, M=256, H=16, K=64 | 1632.0 | | 1626.4 | 4556.0 | 2302.2 | 1626.6 f16 B=64, M=256, H=12, K=64 | 570.4 | | 566.8 | 1499.1 | 778.4 | 567.1 f16 B=1, M=4096, H=16, K=40 | 2153.5 | | 2154.7 | 4193.7 | 23526.2 | f16 B=1, M=16384, H=16, K=40 | 27859.4 | | 27850.3 | | 389891.9 | f16 B=256, M=4096, H=16, K=64 | 491968.1 | | 491246.6 | | 729981.6 | 440392.8 f16 B=16, M=128, H=16, K=16 | 195.8 | | 118.3 | 276.8 | 129.3 | 120.4 f16 B=16, M=128, H=16, K=32 | 197.7 | | 117.6 | 276.8 | 128.8 | 113.7 f16 B=16, M=128, H=16, K=64 | 196.0 | | 122.9 | 324.0 | 129.9 | 113.6 f16 B=16, M=128, H=16, K=128 | 196.4 | | 170.0 | 331.5 | 182.2 | 157.9 f16 B=16, M=128, H=16, K=256 | | 791.8 | | 545.0 | 774.4 | f16 B=16, M=512, H=16, K=16 | 378.0 | | 375.1 | 1203.0 | 623.4 | 325.4 f16 B=16, M=512, H=16, K=32 | 436.7 | | 433.8 | 1306.2 | 711.6 | 432.0 f16 B=16, M=512, H=16, K=64 | 684.6 | | 680.6 | 1543.4 | 917.7 | 702.3 f16 B=16, M=512, H=16, K=128 | 1565.8 | | 1560.9 | 1983.1 | 1656.5 | 1580.7 f16 B=16, M=512, H=16, K=256 | | 8678.8 | | 2901.2 | 8652.9 | f16 B=16, M=1024, H=16, K=16 | 1432.3 | | 1435.8 | 4262.8 | 2422.2 | 1245.2 f16 B=16, M=1024, H=16, K=32 | 1608.0 | | 1605.9 | 4489.6 | 2697.4 | 1617.9 f16 B=16, M=1024, H=16, K=64 | 2549.3 | | 2541.9 | 4992.6 | 3341.4 | 2370.4 f16 B=16, M=1024, H=16, K=128 | 5034.4 | | 5023.6 | 5955.5 | 5903.8 | 5634.1 f16 B=16, M=1024, H=16, K=256 | | 31948.6 | | 7891.3 | 31850.4 | f16 B=64, M=128, H=16, K=16 | 193.9 | | 159.7 | 439.8 | 158.3 | 145.0 f16 B=64, M=128, H=16, K=32 | 217.0 | | 215.3 | 545.3 | 209.4 | 211.2 f16 B=64, M=128, H=16, K=64 | 316.0 | | 313.6 | 767.0 | 328.6 | 310.2 f16 B=64, M=128, H=16, K=128 | 559.7 | | 555.9 | 1229.2 | 637.6 | 561.6 f16 B=64, M=128, H=16, K=256 | | 2851.8 | | 2126.3 | 2839.0 | f16 B=64, M=512, H=16, K=16 | 1372.4 | | 1370.9 | 4484.9 | 2313.9 | 1219.4 f16 B=64, M=512, H=16, K=32 | 1542.0 | | 1539.2 | 4978.8 | 2706.7 | 1547.4 f16 B=64, M=512, H=16, K=64 | 2335.4 | | 2330.0 | 5885.5 | 3464.0 | 2417.8 f16 B=64, M=512, H=16, K=128 | 5357.0 | | 5347.8 | 7711.8 | 5909.5 | 5442.7 f16 B=64, M=512, H=16, K=256 | | 31136.1 | | 11501.8 | 30488.7 | f16 B=64, M=1024, H=16, K=16 | 5226.1 | | 5228.4 | 16908.5 | 9092.8 | 4713.7 f16 B=64, M=1024, H=16, K=32 | 5717.3 | | 5723.1 | 17895.8 | 10439.0 | 5705.3 f16 B=64, M=1024, H=16, K=64 | 8672.2 | | 8664.4 | 19929.7 | 12937.8 | 8137.3 f16 B=64, M=1024, H=16, K=128 | 19104.3 | | 19080.9 | 23711.0 | 20963.4 | 19169.7 f16 B=64, M=1024, H=16, K=256 | | 115163.2 | | 32729.8 | 113715.6 | Times are in microseconds (us). ``` </details> [ghstack-poisoned]
…ors" <details> <summary>BW perf on A100</summary> ``` [-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------] | refactor[flshattB] | refactor[cutlassB] | flash2[flshatt] | vanilla | cutlass[cutlass] | flash1[flshatt] 1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 449.4 | | 447.6 | 2260.4 | 621.0 | f16 B=384, M=197, H=1, K=80 | 433.8 | | 431.8 | 1916.2 | 589.4 | f16 B=384, M=197, H=1, K=64 | 234.5 | | 233.6 | 1810.0 | 406.2 | 232.8 f16 B=1024, M=197, H=1, K=88 | 1112.9 | | 1111.3 | 5945.9 | 1557.7 | f16 B=1024, M=197, H=1, K=80 | 1073.3 | | 1073.5 | 5022.5 | 1480.4 | f16 B=1024, M=197, H=1, K=64 | 579.9 | | 579.4 | 4736.3 | 930.1 | 578.3 f16 B=512, M=197, H=1, K=80 | 546.0 | | 545.1 | 2539.2 | 749.0 | f16 B=32, M=197, H=16, K=80 | 560.2 | | 558.7 | 2579.7 | 748.1 | f16 B=32, M=197, H=16, K=64 | 298.4 | | 297.8 | 2430.3 | 480.1 | 295.8 f16 B=32, M=197, H=16, K=128 | 652.5 | | 647.4 | 4494.1 | 896.4 | 684.3 f16 B=256, M=197, H=1, K=88 | 330.1 | | 328.8 | 1527.1 | 449.9 | f16 B=16, M=197, H=16, K=88 | 331.2 | | 329.1 | 1540.0 | 445.2 | f16 B=16, M=197, H=16, K=64 | 198.7 | | 166.9 | 1243.8 | 246.4 | 165.9 f16 B=16, M=197, H=16, K=128 | 371.7 | | 367.4 | 2266.8 | 503.6 | 386.4 f16 B=1, M=4096, H=160, K=128 | 42897.1 | | 42915.9 | 45961.0 | 56457.5 | 54740.1 f16 B=2, M=4096, H=160, K=128 | 84681.5 | | 84646.9 | | 93553.0 | 84262.0 f16 B=1, M=8192, H=160, K=128 | 165741.2 | | 165657.6 | | 222850.4 | 215961.0 f16 B=2, M=8192, H=160, K=128 | 332314.2 | | 332593.3 | | 364220.8 | 330727.9 f16 B=1024, M=82, H=8, K=64 | 1647.0 | | 1643.6 | 3821.1 | 1886.9 | 1621.7 f16 B=150, M=256, H=16, K=64 | 1632.0 | | 1626.4 | 4556.0 | 2302.2 | 1626.6 f16 B=64, M=256, H=12, K=64 | 570.4 | | 566.8 | 1499.1 | 778.4 | 567.1 f16 B=1, M=4096, H=16, K=40 | 2153.5 | | 2154.7 | 4193.7 | 23526.2 | f16 B=1, M=16384, H=16, K=40 | 27859.4 | | 27850.3 | | 389891.9 | f16 B=256, M=4096, H=16, K=64 | 491968.1 | | 491246.6 | | 729981.6 | 440392.8 f16 B=16, M=128, H=16, K=16 | 195.8 | | 118.3 | 276.8 | 129.3 | 120.4 f16 B=16, M=128, H=16, K=32 | 197.7 | | 117.6 | 276.8 | 128.8 | 113.7 f16 B=16, M=128, H=16, K=64 | 196.0 | | 122.9 | 324.0 | 129.9 | 113.6 f16 B=16, M=128, H=16, K=128 | 196.4 | | 170.0 | 331.5 | 182.2 | 157.9 f16 B=16, M=128, H=16, K=256 | | 791.8 | | 545.0 | 774.4 | f16 B=16, M=512, H=16, K=16 | 378.0 | | 375.1 | 1203.0 | 623.4 | 325.4 f16 B=16, M=512, H=16, K=32 | 436.7 | | 433.8 | 1306.2 | 711.6 | 432.0 f16 B=16, M=512, H=16, K=64 | 684.6 | | 680.6 | 1543.4 | 917.7 | 702.3 f16 B=16, M=512, H=16, K=128 | 1565.8 | | 1560.9 | 1983.1 | 1656.5 | 1580.7 f16 B=16, M=512, H=16, K=256 | | 8678.8 | | 2901.2 | 8652.9 | f16 B=16, M=1024, H=16, K=16 | 1432.3 | | 1435.8 | 4262.8 | 2422.2 | 1245.2 f16 B=16, M=1024, H=16, K=32 | 1608.0 | | 1605.9 | 4489.6 | 2697.4 | 1617.9 f16 B=16, M=1024, H=16, K=64 | 2549.3 | | 2541.9 | 4992.6 | 3341.4 | 2370.4 f16 B=16, M=1024, H=16, K=128 | 5034.4 | | 5023.6 | 5955.5 | 5903.8 | 5634.1 f16 B=16, M=1024, H=16, K=256 | | 31948.6 | | 7891.3 | 31850.4 | f16 B=64, M=128, H=16, K=16 | 193.9 | | 159.7 | 439.8 | 158.3 | 145.0 f16 B=64, M=128, H=16, K=32 | 217.0 | | 215.3 | 545.3 | 209.4 | 211.2 f16 B=64, M=128, H=16, K=64 | 316.0 | | 313.6 | 767.0 | 328.6 | 310.2 f16 B=64, M=128, H=16, K=128 | 559.7 | | 555.9 | 1229.2 | 637.6 | 561.6 f16 B=64, M=128, H=16, K=256 | | 2851.8 | | 2126.3 | 2839.0 | f16 B=64, M=512, H=16, K=16 | 1372.4 | | 1370.9 | 4484.9 | 2313.9 | 1219.4 f16 B=64, M=512, H=16, K=32 | 1542.0 | | 1539.2 | 4978.8 | 2706.7 | 1547.4 f16 B=64, M=512, H=16, K=64 | 2335.4 | | 2330.0 | 5885.5 | 3464.0 | 2417.8 f16 B=64, M=512, H=16, K=128 | 5357.0 | | 5347.8 | 7711.8 | 5909.5 | 5442.7 f16 B=64, M=512, H=16, K=256 | | 31136.1 | | 11501.8 | 30488.7 | f16 B=64, M=1024, H=16, K=16 | 5226.1 | | 5228.4 | 16908.5 | 9092.8 | 4713.7 f16 B=64, M=1024, H=16, K=32 | 5717.3 | | 5723.1 | 17895.8 | 10439.0 | 5705.3 f16 B=64, M=1024, H=16, K=64 | 8672.2 | | 8664.4 | 19929.7 | 12937.8 | 8137.3 f16 B=64, M=1024, H=16, K=128 | 19104.3 | | 19080.9 | 23711.0 | 20963.4 | 19169.7 f16 B=64, M=1024, H=16, K=256 | | 115163.2 | | 32729.8 | 113715.6 | Times are in microseconds (us). ``` </details> [ghstack-poisoned]
<details> <summary>BW perf on A100</summary> ``` [-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------] | refactor[flshattB] | refactor[cutlassB] | flash2[flshatt] | vanilla | cutlass[cutlass] | flash1[flshatt] 1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 449.4 | | 447.6 | 2260.4 | 621.0 | f16 B=384, M=197, H=1, K=80 | 433.8 | | 431.8 | 1916.2 | 589.4 | f16 B=384, M=197, H=1, K=64 | 234.5 | | 233.6 | 1810.0 | 406.2 | 232.8 f16 B=1024, M=197, H=1, K=88 | 1112.9 | | 1111.3 | 5945.9 | 1557.7 | f16 B=1024, M=197, H=1, K=80 | 1073.3 | | 1073.5 | 5022.5 | 1480.4 | f16 B=1024, M=197, H=1, K=64 | 579.9 | | 579.4 | 4736.3 | 930.1 | 578.3 f16 B=512, M=197, H=1, K=80 | 546.0 | | 545.1 | 2539.2 | 749.0 | f16 B=32, M=197, H=16, K=80 | 560.2 | | 558.7 | 2579.7 | 748.1 | f16 B=32, M=197, H=16, K=64 | 298.4 | | 297.8 | 2430.3 | 480.1 | 295.8 f16 B=32, M=197, H=16, K=128 | 652.5 | | 647.4 | 4494.1 | 896.4 | 684.3 f16 B=256, M=197, H=1, K=88 | 330.1 | | 328.8 | 1527.1 | 449.9 | f16 B=16, M=197, H=16, K=88 | 331.2 | | 329.1 | 1540.0 | 445.2 | f16 B=16, M=197, H=16, K=64 | 198.7 | | 166.9 | 1243.8 | 246.4 | 165.9 f16 B=16, M=197, H=16, K=128 | 371.7 | | 367.4 | 2266.8 | 503.6 | 386.4 f16 B=1, M=4096, H=160, K=128 | 42897.1 | | 42915.9 | 45961.0 | 56457.5 | 54740.1 f16 B=2, M=4096, H=160, K=128 | 84681.5 | | 84646.9 | | 93553.0 | 84262.0 f16 B=1, M=8192, H=160, K=128 | 165741.2 | | 165657.6 | | 222850.4 | 215961.0 f16 B=2, M=8192, H=160, K=128 | 332314.2 | | 332593.3 | | 364220.8 | 330727.9 f16 B=1024, M=82, H=8, K=64 | 1647.0 | | 1643.6 | 3821.1 | 1886.9 | 1621.7 f16 B=150, M=256, H=16, K=64 | 1632.0 | | 1626.4 | 4556.0 | 2302.2 | 1626.6 f16 B=64, M=256, H=12, K=64 | 570.4 | | 566.8 | 1499.1 | 778.4 | 567.1 f16 B=1, M=4096, H=16, K=40 | 2153.5 | | 2154.7 | 4193.7 | 23526.2 | f16 B=1, M=16384, H=16, K=40 | 27859.4 | | 27850.3 | | 389891.9 | f16 B=256, M=4096, H=16, K=64 | 491968.1 | | 491246.6 | | 729981.6 | 440392.8 f16 B=16, M=128, H=16, K=16 | 195.8 | | 118.3 | 276.8 | 129.3 | 120.4 f16 B=16, M=128, H=16, K=32 | 197.7 | | 117.6 | 276.8 | 128.8 | 113.7 f16 B=16, M=128, H=16, K=64 | 196.0 | | 122.9 | 324.0 | 129.9 | 113.6 f16 B=16, M=128, H=16, K=128 | 196.4 | | 170.0 | 331.5 | 182.2 | 157.9 f16 B=16, M=128, H=16, K=256 | | 791.8 | | 545.0 | 774.4 | f16 B=16, M=512, H=16, K=16 | 378.0 | | 375.1 | 1203.0 | 623.4 | 325.4 f16 B=16, M=512, H=16, K=32 | 436.7 | | 433.8 | 1306.2 | 711.6 | 432.0 f16 B=16, M=512, H=16, K=64 | 684.6 | | 680.6 | 1543.4 | 917.7 | 702.3 f16 B=16, M=512, H=16, K=128 | 1565.8 | | 1560.9 | 1983.1 | 1656.5 | 1580.7 f16 B=16, M=512, H=16, K=256 | | 8678.8 | | 2901.2 | 8652.9 | f16 B=16, M=1024, H=16, K=16 | 1432.3 | | 1435.8 | 4262.8 | 2422.2 | 1245.2 f16 B=16, M=1024, H=16, K=32 | 1608.0 | | 1605.9 | 4489.6 | 2697.4 | 1617.9 f16 B=16, M=1024, H=16, K=64 | 2549.3 | | 2541.9 | 4992.6 | 3341.4 | 2370.4 f16 B=16, M=1024, H=16, K=128 | 5034.4 | | 5023.6 | 5955.5 | 5903.8 | 5634.1 f16 B=16, M=1024, H=16, K=256 | | 31948.6 | | 7891.3 | 31850.4 | f16 B=64, M=128, H=16, K=16 | 193.9 | | 159.7 | 439.8 | 158.3 | 145.0 f16 B=64, M=128, H=16, K=32 | 217.0 | | 215.3 | 545.3 | 209.4 | 211.2 f16 B=64, M=128, H=16, K=64 | 316.0 | | 313.6 | 767.0 | 328.6 | 310.2 f16 B=64, M=128, H=16, K=128 | 559.7 | | 555.9 | 1229.2 | 637.6 | 561.6 f16 B=64, M=128, H=16, K=256 | | 2851.8 | | 2126.3 | 2839.0 | f16 B=64, M=512, H=16, K=16 | 1372.4 | | 1370.9 | 4484.9 | 2313.9 | 1219.4 f16 B=64, M=512, H=16, K=32 | 1542.0 | | 1539.2 | 4978.8 | 2706.7 | 1547.4 f16 B=64, M=512, H=16, K=64 | 2335.4 | | 2330.0 | 5885.5 | 3464.0 | 2417.8 f16 B=64, M=512, H=16, K=128 | 5357.0 | | 5347.8 | 7711.8 | 5909.5 | 5442.7 f16 B=64, M=512, H=16, K=256 | | 31136.1 | | 11501.8 | 30488.7 | f16 B=64, M=1024, H=16, K=16 | 5226.1 | | 5228.4 | 16908.5 | 9092.8 | 4713.7 f16 B=64, M=1024, H=16, K=32 | 5717.3 | | 5723.1 | 17895.8 | 10439.0 | 5705.3 f16 B=64, M=1024, H=16, K=64 | 8672.2 | | 8664.4 | 19929.7 | 12937.8 | 8137.3 f16 B=64, M=1024, H=16, K=128 | 19104.3 | | 19080.9 | 23711.0 | 20963.4 | 19169.7 f16 B=64, M=1024, H=16, K=256 | | 115163.2 | | 32729.8 | 113715.6 | Times are in microseconds (us). ``` </details> [ghstack-poisoned]
<details> <summary>BW perf on A100</summary> ``` [-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------] | refactor[flshattB] | refactor[cutlassB] | flash2[flshatt] | vanilla | cutlass[cutlass] | flash1[flshatt] 1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 449.4 | | 447.6 | 2260.4 | 621.0 | f16 B=384, M=197, H=1, K=80 | 433.8 | | 431.8 | 1916.2 | 589.4 | f16 B=384, M=197, H=1, K=64 | 234.5 | | 233.6 | 1810.0 | 406.2 | 232.8 f16 B=1024, M=197, H=1, K=88 | 1112.9 | | 1111.3 | 5945.9 | 1557.7 | f16 B=1024, M=197, H=1, K=80 | 1073.3 | | 1073.5 | 5022.5 | 1480.4 | f16 B=1024, M=197, H=1, K=64 | 579.9 | | 579.4 | 4736.3 | 930.1 | 578.3 f16 B=512, M=197, H=1, K=80 | 546.0 | | 545.1 | 2539.2 | 749.0 | f16 B=32, M=197, H=16, K=80 | 560.2 | | 558.7 | 2579.7 | 748.1 | f16 B=32, M=197, H=16, K=64 | 298.4 | | 297.8 | 2430.3 | 480.1 | 295.8 f16 B=32, M=197, H=16, K=128 | 652.5 | | 647.4 | 4494.1 | 896.4 | 684.3 f16 B=256, M=197, H=1, K=88 | 330.1 | | 328.8 | 1527.1 | 449.9 | f16 B=16, M=197, H=16, K=88 | 331.2 | | 329.1 | 1540.0 | 445.2 | f16 B=16, M=197, H=16, K=64 | 198.7 | | 166.9 | 1243.8 | 246.4 | 165.9 f16 B=16, M=197, H=16, K=128 | 371.7 | | 367.4 | 2266.8 | 503.6 | 386.4 f16 B=1, M=4096, H=160, K=128 | 42897.1 | | 42915.9 | 45961.0 | 56457.5 | 54740.1 f16 B=2, M=4096, H=160, K=128 | 84681.5 | | 84646.9 | | 93553.0 | 84262.0 f16 B=1, M=8192, H=160, K=128 | 165741.2 | | 165657.6 | | 222850.4 | 215961.0 f16 B=2, M=8192, H=160, K=128 | 332314.2 | | 332593.3 | | 364220.8 | 330727.9 f16 B=1024, M=82, H=8, K=64 | 1647.0 | | 1643.6 | 3821.1 | 1886.9 | 1621.7 f16 B=150, M=256, H=16, K=64 | 1632.0 | | 1626.4 | 4556.0 | 2302.2 | 1626.6 f16 B=64, M=256, H=12, K=64 | 570.4 | | 566.8 | 1499.1 | 778.4 | 567.1 f16 B=1, M=4096, H=16, K=40 | 2153.5 | | 2154.7 | 4193.7 | 23526.2 | f16 B=1, M=16384, H=16, K=40 | 27859.4 | | 27850.3 | | 389891.9 | f16 B=256, M=4096, H=16, K=64 | 491968.1 | | 491246.6 | | 729981.6 | 440392.8 f16 B=16, M=128, H=16, K=16 | 195.8 | | 118.3 | 276.8 | 129.3 | 120.4 f16 B=16, M=128, H=16, K=32 | 197.7 | | 117.6 | 276.8 | 128.8 | 113.7 f16 B=16, M=128, H=16, K=64 | 196.0 | | 122.9 | 324.0 | 129.9 | 113.6 f16 B=16, M=128, H=16, K=128 | 196.4 | | 170.0 | 331.5 | 182.2 | 157.9 f16 B=16, M=128, H=16, K=256 | | 791.8 | | 545.0 | 774.4 | f16 B=16, M=512, H=16, K=16 | 378.0 | | 375.1 | 1203.0 | 623.4 | 325.4 f16 B=16, M=512, H=16, K=32 | 436.7 | | 433.8 | 1306.2 | 711.6 | 432.0 f16 B=16, M=512, H=16, K=64 | 684.6 | | 680.6 | 1543.4 | 917.7 | 702.3 f16 B=16, M=512, H=16, K=128 | 1565.8 | | 1560.9 | 1983.1 | 1656.5 | 1580.7 f16 B=16, M=512, H=16, K=256 | | 8678.8 | | 2901.2 | 8652.9 | f16 B=16, M=1024, H=16, K=16 | 1432.3 | | 1435.8 | 4262.8 | 2422.2 | 1245.2 f16 B=16, M=1024, H=16, K=32 | 1608.0 | | 1605.9 | 4489.6 | 2697.4 | 1617.9 f16 B=16, M=1024, H=16, K=64 | 2549.3 | | 2541.9 | 4992.6 | 3341.4 | 2370.4 f16 B=16, M=1024, H=16, K=128 | 5034.4 | | 5023.6 | 5955.5 | 5903.8 | 5634.1 f16 B=16, M=1024, H=16, K=256 | | 31948.6 | | 7891.3 | 31850.4 | f16 B=64, M=128, H=16, K=16 | 193.9 | | 159.7 | 439.8 | 158.3 | 145.0 f16 B=64, M=128, H=16, K=32 | 217.0 | | 215.3 | 545.3 | 209.4 | 211.2 f16 B=64, M=128, H=16, K=64 | 316.0 | | 313.6 | 767.0 | 328.6 | 310.2 f16 B=64, M=128, H=16, K=128 | 559.7 | | 555.9 | 1229.2 | 637.6 | 561.6 f16 B=64, M=128, H=16, K=256 | | 2851.8 | | 2126.3 | 2839.0 | f16 B=64, M=512, H=16, K=16 | 1372.4 | | 1370.9 | 4484.9 | 2313.9 | 1219.4 f16 B=64, M=512, H=16, K=32 | 1542.0 | | 1539.2 | 4978.8 | 2706.7 | 1547.4 f16 B=64, M=512, H=16, K=64 | 2335.4 | | 2330.0 | 5885.5 | 3464.0 | 2417.8 f16 B=64, M=512, H=16, K=128 | 5357.0 | | 5347.8 | 7711.8 | 5909.5 | 5442.7 f16 B=64, M=512, H=16, K=256 | | 31136.1 | | 11501.8 | 30488.7 | f16 B=64, M=1024, H=16, K=16 | 5226.1 | | 5228.4 | 16908.5 | 9092.8 | 4713.7 f16 B=64, M=1024, H=16, K=32 | 5717.3 | | 5723.1 | 17895.8 | 10439.0 | 5705.3 f16 B=64, M=1024, H=16, K=64 | 8672.2 | | 8664.4 | 19929.7 | 12937.8 | 8137.3 f16 B=64, M=1024, H=16, K=128 | 19104.3 | | 19080.9 | 23711.0 | 20963.4 | 19169.7 f16 B=64, M=1024, H=16, K=256 | | 115163.2 | | 32729.8 | 113715.6 | Times are in microseconds (us). ``` </details> [ghstack-poisoned]
<details> <summary>BW perf on A100</summary> ``` [-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------] | refactor[flshattB] | refactor[cutlassB] | flash2[flshatt] | vanilla | cutlass[cutlass] | flash1[flshatt] 1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 449.4 | | 447.6 | 2260.4 | 621.0 | f16 B=384, M=197, H=1, K=80 | 433.8 | | 431.8 | 1916.2 | 589.4 | f16 B=384, M=197, H=1, K=64 | 234.5 | | 233.6 | 1810.0 | 406.2 | 232.8 f16 B=1024, M=197, H=1, K=88 | 1112.9 | | 1111.3 | 5945.9 | 1557.7 | f16 B=1024, M=197, H=1, K=80 | 1073.3 | | 1073.5 | 5022.5 | 1480.4 | f16 B=1024, M=197, H=1, K=64 | 579.9 | | 579.4 | 4736.3 | 930.1 | 578.3 f16 B=512, M=197, H=1, K=80 | 546.0 | | 545.1 | 2539.2 | 749.0 | f16 B=32, M=197, H=16, K=80 | 560.2 | | 558.7 | 2579.7 | 748.1 | f16 B=32, M=197, H=16, K=64 | 298.4 | | 297.8 | 2430.3 | 480.1 | 295.8 f16 B=32, M=197, H=16, K=128 | 652.5 | | 647.4 | 4494.1 | 896.4 | 684.3 f16 B=256, M=197, H=1, K=88 | 330.1 | | 328.8 | 1527.1 | 449.9 | f16 B=16, M=197, H=16, K=88 | 331.2 | | 329.1 | 1540.0 | 445.2 | f16 B=16, M=197, H=16, K=64 | 198.7 | | 166.9 | 1243.8 | 246.4 | 165.9 f16 B=16, M=197, H=16, K=128 | 371.7 | | 367.4 | 2266.8 | 503.6 | 386.4 f16 B=1, M=4096, H=160, K=128 | 42897.1 | | 42915.9 | 45961.0 | 56457.5 | 54740.1 f16 B=2, M=4096, H=160, K=128 | 84681.5 | | 84646.9 | | 93553.0 | 84262.0 f16 B=1, M=8192, H=160, K=128 | 165741.2 | | 165657.6 | | 222850.4 | 215961.0 f16 B=2, M=8192, H=160, K=128 | 332314.2 | | 332593.3 | | 364220.8 | 330727.9 f16 B=1024, M=82, H=8, K=64 | 1647.0 | | 1643.6 | 3821.1 | 1886.9 | 1621.7 f16 B=150, M=256, H=16, K=64 | 1632.0 | | 1626.4 | 4556.0 | 2302.2 | 1626.6 f16 B=64, M=256, H=12, K=64 | 570.4 | | 566.8 | 1499.1 | 778.4 | 567.1 f16 B=1, M=4096, H=16, K=40 | 2153.5 | | 2154.7 | 4193.7 | 23526.2 | f16 B=1, M=16384, H=16, K=40 | 27859.4 | | 27850.3 | | 389891.9 | f16 B=256, M=4096, H=16, K=64 | 491968.1 | | 491246.6 | | 729981.6 | 440392.8 f16 B=16, M=128, H=16, K=16 | 195.8 | | 118.3 | 276.8 | 129.3 | 120.4 f16 B=16, M=128, H=16, K=32 | 197.7 | | 117.6 | 276.8 | 128.8 | 113.7 f16 B=16, M=128, H=16, K=64 | 196.0 | | 122.9 | 324.0 | 129.9 | 113.6 f16 B=16, M=128, H=16, K=128 | 196.4 | | 170.0 | 331.5 | 182.2 | 157.9 f16 B=16, M=128, H=16, K=256 | | 791.8 | | 545.0 | 774.4 | f16 B=16, M=512, H=16, K=16 | 378.0 | | 375.1 | 1203.0 | 623.4 | 325.4 f16 B=16, M=512, H=16, K=32 | 436.7 | | 433.8 | 1306.2 | 711.6 | 432.0 f16 B=16, M=512, H=16, K=64 | 684.6 | | 680.6 | 1543.4 | 917.7 | 702.3 f16 B=16, M=512, H=16, K=128 | 1565.8 | | 1560.9 | 1983.1 | 1656.5 | 1580.7 f16 B=16, M=512, H=16, K=256 | | 8678.8 | | 2901.2 | 8652.9 | f16 B=16, M=1024, H=16, K=16 | 1432.3 | | 1435.8 | 4262.8 | 2422.2 | 1245.2 f16 B=16, M=1024, H=16, K=32 | 1608.0 | | 1605.9 | 4489.6 | 2697.4 | 1617.9 f16 B=16, M=1024, H=16, K=64 | 2549.3 | | 2541.9 | 4992.6 | 3341.4 | 2370.4 f16 B=16, M=1024, H=16, K=128 | 5034.4 | | 5023.6 | 5955.5 | 5903.8 | 5634.1 f16 B=16, M=1024, H=16, K=256 | | 31948.6 | | 7891.3 | 31850.4 | f16 B=64, M=128, H=16, K=16 | 193.9 | | 159.7 | 439.8 | 158.3 | 145.0 f16 B=64, M=128, H=16, K=32 | 217.0 | | 215.3 | 545.3 | 209.4 | 211.2 f16 B=64, M=128, H=16, K=64 | 316.0 | | 313.6 | 767.0 | 328.6 | 310.2 f16 B=64, M=128, H=16, K=128 | 559.7 | | 555.9 | 1229.2 | 637.6 | 561.6 f16 B=64, M=128, H=16, K=256 | | 2851.8 | | 2126.3 | 2839.0 | f16 B=64, M=512, H=16, K=16 | 1372.4 | | 1370.9 | 4484.9 | 2313.9 | 1219.4 f16 B=64, M=512, H=16, K=32 | 1542.0 | | 1539.2 | 4978.8 | 2706.7 | 1547.4 f16 B=64, M=512, H=16, K=64 | 2335.4 | | 2330.0 | 5885.5 | 3464.0 | 2417.8 f16 B=64, M=512, H=16, K=128 | 5357.0 | | 5347.8 | 7711.8 | 5909.5 | 5442.7 f16 B=64, M=512, H=16, K=256 | | 31136.1 | | 11501.8 | 30488.7 | f16 B=64, M=1024, H=16, K=16 | 5226.1 | | 5228.4 | 16908.5 | 9092.8 | 4713.7 f16 B=64, M=1024, H=16, K=32 | 5717.3 | | 5723.1 | 17895.8 | 10439.0 | 5705.3 f16 B=64, M=1024, H=16, K=64 | 8672.2 | | 8664.4 | 19929.7 | 12937.8 | 8137.3 f16 B=64, M=1024, H=16, K=128 | 19104.3 | | 19080.9 | 23711.0 | 20963.4 | 19169.7 f16 B=64, M=1024, H=16, K=256 | | 115163.2 | | 32729.8 | 113715.6 | Times are in microseconds (us). ``` </details> [ghstack-poisoned]
<details> <summary>BW perf on A100</summary> ``` [-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------] | refactor[flshattB] | refactor[cutlassB] | flash2[flshatt] | vanilla | cutlass[cutlass] | flash1[flshatt] 1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 449.4 | | 447.6 | 2260.4 | 621.0 | f16 B=384, M=197, H=1, K=80 | 433.8 | | 431.8 | 1916.2 | 589.4 | f16 B=384, M=197, H=1, K=64 | 234.5 | | 233.6 | 1810.0 | 406.2 | 232.8 f16 B=1024, M=197, H=1, K=88 | 1112.9 | | 1111.3 | 5945.9 | 1557.7 | f16 B=1024, M=197, H=1, K=80 | 1073.3 | | 1073.5 | 5022.5 | 1480.4 | f16 B=1024, M=197, H=1, K=64 | 579.9 | | 579.4 | 4736.3 | 930.1 | 578.3 f16 B=512, M=197, H=1, K=80 | 546.0 | | 545.1 | 2539.2 | 749.0 | f16 B=32, M=197, H=16, K=80 | 560.2 | | 558.7 | 2579.7 | 748.1 | f16 B=32, M=197, H=16, K=64 | 298.4 | | 297.8 | 2430.3 | 480.1 | 295.8 f16 B=32, M=197, H=16, K=128 | 652.5 | | 647.4 | 4494.1 | 896.4 | 684.3 f16 B=256, M=197, H=1, K=88 | 330.1 | | 328.8 | 1527.1 | 449.9 | f16 B=16, M=197, H=16, K=88 | 331.2 | | 329.1 | 1540.0 | 445.2 | f16 B=16, M=197, H=16, K=64 | 198.7 | | 166.9 | 1243.8 | 246.4 | 165.9 f16 B=16, M=197, H=16, K=128 | 371.7 | | 367.4 | 2266.8 | 503.6 | 386.4 f16 B=1, M=4096, H=160, K=128 | 42897.1 | | 42915.9 | 45961.0 | 56457.5 | 54740.1 f16 B=2, M=4096, H=160, K=128 | 84681.5 | | 84646.9 | | 93553.0 | 84262.0 f16 B=1, M=8192, H=160, K=128 | 165741.2 | | 165657.6 | | 222850.4 | 215961.0 f16 B=2, M=8192, H=160, K=128 | 332314.2 | | 332593.3 | | 364220.8 | 330727.9 f16 B=1024, M=82, H=8, K=64 | 1647.0 | | 1643.6 | 3821.1 | 1886.9 | 1621.7 f16 B=150, M=256, H=16, K=64 | 1632.0 | | 1626.4 | 4556.0 | 2302.2 | 1626.6 f16 B=64, M=256, H=12, K=64 | 570.4 | | 566.8 | 1499.1 | 778.4 | 567.1 f16 B=1, M=4096, H=16, K=40 | 2153.5 | | 2154.7 | 4193.7 | 23526.2 | f16 B=1, M=16384, H=16, K=40 | 27859.4 | | 27850.3 | | 389891.9 | f16 B=256, M=4096, H=16, K=64 | 491968.1 | | 491246.6 | | 729981.6 | 440392.8 f16 B=16, M=128, H=16, K=16 | 195.8 | | 118.3 | 276.8 | 129.3 | 120.4 f16 B=16, M=128, H=16, K=32 | 197.7 | | 117.6 | 276.8 | 128.8 | 113.7 f16 B=16, M=128, H=16, K=64 | 196.0 | | 122.9 | 324.0 | 129.9 | 113.6 f16 B=16, M=128, H=16, K=128 | 196.4 | | 170.0 | 331.5 | 182.2 | 157.9 f16 B=16, M=128, H=16, K=256 | | 791.8 | | 545.0 | 774.4 | f16 B=16, M=512, H=16, K=16 | 378.0 | | 375.1 | 1203.0 | 623.4 | 325.4 f16 B=16, M=512, H=16, K=32 | 436.7 | | 433.8 | 1306.2 | 711.6 | 432.0 f16 B=16, M=512, H=16, K=64 | 684.6 | | 680.6 | 1543.4 | 917.7 | 702.3 f16 B=16, M=512, H=16, K=128 | 1565.8 | | 1560.9 | 1983.1 | 1656.5 | 1580.7 f16 B=16, M=512, H=16, K=256 | | 8678.8 | | 2901.2 | 8652.9 | f16 B=16, M=1024, H=16, K=16 | 1432.3 | | 1435.8 | 4262.8 | 2422.2 | 1245.2 f16 B=16, M=1024, H=16, K=32 | 1608.0 | | 1605.9 | 4489.6 | 2697.4 | 1617.9 f16 B=16, M=1024, H=16, K=64 | 2549.3 | | 2541.9 | 4992.6 | 3341.4 | 2370.4 f16 B=16, M=1024, H=16, K=128 | 5034.4 | | 5023.6 | 5955.5 | 5903.8 | 5634.1 f16 B=16, M=1024, H=16, K=256 | | 31948.6 | | 7891.3 | 31850.4 | f16 B=64, M=128, H=16, K=16 | 193.9 | | 159.7 | 439.8 | 158.3 | 145.0 f16 B=64, M=128, H=16, K=32 | 217.0 | | 215.3 | 545.3 | 209.4 | 211.2 f16 B=64, M=128, H=16, K=64 | 316.0 | | 313.6 | 767.0 | 328.6 | 310.2 f16 B=64, M=128, H=16, K=128 | 559.7 | | 555.9 | 1229.2 | 637.6 | 561.6 f16 B=64, M=128, H=16, K=256 | | 2851.8 | | 2126.3 | 2839.0 | f16 B=64, M=512, H=16, K=16 | 1372.4 | | 1370.9 | 4484.9 | 2313.9 | 1219.4 f16 B=64, M=512, H=16, K=32 | 1542.0 | | 1539.2 | 4978.8 | 2706.7 | 1547.4 f16 B=64, M=512, H=16, K=64 | 2335.4 | | 2330.0 | 5885.5 | 3464.0 | 2417.8 f16 B=64, M=512, H=16, K=128 | 5357.0 | | 5347.8 | 7711.8 | 5909.5 | 5442.7 f16 B=64, M=512, H=16, K=256 | | 31136.1 | | 11501.8 | 30488.7 | f16 B=64, M=1024, H=16, K=16 | 5226.1 | | 5228.4 | 16908.5 | 9092.8 | 4713.7 f16 B=64, M=1024, H=16, K=32 | 5717.3 | | 5723.1 | 17895.8 | 10439.0 | 5705.3 f16 B=64, M=1024, H=16, K=64 | 8672.2 | | 8664.4 | 19929.7 | 12937.8 | 8137.3 f16 B=64, M=1024, H=16, K=128 | 19104.3 | | 19080.9 | 23711.0 | 20963.4 | 19169.7 f16 B=64, M=1024, H=16, K=256 | | 115163.2 | | 32729.8 | 113715.6 | Times are in microseconds (us). ``` </details> [ghstack-poisoned]
NOTE: Dropout is entirely disabled until we figure out how to implement it properly <details> <summary>BW perf on A100</summary> ``` [-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------] | refactor[flshattB] | refactor[cutlassB] | flash2[flshatt] | vanilla | cutlass[cutlass] | flash1[flshatt] 1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 449.4 | | 447.6 | 2260.4 | 621.0 | f16 B=384, M=197, H=1, K=80 | 433.8 | | 431.8 | 1916.2 | 589.4 | f16 B=384, M=197, H=1, K=64 | 234.5 | | 233.6 | 1810.0 | 406.2 | 232.8 f16 B=1024, M=197, H=1, K=88 | 1112.9 | | 1111.3 | 5945.9 | 1557.7 | f16 B=1024, M=197, H=1, K=80 | 1073.3 | | 1073.5 | 5022.5 | 1480.4 | f16 B=1024, M=197, H=1, K=64 | 579.9 | | 579.4 | 4736.3 | 930.1 | 578.3 f16 B=512, M=197, H=1, K=80 | 546.0 | | 545.1 | 2539.2 | 749.0 | f16 B=32, M=197, H=16, K=80 | 560.2 | | 558.7 | 2579.7 | 748.1 | f16 B=32, M=197, H=16, K=64 | 298.4 | | 297.8 | 2430.3 | 480.1 | 295.8 f16 B=32, M=197, H=16, K=128 | 652.5 | | 647.4 | 4494.1 | 896.4 | 684.3 f16 B=256, M=197, H=1, K=88 | 330.1 | | 328.8 | 1527.1 | 449.9 | f16 B=16, M=197, H=16, K=88 | 331.2 | | 329.1 | 1540.0 | 445.2 | f16 B=16, M=197, H=16, K=64 | 198.7 | | 166.9 | 1243.8 | 246.4 | 165.9 f16 B=16, M=197, H=16, K=128 | 371.7 | | 367.4 | 2266.8 | 503.6 | 386.4 f16 B=1, M=4096, H=160, K=128 | 42897.1 | | 42915.9 | 45961.0 | 56457.5 | 54740.1 f16 B=2, M=4096, H=160, K=128 | 84681.5 | | 84646.9 | | 93553.0 | 84262.0 f16 B=1, M=8192, H=160, K=128 | 165741.2 | | 165657.6 | | 222850.4 | 215961.0 f16 B=2, M=8192, H=160, K=128 | 332314.2 | | 332593.3 | | 364220.8 | 330727.9 f16 B=1024, M=82, H=8, K=64 | 1647.0 | | 1643.6 | 3821.1 | 1886.9 | 1621.7 f16 B=150, M=256, H=16, K=64 | 1632.0 | | 1626.4 | 4556.0 | 2302.2 | 1626.6 f16 B=64, M=256, H=12, K=64 | 570.4 | | 566.8 | 1499.1 | 778.4 | 567.1 f16 B=1, M=4096, H=16, K=40 | 2153.5 | | 2154.7 | 4193.7 | 23526.2 | f16 B=1, M=16384, H=16, K=40 | 27859.4 | | 27850.3 | | 389891.9 | f16 B=256, M=4096, H=16, K=64 | 491968.1 | | 491246.6 | | 729981.6 | 440392.8 f16 B=16, M=128, H=16, K=16 | 195.8 | | 118.3 | 276.8 | 129.3 | 120.4 f16 B=16, M=128, H=16, K=32 | 197.7 | | 117.6 | 276.8 | 128.8 | 113.7 f16 B=16, M=128, H=16, K=64 | 196.0 | | 122.9 | 324.0 | 129.9 | 113.6 f16 B=16, M=128, H=16, K=128 | 196.4 | | 170.0 | 331.5 | 182.2 | 157.9 f16 B=16, M=128, H=16, K=256 | | 791.8 | | 545.0 | 774.4 | f16 B=16, M=512, H=16, K=16 | 378.0 | | 375.1 | 1203.0 | 623.4 | 325.4 f16 B=16, M=512, H=16, K=32 | 436.7 | | 433.8 | 1306.2 | 711.6 | 432.0 f16 B=16, M=512, H=16, K=64 | 684.6 | | 680.6 | 1543.4 | 917.7 | 702.3 f16 B=16, M=512, H=16, K=128 | 1565.8 | | 1560.9 | 1983.1 | 1656.5 | 1580.7 f16 B=16, M=512, H=16, K=256 | | 8678.8 | | 2901.2 | 8652.9 | f16 B=16, M=1024, H=16, K=16 | 1432.3 | | 1435.8 | 4262.8 | 2422.2 | 1245.2 f16 B=16, M=1024, H=16, K=32 | 1608.0 | | 1605.9 | 4489.6 | 2697.4 | 1617.9 f16 B=16, M=1024, H=16, K=64 | 2549.3 | | 2541.9 | 4992.6 | 3341.4 | 2370.4 f16 B=16, M=1024, H=16, K=128 | 5034.4 | | 5023.6 | 5955.5 | 5903.8 | 5634.1 f16 B=16, M=1024, H=16, K=256 | | 31948.6 | | 7891.3 | 31850.4 | f16 B=64, M=128, H=16, K=16 | 193.9 | | 159.7 | 439.8 | 158.3 | 145.0 f16 B=64, M=128, H=16, K=32 | 217.0 | | 215.3 | 545.3 | 209.4 | 211.2 f16 B=64, M=128, H=16, K=64 | 316.0 | | 313.6 | 767.0 | 328.6 | 310.2 f16 B=64, M=128, H=16, K=128 | 559.7 | | 555.9 | 1229.2 | 637.6 | 561.6 f16 B=64, M=128, H=16, K=256 | | 2851.8 | | 2126.3 | 2839.0 | f16 B=64, M=512, H=16, K=16 | 1372.4 | | 1370.9 | 4484.9 | 2313.9 | 1219.4 f16 B=64, M=512, H=16, K=32 | 1542.0 | | 1539.2 | 4978.8 | 2706.7 | 1547.4 f16 B=64, M=512, H=16, K=64 | 2335.4 | | 2330.0 | 5885.5 | 3464.0 | 2417.8 f16 B=64, M=512, H=16, K=128 | 5357.0 | | 5347.8 | 7711.8 | 5909.5 | 5442.7 f16 B=64, M=512, H=16, K=256 | | 31136.1 | | 11501.8 | 30488.7 | f16 B=64, M=1024, H=16, K=16 | 5226.1 | | 5228.4 | 16908.5 | 9092.8 | 4713.7 f16 B=64, M=1024, H=16, K=32 | 5717.3 | | 5723.1 | 17895.8 | 10439.0 | 5705.3 f16 B=64, M=1024, H=16, K=64 | 8672.2 | | 8664.4 | 19929.7 | 12937.8 | 8137.3 f16 B=64, M=1024, H=16, K=128 | 19104.3 | | 19080.9 | 23711.0 | 20963.4 | 19169.7 f16 B=64, M=1024, H=16, K=256 | | 115163.2 | | 32729.8 | 113715.6 | Times are in microseconds (us). ``` </details> [ghstack-poisoned]
NOTE: Dropout is entirely disabled until we figure out how to implement it properly <details> <summary>BW perf on A100</summary> ``` [-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------] | refactor[flshattB] | refactor[cutlassB] | flash2[flshatt] | vanilla | cutlass[cutlass] | flash1[flshatt] 1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 449.4 | | 447.6 | 2260.4 | 621.0 | f16 B=384, M=197, H=1, K=80 | 433.8 | | 431.8 | 1916.2 | 589.4 | f16 B=384, M=197, H=1, K=64 | 234.5 | | 233.6 | 1810.0 | 406.2 | 232.8 f16 B=1024, M=197, H=1, K=88 | 1112.9 | | 1111.3 | 5945.9 | 1557.7 | f16 B=1024, M=197, H=1, K=80 | 1073.3 | | 1073.5 | 5022.5 | 1480.4 | f16 B=1024, M=197, H=1, K=64 | 579.9 | | 579.4 | 4736.3 | 930.1 | 578.3 f16 B=512, M=197, H=1, K=80 | 546.0 | | 545.1 | 2539.2 | 749.0 | f16 B=32, M=197, H=16, K=80 | 560.2 | | 558.7 | 2579.7 | 748.1 | f16 B=32, M=197, H=16, K=64 | 298.4 | | 297.8 | 2430.3 | 480.1 | 295.8 f16 B=32, M=197, H=16, K=128 | 652.5 | | 647.4 | 4494.1 | 896.4 | 684.3 f16 B=256, M=197, H=1, K=88 | 330.1 | | 328.8 | 1527.1 | 449.9 | f16 B=16, M=197, H=16, K=88 | 331.2 | | 329.1 | 1540.0 | 445.2 | f16 B=16, M=197, H=16, K=64 | 198.7 | | 166.9 | 1243.8 | 246.4 | 165.9 f16 B=16, M=197, H=16, K=128 | 371.7 | | 367.4 | 2266.8 | 503.6 | 386.4 f16 B=1, M=4096, H=160, K=128 | 42897.1 | | 42915.9 | 45961.0 | 56457.5 | 54740.1 f16 B=2, M=4096, H=160, K=128 | 84681.5 | | 84646.9 | | 93553.0 | 84262.0 f16 B=1, M=8192, H=160, K=128 | 165741.2 | | 165657.6 | | 222850.4 | 215961.0 f16 B=2, M=8192, H=160, K=128 | 332314.2 | | 332593.3 | | 364220.8 | 330727.9 f16 B=1024, M=82, H=8, K=64 | 1647.0 | | 1643.6 | 3821.1 | 1886.9 | 1621.7 f16 B=150, M=256, H=16, K=64 | 1632.0 | | 1626.4 | 4556.0 | 2302.2 | 1626.6 f16 B=64, M=256, H=12, K=64 | 570.4 | | 566.8 | 1499.1 | 778.4 | 567.1 f16 B=1, M=4096, H=16, K=40 | 2153.5 | | 2154.7 | 4193.7 | 23526.2 | f16 B=1, M=16384, H=16, K=40 | 27859.4 | | 27850.3 | | 389891.9 | f16 B=256, M=4096, H=16, K=64 | 491968.1 | | 491246.6 | | 729981.6 | 440392.8 f16 B=16, M=128, H=16, K=16 | 195.8 | | 118.3 | 276.8 | 129.3 | 120.4 f16 B=16, M=128, H=16, K=32 | 197.7 | | 117.6 | 276.8 | 128.8 | 113.7 f16 B=16, M=128, H=16, K=64 | 196.0 | | 122.9 | 324.0 | 129.9 | 113.6 f16 B=16, M=128, H=16, K=128 | 196.4 | | 170.0 | 331.5 | 182.2 | 157.9 f16 B=16, M=128, H=16, K=256 | | 791.8 | | 545.0 | 774.4 | f16 B=16, M=512, H=16, K=16 | 378.0 | | 375.1 | 1203.0 | 623.4 | 325.4 f16 B=16, M=512, H=16, K=32 | 436.7 | | 433.8 | 1306.2 | 711.6 | 432.0 f16 B=16, M=512, H=16, K=64 | 684.6 | | 680.6 | 1543.4 | 917.7 | 702.3 f16 B=16, M=512, H=16, K=128 | 1565.8 | | 1560.9 | 1983.1 | 1656.5 | 1580.7 f16 B=16, M=512, H=16, K=256 | | 8678.8 | | 2901.2 | 8652.9 | f16 B=16, M=1024, H=16, K=16 | 1432.3 | | 1435.8 | 4262.8 | 2422.2 | 1245.2 f16 B=16, M=1024, H=16, K=32 | 1608.0 | | 1605.9 | 4489.6 | 2697.4 | 1617.9 f16 B=16, M=1024, H=16, K=64 | 2549.3 | | 2541.9 | 4992.6 | 3341.4 | 2370.4 f16 B=16, M=1024, H=16, K=128 | 5034.4 | | 5023.6 | 5955.5 | 5903.8 | 5634.1 f16 B=16, M=1024, H=16, K=256 | | 31948.6 | | 7891.3 | 31850.4 | f16 B=64, M=128, H=16, K=16 | 193.9 | | 159.7 | 439.8 | 158.3 | 145.0 f16 B=64, M=128, H=16, K=32 | 217.0 | | 215.3 | 545.3 | 209.4 | 211.2 f16 B=64, M=128, H=16, K=64 | 316.0 | | 313.6 | 767.0 | 328.6 | 310.2 f16 B=64, M=128, H=16, K=128 | 559.7 | | 555.9 | 1229.2 | 637.6 | 561.6 f16 B=64, M=128, H=16, K=256 | | 2851.8 | | 2126.3 | 2839.0 | f16 B=64, M=512, H=16, K=16 | 1372.4 | | 1370.9 | 4484.9 | 2313.9 | 1219.4 f16 B=64, M=512, H=16, K=32 | 1542.0 | | 1539.2 | 4978.8 | 2706.7 | 1547.4 f16 B=64, M=512, H=16, K=64 | 2335.4 | | 2330.0 | 5885.5 | 3464.0 | 2417.8 f16 B=64, M=512, H=16, K=128 | 5357.0 | | 5347.8 | 7711.8 | 5909.5 | 5442.7 f16 B=64, M=512, H=16, K=256 | | 31136.1 | | 11501.8 | 30488.7 | f16 B=64, M=1024, H=16, K=16 | 5226.1 | | 5228.4 | 16908.5 | 9092.8 | 4713.7 f16 B=64, M=1024, H=16, K=32 | 5717.3 | | 5723.1 | 17895.8 | 10439.0 | 5705.3 f16 B=64, M=1024, H=16, K=64 | 8672.2 | | 8664.4 | 19929.7 | 12937.8 | 8137.3 f16 B=64, M=1024, H=16, K=128 | 19104.3 | | 19080.9 | 23711.0 | 20963.4 | 19169.7 f16 B=64, M=1024, H=16, K=256 | | 115163.2 | | 32729.8 | 113715.6 | Times are in microseconds (us). ``` </details> [ghstack-poisoned]
NOTE: Dropout is entirely disabled until we figure out how to implement it properly <details> <summary>BW perf on A100</summary> ``` [-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------] | refactor[flshattB] | refactor[cutlassB] | flash2[flshatt] | vanilla | cutlass[cutlass] | flash1[flshatt] 1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 449.4 | | 447.6 | 2260.4 | 621.0 | f16 B=384, M=197, H=1, K=80 | 433.8 | | 431.8 | 1916.2 | 589.4 | f16 B=384, M=197, H=1, K=64 | 234.5 | | 233.6 | 1810.0 | 406.2 | 232.8 f16 B=1024, M=197, H=1, K=88 | 1112.9 | | 1111.3 | 5945.9 | 1557.7 | f16 B=1024, M=197, H=1, K=80 | 1073.3 | | 1073.5 | 5022.5 | 1480.4 | f16 B=1024, M=197, H=1, K=64 | 579.9 | | 579.4 | 4736.3 | 930.1 | 578.3 f16 B=512, M=197, H=1, K=80 | 546.0 | | 545.1 | 2539.2 | 749.0 | f16 B=32, M=197, H=16, K=80 | 560.2 | | 558.7 | 2579.7 | 748.1 | f16 B=32, M=197, H=16, K=64 | 298.4 | | 297.8 | 2430.3 | 480.1 | 295.8 f16 B=32, M=197, H=16, K=128 | 652.5 | | 647.4 | 4494.1 | 896.4 | 684.3 f16 B=256, M=197, H=1, K=88 | 330.1 | | 328.8 | 1527.1 | 449.9 | f16 B=16, M=197, H=16, K=88 | 331.2 | | 329.1 | 1540.0 | 445.2 | f16 B=16, M=197, H=16, K=64 | 198.7 | | 166.9 | 1243.8 | 246.4 | 165.9 f16 B=16, M=197, H=16, K=128 | 371.7 | | 367.4 | 2266.8 | 503.6 | 386.4 f16 B=1, M=4096, H=160, K=128 | 42897.1 | | 42915.9 | 45961.0 | 56457.5 | 54740.1 f16 B=2, M=4096, H=160, K=128 | 84681.5 | | 84646.9 | | 93553.0 | 84262.0 f16 B=1, M=8192, H=160, K=128 | 165741.2 | | 165657.6 | | 222850.4 | 215961.0 f16 B=2, M=8192, H=160, K=128 | 332314.2 | | 332593.3 | | 364220.8 | 330727.9 f16 B=1024, M=82, H=8, K=64 | 1647.0 | | 1643.6 | 3821.1 | 1886.9 | 1621.7 f16 B=150, M=256, H=16, K=64 | 1632.0 | | 1626.4 | 4556.0 | 2302.2 | 1626.6 f16 B=64, M=256, H=12, K=64 | 570.4 | | 566.8 | 1499.1 | 778.4 | 567.1 f16 B=1, M=4096, H=16, K=40 | 2153.5 | | 2154.7 | 4193.7 | 23526.2 | f16 B=1, M=16384, H=16, K=40 | 27859.4 | | 27850.3 | | 389891.9 | f16 B=256, M=4096, H=16, K=64 | 491968.1 | | 491246.6 | | 729981.6 | 440392.8 f16 B=16, M=128, H=16, K=16 | 195.8 | | 118.3 | 276.8 | 129.3 | 120.4 f16 B=16, M=128, H=16, K=32 | 197.7 | | 117.6 | 276.8 | 128.8 | 113.7 f16 B=16, M=128, H=16, K=64 | 196.0 | | 122.9 | 324.0 | 129.9 | 113.6 f16 B=16, M=128, H=16, K=128 | 196.4 | | 170.0 | 331.5 | 182.2 | 157.9 f16 B=16, M=128, H=16, K=256 | | 791.8 | | 545.0 | 774.4 | f16 B=16, M=512, H=16, K=16 | 378.0 | | 375.1 | 1203.0 | 623.4 | 325.4 f16 B=16, M=512, H=16, K=32 | 436.7 | | 433.8 | 1306.2 | 711.6 | 432.0 f16 B=16, M=512, H=16, K=64 | 684.6 | | 680.6 | 1543.4 | 917.7 | 702.3 f16 B=16, M=512, H=16, K=128 | 1565.8 | | 1560.9 | 1983.1 | 1656.5 | 1580.7 f16 B=16, M=512, H=16, K=256 | | 8678.8 | | 2901.2 | 8652.9 | f16 B=16, M=1024, H=16, K=16 | 1432.3 | | 1435.8 | 4262.8 | 2422.2 | 1245.2 f16 B=16, M=1024, H=16, K=32 | 1608.0 | | 1605.9 | 4489.6 | 2697.4 | 1617.9 f16 B=16, M=1024, H=16, K=64 | 2549.3 | | 2541.9 | 4992.6 | 3341.4 | 2370.4 f16 B=16, M=1024, H=16, K=128 | 5034.4 | | 5023.6 | 5955.5 | 5903.8 | 5634.1 f16 B=16, M=1024, H=16, K=256 | | 31948.6 | | 7891.3 | 31850.4 | f16 B=64, M=128, H=16, K=16 | 193.9 | | 159.7 | 439.8 | 158.3 | 145.0 f16 B=64, M=128, H=16, K=32 | 217.0 | | 215.3 | 545.3 | 209.4 | 211.2 f16 B=64, M=128, H=16, K=64 | 316.0 | | 313.6 | 767.0 | 328.6 | 310.2 f16 B=64, M=128, H=16, K=128 | 559.7 | | 555.9 | 1229.2 | 637.6 | 561.6 f16 B=64, M=128, H=16, K=256 | | 2851.8 | | 2126.3 | 2839.0 | f16 B=64, M=512, H=16, K=16 | 1372.4 | | 1370.9 | 4484.9 | 2313.9 | 1219.4 f16 B=64, M=512, H=16, K=32 | 1542.0 | | 1539.2 | 4978.8 | 2706.7 | 1547.4 f16 B=64, M=512, H=16, K=64 | 2335.4 | | 2330.0 | 5885.5 | 3464.0 | 2417.8 f16 B=64, M=512, H=16, K=128 | 5357.0 | | 5347.8 | 7711.8 | 5909.5 | 5442.7 f16 B=64, M=512, H=16, K=256 | | 31136.1 | | 11501.8 | 30488.7 | f16 B=64, M=1024, H=16, K=16 | 5226.1 | | 5228.4 | 16908.5 | 9092.8 | 4713.7 f16 B=64, M=1024, H=16, K=32 | 5717.3 | | 5723.1 | 17895.8 | 10439.0 | 5705.3 f16 B=64, M=1024, H=16, K=64 | 8672.2 | | 8664.4 | 19929.7 | 12937.8 | 8137.3 f16 B=64, M=1024, H=16, K=128 | 19104.3 | | 19080.9 | 23711.0 | 20963.4 | 19169.7 f16 B=64, M=1024, H=16, K=256 | | 115163.2 | | 32729.8 | 113715.6 | Times are in microseconds (us). ``` </details> [ghstack-poisoned]
NOTE: Dropout is entirely disabled until we figure out how to implement it properly <details> <summary>BW perf on A100</summary> ``` [-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------] | refactor[flshattB] | refactor[cutlassB] | flash2[flshatt] | vanilla | cutlass[cutlass] | flash1[flshatt] 1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 449.4 | | 447.6 | 2260.4 | 621.0 | f16 B=384, M=197, H=1, K=80 | 433.8 | | 431.8 | 1916.2 | 589.4 | f16 B=384, M=197, H=1, K=64 | 234.5 | | 233.6 | 1810.0 | 406.2 | 232.8 f16 B=1024, M=197, H=1, K=88 | 1112.9 | | 1111.3 | 5945.9 | 1557.7 | f16 B=1024, M=197, H=1, K=80 | 1073.3 | | 1073.5 | 5022.5 | 1480.4 | f16 B=1024, M=197, H=1, K=64 | 579.9 | | 579.4 | 4736.3 | 930.1 | 578.3 f16 B=512, M=197, H=1, K=80 | 546.0 | | 545.1 | 2539.2 | 749.0 | f16 B=32, M=197, H=16, K=80 | 560.2 | | 558.7 | 2579.7 | 748.1 | f16 B=32, M=197, H=16, K=64 | 298.4 | | 297.8 | 2430.3 | 480.1 | 295.8 f16 B=32, M=197, H=16, K=128 | 652.5 | | 647.4 | 4494.1 | 896.4 | 684.3 f16 B=256, M=197, H=1, K=88 | 330.1 | | 328.8 | 1527.1 | 449.9 | f16 B=16, M=197, H=16, K=88 | 331.2 | | 329.1 | 1540.0 | 445.2 | f16 B=16, M=197, H=16, K=64 | 198.7 | | 166.9 | 1243.8 | 246.4 | 165.9 f16 B=16, M=197, H=16, K=128 | 371.7 | | 367.4 | 2266.8 | 503.6 | 386.4 f16 B=1, M=4096, H=160, K=128 | 42897.1 | | 42915.9 | 45961.0 | 56457.5 | 54740.1 f16 B=2, M=4096, H=160, K=128 | 84681.5 | | 84646.9 | | 93553.0 | 84262.0 f16 B=1, M=8192, H=160, K=128 | 165741.2 | | 165657.6 | | 222850.4 | 215961.0 f16 B=2, M=8192, H=160, K=128 | 332314.2 | | 332593.3 | | 364220.8 | 330727.9 f16 B=1024, M=82, H=8, K=64 | 1647.0 | | 1643.6 | 3821.1 | 1886.9 | 1621.7 f16 B=150, M=256, H=16, K=64 | 1632.0 | | 1626.4 | 4556.0 | 2302.2 | 1626.6 f16 B=64, M=256, H=12, K=64 | 570.4 | | 566.8 | 1499.1 | 778.4 | 567.1 f16 B=1, M=4096, H=16, K=40 | 2153.5 | | 2154.7 | 4193.7 | 23526.2 | f16 B=1, M=16384, H=16, K=40 | 27859.4 | | 27850.3 | | 389891.9 | f16 B=256, M=4096, H=16, K=64 | 491968.1 | | 491246.6 | | 729981.6 | 440392.8 f16 B=16, M=128, H=16, K=16 | 195.8 | | 118.3 | 276.8 | 129.3 | 120.4 f16 B=16, M=128, H=16, K=32 | 197.7 | | 117.6 | 276.8 | 128.8 | 113.7 f16 B=16, M=128, H=16, K=64 | 196.0 | | 122.9 | 324.0 | 129.9 | 113.6 f16 B=16, M=128, H=16, K=128 | 196.4 | | 170.0 | 331.5 | 182.2 | 157.9 f16 B=16, M=128, H=16, K=256 | | 791.8 | | 545.0 | 774.4 | f16 B=16, M=512, H=16, K=16 | 378.0 | | 375.1 | 1203.0 | 623.4 | 325.4 f16 B=16, M=512, H=16, K=32 | 436.7 | | 433.8 | 1306.2 | 711.6 | 432.0 f16 B=16, M=512, H=16, K=64 | 684.6 | | 680.6 | 1543.4 | 917.7 | 702.3 f16 B=16, M=512, H=16, K=128 | 1565.8 | | 1560.9 | 1983.1 | 1656.5 | 1580.7 f16 B=16, M=512, H=16, K=256 | | 8678.8 | | 2901.2 | 8652.9 | f16 B=16, M=1024, H=16, K=16 | 1432.3 | | 1435.8 | 4262.8 | 2422.2 | 1245.2 f16 B=16, M=1024, H=16, K=32 | 1608.0 | | 1605.9 | 4489.6 | 2697.4 | 1617.9 f16 B=16, M=1024, H=16, K=64 | 2549.3 | | 2541.9 | 4992.6 | 3341.4 | 2370.4 f16 B=16, M=1024, H=16, K=128 | 5034.4 | | 5023.6 | 5955.5 | 5903.8 | 5634.1 f16 B=16, M=1024, H=16, K=256 | | 31948.6 | | 7891.3 | 31850.4 | f16 B=64, M=128, H=16, K=16 | 193.9 | | 159.7 | 439.8 | 158.3 | 145.0 f16 B=64, M=128, H=16, K=32 | 217.0 | | 215.3 | 545.3 | 209.4 | 211.2 f16 B=64, M=128, H=16, K=64 | 316.0 | | 313.6 | 767.0 | 328.6 | 310.2 f16 B=64, M=128, H=16, K=128 | 559.7 | | 555.9 | 1229.2 | 637.6 | 561.6 f16 B=64, M=128, H=16, K=256 | | 2851.8 | | 2126.3 | 2839.0 | f16 B=64, M=512, H=16, K=16 | 1372.4 | | 1370.9 | 4484.9 | 2313.9 | 1219.4 f16 B=64, M=512, H=16, K=32 | 1542.0 | | 1539.2 | 4978.8 | 2706.7 | 1547.4 f16 B=64, M=512, H=16, K=64 | 2335.4 | | 2330.0 | 5885.5 | 3464.0 | 2417.8 f16 B=64, M=512, H=16, K=128 | 5357.0 | | 5347.8 | 7711.8 | 5909.5 | 5442.7 f16 B=64, M=512, H=16, K=256 | | 31136.1 | | 11501.8 | 30488.7 | f16 B=64, M=1024, H=16, K=16 | 5226.1 | | 5228.4 | 16908.5 | 9092.8 | 4713.7 f16 B=64, M=1024, H=16, K=32 | 5717.3 | | 5723.1 | 17895.8 | 10439.0 | 5705.3 f16 B=64, M=1024, H=16, K=64 | 8672.2 | | 8664.4 | 19929.7 | 12937.8 | 8137.3 f16 B=64, M=1024, H=16, K=128 | 19104.3 | | 19080.9 | 23711.0 | 20963.4 | 19169.7 f16 B=64, M=1024, H=16, K=256 | | 115163.2 | | 32729.8 | 113715.6 | Times are in microseconds (us). ``` </details> [ghstack-poisoned]
NOTE: Dropout is entirely disabled until we figure out how to implement it properly <details> <summary>BW perf on A100</summary> ``` [-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------] | refactor[flshattB] | refactor[cutlassB] | flash2[flshatt] | vanilla | cutlass[cutlass] | flash1[flshatt] 1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 449.4 | | 447.6 | 2260.4 | 621.0 | f16 B=384, M=197, H=1, K=80 | 433.8 | | 431.8 | 1916.2 | 589.4 | f16 B=384, M=197, H=1, K=64 | 234.5 | | 233.6 | 1810.0 | 406.2 | 232.8 f16 B=1024, M=197, H=1, K=88 | 1112.9 | | 1111.3 | 5945.9 | 1557.7 | f16 B=1024, M=197, H=1, K=80 | 1073.3 | | 1073.5 | 5022.5 | 1480.4 | f16 B=1024, M=197, H=1, K=64 | 579.9 | | 579.4 | 4736.3 | 930.1 | 578.3 f16 B=512, M=197, H=1, K=80 | 546.0 | | 545.1 | 2539.2 | 749.0 | f16 B=32, M=197, H=16, K=80 | 560.2 | | 558.7 | 2579.7 | 748.1 | f16 B=32, M=197, H=16, K=64 | 298.4 | | 297.8 | 2430.3 | 480.1 | 295.8 f16 B=32, M=197, H=16, K=128 | 652.5 | | 647.4 | 4494.1 | 896.4 | 684.3 f16 B=256, M=197, H=1, K=88 | 330.1 | | 328.8 | 1527.1 | 449.9 | f16 B=16, M=197, H=16, K=88 | 331.2 | | 329.1 | 1540.0 | 445.2 | f16 B=16, M=197, H=16, K=64 | 198.7 | | 166.9 | 1243.8 | 246.4 | 165.9 f16 B=16, M=197, H=16, K=128 | 371.7 | | 367.4 | 2266.8 | 503.6 | 386.4 f16 B=1, M=4096, H=160, K=128 | 42897.1 | | 42915.9 | 45961.0 | 56457.5 | 54740.1 f16 B=2, M=4096, H=160, K=128 | 84681.5 | | 84646.9 | | 93553.0 | 84262.0 f16 B=1, M=8192, H=160, K=128 | 165741.2 | | 165657.6 | | 222850.4 | 215961.0 f16 B=2, M=8192, H=160, K=128 | 332314.2 | | 332593.3 | | 364220.8 | 330727.9 f16 B=1024, M=82, H=8, K=64 | 1647.0 | | 1643.6 | 3821.1 | 1886.9 | 1621.7 f16 B=150, M=256, H=16, K=64 | 1632.0 | | 1626.4 | 4556.0 | 2302.2 | 1626.6 f16 B=64, M=256, H=12, K=64 | 570.4 | | 566.8 | 1499.1 | 778.4 | 567.1 f16 B=1, M=4096, H=16, K=40 | 2153.5 | | 2154.7 | 4193.7 | 23526.2 | f16 B=1, M=16384, H=16, K=40 | 27859.4 | | 27850.3 | | 389891.9 | f16 B=256, M=4096, H=16, K=64 | 491968.1 | | 491246.6 | | 729981.6 | 440392.8 f16 B=16, M=128, H=16, K=16 | 195.8 | | 118.3 | 276.8 | 129.3 | 120.4 f16 B=16, M=128, H=16, K=32 | 197.7 | | 117.6 | 276.8 | 128.8 | 113.7 f16 B=16, M=128, H=16, K=64 | 196.0 | | 122.9 | 324.0 | 129.9 | 113.6 f16 B=16, M=128, H=16, K=128 | 196.4 | | 170.0 | 331.5 | 182.2 | 157.9 f16 B=16, M=128, H=16, K=256 | | 791.8 | | 545.0 | 774.4 | f16 B=16, M=512, H=16, K=16 | 378.0 | | 375.1 | 1203.0 | 623.4 | 325.4 f16 B=16, M=512, H=16, K=32 | 436.7 | | 433.8 | 1306.2 | 711.6 | 432.0 f16 B=16, M=512, H=16, K=64 | 684.6 | | 680.6 | 1543.4 | 917.7 | 702.3 f16 B=16, M=512, H=16, K=128 | 1565.8 | | 1560.9 | 1983.1 | 1656.5 | 1580.7 f16 B=16, M=512, H=16, K=256 | | 8678.8 | | 2901.2 | 8652.9 | f16 B=16, M=1024, H=16, K=16 | 1432.3 | | 1435.8 | 4262.8 | 2422.2 | 1245.2 f16 B=16, M=1024, H=16, K=32 | 1608.0 | | 1605.9 | 4489.6 | 2697.4 | 1617.9 f16 B=16, M=1024, H=16, K=64 | 2549.3 | | 2541.9 | 4992.6 | 3341.4 | 2370.4 f16 B=16, M=1024, H=16, K=128 | 5034.4 | | 5023.6 | 5955.5 | 5903.8 | 5634.1 f16 B=16, M=1024, H=16, K=256 | | 31948.6 | | 7891.3 | 31850.4 | f16 B=64, M=128, H=16, K=16 | 193.9 | | 159.7 | 439.8 | 158.3 | 145.0 f16 B=64, M=128, H=16, K=32 | 217.0 | | 215.3 | 545.3 | 209.4 | 211.2 f16 B=64, M=128, H=16, K=64 | 316.0 | | 313.6 | 767.0 | 328.6 | 310.2 f16 B=64, M=128, H=16, K=128 | 559.7 | | 555.9 | 1229.2 | 637.6 | 561.6 f16 B=64, M=128, H=16, K=256 | | 2851.8 | | 2126.3 | 2839.0 | f16 B=64, M=512, H=16, K=16 | 1372.4 | | 1370.9 | 4484.9 | 2313.9 | 1219.4 f16 B=64, M=512, H=16, K=32 | 1542.0 | | 1539.2 | 4978.8 | 2706.7 | 1547.4 f16 B=64, M=512, H=16, K=64 | 2335.4 | | 2330.0 | 5885.5 | 3464.0 | 2417.8 f16 B=64, M=512, H=16, K=128 | 5357.0 | | 5347.8 | 7711.8 | 5909.5 | 5442.7 f16 B=64, M=512, H=16, K=256 | | 31136.1 | | 11501.8 | 30488.7 | f16 B=64, M=1024, H=16, K=16 | 5226.1 | | 5228.4 | 16908.5 | 9092.8 | 4713.7 f16 B=64, M=1024, H=16, K=32 | 5717.3 | | 5723.1 | 17895.8 | 10439.0 | 5705.3 f16 B=64, M=1024, H=16, K=64 | 8672.2 | | 8664.4 | 19929.7 | 12937.8 | 8137.3 f16 B=64, M=1024, H=16, K=128 | 19104.3 | | 19080.9 | 23711.0 | 20963.4 | 19169.7 f16 B=64, M=1024, H=16, K=256 | | 115163.2 | | 32729.8 | 113715.6 | Times are in microseconds (us). ``` </details> [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Putting some early comments in (the PR was still changing while I added the first comments), so that I don't forget
# Saving attn_bias is a bit complicated, as the | ||
# torch part should go in `save_for_backward` | ||
if isinstance(inp.attn_bias, torch.Tensor): | ||
attn_bias_tensor = inp.attn_bias | ||
attn_bias_ctx = None | ||
else: | ||
attn_bias_tensor = None | ||
attn_bias_ctx = inp.attn_bias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably make the CausalTensor
actually inherit from a torch.Tensor
in the near future
return attn_bias_ctx | ||
return attn_bias_tensor | ||
|
||
@classmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably also make it @once_differentiable
for correctness. But I don't know how it plays with classmethod
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
return (gradients.dq, gradients.dk, gradients.dv) | ||
|
||
|
||
def _memory_efficient_attention( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Those _
functions seems like they could be folded into the original memory_efficient_attention
function, avoiding the need for this private function?
NOTE: Dropout is entirely disabled until we figure out how to implement it properly <details> <summary>BW perf on A100</summary> ``` [-------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ---------------------------------------------------] | refactor[flshattB] | refactor[cutlassB] | flash2[flshatt] | vanilla | cutlass[cutlass] | flash1[flshatt] 1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------- f16 B=384, M=197, H=1, K=88 | 449.4 | | 447.6 | 2260.4 | 621.0 | f16 B=384, M=197, H=1, K=80 | 433.8 | | 431.8 | 1916.2 | 589.4 | f16 B=384, M=197, H=1, K=64 | 234.5 | | 233.6 | 1810.0 | 406.2 | 232.8 f16 B=1024, M=197, H=1, K=88 | 1112.9 | | 1111.3 | 5945.9 | 1557.7 | f16 B=1024, M=197, H=1, K=80 | 1073.3 | | 1073.5 | 5022.5 | 1480.4 | f16 B=1024, M=197, H=1, K=64 | 579.9 | | 579.4 | 4736.3 | 930.1 | 578.3 f16 B=512, M=197, H=1, K=80 | 546.0 | | 545.1 | 2539.2 | 749.0 | f16 B=32, M=197, H=16, K=80 | 560.2 | | 558.7 | 2579.7 | 748.1 | f16 B=32, M=197, H=16, K=64 | 298.4 | | 297.8 | 2430.3 | 480.1 | 295.8 f16 B=32, M=197, H=16, K=128 | 652.5 | | 647.4 | 4494.1 | 896.4 | 684.3 f16 B=256, M=197, H=1, K=88 | 330.1 | | 328.8 | 1527.1 | 449.9 | f16 B=16, M=197, H=16, K=88 | 331.2 | | 329.1 | 1540.0 | 445.2 | f16 B=16, M=197, H=16, K=64 | 198.7 | | 166.9 | 1243.8 | 246.4 | 165.9 f16 B=16, M=197, H=16, K=128 | 371.7 | | 367.4 | 2266.8 | 503.6 | 386.4 f16 B=1, M=4096, H=160, K=128 | 42897.1 | | 42915.9 | 45961.0 | 56457.5 | 54740.1 f16 B=2, M=4096, H=160, K=128 | 84681.5 | | 84646.9 | | 93553.0 | 84262.0 f16 B=1, M=8192, H=160, K=128 | 165741.2 | | 165657.6 | | 222850.4 | 215961.0 f16 B=2, M=8192, H=160, K=128 | 332314.2 | | 332593.3 | | 364220.8 | 330727.9 f16 B=1024, M=82, H=8, K=64 | 1647.0 | | 1643.6 | 3821.1 | 1886.9 | 1621.7 f16 B=150, M=256, H=16, K=64 | 1632.0 | | 1626.4 | 4556.0 | 2302.2 | 1626.6 f16 B=64, M=256, H=12, K=64 | 570.4 | | 566.8 | 1499.1 | 778.4 | 567.1 f16 B=1, M=4096, H=16, K=40 | 2153.5 | | 2154.7 | 4193.7 | 23526.2 | f16 B=1, M=16384, H=16, K=40 | 27859.4 | | 27850.3 | | 389891.9 | f16 B=256, M=4096, H=16, K=64 | 491968.1 | | 491246.6 | | 729981.6 | 440392.8 f16 B=16, M=128, H=16, K=16 | 195.8 | | 118.3 | 276.8 | 129.3 | 120.4 f16 B=16, M=128, H=16, K=32 | 197.7 | | 117.6 | 276.8 | 128.8 | 113.7 f16 B=16, M=128, H=16, K=64 | 196.0 | | 122.9 | 324.0 | 129.9 | 113.6 f16 B=16, M=128, H=16, K=128 | 196.4 | | 170.0 | 331.5 | 182.2 | 157.9 f16 B=16, M=128, H=16, K=256 | | 791.8 | | 545.0 | 774.4 | f16 B=16, M=512, H=16, K=16 | 378.0 | | 375.1 | 1203.0 | 623.4 | 325.4 f16 B=16, M=512, H=16, K=32 | 436.7 | | 433.8 | 1306.2 | 711.6 | 432.0 f16 B=16, M=512, H=16, K=64 | 684.6 | | 680.6 | 1543.4 | 917.7 | 702.3 f16 B=16, M=512, H=16, K=128 | 1565.8 | | 1560.9 | 1983.1 | 1656.5 | 1580.7 f16 B=16, M=512, H=16, K=256 | | 8678.8 | | 2901.2 | 8652.9 | f16 B=16, M=1024, H=16, K=16 | 1432.3 | | 1435.8 | 4262.8 | 2422.2 | 1245.2 f16 B=16, M=1024, H=16, K=32 | 1608.0 | | 1605.9 | 4489.6 | 2697.4 | 1617.9 f16 B=16, M=1024, H=16, K=64 | 2549.3 | | 2541.9 | 4992.6 | 3341.4 | 2370.4 f16 B=16, M=1024, H=16, K=128 | 5034.4 | | 5023.6 | 5955.5 | 5903.8 | 5634.1 f16 B=16, M=1024, H=16, K=256 | | 31948.6 | | 7891.3 | 31850.4 | f16 B=64, M=128, H=16, K=16 | 193.9 | | 159.7 | 439.8 | 158.3 | 145.0 f16 B=64, M=128, H=16, K=32 | 217.0 | | 215.3 | 545.3 | 209.4 | 211.2 f16 B=64, M=128, H=16, K=64 | 316.0 | | 313.6 | 767.0 | 328.6 | 310.2 f16 B=64, M=128, H=16, K=128 | 559.7 | | 555.9 | 1229.2 | 637.6 | 561.6 f16 B=64, M=128, H=16, K=256 | | 2851.8 | | 2126.3 | 2839.0 | f16 B=64, M=512, H=16, K=16 | 1372.4 | | 1370.9 | 4484.9 | 2313.9 | 1219.4 f16 B=64, M=512, H=16, K=32 | 1542.0 | | 1539.2 | 4978.8 | 2706.7 | 1547.4 f16 B=64, M=512, H=16, K=64 | 2335.4 | | 2330.0 | 5885.5 | 3464.0 | 2417.8 f16 B=64, M=512, H=16, K=128 | 5357.0 | | 5347.8 | 7711.8 | 5909.5 | 5442.7 f16 B=64, M=512, H=16, K=256 | | 31136.1 | | 11501.8 | 30488.7 | f16 B=64, M=1024, H=16, K=16 | 5226.1 | | 5228.4 | 16908.5 | 9092.8 | 4713.7 f16 B=64, M=1024, H=16, K=32 | 5717.3 | | 5723.1 | 17895.8 | 10439.0 | 5705.3 f16 B=64, M=1024, H=16, K=64 | 8672.2 | | 8664.4 | 19929.7 | 12937.8 | 8137.3 f16 B=64, M=1024, H=16, K=128 | 19104.3 | | 19080.9 | 23711.0 | 20963.4 | 19169.7 f16 B=64, M=1024, H=16, K=256 | | 115163.2 | | 32729.8 | 113715.6 | Times are in microseconds (us). ``` </details> [ghstack-poisoned]
ghstack-source-id: 548ab26edd3160663dd9bdf08e4be3f1f1ab809c Pull Request resolved: #560
Stack from ghstack (oldest at bottom):
NOTE: Dropout is entirely disabled until we figure out how to implement it properly
BW perf on A100