-
Notifications
You must be signed in to change notification settings - Fork 10.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cuda : add batched cuBLAS GEMM for faster attention #3749
Conversation
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes plus:
#define cublasGemmBatchedEx hipblasGemmBatchedEx
are needed to compile with ROCM. I haven't done performance testing, but it seems to work.
I couldn't figure out how to propose a change for lines outside what the pull changed, also this is the first time trying to create a multi-part review so please forgive me if I mess something up.
These changes plus: ```c++ #define cublasGemmBatchedEx hipblasGemmBatchedEx ``` are needed to compile with ROCM. I haven't done performance testing, but it seems to work. I couldn't figure out how to propose a change for lines outside what the pull changed, also this is the first time trying to create a multi-part review so please forgive me if I mess something up.
@ggerganov I am so sorry, it said "commit suggestion". I thought I was just committing the suggested changes as a comment, not changing your pull. I definitely did not mean to do that. And I don't even know how to revert it. :( |
I ran the ExpandPR
Master
I tried messing with the MMQX/Y NWARPS stuff for Q5_K and Q6_K (the quants the model actually used) but it didn't seem to make much of a difference. |
The MMQ constants take effect only if you run with MMQ=1 (set argv[5] to 1) |
Ahh, derp. ExpandPRTweaked
Non-tweaked
MasterTweaked
Non-tweaked
Tweaked here means using diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index e2dea9e..43d7825 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -4041,8 +4041,8 @@ template <bool need_check> static __global__ void
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
-#define MMQ_X_Q5_K_RDNA2 64
-#define MMQ_Y_Q5_K_RDNA2 128
+#define MMQ_X_Q5_K_RDNA2 8
+#define MMQ_Y_Q5_K_RDNA2 32
#define NWARPS_Q5_K_RDNA2 8
#define MMQ_X_Q5_K_RDNA1 32
#define MMQ_Y_Q5_K_RDNA1 64
@@ -4102,8 +4102,8 @@ mul_mat_q5_K(
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
-#define MMQ_X_Q6_K_RDNA2 64
-#define MMQ_Y_Q6_K_RDNA2 128
+#define MMQ_X_Q6_K_RDNA2 8
+#define MMQ_Y_Q6_K_RDNA2 32
#define NWARPS_Q6_K_RDNA2 8
#define MMQ_X_Q6_K_RDNA1 32
#define MMQ_Y_Q6_K_RDNA1 64 I couldn't increase NWARPS without cause a compile error. I also have no idea of what changes would be good for this weak RDNA2 card, I tried to do something similar to what you suggested to see what would happen. I also tried running with NGL 0. I wasn't expecting to see any difference between the PR and master, but the PR is consistently a little bit faster even when only the prompt is using the GPU. If you/anyone has suggestions for different changes to try, I'm willing to help test. edit: I tried various settings:
If you don't really care about prompt speed, at least on this hardware 4,32,4 looks best. I only included batch size 1 and 2 PP because it really only varied going from 1 to 2, after that it pretty much always stayed the same. |
25a0b90
to
3d297c1
Compare
Using
|
ggml-cuda.cu
Outdated
CUDA_CHECK(cudaMemcpy(src0_ptrs_as, src0_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice)); | ||
CUDA_CHECK(cudaMemcpy(src1_ptrs_as, src1_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice)); | ||
CUDA_CHECK(cudaMemcpy( dst_ptrs_as, dst_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does anyone know if I changed these cudaMemcpy
to cudaMemcpyAsync
, do I need to add some synchronization before calling cublasGemmBatchedEx
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You only need to run them on the same stream, but I don't think that this can be made async because the host memory may already be freed by the time the copy happens. Running memcpy asynchronously also requires using host pinned memory.
If the cublasGemmBatchedEx
needs to stay to support GQA, I would consider writing a kernel to calculate these values and calling cublas from the kernel. Additionally, cudaMalloc
is usually very slow, which is why we have the memory pool. These allocations should be changed to use the memory pool.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reduced the mallocs from 3 to 1, but when I try to replace it with ggml_cuda_pool_malloc()
(see commented code) the computation crashes with illegal access memory somewhere later. Couldn't figure out what is the cause - probably some memory alignment issue in the pointer arrays.
I've been testing this quite extensively today and I think it is an all-around improvement compared to |
* master: (350 commits) speculative : ensure draft and target model vocab matches (ggerganov#3812) llama : correctly report GGUFv3 format (ggerganov#3818) simple : fix batch handling (ggerganov#3803) cuda : improve text-generation and batched decoding performance (ggerganov#3776) server : do not release slot on image input (ggerganov#3798) batched-bench : print params at start log : disable pid in log filenames server : add parameter -tb N, --threads-batch N (ggerganov#3584) (ggerganov#3768) server : do not block system prompt update (ggerganov#3767) sync : ggml (conv ops + cuda MSVC fixes) (ggerganov#3765) cmake : add missed dependencies (ggerganov#3763) cuda : add batched cuBLAS GEMM for faster attention (ggerganov#3749) Add more tokenizer tests (ggerganov#3742) metal : handle ggml_scale for n%4 != 0 (close ggerganov#3754) Revert "make : add optional CUDA_NATIVE_ARCH (ggerganov#2482)" issues : separate bug and enhancement template + no default title (ggerganov#3748) Update special token handling in conversion scripts for gpt2 derived tokenizers (ggerganov#3746) llama : remove token functions with `context` args in favor of `model` (ggerganov#3720) Fix baichuan convert script not detecing model (ggerganov#3739) make : add optional CUDA_NATIVE_ARCH (ggerganov#2482) ...
resolve #3479
Description
For baseline performance on
master
, check the info in #3726We identified the KQ and KQV operations in the attention layer as bottleneck. To overcome this, we parallelize the GEMMs across the head dimension using cublasGemmBatchedEx.
master
The single-batch text generation speed is improved for large contexts(it's the same, but maybe can be improved for large contexts)I did some perplexity tests as well and I think everything works correctly, although I tested mostly F16 models. I don't think the accuracy of quantum models would be affected by this change, but still - need to double check.
Please take a careful look and let me know if you observe improved performance and if the ppl is good
Results
Below are some results for A100. Also did similar tests for V100 and the results are very similar.
Batched decoding for 7B F16 models
Batched decoding for 1B F16 models
llama-bench
for 7B F16Serving multiple clients with parallel decoding and continuous batching
After update with
cublasGemmStridedBatchedEx
and Codellama 7B F16:Perplexity for F16 7B
make -j perplexity && ./bin/perplexity -m ../models/openllama-7b/ggml-model-f16.gguf -f ./wikitext-2-raw/wiki.test.raw -ngl 100 -nommq
Quantum batched decoding (Q4_0)