Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[0.0.17 / 0.0.18] nvcc version and performance #712

Closed
danthe3rd opened this issue Mar 30, 2023 · 3 comments
Closed

[0.0.17 / 0.0.18] nvcc version and performance #712

danthe3rd opened this issue Mar 30, 2023 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@danthe3rd
Copy link
Contributor

danthe3rd commented Mar 30, 2023

On A100, we seem to get best performance using nvcc version 11.6.2
We need to investigate why the performance is significatively worse on other nvcc versions

fMHA FW - A100
[--------------------- attention (attn_bias=<class 'NoneType'>) ---------------------]
                                                 |   nv1107   |   nv1106   |   nv1108 
1 threads: ---------------------------------------------------------------------------
      f16 384-197-1-88, p=0.0, BiasT=NoneType    |     143.0  |     125.0  |     133.0
      f16 384-197-1-80, p=0.0, BiasT=NoneType    |     133.4  |     116.0  |     124.0
      f16 384-197-1-64, p=0.0, BiasT=NoneType    |      97.5  |      89.0  |      94.0
      f16 1024-197-1-88, p=0.0, BiasT=NoneType   |     356.8  |     316.0  |     331.0
      f16 1024-197-1-80, p=0.0, BiasT=NoneType   |     333.0  |     296.0  |     309.0
      f16 1024-197-1-64, p=0.0, BiasT=NoneType   |     236.2  |     216.0  |     228.0
      f16 512-197-1-80, p=0.0, BiasT=NoneType    |     173.0  |     154.0  |     160.0
      f16 32-197-16-80, p=0.0, BiasT=NoneType    |     172.7  |     155.0  |     161.0
      f16 32-197-16-64, p=0.0, BiasT=NoneType    |     125.9  |     114.0  |     121.0
      f16 32-197-16-128, p=0.0, BiasT=NoneType   |     187.6  |     166.0  |     174.0
      f16 256-197-1-88, p=0.0, BiasT=NoneType    |     100.0  |      88.0  |      93.0
      f16 16-197-16-88, p=0.0, BiasT=NoneType    |     100.8  |     103.0  |      94.0
      f16 16-197-16-64, p=0.0, BiasT=NoneType    |      68.2  |      72.0  |      66.0
      f16 16-197-16-128, p=0.0, BiasT=NoneType   |     101.9  |      89.0  |      94.0
      f16 1024-82-8-64, p=0.0, BiasT=NoneType    |     511.1  |     467.0  |     498.0
      f16 150-256-16-64, p=0.0, BiasT=NoneType   |     566.5  |     524.0  |     548.0
      f16 64-256-12-64, p=0.0, BiasT=NoneType    |     194.6  |     176.0  |     188.0
      f16 1-4096-16-40, p=0.0, BiasT=NoneType    |     958.7  |     868.0  |     931.0
      f16 1-16384-16-40, p=0.0, BiasT=NoneType   |   13591.2  |   12403.0  |   13229.0
      f16 1-4096-16-80, p=0.0, BiasT=NoneType    |    1311.0  |    1238.0  |    1283.0
      f16 1-16384-16-80, p=0.0, BiasT=NoneType   |   20207.2  |   19066.0  |   19696.0
      f16 4-4096-16-40, p=0.0, BiasT=NoneType    |    3435.9  |    3141.0  |    3347.0
      f16 4-16384-16-40, p=0.0, BiasT=NoneType   |   53435.1  |   48768.0  |   52215.0
      f16 4-4096-16-80, p=0.0, BiasT=NoneType    |    5105.8  |    4809.0  |    4976.0
      f16 4-16384-16-80, p=0.0, BiasT=NoneType   |   80385.0  |   75779.0  |   78274.0
      f16 256-4096-16-64, p=0.0, BiasT=NoneType  |  204914.6  |  187167.0  |  199315.0
      f16 8-2048-20-128, p=0.0, BiasT=NoneType   |    3509.2  |    3284.0  |    3373.0
      f16 1-2048-4-128, p=0.0, BiasT=NoneType    |     124.5  |     117.0  |     120.0
      f16 1-2048-8-128, p=0.0, BiasT=NoneType    |     220.6  |     200.0  |     208.0
      f16 1-4096-4-128, p=0.0, BiasT=NoneType    |     422.3  |     394.0  |     409.0
      f16 1-4096-8-128, p=0.0, BiasT=NoneType    |     752.1  |     709.0  |     728.0
      f16 1-8192-4-128, p=0.0, BiasT=NoneType    |    1480.9  |    1397.0  |    1438.0
      f16 1-8192-8-128, p=0.0, BiasT=NoneType    |    2767.2  |    2600.0  |    2679.0
      f16 2-2048-4-128, p=0.0, BiasT=NoneType    |     220.9  |     202.0  |     209.0
      f16 2-2048-8-128, p=0.0, BiasT=NoneType    |     384.9  |     360.0  |     373.0
      f16 2-4096-4-128, p=0.0, BiasT=NoneType    |     754.3  |     711.0  |     731.0
      f16 2-4096-8-128, p=0.0, BiasT=NoneType    |    1391.6  |    1311.0  |    1351.0
      f16 2-8192-4-128, p=0.0, BiasT=NoneType    |    2754.3  |    2602.0  |    2683.0
      f16 2-8192-8-128, p=0.0, BiasT=NoneType    |    5451.7  |    5139.0  |    5294.0
      f16 16-128-16-16, p=0.0, BiasT=NoneType    |      66.6  |      65.0  |      65.0
      f16 16-128-16-32, p=0.0, BiasT=NoneType    |      66.2  |      63.0  |      65.0
      f16 16-128-16-64, p=0.0, BiasT=NoneType    |      66.7  |      65.0  |      71.0
      f16 16-128-16-128, p=0.0, BiasT=NoneType   |      67.0  |      64.0  |      68.0
      f16 16-512-16-16, p=0.0, BiasT=NoneType    |     201.7  |     182.0  |     194.0
      f16 16-512-16-32, p=0.0, BiasT=NoneType    |     209.2  |     189.0  |     202.0
      f16 16-512-16-64, p=0.0, BiasT=NoneType    |     235.6  |     213.0  |     229.0
      f16 16-512-16-128, p=0.0, BiasT=NoneType   |     396.0  |     362.0  |     374.0
      f16 16-1024-16-16, p=0.0, BiasT=NoneType   |     765.1  |     695.0  |     742.0
      f16 16-1024-16-32, p=0.0, BiasT=NoneType   |     770.6  |     702.0  |     747.0
      f16 16-1024-16-64, p=0.0, BiasT=NoneType   |     860.8  |     786.0  |     836.0
      f16 16-1024-16-128, p=0.0, BiasT=NoneType  |    1448.7  |    1344.0  |    1386.0
      f16 384-197-1-88, p=0.3, BiasT=NoneType    |     194.4  |     177.0  |     182.0
      f16 384-197-1-64, p=0.3, BiasT=NoneType    |     157.9  |     151.0  |     155.0
      f32 1024-197-1-88, p=0.0, BiasT=NoneType   |    1158.2  |    1158.0  |    1148.0
      f32 1024-197-1-80, p=0.0, BiasT=NoneType   |    1147.2  |    1150.0  |    1135.0
      f16 1024-197-1-64, p=0.3, BiasT=NoneType   |     379.2  |     365.0  |     374.0
      f16 32-197-16-80, p=0.3, BiasT=NoneType    |     242.4  |     223.0  |     229.0
      b16 256-197-1-88, p=0.0, BiasT=NoneType    |     100.0  |      88.0  |      93.0
      f16 16-197-16-88, p=0.3, BiasT=NoneType    |     136.5  |     124.0  |     128.0
      f16 150-256-16-64, p=0.3, BiasT=NoneType   |     968.0  |     933.0  |     955.0
      f16 1-16384-16-40, p=0.3, BiasT=NoneType   |   24826.4  |   23879.0  |   24594.0
      f32 1-4096-16-80, p=0.0, BiasT=NoneType    |    5190.8  |    4978.0  |    5131.0
      b16 8-2048-20-128, p=0.0, BiasT=NoneType   |    3460.7  |    3244.0  |    3336.0
      b16 1-4096-4-128, p=0.0, BiasT=NoneType    |     422.2  |     394.0  |     409.0
      b16 1-4096-8-128, p=0.0, BiasT=NoneType    |     751.7  |     708.0  |     728.0
      f32 1-8192-4-128, p=0.0, BiasT=NoneType    |    5942.5  |    5825.0  |    5838.0
      f16 2-2048-8-128, p=0.3, BiasT=NoneType    |     585.1  |     551.0  |     562.0
      b16 2-4096-4-128, p=0.0, BiasT=NoneType    |     754.1  |     711.0  |     730.0
      b16 2-8192-8-128, p=0.0, BiasT=NoneType    |    5434.2  |    5086.0  |    5256.0
      f32 16-128-16-16, p=0.0, BiasT=NoneType    |      67.4  |      65.0  |      66.0
      b16 16-128-16-128, p=0.0, BiasT=NoneType   |      71.8  |      69.0  |      71.0
      b16 16-1024-16-16, p=0.0, BiasT=NoneType   |     765.2  |     695.0  |     742.0
      f16 16-1024-16-32, p=0.3, BiasT=NoneType   |    1474.3  |    1425.0  |    1458.0
      b16 16-1024-16-64, p=0.0, BiasT=NoneType   |     860.9  |     786.0  |     836.0
      b16 16-1024-16-128, p=0.0, BiasT=NoneType  |    1433.9  |    1337.0  |    1373.0

Times are in microseconds (us).

[--- attention (attn_bias=<class 'xformers.ops.fmha.attn_bias.LowerTriangularMask'>) ----]
                                                           |  nv1107  |  nv1106  |  nv1108
1 threads: -------------------------------------------------------------------------------
      f16 384-197-1-80, p=0.0, BiasT=LowerTriangularMask   |   104.9  |    91.0  |    95.0
      f16 32-197-16-128, p=0.0, BiasT=LowerTriangularMask  |   147.8  |   127.0  |   135.0
      f16 16-197-16-128, p=0.0, BiasT=LowerTriangularMask  |    81.7  |    71.0  |    75.0
      f16 1-4096-16-40, p=0.0, BiasT=LowerTriangularMask   |   589.0  |   538.0  |   572.0
      f16 4-4096-16-40, p=0.0, BiasT=LowerTriangularMask   |  1880.9  |  1716.0  |  1827.0
      f16 2-8192-4-128, p=0.0, BiasT=LowerTriangularMask   |  1585.6  |  1499.0  |  1537.0
      f16 16-128-16-64, p=0.0, BiasT=LowerTriangularMask   |    66.7  |    64.0  |    66.0
      f16 16-512-16-128, p=0.0, BiasT=LowerTriangularMask  |   265.5  |   237.0  |   249.0

Times are in microseconds (us).

[------------------ attention (attn_bias=<class 'torch.Tensor'>) -----------------]
                                              |   nv1107   |   nv1106   |   nv1108 
1 threads: ------------------------------------------------------------------------
      f16 1-16384-16-80, p=0.0, BiasT=Tensor  |   27293.7  |   25950.0  |   26765.0
      f16 4-4096-16-80, p=0.0, BiasT=Tensor   |    6874.4  |    6529.0  |    6734.0
      f16 4-16384-16-80, p=0.0, BiasT=Tensor  |  108550.4  |  103304.0  |  106499.0
      f16 1-2048-8-128, p=0.0, BiasT=Tensor   |     292.3  |     280.0  |     277.0
      f16 2-4096-8-128, p=0.0, BiasT=Tensor   |    1802.5  |    1719.0  |    1763.0
      f16 16-512-16-16, p=0.0, BiasT=Tensor   |     308.8  |     288.0  |     302.0
      f16 16-512-16-32, p=0.0, BiasT=Tensor   |     321.6  |     303.0  |     313.0
      f16 16-512-16-64, p=0.0, BiasT=Tensor   |     347.9  |     326.0  |     340.0

Times are in microseconds (us).
fMHA BW - A100
[------------------------ attention backward (attn_bias=<class 'NoneType'>) -------------------------]
                                                                 |   nv1107   |   nv1106   |   nv1108 
1 threads: -------------------------------------------------------------------------------------------
      f16 384-197-1-88, p=0.0, BiasT=NoneType, BiasGrad=False    |     719.7  |     742.0  |     709.0
      f16 384-197-1-80, p=0.0, BiasT=NoneType, BiasGrad=False    |     695.0  |     713.0  |     680.0
      f16 384-197-1-64, p=0.0, BiasT=NoneType, BiasGrad=False    |     471.5  |     452.0  |     481.0
      f16 1024-197-1-88, p=0.0, BiasT=NoneType, BiasGrad=False   |    1845.3  |    1894.0  |    1817.0
      f16 1024-197-1-80, p=0.0, BiasT=NoneType, BiasGrad=False   |    1776.3  |    1812.0  |    1746.0
      f16 1024-197-1-64, p=0.0, BiasT=NoneType, BiasGrad=False   |    1050.9  |    1015.0  |    1073.0
      f16 512-197-1-80, p=0.0, BiasT=NoneType, BiasGrad=False    |     891.5  |     921.0  |     876.0
      f16 32-197-16-80, p=0.0, BiasT=NoneType, BiasGrad=False    |     888.1  |     907.0  |     877.0
      f16 32-197-16-64, p=0.0, BiasT=NoneType, BiasGrad=False    |     536.1  |     514.0  |     546.0
      f16 32-197-16-128, p=0.0, BiasT=NoneType, BiasGrad=False   |    1024.2  |    1073.0  |    1020.0
      f16 256-197-1-88, p=0.0, BiasT=NoneType, BiasGrad=False    |     519.8  |     533.0  |     512.0
      f16 16-197-16-88, p=0.0, BiasT=NoneType, BiasGrad=False    |     516.7  |     527.0  |     511.0
      f16 16-197-16-64, p=0.0, BiasT=NoneType, BiasGrad=False    |     274.7  |     286.0  |     279.0
      f16 16-197-16-128, p=0.0, BiasT=NoneType, BiasGrad=False   |     572.7  |     592.0  |     567.0
      f16 1024-82-8-64, p=0.0, BiasT=NoneType, BiasGrad=False    |    2173.3  |    2172.0  |    2172.0
      f16 150-256-16-64, p=0.0, BiasT=NoneType, BiasGrad=False   |    2501.7  |    2381.0  |    2595.0
      f16 64-256-12-64, p=0.0, BiasT=NoneType, BiasGrad=False    |     863.3  |     812.0  |     893.0
      f16 1-4096-16-40, p=0.0, BiasT=NoneType, BiasGrad=False    |   25996.1  |   24943.0  |   27316.0
      f16 1-16384-16-40, p=0.0, BiasT=NoneType, BiasGrad=False   |  434043.8  |  415297.0  |  460636.0
      f16 1-4096-16-80, p=0.0, BiasT=NoneType, BiasGrad=False    |   28160.4  |   25455.0  |   28207.0
      f16 1-16384-16-80, p=0.0, BiasT=NoneType, BiasGrad=False   |  446704.9  |  414644.0  |  451900.0
      f16 4-4096-16-40, p=0.0, BiasT=NoneType, BiasGrad=False    |   28960.8  |   27548.0  |   30517.0
      f16 4-16384-16-40, p=0.0, BiasT=NoneType, BiasGrad=False   |  474622.0  |  447150.0  |  495221.0
      f16 4-4096-16-80, p=0.0, BiasT=NoneType, BiasGrad=False    |   29966.7  |   28440.0  |   30272.0
      f16 4-16384-16-80, p=0.0, BiasT=NoneType, BiasGrad=False   |  480510.1  |  455357.0  |  485223.0
      f16 256-4096-16-64, p=0.0, BiasT=NoneType, BiasGrad=False  |  775476.6  |  748286.0  |  805287.0
      f16 8-2048-20-128, p=0.0, BiasT=NoneType, BiasGrad=False   |   17572.9  |   16505.0  |   17578.0
      f16 1-2048-4-128, p=0.0, BiasT=NoneType, BiasGrad=False    |    6615.7  |    5742.0  |    6532.0
      f16 1-2048-8-128, p=0.0, BiasT=NoneType, BiasGrad=False    |    6671.8  |    5935.0  |    6665.0
      f16 1-4096-4-128, p=0.0, BiasT=NoneType, BiasGrad=False    |   26242.8  |   22863.0  |   25923.0
      f16 1-4096-8-128, p=0.0, BiasT=NoneType, BiasGrad=False    |   28542.1  |   26200.0  |   28567.0
      f16 1-8192-4-128, p=0.0, BiasT=NoneType, BiasGrad=False    |  111636.2  |  100469.0  |  110541.0
      f16 1-8192-8-128, p=0.0, BiasT=NoneType, BiasGrad=False    |  114243.9  |  105447.0  |  114525.0
      f16 2-2048-4-128, p=0.0, BiasT=NoneType, BiasGrad=False    |    6670.3  |    5877.0  |    6681.0
      f16 2-2048-8-128, p=0.0, BiasT=NoneType, BiasGrad=False    |    7434.4  |    6773.0  |    7445.0
      f16 2-4096-4-128, p=0.0, BiasT=NoneType, BiasGrad=False    |   28506.0  |   26092.0  |   28532.0
      f16 2-4096-8-128, p=0.0, BiasT=NoneType, BiasGrad=False    |   29408.8  |   26915.0  |   29459.0
      f16 2-8192-4-128, p=0.0, BiasT=NoneType, BiasGrad=False    |  114038.3  |  105098.0  |  114181.0
      f16 2-8192-8-128, p=0.0, BiasT=NoneType, BiasGrad=False    |  116973.1  |  107327.0  |  117304.0
      f16 16-128-16-16, p=0.0, BiasT=NoneType, BiasGrad=False    |     234.7  |     249.0  |     237.0
      f16 16-128-16-32, p=0.0, BiasT=NoneType, BiasGrad=False    |     235.7  |     249.0  |     236.0
      f16 16-128-16-64, p=0.0, BiasT=NoneType, BiasGrad=False    |     233.8  |     249.0  |     235.0
      f16 16-128-16-128, p=0.0, BiasT=NoneType, BiasGrad=False   |     233.9  |     247.0  |     235.0
      f16 16-512-16-16, p=0.0, BiasT=NoneType, BiasGrad=False    |     703.3  |     667.0  |     720.0
      f16 16-512-16-32, p=0.0, BiasT=NoneType, BiasGrad=False    |     792.2  |     749.0  |     811.0
      f16 16-512-16-64, p=0.0, BiasT=NoneType, BiasGrad=False    |     981.5  |     948.0  |    1025.0
      f16 16-512-16-128, p=0.0, BiasT=NoneType, BiasGrad=False   |    2015.6  |    1887.0  |    2013.0
      f16 16-1024-16-16, p=0.0, BiasT=NoneType, BiasGrad=False   |    2739.5  |    2652.0  |    2827.0
      f16 16-1024-16-32, p=0.0, BiasT=NoneType, BiasGrad=False   |    3011.9  |    2877.0  |    3082.0
      f16 16-1024-16-64, p=0.0, BiasT=NoneType, BiasGrad=False   |    3552.1  |    3440.0  |    3720.0
      f16 16-1024-16-128, p=0.0, BiasT=NoneType, BiasGrad=False  |    7177.2  |    6764.0  |    7169.0
      f16 384-197-1-88, p=0.3, BiasT=NoneType, BiasGrad=False    |     954.9  |     993.0  |     983.0
      f16 384-197-1-64, p=0.3, BiasT=NoneType, BiasGrad=False    |     626.0  |     557.0  |     591.0
      f32 1024-197-1-88, p=0.0, BiasT=NoneType, BiasGrad=False   |    6209.0  |    6133.0  |    6182.0
      f32 1024-197-1-80, p=0.0, BiasT=NoneType, BiasGrad=False   |    5936.1  |    5877.0  |    5921.0
      f16 1024-197-1-64, p=0.3, BiasT=NoneType, BiasGrad=False   |    1392.0  |    1234.0  |    1284.0
      f16 32-197-16-80, p=0.3, BiasT=NoneType, BiasGrad=False    |    1184.9  |    1227.0  |    1213.0
      b16 256-197-1-88, p=0.0, BiasT=NoneType, BiasGrad=False    |     506.1  |     528.0  |     493.0
      f16 16-197-16-88, p=0.3, BiasT=NoneType, BiasGrad=False    |     682.0  |     694.0  |     696.0
      f16 150-256-16-64, p=0.3, BiasT=NoneType, BiasGrad=False   |    2823.3  |    2574.0  |    2706.0
      f16 1-16384-16-40, p=0.3, BiasT=NoneType, BiasGrad=False   |  623139.4  |  569835.0  |  617652.0
      f32 1-4096-16-80, p=0.0, BiasT=NoneType, BiasGrad=False    |   94343.5  |   99903.0  |   94049.0
      b16 8-2048-20-128, p=0.0, BiasT=NoneType, BiasGrad=False   |   17279.3  |   15883.0  |   17310.0
      b16 1-4096-4-128, p=0.0, BiasT=NoneType, BiasGrad=False    |   25646.0  |   22230.0  |   25699.0
      b16 1-4096-8-128, p=0.0, BiasT=NoneType, BiasGrad=False    |   27548.3  |   24661.0  |   27513.0
      f32 1-8192-4-128, p=0.0, BiasT=NoneType, BiasGrad=False    |  404296.2  |  427443.0  |  403264.0
      f16 2-2048-8-128, p=0.3, BiasT=NoneType, BiasGrad=False    |   10368.9  |    9233.0  |    9862.0
      b16 2-4096-4-128, p=0.0, BiasT=NoneType, BiasGrad=False    |   27478.5  |   24591.0  |   27440.0
      b16 2-8192-8-128, p=0.0, BiasT=NoneType, BiasGrad=False    |  112630.8  |  100655.0  |  112807.0
      f32 16-128-16-16, p=0.0, BiasT=NoneType, BiasGrad=False    |     288.0  |     293.0  |     295.0
      b16 16-128-16-128, p=0.0, BiasT=NoneType, BiasGrad=False   |     244.9  |     246.0  |     251.0
      b16 16-1024-16-16, p=0.0, BiasT=NoneType, BiasGrad=False   |    2692.4  |    2507.0  |    2805.0
      f16 16-1024-16-32, p=0.3, BiasT=NoneType, BiasGrad=False   |    3835.8  |    3431.0  |    3760.0
      b16 16-1024-16-64, p=0.0, BiasT=NoneType, BiasGrad=False   |    3550.4  |    3289.0  |    3602.0
      b16 16-1024-16-128, p=0.0, BiasT=NoneType, BiasGrad=False  |    7034.5  |    6519.0  |    7003.0

Times are in microseconds (us).

[--------- attention backward (attn_bias=<class 'xformers.ops.fmha.attn_bias.LowerTriangularMask'>) --------]
                                                                           |   nv1107  |   nv1106  |   nv1108
1 threads: --------------------------------------------------------------------------------------------------
      f16 384-197-1-80, p=0.0, BiasT=LowerTriangularMask, BiasGrad=False   |    565.8  |    558.0  |    552.0
      f16 32-197-16-128, p=0.0, BiasT=LowerTriangularMask, BiasGrad=False  |    849.6  |    861.0  |    842.0
      f16 16-197-16-128, p=0.0, BiasT=LowerTriangularMask, BiasGrad=False  |    474.8  |    473.0  |    469.0
      f16 1-4096-16-40, p=0.0, BiasT=LowerTriangularMask, BiasGrad=False   |  12318.2  |  11579.0  |  12737.0
      f16 4-4096-16-40, p=0.0, BiasT=LowerTriangularMask, BiasGrad=False   |  15088.9  |  13942.0  |  15774.0
      f16 2-8192-4-128, p=0.0, BiasT=LowerTriangularMask, BiasGrad=False   |  57975.0  |  53121.0  |  58034.0
      f16 16-128-16-64, p=0.0, BiasT=LowerTriangularMask, BiasGrad=False   |    235.1  |    240.0  |    244.0
      f16 16-512-16-128, p=0.0, BiasT=LowerTriangularMask, BiasGrad=False  |   1416.3  |   1322.0  |   1408.0

Times are in microseconds (us).

[--------------------- attention backward (attn_bias=<class 'torch.Tensor'>) ---------------------]
                                                              |   nv1107   |   nv1106   |   nv1108 
1 threads: ----------------------------------------------------------------------------------------
      f16 64-256-12-64, p=0.0, BiasT=Tensor, BiasGrad=True    |     922.6  |     873.0  |     950.0
      f16 1-16384-16-80, p=0.0, BiasT=Tensor, BiasGrad=False  |  487628.1  |  452192.0  |  489676.0
      f16 4-16384-16-40, p=0.0, BiasT=Tensor, BiasGrad=True   |  536086.3  |  501462.0  |  548997.0
      f16 4-4096-16-80, p=0.0, BiasT=Tensor, BiasGrad=False   |   32934.7  |   31202.0  |   33046.0
      f16 4-16384-16-80, p=0.0, BiasT=Tensor, BiasGrad=False  |  524850.8  |  498805.0  |  526575.0
      f16 256-4096-16-64, p=0.0, BiasT=Tensor, BiasGrad=True  |  816832.0  |  782341.0  |  845349.0
      f16 1-2048-4-128, p=0.0, BiasT=Tensor, BiasGrad=True    |    7205.3  |    6287.0  |    7149.0
      f16 1-2048-8-128, p=0.0, BiasT=Tensor, BiasGrad=False   |    7259.8  |    6509.0  |    7269.0
      f16 1-8192-8-128, p=0.0, BiasT=Tensor, BiasGrad=True    |  123647.3  |  114378.0  |  124246.0
      f16 2-2048-4-128, p=0.0, BiasT=Tensor, BiasGrad=True    |    7274.7  |    6474.0  |    7292.0
      f16 2-4096-8-128, p=0.0, BiasT=Tensor, BiasGrad=False   |   31912.1  |   29229.0  |   31894.0
      f16 16-128-16-32, p=0.0, BiasT=Tensor, BiasGrad=True    |     245.3  |     249.0  |     253.0
      f16 16-512-16-16, p=0.0, BiasT=Tensor, BiasGrad=False   |     785.2  |     730.0  |     799.0
      f16 16-512-16-32, p=0.0, BiasT=Tensor, BiasGrad=False   |     867.6  |     814.0  |     879.0
      f16 16-512-16-64, p=0.0, BiasT=Tensor, BiasGrad=False   |    1051.0  |    1002.0  |    1088.0

Times are in microseconds (us).

We also expect to release a new version 0.0.19 once this is fixed, so that our pre-built binaries have the best performance

@danthe3rd danthe3rd pinned this issue Mar 30, 2023
@danthe3rd danthe3rd changed the title nvcc version and performance [0.0.17] nvcc version and performance Mar 30, 2023
@danthe3rd danthe3rd added the bug Something isn't working label Mar 30, 2023
@danthe3rd danthe3rd self-assigned this Mar 30, 2023
@danthe3rd danthe3rd changed the title [0.0.17] nvcc version and performance [0.0.17 / 0.0.18] nvcc version and performance Mar 30, 2023
facebook-github-bot pushed a commit that referenced this issue Apr 24, 2023
 This helps mitigate a performance regression in nvcc>11.6.
nvcc 11.8 still performs worse than 11.6, but it's not that bad now

See #712

__original_commit__ = fairinternal/xformers@42d55eb5f438ec6907836fbd22056a50076f14d5
@danthe3rd
Copy link
Contributor Author

This fixes this issue
1c73b40

@danthe3rd danthe3rd unpinned this issue May 11, 2023
@tmm1
Copy link
Contributor

tmm1 commented Aug 3, 2023

Are you able to share post-fix benchmarks, since the commit says it doesn't quite achieve 11.6's performance?

Was this determined to be something specific to xformers' usage of cuda, or something that is recommended and being used across facebook projects?

If building flash-attention with these same options would be faster, then that may be worth doing now that v2.0.4 fixed Dao-AILab/flash-attention#359

@tmm1 tmm1 mentioned this issue Aug 3, 2023
10 tasks
@danthe3rd
Copy link
Contributor Author

danthe3rd commented Aug 17, 2023

@tmm1 I don't have the post-fix benchmarks anymore unfortunately ... The issue was because the PTX optimizer (ptxas) was saving intermediate calculation (instead of recalculating them), inducing more register use, and even registers spilling to global memory - hence a big slowdown.
I don't think we have registers spilling in flash-attention (you can see when compiling in the log) so I don't think it matters for flash.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants