Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda : add batched cuBLAS GEMM for faster attention #3749

Merged
merged 10 commits into from
Oct 24, 2023
Merged

Conversation

ggerganov
Copy link
Owner

@ggerganov ggerganov commented Oct 23, 2023

resolve #3479

Description

For baseline performance on master, check the info in #3726

We identified the KQ and KQV operations in the attention layer as bottleneck. To overcome this, we parallelize the GEMMs across the head dimension using cublasGemmBatchedEx.

  • The overall text-generation speedup with more than 1 batch is significant - up to 2-3 times compared to master
  • The prompt processing speed is also improved for all models
  • The single-batch text generation speed is improved for large contexts (it's the same, but maybe can be improved for large contexts)

I did some perplexity tests as well and I think everything works correctly, although I tested mostly F16 models. I don't think the accuracy of quantum models would be affected by this change, but still - need to double check.

Please take a careful look and let me know if you observe improved performance and if the ppl is good

Results

Below are some results for A100. Also did similar tests for V100 and the results are very similar.

Batched decoding for 7B F16 models

LLAMA_CUBLAS=1 make -j && ./batched-bench /workspace/openllama-7b/ggml-model-f16.gguf 4608 1 99 0 512 128 1,2,3,4,5,6,7,8,16,32
  • master
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.126 4063.46 1.973 64.88 2.099 304.91
512 128 2 768 0.130 3941.89 6.521 39.26 6.651 115.48
512 128 3 896 0.118 4327.61 6.988 54.95 7.106 126.09
512 128 4 1024 0.111 4597.99 6.209 82.47 6.320 162.03
512 128 5 1152 0.110 4664.21 7.266 88.09 7.375 156.20
512 128 6 1280 0.108 4745.09 7.300 105.20 7.408 172.78
512 128 7 1408 0.111 4632.48 7.369 121.60 7.479 188.26
512 128 8 1536 0.111 4603.20 7.200 142.22 7.312 210.08
512 128 16 2560 0.111 4602.00 7.168 285.71 7.279 351.68
512 128 32 4608 0.113 4521.81 7.693 532.46 7.806 590.32
  • PR
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.091 5648.97 1.839 69.61 1.929 331.71
512 128 2 768 0.087 5876.55 2.023 126.57 2.110 364.02
512 128 3 896 0.090 5690.47 2.064 186.01 2.154 415.89
512 128 4 1024 0.090 5683.21 2.096 244.30 2.186 468.46
512 128 5 1152 0.090 5661.59 2.158 296.58 2.248 512.38
512 128 6 1280 0.090 5687.18 2.185 351.47 2.275 562.61
512 128 7 1408 0.090 5709.32 2.231 401.60 2.321 606.70
512 128 8 1536 0.091 5645.36 2.268 451.60 2.358 651.34
512 128 16 2560 0.091 5640.63 2.784 735.57 2.875 890.43
512 128 32 4608 0.092 5550.26 2.841 1441.86 2.933 1571.08

Batched decoding for 1B F16 models

LLAMA_CUBLAS=1 make -j && ./batched-bench /workspace/tinyllama-1b/ggml-model-f16.gguf 4608 1 99 0 512 128 1,2,3,4,5,6,7,8,16,32
  • master
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.076 6716.34 0.662 193.33 0.738 866.82
512 128 2 768 0.060 8520.84 3.218 79.56 3.278 234.29
512 128 3 896 0.064 7995.38 3.258 117.86 3.322 269.70
512 128 4 1024 0.054 9441.96 3.302 155.05 3.356 305.08
512 128 5 1152 0.064 7943.90 3.259 196.39 3.323 346.65
512 128 6 1280 0.053 9593.22 3.325 230.98 3.378 378.88
512 128 7 1408 0.066 7770.41 3.425 261.62 3.491 403.36
512 128 8 1536 0.054 9464.83 3.463 295.66 3.518 436.67
512 128 16 2560 0.050 10192.71 3.369 607.92 3.419 748.74
512 128 32 4608 0.056 9119.08 3.770 1086.37 3.826 1204.24
  • PR
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.041 12337.94 0.560 228.57 0.602 1063.99
512 128 2 768 0.035 14805.39 0.831 307.97 0.866 887.01
512 128 3 896 0.027 18887.41 0.918 418.19 0.945 947.79
512 128 4 1024 0.040 12897.05 0.976 524.85 1.015 1008.65
512 128 5 1152 0.071 7234.19 0.983 651.16 1.054 1093.36
512 128 6 1280 0.040 12742.66 0.984 780.35 1.024 1249.58
512 128 7 1408 0.032 15873.02 1.017 880.62 1.050 1341.30
512 128 8 1536 0.040 12690.23 1.037 987.82 1.077 1426.22
512 128 16 2560 0.032 15780.07 1.309 1565.10 1.341 1909.04
512 128 32 4608 0.041 12502.75 1.374 2980.22 1.415 3255.74

llama-bench for 7B F16

make -j && ../scripts/run-all-perf.sh openllama-7b "f16" "-ngl 99 -t 4 -p 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,64,128,256,512,1024,2048 -n 128"
  • master
model size params backend ngl threads test t/s
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 1 37.82 ± 11.16
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 2 50.23 ± 0.07
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 3 73.45 ± 0.23
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 4 109.46 ± 0.08
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 5 120.57 ± 0.22
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 6 146.11 ± 0.14
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 7 166.63 ± 0.10
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 8 214.78 ± 0.17
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 9 212.03 ± 2.23
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 10 237.94 ± 0.12
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 11 259.16 ± 0.09
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 12 316.69 ± 2.23
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 13 302.72 ± 0.68
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 14 328.28 ± 0.88
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 15 347.02 ± 1.11
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 16 415.30 ± 1.38
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 32 721.61 ± 47.33
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 64 1338.41 ± 4.46
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 128 2538.48 ± 153.70
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 256 3924.85 ± 204.85
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 512 5089.59 ± 146.64
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 1024 4665.57 ± 24.93
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 2048 3421.92 ± 6.96
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 tg 128 76.16 ± 0.11
  • PR
model size params backend ngl threads test t/s
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 1 62.60 ± 23.28
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 2 130.64 ± 1.33
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 3 194.78 ± 2.66
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 4 259.73 ± 1.01
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 5 315.80 ± 1.48
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 6 380.62 ± 0.99
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 7 435.28 ± 11.00
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 8 498.21 ± 3.11
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 9 550.94 ± 1.53
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 10 605.28 ± 13.77
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 11 665.26 ± 2.55
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 12 716.74 ± 7.20
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 13 774.57 ± 7.29
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 14 815.00 ± 19.86
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 15 874.26 ± 7.58
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 16 930.06 ± 9.51
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 32 1585.78 ± 230.53
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 64 2809.06 ± 53.23
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 128 4990.19 ± 434.57
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 256 6431.97 ± 506.03
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 512 6777.28 ± 324.30
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 1024 6441.81 ± 60.10
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 pp 2048 5581.89 ± 49.59
llama 7B mostly F16 12.55 GiB 6.74 B CUDA 99 4 tg 128 76.13 ± 0.16

Serving multiple clients with parallel decoding and continuous batching

# 128 sequences, 8 parallel
make -j && ./bin/parallel -m ../models/openllama-7b/ggml-model-f16.gguf -t 1 -ngl 100 -c 4096 -b 512 -s 1 -np 8 -ns 128 -n 100 -cb

main: n_parallel = 8, n_sequences = 128, cont_batching = 1, system tokens = 293
External prompt file: used built-in defaults
Model and path used:  ../models/openllama-7b/ggml-model-f16.gguf

Total prompt tokens:   2011, speed: 111.82 t/s
Total gen tokens:      6448, speed: 358.54 t/s
Total speed (AVG):           speed: 470.36 t/s
Cache misses:             0

llama_print_timings:        load time =    2515.70 ms
llama_print_timings:      sample time =     952.12 ms /  6576 runs   (    0.14 ms per token,  6906.72 tokens per second)
llama_print_timings: prompt eval time =   15236.39 ms /  8721 tokens (    1.75 ms per token,   572.38 tokens per second)
llama_print_timings:        eval time =     456.73 ms /    31 runs   (   14.73 ms per token,    67.87 tokens per second)
llama_print_timings:       total time =   17984.59 ms


# 128 sequences, 16 parallel
make -j && ./bin/parallel -m ../models/openllama-7b/ggml-model-f16.gguf -t 1 -ngl 100 -c 4096 -b 512 -s 1 -np 16 -ns 128 -n 100 -cb

main: n_parallel = 16, n_sequences = 128, cont_batching = 1, system tokens = 293
External prompt file: used built-in defaults
Model and path used:  ../models/openllama-7b/ggml-model-f16.gguf

Total prompt tokens:   2011, speed: 148.40 t/s
Total gen tokens:      6883, speed: 507.91 t/s
Total speed (AVG):           speed: 656.30 t/s
Cache misses:             0


llama_print_timings:        load time =    2484.72 ms
llama_print_timings:      sample time =     988.93 ms /  7011 runs   (    0.14 ms per token,  7089.52 tokens per second)
llama_print_timings: prompt eval time =   10928.88 ms /  9172 tokens (    1.19 ms per token,   839.24 tokens per second)
llama_print_timings:        eval time =     230.13 ms /    15 runs   (   15.34 ms per token,    65.18 tokens per second)
llama_print_timings:       total time =   13552.41 ms


# 128 sequences, 32 parallel

main: n_parallel = 32, n_sequences = 128, cont_batching = 1, system tokens = 293
External prompt file: used built-in defaults
Model and path used:  ../models/openllama-7b/ggml-model-f16.gguf

Total prompt tokens:   2011, speed: 200.03 t/s
Total gen tokens:      6491, speed: 645.64 t/s
Total speed (AVG):           speed: 845.67 t/s
Cache misses:             0


llama_print_timings:        load time =    2435.73 ms
llama_print_timings:      sample time =     923.99 ms /  6619 runs   (    0.14 ms per token,  7163.53 tokens per second)
llama_print_timings: prompt eval time =    7695.20 ms /  8790 tokens (    0.88 ms per token,  1142.27 tokens per second)
llama_print_timings:        eval time =     105.69 ms /     5 runs   (   21.14 ms per token,    47.31 tokens per second)
llama_print_timings:       total time =   10054.29 ms
Clients Master t/s PR t/s
8 247.30 470.36
16 368.59 656.30
32 422.33 845.67

After update with cublasGemmStridedBatchedEx and Codellama 7B F16:

Clients PR t/s
4 306.07
8 487.71
12 616.16
16 707.03
20 802.23
24 829.71
32 857.19

Perplexity for F16 7B

make -j perplexity && ./bin/perplexity -m ../models/openllama-7b/ggml-model-f16.gguf -f ./wikitext-2-raw/wiki.test.raw  -ngl 100 -nommq
  • master
llm_load_tensors: ggml ctx size =    0.10 MB
llm_load_tensors: using CUDA for GPU acceleration
llm_load_tensors: mem required  =  250.10 MB
llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 35/35 layers to GPU
llm_load_tensors: VRAM used: 12603.02 MB
...................................................................................................
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init: offloading v cache to GPU
llama_kv_cache_init: offloading k cache to GPU
llama_kv_cache_init: VRAM kv self = 256.00 MB
llama_new_context_with_model: kv self size  =  256.00 MB
llama_new_context_with_model: compute buffer total size = 76.63 MB
llama_new_context_with_model: VRAM scratch buffer: 70.50 MB
llama_new_context_with_model: total VRAM used: 12929.52 MB (model: 12603.02 MB, context: 326.50 MB)

system_info: n_threads = 96 / 192 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 |
perplexity: tokenizing the input ..
perplexity: tokenization took 858.533 ms
perplexity: calculating perplexity over 616 chunks, batch_size=512
perplexity: 0.15 seconds per pass - ETA 1.55 minutes
[1]4.3552,[2]6.1570,[3]5.7857,[4]6.7030,[5]6.8301,[6]7.5447,[7]8.1180,[8]8.4626,[9]8.8826,[10]9.2268,[11]9.1924,[12]9.2277,[13]9.3662,[14]9.8428,[15]9.3058,[16]9.0531,[17]8.9313,[18]8.4224,[19]8.3904,[20]8.2445,[21]8.1787,[22]8.0825,[23]8.0098,[24]7.8000,[25]7.5941,[26]7.4137,[27]7.2866,[28]7.1175,[29]7.0882,[30]6.9224,[31]6.9037,[32]6.8609,[33]6.8255,[34]6.8352,[35]6.7436,[36]6.7418,[37]6.7644,[38]6.8275,[39]6.8727,[40]6.9264,[41]6.8493,[42]6.9215,[43]6.8709,[44]6.8286,[45]6.8529,[46]6.8773,[47]6.8260,[48]6.8115,[49]6.7427,[50]6.8259,[51]6.8257,[52]6.7904,[53]6.7868,[54]6.7588,[55]6.8117,[56]6.8324,[57]6.8663,[58]6.8800,[59]6.9294,[60]6.9350,[61]6.9554,[62]7.0007,[63]7.0148,[64]7.0640,[65]7.0465,[66]7.0968,[67]7.1372,[68]7.1725,[69]7.2322,[70]7.2555,[71]7.2747,[72]7.3050,[73]7.3265,[74]7.3242,[75]7.3835,[76]7.3709,[77]7.3524,[78]7.3084,[79]7.2847,[80]7.2624,[81]7.2813,[82]7.2645,[83]7.2153,[84]7.2197,[85]7.1988,[86]7.2443,[87]7.2306,[88]7.2263,[89]7.2207,[90]7.2672,[91]7.2925,[92]7.2798,[93]7.2905,[94]7.2740,[95]7.3029,[96]7.3081,[97]7.2877,[98]7.3098,[99]7.3100,[100]7.3283,[101]7.3356,[102]7.3264,[103]7.3282,[104]7.3220,[105]7.3574,[106]7.3884,[107]7.3912,[108]7.3950,[109]7.3952,[110]7.3859,[111]7.3867,[112]7.4256,[113]7.4826,[114]7.5352,[115]7.5431,[116]7.5892,[117]7.6039,[118]7.6070,[119]7.6485,[120]7.6949,[121]7.7201,[122]7.7005,[123]7.7121,[124]7.7098,[125]7.6806,[126]7.6824,[127]7.6780,[128]7.6740,[129]7.6429,[130]7.6317,[131]7.6164,[132]7.6071,[133]7.5911,[134]7.5450,[135]7.5202,[136]7.5132,[137]7.4874,[138]7.4549,[139]7.4332,[140]7.4346,[141]7.4344,[142]7.4320,[143]7.4429,[144]7.4386,[145]7.4244,[146]7.4116,[147]7.4206,[148]7.4217,[149]7.4598,[150]7.4689,[151]7.4711,[152]7.5067,[153]7.4848,[154]7.4548,[155]7.4233,[156]7.3866,[157]7.3613,[158]7.3155,[159]7.2914,[160]7.2745,[161]7.2450,[162]7.2122,[163]7.1877,[164]7.1620,[165]7.1231,[166]7.0945,[167]7.0707,[168]7.0328,[169]7.0099,[170]6.9896,[171]6.9676,[172]6.9389,[173]6.9267,[174]6.8934,[175]6.8843,[176]6.8869,[177]6.8940,[178]6.8857,[179]6.9077,[180]6.9161,[181]6.9517,[182]6.9830,[183]7.0022,[184]7.0422,[185]7.0526,[186]7.0762,[187]7.1001,[188]7.1083,[189]7.1125,[190]7.1035,[191]7.1281,[192]7.1384,[193]7.1376,[194]7.1392,[195]7.1459,[196]7.1533,[197]7.1623,[198]7.1575,[199]7.1679,[200]7.1873,[201]7.2013,[202]7.2044,[203]7.2002,[204]7.2189,[205]7.2350,[206]7.2619,[207]7.2657,[208]7.2701,[209]7.2671,[210]7.2562,[211]7.2451,[212]7.2331,[213]7.2569,[214]7.2619,[215]7.2647,[216]7.2627,[217]7.2532,[218]7.2491,[219]7.2321,[220]7.2414,[221]7.2292,[222]7.2196,[223]7.2221,[224]7.2193,[225]7.2072,[226]7.1956,[227]7.2059,[228]7.2015,[229]7.2007,[230]7.1854,[231]7.1795,[232]7.1733,[233]7.1631,[234]7.1594,[235]7.1525,[236]7.1545,[237]7.1333,[238]7.1281,[239]7.1232,[240]7.1013,[241]7.0859,[242]7.0682,[243]7.0576,[244]7.0535,[245]7.0442,[246]7.0397,[247]7.0367,[248]7.0250,[249]7.0194,[250]7.0203,[251]7.0148,[252]7.0140,[253]7.0239,[254]7.0263,[255]7.0369,[256]7.0380,[257]7.0374,[258]7.0354,[259]7.0442,[260]7.0470,[261]7.0581,[262]7.0727,[263]7.0836,[264]7.0897,[265]7.0980,[266]7.0994,[267]7.1120,[268]7.1305,[269]7.1314,[270]7.1353,[271]7.1373,[272]7.1301,[273]7.1098,[274]7.0968,[275]7.0844,[276]7.0692,[277]7.0670,[278]7.0669,[279]7.0712,[280]7.0696,[281]7.0685,[282]7.0575,[283]7.0507,[284]7.0468,[285]7.0366,[286]7.0280,[287]7.0212,[288]7.0100,[289]7.0058,[290]7.0051,[291]6.9897,[292]6.9844,[293]6.9784,[294]6.9708,[295]6.9661,[296]6.9670,[297]6.9540,[298]6.9437,[299]6.9183,[300]6.9185,[301]6.9313,[302]6.9373,[303]6.9220,[304]6.9190,[305]6.9133,[306]6.9264,[307]6.9312,[308]6.9302,[309]6.9350,[310]6.9375,[311]6.9443,[312]6.9577,[313]6.9677,[314]6.9669,[315]6.9577,[316]6.9514,[317]6.9472,[318]6.9419,[319]6.9367,[320]6.9366,[321]6.9444,[322]6.9467,[323]6.9501,[324]6.9534,[325]6.9534,[326]6.9492,[327]6.9516,[328]6.9584,[329]6.9570,[330]6.9534,[331]6.9525,[332]6.9461,[333]6.9406,[334]6.9267,[335]6.9341,[336]6.9348,[337]6.9353,[338]6.9433,[339]6.9316,[340]6.9400,[341]6.9494,[342]6.9640,[343]6.9714,[344]6.9754,[345]6.9722,[346]6.9748,[347]6.9651,[348]6.9707,[349]6.9652,[350]6.9551,[351]6.9541,[352]6.9517,[353]6.9529,[354]6.9461,[355]6.9493,[356]6.9500,[357]6.9633,[358]6.9625,[359]6.9596,[360]6.9537,[361]6.9437,[362]6.9455,[363]6.9421,[364]6.9447,[365]6.9423,[366]6.9343,[367]6.9272,[368]6.9205,[369]6.9121,[370]6.9065,[371]6.9071,[372]6.9093,[373]6.9074,[374]6.9006,[375]6.9111,[376]6.9198,[377]6.9230,[378]6.9141,[379]6.9174,[380]6.9182,[381]6.9220,[382]6.9167,[383]6.9140,[384]6.9189,[385]6.9219,[386]6.9394,[387]6.9521,[388]6.9686,[389]6.9790,[390]6.9883,[391]7.0046,[392]7.0165,[393]7.0306,[394]7.0334,[395]7.0384,[396]7.0507,[397]7.0593,[398]7.0634,[399]7.0726,[400]7.0764,[401]7.0859,[402]7.0928,[403]7.0997,[404]7.1083,[405]7.1207,[406]7.1334,[407]7.1314,[408]7.1262,[409]7.1176,[410]7.1179,[411]7.1274,[412]7.1370,[413]7.1425,[414]7.1505,[415]7.1432,[416]7.1425,[417]7.1446,[418]7.1486,[419]7.1492,[420]7.1533,[421]7.1548,[422]7.1639,[423]7.1667,[424]7.1621,[425]7.1534,[426]7.1467,[427]7.1414,[428]7.1323,[429]7.1306,[430]7.1315,[431]7.1285,[432]7.1280,[433]7.1361,[434]7.1306,[435]7.1287,[436]7.1336,[437]7.1324,[438]7.1257,[439]7.1263,[440]7.1316,[441]7.1347,[442]7.1329,[443]7.1308,[444]7.1342,[445]7.1254,[446]7.1302,[447]7.1222,[448]7.1206,[449]7.1100,[450]7.1092,[451]7.1120,[452]7.1178,[453]7.1208,[454]7.1167,[455]7.1158,[456]7.1168,[457]7.1084,[458]7.1052,[459]7.1066,[460]7.0958,^[[B[461]7.0925,^[[B^[[B^[[B^[[B[462]7.0898,^[[B^[[B^[[B^[[B^[[B^[[B[463]7.0839,[464]7.0856,[465]7.0748,[466]7.0724,[467]7.0692,[468]7.0632,[469]7.0613,[470]7.0546,[471]7.0547,[472]7.0527,[473]7.0503,[474]7.0379,[475]7.0391,[476]7.0402,[477]7.0347,[478]7.0293,[479]7.0270,[480]7.0292,[481]7.0278,[482]7.0288,[483]7.0382,[484]7.0475,[485]7.0453,[486]7.0437,[487]7.0439,[488]7.0491,[489]7.0514,[490]7.0542,[491]7.0586,[492]7.0610,[493]7.0667,[494]7.0685,[495]7.0663,[496]7.0692,[497]7.0653,[498]7.0663,[499]7.0563,[500]7.0577,[501]7.0617,[502]7.0606,[503]7.0569,[504]7.0518,[505]7.0484,[506]7.0543,[507]7.0585,[508]7.0601,[509]7.0602,[510]7.0608,[511]7.0574,[512]7.0571,[513]7.0499,[514]7.0469,[515]7.0473,[516]7.0446,[517]7.0444,[518]7.0407,[519]7.0405,[520]7.0333,[521]7.0310,[522]7.0248,[523]7.0240,[524]7.0286,[525]7.0250,[526]7.0237,[527]7.0239,[528]7.0226,[529]7.0168,[530]7.0194,[531]7.0257,[532]7.0269,[533]7.0274,[534]7.0289,[535]7.0284,[536]7.0304,[537]7.0270,[538]7.0282,[539]7.0273,[540]7.0236,[541]7.0224,[542]7.0230,[543]7.0221,[544]7.0216,[545]7.0237,[546]7.0201,[547]7.0218,[548]7.0183,[549]7.0191,[550]7.0167,[551]7.0118,[552]7.0133,[553]7.0181,[554]7.0235,[555]7.0209,[556]7.0199,[557]7.0156,[558]7.0116,[559]7.0100,[560]7.0120,[561]7.0100,[562]7.0076,[563]7.0096,[564]7.0132,[565]7.0100,[566]7.0082,[567]7.0123,[568]7.0132,[569]7.0176,[570]7.0239,[571]7.0156,[572]7.0078,[573]7.0113,[574]7.0124,[575]7.0194,[576]7.0167,[577]7.0099,[578]7.0023,[579]7.0032,[580]6.9927,[581]6.9839,[582]6.9786,[583]6.9615,[584]6.9526,[585]6.9504,[586]6.9543,[587]6.9568,[588]6.9565,[589]6.9583,[590]6.9609,[591]6.9626,[592]6.9678,[593]6.9746,[594]6.9817,[595]6.9837,[596]6.9902,[597]6.9887,[598]6.9912,[599]6.9861,[600]6.9878,[601]6.9827,[602]6.9841,[603]6.9861,[604]6.9925,[605]6.9933,[606]6.9930,[607]6.9915,[608]6.9923,[609]6.9886,[610]6.9921,[611]6.9954,[612]7.0020,[613]7.0046,[614]7.0089,[615]6.9998,[616]6.9985,
Final estimate: PPL = 6.9985 +/- 0.04227

llama_print_timings:        load time =    1879.26 ms
llama_print_timings:      sample time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings: prompt eval time =   73628.96 ms / 315392 tokens (    0.23 ms per token,  4283.53 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =   96843.77 ms
  • PR
llm_load_tensors: ggml ctx size =    0.10 MB
llm_load_tensors: using CUDA for GPU acceleration
llm_load_tensors: mem required  =  250.10 MB
llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 35/35 layers to GPU
llm_load_tensors: VRAM used: 12603.02 MB
...................................................................................................
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init: offloading v cache to GPU
llama_kv_cache_init: offloading k cache to GPU
llama_kv_cache_init: VRAM kv self = 256.00 MB
llama_new_context_with_model: kv self size  =  256.00 MB
llama_new_context_with_model: compute buffer total size = 76.63 MB
llama_new_context_with_model: VRAM scratch buffer: 70.50 MB
llama_new_context_with_model: total VRAM used: 12929.52 MB (model: 12603.02 MB, context: 326.50 MB)

system_info: n_threads = 96 / 192 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 |
perplexity: tokenizing the input ..
perplexity: tokenization took 846.531 ms
perplexity: calculating perplexity over 616 chunks, batch_size=512
perplexity: 0.13 seconds per pass - ETA 1.33 minutes
[1]4.3517,[2]6.1586,[3]5.7870,[4]6.7052,[5]6.8317,[6]7.5458,[7]8.1176,[8]8.4616,[9]8.8808,[10]9.2266,[11]9.1916,[12]9.2280,[13]9.3661,[14]9.8421,[15]9.3047,[16]9.0525,[17]8.9304,[18]8.4211,[19]8.3890,[20]8.2428,[21]8.1773,[22]8.0807,[23]8.0084,[24]7.7985,[25]7.5929,[26]7.4123,[27]7.2853,[28]7.1162,[29]7.0870,[30]6.9210,[31]6.9025,[32]6.8596,[33]6.8242,[34]6.8339,[35]6.7425,[36]6.7408,[37]6.7634,[38]6.8264,[39]6.8718,[40]6.9258,[41]6.8490,[42]6.9214,[43]6.8705,[44]6.8281,[45]6.8528,[46]6.8772,[47]6.8258,[48]6.8111,[49]6.7421,[50]6.8254,[51]6.8251,[52]6.7898,[53]6.7861,[54]6.7581,[55]6.8110,[56]6.8319,[57]6.8657,[58]6.8795,[59]6.9289,[60]6.9346,[61]6.9548,[62]7.0001,[63]7.0141,[64]7.0635,[65]7.0459,[66]7.0965,[67]7.1369,[68]7.1723,[69]7.2321,[70]7.2555,[71]7.2748,[72]7.3050,[73]7.3264,[74]7.3242,[75]7.3835,[76]7.3710,[77]7.3525,[78]7.3085,[79]7.2847,[80]7.2623,[81]7.2812,[82]7.2643,[83]7.2151,[84]7.2195,[85]7.1986,[86]7.2440,[87]7.2304,[88]7.2261,[89]7.2204,[90]7.2669,[91]7.2921,[92]7.2795,[93]7.2902,[94]7.2737,[95]7.3027,[96]7.3080,[97]7.2875,[98]7.3096,[99]7.3099,[100]7.3282,[101]7.3353,[102]7.3260,[103]7.3277,[104]7.3215,[105]7.3569,[106]7.3878,[107]7.3906,[108]7.3945,[109]7.3947,[110]7.3854,[111]7.3861,[112]7.4251,[113]7.4821,[114]7.5346,[115]7.5426,[116]7.5888,[117]7.6035,[118]7.6067,[119]7.6482,[120]7.6946,[121]7.7198,[122]7.7002,[123]7.7118,[124]7.7096,[125]7.6804,[126]7.6823,[127]7.6779,[128]7.6739,[129]7.6427,[130]7.6315,[131]7.6163,[132]7.6070,[133]7.5910,[134]7.5447,[135]7.5200,[136]7.5130,[137]7.4872,[138]7.4545,[139]7.4328,[140]7.4343,[141]7.4341,[142]7.4316,[143]7.4426,[144]7.4383,[145]7.4241,[146]7.4112,[147]7.4201,[148]7.4213,[149]7.4592,[150]7.4683,[151]7.4706,[152]7.5062,[153]7.4843,[154]7.4542,[155]7.4227,[156]7.3861,[157]7.3608,[158]7.3150,[159]7.2910,[160]7.2740,[161]7.2445,[162]7.2117,[163]7.1872,[164]7.1615,[165]7.1226,[166]7.0940,[167]7.0702,[168]7.0322,[169]7.0093,[170]6.9890,[171]6.9671,[172]6.9383,[173]6.9262,[174]6.8929,[175]6.8838,[176]6.8863,[177]6.8934,[178]6.8851,[179]6.9071,[180]6.9155,[181]6.9511,[182]6.9823,[183]7.0016,[184]7.0417,[185]7.0521,[186]7.0757,[187]7.0997,[188]7.1078,[189]7.1121,[190]7.1030,[191]7.1276,[192]7.1379,[193]7.1371,[194]7.1388,[195]7.1455,[196]7.1529,[197]7.1618,[198]7.1570,[199]7.1673,[200]7.1868,[201]7.2008,[202]7.2039,[203]7.1998,[204]7.2184,[205]7.2346,[206]7.2614,[207]7.2652,[208]7.2696,[209]7.2665,[210]7.2556,[211]7.2446,[212]7.2326,[213]7.2564,[214]7.2613,[215]7.2640,[216]7.2621,[217]7.2526,[218]7.2485,[219]7.2315,[220]7.2408,[221]7.2286,[222]7.2190,[223]7.2214,[224]7.2186,[225]7.2065,[226]7.1949,[227]7.2052,[228]7.2008,[229]7.2000,[230]7.1847,[231]7.1788,[232]7.1726,[233]7.1624,[234]7.1587,[235]7.1519,[236]7.1539,[237]7.1327,[238]7.1275,[239]7.1225,[240]7.1007,[241]7.0853,[242]7.0676,[243]7.0570,[244]7.0529,[245]7.0436,[246]7.0391,[247]7.0361,[248]7.0244,[249]7.0187,[250]7.0197,[251]7.0142,[252]7.0133,[253]7.0232,[254]7.0257,[255]7.0362,[256]7.0374,[257]7.0368,[258]7.0348,[259]7.0436,[260]7.0463,[261]7.0574,[262]7.0720,[263]7.0829,[264]7.0890,[265]7.0974,[266]7.0987,[267]7.1112,[268]7.1298,[269]7.1306,[270]7.1345,[271]7.1365,[272]7.1293,[273]7.1090,[274]7.0960,[275]7.0836,[276]7.0684,[277]7.0662,[278]7.0661,[279]7.0704,[280]7.0688,[281]7.0676,[282]7.0567,[283]7.0498,[284]7.0459,[285]7.0357,[286]7.0271,[287]7.0203,[288]7.0092,[289]7.0049,[290]7.0042,[291]6.9888,[292]6.9836,[293]6.9776,[294]6.9699,[295]6.9653,[296]6.9662,[297]6.9533,[298]6.9430,[299]6.9176,[300]6.9178,[301]6.9306,[302]6.9366,[303]6.9213,[304]6.9183,[305]6.9126,[306]6.9257,[307]6.9305,[308]6.9296,[309]6.9343,[310]6.9369,[311]6.9436,[312]6.9570,[313]6.9670,[314]6.9663,[315]6.9571,[316]6.9508,[317]6.9465,[318]6.9412,[319]6.9360,[320]6.9359,[321]6.9437,[322]6.9459,[323]6.9494,[324]6.9527,[325]6.9528,[326]6.9486,[327]6.9510,[328]6.9578,[329]6.9564,[330]6.9527,[331]6.9519,[332]6.9455,[333]6.9400,[334]6.9261,[335]6.9335,[336]6.9343,[337]6.9347,[338]6.9427,[339]6.9311,[340]6.9395,[341]6.9488,[342]6.9634,[343]6.9708,[344]6.9748,[345]6.9716,[346]6.9742,[347]6.9645,[348]6.9701,[349]6.9646,[350]6.9545,[351]6.9535,[352]6.9511,[353]6.9522,[354]6.9455,[355]6.9486,[356]6.9494,[357]6.9626,[358]6.9618,[359]6.9590,[360]6.9531,[361]6.9430,[362]6.9449,[363]6.9414,[364]6.9440,[365]6.9417,[366]6.9337,[367]6.9265,[368]6.9199,[369]6.9115,[370]6.9059,[371]6.9066,[372]6.9088,[373]6.9068,[374]6.9000,[375]6.9106,[376]6.9193,[377]6.9225,[378]6.9136,[379]6.9169,[380]6.9177,[381]6.9215,[382]6.9162,[383]6.9135,[384]6.9185,[385]6.9215,[386]6.9389,[387]6.9517,[388]6.9681,[389]6.9785,[390]6.9879,[391]7.0041,[392]7.0160,[393]7.0301,[394]7.0329,[395]7.0379,[396]7.0501,[397]7.0588,[398]7.0629,[399]7.0721,[400]7.0759,[401]7.0854,[402]7.0923,[403]7.0993,[404]7.1079,[405]7.1203,[406]7.1330,[407]7.1310,[408]7.1258,[409]7.1172,[410]7.1176,[411]7.1271,[412]7.1367,[413]7.1421,[414]7.1502,[415]7.1428,[416]7.1420,[417]7.1441,[418]7.1482,[419]7.1488,[420]7.1529,[421]7.1543,[422]7.1635,[423]7.1663,[424]7.1616,[425]7.1529,[426]7.1462,[427]7.1409,[428]7.1319,[429]7.1301,[430]7.1310,[431]7.1280,[432]7.1275,[433]7.1356,[434]7.1302,[435]7.1282,[436]7.1331,[437]7.1318,[438]7.1251,[439]7.1257,[440]7.1311,[441]7.1341,[442]7.1324,[443]7.1302,[444]7.1337,[445]7.1249,[446]7.1297,[447]7.1217,[448]7.1201,[449]7.1094,[450]7.1087,[451]7.1115,[452]7.1172,[453]7.1203,[454]7.1161,[455]7.1153,[456]7.1162,[457]7.1078,[458]7.1046,[459]7.1061,[460]7.0953,[461]7.0919,[462]7.0893,[463]7.0834,[464]7.0850,[465]7.0742,[466]7.0718,[467]7.0687,[468]7.0626,[469]7.0607,[470]7.0540,[471]7.0541,[472]7.0521,[473]7.0497,[474]7.0373,[475]7.0385,[476]7.0397,[477]7.0341,[478]7.0288,[479]7.0264,[480]7.0286,[481]7.0272,[482]7.0282,[483]7.0376,[484]7.0469,[485]7.0448,[486]7.0432,[487]7.0433,[488]7.0486,[489]7.0509,[490]7.0537,[491]7.0580,[492]7.0605,[493]7.0661,[494]7.0679,[495]7.0657,[496]7.0686,[497]7.0647,[498]7.0657,[499]7.0558,[500]7.0571,[501]7.0611,[502]7.0600,[503]7.0563,[504]7.0512,[505]7.0479,[506]7.0538,[507]7.0579,[508]7.0595,[509]7.0596,[510]7.0602,[511]7.0568,[512]7.0565,[513]7.0493,[514]7.0463,[515]7.0468,[516]7.0440,[517]7.0438,[518]7.0401,[519]7.0399,[520]7.0326,[521]7.0304,[522]7.0242,[523]7.0234,[524]7.0280,[525]7.0244,[526]7.0231,[527]7.0233,[528]7.0221,[529]7.0163,[530]7.0188,[531]7.0252,[532]7.0264,[533]7.0269,[534]7.0283,[535]7.0278,[536]7.0298,[537]7.0265,[538]7.0277,[539]7.0268,[540]7.0231,[541]7.0218,[542]7.0224,[543]7.0215,[544]7.0210,[545]7.0231,[546]7.0196,[547]7.0212,[548]7.0178,[549]7.0186,[550]7.0162,[551]7.0113,[552]7.0128,[553]7.0176,[554]7.0230,[555]7.0204,[556]7.0194,[557]7.0152,[558]7.0112,[559]7.0095,[560]7.0116,[561]7.0095,[562]7.0071,[563]7.0091,[564]7.0128,[565]7.0095,[566]7.0078,[567]7.0119,[568]7.0128,[569]7.0171,[570]7.0234,[571]7.0152,[572]7.0074,[573]7.0109,[574]7.0120,[575]7.0190,[576]7.0163,[577]7.0094,[578]7.0019,[579]7.0028,[580]6.9922,[581]6.9834,[582]6.9781,[583]6.9611,[584]6.9521,[585]6.9499,[586]6.9539,[587]6.9563,[588]6.9561,[589]6.9578,[590]6.9604,[591]6.9622,[592]6.9674,[593]6.9742,[594]6.9813,[595]6.9833,[596]6.9898,[597]6.9883,[598]6.9908,[599]6.9857,[600]6.9873,[601]6.9823,[602]6.9836,[603]6.9857,[604]6.9920,[605]6.9928,[606]6.9925,[607]6.9910,[608]6.9919,[609]6.9881,[610]6.9917,[611]6.9950,[612]7.0015,[613]7.0041,[614]7.0084,[615]6.9993,[616]6.9980,
Final estimate: PPL = 6.9980 +/- 0.04226

llama_print_timings:        load time =    1866.01 ms
llama_print_timings:      sample time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings: prompt eval time =   57857.87 ms / 315392 tokens (    0.18 ms per token,  5451.15 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =   80736.93 ms

Quantum batched decoding (Q4_0)

  • PR + ampere constants patch
-#define  MMQ_X_Q4_0_AMPERE 64
-#define  MMQ_Y_Q4_0_AMPERE 128
-#define NWARPS_Q4_0_AMPERE 4
+#define  MMQ_X_Q4_0_AMPERE 8
+#define  MMQ_Y_Q4_0_AMPERE 32
+#define NWARPS_Q4_0_AMPERE 8
LLAMA_CUBLAS=1 make -j && ./batched-bench /workspace/openllama-7b/ggml-model-q4_0.gguf 4608 1 99 1 512 128 1,2,3,4,5,6,7,8,16,32

LLAMA_CUBLAS=1 make -j && ./batched-bench /workspace/openllama-7b/ggml-model-q4_0.gguf 4608 0 99 1 50 100 1,2,3,4,5,6,7,8,16,32
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.449 1140.48 1.037 123.40 1.486 430.62
512 128 2 768 0.448 1142.55 2.201 116.29 2.649 289.87
512 128 3 896 0.446 1148.97 2.239 171.50 2.685 333.74
512 128 4 1024 0.446 1148.06 2.275 225.03 2.721 376.30
512 128 5 1152 0.445 1149.38 2.343 273.19 2.788 413.18
512 128 6 1280 0.445 1149.87 2.380 322.75 2.825 453.12
512 128 7 1408 0.446 1147.91 2.444 366.61 2.890 487.19
512 128 8 1536 0.445 1151.08 2.500 409.67 2.944 521.67
512 128 16 2560 0.448 1143.01 3.601 568.65 4.049 632.19
512 128 32 4608 0.450 1136.52 4.976 823.16 5.426 849.17
50 100 1 150 0.070 714.76 0.713 140.28 0.783 191.62
50 100 2 300 0.100 999.31 1.676 119.34 1.776 168.92
50 100 3 450 0.143 1046.70 1.715 174.90 1.859 242.12
50 100 4 600 0.183 1091.19 1.744 229.31 1.928 311.25
50 100 5 750 0.232 1079.15 1.785 280.14 2.016 371.94
50 100 6 900 0.275 1089.17 1.821 329.54 2.096 429.36
50 100 7 1050 0.318 1101.73 1.870 374.31 2.188 479.93
50 100 8 1200 0.355 1126.32 1.913 418.22 2.268 529.10
50 100 16 2400 0.703 1138.72 2.827 565.96 3.530 679.97

@ggerganov ggerganov added performance Speed related topics high priority Very important issue need feedback Testing and feedback with results are needed Nvidia GPU Issues specific to Nvidia GPUs labels Oct 23, 2023
@slaren
Copy link
Collaborator

slaren commented Oct 23, 2023

cublasGemmStridedBatchedEx may also work, and it wouldn't require copying an array of pointers.

Copy link
Collaborator

@KerfuffleV2 KerfuffleV2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes plus:

#define cublasGemmBatchedEx hipblasGemmBatchedEx

are needed to compile with ROCM. I haven't done performance testing, but it seems to work.

I couldn't figure out how to propose a change for lines outside what the pull changed, also this is the first time trying to create a multi-part review so please forgive me if I mess something up.

ggml-cuda.cu Outdated Show resolved Hide resolved
ggml-cuda.cu Outdated Show resolved Hide resolved
These changes plus:

```c++
#define cublasGemmBatchedEx hipblasGemmBatchedEx
```

are needed to compile with ROCM. I haven't done performance testing, but it seems to work.

I couldn't figure out how to propose a change for lines outside what the pull changed, also this is the first time trying to create a multi-part review so please forgive me if I mess something up.
@KerfuffleV2
Copy link
Collaborator

@ggerganov I am so sorry, it said "commit suggestion". I thought I was just committing the suggested changes as a comment, not changing your pull. I definitely did not mean to do that. And I don't even know how to revert it. :(

@KerfuffleV2
Copy link
Collaborator

I ran the batched-bench tests on my RX6600 with a Q5_K Mistral model. The results look weird, but this pull does seem to increase performance a little bit:

Expand

PR

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 1.023 500.72 4.011 31.91 5.034 127.14
512 128 2 768 1.025 499.73 41.000 6.24 42.024 18.28
512 128 3 896 1.029 497.74 40.861 9.40 41.889 21.39
512 128 4 1024 1.025 499.31 40.541 12.63 41.566 24.64
512 128 5 1152 1.026 498.87 40.685 15.73 41.712 27.62
512 128 6 1280 1.029 497.39 41.978 18.30 43.008 29.76
512 128 7 1408 1.031 496.63 40.992 21.86 42.023 33.51
512 128 8 1536 1.029 497.56 41.915 24.43 42.944 35.77
512 128 16 2560 1.031 496.44 43.351 47.24 44.383 57.68
512 128 32 4608 1.033 495.83 52.089 78.63 53.122 86.74

Master

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 1.016 503.94 4.318 29.65 5.334 119.99
512 128 2 768 1.013 505.28 48.181 5.31 49.195 15.61
512 128 3 896 1.018 502.93 46.285 8.30 47.303 18.94
512 128 4 1024 1.017 503.32 46.901 10.92 47.918 21.37
512 128 5 1152 1.019 502.61 47.411 13.50 48.429 23.79
512 128 6 1280 1.020 501.80 47.845 16.05 48.866 26.19
512 128 7 1408 1.021 501.69 47.975 18.68 48.996 28.74
512 128 8 1536 1.046 489.63 48.433 21.14 49.479 31.04
512 128 16 2560 1.020 501.90 53.560 38.24 54.580 46.90
512 128 32 4608 1.025 499.35 64.507 63.50 65.532 70.32

I tried messing with the MMQX/Y NWARPS stuff for Q5_K and Q6_K (the quants the model actually used) but it didn't seem to make much of a difference.

@ggerganov
Copy link
Owner Author

The MMQ constants take effect only if you run with MMQ=1 (set argv[5] to 1)

@KerfuffleV2
Copy link
Collaborator

KerfuffleV2 commented Oct 23, 2023

The MMQ constants take effect only if you run with MMQ=1 (set argv[5] to 1)

Ahh, derp.

Expand

PR

Tweaked

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 2.948 173.66 4.043 31.66 6.992 91.54
512 128 2 768 2.955 173.29 7.518 34.05 10.473 73.33
512 128 3 896 2.963 172.77 7.655 50.17 10.618 84.39
512 128 4 1024 2.966 172.64 7.750 66.07 10.716 95.56
512 128 5 1152 2.966 172.63 7.853 81.50 10.819 106.48
512 128 6 1280 2.970 172.39 7.976 96.29 10.946 116.94
512 128 7 1408 2.970 172.37 8.147 109.98 11.117 126.65
512 128 8 1536 2.989 171.28 8.350 122.63 11.340 135.45
512 128 16 2560 2.972 172.26 15.349 133.43 18.321 139.73
512 128 32 4608 2.976 172.03 32.125 127.50 35.101 131.28

Non-tweaked

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.991 516.90 4.007 31.94 4.998 128.06
512 128 2 768 0.991 516.50 17.215 14.87 18.207 42.18
512 128 3 896 0.994 515.10 17.382 22.09 18.376 48.76
512 128 4 1024 0.996 514.23 17.465 29.32 18.461 55.47
512 128 5 1152 0.997 513.33 17.580 36.41 18.577 62.01
512 128 6 1280 0.998 512.96 17.669 43.47 18.667 68.57
512 128 7 1408 0.999 512.76 17.805 50.32 18.803 74.88
512 128 8 1536 1.000 512.18 17.992 56.91 18.992 80.88
512 128 16 2560 1.000 511.99 19.738 103.76 20.738 123.45
512 128 32 4608 1.001 511.72 26.297 155.76 27.298 168.80

Master

Tweaked

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 3.757 136.28 4.013 31.90 7.770 82.37
512 128 2 768 2.950 173.55 13.647 18.76 16.597 46.27
512 128 3 896 2.955 173.26 13.631 28.17 16.586 54.02
512 128 4 1024 2.957 173.17 14.361 35.65 17.318 59.13
512 128 5 1152 2.959 173.02 14.878 43.02 17.837 64.59
512 128 6 1280 2.962 172.83 15.303 50.19 18.265 70.08
512 128 7 1408 2.958 173.06 15.415 58.13 18.373 76.63
512 128 8 1536 2.965 172.67 15.907 64.38 18.872 81.39
512 128 16 2560 2.965 172.70 25.846 79.24 28.811 88.86
512 128 32 4608 2.976 172.07 45.309 90.40 48.285 95.43

Non-tweaked

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.982 521.50 4.075 31.41 5.057 126.56
512 128 2 768 0.983 520.68 22.996 11.13 23.980 32.03
512 128 3 896 0.987 518.87 22.990 16.70 23.977 37.37
512 128 4 1024 0.988 518.35 23.714 21.59 24.701 41.46
512 128 5 1152 0.989 517.69 24.227 26.42 25.216 45.69
512 128 6 1280 0.989 517.68 24.627 31.19 25.616 49.97
512 128 7 1408 0.990 516.94 24.687 36.29 25.678 54.83
512 128 8 1536 0.990 517.12 25.182 40.66 26.173 58.69
512 128 16 2560 0.989 517.57 30.094 68.05 31.084 82.36
512 128 32 4608 0.989 517.65 39.087 104.79 40.076 114.98

Tweaked here means using

diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index e2dea9e..43d7825 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -4041,8 +4041,8 @@ template <bool need_check> static __global__ void
 #endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
-#define  MMQ_X_Q5_K_RDNA2  64
-#define  MMQ_Y_Q5_K_RDNA2  128
+#define  MMQ_X_Q5_K_RDNA2  8
+#define  MMQ_Y_Q5_K_RDNA2  32
 #define NWARPS_Q5_K_RDNA2  8
 #define  MMQ_X_Q5_K_RDNA1  32
 #define  MMQ_Y_Q5_K_RDNA1  64
@@ -4102,8 +4102,8 @@ mul_mat_q5_K(
 #endif // __CUDA_ARCH__ >= CC_VOLTA
 }
 
-#define  MMQ_X_Q6_K_RDNA2  64
-#define  MMQ_Y_Q6_K_RDNA2  128
+#define  MMQ_X_Q6_K_RDNA2  8
+#define  MMQ_Y_Q6_K_RDNA2  32
 #define NWARPS_Q6_K_RDNA2  8
 #define  MMQ_X_Q6_K_RDNA1  32
 #define  MMQ_Y_Q6_K_RDNA1  64

I couldn't increase NWARPS without cause a compile error. I also have no idea of what changes would be good for this weak RDNA2 card, I tried to do something similar to what you suggested to see what would happen.

I also tried running with NGL 0. I wasn't expecting to see any difference between the PR and master, but the PR is consistently a little bit faster even when only the prompt is using the GPU.

If you/anyone has suggestions for different changes to try, I'm willing to help test.


edit: I tried various settings:

MX MY NW PP1 PP2 TG1 TG2 TG3 TG4 TG5 TG6 TG7
4 32 4 157.6 190.5 29.8 37.9 55.8 73.4 82.2 97.3 111.3
8 32 8 172.4 172.2 30.0 34.0 50.1 66.0 81.3 96.0 109.6
16 32 8 252.3 251.2 31.9 27.4 41.6 54.9 68.0 80.7 92.9
16 32 4 223.7 286.5 31.9 31.4 46.3 61.1 75.4 89.5 102.8
64 128 8 516.0 515.8 29.2 14.4 21.9 29.1 36.2 43.3 50.1
32 32 8 284.1 283.8 31.9 17.2 21.6 34.0 42.3 50.5 58.5

If you don't really care about prompt speed, at least on this hardware 4,32,4 looks best. I only included batch size 1 and 2 PP because it really only varied going from 1 to 2, after that it pretty much always stayed the same.

@ggerganov
Copy link
Owner Author

@slaren

Using cublasGemmStridedBatchedEx makes it even faster, but I think we can apply it only for MQA.
With Grouped-Query Attention, I don't think we can use it.

  • V100 LLaMA 7B F16, cublasGemmBatchedEx
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.174 2948.07 2.612 49.01 2.785 229.78
512 128 2 768 0.168 3046.75 3.280 78.04 3.448 222.71
512 128 3 896 0.160 3194.83 3.327 115.41 3.488 256.92
512 128 4 1024 0.161 3189.71 3.394 150.87 3.554 288.10
512 128 5 1152 0.160 3194.75 3.501 182.82 3.661 314.66
512 128 6 1280 0.164 3118.81 3.618 212.27 3.782 338.42
512 128 7 1408 0.163 3135.99 3.709 241.58 3.872 363.62
512 128 8 1536 0.160 3190.97 3.717 275.47 3.878 396.11
512 128 16 2560 0.161 3182.71 4.518 453.25 4.679 547.09
512 128 32 4608 0.164 3119.50 5.308 771.73 5.472 842.15
  • V100 LLaMA 7B F16, cublasGemmStridedBatchedEx
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.173 2957.91 2.619 48.88 2.792 229.24
512 128 2 768 0.161 3189.63 3.099 82.61 3.259 235.63
512 128 3 896 0.160 3205.57 3.163 121.40 3.323 269.66
512 128 4 1024 0.160 3203.92 3.194 160.30 3.354 305.32
512 128 5 1152 0.160 3203.38 3.275 195.40 3.435 335.35
512 128 6 1280 0.160 3198.94 3.425 224.21 3.585 357.01
512 128 7 1408 0.162 3167.38 3.502 255.85 3.664 384.31
512 128 8 1536 0.162 3165.10 3.551 288.33 3.713 413.66
512 128 16 2560 0.161 3186.79 4.366 469.05 4.527 565.50
512 128 32 4608 0.162 3161.76 5.216 785.23 5.378 856.79

ggml-cuda.cu Outdated
Comment on lines 7180 to 7182
CUDA_CHECK(cudaMemcpy(src0_ptrs_as, src0_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(src1_ptrs_as, src1_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy( dst_ptrs_as, dst_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice));
Copy link
Owner Author

@ggerganov ggerganov Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does anyone know if I changed these cudaMemcpy to cudaMemcpyAsync, do I need to add some synchronization before calling cublasGemmBatchedEx?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You only need to run them on the same stream, but I don't think that this can be made async because the host memory may already be freed by the time the copy happens. Running memcpy asynchronously also requires using host pinned memory.

If the cublasGemmBatchedEx needs to stay to support GQA, I would consider writing a kernel to calculate these values and calling cublas from the kernel. Additionally, cudaMalloc is usually very slow, which is why we have the memory pool. These allocations should be changed to use the memory pool.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reduced the mallocs from 3 to 1, but when I try to replace it with ggml_cuda_pool_malloc() (see commented code) the computation crashes with illegal access memory somewhere later. Couldn't figure out what is the cause - probably some memory alignment issue in the pointer arrays.

@ggerganov
Copy link
Owner Author

I've been testing this quite extensively today and I think it is an all-around improvement compared to master. The F16 batched decoding performance is now quite respectable and I believe llama.cpp now offers a very good solution for hosted inference.

@ggerganov ggerganov merged commit 2b4ea35 into master Oct 24, 2023
32 checks passed
mattgauf added a commit to mattgauf/llama.cpp that referenced this pull request Oct 27, 2023
* master: (350 commits)
  speculative : ensure draft and target model vocab matches (ggerganov#3812)
  llama : correctly report GGUFv3 format (ggerganov#3818)
  simple : fix batch handling (ggerganov#3803)
  cuda : improve text-generation and batched decoding performance (ggerganov#3776)
  server : do not release slot on image input (ggerganov#3798)
  batched-bench : print params at start
  log : disable pid in log filenames
  server : add parameter -tb N, --threads-batch N (ggerganov#3584) (ggerganov#3768)
  server : do not block system prompt update (ggerganov#3767)
  sync : ggml (conv ops + cuda MSVC fixes) (ggerganov#3765)
  cmake : add missed dependencies (ggerganov#3763)
  cuda : add batched cuBLAS GEMM for faster attention (ggerganov#3749)
  Add more tokenizer tests (ggerganov#3742)
  metal : handle ggml_scale for n%4 != 0 (close ggerganov#3754)
  Revert "make : add optional CUDA_NATIVE_ARCH (ggerganov#2482)"
  issues : separate bug and enhancement template + no default title (ggerganov#3748)
  Update special token handling in conversion scripts for gpt2 derived tokenizers (ggerganov#3746)
  llama : remove token functions with `context` args in favor of `model` (ggerganov#3720)
  Fix baichuan convert script not detecing model (ggerganov#3739)
  make : add optional CUDA_NATIVE_ARCH (ggerganov#2482)
  ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority Very important issue need feedback Testing and feedback with results are needed Nvidia GPU Issues specific to Nvidia GPUs performance Speed related topics
Projects
None yet
Development

Successfully merging this pull request may close these issues.

llama : improve batched decoding performance
3 participants