From c34c47237148ba264a4d1ac75d1ecc70c7952589 Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Mon, 22 Apr 2024 22:15:10 -0700 Subject: [PATCH] llamafile : improve moe prompt eval speed on cpu This change introduces a llamafile_mixmul() API that allows tinyBLAS to speed up "Mixture of Expert" models. On my Threadripper, Mixtral's 8x7b F16 weights now process prompts 2x faster. I'm also seeing a 60 percent improvement with Mixtral 8x22b Q4_0. The same applies to Q8_0, which is also supported by tinyBLAS. MoE models spend the majority of their time inside MUL_MAT_ID rather than MUL_MAT, which is why llamafile_sgemm was not able to help them before. llamafile_mixmul works by decomposing the mixmul operation into sgemm calls. --- common/common.cpp | 6 +- ggml.c | 8 +- sgemm.cpp | 785 +++++++++++++++++++++++++++++++++------------- sgemm.h | 16 +- 4 files changed, 599 insertions(+), 216 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 97f55b053eee11..7edf4eb15beb77 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -73,7 +73,7 @@ using json = nlohmann::ordered_json; int32_t get_num_physical_cores() { -#ifdef __linux__ +#if defined(__linux__) || defined(__COSMOPOLITAN__) // enumerate the set of thread siblings, num entries is num cores std::unordered_set siblings; for (uint32_t cpu=0; cpu < UINT32_MAX; ++cpu) { @@ -108,7 +108,7 @@ int32_t get_num_physical_cores() { return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; } -#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__) +#if defined(__x86_64__) && (defined(__linux__) || defined(__COSMOPOLITAN__)) && !defined(__ANDROID__) #include static void cpuid(unsigned leaf, unsigned subleaf, @@ -162,7 +162,7 @@ static int count_math_cpus(int cpu_count) { * Returns number of CPUs on system that are useful for math. */ int get_math_cpu_count() { -#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__) +#if defined(__x86_64__) && (defined(__linux__) || defined(__COSMOPOLITAN__)) && !defined(__ANDROID__) int cpu_count = sysconf(_SC_NPROCESSORS_ONLN); if (cpu_count < 1) { return get_num_physical_cores(); diff --git a/ggml.c b/ggml.c index 9f99a69d1158b9..c0701ef098436b 100644 --- a/ggml.c +++ b/ggml.c @@ -10991,11 +10991,14 @@ static void ggml_compute_forward_mul_mat_id( const struct ggml_tensor * src1 = dst->src[1]; const struct ggml_tensor * ids = dst->src[2]; - GGML_TENSOR_BINARY_OP_LOCALS + if (llamafile_mixmul(params, src0, src1, ids, dst)) + return; const int ith = params->ith; const int nth = params->nth; + GGML_TENSOR_BINARY_OP_LOCALS + const enum ggml_type type = src0->type; const bool src1_cont = ggml_is_contiguous(src1); @@ -18492,6 +18495,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur = 0; const struct ggml_tensor * src0 = node->src[0]; const struct ggml_tensor * src1 = node->src[1]; + const struct ggml_tensor * src2 = node->src[2]; const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; if (src1->type != vec_dot_type) { cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)); @@ -18500,6 +18504,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += GGML_PAD(cur, sizeof(int64_t)); // align cur += n_as * sizeof(int64_t); // matrix_row_counts cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows + size_t cur2 = llamafile_mixmul_needs(src0, src1, src2); + cur = cur > cur2 ? cur : cur2; } break; case GGML_OP_OUT_PROD: { diff --git a/sgemm.cpp b/sgemm.cpp index 4e0159804e8166..4b0e3068e3a506 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -53,6 +53,10 @@ #include "ggml-impl.h" #include "ggml-quants.h" +#define ROW_ALIGN 64 +#define MATRIX_ALIGN 4096 +#define MAX_ALIGN 4096 + #ifdef _MSC_VER #define NOINLINE __declspec(noinline) #else @@ -65,7 +69,7 @@ #define VECTOR_REGISTERS 16 #endif -#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) +typedef intptr_t dim; namespace { @@ -73,6 +77,23 @@ inline float unhalf(ggml_fp16_t d) { return GGML_FP16_TO_FP32(d); } +//////////////////////////////////////////////////////////////////////////////////////////////////// +// MATRIX MEMORY INDEXING + +#define NCA 1 +#define NCB 2 +#define NCC 4 + +#define INDEX(A, lda, j, i) (CONFIG & NC##A ? ((T##A *const *)A)[j] + i : A + lda * (j) + i) + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GGML TYPE TRAITS + +template struct ggml_type_trait; +template<> struct ggml_type_trait { static constexpr ggml_type id = GGML_TYPE_F32; }; +template<> struct ggml_type_trait { static constexpr ggml_type id = GGML_TYPE_F16; }; +template<> struct ggml_type_trait { static constexpr ggml_type id = GGML_TYPE_Q8_0; }; + //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED ARITHMETIC OPERATIONS @@ -117,6 +138,21 @@ inline U madd(T a, T b, U c) { return add(mul(a, b), c); } +/** + * Computes a * b + c with error correction. + * + * @see W. Kahan, "Further remarks on reducing truncation errors," + * Communications of the ACM, vol. 8, no. 1, p. 40, Jan. 1965, + * doi: 10.1145/363707.363723. + */ +template +inline U madder(T a, T b, U c, U *e) { + U y = sub(mul(a, b), *e); + U t = add(c, y); + *e = sub(sub(t, c), y); + return t; +} + #if defined(__FMA__) #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) template <> @@ -135,10 +171,10 @@ inline __m512 madd(__m512 a, __m512 b, __m512 c) { #if defined(__ARM_FEATURE_FMA) template <> inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) { - return vfmaq_f32(c, b, a); + return vfmaq_f32(c, a, b); } -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) -template <> +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(__clang__) +template <> // this specialization chops gcc 12.3 performance in half inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) { return vfmaq_f16(c, b, a); } @@ -156,6 +192,7 @@ inline float hsum(float32x4_t x) { #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) inline float hsum(float16x8_t x) { + // todo: works great for clang but produces sketchy code with gcc 12.3 return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)), vcvt_f32_f16(vget_high_f16(x)))); } @@ -239,61 +276,57 @@ template <> inline __m512 load(const ggml_fp16_t *p) { //////////////////////////////////////////////////////////////////////////////////////////////////// // FLOATING POINT MATRIX MULTIPLICATION -template +template class tinyBLAS { public: - tinyBLAS(int64_t k, - const TA *A, int64_t lda, - const TB *B, int64_t ldb, - TC *C, int64_t ldc, + tinyBLAS(dim k, + const TA *A, dim lda, + const TB *B, dim ldb, + TC *C, dim ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } - void matmul(int64_t m, int64_t n, int task) { + void matmul(dim m, dim n, int task) { if (task == GGML_TASK_TYPE_COMPUTE) mnpack(0, m, 0, n); } private: - NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; + NOINLINE void mnpack(dim m0, dim m, dim n0, dim n) { + dim mc, nc, mp, np; switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) { #if VECTOR_REGISTERS == 32 case 0x55: + case 0x54: mc = 5; - nc = 5; - gemm<5, 5>(m0, m, n0, n); + nc = 4; + gemm<5, 4, false>(m0, m, n0, n); break; case 0x45: mc = 4; nc = 5; - gemm<4, 5>(m0, m, n0, n); - break; - case 0x54: - mc = 5; - nc = 4; - gemm<5, 4>(m0, m, n0, n); + gemm<4, 5, false>(m0, m, n0, n); break; case 0x44: mc = 4; nc = 4; - gemm<4, 4>(m0, m, n0, n); + gemm<4, 4, false>(m0, m, n0, n); break; case 0x53: mc = 5; nc = 3; - gemm<5, 3>(m0, m, n0, n); + gemm<5, 3, false>(m0, m, n0, n); break; case 0x35: mc = 3; nc = 5; - gemm<3, 5>(m0, m, n0, n); + gemm<3, 5, false>(m0, m, n0, n); break; case 0x43: mc = 4; nc = 3; - gemm<4, 3>(m0, m, n0, n); + gemm<4, 3, false>(m0, m, n0, n); break; #else case 0x55: @@ -304,99 +337,99 @@ class tinyBLAS { case 0x43: mc = 4; nc = 3; - gemm<4, 3>(m0, m, n0, n); + gemm<4, 3, false>(m0, m, n0, n); break; case 0x35: #endif case 0x34: mc = 3; nc = 4; - gemm<3, 4>(m0, m, n0, n); + gemm<3, 4, false>(m0, m, n0, n); break; case 0x52: mc = 5; nc = 2; - gemm<5, 2>(m0, m, n0, n); + gemm<5, 2, false>(m0, m, n0, n); break; case 0x33: mc = 3; nc = 3; - gemm<3, 3>(m0, m, n0, n); + gemm<3, 3, false>(m0, m, n0, n); break; case 0x25: mc = 2; nc = 5; - gemm<2, 5>(m0, m, n0, n); + gemm<2, 5, false>(m0, m, n0, n); break; case 0x42: mc = 4; nc = 2; - gemm<4, 2>(m0, m, n0, n); + gemm<4, 2, false>(m0, m, n0, n); break; case 0x24: mc = 2; nc = 4; - gemm<2, 4>(m0, m, n0, n); + gemm<2, 4, false>(m0, m, n0, n); break; case 0x32: mc = 3; nc = 2; - gemm<3, 2>(m0, m, n0, n); + gemm<3, 2, true>(m0, m, n0, n); break; case 0x23: mc = 2; nc = 3; - gemm<2, 3>(m0, m, n0, n); + gemm<2, 3, true>(m0, m, n0, n); break; case 0x51: mc = 5; nc = 1; - gemm<5, 1>(m0, m, n0, n); + gemm<5, 1, true>(m0, m, n0, n); break; case 0x41: mc = 4; nc = 1; - gemm<4, 1>(m0, m, n0, n); + gemm<4, 1, true>(m0, m, n0, n); break; case 0x22: mc = 2; nc = 2; - gemm<2, 2>(m0, m, n0, n); + gemm<2, 2, true>(m0, m, n0, n); break; case 0x15: mc = 1; nc = 5; - gemm<1, 5>(m0, m, n0, n); + gemm<1, 5, true>(m0, m, n0, n); break; case 0x14: mc = 1; nc = 4; - gemm<1, 4>(m0, m, n0, n); + gemm<1, 4, true>(m0, m, n0, n); break; case 0x31: mc = 3; nc = 1; - gemm<3, 1>(m0, m, n0, n); + gemm<3, 1, true>(m0, m, n0, n); break; case 0x13: mc = 1; nc = 3; - gemm<1, 3>(m0, m, n0, n); + gemm<1, 3, true>(m0, m, n0, n); break; case 0x21: mc = 2; nc = 1; - gemm<2, 1>(m0, m, n0, n); + gemm<2, 1, true>(m0, m, n0, n); break; case 0x12: mc = 1; nc = 2; - gemm<1, 2>(m0, m, n0, n); + gemm<1, 2, true>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; - gemm<1, 1>(m0, m, n0, n); + gemm<1, 1, true>(m0, m, n0, n); break; default: return; @@ -407,39 +440,45 @@ class tinyBLAS { mnpack(m0, m, np, n); } - template - NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t ytiles = (m - m0) / RM; - int64_t xtiles = (n - n0) / RN; - int64_t tiles = xtiles * ytiles; - int64_t duty = (tiles + nth - 1) / nth; - int64_t start = duty * ith; - int64_t end = start + duty; + template + NOINLINE void gemm(dim m0, dim m, dim n0, dim n) { + dim ytiles = RM > 1 ? (m - m0) / RM : 1; + dim xtiles = RN > 1 ? (n - n0) / RN : 1; + dim tiles = xtiles * ytiles; + dim duty = (tiles + nth - 1) / nth; + dim start = duty * ith; + dim end = start + duty; if (end > tiles) end = tiles; - for (int64_t job = start; job < end; ++job) { - int64_t ii = m0 + job / xtiles * RM; - int64_t jj = n0 + job % xtiles * RN; + for (dim job = start; job < end; ++job) { + dim ii = m0 + job / xtiles * RM; + dim jj = n0 + job % xtiles * RN; D Cv[RN][RM] = {}; - for (int64_t l = 0; l < k; l += KN) - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - Cv[j][i] = madd(load(A + lda * (ii + i) + l), - load(B + ldb * (jj + j) + l), - Cv[j][i]); - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + D Ce[RN][RM] = {}; + for (dim l = 0; l < k; l += KN) + for (int j = 0; j < RN; ++j) + for (int i = 0; i < RM; ++i) + if (KAHAN) + Cv[j][i] = madder(load(INDEX(A, lda, ii + i, l)), + load(INDEX(B, ldb, jj + j, l)), + Cv[j][i], &Ce[j][i]); + else + Cv[j][i] = madd(load(INDEX(A, lda, ii + i, l)), + load(INDEX(B, ldb, jj + j, l)), + Cv[j][i]); + for (int j = 0; j < RN; ++j) + for (int i = 0; i < RM; ++i) + *INDEX(C, ldc, jj + j, ii + i) = hsum(Cv[j][i]); } } const TA *const A; const TB *const B; TC *const C; - const int64_t k; - const int64_t lda; - const int64_t ldb; - const int64_t ldc; + const dim k; + const dim lda; + const dim ldb; + const dim ldc; const int ith; const int nth; }; @@ -448,25 +487,25 @@ class tinyBLAS { // QUANT ZERO MATRIX MULTIPLICATION #if defined(__ARM_FEATURE_DOTPROD) -template +template class tinyBLAS_Q0_ARM { public: - tinyBLAS_Q0_ARM(int64_t k, - const TA *A, int64_t lda, - const block_q8_0 *B, int64_t ldb, - float *C, int64_t ldc, + tinyBLAS_Q0_ARM(dim k, + const TA *A, dim lda, + const TB *B, dim ldb, + TC *C, dim ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } - void matmul(int64_t m, int64_t n, int task) { + void matmul(dim m, dim n, int task) { if (task == GGML_TASK_TYPE_COMPUTE) mnpack(0, m, 0, n); } private: - NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; + NOINLINE void mnpack(dim m0, dim m, dim n0, dim n) { + dim mc, nc, mp, np; switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) { case 0x33: mc = 3; @@ -523,34 +562,34 @@ class tinyBLAS_Q0_ARM { } template - NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t ytiles = (m - m0) / RM; - int64_t xtiles = (n - n0) / RN; - int64_t tiles = xtiles * ytiles; - int64_t duty = (tiles + nth - 1) / nth; - int64_t start = duty * ith; - int64_t end = start + duty; + NOINLINE void gemm(dim m0, dim m, dim n0, dim n) { + dim ytiles = RM > 1 ? (m - m0) / RM : 1; + dim xtiles = RN > 1 ? (n - n0) / RN : 1; + dim tiles = xtiles * ytiles; + dim duty = (tiles + nth - 1) / nth; + dim start = duty * ith; + dim end = start + duty; if (end > tiles) end = tiles; - for (int64_t job = start; job < end; ++job) { - int64_t ii = m0 + job / xtiles * RM; - int64_t jj = n0 + job % xtiles * RN; + for (dim job = start; job < end; ++job) { + dim ii = m0 + job / xtiles * RM; + dim jj = n0 + job % xtiles * RN; float32x4_t Cv[RN][RM] = {}; - for (int64_t l = 0; l < k; ++l) - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - Cv[j][i] = vmlaq_n_f32(Cv[j][i], - vcvtq_f32_s32(vdotq_s32( - vdotq_s32(vdupq_n_s32(0), - load_lo(A + lda * (ii + i) + l), - load_lo(B + ldb * (jj + j) + l)), - load_hi(A + lda * (ii + i) + l), - load_hi(B + ldb * (jj + j) + l))), - unhalf(A[lda * (ii + i) + l].d) * - unhalf(B[ldb * (jj + j) + l].d)); - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + for (dim l = 0; l < k; ++l) + for (int j = 0; j < RN; ++j) + for (int i = 0; i < RM; ++i) + Cv[j][i] = vmlaq_n_f32( + Cv[j][i], + vcvtq_f32_s32(vdotq_s32(vdotq_s32(vdupq_n_s32(0), + load_lo(INDEX(A, lda, ii + i, l)), + load_lo(INDEX(B, ldb, jj + j, l))), + load_hi(INDEX(A, lda, ii + i, l)), + load_hi(INDEX(B, ldb, jj + j, l)))), + (unhalf(INDEX(A, lda, ii + i, l)->d) * + unhalf(INDEX(B, ldb, jj + j, l)->d))); + for (int j = 0; j < RN; ++j) + for (int i = 0; i < RM; ++i) + *INDEX(C, ldc, jj + j, ii + i) = hsum(Cv[j][i]); } } @@ -563,48 +602,46 @@ class tinyBLAS_Q0_ARM { } inline int8x16_t load_lo(const block_q4_0 *b) { - return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs), - vdupq_n_u8(0x0f))), + return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs), vdupq_n_u8(0x0f))), vdupq_n_s8(0x8)); } inline int8x16_t load_hi(const block_q4_0 *b) { - return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)), - vdupq_n_s8(0x8)); + return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)), vdupq_n_s8(0x8)); } const TA *const A; - const block_q8_0 *const B; - float *const C; - const int64_t k; - const int64_t lda; - const int64_t ldb; - const int64_t ldc; + const TB *const B; + TC *const C; + const dim k; + const dim lda; + const dim ldb; + const dim ldc; const int ith; const int nth; }; #endif // __ARM_FEATURE_DOTPROD #if defined(__AVX2__) || defined(__AVX512F__) -template +template class tinyBLAS_Q0_AVX2 { public: - tinyBLAS_Q0_AVX2(int64_t k, - const TA *A, int64_t lda, - const TB *B, int64_t ldb, - TC *C, int64_t ldc, + tinyBLAS_Q0_AVX2(dim k, + const TA *A, dim lda, + const TB *B, dim ldb, + TC *C, dim ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } - void matmul(int64_t m, int64_t n, int task) { + void matmul(dim m, dim n, int task) { if (task == GGML_TASK_TYPE_COMPUTE) mnpack(0, m, 0, n); } private: - void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; + void mnpack(dim m0, dim m, dim n0, dim n) { + dim mc, nc, mp, np; switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) { #if VECTOR_REGISTERS == 32 case 0x44: @@ -713,32 +750,32 @@ class tinyBLAS_Q0_AVX2 { } template - NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t ytiles = (m - m0) / RM; - int64_t xtiles = (n - n0) / RN; - int64_t tiles = xtiles * ytiles; - int64_t duty = (tiles + nth - 1) / nth; - int64_t start = duty * ith; - int64_t end = start + duty; + NOINLINE void gemm(dim m0, dim m, dim n0, dim n) { + dim ytiles = RM > 1 ? (m - m0) / RM : 1; + dim xtiles = RN > 1 ? (n - n0) / RN : 1; + dim tiles = xtiles * ytiles; + dim duty = (tiles + nth - 1) / nth; + dim start = duty * ith; + dim end = start + duty; if (end > tiles) end = tiles; - for (int64_t job = start; job < end; ++job) { - int64_t ii = m0 + job / xtiles * RM; - int64_t jj = n0 + job % xtiles * RN; + for (dim job = start; job < end; ++job) { + dim ii = m0 + job / xtiles * RM; + dim jj = n0 + job % xtiles * RN; __m256 Cv[RN][RM] = {}; - for (int64_t l = 0; l < k; ++l) - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * - unhalf(B[ldb * (jj + j) + l].d)), - updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), - load(A + lda * (ii + i) + l)), - _mm256_sign_epi8(load(B + ldb * (jj + j) + l), - load(A + lda * (ii + i) + l))), + for (dim l = 0; l < k; ++l) + for (int j = 0; j < RN; ++j) + for (int i = 0; i < RM; ++i) + Cv[j][i] = madd(_mm256_set1_ps(unhalf(INDEX(A, lda, ii + i, l)->d) * + unhalf(INDEX(B, ldb, jj + j, l)->d)), + updot(_mm256_sign_epi8(load(INDEX(A, lda, ii + i, l)), + load(INDEX(A, lda, ii + i, l))), + _mm256_sign_epi8(load(INDEX(B, ldb, jj + j, l)), + load(INDEX(A, lda, ii + i, l)))), Cv[j][i]); - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + for (int j = 0; j < RN; ++j) + for (int i = 0; i < RM; ++i) + *INDEX(C, ldc, jj + j, ii + i) = hsum(Cv[j][i]); } } @@ -747,7 +784,15 @@ class tinyBLAS_Q0_AVX2 { } inline __m256i load(const block_q4_0 *b) { - return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); + __m128i x = _mm_loadu_si128((const __m128i *)b->qs); + return _mm256_sub_epi8( + _mm256_and_si256( + _mm256_set1_epi8(15), + _mm256_insertf128_si256( + _mm256_castsi128_si256(x), + _mm_srli_epi16(x, 4), + 1)), + _mm256_set1_epi8(8)); } inline __m256 updot(__m256i u, __m256i s) { @@ -760,20 +805,13 @@ class tinyBLAS_Q0_AVX2 { return _mm256_cvtepi32_ps(res); } - static inline __m256i denibble(const uint8_t *p) { - __m128i x = _mm_loadu_si128((const __m128i *)p); - return _mm256_and_si256(_mm256_set1_epi8(15), - _mm256_insertf128_si256(_mm256_castsi128_si256(x), - _mm_srli_epi16(x, 4), 1)); - } - const TA *const A; const TB *const B; TC *const C; - const int64_t k; - const int64_t lda; - const int64_t ldb; - const int64_t ldc; + const dim k; + const dim lda; + const dim ldb; + const dim ldc; const int ith; const int nth; }; @@ -812,8 +850,12 @@ class tinyBLAS_Q0_AVX2 { * @param Ctype is GGML data type of `C` * @return true if this function was able to service the matmul request */ -bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C, - int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) { +bool llamafile_sgemm(dim m, dim n, dim k, + const void *A, dim lda, + const void *B, dim ldb, + void *C, dim ldc, + int ith, int nth, int task, + int Atype, int Btype, int Ctype) { assert(m >= 0); assert(n >= 0); @@ -835,21 +877,15 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda #if defined(__AVX512F__) if (k % 16) return false; - tinyBLAS<16, __m512, __m512, float, float, float> tb{ - k, (const float *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS<0, 16, __m512, __m512, float, float, float> tb{ + k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__AVX__) || defined(__AVX2__) if (k % 8) return false; - tinyBLAS<8, __m256, __m256, float, float, float> tb{ - k, (const float *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS<0, 8, __m256, __m256, float, float, float> tb{ + k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_NEON) @@ -857,11 +893,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda return false; if (k % 4) return false; - tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ - k, (const float *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, float> tb{ + k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else @@ -875,11 +908,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda return false; if (Btype != GGML_TYPE_F32) return false; - tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, float> tb{ + k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) @@ -887,11 +917,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda return false; if (Btype != GGML_TYPE_F32) return false; - tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, float> tb{ + k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) @@ -901,11 +928,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda return false; if (Btype != GGML_TYPE_F16) return false; - tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const ggml_fp16_t *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ + k, (const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_NEON) && !defined(_MSC_VER) @@ -913,11 +937,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda return false; if (Btype != GGML_TYPE_F32) return false; - tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ + k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else @@ -927,21 +948,15 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda case GGML_TYPE_Q8_0: { if (Btype != GGML_TYPE_Q8_0) - return false; + return false; #if defined(__AVX2__) || defined(__AVX512F__) - tinyBLAS_Q0_AVX2 tb{ - k, (const block_q8_0 *)A, lda, - (const block_q8_0 *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS_Q0_AVX2<0, block_q8_0, block_q8_0, float> tb{ + k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_FEATURE_DOTPROD) - tinyBLAS_Q0_ARM tb{ - k, (const block_q8_0 *)A, lda, - (const block_q8_0 *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS_Q0_ARM<0, block_q8_0, block_q8_0, float> tb{ + k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else @@ -953,19 +968,13 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda if (Btype != GGML_TYPE_Q8_0) return false; #if defined(__AVX2__) || defined(__AVX512F__) - tinyBLAS_Q0_AVX2 tb{ - k, (const block_q4_0 *)A, lda, - (const block_q8_0 *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS_Q0_AVX2<0, block_q4_0, block_q8_0, float> tb{ + k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_FEATURE_DOTPROD) - tinyBLAS_Q0_ARM tb{ - k, (const block_q4_0 *)A, lda, - (const block_q8_0 *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS_Q0_ARM<0, block_q4_0, block_q8_0, float> tb{ + k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else @@ -993,3 +1002,359 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda (void)Btype; (void)Ctype; } + +// +// _ _ ___ _ _ ___ +// | |_(_)_ _ _ _| _ ) | /_\ / __| +// | _| | ' \ || | _ \ |__ / _ \\__ \. +// \__|_|_||_\_, |___/____/_/ \_\___/ +// |__/ +// +// MIXTURE OF EXPERTS TENSOR MULTIPLICATION +// +// +// SHAPES +// +// - weights [cols, rows, experts] +// - thought [cols, tasks, tokens] w/ tasks ≤ thinkers +// - result [rows, thinkers, tokens] w/ thinkers ≤ experts +// - plan [thinkers, tokens] w/ i32 < experts +// +// DEFINITION +// +// for thinker in range(thinkers): +// for token in range(tokens): +// for row in range(rows): +// c = 0 +// for col in range(cols): +// expert = plan[token][thinker] +// a = weights[expert][row][col] +// b = thought[token][thinker % tasks][col] +// c += a * b +// result[token][thinker][row] = c +// +// REGULARITIES +// +// - tokens can be odd +// - thinkers is usually 2 +// - tasks is usually 1 or 2 +// - cols should be a multiple of 64 +// - rows should be a multiple of 64 +// - experts is usually 8 but could be 60 +// - tokens is always 1 for token generation +// - tokens can be huge for prompt processing +// +// EXAMPLE +// +// mixtral 8x7b w/ 217 token prompt +// +// | ne*0 ne*1 ne*2 ne*3 | nb*0 nb*1 nb*2 nb*3 | type +// ========================================================================= +// weights | 16384 6144 8 1 | 18 0x2400 0x3600000 0x1b000000 | q4_0 +// thought | 16384 2 217 1 | 4 0x10000 0x20000 0x1b20000 | f32 +// result | 6144 2 217 1 | 4 0x6000 0xc000 0xa2c000 | f32 +// plan | 2 217 1 1 | 4 0x20 0x1b20 0x1b20 | i32 +// + +namespace { +class MixMul { + public: + MixMul(const ggml_compute_params *params, const ggml_tensor *weights, + const ggml_tensor *thought, const ggml_tensor *plan, ggml_tensor *result) + : params(params), + weights(weights), + thought(thought), + plan(plan), + result(result), + rows(weights->ne[1]), + cols(weights->ne[0]), + experts(weights->ne[2]), + thinkers(plan->ne[0]), + tasks(thought->ne[1]), + tokens(thought->ne[2]), + ldq((cols * 2 + ROW_ALIGN - 1) & -ROW_ALIGN), + wdata_((char *)(((uintptr_t)params->wdata + MAX_ALIGN - 1) & -MAX_ALIGN)), + allocated_(0) { + } + + bool allocate_shared_memory() { + if (!(quantized_thought_ = allocate(MATRIX_ALIGN, tokens * tasks * ldq))) + return false; + if (!(rowptr_result_ = allocate(ROW_ALIGN, experts * tokens * thinkers))) + return false; + if (!(rowptr_thought_ = allocate(ROW_ALIGN, experts * tokens * thinkers))) + return false; + if (!(rowptr_count_ = allocate(sizeof(dim), experts))) + return false; + return true; + } + + size_t get_allocated_bytes() { + return (wdata_ - (char *)params->wdata) + allocated_; + } + + bool mixmul() { + + // invariants + assert(tasks <= thinkers); + assert(thinkers <= experts); + assert(tokens == plan->ne[1]); + assert(rows == result->ne[0]); + assert(cols == thought->ne[0]); + assert(tokens == result->ne[2]); + assert(thinkers == result->ne[1]); + + // dimensionality + assert(plan->ne[2] == 1); + assert(plan->ne[3] == 1); + assert(result->ne[3] == 1); + assert(weights->ne[3] == 1); + assert(thought->ne[3] == 1); + + // miscellaneous + assert(params->nth > 0); + assert(params->ith < params->nth); + assert(plan->type == GGML_TYPE_I32); + + // supported types + if (result->type != GGML_TYPE_F32) + return false; + + // check nb01 is convertible to lda + if (weights->nb[1] % ggml_type_size(weights->type)) + return false; + + // no support for column strides + if (result->nb[0] != ggml_type_size(result->type)) + return false; + if (thought->nb[0] != ggml_type_size(thought->type)) + return false; + if (weights->nb[0] != ggml_type_size(weights->type)) + return false; + + switch (weights->type) { + + case GGML_TYPE_F32: + if (thought->type != GGML_TYPE_F32) + return false; +#if defined(__AVX512F__) + return mixmat<16, 1, tinyBLAS, + float, float, float>(); +#elif defined(__AVX__) || defined(__AVX2__) + return mixmat<8, 1, tinyBLAS, float, + float, float>(); +#elif defined(__SSE__) + return mixmat<4, 1, tinyBLAS, float, + float, float>(); +#elif defined(__ARM_NEON) + return mixmat<4, 1, + tinyBLAS, + float, float, float>(); +#else + return false; +#endif + + case GGML_TYPE_F16: + if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_F16) + return false; +#if defined(__AVX512F__) + return mixmat<16, 1, + tinyBLAS, + ggml_fp16_t, ggml_fp16_t, float>(); +#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) + return mixmat< + 8, 1, tinyBLAS, + ggml_fp16_t, ggml_fp16_t, float>(); +#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) + return mixmat< + 8, 1, + tinyBLAS, + ggml_fp16_t, ggml_fp16_t, float>(); +#elif defined(__ARM_NEON) && !defined(_MSC_VER) + return mixmat< + 4, 1, + tinyBLAS, + ggml_fp16_t, ggml_fp16_t, float>(); +#else + return false; +#endif + + case GGML_TYPE_Q4_0: + if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_Q8_0) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) + return mixmat<32, 32, tinyBLAS_Q0_AVX2, + block_q4_0, block_q8_0, float>(); +#elif defined(__ARM_FEATURE_DOTPROD) + return mixmat<32, 32, tinyBLAS_Q0_ARM, + block_q4_0, block_q8_0, float>(); +#else + return false; +#endif + + case GGML_TYPE_Q8_0: + if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_Q8_0) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) + return mixmat<32, 32, tinyBLAS_Q0_AVX2, + block_q8_0, block_q8_0, float>(); +#elif defined(__ARM_FEATURE_DOTPROD) + return mixmat<32, 32, tinyBLAS_Q0_ARM, + block_q8_0, block_q8_0, float>(); +#else + return false; +#endif + + default: + return false; + } + } + + private: + template + bool mixmat() { + if (cols % KN) + return false; + switch (params->type) { + case GGML_TASK_TYPE_INIT: + if (thought->type != ggml_type_trait::id) + quantize_thought(ggml_type_trait::id); + build_row_pointers(ggml_type_trait::id); + return true; + case GGML_TASK_TYPE_COMPUTE: + assert(!(cols % BS)); + assert(!(weights->nb[1] % sizeof(TA))); + for (int expert = 0; expert < experts; ++expert) { + BLAS tb{cols / BS, + (const TA *)((const char *)weights->data + expert * weights->nb[2]), + (dim)(weights->nb[1] / sizeof(TA)), + (const TB *)(rowptr_thought_ + expert * tokens * thinkers), 0, + (TC *)(rowptr_result_ + expert * tokens * thinkers), 0, + params->ith, + params->nth}; + tb.matmul(rows, rowptr_count_[expert], GGML_TASK_TYPE_COMPUTE); + } + return true; + default: + return true; + } + } + + void build_row_pointers(ggml_type vec_dot_type) { + for (int expert = params->ith; expert < experts; expert += params->nth) { + dim count = 0; + for (dim token = 0; token < tokens; ++token) + for (int thinker = 0; thinker < thinkers; ++thinker) + if (expert == *(const int32_t *)((const char *)plan->data + + token * plan->nb[1] + + thinker * plan->nb[0])) { + dim row = count++; + dim idx = expert * thinkers * tokens + row; + rowptr_result_[idx] = + (uintptr_t)((char *)result->data + token * result->nb[2] + + thinker * result->nb[1]); + if (thought->type == vec_dot_type) + rowptr_thought_[idx] = + (uintptr_t)((char *)thought->data + token * thought->nb[2] + + thinker % tasks * thought->nb[1]); + else + rowptr_thought_[idx] = + (uintptr_t)((char *)quantized_thought_ + token * tasks * ldq + + thinker % tasks * ldq); + } + rowptr_count_[expert] = count; + } + } + + void quantize_thought(ggml_type vec_dot_type) { + dim chore = 0; + for (dim token = 0; token < tokens; ++token) + for (int task = 0; task < tasks; ++task) + if (chore++ % params->nth == params->ith) + quantize_row(quantized_thought_ + token * tasks * ldq + task * ldq, + (const float *)((const char *)thought->data + + token * thought->nb[2] + task * thought->nb[1]), + vec_dot_type); + } + + void quantize_row(void *dst, const float *src, ggml_type type) { + assert((dim)ggml_row_size(type, cols) <= ldq); + switch (type) { + case GGML_TYPE_F16: + ggml_fp32_to_fp16_row(src, (ggml_fp16_t *)dst, cols); + break; + case GGML_TYPE_Q8_0: + quantize_row_q8_0(src, (block_q8_0 *)dst, cols); + break; + default: + GGML_UNREACHABLE(); + } + } + + template + T *allocate(size_t align, size_t elems) { + T *res = nullptr; + size_t need = sizeof(T) * elems; + size_t base = allocated_; + base += align - 1; + base &= -align; + size_t toto = base + need; + if (toto >= allocated_ && toto <= params->wsize) { + res = (T *)(wdata_ + base); + allocated_ = toto; + } + return res; + } + + const ggml_compute_params *const params; + const ggml_tensor *const weights; + const ggml_tensor *const thought; + const ggml_tensor *const plan; + ggml_tensor *const result; + const dim rows; + const dim cols; + const int experts; + const int thinkers; + const int tasks; + const dim tokens; + const dim ldq; + + // variables + char *const wdata_; + size_t allocated_; + + // shared memory + dim *rowptr_count_ /*[experts]*/; + char *quantized_thought_ /*[tokens][tasks][cols][2]*/; + uintptr_t *rowptr_result_ /*[experts][tokens*thinkers]*/; + uintptr_t *rowptr_thought_ /*[experts][tokens*thinkers]*/; +}; +} // namespace + +/** + * Performs "mixture of experts" tensor multiplication on CPU. + */ +bool llamafile_mixmul(const ggml_compute_params *params, + const ggml_tensor *weights, + const ggml_tensor *thought, + const ggml_tensor *plan, + ggml_tensor *result) { + MixMul mm{params, weights, thought, plan, result}; + return mm.allocate_shared_memory() && mm.mixmul(); +} + +/** + * Returns number of shared memory bytes llamafile_mixmul() needs. + */ +size_t llamafile_mixmul_needs(const ggml_tensor *weights, + const ggml_tensor *thought, + const ggml_tensor *plan) { + ggml_compute_params params{}; + params.wsize = 0x7ffff000; + params.wdata = (void *)0x1000; + MixMul mm{¶ms, weights, thought, plan, 0}; + if (mm.allocate_shared_memory()) + return mm.get_allocated_bytes(); + else + return 0; +} diff --git a/sgemm.h b/sgemm.h index f29747d0a477af..06b8d7f02ddc94 100644 --- a/sgemm.h +++ b/sgemm.h @@ -1,14 +1,26 @@ #pragma once #include +#include #include #ifdef __cplusplus extern "C" { #endif -bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t, - const void *, int64_t, void *, int64_t, int, int, +struct ggml_tensor; +struct ggml_compute_params; + +bool llamafile_sgemm(intptr_t, intptr_t, intptr_t, const void *, intptr_t, + const void *, intptr_t, void *, intptr_t, int, int, int, int, int, int); +bool llamafile_mixmul(const struct ggml_compute_params *, const struct ggml_tensor *, + const struct ggml_tensor *, const struct ggml_tensor *, + struct ggml_tensor *); + +size_t llamafile_mixmul_needs(const struct ggml_tensor *, + const struct ggml_tensor *, + const struct ggml_tensor *); + #ifdef __cplusplus } #endif