From 957670a18ede62a481fcb235c3714de39bc784ea Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Mon, 22 Apr 2024 22:15:10 -0700 Subject: [PATCH] llamafile : improve moe prompt eval speed on cpu This change introduces a llamafile_mixmul() API that allows tinyBLAS to speed up "Mixture of Expert" models. On my Threadripper, Mixtral's 8x7b F16 weights now process prompts 2x faster. I'm also seeing a 60 percent improvement with Mixtral 8x22b Q4_0. The same applies to Q8_0, which is also supported by tinyBLAS. MoE models spend the majority of their time inside MUL_MAT_ID rather than MUL_MAT, which is why llamafile_sgemm was not able to help them before. llamafile_mixmul works by decomposing the mixmul operation into sgemm calls. --- common/common.cpp | 6 +- ggml.c | 8 +- sgemm.cpp | 895 +++++++++++++++++++++++++++++++++++----------- sgemm.h | 16 +- 4 files changed, 715 insertions(+), 210 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index ba1ecf0e59c8be..10401820140888 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -74,7 +74,7 @@ using json = nlohmann::ordered_json; int32_t get_num_physical_cores() { -#ifdef __linux__ +#if defined(__linux__) || defined(__COSMOPOLITAN__) // enumerate the set of thread siblings, num entries is num cores std::unordered_set siblings; for (uint32_t cpu=0; cpu < UINT32_MAX; ++cpu) { @@ -109,7 +109,7 @@ int32_t get_num_physical_cores() { return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; } -#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__) +#if defined(__x86_64__) && (defined(__linux__) || defined(__COSMOPOLITAN__)) && !defined(__ANDROID__) #include static void cpuid(unsigned leaf, unsigned subleaf, @@ -163,7 +163,7 @@ static int count_math_cpus(int cpu_count) { * Returns number of CPUs on system that are useful for math. */ int get_math_cpu_count() { -#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__) +#if defined(__x86_64__) && (defined(__linux__) || defined(__COSMOPOLITAN__)) && !defined(__ANDROID__) int cpu_count = sysconf(_SC_NPROCESSORS_ONLN); if (cpu_count < 1) { return get_num_physical_cores(); diff --git a/ggml.c b/ggml.c index b96a82a41517de..fa01f6eecfa2e2 100644 --- a/ggml.c +++ b/ggml.c @@ -12068,11 +12068,14 @@ static void ggml_compute_forward_mul_mat_id( const struct ggml_tensor * src1 = dst->src[1]; const struct ggml_tensor * ids = dst->src[2]; - GGML_TENSOR_BINARY_OP_LOCALS + if (llamafile_mixmul(params, src0, src1, ids, dst)) + return; const int ith = params->ith; const int nth = params->nth; + GGML_TENSOR_BINARY_OP_LOCALS + const enum ggml_type type = src0->type; const bool src1_cont = ggml_is_contiguous(src1); @@ -19659,6 +19662,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur = 0; const struct ggml_tensor * src0 = node->src[0]; const struct ggml_tensor * src1 = node->src[1]; + const struct ggml_tensor * src2 = node->src[2]; const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; if (src1->type != vec_dot_type) { cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)); @@ -19667,6 +19671,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += GGML_PAD(cur, sizeof(int64_t)); // align cur += n_as * sizeof(int64_t); // matrix_row_counts cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows + size_t cur2 = llamafile_mixmul_needs(src0, src1, src2); + cur = cur > cur2 ? cur : cur2; } break; case GGML_OP_OUT_PROD: { diff --git a/sgemm.cpp b/sgemm.cpp index 40ba9d7e9a7b72..05e481536ef1ba 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -50,6 +50,10 @@ #include "ggml-impl.h" #include "ggml-quants.h" +#define ROW_ALIGN 64 +#define MATRIX_ALIGN 4096 +#define MAX_ALIGN 4096 + #ifdef _MSC_VER #define NOINLINE __declspec(noinline) #else @@ -62,13 +66,34 @@ #define VECTOR_REGISTERS 16 #endif -#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) +typedef intptr_t dim; namespace { inline float unhalf(ggml_fp16_t d) { return GGML_FP16_TO_FP32(d); } +inline float unhalf(ggml_bf16_t d) { + return GGML_BF16_TO_FP32(d); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// MATRIX MEMORY INDEXING + +#define NCA 1 +#define NCB 2 +#define NCC 4 + +#define INDEX(A, lda, j, i) (CONFIG & NC##A ? ((T##A *const *)A)[j] + i : A + lda * (j) + i) + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GGML TYPE TRAITS + +template struct ggml_type_trait; +template<> struct ggml_type_trait { static constexpr ggml_type id = GGML_TYPE_F32; }; +template<> struct ggml_type_trait { static constexpr ggml_type id = GGML_TYPE_F16; }; +template<> struct ggml_type_trait { static constexpr ggml_type id = GGML_TYPE_BF16; }; +template<> struct ggml_type_trait { static constexpr ggml_type id = GGML_TYPE_Q8_0; }; //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED ARITHMETIC OPERATIONS @@ -114,6 +139,21 @@ inline U madd(T a, T b, U c) { return add(mul(a, b), c); } +/** + * Computes a * b + c with error correction. + * + * @see W. Kahan, "Further remarks on reducing truncation errors," + * Communications of the ACM, vol. 8, no. 1, p. 40, Jan. 1965, + * doi: 10.1145/363707.363723. + */ +template +inline U madder(T a, T b, U c, U *e) { + U y = sub(mul(a, b), *e); + U t = add(c, y); + *e = sub(sub(t, c), y); + return t; +} + #if defined(__FMA__) #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) template <> @@ -132,15 +172,27 @@ inline __m512 madd(__m512 a, __m512 b, __m512 c) { #if defined(__ARM_FEATURE_FMA) template <> inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) { - return vfmaq_f32(c, b, a); + return vfmaq_f32(c, a, b); } -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) -template <> +#if 0 // todo: this specialization chops gcc 12.3 performance in half +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(__clang__) +template <> // this specialization chops gcc 12.3 performance in half inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) { return vfmaq_f16(c, b, a); } #endif #endif +#endif + +#if defined(__AVX512BF16__) +template <> inline __m512 madd(__m512bh x, __m512bh y, __m512 z) { + return _mm512_dpbf16_ps(z, x, y); +} +template <> inline __m512 madder(__m512bh x, __m512bh y, __m512 z, __m512 *_) { + return _mm512_dpbf16_ps(z, x, y); + (void)_; +} +#endif //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED HORIZONTAL SUM @@ -153,6 +205,7 @@ inline float hsum(float32x4_t x) { #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) inline float hsum(float16x8_t x) { + // todo: works great for clang but produces sketchy code with gcc 12.3 return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)), vcvt_f32_f16(vget_high_f16(x)))); } @@ -196,6 +249,9 @@ template T load(const U *); template <> inline float32x4_t load(const float *p) { return vld1q_f32(p); } +template <> inline float32x4_t load(const ggml_bf16_t *p) { + return vreinterpretq_f32_u32(vshll_n_u16(vld1_u16((const unsigned short *)p), 16)); +} #if !defined(_MSC_VER) template <> inline float16x8_t load(const ggml_fp16_t *p) { return vld1q_f16((const float16_t *)p); @@ -218,6 +274,13 @@ template <> inline __m256 load(const float *p) { } #endif // __AVX__ +#if defined(__AVX2__) || defined(__AVX512F__) +template <> inline __m256 load(const ggml_bf16_t *p) { + return _mm256_castsi256_ps( + _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)p)), 16)); +} +#endif // __AVX2__ + #if defined(__F16C__) template <> inline __m256 load(const ggml_fp16_t *p) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p)); @@ -231,66 +294,77 @@ template <> inline __m512 load(const float *p) { template <> inline __m512 load(const ggml_fp16_t *p) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p)); } +template <> inline __m512 load(const ggml_bf16_t *p) { + return _mm512_castsi512_ps( + _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)p)), 16)); +} #endif // __AVX512F__ +#if defined(__AVX512BF16__) +template <> +inline __m512bh load(const ggml_bf16_t *p) { + return (__m512bh)_mm512_loadu_ps((const float *)p); +} +template <> +inline __m512bh load(const float *p) { + return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p)); +} +#endif // __AVX512BF16__ + //////////////////////////////////////////////////////////////////////////////////////////////////// // FLOATING POINT MATRIX MULTIPLICATION -template +template class tinyBLAS { public: - tinyBLAS(int64_t k, - const TA *A, int64_t lda, - const TB *B, int64_t ldb, - TC *C, int64_t ldc, + tinyBLAS(dim k, + const TA *A, dim lda, + const TB *B, dim ldb, + TC *C, dim ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } - void matmul(int64_t m, int64_t n, int task) { + void matmul(dim m, dim n, int task) { if (task == GGML_TASK_TYPE_COMPUTE) mnpack(0, m, 0, n); } private: - NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; + NOINLINE void mnpack(dim m0, dim m, dim n0, dim n) { + dim mc, nc, mp, np; switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) { #if VECTOR_REGISTERS == 32 case 0x55: + case 0x54: mc = 5; - nc = 5; - gemm<5, 5>(m0, m, n0, n); + nc = 4; + gemm<5, 4, false>(m0, m, n0, n); break; case 0x45: mc = 4; nc = 5; - gemm<4, 5>(m0, m, n0, n); - break; - case 0x54: - mc = 5; - nc = 4; - gemm<5, 4>(m0, m, n0, n); + gemm<4, 5, false>(m0, m, n0, n); break; case 0x44: mc = 4; nc = 4; - gemm<4, 4>(m0, m, n0, n); + gemm<4, 4, false>(m0, m, n0, n); break; case 0x53: mc = 5; nc = 3; - gemm<5, 3>(m0, m, n0, n); + gemm<5, 3, false>(m0, m, n0, n); break; case 0x35: mc = 3; nc = 5; - gemm<3, 5>(m0, m, n0, n); + gemm<3, 5, false>(m0, m, n0, n); break; case 0x43: mc = 4; nc = 3; - gemm<4, 3>(m0, m, n0, n); + gemm<4, 3, false>(m0, m, n0, n); break; #else case 0x55: @@ -301,99 +375,99 @@ class tinyBLAS { case 0x43: mc = 4; nc = 3; - gemm<4, 3>(m0, m, n0, n); + gemm<4, 3, false>(m0, m, n0, n); break; case 0x35: #endif case 0x34: mc = 3; nc = 4; - gemm<3, 4>(m0, m, n0, n); + gemm<3, 4, false>(m0, m, n0, n); break; case 0x52: mc = 5; nc = 2; - gemm<5, 2>(m0, m, n0, n); + gemm<5, 2, false>(m0, m, n0, n); break; case 0x33: mc = 3; nc = 3; - gemm<3, 3>(m0, m, n0, n); + gemm<3, 3, false>(m0, m, n0, n); break; case 0x25: mc = 2; nc = 5; - gemm<2, 5>(m0, m, n0, n); + gemm<2, 5, false>(m0, m, n0, n); break; case 0x42: mc = 4; nc = 2; - gemm<4, 2>(m0, m, n0, n); + gemm<4, 2, false>(m0, m, n0, n); break; case 0x24: mc = 2; nc = 4; - gemm<2, 4>(m0, m, n0, n); + gemm<2, 4, false>(m0, m, n0, n); break; case 0x32: mc = 3; nc = 2; - gemm<3, 2>(m0, m, n0, n); + gemm<3, 2, true>(m0, m, n0, n); break; case 0x23: mc = 2; nc = 3; - gemm<2, 3>(m0, m, n0, n); + gemm<2, 3, true>(m0, m, n0, n); break; case 0x51: mc = 5; nc = 1; - gemm<5, 1>(m0, m, n0, n); + gemm<5, 1, true>(m0, m, n0, n); break; case 0x41: mc = 4; nc = 1; - gemm<4, 1>(m0, m, n0, n); + gemm<4, 1, true>(m0, m, n0, n); break; case 0x22: mc = 2; nc = 2; - gemm<2, 2>(m0, m, n0, n); + gemm<2, 2, true>(m0, m, n0, n); break; case 0x15: mc = 1; nc = 5; - gemm<1, 5>(m0, m, n0, n); + gemm<1, 5, true>(m0, m, n0, n); break; case 0x14: mc = 1; nc = 4; - gemm<1, 4>(m0, m, n0, n); + gemm<1, 4, true>(m0, m, n0, n); break; case 0x31: mc = 3; nc = 1; - gemm<3, 1>(m0, m, n0, n); + gemm<3, 1, true>(m0, m, n0, n); break; case 0x13: mc = 1; nc = 3; - gemm<1, 3>(m0, m, n0, n); + gemm<1, 3, true>(m0, m, n0, n); break; case 0x21: mc = 2; nc = 1; - gemm<2, 1>(m0, m, n0, n); + gemm<2, 1, true>(m0, m, n0, n); break; case 0x12: mc = 1; nc = 2; - gemm<1, 2>(m0, m, n0, n); + gemm<1, 2, true>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; - gemm<1, 1>(m0, m, n0, n); + gemm<1, 1, true>(m0, m, n0, n); break; default: return; @@ -404,39 +478,45 @@ class tinyBLAS { mnpack(m0, m, np, n); } - template - NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t ytiles = (m - m0) / RM; - int64_t xtiles = (n - n0) / RN; - int64_t tiles = xtiles * ytiles; - int64_t duty = (tiles + nth - 1) / nth; - int64_t start = duty * ith; - int64_t end = start + duty; + template + NOINLINE void gemm(dim m0, dim m, dim n0, dim n) { + dim ytiles = RM > 1 ? (m - m0) / RM : 1; + dim xtiles = RN > 1 ? (n - n0) / RN : 1; + dim tiles = xtiles * ytiles; + dim duty = (tiles + nth - 1) / nth; + dim start = duty * ith; + dim end = start + duty; if (end > tiles) end = tiles; - for (int64_t job = start; job < end; ++job) { - int64_t ii = m0 + job / xtiles * RM; - int64_t jj = n0 + job % xtiles * RN; + for (dim job = start; job < end; ++job) { + dim ii = m0 + job / xtiles * RM; + dim jj = n0 + job % xtiles * RN; D Cv[RN][RM] = {}; - for (int64_t l = 0; l < k; l += KN) - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - Cv[j][i] = madd(load(A + lda * (ii + i) + l), - load(B + ldb * (jj + j) + l), - Cv[j][i]); - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + D Ce[RN][RM] = {}; + for (dim l = 0; l < k; l += KN) + for (int j = 0; j < RN; ++j) + for (int i = 0; i < RM; ++i) + if (KAHAN) + Cv[j][i] = madder(load(INDEX(A, lda, ii + i, l)), + load(INDEX(B, ldb, jj + j, l)), + Cv[j][i], &Ce[j][i]); + else + Cv[j][i] = madd(load(INDEX(A, lda, ii + i, l)), + load(INDEX(B, ldb, jj + j, l)), + Cv[j][i]); + for (int j = 0; j < RN; ++j) + for (int i = 0; i < RM; ++i) + *INDEX(C, ldc, jj + j, ii + i) = hsum(Cv[j][i]); } } const TA *const A; const TB *const B; TC *const C; - const int64_t k; - const int64_t lda; - const int64_t ldb; - const int64_t ldc; + const dim k; + const dim lda; + const dim ldb; + const dim ldc; const int ith; const int nth; }; @@ -445,25 +525,25 @@ class tinyBLAS { // QUANT ZERO MATRIX MULTIPLICATION #if defined(__ARM_FEATURE_DOTPROD) -template +template class tinyBLAS_Q0_ARM { public: - tinyBLAS_Q0_ARM(int64_t k, - const TA *A, int64_t lda, - const block_q8_0 *B, int64_t ldb, - float *C, int64_t ldc, + tinyBLAS_Q0_ARM(dim k, + const TA *A, dim lda, + const TB *B, dim ldb, + TC *C, dim ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } - void matmul(int64_t m, int64_t n, int task) { + void matmul(dim m, dim n, int task) { if (task == GGML_TASK_TYPE_COMPUTE) mnpack(0, m, 0, n); } private: - NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; + NOINLINE void mnpack(dim m0, dim m, dim n0, dim n) { + dim mc, nc, mp, np; switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) { case 0x33: mc = 3; @@ -520,34 +600,34 @@ class tinyBLAS_Q0_ARM { } template - NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t ytiles = (m - m0) / RM; - int64_t xtiles = (n - n0) / RN; - int64_t tiles = xtiles * ytiles; - int64_t duty = (tiles + nth - 1) / nth; - int64_t start = duty * ith; - int64_t end = start + duty; + NOINLINE void gemm(dim m0, dim m, dim n0, dim n) { + dim ytiles = RM > 1 ? (m - m0) / RM : 1; + dim xtiles = RN > 1 ? (n - n0) / RN : 1; + dim tiles = xtiles * ytiles; + dim duty = (tiles + nth - 1) / nth; + dim start = duty * ith; + dim end = start + duty; if (end > tiles) end = tiles; - for (int64_t job = start; job < end; ++job) { - int64_t ii = m0 + job / xtiles * RM; - int64_t jj = n0 + job % xtiles * RN; + for (dim job = start; job < end; ++job) { + dim ii = m0 + job / xtiles * RM; + dim jj = n0 + job % xtiles * RN; float32x4_t Cv[RN][RM] = {}; - for (int64_t l = 0; l < k; ++l) - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - Cv[j][i] = vmlaq_n_f32(Cv[j][i], - vcvtq_f32_s32(vdotq_s32( - vdotq_s32(vdupq_n_s32(0), - load_lo(A + lda * (ii + i) + l), - load_lo(B + ldb * (jj + j) + l)), - load_hi(A + lda * (ii + i) + l), - load_hi(B + ldb * (jj + j) + l))), - unhalf(A[lda * (ii + i) + l].d) * - unhalf(B[ldb * (jj + j) + l].d)); - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + for (dim l = 0; l < k; ++l) + for (int j = 0; j < RN; ++j) + for (int i = 0; i < RM; ++i) + Cv[j][i] = vmlaq_n_f32( + Cv[j][i], + vcvtq_f32_s32(vdotq_s32(vdotq_s32(vdupq_n_s32(0), + load_lo(INDEX(A, lda, ii + i, l)), + load_lo(INDEX(B, ldb, jj + j, l))), + load_hi(INDEX(A, lda, ii + i, l)), + load_hi(INDEX(B, ldb, jj + j, l)))), + (unhalf(INDEX(A, lda, ii + i, l)->d) * + unhalf(INDEX(B, ldb, jj + j, l)->d))); + for (int j = 0; j < RN; ++j) + for (int i = 0; i < RM; ++i) + *INDEX(C, ldc, jj + j, ii + i) = hsum(Cv[j][i]); } } @@ -560,48 +640,46 @@ class tinyBLAS_Q0_ARM { } inline int8x16_t load_lo(const block_q4_0 *b) { - return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs), - vdupq_n_u8(0x0f))), + return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs), vdupq_n_u8(0x0f))), vdupq_n_s8(0x8)); } inline int8x16_t load_hi(const block_q4_0 *b) { - return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)), - vdupq_n_s8(0x8)); + return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)), vdupq_n_s8(0x8)); } const TA *const A; - const block_q8_0 *const B; - float *const C; - const int64_t k; - const int64_t lda; - const int64_t ldb; - const int64_t ldc; + const TB *const B; + TC *const C; + const dim k; + const dim lda; + const dim ldb; + const dim ldc; const int ith; const int nth; }; #endif // __ARM_FEATURE_DOTPROD #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) -template -class tinyBLAS_Q0_AVX { +template +class tinyBLAS_Q0_AVX2 { public: - tinyBLAS_Q0_AVX(int64_t k, - const TA *A, int64_t lda, - const TB *B, int64_t ldb, - TC *C, int64_t ldc, - int ith, int nth) + tinyBLAS_Q0_AVX2(dim k, + const TA *A, dim lda, + const TB *B, dim ldb, + TC *C, dim ldc, + int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } - void matmul(int64_t m, int64_t n, int task) { + void matmul(dim m, dim n, int task) { if (task == GGML_TASK_TYPE_COMPUTE) mnpack(0, m, 0, n); } private: - void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; + void mnpack(dim m0, dim m, dim n0, dim n) { + dim mc, nc, mp, np; switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) { #if VECTOR_REGISTERS == 32 case 0x44: @@ -727,35 +805,34 @@ class tinyBLAS_Q0_AVX { for (int64_t j = 0; j < RN; ++j) for (int64_t i = 0; i < RM; ++i) { #if defined(__AVX2__) - __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), - load(A + lda * (ii + i) + l)), - _mm256_sign_epi8(load(B + ldb * (jj + j) + l), - load(A + lda * (ii + i) + l))); + __m256 udTmp = updot(_mm256_sign_epi8(load(INDEX(A, lda, ii + i, l)), + load(INDEX(A, lda, ii + i, l))), + _mm256_sign_epi8(load(INDEX(B, ldb, jj + j, l)), + load(INDEX(A, lda, ii + i, l)))); #else - __m128i ali0 = load0(A + lda * (ii + i) + l); - __m128i ali1 = load1(A + lda * (ii + i) + l); - __m128i blj0 = load0(B + ldb * (jj + j) + l); - __m128i blj1 = load1(B + ldb * (jj + j) + l); - + __m128i ali0 = load0(INDEX(A, lda, ii + i, l)); + __m128i ali1 = load1(INDEX(A, lda, ii + i, l)); + __m128i blj0 = load0(INDEX(B, ldb, jj + j, l)); + __m128i blj1 = load1(INDEX(B, ldb, jj + j, l)); __m128i sepAA0 = _mm_sign_epi8(ali0, ali0); __m128i sepAA1 = _mm_sign_epi8(ali1, ali1); __m128i sepBA0 = _mm_sign_epi8(blj0, ali0); __m128i sepBA1 = _mm_sign_epi8(blj1, ali1); - // updot const __m128i oneFill = _mm_set1_epi16(1); __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0); __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1); - __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0))); + __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), + _mm_madd_epi16(oneFill, mad0))); #endif - Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * - unhalf(B[ldb * (jj + j) + l].d)), + Cv[j][i] = madd(_mm256_set1_ps(unhalf(INDEX(A, lda, ii + i, l)->d) * + unhalf(INDEX(B, ldb, jj + j, l)->d)), udTmp, Cv[j][i]); } for (int64_t j = 0; j < RN; ++j) for (int64_t i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + *INDEX(C, ldc, jj + j, ii + i) = hsum(Cv[j][i]); } } @@ -772,7 +849,15 @@ class tinyBLAS_Q0_AVX { } inline __m256i load(const block_q4_0 *b) { - return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); + __m128i x = _mm_loadu_si128((const __m128i *)b->qs); + return _mm256_sub_epi8( + _mm256_and_si256( + _mm256_set1_epi8(15), + _mm256_insertf128_si256( + _mm256_castsi128_si256(x), + _mm_srli_epi16(x, 4), + 1)), + _mm256_set1_epi8(8)); } inline __m128i load0(const block_q4_0 *b) { @@ -795,20 +880,13 @@ class tinyBLAS_Q0_AVX { return _mm256_cvtepi32_ps(res); } - static inline __m256i denibble(const uint8_t *p) { - __m128i x = _mm_loadu_si128((const __m128i *)p); - return _mm256_and_si256(_mm256_set1_epi8(15), - _mm256_insertf128_si256(_mm256_castsi128_si256(x), - _mm_srli_epi16(x, 4), 1)); - } - const TA *const A; const TB *const B; TC *const C; - const int64_t k; - const int64_t lda; - const int64_t ldb; - const int64_t ldc; + const dim k; + const dim lda; + const dim ldb; + const dim ldc; const int ith; const int nth; }; @@ -847,8 +925,12 @@ class tinyBLAS_Q0_AVX { * @param Ctype is GGML data type of `C` * @return true if this function was able to service the matmul request */ -bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C, - int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) { +bool llamafile_sgemm(dim m, dim n, dim k, + const void *A, dim lda, + const void *B, dim ldb, + void *C, dim ldc, + int ith, int nth, int task, + int Atype, int Btype, int Ctype) { assert(m >= 0); assert(n >= 0); @@ -870,21 +952,15 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda #if defined(__AVX512F__) if (k % 16) return false; - tinyBLAS<16, __m512, __m512, float, float, float> tb{ - k, (const float *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS<0, 16, __m512, __m512, float, float, float> tb{ + k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__AVX__) || defined(__AVX2__) if (k % 8) return false; - tinyBLAS<8, __m256, __m256, float, float, float> tb{ - k, (const float *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS<0, 8, __m256, __m256, float, float, float> tb{ + k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_NEON) @@ -892,11 +968,56 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda return false; if (k % 4) return false; - tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ - k, (const float *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, float> tb{ + k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; + tb.matmul(m, n, task); + return true; +#else + return false; +#endif + } + + case GGML_TYPE_BF16: { +#if defined(__AVX512BF16__) + if (k % 32) + return false; + if (Btype == GGML_TYPE_F32 && n < 2) { + tinyBLAS<0, 32, __m512, __m512bh, ggml_bf16_t, float, float> tb{ + k, (const ggml_bf16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; + tb.matmul(m, n, task); + return true; + } + if (Btype == GGML_TYPE_F32) + return false; + if (Btype != GGML_TYPE_BF16) + return false; + tinyBLAS<0, 32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ + k, (const ggml_bf16_t *)A, lda, (const ggml_bf16_t *)B, ldb, (float *)C, ldc, ith, nth}; + tb.matmul(m, n, task); + return true; +#elif defined(__AVX512F__) + if (k % 16) + return false; + tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, float> tb{ + k, (const ggml_bf16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; + tb.matmul(m, n, task); + return true; +#elif defined(__AVX2__) + if (k % 8) + return false; + if (Btype != GGML_TYPE_F32) + return false; + tinyBLAS<0, 8, __m256, __m256, ggml_bf16_t, float, float> tb{ + k, (const ggml_bf16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; + tb.matmul(m, n, task); + return true; +#elif defined(__ARM_NEON) && !defined(_MSC_VER) + if (k % 4) + return false; + if (Btype != GGML_TYPE_F32) + return false; + tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_bf16_t, float, float> tb{ + k, (const ggml_bf16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else @@ -910,11 +1031,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda return false; if (Btype != GGML_TYPE_F32) return false; - tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, float> tb{ + k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) @@ -922,11 +1040,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda return false; if (Btype != GGML_TYPE_F32) return false; - tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, float> tb{ + k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) @@ -936,11 +1051,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda return false; if (Btype != GGML_TYPE_F16) return false; - tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const ggml_fp16_t *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ + k, (const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_NEON) && !defined(_MSC_VER) @@ -948,11 +1060,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda return false; if (Btype != GGML_TYPE_F32) return false; - tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ + k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else @@ -962,21 +1071,15 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda case GGML_TYPE_Q8_0: { if (Btype != GGML_TYPE_Q8_0) - return false; + return false; #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) - tinyBLAS_Q0_AVX tb{ - k, (const block_q8_0 *)A, lda, - (const block_q8_0 *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS_Q0_AVX2<0, block_q8_0, block_q8_0, float> tb{ + k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_FEATURE_DOTPROD) - tinyBLAS_Q0_ARM tb{ - k, (const block_q8_0 *)A, lda, - (const block_q8_0 *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS_Q0_ARM<0, block_q8_0, block_q8_0, float> tb{ + k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else @@ -988,19 +1091,13 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda if (Btype != GGML_TYPE_Q8_0) return false; #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) - tinyBLAS_Q0_AVX tb{ - k, (const block_q4_0 *)A, lda, - (const block_q8_0 *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS_Q0_AVX2<0, block_q4_0, block_q8_0, float> tb{ + k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #elif defined(__ARM_FEATURE_DOTPROD) - tinyBLAS_Q0_ARM tb{ - k, (const block_q4_0 *)A, lda, - (const block_q8_0 *)B, ldb, - (float *)C, ldc, - ith, nth}; + tinyBLAS_Q0_ARM<0, block_q4_0, block_q8_0, float> tb{ + k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n, task); return true; #else @@ -1028,3 +1125,393 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda (void)Btype; (void)Ctype; } + +// +// _ _ ___ _ _ ___ +// | |_(_)_ _ _ _| _ ) | /_\ / __| +// | _| | ' \ || | _ \ |__ / _ \\__ \. +// \__|_|_||_\_, |___/____/_/ \_\___/ +// |__/ +// +// MIXTURE OF EXPERTS TENSOR MULTIPLICATION +// +// +// SHAPES +// +// - weights [cols, rows, experts] +// - thought [cols, tasks, tokens] w/ tasks ≤ thinkers +// - result [rows, thinkers, tokens] w/ thinkers ≤ experts +// - plan [thinkers, tokens] w/ i32 < experts +// +// DEFINITION +// +// for thinker in range(thinkers): +// for token in range(tokens): +// for row in range(rows): +// c = 0 +// for col in range(cols): +// expert = plan[token][thinker] +// a = weights[expert][row][col] +// b = thought[token][thinker % tasks][col] +// c += a * b +// result[token][thinker][row] = c +// +// REGULARITIES +// +// - tokens can be odd +// - thinkers is usually 2 +// - tasks is usually 1 or 2 +// - cols should be a multiple of 64 +// - rows should be a multiple of 64 +// - experts is usually 8 but could be 60 +// - tokens is always 1 for token generation +// - tokens can be huge for prompt processing +// +// EXAMPLE +// +// mixtral 8x7b w/ 217 token prompt +// +// | ne*0 ne*1 ne*2 ne*3 | nb*0 nb*1 nb*2 nb*3 | type +// ========================================================================= +// weights | 16384 6144 8 1 | 18 0x2400 0x3600000 0x1b000000 | q4_0 +// thought | 16384 2 217 1 | 4 0x10000 0x20000 0x1b20000 | f32 +// result | 6144 2 217 1 | 4 0x6000 0xc000 0xa2c000 | f32 +// plan | 2 217 1 1 | 4 0x20 0x1b20 0x1b20 | i32 +// + +namespace { +class MixMul { + public: + MixMul(const ggml_compute_params *params, const ggml_tensor *weights, + const ggml_tensor *thought, const ggml_tensor *plan, ggml_tensor *result) + : params(params), + weights(weights), + thought(thought), + plan(plan), + result(result), + rows(weights->ne[1]), + cols(weights->ne[0]), + experts(weights->ne[2]), + thinkers(plan->ne[0]), + tasks(thought->ne[1]), + tokens(thought->ne[2]), + ldq((cols * 2 + ROW_ALIGN - 1) & -ROW_ALIGN), + wdata_((char *)(((uintptr_t)params->wdata + MAX_ALIGN - 1) & -MAX_ALIGN)), + allocated_(0) { + } + + bool allocate_shared_memory() { + if (!(quantized_thought_ = allocate(MATRIX_ALIGN, tokens * tasks * ldq))) + return false; + if (!(rowptr_result_ = allocate(ROW_ALIGN, experts * tokens * thinkers))) + return false; + if (!(rowptr_thought_ = allocate(ROW_ALIGN, experts * tokens * thinkers))) + return false; + if (!(rowptr_count_ = allocate(sizeof(dim), experts))) + return false; + return true; + } + + size_t get_allocated_bytes() { + return (wdata_ - (char *)params->wdata) + allocated_; + } + + bool mixmul() { + + // invariants + assert(tasks <= thinkers); + assert(thinkers <= experts); + assert(tokens == plan->ne[1]); + assert(rows == result->ne[0]); + assert(cols == thought->ne[0]); + assert(tokens == result->ne[2]); + assert(thinkers == result->ne[1]); + + // dimensionality + assert(plan->ne[2] == 1); + assert(plan->ne[3] == 1); + assert(result->ne[3] == 1); + assert(weights->ne[3] == 1); + assert(thought->ne[3] == 1); + + // miscellaneous + assert(params->nth > 0); + assert(params->ith < params->nth); + assert(plan->type == GGML_TYPE_I32); + + // supported types + if (result->type != GGML_TYPE_F32) + return false; + + // check nb01 is convertible to lda + if (weights->nb[1] % ggml_type_size(weights->type)) + return false; + + // no support for column strides + if (result->nb[0] != ggml_type_size(result->type)) + return false; + if (thought->nb[0] != ggml_type_size(thought->type)) + return false; + if (weights->nb[0] != ggml_type_size(weights->type)) + return false; + + switch (weights->type) { + + case GGML_TYPE_F32: + if (thought->type != GGML_TYPE_F32) + return false; +#if defined(__AVX512F__) + return mixmat<16, 1, tinyBLAS, + float, float, float>(); +#elif defined(__AVX__) || defined(__AVX2__) + return mixmat<8, 1, tinyBLAS, float, + float, float>(); +#elif defined(__SSE__) + return mixmat<4, 1, tinyBLAS, float, + float, float>(); +#elif defined(__ARM_NEON) + return mixmat<4, 1, + tinyBLAS, + float, float, float>(); +#else + return false; +#endif + + case GGML_TYPE_BF16: + if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_BF16) + return false; +#if defined(__AVX512BF16__) + return mixmat<32, 1, + tinyBLAS, + ggml_bf16_t, ggml_bf16_t, float>(); +#elif defined(__AVX512F__) + return mixmat<16, 1, + tinyBLAS, + ggml_bf16_t, ggml_bf16_t, float>(); +#elif defined(__AVX2__) + return mixmat<8, 1, + tinyBLAS, + ggml_bf16_t, ggml_bf16_t, float>(); +#elif defined(__ARM_NEON) && !defined(_MSC_VER) + return mixmat< + 4, 1, + tinyBLAS, + ggml_bf16_t, ggml_bf16_t, float>(); +#else + return false; +#endif + + case GGML_TYPE_F16: + if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_F16) + return false; +#if defined(__AVX512F__) + return mixmat<16, 1, + tinyBLAS, + ggml_fp16_t, ggml_fp16_t, float>(); +#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) + return mixmat<8, 1, + tinyBLAS, + ggml_fp16_t, ggml_fp16_t, float>(); +#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) + if (result->op_params[0] == GGML_PREC_F32) { + return mixmat< + 4, 1, + tinyBLAS, + ggml_fp16_t, ggml_fp16_t, float>(); + } else { + return mixmat< + 8, 1, + tinyBLAS, + ggml_fp16_t, ggml_fp16_t, float>(); + } +#elif defined(__ARM_NEON) && !defined(_MSC_VER) + return mixmat< + 4, 1, + tinyBLAS, + ggml_fp16_t, ggml_fp16_t, float>(); +#else + return false; +#endif + + case GGML_TYPE_Q4_0: + if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_Q8_0) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) + return mixmat<32, 32, tinyBLAS_Q0_AVX2, + block_q4_0, block_q8_0, float>(); +#elif defined(__ARM_FEATURE_DOTPROD) + return mixmat<32, 32, tinyBLAS_Q0_ARM, + block_q4_0, block_q8_0, float>(); +#else + return false; +#endif + + case GGML_TYPE_Q8_0: + if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_Q8_0) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) + return mixmat<32, 32, tinyBLAS_Q0_AVX2, + block_q8_0, block_q8_0, float>(); +#elif defined(__ARM_FEATURE_DOTPROD) + return mixmat<32, 32, tinyBLAS_Q0_ARM, + block_q8_0, block_q8_0, float>(); +#else + return false; +#endif + + default: + return false; + } + } + + private: + template + bool mixmat() { + if (cols % KN) + return false; + switch (params->type) { + case GGML_TASK_TYPE_INIT: + if (thought->type != ggml_type_trait::id) + quantize_thought(ggml_type_trait::id); + build_row_pointers(ggml_type_trait::id); + return true; + case GGML_TASK_TYPE_COMPUTE: + assert(!(cols % BS)); + assert(!(weights->nb[1] % sizeof(TA))); + for (int expert = 0; expert < experts; ++expert) { + BLAS tb{cols / BS, + (const TA *)((const char *)weights->data + expert * weights->nb[2]), + (dim)(weights->nb[1] / sizeof(TA)), + (const TB *)(rowptr_thought_ + expert * tokens * thinkers), 0, + (TC *)(rowptr_result_ + expert * tokens * thinkers), 0, + params->ith, + params->nth}; + tb.matmul(rows, rowptr_count_[expert], GGML_TASK_TYPE_COMPUTE); + } + return true; + default: + return true; + } + } + + void build_row_pointers(ggml_type vec_dot_type) { + for (int expert = params->ith; expert < experts; expert += params->nth) { + dim count = 0; + for (dim token = 0; token < tokens; ++token) + for (int thinker = 0; thinker < thinkers; ++thinker) + if (expert == *(const int32_t *)((const char *)plan->data + + token * plan->nb[1] + + thinker * plan->nb[0])) { + dim row = count++; + dim idx = expert * thinkers * tokens + row; + rowptr_result_[idx] = + (uintptr_t)((char *)result->data + token * result->nb[2] + + thinker * result->nb[1]); + if (thought->type == vec_dot_type) + rowptr_thought_[idx] = + (uintptr_t)((char *)thought->data + token * thought->nb[2] + + thinker % tasks * thought->nb[1]); + else + rowptr_thought_[idx] = + (uintptr_t)((char *)quantized_thought_ + token * tasks * ldq + + thinker % tasks * ldq); + } + rowptr_count_[expert] = count; + } + } + + void quantize_thought(ggml_type vec_dot_type) { + dim chore = 0; + for (dim token = 0; token < tokens; ++token) + for (int task = 0; task < tasks; ++task) + if (chore++ % params->nth == params->ith) + quantize_row(quantized_thought_ + token * tasks * ldq + task * ldq, + (const float *)((const char *)thought->data + + token * thought->nb[2] + task * thought->nb[1]), + vec_dot_type); + } + + void quantize_row(void *dst, const float *src, ggml_type type) { + assert((dim)ggml_row_size(type, cols) <= ldq); + switch (type) { + case GGML_TYPE_F16: + ggml_fp32_to_fp16_row(src, (ggml_fp16_t *)dst, cols); + break; + case GGML_TYPE_BF16: + ggml_fp32_to_bf16_row(src, (ggml_bf16_t *)dst, cols); + break; + case GGML_TYPE_Q8_0: + quantize_row_q8_0(src, (block_q8_0 *)dst, cols); + break; + default: + GGML_UNREACHABLE(); + } + } + + template + T *allocate(size_t align, size_t elems) { + T *res = nullptr; + size_t need = sizeof(T) * elems; + size_t base = allocated_; + base += align - 1; + base &= -align; + size_t toto = base + need; + if (toto >= allocated_ && toto <= params->wsize) { + res = (T *)(wdata_ + base); + allocated_ = toto; + } + return res; + } + + const ggml_compute_params *const params; + const ggml_tensor *const weights; + const ggml_tensor *const thought; + const ggml_tensor *const plan; + ggml_tensor *const result; + const dim rows; + const dim cols; + const int experts; + const int thinkers; + const int tasks; + const dim tokens; + const dim ldq; + + // variables + char *const wdata_; + size_t allocated_; + + // shared memory + dim *rowptr_count_ /*[experts]*/; + char *quantized_thought_ /*[tokens][tasks][cols][2]*/; + uintptr_t *rowptr_result_ /*[experts][tokens*thinkers]*/; + uintptr_t *rowptr_thought_ /*[experts][tokens*thinkers]*/; +}; +} // namespace + +/** + * Performs "mixture of experts" tensor multiplication on CPU. + */ +bool llamafile_mixmul(const ggml_compute_params *params, + const ggml_tensor *weights, + const ggml_tensor *thought, + const ggml_tensor *plan, + ggml_tensor *result) { + MixMul mm{params, weights, thought, plan, result}; + return mm.allocate_shared_memory() && mm.mixmul(); +} + +/** + * Returns number of shared memory bytes llamafile_mixmul() needs. + */ +size_t llamafile_mixmul_needs(const ggml_tensor *weights, + const ggml_tensor *thought, + const ggml_tensor *plan) { + ggml_compute_params params{}; + params.wsize = 0x7ffff000; + params.wdata = (void *)0x1000; + MixMul mm{¶ms, weights, thought, plan, 0}; + if (mm.allocate_shared_memory()) + return mm.get_allocated_bytes(); + else + return 0; +} diff --git a/sgemm.h b/sgemm.h index f29747d0a477af..06b8d7f02ddc94 100644 --- a/sgemm.h +++ b/sgemm.h @@ -1,14 +1,26 @@ #pragma once #include +#include #include #ifdef __cplusplus extern "C" { #endif -bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t, - const void *, int64_t, void *, int64_t, int, int, +struct ggml_tensor; +struct ggml_compute_params; + +bool llamafile_sgemm(intptr_t, intptr_t, intptr_t, const void *, intptr_t, + const void *, intptr_t, void *, intptr_t, int, int, int, int, int, int); +bool llamafile_mixmul(const struct ggml_compute_params *, const struct ggml_tensor *, + const struct ggml_tensor *, const struct ggml_tensor *, + struct ggml_tensor *); + +size_t llamafile_mixmul_needs(const struct ggml_tensor *, + const struct ggml_tensor *, + const struct ggml_tensor *); + #ifdef __cplusplus } #endif