From 2dd5d1f4b38a9f58e9eabcc3d810962d4e00a1d3 Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Fri, 28 Jun 2024 16:18:33 -0700 Subject: [PATCH] llamafile : improve moe prompt eval speed on cpu This change introduces a llamafile_mixmul() API, that allows tinyBLAS to speed up "Mixture of Expert" models. On my Threadripper the Mixtral 8x7b F16 weights now process prompts 2x faster. I am also seeing a 60 percent improvement with Mixtral 8x22b Q4_0. Support is provided for Q8_0; it is also supported by tinyBLAS. MoE models spend the most time in MUL_MAT_ID rather than MUL_MAT, which is why llamafile_sgemm() was not able to help them before. The new code works by decomposing the mixmul operation into fast 2d llamafile_sgemm() calls. This also adds BF16 support to tinyBLAS --- common/common.cpp | 6 +- ggml/include/ggml.h | 15 + ggml/src/ggml.c | 29 +- ggml/src/sgemm.cpp | 774 ++++++++++++++++++++++++++++++++++++++------ ggml/src/sgemm.h | 8 + 5 files changed, 714 insertions(+), 118 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 6a00d25be1316..0e0f4fa07315c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -78,7 +78,7 @@ using json = nlohmann::ordered_json; // int32_t cpu_get_num_physical_cores() { -#ifdef __linux__ +#if defined(__linux__) || defined(__COSMOPOLITAN__) // enumerate the set of thread siblings, num entries is num cores std::unordered_set siblings; for (uint32_t cpu=0; cpu < UINT32_MAX; ++cpu) { @@ -113,7 +113,7 @@ int32_t cpu_get_num_physical_cores() { return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; } -#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__) +#if defined(__x86_64__) && (defined(__linux__) || defined(__COSMOPOLITAN__)) && !defined(__ANDROID__) #include static void cpuid(unsigned leaf, unsigned subleaf, @@ -167,7 +167,7 @@ static int cpu_count_math_cpus(int n_cpu) { * Returns number of CPUs on system that are useful for math. */ int32_t cpu_get_num_math() { -#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__) +#if defined(__x86_64__) && (defined(__linux__) || defined(__COSMOPOLITAN__)) && !defined(__ANDROID__) int n_cpu = sysconf(_SC_NPROCESSORS_ONLN); if (n_cpu < 1) { return cpu_get_num_physical_cores(); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index d895c9acdb596..10e3834822e57 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -650,6 +650,21 @@ extern "C" { enum ggml_cgraph_eval_order order; }; + struct ggml_compute_state_shared; + + struct ggml_compute_params { + // ith = thread index, nth = number of threads + int ith, nth; + + // work buffer for all threads + size_t wsize; + void * wdata; + + struct ggml_compute_state_shared * shared; + }; + + void ggml_barrier(struct ggml_compute_state_shared * shared); + // scratch buffer struct ggml_scratch { size_t offs; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f5502afbe98b3..70521884c3274 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1754,17 +1754,6 @@ struct ggml_compute_state { struct ggml_compute_state_shared * shared; }; -struct ggml_compute_params { - // ith = thread index, nth = number of threads - int ith, nth; - - // work buffer for all threads - size_t wsize; - void * wdata; - - struct ggml_compute_state_shared * shared; -}; - // // fundamental operations // @@ -2857,7 +2846,7 @@ inline static void ggml_critical_section_start(void) { } #ifdef GGML_USE_OPENMP -static void ggml_barrier(struct ggml_compute_state_shared * shared) { +void ggml_barrier(struct ggml_compute_state_shared * shared) { if (shared->n_threads == 1) { return; } @@ -2865,7 +2854,7 @@ static void ggml_barrier(struct ggml_compute_state_shared * shared) { #pragma omp barrier } #else -static void ggml_barrier(struct ggml_compute_state_shared * shared) { +void ggml_barrier(struct ggml_compute_state_shared * shared) { if (shared->n_threads == 1) { return; } @@ -12306,11 +12295,16 @@ static void ggml_compute_forward_mul_mat_id( const struct ggml_tensor * src1 = dst->src[1]; const struct ggml_tensor * ids = dst->src[2]; - GGML_TENSOR_BINARY_OP_LOCALS +#if GGML_USE_LLAMAFILE + if (llamafile_mixmul(params, src0, src1, ids, dst)) + return; +#endif const int ith = params->ith; const int nth = params->nth; + GGML_TENSOR_BINARY_OP_LOCALS + const enum ggml_type type = src0->type; const bool src1_cont = ggml_is_contiguous(src1); @@ -18536,6 +18530,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur = 0; const struct ggml_tensor * src0 = node->src[0]; const struct ggml_tensor * src1 = node->src[1]; +#if GGML_USE_LLAMAFILE + const struct ggml_tensor * src2 = node->src[2]; +#endif const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; if (src1->type != vec_dot_type) { cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)); @@ -18544,6 +18541,10 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += GGML_PAD(cur, sizeof(int64_t)); // align cur += n_as * sizeof(int64_t); // matrix_row_counts cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows +#if GGML_USE_LLAMAFILE + size_t cur2 = llamafile_mixmul_needs(src0, src1, src2); + cur = cur > cur2 ? cur : cur2; +#endif } break; case GGML_OP_OUT_PROD: { diff --git a/ggml/src/sgemm.cpp b/ggml/src/sgemm.cpp index 6626ceb26213f..6b61d905fb770 100644 --- a/ggml/src/sgemm.cpp +++ b/ggml/src/sgemm.cpp @@ -21,13 +21,15 @@ // SOFTWARE. // -// _ _ ___ _ _ ___ -// | |_(_)_ _ _ _| _ ) | /_\ / __| -// | _| | ' \ || | _ \ |__ / _ \\__ \. -// \__|_|_||_\_, |___/____/_/ \_\___/ -// |__/ // -// BASIC LINEAR ALGEBRA SUBPROGRAMS +// ██████╗ ██╗ █████╗ ██████╗ +// ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║ ██╔══██╗██╔═══╝ +// ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║ ███████║██████╗ +// ██║ ██║██▀███║╚███╔╝██╔══██╗██║ ██╔══██║╔═══██║ +// ██║ ██║██║ ██║ ███║ ██████╔╝████╗██║ ██║██████║ +// ╚═╝ ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝ ╚═╝╚═════╝ +// +// BASIC LINEAR ALGEBRA SUBPROGRAMS // // // This file implements multithreaded CPU matrix multiplication for the @@ -52,6 +54,10 @@ #include "ggml-impl.h" #include "ggml-quants.h" +#define ROW_ALIGN 64 +#define MATRIX_ALIGN 4096 +#define MAX_ALIGN 4096 + #ifdef _MSC_VER #define NOINLINE __declspec(noinline) #else @@ -64,14 +70,61 @@ #define VECTOR_REGISTERS 16 #endif +#if 0 +#define NOT_SUPPORTED tinyBLAS_not_supported(__FILE__, __LINE__) +#else +#define NOT_SUPPORTED false +#endif +#define WANT_QUANTIZATION false + #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) namespace { +bool tinyBLAS_not_supported(const char *file, int line) { + fprintf(stderr, "%s:%d: tinyBLAS not supported\n", file, line); + return false; +} + inline float unhalf(ggml_fp16_t d) { return GGML_FP16_TO_FP32(d); } +inline float unhalf(ggml_bf16_t d) { + return GGML_BF16_TO_FP32(d); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// MATRIX MEMORY INDEXING + +#define NCA 1 +#define NCB 2 +#define NCC 4 + +#define INDEX(A, lda, j, i) (CONFIG & NC##A ? ((T##A *const *)A)[j] + i : A + lda * (j) + i) + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GGML TYPE TRAITS + +template +struct ggml_type_trait; +template <> +struct ggml_type_trait { + static constexpr ggml_type id = GGML_TYPE_F32; +}; +template <> +struct ggml_type_trait { + static constexpr ggml_type id = GGML_TYPE_BF16; +}; +template <> +struct ggml_type_trait { + static constexpr ggml_type id = GGML_TYPE_F16; +}; +template <> +struct ggml_type_trait { + static constexpr ggml_type id = GGML_TYPE_Q8_0; +}; + //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED ARITHMETIC OPERATIONS @@ -144,6 +197,13 @@ inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) { #endif #endif +#if defined(__AVX512BF16__) +template <> +inline __m512 madd(__m512bh x, __m512bh y, __m512 z) { + return _mm512_dpbf16_ps(z, x, y); +} +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED HORIZONTAL SUM @@ -194,10 +254,18 @@ inline float hsum(__m512 x) { template T load(const U *); +template <> inline float load(const float *p) { return *p; } +template <> inline float load(const ggml_fp16_t *p) { return unhalf(*p); } +template <> inline float load(const ggml_bf16_t *p) { return unhalf(*p); } + #if defined(__ARM_NEON) template <> inline float32x4_t load(const float *p) { return vld1q_f32(p); } +template <> +inline float32x4_t load(const ggml_bf16_t *p) { + return vreinterpretq_f32_u32(vshll_n_u16(vld1_u16((const unsigned short *)p), 16)); +} #if !defined(_MSC_VER) template <> inline float16x8_t load(const ggml_fp16_t *p) { return vld1q_f16((const float16_t *)p); @@ -220,6 +288,13 @@ template <> inline __m256 load(const float *p) { } #endif // __AVX__ +#if defined(__AVX2__) || defined(__AVX512F__) +template <> inline __m256 load(const ggml_bf16_t *p) { + return _mm256_castsi256_ps( + _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)p)), 16)); +} +#endif // __AVX2__ + #if defined(__F16C__) template <> inline __m256 load(const ggml_fp16_t *p) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p)); @@ -233,12 +308,42 @@ template <> inline __m512 load(const float *p) { template <> inline __m512 load(const ggml_fp16_t *p) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p)); } +template <> inline __m512 load(const ggml_bf16_t *p) { + return _mm512_castsi512_ps( + _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)p)), 16)); +} #endif // __AVX512F__ +#if defined(__AVX512BF16__) +template <> +inline __m512bh load(const ggml_bf16_t *p) { + return (__m512bh)_mm512_loadu_ps((const float *)p); +} +template <> +inline __m512bh load(const float *p) { + return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p)); +} +#endif // __AVX512BF16__ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// FLOATING POINT OUTPUT STREAMING + +inline void store(float *p, float f) { + *p = f; +} + +inline void store(ggml_fp16_t *p, float f) { + *p = GGML_FP32_TO_FP16(f); +} + +inline void store(ggml_bf16_t *p, float f) { + *p = GGML_FP32_TO_BF16(f); +} + //////////////////////////////////////////////////////////////////////////////////////////////////// // FLOATING POINT MATRIX MULTIPLICATION -template +template class tinyBLAS { public: tinyBLAS(int64_t k, @@ -249,7 +354,7 @@ class tinyBLAS { : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } - void matmul(int64_t m, int64_t n) { + void matmul(long m, long n) { mnpack(0, m, 0, n); } @@ -420,14 +525,18 @@ class tinyBLAS { int64_t jj = n0 + job % xtiles * RN; D Cv[RN][RM] = {}; for (int64_t l = 0; l < k; l += KN) +#pragma GCC unroll 100 for (int64_t j = 0; j < RN; ++j) +#pragma GCC unroll 100 for (int64_t i = 0; i < RM; ++i) - Cv[j][i] = madd(load(A + lda * (ii + i) + l), - load(B + ldb * (jj + j) + l), - Cv[j][i]); + Cv[j][i] = madd(load(INDEX(A, lda, ii + i, l)), // + load(INDEX(B, ldb, jj + j, l)), // + Cv[j][i]); +#pragma GCC unroll 100 for (int64_t j = 0; j < RN; ++j) +#pragma GCC unroll 100 for (int64_t i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i])); } } @@ -446,18 +555,18 @@ class tinyBLAS { // QUANT ZERO MATRIX MULTIPLICATION #if defined(__ARM_FEATURE_DOTPROD) -template +template class tinyBLAS_Q0_ARM { public: tinyBLAS_Q0_ARM(int64_t k, const TA *A, int64_t lda, - const block_q8_0 *B, int64_t ldb, + const TB *B, int64_t ldb, float *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } - void matmul(int64_t m, int64_t n) { + void matmul(long m, long n) { mnpack(0, m, 0, n); } @@ -539,15 +648,15 @@ class tinyBLAS_Q0_ARM { Cv[j][i] = vmlaq_n_f32(Cv[j][i], vcvtq_f32_s32(vdotq_s32( vdotq_s32(vdupq_n_s32(0), - load_lo(A + lda * (ii + i) + l), - load_lo(B + ldb * (jj + j) + l)), - load_hi(A + lda * (ii + i) + l), - load_hi(B + ldb * (jj + j) + l))), - unhalf(A[lda * (ii + i) + l].d) * - unhalf(B[ldb * (jj + j) + l].d)); + load_lo(INDEX(A, lda, ii + i, l)), + load_lo(INDEX(B, ldb, jj + j, l))), + load_hi(INDEX(A, lda, ii + i, l)), + load_hi(INDEX(B, ldb, jj + j, l)))), + unhalf(INDEX(A, lda, ii + i, l)->d) * + unhalf(INDEX(B, ldb, jj + j, l)->d)); for (int64_t j = 0; j < RN; ++j) for (int64_t i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i])); } } @@ -571,8 +680,8 @@ class tinyBLAS_Q0_ARM { } const TA *const A; - const block_q8_0 *const B; - float *const C; + const TB *const B; + TC *const C; const int64_t k; const int64_t lda; const int64_t ldb; @@ -583,7 +692,7 @@ class tinyBLAS_Q0_ARM { #endif // __ARM_FEATURE_DOTPROD #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) -template +template class tinyBLAS_Q0_AVX { public: tinyBLAS_Q0_AVX(int64_t k, @@ -594,7 +703,7 @@ class tinyBLAS_Q0_AVX { : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } - void matmul(int64_t m, int64_t n) { + void matmul(long m, long n) { mnpack(0, m, 0, n); } @@ -726,15 +835,15 @@ class tinyBLAS_Q0_AVX { for (int64_t j = 0; j < RN; ++j) for (int64_t i = 0; i < RM; ++i) { #if defined(__AVX2__) - __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), - load(A + lda * (ii + i) + l)), - _mm256_sign_epi8(load(B + ldb * (jj + j) + l), - load(A + lda * (ii + i) + l))); + __m256 udTmp = updot(_mm256_sign_epi8(load(INDEX(A, lda, ii + i, l)), + load(INDEX(A, lda, ii + i, l))), + _mm256_sign_epi8(load(INDEX(B, ldb, jj + j, l)), + load(INDEX(A, lda, ii + i, l)))); #else - __m128i ali0 = load0(A + lda * (ii + i) + l); - __m128i ali1 = load1(A + lda * (ii + i) + l); - __m128i blj0 = load0(B + ldb * (jj + j) + l); - __m128i blj1 = load1(B + ldb * (jj + j) + l); + __m128i ali0 = load0(INDEX(A, lda, ii + i, l)); + __m128i ali1 = load1(INDEX(A, lda, ii + i, l)); + __m128i blj0 = load0(INDEX(B, ldb, jj + j, l)); + __m128i blj1 = load1(INDEX(B, ldb, jj + j, l)); __m128i sepAA0 = _mm_sign_epi8(ali0, ali0); __m128i sepAA1 = _mm_sign_epi8(ali1, ali1); @@ -747,14 +856,14 @@ class tinyBLAS_Q0_AVX { __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1); __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0))); #endif - Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * - unhalf(B[ldb * (jj + j) + l].d)), + Cv[j][i] = madd(_mm256_set1_ps(unhalf(INDEX(A, lda, ii + i, l)->d) * + unhalf(INDEX(B, ldb, jj + j, l)->d)), udTmp, Cv[j][i]); } for (int64_t j = 0; j < RN; ++j) for (int64_t i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i])); } } @@ -857,6 +966,9 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda assert(nth > 0); assert(ith < nth); + if (n < 2) + return NOT_SUPPORTED; + if (Ctype != GGML_TYPE_F32) return false; @@ -864,105 +976,166 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda case GGML_TYPE_F32: { if (Btype != GGML_TYPE_F32) - return false; + return NOT_SUPPORTED; #if defined(__AVX512F__) if (k % 16) - return false; - tinyBLAS<16, __m512, __m512, float, float, float> tb{ - k, (const float *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + return NOT_SUPPORTED; + tinyBLAS<0, 16, __m512, __m512, float, float, float> tb{ + k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n); return true; #elif defined(__AVX__) || defined(__AVX2__) if (k % 8) - return false; - tinyBLAS<8, __m256, __m256, float, float, float> tb{ - k, (const float *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + return NOT_SUPPORTED; + tinyBLAS<0, 8, __m256, __m256, float, float, float> tb{ + k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n); return true; #elif defined(__ARM_NEON) - if (n < 4) - return false; if (k % 4) - return false; - tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ - k, (const float *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + return NOT_SUPPORTED; + tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, float> tb{ + k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n); return true; #else - return false; + return NOT_SUPPORTED; #endif } - case GGML_TYPE_F16: { -#if defined(__AVX512F__) + case GGML_TYPE_BF16: { +#if defined(__AVX512BF16__) + if (k % 32) + return NOT_SUPPORTED; + if (Btype == GGML_TYPE_F32 && n < 2) { + tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, float> tb{ + k, (const ggml_bf16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; + tb.matmul(m, n); + return true; + } + if (Btype == GGML_TYPE_F32) + return WANT_QUANTIZATION; + if (Btype != GGML_TYPE_BF16) + return NOT_SUPPORTED; + tinyBLAS<0, 32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ + k, (const ggml_bf16_t *)A, lda, (const ggml_bf16_t *)B, ldb, (float *)C, ldc, ith, nth}; + tb.matmul(m, n); + return true; +#elif defined(__AVX512F__) if (k % 16) - return false; - if (Btype != GGML_TYPE_F32) - return false; - tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + return NOT_SUPPORTED; + tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, float> tb{ + k, (const ggml_bf16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n); return true; -#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) +#elif defined(__AVX2__) if (k % 8) - return false; + return NOT_SUPPORTED; if (Btype != GGML_TYPE_F32) - return false; - tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + return NOT_SUPPORTED; + tinyBLAS<0, 8, __m256, __m256, ggml_bf16_t, float, float> tb{ + k, (const ggml_bf16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n); return true; -#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) - if (n < 8) - return false; - if (k % 8) - return false; +#elif defined(__ARM_NEON) && !defined(_MSC_VER) + if (k % 4) + return NOT_SUPPORTED; + if (Btype != GGML_TYPE_F32) + return NOT_SUPPORTED; + tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_bf16_t, float, float> tb{ + k, (const ggml_bf16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; + tb.matmul(m, n); + return true; +#else + return NOT_SUPPORTED; +#endif + } + + case GGML_TYPE_F16: { +#if defined(__AVX512F__) + if (k % 16) + return NOT_SUPPORTED; + if (Btype == GGML_TYPE_F32 && n < 2) { + tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, float> tb{ + k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; + tb.matmul(m, n); + return true; + } + if (Btype == GGML_TYPE_F32) + return WANT_QUANTIZATION; if (Btype != GGML_TYPE_F16) - return false; - tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const ggml_fp16_t *)B, ldb, - (float *)C, ldc, - ith, nth}; + return NOT_SUPPORTED; + tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ + k, (const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n); return true; +#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) + if (X86_CHECK(F16C)) { + if (k % 8) + return NOT_SUPPORTED; + if (Btype == GGML_TYPE_F32 && n < 2) { + tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, float> tb{ + k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; + tb.matmul(m, n); + return true; + } + if (Btype == GGML_TYPE_F32) + return WANT_QUANTIZATION; + if (Btype != GGML_TYPE_F16) + return NOT_SUPPORTED; + tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ + k, (const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)B, ldb, (float *)C, ldc, ith, nth}; + tb.matmul(m, n); + return true; + } else { + return NOT_SUPPORTED; + } +#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) + if (n < 2) + // TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec? + return NOT_SUPPORTED; + if (precision == GGML_PREC_F32) { + if (k % 4) + return NOT_SUPPORTED; + if (Btype != GGML_TYPE_F32) + return NOT_SUPPORTED; + tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ + k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; + tb.matmul(m, n); + return true; + } else { + if (k % 8) + return NOT_SUPPORTED; + if (Btype == GGML_TYPE_F32) + return WANT_QUANTIZATION; + if (Btype != GGML_TYPE_F16) + return NOT_SUPPORTED; + tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ + k, (const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)B, ldb, (float *)C, ldc, ith, nth}; + tb.matmul(m, n); + return true; + } #elif defined(__ARM_NEON) && !defined(_MSC_VER) if (k % 4) - return false; + return NOT_SUPPORTED; if (Btype != GGML_TYPE_F32) - return false; - tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + return NOT_SUPPORTED; + tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ + k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n); return true; #else - return false; + return NOT_SUPPORTED; #endif } case GGML_TYPE_Q8_0: { + if (Btype == GGML_TYPE_F32) + return WANT_QUANTIZATION; if (Btype != GGML_TYPE_Q8_0) - return false; + return NOT_SUPPORTED; #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) - tinyBLAS_Q0_AVX tb{ + tinyBLAS_Q0_AVX<0, block_q8_0, block_q8_0, float> tb{ k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, @@ -970,7 +1143,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda tb.matmul(m, n); return true; #elif defined(__ARM_FEATURE_DOTPROD) - tinyBLAS_Q0_ARM tb{ + tinyBLAS_Q0_ARM<0, block_q8_0, block_q8_0, float> tb{ k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, @@ -986,7 +1159,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda if (Btype != GGML_TYPE_Q8_0) return false; #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) - tinyBLAS_Q0_AVX tb{ + tinyBLAS_Q0_AVX<0, block_q4_0, block_q8_0, float> tb{ k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, @@ -994,7 +1167,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda tb.matmul(m, n); return true; #elif defined(__ARM_FEATURE_DOTPROD) - tinyBLAS_Q0_ARM tb{ + tinyBLAS_Q0_ARM<0, block_q4_0, block_q8_0, float> tb{ k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, @@ -1025,3 +1198,402 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda (void)Btype; (void)Ctype; } + +// +// +// ██████╗ ██╗ █████╗ ██████╗ +// ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║ ██╔══██╗██╔═══╝ +// ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║ ███████║██████╗ +// ██║ ██║██▀███║╚███╔╝██╔══██╗██║ ██╔══██║╔═══██║ +// ██║ ██║██║ ██║ ███║ ██████╔╝████╗██║ ██║██████║ +// ╚═╝ ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝ ╚═╝╚═════╝ +// +// MIXTURE OF EXPERTS TENSOR MULTIPLICATION +// +// +// SHAPES +// +// - weights [cols, rows, experts] +// - thought [cols, tasks, tokens] w/ tasks ≤ thinkers +// - result [rows, thinkers, tokens] w/ thinkers ≤ experts +// - plan [thinkers, tokens] w/ i32 < experts +// +// DEFINITION +// +// for thinker in range(thinkers): +// for token in range(tokens): +// for row in range(rows): +// c = 0 +// for col in range(cols): +// expert = plan[token][thinker] +// a = weights[expert][row][col] +// b = thought[token][thinker % tasks][col] +// c += a * b +// result[token][thinker][row] = c +// +// REGULARITIES +// +// - tokens can be odd +// - thinkers is usually 2 +// - tasks is usually 1 or 2 +// - cols should be a multiple of 64 +// - rows should be a multiple of 64 +// - experts is usually 8 but could be 60 +// - tokens is always 1 for token generation +// - tokens can be huge for prompt processing +// +// EXAMPLE +// +// mixtral 8x7b w/ 217 token prompt +// +// | ne*0 ne*1 ne*2 ne*3 | nb*0 nb*1 nb*2 nb*3 | type +// ========================================================================= +// weights | 16384 6144 8 1 | 18 0x2400 0x3600000 0x1b000000 | q4_0 +// thought | 16384 2 217 1 | 4 0x10000 0x20000 0x1b20000 | f32 +// result | 6144 2 217 1 | 4 0x6000 0xc000 0xa2c000 | f32 +// plan | 2 217 1 1 | 4 0x20 0x1b20 0x1b20 | i32 +// + +namespace { + +class MixMul { + public: + MixMul(const ggml_compute_params *params, const ggml_tensor *weights, + const ggml_tensor *thought, const ggml_tensor *plan, ggml_tensor *result) + : params(params), + weights(weights), + thought(thought), + plan(plan), + result(result), + rows(weights->ne[1]), + cols(weights->ne[0]), + experts(weights->ne[2]), + thinkers(plan->ne[0]), + tasks(thought->ne[1]), + tokens(thought->ne[2]), + ldq((cols * 2 + ROW_ALIGN - 1) & -ROW_ALIGN), + wdata_((char *)(((uintptr_t)params->wdata + MAX_ALIGN - 1) & -MAX_ALIGN)), + allocated_(0) { + } + + bool allocate_shared_memory() { + if (!(quantized_thought_ = allocate(MATRIX_ALIGN, tokens * tasks * ldq))) + return false; + if (!(rowptr_result_ = allocate(ROW_ALIGN, experts * tokens * thinkers))) + return false; + if (!(rowptr_thought_ = allocate(ROW_ALIGN, experts * tokens * thinkers))) + return false; + if (!(rowptr_count_ = allocate(sizeof(long), experts))) + return false; + return true; + } + + size_t get_allocated_bytes() { + return (wdata_ - (char *)params->wdata) + allocated_; + } + + bool mixmul() { + + // invariants + assert(tasks <= thinkers); + assert(thinkers <= experts); + assert(tokens == plan->ne[1]); + assert(rows == result->ne[0]); + assert(cols == thought->ne[0]); + assert(tokens == result->ne[2]); + assert(thinkers == result->ne[1]); + + // dimensionality + assert(plan->ne[2] == 1); + assert(plan->ne[3] == 1); + assert(result->ne[3] == 1); + assert(weights->ne[3] == 1); + assert(thought->ne[3] == 1); + + // miscellaneous + assert(params->nth > 0); + assert(params->ith < params->nth); + assert(plan->type == GGML_TYPE_I32); + + // check nb01 is convertible to lda + if (weights->nb[1] % ggml_type_size(weights->type)) + return false; + + // no support for column strides + if (result->nb[0] != ggml_type_size(result->type)) + return false; + if (thought->nb[0] != ggml_type_size(thought->type)) + return false; + if (weights->nb[0] != ggml_type_size(weights->type)) + return false; + + if (rows < 2) + return NOT_SUPPORTED; + + // supported output types + switch (result->type) { + case GGML_TYPE_F32: + return mixmuler(); + default: + return false; + } + } + + private: + template + bool mixmuler() { + switch (weights->type) { + + case GGML_TYPE_F32: + if (thought->type != GGML_TYPE_F32) + return false; +#if defined(__AVX512F__) + return mixmat<16, 1, tinyBLAS, float, + float, TC>(); +#elif defined(__AVX__) || defined(__AVX2__) + return mixmat<8, 1, tinyBLAS, float, + float, TC>(); +#elif defined(__SSE__) + return mixmat<4, 1, tinyBLAS, float, + float, TC>(); +#elif defined(__ARM_NEON) + return mixmat<4, 1, tinyBLAS, + float, float, TC>(); +#else + return false; +#endif + + case GGML_TYPE_BF16: + if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_BF16) + return false; +#if defined(__AVX512BF16__) + return mixmat< + 32, 1, tinyBLAS, + ggml_bf16_t, ggml_bf16_t, TC>(); +#elif defined(__AVX512F__) + return mixmat<16, 1, + tinyBLAS, + ggml_bf16_t, ggml_bf16_t, TC>(); +#elif defined(__AVX2__) + return mixmat<8, 1, + tinyBLAS, + ggml_bf16_t, ggml_bf16_t, TC>(); +#elif defined(__ARM_NEON) && !defined(_MSC_VER) + return mixmat< + 4, 1, + tinyBLAS, + ggml_bf16_t, ggml_bf16_t, TC>(); +#else + return false; +#endif + + case GGML_TYPE_F16: + if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_F16) + return false; +#if defined(__AVX512F__) + return mixmat<16, 1, + tinyBLAS, + ggml_fp16_t, ggml_fp16_t, TC>(); +#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) + if (X86_CHECK(F16C)) { + return mixmat<8, 1, + tinyBLAS, + ggml_fp16_t, ggml_fp16_t, TC>(); + } else { + return false; + } +#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) + if (result->op_params[0] == GGML_PREC_F32) { + return mixmat< + 4, 1, + tinyBLAS, + ggml_fp16_t, ggml_fp16_t, TC>(); + } else { + return mixmat< + 8, 1, + tinyBLAS, + ggml_fp16_t, ggml_fp16_t, TC>(); + } +#elif defined(__ARM_NEON) && !defined(_MSC_VER) + return mixmat< + 4, 1, + tinyBLAS, + ggml_fp16_t, ggml_fp16_t, TC>(); +#else + return false; +#endif + + case GGML_TYPE_Q4_0: + if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_Q8_0) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) + return mixmat<32, 32, tinyBLAS_Q0_AVX, + block_q4_0, block_q8_0, TC>(); +#elif defined(__ARM_FEATURE_DOTPROD) + return mixmat<32, 32, tinyBLAS_Q0_ARM, + block_q4_0, block_q8_0, TC>(); +#else + return false; +#endif + + case GGML_TYPE_Q8_0: + if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_Q8_0) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) + return mixmat<32, 32, tinyBLAS_Q0_AVX, + block_q8_0, block_q8_0, TC>(); +#elif defined(__ARM_FEATURE_DOTPROD) + return mixmat<32, 32, tinyBLAS_Q0_ARM, + block_q8_0, block_q8_0, TC>(); +#else + return false; +#endif + + default: + return false; + } + } + + template + bool mixmat() { + if (cols % KN) + return false; + if (thought->type != ggml_type_trait::id) + quantize_thought(ggml_type_trait::id); + build_row_pointers(ggml_type_trait::id); + ggml_barrier(params->shared); + assert(!(cols % BS)); + assert(!(weights->nb[1] % sizeof(TA))); + for (int expert = 0; expert < experts; ++expert) { + BLAS tb{cols / BS, + (const TA *)((const char *)weights->data + expert * weights->nb[2]), + (long)(weights->nb[1] / sizeof(TA)), + (const TB *)(rowptr_thought_ + expert * tokens * thinkers), + 0, + (TC *)(rowptr_result_ + expert * tokens * thinkers), + 0, + params->ith, + params->nth}; + tb.matmul(rows, rowptr_count_[expert]); + } + return true; + } + + void build_row_pointers(ggml_type vec_dot_type) { + for (int expert = params->ith; expert < experts; expert += params->nth) { + long count = 0; + for (long token = 0; token < tokens; ++token) + for (int thinker = 0; thinker < thinkers; ++thinker) + if (expert == *(const int32_t *)((const char *)plan->data + + token * plan->nb[1] + thinker * plan->nb[0])) { + long row = count++; + long idx = expert * thinkers * tokens + row; + rowptr_result_[idx] = + (uintptr_t)((char *)result->data + token * result->nb[2] + + thinker * result->nb[1]); + if (thought->type == vec_dot_type) + rowptr_thought_[idx] = + (uintptr_t)((char *)thought->data + token * thought->nb[2] + + thinker % tasks * thought->nb[1]); + else + rowptr_thought_[idx] = + (uintptr_t)((char *)quantized_thought_ + token * tasks * ldq + + thinker % tasks * ldq); + } + rowptr_count_[expert] = count; + } + } + + void quantize_thought(ggml_type vec_dot_type) { + long chore = 0; + for (long token = 0; token < tokens; ++token) + for (int task = 0; task < tasks; ++task) + if (chore++ % params->nth == params->ith) + quantize_row(quantized_thought_ + token * tasks * ldq + task * ldq, + (const float *)((const char *)thought->data + + token * thought->nb[2] + task * thought->nb[1]), + vec_dot_type); + } + + void quantize_row(void *dst, const float *src, ggml_type type) { + assert((long)ggml_row_size(type, cols) <= ldq); + switch (type) { + case GGML_TYPE_F16: + ggml_fp32_to_fp16_row(src, (ggml_fp16_t *)dst, cols); + break; + case GGML_TYPE_BF16: + ggml_fp32_to_bf16_row(src, (ggml_bf16_t *)dst, cols); + break; + case GGML_TYPE_Q8_0: + quantize_row_q8_0((const float *)src, (block_q8_0 *)dst, cols); + break; + default: + GGML_UNREACHABLE(); + } + } + + template + T *allocate(size_t align, size_t elems) { + T *res = nullptr; + size_t need = sizeof(T) * elems; + size_t base = allocated_; + base += align - 1; + base &= -align; + size_t toto = base + need; + if (toto >= allocated_ && toto <= params->wsize) { + res = (T *)(wdata_ + base); + allocated_ = toto; + } + return res; + } + + const ggml_compute_params *const params; + const ggml_tensor *const weights; + const ggml_tensor *const thought; + const ggml_tensor *const plan; + ggml_tensor *const result; + const long rows; + const long cols; + const int experts; + const int thinkers; + const int tasks; + const long tokens; + const long ldq; + + // variables + char *const wdata_; + size_t allocated_; + + // shared memory + long *rowptr_count_ /*[experts]*/; + char *quantized_thought_ /*[tokens][tasks][cols][2]*/; + uintptr_t *rowptr_result_ /*[experts][tokens*thinkers]*/; + uintptr_t *rowptr_thought_ /*[experts][tokens*thinkers]*/; +}; + +} // namespace + +/** + * Performs "mixture of experts" tensor multiplication on CPU. + */ +bool llamafile_mixmul(const ggml_compute_params *params, const ggml_tensor *weights, + const ggml_tensor *thought, const ggml_tensor *plan, ggml_tensor *result) { + MixMul mm{params, weights, thought, plan, result}; + return mm.allocate_shared_memory() && mm.mixmul(); +} + +/** + * Returns number of shared memory bytes llamafile_mixmul() needs. + */ +size_t llamafile_mixmul_needs(const ggml_tensor *weights, + const ggml_tensor *thought, + const ggml_tensor *plan) { + ggml_compute_params params{}; + params.wsize = 0x7ffff000; + params.wdata = (void *)0x1000; + MixMul mm{¶ms, weights, thought, plan, 0}; + if (mm.allocate_shared_memory()) + return mm.get_allocated_bytes(); + else + return 0; +} diff --git a/ggml/src/sgemm.h b/ggml/src/sgemm.h index caf6dd5567b3a..2a89e20970ec4 100644 --- a/ggml/src/sgemm.h +++ b/ggml/src/sgemm.h @@ -1,4 +1,5 @@ #pragma once +#include "ggml.h" #include #include #ifdef __cplusplus @@ -9,6 +10,13 @@ bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t, const void *, int64_t, void *, int64_t, int, int, int, int, int); +bool llamafile_mixmul(const struct ggml_compute_params *, const struct ggml_tensor *, + const struct ggml_tensor *, const struct ggml_tensor *, struct ggml_tensor *); + +size_t llamafile_mixmul_needs(const struct ggml_tensor *, + const struct ggml_tensor *, + const struct ggml_tensor *); + #ifdef __cplusplus } #endif