Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Demo usage of Flash Attention #778

Closed
wants to merge 1 commit into from
Closed

Demo usage of Flash Attention #778

wants to merge 1 commit into from

Conversation

ggerganov
Copy link
Owner

This is my understanding of how Flash Attention works based on this picture:

image

ref: https://github.com/HazyResearch/flash-attention

The implementation is here:

https://github.com/ggerganov/llama.cpp/blob/flash-attn/ggml.c#L8122-L8367

I don't plan on merging this because on M1 it is the same performance as without FA.
However, in whisper.cpp I have gained performance from using this same exact call in the Encoder:

https://github.com/ggerganov/whisper.cpp/blob/0a2d1210bcb98978214bbf4e100922a413afd39d/whisper.cpp#L1482-L1508

Putting this here if someone wants to play with it or figures out how to implement sparse attention.
The idea is just to merge the ggml operators into a single op and avoid intermediate tensors.

@bakamomi
Copy link

bakamomi commented Apr 5, 2023

Please merge this because it's amazing on x86 with longer context. I tried generating 1500 tokens with the 7B model ( --ignore-eos -c 2048 -n 1500). On the master branch the generation took 1385 seconds. On the flash-attn branch it took 200 seconds.

Base automatically changed from fix-cpy to master April 5, 2023 19:07
@rabidcopy
Copy link
Contributor

rabidcopy commented Apr 5, 2023

Strange, when comparing #775 to this I noticed a regression in the time it took to generate 1024 tokens.
#775

llama_print_timings:        load time =  2776.69 ms
llama_print_timings:      sample time =   801.28 ms /  1024 runs   (    0.78 ms per run)
llama_print_timings: prompt eval time =  1912.48 ms /    14 tokens (  136.61 ms per token)
llama_print_timings:        eval time = 189655.06 ms /  1023 runs   (  185.39 ms per run)
llama_print_timings:       total time = 193245.67 ms

#778 + #775 (fluke)

llama_print_timings:        load time =  2745.93 ms
llama_print_timings:      sample time =   814.02 ms /  1024 runs   (    0.79 ms per run)
llama_print_timings: prompt eval time =  1880.81 ms /    14 tokens (  134.34 ms per token)
llama_print_timings:        eval time = 250896.90 ms /  1023 runs   (  245.26 ms per run)
llama_print_timings:       total time = 254470.03 ms

Please merge this because it's amazing on x86 with longer context. I tried generating 1500 tokens with the 7B model ( --ignore-eos -c 2048 -n 1500). On the master branch the generation took 1385 seconds. On the flash-attn branch it took 200 seconds.

Are you certain that uplift isn't a result of #775? If you cloned the flash-attn branch it included that commit.

Edit: Will run some more tests just to make sure this isn't coincidental for my machine.
Edit2: It was a fluke. Re-running again on this PR and I got a slightly better result now.

llama_print_timings:        load time =  2875.85 ms
llama_print_timings:      sample time =   802.98 ms /  1024 runs   (    0.78 ms per run)
llama_print_timings: prompt eval time =  1938.22 ms /    14 tokens (  138.44 ms per token)
llama_print_timings:        eval time = 180435.09 ms /  1023 runs   (  176.38 ms per run)
llama_print_timings:       total time = 184126.72 ms

@bakamomi
Copy link

bakamomi commented Apr 5, 2023

Alright, #775 clearly contributed to the results I got. I pulled master again with #775 already merged and now I'm getting:

llama_print_timings:        load time =   928.29 ms
llama_print_timings:      sample time =   859.52 ms /  1500 runs   (    0.57 ms per run)
llama_print_timings: prompt eval time =   454.82 ms /     8 tokens (   56.85 ms per token)
llama_print_timings:        eval time = 200382.48 ms /  1500 runs   (  133.59 ms per run)
llama_print_timings:       total time = 202195.50 ms

Exactly the same result as with flash attention. Just for reference this is what I got previously:

llama_print_timings:        load time =  1817.75 ms
llama_print_timings:      sample time =   857.91 ms /  1500 runs   (    0.57 ms per run)
llama_print_timings: prompt eval time =  1454.90 ms /     8 tokens (  181.86 ms per token)
llama_print_timings:        eval time = 1383048.29 ms /  1500 runs   (  922.03 ms per run)
llama_print_timings:       total time = 1385748.68 ms

I guess FA needs more testing.

@rabidcopy
Copy link
Contributor

rabidcopy commented Apr 5, 2023

Alright, #775 clearly contributed to the results I got. I pulled master again with #775 already merged and now I'm getting:

llama_print_timings:        load time =   928.29 ms
llama_print_timings:      sample time =   859.52 ms /  1500 runs   (    0.57 ms per run)
llama_print_timings: prompt eval time =   454.82 ms /     8 tokens (   56.85 ms per token)
llama_print_timings:        eval time = 200382.48 ms /  1500 runs   (  133.59 ms per run)
llama_print_timings:       total time = 202195.50 ms

Exactly the same result as with flash attention. Just for reference this is what I got previously:

llama_print_timings:        load time =  1817.75 ms
llama_print_timings:      sample time =   857.91 ms /  1500 runs   (    0.57 ms per run)
llama_print_timings: prompt eval time =  1454.90 ms /     8 tokens (  181.86 ms per token)
llama_print_timings:        eval time = 1383048.29 ms /  1500 runs   (  922.03 ms per run)
llama_print_timings:       total time = 1385748.68 ms

I guess FA needs more testing.

Wow, that's quite a dramatic change nonetheless! I guess some systems were hit way harder than others by the V transpose on every token.

@slaren
Copy link
Collaborator

slaren commented Apr 5, 2023

I couldn't find a measurable difference between this and master on a 9900k.

@rabidcopy
Copy link
Contributor

Yeah, no noticeable difference on a Ryzen 2600. But interesting if it can go somewhere.

@jon-chuang
Copy link
Contributor

There's a good chance that CPU is more bottlenecked by compute than GPU, and that orig implementation already prefetches cache lines.

See: Dao-AILab/flash-attention#59

@jon-chuang jon-chuang mentioned this pull request Apr 12, 2023
@ggerganov ggerganov added the demo Demonstrate some concept or idea, not intended to be merged label Apr 22, 2023
@NikolaBorisov
Copy link

Would this implementation also work on GPUs? Has anyone tried how well it works on GPUs?

@KerfuffleV2 KerfuffleV2 mentioned this pull request Sep 20, 2023
4 tasks
@WilliamTambellini
Copy link
Contributor

@ggerganov Tks. What about disabling it by default but merging it for people to be able at least to try with master via a cli arg ?

@Aya-ZIbra
Copy link

how do you do the benchmarking?

@Waylon-Zhu
Copy link

Is this the correct implementation? I think the effect on the GPU is good because it uses shared memory with higher bandwidth. On the CPU, should the block data be temporarily stored in registers to obtain higher bandwidth?

@sroussey
Copy link
Contributor

A faster metal implementation:

https://github.com/philipturner/metal-flash-attention

cc: @philipturner

@WilliamTambellini
Copy link
Contributor

@ggerganov I now see ggml_flash_attn_ext/back ... in recent ggml and already used in llama.cpp (if flash_attn true) so should that PR be now closed? Tks

@philipturner
Copy link

I just got pinged for this PR. Does LLaMA.cpp even exist anymore? It was a thing like 1.5 years ago.

@ggerganov ggerganov closed this Aug 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
demo Demonstrate some concept or idea, not intended to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.