-
Notifications
You must be signed in to change notification settings - Fork 10.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llamafile : improve moe prompt eval speed on cpu #6840
base: master
Are you sure you want to change the base?
Conversation
def794c
to
828f3fe
Compare
nice to see this PR <3 Thank you so much |
Does it also help the other K quants? |
@USBhost Unfortunately no. The K quants were designed to exploit under-utilization of CPU resources when doing matvecs. I tried copying and pasting the |
I see thanks for the explanation. Side note: I would love a doc that explains the speed between flat 4_0 vs 4_1 vs K quants. Because I keep seeing the simple ones getting buffs. |
The tinyBLAS code upstreamed by Mozilla's llamafile project makes prompt processing go very fast for F32, F16, Q4_0, and Q8_0.
Measured on AMD Ryzen Threadripper PRO 7995WX with TinyLlama 1.1B. This PR ensures those performance wins will happen for MoE models too. |
Note: I'm still in the process of testing this change and verifying it's correct on all compilers and architectures. |
26ab943
to
89991a1
Compare
OK I've worked out the remaining kinks. This code was just shipped as part of the llamafile 0.8 release. Thanks to this change, I'm seeing a 2x prompt eval speed increase across the board. My Threadripper now runs Mixtral 2x faster. My M2 Ultra runs Mixtral 2x faster on CPU. This change even pumps up the Raspberry Pi 5 to 78 tok/sec performance on non-MoE F16 models in case you want to buy a bag full of the things to build your next supercomputer. PTAL. |
I became intrigued by your assumption that block-tiling is required to speed up prompt processing for k-quants, so spent some time optimizing k-quant CPU matrix multiplications. I'm running on a 16-core Ryzen-7950X CPU, so have done just a better AVX2 implementation. Baseline for this CPU (using your PR) for a 7B LLaMA is
Here is what I get for k-quants
You favorite There are 3 ingredients involved in this speedup:
We see |
f1a134a
to
c34c472
Compare
That's outstanding news @ikawrakow! I can't wait to see your code. Now I won't need to recommend the legacy quantization formats. Am I correct in understanding you used |
Yes, I used |
@ikawrakow Receiving a PR from you would honor the llamafile project. What you'd want to do is create a copy of |
e717fec
to
e1c02a7
Compare
Here's a benchmark of an AMD V3C48 (a Zen 3 part) with PR:
Without these changes, prompt processing for (I'm still a bit confused as to why F16 performs so much better than BF16 without tinyblas and whether there is still something left on the table, but at least this way there is no compromise in using BF16 now) |
@lemmi where you're going to see the biggest changes here are running mixtral (rather than mistral) because moe models use |
Take a look at the failing CI run: https://github.com/ggerganov/llama.cpp/actions/runs/9052429493/job/24870086560?pr=6840
|
@jart I think the following patch should fix the CI: diff --git a/ggml-impl.h b/ggml-impl.h
index d85b152b..85d3f23f 100644
--- a/ggml-impl.h
+++ b/ggml-impl.h
@@ -17,6 +17,9 @@
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
+// some compilers don't provide _mm256_set_m128i, e.g. gcc 7
+#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
+
/**
* Converts brain16 to float32.
*
diff --git a/ggml-quants.c b/ggml-quants.c
index 00334c5f..3677b2db 100644
--- a/ggml-quants.c
+++ b/ggml-quants.c
@@ -22,9 +22,6 @@
#define UNUSED GGML_UNUSED
-// some compilers don't provide _mm256_set_m128i, e.g. gcc 7
-#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
-
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
// multiply int8_t, add results pairwise twice
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { |
@ggerganov I've updated this change with your suggestion. Test flake looks unrelated. PTAL. |
On M2 Ultra I'm observing some TG regression for ./scripts/compare-commits.sh master pr/6840 \
-m models/mixtral-instruct-8x7b-fast/ggml-model-f16.gguf \
-m models/mixtral-instruct-8x7b-fast/ggml-model-q8_0.gguf \
-m models/mixtral-instruct-8x7b-fast/ggml-model-q4_0.gguf -t 16 -ngl 0
@jart Do you observe this regression on your M2 Ultra? |
Thanks for pointing that out. I just reproduced the same regression. This change doesn't appear to be helpful for text generation so I've disabled it. PTAL |
@jart just a heads up that this was marked as merge ready, but CI is not passing. If it's not related to code changes you may want to rebase against latest known working CI in master, as I recall we had issue with CI in the master branch around that time. |
This change introduces a llamafile_mixmul() API, that allows tinyBLAS to speed up "Mixture of Expert" models. On my Threadripper the Mixtral 8x7b F16 weights now process prompts 2x faster. I am also seeing a 60 percent improvement with Mixtral 8x22b Q4_0. Support is provided for Q8_0; it is also supported by tinyBLAS. MoE models spend the most time in MUL_MAT_ID rather than MUL_MAT, which is why llamafile_sgemm() was not able to help them before. The new code works by decomposing the mixmul operation into fast 2d llamafile_sgemm() calls. This also adds BF16 support to tinyBLAS
This change introduces a llamafile_mixmul() API that allows tinyBLAS to speed up "Mixture of Expert" models. On my Threadripper, Mixtral's 8x7b F16 weights now process prompts 2x faster. I'm also seeing a 60 percent improvement with Mixtral 8x22b Q4_0. The same applies to Q8_0, which is also supported by tinyBLAS. MoE models spend the majority of their time inside MUL_MAT_ID rather than MUL_MAT, which is why llamafile_sgemm was not able to help them before. llamafile_mixmul works by decomposing the mixmul operation into approximatively two sgemm calls.