Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: optimize MMQ int8 tensor core performance #8062

Merged
merged 3 commits into from
Jun 24, 2024

Conversation

JohannesGaessler
Copy link
Collaborator

This PR adds the following optimizations for the CUDA MMQ kernels using int8 tensor cores:

  • After CUDA: stream-k decomposition for MMQ #8018 the restrictions on tile sizes have become much looser so I changed the tile sizes for >= 48 tokens in such a way that less data needs to be loaded from shared memory per tensor core operation. This necessitates that the tile size in terms of parallel tokens is a multiple of 16 (instead of 8). As a side effect that also means that fewer kernel versions need to be compiled so compile time and binary size go down.
  • I changed the shared memory layout for the quantized data in such a way that there are fewer shared memory bank conflicts when loading the data.
  • I exposed the ldmatrix PTX instruction to load data in blocks of 16 bytes instead of 4.
Performance vs. master MMQ
GPU Model Microbatch size Test t/s master t/s cuda-mmq-2xa-3 Speedup
RTX 4090 llama 8B Q2_K_M 16 pp2048 1959.31 2058.73 1.05
RTX 4090 llama 8B Q2_K_M 32 pp2048 3041.88 3093.97 1.02
RTX 4090 llama 8B Q2_K_M 63 pp2048 4327.05 4566.56 1.06
RTX 4090 llama 8B Q2_K_M 128 pp2048 5779.88 6088.58 1.05
RTX 4090 llama 8B Q2_K_M 256 pp2048 7145.78 7726.39 1.08
RTX 4090 llama 8B Q2_K_M 512 pp2048 7553.00 8194.04 1.08
RTX 4090 llama 8B Q2_K_M 1024 pp2048 7458.42 8046.76 1.08
RTX 4090 llama 8B Q2_K_M 2048 pp2048 6906.43 7424.69 1.08
RTX 4090 llama 8B Q3_K_S 16 pp2048 1777.22 1945.13 1.09
RTX 4090 llama 8B Q3_K_S 32 pp2048 3015.02 3090.91 1.03
RTX 4090 llama 8B Q3_K_S 63 pp2048 4620.00 4446.98 0.96
RTX 4090 llama 8B Q3_K_S 128 pp2048 6713.58 6592.28 0.98
RTX 4090 llama 8B Q3_K_S 256 pp2048 8489.05 8324.30 0.98
RTX 4090 llama 8B Q3_K_S 512 pp2048 8968.09 8768.22 0.98
RTX 4090 llama 8B Q3_K_S 1024 pp2048 8819.05 8589.54 0.97
RTX 4090 llama 8B Q3_K_S 2048 pp2048 8020.62 7882.24 0.98
RTX 4090 llama 8B Q4_0 16 pp2048 1970.65 1988.00 1.01
RTX 4090 llama 8B Q4_0 32 pp2048 3431.49 3445.99 1.00
RTX 4090 llama 8B Q4_0 63 pp2048 5357.19 5284.59 0.99
RTX 4090 llama 8B Q4_0 128 pp2048 7509.03 7502.92 1.00
RTX 4090 llama 8B Q4_0 256 pp2048 9609.65 9726.08 1.01
RTX 4090 llama 8B Q4_0 512 pp2048 10285.00 10683.60 1.04
RTX 4090 llama 8B Q4_0 1024 pp2048 10018.49 10364.80 1.03
RTX 4090 llama 8B Q4_0 2048 pp2048 9155.64 9417.97 1.03
RTX 4090 llama 8B Q4_1 16 pp2048 1862.71 1874.43 1.01
RTX 4090 llama 8B Q4_1 32 pp2048 2979.06 3018.03 1.01
RTX 4090 llama 8B Q4_1 63 pp2048 5145.03 5312.52 1.03
RTX 4090 llama 8B Q4_1 128 pp2048 7229.84 7462.47 1.03
RTX 4090 llama 8B Q4_1 256 pp2048 9130.95 9438.28 1.03
RTX 4090 llama 8B Q4_1 512 pp2048 9747.77 10059.40 1.03
RTX 4090 llama 8B Q4_1 1024 pp2048 9545.98 9830.56 1.03
RTX 4090 llama 8B Q4_1 2048 pp2048 8699.96 8906.72 1.02
RTX 4090 llama 8B Q4_K_S 16 pp2048 1982.75 1987.60 1.00
RTX 4090 llama 8B Q4_K_S 32 pp2048 3361.63 3500.66 1.04
RTX 4090 llama 8B Q4_K_S 63 pp2048 5149.95 5331.79 1.04
RTX 4090 llama 8B Q4_K_S 128 pp2048 6994.57 7288.84 1.04
RTX 4090 llama 8B Q4_K_S 256 pp2048 8788.43 9232.26 1.05
RTX 4090 llama 8B Q4_K_S 512 pp2048 9403.70 9877.78 1.05
RTX 4090 llama 8B Q4_K_S 1024 pp2048 9294.39 9741.86 1.05
RTX 4090 llama 8B Q4_K_S 2048 pp2048 8552.56 8897.37 1.04
RTX 4090 llama 8B Q5_0 16 pp2048 1660.02 1713.62 1.03
RTX 4090 llama 8B Q5_0 32 pp2048 3029.92 3073.33 1.01
RTX 4090 llama 8B Q5_0 63 pp2048 4691.46 5140.13 1.10
RTX 4090 llama 8B Q5_0 128 pp2048 6847.64 7470.47 1.09
RTX 4090 llama 8B Q5_0 256 pp2048 8927.18 9696.99 1.09
RTX 4090 llama 8B Q5_0 512 pp2048 9707.85 10536.44 1.09
RTX 4090 llama 8B Q5_0 1024 pp2048 9529.31 10299.92 1.08
RTX 4090 llama 8B Q5_0 2048 pp2048 8725.90 9363.21 1.07
RTX 4090 llama 8B Q5_1 16 pp2048 1572.10 1627.59 1.04
RTX 4090 llama 8B Q5_1 32 pp2048 2527.92 2727.62 1.08
RTX 4090 llama 8B Q5_1 63 pp2048 4273.50 4395.52 1.03
RTX 4090 llama 8B Q5_1 128 pp2048 6769.11 7198.66 1.06
RTX 4090 llama 8B Q5_1 256 pp2048 8493.96 9079.53 1.07
RTX 4090 llama 8B Q5_1 512 pp2048 9064.31 9827.46 1.08
RTX 4090 llama 8B Q5_1 1024 pp2048 9008.68 9696.96 1.08
RTX 4090 llama 8B Q5_1 2048 pp2048 8304.00 8897.12 1.07
RTX 4090 llama 8B Q5_K_S 16 pp2048 1690.31 1717.22 1.02
RTX 4090 llama 8B Q5_K_S 32 pp2048 2760.78 3021.09 1.09
RTX 4090 llama 8B Q5_K_S 63 pp2048 4415.75 4883.11 1.11
RTX 4090 llama 8B Q5_K_S 128 pp2048 6646.87 6905.05 1.04
RTX 4090 llama 8B Q5_K_S 256 pp2048 8316.00 8707.63 1.05
RTX 4090 llama 8B Q5_K_S 512 pp2048 8924.37 9348.09 1.05
RTX 4090 llama 8B Q5_K_S 1024 pp2048 8858.57 9265.80 1.05
RTX 4090 llama 8B Q5_K_S 2048 pp2048 8203.50 8525.00 1.04
RTX 4090 llama 8B Q6_K 16 pp2048 1445.48 1466.03 1.01
RTX 4090 llama 8B Q6_K 32 pp2048 2736.55 2774.79 1.01
RTX 4090 llama 8B Q6_K 63 pp2048 4473.27 4541.41 1.02
RTX 4090 llama 8B Q6_K 128 pp2048 6516.45 6720.65 1.03
RTX 4090 llama 8B Q6_K 256 pp2048 8328.57 8696.60 1.04
RTX 4090 llama 8B Q6_K 512 pp2048 9004.15 9461.48 1.05
RTX 4090 llama 8B Q6_K 1024 pp2048 8816.76 9285.52 1.05
RTX 4090 llama 8B Q6_K 2048 pp2048 8056.57 8451.71 1.05
RTX 4090 llama 8B Q8_0 16 pp2048 1271.09 1264.01 0.99
RTX 4090 llama 8B Q8_0 32 pp2048 2362.24 2369.64 1.00
RTX 4090 llama 8B Q8_0 63 pp2048 4111.23 4141.30 1.01
RTX 4090 llama 8B Q8_0 128 pp2048 6820.92 6870.90 1.01
RTX 4090 llama 8B Q8_0 256 pp2048 9500.61 9903.76 1.04
RTX 4090 llama 8B Q8_0 512 pp2048 10607.57 11209.17 1.06
RTX 4090 llama 8B Q8_0 1024 pp2048 10481.10 11106.60 1.06
RTX 4090 llama 8B Q8_0 2048 pp2048 9610.52 10098.05 1.05
RTX 3090 llama 8B Q2_K_M 16 pp2048 1048.92 1083.08 1.03
RTX 3090 llama 8B Q2_K_M 32 pp2048 1535.68 1531.48 1.00
RTX 3090 llama 8B Q2_K_M 63 pp2048 2053.18 2165.83 1.05
RTX 3090 llama 8B Q2_K_M 128 pp2048 2477.58 2625.68 1.06
RTX 3090 llama 8B Q2_K_M 256 pp2048 2753.83 2906.06 1.06
RTX 3090 llama 8B Q2_K_M 512 pp2048 2848.71 3033.38 1.06
RTX 3090 llama 8B Q2_K_M 1024 pp2048 2885.83 3051.31 1.06
RTX 3090 llama 8B Q2_K_M 2048 pp2048 2831.48 3008.83 1.06
RTX 3090 llama 8B Q3_K_S 16 pp2048 981.92 1109.08 1.13
RTX 3090 llama 8B Q3_K_S 32 pp2048 1583.54 1595.45 1.01
RTX 3090 llama 8B Q3_K_S 63 pp2048 2242.77 2119.10 0.94
RTX 3090 llama 8B Q3_K_S 128 pp2048 2957.15 2854.64 0.97
RTX 3090 llama 8B Q3_K_S 256 pp2048 3296.41 3161.27 0.96
RTX 3090 llama 8B Q3_K_S 512 pp2048 3400.63 3258.34 0.96
RTX 3090 llama 8B Q3_K_S 1024 pp2048 3455.97 3309.41 0.96
RTX 3090 llama 8B Q3_K_S 2048 pp2048 3374.90 3232.75 0.96
RTX 3090 llama 8B Q4_0 16 pp2048 1238.54 1301.00 1.05
RTX 3090 llama 8B Q4_0 32 pp2048 1929.33 1940.42 1.01
RTX 3090 llama 8B Q4_0 63 pp2048 2727.29 2722.48 1.00
RTX 3090 llama 8B Q4_0 128 pp2048 3469.76 3491.56 1.01
RTX 3090 llama 8B Q4_0 256 pp2048 3885.53 3924.80 1.01
RTX 3090 llama 8B Q4_0 512 pp2048 4030.46 4124.14 1.02
RTX 3090 llama 8B Q4_0 1024 pp2048 4043.94 4138.45 1.02
RTX 3090 llama 8B Q4_0 2048 pp2048 3911.72 4000.28 1.02
RTX 3090 llama 8B Q4_1 16 pp2048 1371.83 1409.82 1.03
RTX 3090 llama 8B Q4_1 32 pp2048 1794.42 1754.93 0.98
RTX 3090 llama 8B Q4_1 63 pp2048 2586.43 2676.11 1.03
RTX 3090 llama 8B Q4_1 128 pp2048 3271.37 3354.55 1.03
RTX 3090 llama 8B Q4_1 256 pp2048 3625.96 3734.81 1.03
RTX 3090 llama 8B Q4_1 512 pp2048 3763.37 3876.53 1.03
RTX 3090 llama 8B Q4_1 1024 pp2048 3791.94 3884.89 1.02
RTX 3090 llama 8B Q4_1 2048 pp2048 3650.00 3776.34 1.03
RTX 3090 llama 8B Q4_K_S 16 pp2048 1304.90 1343.82 1.03
RTX 3090 llama 8B Q4_K_S 32 pp2048 1820.00 1998.71 1.10
RTX 3090 llama 8B Q4_K_S 63 pp2048 2491.79 2656.05 1.07
RTX 3090 llama 8B Q4_K_S 128 pp2048 3096.96 3205.98 1.04
RTX 3090 llama 8B Q4_K_S 256 pp2048 3449.95 3588.17 1.04
RTX 3090 llama 8B Q4_K_S 512 pp2048 3600.86 3736.54 1.04
RTX 3090 llama 8B Q4_K_S 1024 pp2048 3645.23 3780.41 1.04
RTX 3090 llama 8B Q4_K_S 2048 pp2048 3568.97 3683.96 1.03
RTX 3090 llama 8B Q5_0 16 pp2048 1023.34 1094.85 1.07
RTX 3090 llama 8B Q5_0 32 pp2048 1747.94 1804.44 1.03
RTX 3090 llama 8B Q5_0 63 pp2048 2392.03 2744.89 1.15
RTX 3090 llama 8B Q5_0 128 pp2048 3233.48 3444.51 1.07
RTX 3090 llama 8B Q5_0 256 pp2048 3619.52 3870.99 1.07
RTX 3090 llama 8B Q5_0 512 pp2048 3765.05 4056.88 1.08
RTX 3090 llama 8B Q5_0 1024 pp2048 3769.53 4098.97 1.09
RTX 3090 llama 8B Q5_0 2048 pp2048 3664.35 3941.37 1.08
RTX 3090 llama 8B Q5_1 16 pp2048 1041.17 1125.51 1.08
RTX 3090 llama 8B Q5_1 32 pp2048 1508.22 1632.28 1.08
RTX 3090 llama 8B Q5_1 63 pp2048 2280.39 2443.51 1.07
RTX 3090 llama 8B Q5_1 128 pp2048 3048.77 3207.34 1.05
RTX 3090 llama 8B Q5_1 256 pp2048 3377.53 3569.53 1.06
RTX 3090 llama 8B Q5_1 512 pp2048 3517.54 3743.99 1.06
RTX 3090 llama 8B Q5_1 1024 pp2048 3547.11 3794.01 1.07
RTX 3090 llama 8B Q5_1 2048 pp2048 3431.20 3676.47 1.07
RTX 3090 llama 8B Q5_K_S 16 pp2048 1083.87 1143.45 1.05
RTX 3090 llama 8B Q5_K_S 32 pp2048 1562.63 1737.00 1.11
RTX 3090 llama 8B Q5_K_S 63 pp2048 2212.98 2400.75 1.08
RTX 3090 llama 8B Q5_K_S 128 pp2048 2934.14 3021.05 1.03
RTX 3090 llama 8B Q5_K_S 256 pp2048 3284.45 3375.86 1.03
RTX 3090 llama 8B Q5_K_S 512 pp2048 3412.32 3532.41 1.04
RTX 3090 llama 8B Q5_K_S 1024 pp2048 3453.93 3553.91 1.03
RTX 3090 llama 8B Q5_K_S 2048 pp2048 3387.06 3478.17 1.03
RTX 3090 llama 8B Q6_K 16 pp2048 991.17 1011.44 1.02
RTX 3090 llama 8B Q6_K 32 pp2048 1660.10 1715.38 1.03
RTX 3090 llama 8B Q6_K 63 pp2048 2385.54 2456.09 1.03
RTX 3090 llama 8B Q6_K 128 pp2048 2957.59 3075.06 1.04
RTX 3090 llama 8B Q6_K 256 pp2048 3289.25 3459.49 1.05
RTX 3090 llama 8B Q6_K 512 pp2048 3433.88 3596.13 1.05
RTX 3090 llama 8B Q6_K 1024 pp2048 3467.94 3646.15 1.05
RTX 3090 llama 8B Q6_K 2048 pp2048 3385.46 3539.88 1.05
RTX 3090 llama 8B Q8_0 16 pp2048 982.24 995.66 1.01
RTX 3090 llama 8B Q8_0 32 pp2048 1706.93 1750.15 1.03
RTX 3090 llama 8B Q8_0 63 pp2048 2597.08 2707.43 1.04
RTX 3090 llama 8B Q8_0 128 pp2048 3516.55 3651.09 1.04
RTX 3090 llama 8B Q8_0 256 pp2048 4043.19 4205.97 1.04
RTX 3090 llama 8B Q8_0 512 pp2048 4209.31 4425.70 1.05
RTX 3090 llama 8B Q8_0 1024 pp2048 4211.55 4466.05 1.06
RTX 3090 llama 8B Q8_0 2048 pp2048 4070.12 4294.37 1.06
Performance vs. master FP16 cuBLAS
GPU Model Microbatch size Test t/s master t/s cuda-mmq-2xa-3 Speedup
RTX 4090 llama 8B Q2_K_M 16 pp2048 1958.10 2056.25 1.05
RTX 4090 llama 8B Q2_K_M 32 pp2048 3042.70 3088.43 1.02
RTX 4090 llama 8B Q2_K_M 63 pp2048 4365.36 4559.46 1.04
RTX 4090 llama 8B Q2_K_M 128 pp2048 3645.77 6110.00 1.68
RTX 4090 llama 8B Q2_K_M 256 pp2048 5901.11 7720.55 1.31
RTX 4090 llama 8B Q2_K_M 512 pp2048 7772.35 8107.88 1.04
RTX 4090 llama 8B Q2_K_M 1024 pp2048 9004.92 8059.08 0.89
RTX 4090 llama 8B Q2_K_M 2048 pp2048 8893.63 7388.58 0.83
RTX 4090 llama 8B Q3_K_S 16 pp2048 1768.07 1938.69 1.10
RTX 4090 llama 8B Q3_K_S 32 pp2048 3001.33 3079.55 1.03
RTX 4090 llama 8B Q3_K_S 63 pp2048 4611.69 4442.86 0.96
RTX 4090 llama 8B Q3_K_S 128 pp2048 3543.05 6586.41 1.86
RTX 4090 llama 8B Q3_K_S 256 pp2048 5776.37 8308.35 1.44
RTX 4090 llama 8B Q3_K_S 512 pp2048 7675.56 8721.77 1.14
RTX 4090 llama 8B Q3_K_S 1024 pp2048 8966.03 8576.01 0.96
RTX 4090 llama 8B Q3_K_S 2048 pp2048 8880.83 7863.38 0.89
RTX 4090 llama 8B Q4_0 16 pp2048 1968.84 1982.97 1.01
RTX 4090 llama 8B Q4_0 32 pp2048 3408.90 3447.84 1.01
RTX 4090 llama 8B Q4_0 63 pp2048 5340.36 5289.94 0.99
RTX 4090 llama 8B Q4_0 128 pp2048 3476.02 7461.56 2.15
RTX 4090 llama 8B Q4_0 256 pp2048 5746.13 9681.05 1.68
RTX 4090 llama 8B Q4_0 512 pp2048 7762.86 10600.17 1.37
RTX 4090 llama 8B Q4_0 1024 pp2048 9046.57 10342.53 1.14
RTX 4090 llama 8B Q4_0 2048 pp2048 9023.52 9366.15 1.04
RTX 4090 llama 8B Q4_1 16 pp2048 1861.62 1878.00 1.01
RTX 4090 llama 8B Q4_1 32 pp2048 2968.46 3015.56 1.02
RTX 4090 llama 8B Q4_1 63 pp2048 5127.76 5287.13 1.03
RTX 4090 llama 8B Q4_1 128 pp2048 3467.76 7429.16 2.14
RTX 4090 llama 8B Q4_1 256 pp2048 5686.38 9363.62 1.65
RTX 4090 llama 8B Q4_1 512 pp2048 7623.50 10010.90 1.31
RTX 4090 llama 8B Q4_1 1024 pp2048 8924.97 9762.37 1.09
RTX 4090 llama 8B Q4_1 2048 pp2048 8823.83 8862.15 1.00
RTX 4090 llama 8B Q4_K_S 16 pp2048 1982.94 1999.60 1.01
RTX 4090 llama 8B Q4_K_S 32 pp2048 3368.79 3518.83 1.04
RTX 4090 llama 8B Q4_K_S 63 pp2048 5144.50 5334.99 1.04
RTX 4090 llama 8B Q4_K_S 128 pp2048 3475.47 7283.99 2.10
RTX 4090 llama 8B Q4_K_S 256 pp2048 5745.24 9216.08 1.60
RTX 4090 llama 8B Q4_K_S 512 pp2048 7637.48 9867.40 1.29
RTX 4090 llama 8B Q4_K_S 1024 pp2048 8954.20 9664.16 1.08
RTX 4090 llama 8B Q4_K_S 2048 pp2048 8975.33 8827.73 0.98
RTX 4090 llama 8B Q5_0 16 pp2048 1656.90 1715.82 1.04
RTX 4090 llama 8B Q5_0 32 pp2048 3017.03 3073.50 1.02
RTX 4090 llama 8B Q5_0 63 pp2048 4674.71 5137.43 1.10
RTX 4090 llama 8B Q5_0 128 pp2048 3386.65 7461.09 2.20
RTX 4090 llama 8B Q5_0 256 pp2048 5618.24 9655.87 1.72
RTX 4090 llama 8B Q5_0 512 pp2048 7578.65 10491.78 1.38
RTX 4090 llama 8B Q5_0 1024 pp2048 8898.76 10266.71 1.15
RTX 4090 llama 8B Q5_0 2048 pp2048 8957.79 9336.48 1.04
RTX 4090 llama 8B Q5_1 16 pp2048 1567.99 1627.41 1.04
RTX 4090 llama 8B Q5_1 32 pp2048 2529.22 2726.07 1.08
RTX 4090 llama 8B Q5_1 63 pp2048 4257.82 4379.58 1.03
RTX 4090 llama 8B Q5_1 128 pp2048 3425.06 7162.23 2.09
RTX 4090 llama 8B Q5_1 256 pp2048 5601.90 9040.43 1.61
RTX 4090 llama 8B Q5_1 512 pp2048 7515.86 9783.51 1.30
RTX 4090 llama 8B Q5_1 1024 pp2048 8842.66 9664.97 1.09
RTX 4090 llama 8B Q5_1 2048 pp2048 8889.69 8856.04 1.00
RTX 4090 llama 8B Q5_K_S 16 pp2048 1687.77 1710.00 1.01
RTX 4090 llama 8B Q5_K_S 32 pp2048 2760.63 3001.93 1.09
RTX 4090 llama 8B Q5_K_S 63 pp2048 4421.52 4867.15 1.10
RTX 4090 llama 8B Q5_K_S 128 pp2048 3455.52 6889.52 1.99
RTX 4090 llama 8B Q5_K_S 256 pp2048 5677.46 8695.14 1.53
RTX 4090 llama 8B Q5_K_S 512 pp2048 7601.29 9319.90 1.23
RTX 4090 llama 8B Q5_K_S 1024 pp2048 8894.83 9237.02 1.04
RTX 4090 llama 8B Q5_K_S 2048 pp2048 8911.79 8500.17 0.95
RTX 4090 llama 8B Q6_K 16 pp2048 1447.89 1468.62 1.01
RTX 4090 llama 8B Q6_K 32 pp2048 2741.57 2780.23 1.01
RTX 4090 llama 8B Q6_K 63 pp2048 4472.51 4543.34 1.02
RTX 4090 llama 8B Q6_K 128 pp2048 3385.99 6716.01 1.98
RTX 4090 llama 8B Q6_K 256 pp2048 5577.95 8680.28 1.56
RTX 4090 llama 8B Q6_K 512 pp2048 7491.83 9425.78 1.26
RTX 4090 llama 8B Q6_K 1024 pp2048 8701.13 9252.84 1.06
RTX 4090 llama 8B Q6_K 2048 pp2048 8602.28 8408.26 0.98
RTX 4090 llama 8B Q8_0 16 pp2048 1271.44 1264.01 0.99
RTX 4090 llama 8B Q8_0 32 pp2048 2362.88 2372.23 1.00
RTX 4090 llama 8B Q8_0 63 pp2048 4123.72 4128.31 1.00
RTX 4090 llama 8B Q8_0 128 pp2048 3304.49 6850.57 2.07
RTX 4090 llama 8B Q8_0 256 pp2048 5472.11 9855.14 1.80
RTX 4090 llama 8B Q8_0 512 pp2048 7388.37 11189.63 1.51
RTX 4090 llama 8B Q8_0 1024 pp2048 8714.98 11046.46 1.27
RTX 4090 llama 8B Q8_0 2048 pp2048 8850.18 10048.28 1.14
RTX 3090 llama 8B Q2_K_M 16 pp2048 1026.74 1078.88 1.05
RTX 3090 llama 8B Q2_K_M 32 pp2048 1489.03 1526.54 1.03
RTX 3090 llama 8B Q2_K_M 63 pp2048 1993.09 2160.55 1.08
RTX 3090 llama 8B Q2_K_M 128 pp2048 2181.14 2616.21 1.20
RTX 3090 llama 8B Q2_K_M 256 pp2048 3220.43 2912.93 0.90
RTX 3090 llama 8B Q2_K_M 512 pp2048 3802.51 3020.78 0.79
RTX 3090 llama 8B Q2_K_M 1024 pp2048 4440.76 3049.70 0.69
RTX 3090 llama 8B Q2_K_M 2048 pp2048 4489.62 2999.00 0.67
RTX 3090 llama 8B Q3_K_S 16 pp2048 967.25 1107.61 1.15
RTX 3090 llama 8B Q3_K_S 32 pp2048 1547.20 1592.01 1.03
RTX 3090 llama 8B Q3_K_S 63 pp2048 2200.63 2111.31 0.96
RTX 3090 llama 8B Q3_K_S 128 pp2048 2021.43 2855.68 1.41
RTX 3090 llama 8B Q3_K_S 256 pp2048 3036.72 3155.41 1.04
RTX 3090 llama 8B Q3_K_S 512 pp2048 3671.35 3266.99 0.89
RTX 3090 llama 8B Q3_K_S 1024 pp2048 4362.81 3304.01 0.76
RTX 3090 llama 8B Q3_K_S 2048 pp2048 4443.61 3233.17 0.73
RTX 3090 llama 8B Q4_0 16 pp2048 1209.55 1298.82 1.07
RTX 3090 llama 8B Q4_0 32 pp2048 1859.34 1943.08 1.05
RTX 3090 llama 8B Q4_0 63 pp2048 2656.01 2730.11 1.03
RTX 3090 llama 8B Q4_0 128 pp2048 2287.37 3497.90 1.53
RTX 3090 llama 8B Q4_0 256 pp2048 3345.17 3926.42 1.17
RTX 3090 llama 8B Q4_0 512 pp2048 3938.75 4110.44 1.04
RTX 3090 llama 8B Q4_0 1024 pp2048 4570.93 4140.54 0.91
RTX 3090 llama 8B Q4_0 2048 pp2048 4583.01 3992.97 0.87
RTX 3090 llama 8B Q4_1 16 pp2048 1326.28 1404.33 1.06
RTX 3090 llama 8B Q4_1 32 pp2048 1732.38 1755.45 1.01
RTX 3090 llama 8B Q4_1 63 pp2048 2500.20 2677.78 1.07
RTX 3090 llama 8B Q4_1 128 pp2048 2251.75 3358.79 1.49
RTX 3090 llama 8B Q4_1 256 pp2048 3303.73 3727.24 1.13
RTX 3090 llama 8B Q4_1 512 pp2048 3874.91 3876.52 1.00
RTX 3090 llama 8B Q4_1 1024 pp2048 4501.06 3904.43 0.87
RTX 3090 llama 8B Q4_1 2048 pp2048 4538.51 3740.23 0.82
RTX 3090 llama 8B Q4_K_S 16 pp2048 1293.09 1340.95 1.04
RTX 3090 llama 8B Q4_K_S 32 pp2048 1788.69 1986.87 1.11
RTX 3090 llama 8B Q4_K_S 63 pp2048 2451.97 2653.12 1.08
RTX 3090 llama 8B Q4_K_S 128 pp2048 2211.32 3198.33 1.45
RTX 3090 llama 8B Q4_K_S 256 pp2048 3258.42 3562.22 1.09
RTX 3090 llama 8B Q4_K_S 512 pp2048 3820.36 3730.15 0.98
RTX 3090 llama 8B Q4_K_S 1024 pp2048 4469.41 3781.98 0.85
RTX 3090 llama 8B Q4_K_S 2048 pp2048 4501.87 3677.75 0.82
RTX 3090 llama 8B Q5_0 16 pp2048 992.20 1093.29 1.10
RTX 3090 llama 8B Q5_0 32 pp2048 1699.33 1795.21 1.06
RTX 3090 llama 8B Q5_0 63 pp2048 2321.07 2734.43 1.18
RTX 3090 llama 8B Q5_0 128 pp2048 2127.93 3417.65 1.61
RTX 3090 llama 8B Q5_0 256 pp2048 3151.53 3840.28 1.22
RTX 3090 llama 8B Q5_0 512 pp2048 3756.03 4028.02 1.07
RTX 3090 llama 8B Q5_0 1024 pp2048 4428.28 4075.31 0.92
RTX 3090 llama 8B Q5_0 2048 pp2048 4492.74 3950.17 0.88
RTX 3090 llama 8B Q5_1 16 pp2048 1031.04 1126.66 1.09
RTX 3090 llama 8B Q5_1 32 pp2048 1490.19 1623.55 1.09
RTX 3090 llama 8B Q5_1 63 pp2048 2214.86 2430.90 1.10
RTX 3090 llama 8B Q5_1 128 pp2048 2115.25 3185.31 1.51
RTX 3090 llama 8B Q5_1 256 pp2048 3138.65 3534.36 1.13
RTX 3090 llama 8B Q5_1 512 pp2048 3760.41 3723.92 0.99
RTX 3090 llama 8B Q5_1 1024 pp2048 4400.37 3774.57 0.86
RTX 3090 llama 8B Q5_1 2048 pp2048 4490.87 3691.30 0.82
RTX 3090 llama 8B Q5_K_S 16 pp2048 1066.05 1142.46 1.07
RTX 3090 llama 8B Q5_K_S 32 pp2048 1535.05 1734.99 1.13
RTX 3090 llama 8B Q5_K_S 63 pp2048 2176.52 2401.10 1.10
RTX 3090 llama 8B Q5_K_S 128 pp2048 2170.79 3021.91 1.39
RTX 3090 llama 8B Q5_K_S 256 pp2048 3192.81 3366.21 1.05
RTX 3090 llama 8B Q5_K_S 512 pp2048 3740.39 3532.97 0.94
RTX 3090 llama 8B Q5_K_S 1024 pp2048 4393.43 3564.43 0.81
RTX 3090 llama 8B Q5_K_S 2048 pp2048 4438.97 3481.66 0.78
RTX 3090 llama 8B Q6_K 16 pp2048 968.32 1010.00 1.04
RTX 3090 llama 8B Q6_K 32 pp2048 1629.22 1714.42 1.05
RTX 3090 llama 8B Q6_K 63 pp2048 2354.68 2462.10 1.05
RTX 3090 llama 8B Q6_K 128 pp2048 2184.32 3082.13 1.41
RTX 3090 llama 8B Q6_K 256 pp2048 3214.53 3457.05 1.08
RTX 3090 llama 8B Q6_K 512 pp2048 3767.91 3614.22 0.96
RTX 3090 llama 8B Q6_K 1024 pp2048 4395.47 3635.22 0.83
RTX 3090 llama 8B Q6_K 2048 pp2048 4422.92 3537.11 0.80
RTX 3090 llama 8B Q8_0 16 pp2048 975.07 995.69 1.02
RTX 3090 llama 8B Q8_0 32 pp2048 1681.55 1743.37 1.04
RTX 3090 llama 8B Q8_0 63 pp2048 2538.27 2698.21 1.06
RTX 3090 llama 8B Q8_0 128 pp2048 2149.75 3635.15 1.69
RTX 3090 llama 8B Q8_0 256 pp2048 3200.75 4186.02 1.31
RTX 3090 llama 8B Q8_0 512 pp2048 3772.73 4406.76 1.17
RTX 3090 llama 8B Q8_0 1024 pp2048 4429.63 4462.88 1.01
RTX 3090 llama 8B Q8_0 2048 pp2048 4485.71 4293.24 0.96

I now consider the performance good enough that I think MMQ should be made the default again; the performance for small quants is still suboptimal but for those I think the memory savings outweigh the hit to speed. I would prefer to do the default change in a separate PR.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jun 22, 2024
ggml-cuda/mmq.cuh Outdated Show resolved Hide resolved
ggml-cuda/mmq.cuh Outdated Show resolved Hide resolved
Copy link
Collaborator

@slaren slaren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see similar improvements with 3080 and 3090 Ti.

@JohannesGaessler JohannesGaessler merged commit 9a590c8 into ggerganov:master Jun 24, 2024
53 of 58 checks passed
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Jun 30, 2024
* CUDA: optimize MMQ int8 tensor core performance

* only a single get_mma_tile_x_k function

* simplify code, make functions constexpr
MagnusS0 pushed a commit to MagnusS0/llama.cpp-normistral-tokenizer that referenced this pull request Jul 1, 2024
* CUDA: optimize MMQ int8 tensor core performance

* only a single get_mma_tile_x_k function

* simplify code, make functions constexpr
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants