-
Notifications
You must be signed in to change notification settings - Fork 10k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix flash attention for ROCm #7011
base: master
Are you sure you want to change the base?
Conversation
I didn't close that other PR on accident. As I said before, I don't think we should be adding a dependency with rocWMMA when the performance is no better than master and we have no dev to test and support it. And I will do an implementation of FlashAttention without any tensor cores at all which may end up being faster anyways. |
I don't know how to get compile on windows :( |
Sorry, I didn't realize it had been closed on purpose. Is the dependency that bad, though? rocwmma is header only, so no link time requirement, and it enables sharing the existing CUDA code. The performance is not better, but the VRAM saving can be very significant, 1 GB in one case. The PR is not ready to merge as is anyway, I need to disable flash-attn in CMake by default for AMD GPUs, or enable it only if rocwmma is detected installed. I might not be a ROCm expert, but I am a C++ dev and I own a 7900xtx, if not merged, I might maintain this fork anyway. Of course, if you already have planned to work on that other implementation soon, all of this comment is irrelevant, but having access to a rocwmma based version as a comparison could be useful, I don't know. Please let me know what you think. |
I wasn't able to test flash attention on Windows with 7900XTX yet. |
So i can say that for CDNA this makes a big difference: This pr:
Lastest Master:
Both of those are still terrible compared to exllama but this pr dose make a big difference in the right direction Note that i had to make some trivial changes to this pr to make it choose the wmma path for gfx908 |
Id like to mention it here too, that after some optimization work to the gemm kernels (#8082) this pr now improves pp performance on CDNA by almost 2x and i really think the stance towards this pr needs to be revised. A tiny optional header only dependency is for sure worth a 2x or even 10% increase in speed and the fact that the cuda equivalent depedancy is fine but the rocm equivalent is not speaks volumes, as dose the comment on rocm perfomance here: #7716. |
My original plan was to buy an AMD GPU with tensor cores so that I can test and maintain these changes myself (currently I only have an RX 6800). But I currently have issues finding a slot for it in one of my machines. However, if I can get a pledge from you that you will help with maintenance I would be fine with merging a PR like this. Keep in mind though that the WMMA FlashAttention kernels that I wrote for CUDA are bad in the first place. They rely on the "high-level" WMMA functionality to use tensor cores but after talking to an NVIDIA engineer and doing some related work myself the better way to utilize tensor cores is via PTX instructions (CUDA equivalent of assembly). So I want to at some point rewrite the code accordingly. Instead of rocWMMA it would be much better to implement the equivalent AMD functionality in |
i cant accept maintainership of llamacpp/hip. I can promise to run regular testing (automated even if desired) on cdna. The current state of affairs also strongly discourage any optimization effort on my and others part, as even if you do some work optimize the hip back end, and even if you manage to get that merged, the nvidia centric churn in the common code base invetiably breaks performance again, usually only shortly later. also note that gfx11's wmma and gfx908/a/4x's mfma are very different with totally different hw implementation performance characteristics. |
When I make changes to the CUDA code I test it for correctness and performance using my RX 6800. My standard for numerical software is that correct results are the first and foremost priority. I very much do not want to have any broken code in my repositories. So if I cannot test or assert that the code produces correct results myself and if I also cannot delegate this to anyone else then I am simply not willing to merge the corresponding PR. The simple rocWMMA prototype that I did still required fixes from other people to work at all. My current stance towards HIP performance is that I am willing to invest some effort for AMD support "within reason". When it comes to MMQ in particular the performance depends very heavily on the exact data layout and for good AMD performance you would have to completely re-write the code anyways. |
@JohannesGaessler RDNA3 doesn't have dedicated tensor cores like CDNA does. So you will not see the same 2x perf boost @IMbackK sees. This is happening because rocWMMA translates to MFMA instructions on CDNA archs, which in turn runs directly on their matrix (tensor) cores. On RDNA3, rocWMMA translates to WMMA instructions that run on the shader cores which don't give as much of a perf boost as the dedicated matrix cores. This is why you are not seeing much of a perf boost when not using rocWMMA on RDNA3. I would highly recommend this PR gets merged for the 2x perf boost on CDNA alone. Otherwise you are not using those matrix cores at all. You might remember me recommending rocWMMA a while ago in #4801 (comment) Now, I realize that you are going to deprecate this soon. But let's not leave this 2x perf on the table and keep this code even if it breaks. I wish I had a AMD GPU locally to maintain this, but alas, I don't have one at the moment since I left AMD... |
@JohannesGaessler what's your email address? I reached out to AMD to see if someone can lend a hand in maintenance. If possible please share your email so they could reach out to you to see if they can support. |
My email address can be found on my Github profile. But as I said, as of right now my plan is to remove the WMMA-based implementation and I don't want to invest the effort to maintain it long-term. |
Is my understanding correct that AMD GPUs with no matrix cores (for example, I have 2x AMD MI60 - gfx906 with no matrix cores) can see good improvements in text generation (say >=1.5x) if you implement generic flash attention with no matrix core dependency? On a similar note, Is it possible to use a custom compiled flash attention that works with ROCm (e.g. AMD MI60) in llama.cpp? Someone on reddit shared that they have successfully compiled a flash attention library for MI60. Some benchmarks they shared for that compiled flash attention:
It seems there is almost 10x improvement in fwd pass and around 4x speed up in bwd pass compared to Pytorch. |
There already are implementations that do this, the performance on AMD is just bad, especially for large batch sizes (i.e. prompt processing). CUDA has seen performance improvements comparable to what you're asking so my expectation would be that a proper ROCm implementation instead of a HIP port of the CUDA code would have a similar speedup.
llama.cpp/GGML has no support whatsoever for external FlashAttention implementations. |
Hello, any update on it? Is it possible for us to see it merged? |
llama-bench
buffer = ROCm0 compute buffer size