From bba8b881212bd8b8de69bb1028f187c4211eef19 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 30 May 2024 14:10:55 +0200 Subject: [PATCH 01/13] Arm for i-quants This carries over what I had done within llama.cpp. In llamafile we have nice performance gains for PP, but we get performance regression for TG. For now, just adjusted iq2_xxs to also outperform in TG (~10% beter @ 4 and 8 threads). Will tackle the other quants next. --- llamafile/iqk_mul_mat.inc | 478 +++++++++++++++++++++++++++++++++++++- 1 file changed, 474 insertions(+), 4 deletions(-) diff --git a/llamafile/iqk_mul_mat.inc b/llamafile/iqk_mul_mat.inc index a126f2c9ab..a208ad81c3 100644 --- a/llamafile/iqk_mul_mat.inc +++ b/llamafile/iqk_mul_mat.inc @@ -15,6 +15,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #if defined __x86_64__ || defined __aarch64__ @@ -22,6 +23,9 @@ #include "llama.cpp/ggml-quants.h" #include "sgemm.h" +#define GGML_COMMON_IMPL_C +#include "llama.cpp/ggml-common.h" + // clang-format off // This matrix - vector and matrix - matrix multiplication implementation @@ -123,6 +127,41 @@ inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) { aux32[0] = a0 & 0x3f3f3f3f; } +const uint64_t keven_signs[128] = { + 0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff, + 0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff, + 0xff010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0xff010101ff01ffff, + 0x01010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0x01010101ffffffff, + 0xff0101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0xff0101ff0101ffff, + 0x010101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0x010101ff01ffffff, + 0x010101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0x010101ffff01ffff, + 0xff0101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0xff0101ffffffffff, + 0xff01ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0xff01ff010101ffff, + 0x0101ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0x0101ff0101ffffff, + 0x0101ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0x0101ff01ff01ffff, + 0xff01ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0xff01ff01ffffffff, + 0x0101ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0x0101ffff0101ffff, + 0xff01ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0xff01ffff01ffffff, + 0xff01ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0xff01ffffff01ffff, + 0x0101ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0x0101ffffffffffff, + 0xffff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0xffff01010101ffff, + 0x01ff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0x01ff010101ffffff, + 0x01ff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0x01ff0101ff01ffff, + 0xffff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0xffff0101ffffffff, + 0x01ff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0x01ff01ff0101ffff, + 0xffff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0xffff01ff01ffffff, + 0xffff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0xffff01ffff01ffff, + 0x01ff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0x01ff01ffffffffff, + 0x01ffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0x01ffff010101ffff, + 0xffffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0xffffff0101ffffff, + 0xffffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0xffffff01ff01ffff, + 0x01ffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0x01ffff01ffffffff, + 0xffffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0xffffffff0101ffff, + 0x01ffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0x01ffffff01ffffff, + 0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff, + 0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff, +}; + } bool iqk_mul_mat(long Nx, long Ny, long ne00, int typeA, const void * A, const void * B, @@ -1763,6 +1802,7 @@ struct DequantizerQ3K final : public BaseDequantizer { inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { d = GGML_FP16_TO_FP32(x[i].d); h.bits = vld1q_u8_x2(x[i].hmask); + mask = vdupq_n_u8(0x01); const uint16_t * sc16 = (const uint16_t *)x[i].scales; uint32_t aux0 = sc16[0] | (sc16[1] << 16); uint32_t aux1 = sc16[2] | (sc16[3] << 16); @@ -1771,19 +1811,43 @@ struct DequantizerQ3K final : public BaseDequantizer { aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030); aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030); aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030); - return process_scales_mins_16(vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32)), q8, acc, i, -4.f*d); + auto scales8 = vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32)); + if (nrc > 1) { + return process_scales_mins_16(scales8, q8, acc, i, -4.f*d); + } + int16x8x2_t scales16; + scales16.val[0] = vmovl_s8(vget_low_s8(scales8)); + scales16.val[1] = vmovl_s8(vget_high_s8(scales8)); + return make_wider(scales16); } inline void prepare(int i, int j) { bits.prepare(x[i].qs+32*j); - h.apply(bits.b1, bits.b2, j == 0); + if (nrc > 1) { + h.apply(bits.b1, bits.b2, j == 0); + } else { + auto minus4 = vdupq_n_u8(0xfc); + auto zero = vdupq_n_u8(0); + bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); + bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); + mask = vshlq_n_u8(mask, 1); + bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); + bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); + mask = vshlq_n_u8(mask, 1); + bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); + bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); + mask = vshlq_n_u8(mask, 1); + bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); + bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); + mask = vshlq_n_u8(mask, 1); + } } uint32_t aux32[4]; Q2bits bits; - const uint8x16_t mhb = vdupq_n_u8(0x04); + uint8x16_t mask; HighBit3 h; float d; @@ -1861,6 +1925,8 @@ struct DequantizerQ2K final : public BaseDequantizer { float d; }; +// ============================= i-quants + struct DequantizerIQ4XS final : public BaseDequantizer { static int8x16_t load_values() { @@ -1919,8 +1985,365 @@ struct DequantizerIQ4XS final : public BaseDequantizer { float d; }; +struct SimpleBits { + uint8x16x4_t b1; + uint8x16x4_t b2; +}; + +inline int32x4x2_t prepare_scales_8(const uint32x4_t& v1, const uint32x4_t& v2) { + int32x4x2_t scales; + scales.val[0] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v1, 28), 1), vdupq_n_u32(1))); + scales.val[1] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v2, 28), 1), vdupq_n_u32(1))); + return scales; +} + +inline void apply_signs_2(uint8x16_t * b, const uint64_t * signs, uint32_t sidx) { + auto s1 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >> 0) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >> 7) & 127)))); + auto s2 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >>14) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >>21) & 127)))); + b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s1)); + b[1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[1]), s2)); +} + +struct DequantizerIQ2XXS final : public BaseDequantizer { + DequantizerIQ2XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { + d = 0.125f * GGML_FP16_TO_FP32(x[i].d); + + // This is slower + //auto tmp = vld1q_u8_x2((const uint8_t *)x[i].qs); + //data.val[0] = vreinterpretq_u32_u8(vqtbl2q_u8(tmp, vreinterpretq_u8_u32(shuffle1))); + //data.val[1] = vreinterpretq_u32_u8(vqtbl2q_u8(tmp, vreinterpretq_u8_u32(shuffle2))); + //tmp = vld1q_u8_x2((const uint8_t *)x[i].qs + 32); + //data.val[2] = vreinterpretq_u32_u8(vqtbl2q_u8(tmp, vreinterpretq_u8_u32(shuffle1))); + //data.val[3] = vreinterpretq_u32_u8(vqtbl2q_u8(tmp, vreinterpretq_u8_u32(shuffle2))); + + auto tmp = vld1q_u32_x4((const uint32_t *)x[i].qs); + data.val[0] = vuzp1q_u32(tmp.val[0], tmp.val[1]); // codebook indices for blocks 0...3 + data.val[1] = vuzp2q_u32(tmp.val[0], tmp.val[1]); // scales and signs for blocks 0...3 + data.val[2] = vuzp1q_u32(tmp.val[2], tmp.val[3]); // codebook indices for blocks 4...7 + data.val[3] = vuzp2q_u32(tmp.val[2], tmp.val[3]); // scales and signs for blocks 4...7 + + return prepare_scales_8(data.val[1], data.val[3]); + } + + static inline void prepare2(uint8x16_t * b, const uint8_t * idx, const uint64_t * signs, uint32_t sidx) { + b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]}); + b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]}); + //b[0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2xxs_grid + idx[0])), vld1_u8((const uint8_t *)(iq2xxs_grid + idx[1]))); + //b[1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2xxs_grid + idx[2])), vld1_u8((const uint8_t *)(iq2xxs_grid + idx[3]))); + apply_signs_2(b, signs, sidx); + } + static inline void prepare4(uint8x16_t * b, const uint32_t * aux32) { + const uint8_t * idx = (const uint8_t *)aux32; + b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]}); + b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]}); + idx += 8; + b[2] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]}); + b[3] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]}); + apply_signs_2(b+0, keven_signs, aux32[1]); + apply_signs_2(b+2, keven_signs, aux32[3]); + } + static inline void prepare4_unsigned(uint8x16_t * b, const uint32_t * aux32) { + const uint8_t * idx = (const uint8_t *)aux32; + b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]}); + b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]}); + idx += 8; + b[2] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]}); + b[3] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]}); + } + static inline void sign4(uint8x16_t * b, const uint32_t * aux32) { + apply_signs_2(b+0, keven_signs, aux32[1]); + apply_signs_2(b+2, keven_signs, aux32[3]); + } + + inline void prepare(int /*i*/, int j) { + const uint8_t * idx = (const uint8_t *)(data.val + 2*j); + const uint32_t * sidx = (const uint32_t *)(data.val + 2*j+1); + prepare2(bits.b1.val + 0, idx, keven_signs, sidx[0]); idx += 4; + prepare2(bits.b1.val + 2, idx, keven_signs, sidx[1]); idx += 4; + prepare2(bits.b2.val + 0, idx, keven_signs, sidx[2]); idx += 4; + prepare2(bits.b2.val + 2, idx, keven_signs, sidx[3]); + } + + inline void new_block(int i) { + d = 0.125f * GGML_FP16_TO_FP32(x[i].d); + //auto tmp = vld1q_u32_x4((const uint32_t *)x[i].qs); + //data.val[0] = vuzp1q_u32(tmp.val[0], tmp.val[1]); // codebook indices for blocks 0...3 + //data.val[1] = vuzp2q_u32(tmp.val[0], tmp.val[1]); // scales and signs for blocks 0...3 + //data.val[2] = vuzp1q_u32(tmp.val[2], tmp.val[3]); // codebook indices for blocks 4...7 + //data.val[3] = vuzp2q_u32(tmp.val[2], tmp.val[3]); // scales and signs for blocks 4...7 + } + //inline int32_t process_block(int i, int j, const Q8<1, block_q8_K>& q8) { //const Q8& q8) { + // auto qs = x[i].qs + 16*j; + // prepare2(bits.b1.val + 0, (const uint8_t *)(qs + 0), keven_signs, qs[ 2] | (qs[ 3] << 16)); + // auto q = q8.load_quants(0, i, 4*j+0); + // auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), vreinterpretq_s8_u8(bits.b1.val[0]), q.val[0]), vreinterpretq_s8_u8(bits.b1.val[1]), q.val[1]); + // prepare2(bits.b1.val + 2, (const uint8_t *)(qs + 4), keven_signs, qs[ 6] | (qs[ 7] << 16)); + // q = q8.load_quants(0, i, 4*j+1); + // auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), vreinterpretq_s8_u8(bits.b1.val[2]), q.val[0]), vreinterpretq_s8_u8(bits.b1.val[3]), q.val[1]); + // prepare2(bits.b1.val + 0, (const uint8_t *)(qs + 8), keven_signs, qs[10] | (qs[11] << 16)); + // q = q8.load_quants(0, i, 4*j+2); + // auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), vreinterpretq_s8_u8(bits.b1.val[0]), q.val[0]), vreinterpretq_s8_u8(bits.b1.val[1]), q.val[1]); + // prepare2(bits.b1.val + 2, (const uint8_t *)(qs +12), keven_signs, qs[14] | (qs[15] << 16)); + // q = q8.load_quants(0, i, 4*j+3); + // auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), vreinterpretq_s8_u8(bits.b1.val[2]), q.val[0]), vreinterpretq_s8_u8(bits.b1.val[3]), q.val[1]); + // int32_t sum = vaddvq_s32(p1) * (2*(qs[ 3] >> 12) + 1) + vaddvq_s32(p2) * (2*(qs[ 7] >> 12) + 1) + // + vaddvq_s32(p3) * (2*(qs[11] >> 12) + 1) + vaddvq_s32(p4) * (2*(qs[15] >> 12) + 1); + // return sum; + //} + inline int32_t process_block(int i, int j, const Q8<1, block_q8_K>& q8) { + uint32_t aux32[8]; + std::memcpy(aux32, x[i].qs + 16*j, 32); + auto q1 = q8.load_quants_64(0, i, 2*j+0); + auto q2 = q8.load_quants_64(0, i, 2*j+1); + int8x16_t * b1 = (int8x16_t *)bits.b1.val; + int8x16_t * b2 = (int8x16_t *)bits.b2.val; + prepare4(bits.b1.val, aux32+0); + prepare4(bits.b2.val, aux32+4); + //prepare4_unsigned(bits.b1.val, aux32+0); + //prepare4_unsigned(bits.b2.val, aux32+4); + //sign4(bits.b1.val, aux32+0); + //sign4(bits.b2.val, aux32+4); + //prepare2(bits.b1.val + 0, (const uint8_t *)(aux32 + 0), keven_signs, aux32[1]); + //prepare2(bits.b1.val + 2, (const uint8_t *)(aux32 + 2), keven_signs, aux32[3]); + //prepare2(bits.b2.val + 0, (const uint8_t *)(aux32 + 4), keven_signs, aux32[5]); + //prepare2(bits.b2.val + 2, (const uint8_t *)(aux32 + 6), keven_signs, aux32[7]); + auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[0], q1.val[0]), b1[1], q1.val[1]); + auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[2], q1.val[2]), b1[3], q1.val[3]); + auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[0], q2.val[0]), b2[1], q2.val[1]); + auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[2], q2.val[2]), b2[3], q2.val[3]); + int32_t sumi = vaddvq_s32(p1) * (2*(aux32[1] >> 28) + 1) + + vaddvq_s32(p2) * (2*(aux32[3] >> 28) + 1) + + vaddvq_s32(p3) * (2*(aux32[5] >> 28) + 1) + + vaddvq_s32(p4) * (2*(aux32[7] >> 28) + 1); + return sumi; + //sumi[0] += vaddvq_s32(p1) * (2*(aux32[1] >> 28) + 1); + //sumi[1] += vaddvq_s32(p2) * (2*(aux32[3] >> 28) + 1); + //sumi[2] += vaddvq_s32(p3) * (2*(aux32[5] >> 28) + 1); + //sumi[3] += vaddvq_s32(p4) * (2*(aux32[7] >> 28) + 1); + } + + uint32x4x4_t data; + SimpleBits bits; + //const uint32x4_t shuffle1 = {0x03020100, 0x0b0a0908, 0x13121110, 0x1b1a1918}; + //const uint32x4_t shuffle2 = {0x07060504, 0x0f0e0d0c, 0x17161514, 0x1f1e1d1c}; + + float d; +}; + +inline int32x4x4_t prepare_4bit_scales16(const uint8_t * sc) { + auto aux = vld1_u8(sc); + auto scales_l = vand_u8(aux, vdup_n_u8(0xf)); + auto scales_h = vshr_n_u8(aux, 4); + auto aux1 = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h)); + + auto scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(aux1, 1), vdupq_n_u8(1))); + int16x8x2_t scales16 = { vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8)) }; + return make_wider(scales16); +} + +struct DequantizerIQ2XS final : public BaseDequantizer { + DequantizerIQ2XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { + d = 0.125f * GGML_FP16_TO_FP32(x[i].d); + return prepare_4bit_scales16(x[i].scales); + } + + inline static uint8x16_t make1(const uint16_t * qs) { + auto b = vcombine_u8(vld1_u8((const uint8_t *)(iq2xs_grid + (qs[0] & 511))), vld1_u8((const uint8_t *)(iq2xs_grid + (qs[1] & 511)))); + auto s = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[0] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[1] >> 9)))); + return vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b), s)); + } + + inline static void make4(const uint16_t * qs, uint8x16_t * b) { + b[0] = make1(qs + 0); + b[1] = make1(qs + 2); + b[2] = make1(qs + 4); + b[3] = make1(qs + 6); + } + + inline void prepare(int i, int j) { + make4(x[i].qs + 16*j + 0, bits.b1.val); + make4(x[i].qs + 16*j + 8, bits.b2.val); + } + + SimpleBits bits; + + float d; + +}; + +struct SignHelper { + + inline void init() { shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1)); } + + inline void apply_signs_1(uint8x16_t * b, const uint8x16_t& signs16) { + auto aux = vqtbl1q_u8(signs16, shuffle); + auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1)); + b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s)); + shuffle = vaddq_u8(shuffle, step); + } + + const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); + const uint8x16_t m1 = vdupq_n_u8(1); + const uint8x16_t step = vdupq_n_u8(2); + uint8x16_t shuffle; +}; + +struct DequantizerIQ2S final : public BaseDequantizer { + DequantizerIQ2S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { + d = 0.125f * GGML_FP16_TO_FP32(x[i].d); + return prepare_4bit_scales16(x[i].scales); + } + + static inline void make4(SignHelper& sh, const uint8x16_t& signs16, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) { + uint32_t aux32[2]; + const uint16_t * aux16 = (const uint16_t *)aux32; + for (int k = 0; k < 2; ++k) { + aux32[1] = (qh[k] << 4) | (qh[k] << 18); + aux32[0] = (aux32[1] << 4) & 0x03000300; + aux32[1] &= 0x03000300; + b[2*k+0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+0] | aux16[0]))), + vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+1] | aux16[1])))); + sh.apply_signs_1(b+2*k+0, signs16); + + b[2*k+1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+2] | aux16[2]))), + vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+3] | aux16[3])))); + sh.apply_signs_1(b+2*k+1, signs16); + } + } + + inline void prepare(int i, int j) { + + const auto * qs = x[i].qs + 16*j; + const auto * qh = x[i].qh + 4*j; + const auto signs16 = vld1q_u8(qs + QK_K/8); + + sh.init(); + make4(sh, signs16, qs+0, qh+0, bits.b1.val); + make4(sh, signs16, qs+8, qh+2, bits.b2.val); + } + + SimpleBits bits; + SignHelper sh; + + float d; + +}; + +struct DequantizerIQ3XXS final : public BaseDequantizer { + DequantizerIQ3XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { + d = 0.25f * GGML_FP16_TO_FP32(x[i].d); + gas = vld1q_u32_x2((const uint32_t *)(x[i].qs + QK_K/4)); + return prepare_scales_8(gas.val[0], gas.val[1]); + } + + inline static void make2(const uint8_t * q3, uint32_t sidx, uint8x16_t * b) { + b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[0]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[3]]}); + b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[4]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[7]]}); + apply_signs_2(b, keven_signs, sidx); + } + inline void prepare(int i, int j) { + const auto * q3 = x[i].qs + 32*j; + const auto * signs = (const uint32_t *)(gas.val + j); + make2(q3, signs[0], bits.b1.val + 0); q3 += 8; + make2(q3, signs[1], bits.b1.val + 2); q3 += 8; + make2(q3, signs[2], bits.b2.val + 0); q3 += 8; + make2(q3, signs[3], bits.b2.val + 2); + } + + SimpleBits bits; + uint32x4x2_t gas; + + float d; + +}; + +struct DequantizerIQ3S final : public BaseDequantizer { + DequantizerIQ3S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { + d = GGML_FP16_TO_FP32(x[i].d); + uint32_t scales32[2]; + std::memcpy(scales32, x[i].scales, 4); + scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101; + scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101; + auto scales8 = vld1_u8((const uint8_t *)scales32); // 0, 2, 4, 6, 1, 3, 5, 7 + scales8 = vtbl1_u8(scales8, vreinterpret_u8_u64(vdup_n_u64(0x0703060205010400))); + auto scales16 = vreinterpretq_s16_u16(vmovl_u8(scales8)); + int32x4x2_t scales; + scales.val[0] = vmovl_s16(vget_low_s16(scales16)); + scales.val[1] = vmovl_s16(vget_high_s16(scales16)); + return scales; + } + + static inline void make2(SignHelper& sh, const uint8x16_t& signs16, const uint16x8_t& idx_l, uint8_t qh, + const int16x8_t& hshift, uint8x16_t * b) { + auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256))); + const uint16_t * idx = (const uint16_t *)&vindex; + b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]}); + b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]}); + sh.apply_signs_1(b+0, signs16); + sh.apply_signs_1(b+1, signs16); + } + static inline void make4(SignHelper& sh, const uint8x16_t& signs16, const uint8_t * qs, const uint8_t * qh, + const int16x8_t& hshift, uint8x16_t * b) { + auto idx_l = vld1q_u8(qs); + make2(sh, signs16, vmovl_u8(vget_low_u8 (idx_l)), qh[0], hshift, b+0); + make2(sh, signs16, vmovl_u8(vget_high_u8(idx_l)), qh[1], hshift, b+2); + } + + inline void prepare(int i, int j) { + + static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1}; + const auto hshift = vld1q_s16(k_shift); + + const auto * qs = x[i].qs + 32*j; + const auto * qh = x[i].qh + 4*j; + const auto signs16 = vld1q_u8(x[i].signs + 16*j); + + sh.init(); + make4(sh, signs16, qs+ 0, qh+0, hshift, bits.b1.val); + make4(sh, signs16, qs+16, qh+2, hshift, bits.b2.val); + } + + SimpleBits bits; + SignHelper sh; + uint32x4x2_t gas; + + float d; + +}; + + template -static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; @@ -1977,6 +2400,30 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf } } +template +void mul_mat_qX_K_q8_K_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8<1, block_q8_K> q8(info); + + Dequantizer deq(vx, bx, 1); + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + float32_t acc = 0; + for (int i = 0; i < nb; ++i) { + deq.new_block(i); + auto sumi = deq.process_block(i, 0, q8); + sumi += deq.process_block(i, 1, q8); + acc += deq.d*q8.scale(0, i)*sumi; + } + info.store(ix, 0, acc); + } +} + // =========================================== Legacy quants template @@ -2463,6 +2910,29 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int / case GGML_TYPE_IQ4_XS: MulMat::set_functions(m); break; + case GGML_TYPE_IQ2_XXS: + //MulMat::set_functions(m); + m.funcs[0] = mul_mat_qX_K_q8_K_1; + m.funcs[1] = mul_mat_qX_K_q8_K_T<2, DequantizerIQ2XXS>; + m.funcs[2] = mul_mat_qX_K_q8_K_T<3, DequantizerIQ2XXS>; + m.funcs[3] = mul_mat_qX_K_q8_K_T<4, DequantizerIQ2XXS>; + m.funcs[4] = mul_mat_qX_K_q8_K_T<5, DequantizerIQ2XXS>; + m.funcs[5] = mul_mat_qX_K_q8_K_T<6, DequantizerIQ2XXS>; + m.funcs[6] = mul_mat_qX_K_q8_K_T<7, DequantizerIQ2XXS>; + m.funcs[7] = mul_mat_qX_K_q8_K_T<8, DequantizerIQ2XXS>; + break; + case GGML_TYPE_IQ2_XS: + MulMat::set_functions(m); + break; + case GGML_TYPE_IQ2_S: + MulMat::set_functions(m); + break; + case GGML_TYPE_IQ3_XXS: + MulMat::set_functions(m); + break; + case GGML_TYPE_IQ3_S: + MulMat::set_functions(m); + break; case GGML_TYPE_Q4_0: MulMat::set_functions(m); row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); From 9347a82ff21abf0b60ed8b1030322aa25dc75003 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 30 May 2024 15:13:32 +0200 Subject: [PATCH 02/13] Arm for i-quants: iq2_xxs So, improving TG speed results in a drop of performance for PP. Before I had PP-512 = 56.78 t/s, TG-128 = 12.42 t/s @ 8 threads. Now we have PP-512 = 52.77 t/s, TG-128 = 15.97 t/s @ 8 threads. --- llamafile/iqk_mul_mat.inc | 61 +-------------------------------------- 1 file changed, 1 insertion(+), 60 deletions(-) diff --git a/llamafile/iqk_mul_mat.inc b/llamafile/iqk_mul_mat.inc index a208ad81c3..ac77570571 100644 --- a/llamafile/iqk_mul_mat.inc +++ b/llamafile/iqk_mul_mat.inc @@ -2014,14 +2014,6 @@ struct DequantizerIQ2XXS final : public BaseDequantizer { inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { d = 0.125f * GGML_FP16_TO_FP32(x[i].d); - // This is slower - //auto tmp = vld1q_u8_x2((const uint8_t *)x[i].qs); - //data.val[0] = vreinterpretq_u32_u8(vqtbl2q_u8(tmp, vreinterpretq_u8_u32(shuffle1))); - //data.val[1] = vreinterpretq_u32_u8(vqtbl2q_u8(tmp, vreinterpretq_u8_u32(shuffle2))); - //tmp = vld1q_u8_x2((const uint8_t *)x[i].qs + 32); - //data.val[2] = vreinterpretq_u32_u8(vqtbl2q_u8(tmp, vreinterpretq_u8_u32(shuffle1))); - //data.val[3] = vreinterpretq_u32_u8(vqtbl2q_u8(tmp, vreinterpretq_u8_u32(shuffle2))); - auto tmp = vld1q_u32_x4((const uint32_t *)x[i].qs); data.val[0] = vuzp1q_u32(tmp.val[0], tmp.val[1]); // codebook indices for blocks 0...3 data.val[1] = vuzp2q_u32(tmp.val[0], tmp.val[1]); // scales and signs for blocks 0...3 @@ -2034,8 +2026,6 @@ struct DequantizerIQ2XXS final : public BaseDequantizer { static inline void prepare2(uint8x16_t * b, const uint8_t * idx, const uint64_t * signs, uint32_t sidx) { b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]}); b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]}); - //b[0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2xxs_grid + idx[0])), vld1_u8((const uint8_t *)(iq2xxs_grid + idx[1]))); - //b[1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2xxs_grid + idx[2])), vld1_u8((const uint8_t *)(iq2xxs_grid + idx[3]))); apply_signs_2(b, signs, sidx); } static inline void prepare4(uint8x16_t * b, const uint32_t * aux32) { @@ -2048,18 +2038,6 @@ struct DequantizerIQ2XXS final : public BaseDequantizer { apply_signs_2(b+0, keven_signs, aux32[1]); apply_signs_2(b+2, keven_signs, aux32[3]); } - static inline void prepare4_unsigned(uint8x16_t * b, const uint32_t * aux32) { - const uint8_t * idx = (const uint8_t *)aux32; - b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]}); - b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]}); - idx += 8; - b[2] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]}); - b[3] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]}); - } - static inline void sign4(uint8x16_t * b, const uint32_t * aux32) { - apply_signs_2(b+0, keven_signs, aux32[1]); - apply_signs_2(b+2, keven_signs, aux32[3]); - } inline void prepare(int /*i*/, int j) { const uint8_t * idx = (const uint8_t *)(data.val + 2*j); @@ -2072,30 +2050,7 @@ struct DequantizerIQ2XXS final : public BaseDequantizer { inline void new_block(int i) { d = 0.125f * GGML_FP16_TO_FP32(x[i].d); - //auto tmp = vld1q_u32_x4((const uint32_t *)x[i].qs); - //data.val[0] = vuzp1q_u32(tmp.val[0], tmp.val[1]); // codebook indices for blocks 0...3 - //data.val[1] = vuzp2q_u32(tmp.val[0], tmp.val[1]); // scales and signs for blocks 0...3 - //data.val[2] = vuzp1q_u32(tmp.val[2], tmp.val[3]); // codebook indices for blocks 4...7 - //data.val[3] = vuzp2q_u32(tmp.val[2], tmp.val[3]); // scales and signs for blocks 4...7 - } - //inline int32_t process_block(int i, int j, const Q8<1, block_q8_K>& q8) { //const Q8& q8) { - // auto qs = x[i].qs + 16*j; - // prepare2(bits.b1.val + 0, (const uint8_t *)(qs + 0), keven_signs, qs[ 2] | (qs[ 3] << 16)); - // auto q = q8.load_quants(0, i, 4*j+0); - // auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), vreinterpretq_s8_u8(bits.b1.val[0]), q.val[0]), vreinterpretq_s8_u8(bits.b1.val[1]), q.val[1]); - // prepare2(bits.b1.val + 2, (const uint8_t *)(qs + 4), keven_signs, qs[ 6] | (qs[ 7] << 16)); - // q = q8.load_quants(0, i, 4*j+1); - // auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), vreinterpretq_s8_u8(bits.b1.val[2]), q.val[0]), vreinterpretq_s8_u8(bits.b1.val[3]), q.val[1]); - // prepare2(bits.b1.val + 0, (const uint8_t *)(qs + 8), keven_signs, qs[10] | (qs[11] << 16)); - // q = q8.load_quants(0, i, 4*j+2); - // auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), vreinterpretq_s8_u8(bits.b1.val[0]), q.val[0]), vreinterpretq_s8_u8(bits.b1.val[1]), q.val[1]); - // prepare2(bits.b1.val + 2, (const uint8_t *)(qs +12), keven_signs, qs[14] | (qs[15] << 16)); - // q = q8.load_quants(0, i, 4*j+3); - // auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), vreinterpretq_s8_u8(bits.b1.val[2]), q.val[0]), vreinterpretq_s8_u8(bits.b1.val[3]), q.val[1]); - // int32_t sum = vaddvq_s32(p1) * (2*(qs[ 3] >> 12) + 1) + vaddvq_s32(p2) * (2*(qs[ 7] >> 12) + 1) - // + vaddvq_s32(p3) * (2*(qs[11] >> 12) + 1) + vaddvq_s32(p4) * (2*(qs[15] >> 12) + 1); - // return sum; - //} + } inline int32_t process_block(int i, int j, const Q8<1, block_q8_K>& q8) { uint32_t aux32[8]; std::memcpy(aux32, x[i].qs + 16*j, 32); @@ -2105,14 +2060,6 @@ struct DequantizerIQ2XXS final : public BaseDequantizer { int8x16_t * b2 = (int8x16_t *)bits.b2.val; prepare4(bits.b1.val, aux32+0); prepare4(bits.b2.val, aux32+4); - //prepare4_unsigned(bits.b1.val, aux32+0); - //prepare4_unsigned(bits.b2.val, aux32+4); - //sign4(bits.b1.val, aux32+0); - //sign4(bits.b2.val, aux32+4); - //prepare2(bits.b1.val + 0, (const uint8_t *)(aux32 + 0), keven_signs, aux32[1]); - //prepare2(bits.b1.val + 2, (const uint8_t *)(aux32 + 2), keven_signs, aux32[3]); - //prepare2(bits.b2.val + 0, (const uint8_t *)(aux32 + 4), keven_signs, aux32[5]); - //prepare2(bits.b2.val + 2, (const uint8_t *)(aux32 + 6), keven_signs, aux32[7]); auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[0], q1.val[0]), b1[1], q1.val[1]); auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[2], q1.val[2]), b1[3], q1.val[3]); auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[0], q2.val[0]), b2[1], q2.val[1]); @@ -2122,16 +2069,10 @@ struct DequantizerIQ2XXS final : public BaseDequantizer { + vaddvq_s32(p3) * (2*(aux32[5] >> 28) + 1) + vaddvq_s32(p4) * (2*(aux32[7] >> 28) + 1); return sumi; - //sumi[0] += vaddvq_s32(p1) * (2*(aux32[1] >> 28) + 1); - //sumi[1] += vaddvq_s32(p2) * (2*(aux32[3] >> 28) + 1); - //sumi[2] += vaddvq_s32(p3) * (2*(aux32[5] >> 28) + 1); - //sumi[3] += vaddvq_s32(p4) * (2*(aux32[7] >> 28) + 1); } uint32x4x4_t data; SimpleBits bits; - //const uint32x4_t shuffle1 = {0x03020100, 0x0b0a0908, 0x13121110, 0x1b1a1918}; - //const uint32x4_t shuffle2 = {0x07060504, 0x0f0e0d0c, 0x17161514, 0x1f1e1d1c}; float d; }; From ffcb2972c399a9011b955ced0f04da031463866f Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 30 May 2024 17:55:02 +0200 Subject: [PATCH 03/13] Arm for i-quants: iq3_s Improved TG from 4.96 t/s yto 5.43 t/s. Still ~3.5$ slower than mainline. PP-512 became slightly better (47.9 vs 46.8 t/s). This is 3.9X mainline (!) --- llamafile/iqk_mul_mat.inc | 117 +++++++++++++++++++++++++++----------- 1 file changed, 83 insertions(+), 34 deletions(-) diff --git a/llamafile/iqk_mul_mat.inc b/llamafile/iqk_mul_mat.inc index ac77570571..c17a6ac859 100644 --- a/llamafile/iqk_mul_mat.inc +++ b/llamafile/iqk_mul_mat.inc @@ -23,6 +23,12 @@ #include "llama.cpp/ggml-quants.h" #include "sgemm.h" +#ifdef _MSC_VER +#define IQK_NOINLINE __declspec(noinline) +#else +#define IQK_NOINLINE __attribute__((__noinline__)) +#endif + #define GGML_COMMON_IMPL_C #include "llama.cpp/ggml-common.h" @@ -86,8 +92,8 @@ typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& inf struct MulMat { std::array funcs = {}; - //std::array funcs = {}; - inline void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) { + //inline void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) { + IQK_NOINLINE void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) { #ifdef __aarch64__ constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small) #else @@ -111,9 +117,9 @@ struct MulMat { funcs[n_left-1](n, vx, bx, info, nrc_x); } } - static bool set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int Ny); + static IQK_NOINLINE bool set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int Ny); private: - template static void set_functions(MulMat& m); + template static IQK_NOINLINE void set_functions(MulMat& m); }; inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) { @@ -2140,6 +2146,20 @@ struct SignHelper { const uint8x16_t step = vdupq_n_u8(2); uint8x16_t shuffle; }; +struct SignHelper1 { + + inline void apply_signs_1x(uint8x16_t * b, const uint8_t * sign_bits) const { + auto aux = vcombine_u8(vdup_n_u8(sign_bits[0]), vdup_n_u8(sign_bits[1])); + //auto s = vceqq_u8(vandq_u8(aux, smask), smask); + //b[0] = vreinterpretq_u8_s8(vsubq_s8(vreinterpretq_s8_u8(veorq_u8(b[0], s)), vreinterpretq_s8_u8(s))); + // Not much of a difference compared to the above. Perhaps tiny little bit faster. + auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1)); + b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s)); + } + + const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); + const uint8x16_t m1 = vdupq_n_u8(1); +}; struct DequantizerIQ2S final : public BaseDequantizer { DequantizerIQ2S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} @@ -2244,39 +2264,65 @@ struct DequantizerIQ3S final : public BaseDequantizer { return scales; } - static inline void make2(SignHelper& sh, const uint8x16_t& signs16, const uint16x8_t& idx_l, uint8_t qh, + static inline void make2(SignHelper1& sh, const uint8_t * sign_bits, const uint16x8_t& idx_l, uint8_t qh, const int16x8_t& hshift, uint8x16_t * b) { auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256))); const uint16_t * idx = (const uint16_t *)&vindex; b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]}); + sh.apply_signs_1x(b+0, sign_bits+0); b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]}); - sh.apply_signs_1(b+0, signs16); - sh.apply_signs_1(b+1, signs16); + sh.apply_signs_1x(b+1, sign_bits+2); } - static inline void make4(SignHelper& sh, const uint8x16_t& signs16, const uint8_t * qs, const uint8_t * qh, + static inline void make4(SignHelper1& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, const int16x8_t& hshift, uint8x16_t * b) { auto idx_l = vld1q_u8(qs); - make2(sh, signs16, vmovl_u8(vget_low_u8 (idx_l)), qh[0], hshift, b+0); - make2(sh, signs16, vmovl_u8(vget_high_u8(idx_l)), qh[1], hshift, b+2); + make2(sh, sign_bits+0, vmovl_u8(vget_low_u8 (idx_l)), qh[0], hshift, b+0); + make2(sh, sign_bits+4, vmovl_u8(vget_high_u8(idx_l)), qh[1], hshift, b+2); } - inline void prepare(int i, int j) { - + static int16x8_t load_shift() { static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1}; - const auto hshift = vld1q_s16(k_shift); + return vld1q_s16(k_shift); + } + + inline void prepare(int i, int j) { const auto * qs = x[i].qs + 32*j; const auto * qh = x[i].qh + 4*j; - const auto signs16 = vld1q_u8(x[i].signs + 16*j); - sh.init(); - make4(sh, signs16, qs+ 0, qh+0, hshift, bits.b1.val); - make4(sh, signs16, qs+16, qh+2, hshift, bits.b2.val); + make4(sh, x[i].signs + 16*j + 0, qs+ 0, qh+0, hshift, bits.b1.val); + make4(sh, x[i].signs + 16*j + 8, qs+16, qh+2, hshift, bits.b2.val); + } + + inline void new_block(int i) { + d = GGML_FP16_TO_FP32(x[i].d); + } + inline int32_t process_block(int i, int j, const Q8<1, block_q8_K>& q8) { + prepare(i, j); + const uint16_t * aux16 = (const uint16_t *)x[i].scales + j; + uint32_t scales32 = (((aux16[0] | (aux16[0] << 12)) & 0x0f0f0f0f) << 1) | 0x01010101; + int8x16_t * b1 = (int8x16_t *)bits.b1.val; + int8x16_t * b2 = (int8x16_t *)bits.b2.val; + auto q1 = q8.load_quants_64(0, i, 2*j+0); + auto q2 = q8.load_quants_64(0, i, 2*j+1); + //auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[0], q1.val[0]), b1[1], q1.val[1]); + //auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[2], q1.val[2]), b1[3], q1.val[3]); + //auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[0], q2.val[0]), b2[1], q2.val[1]); + //auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[2], q2.val[2]), b2[3], q2.val[3]); + //const int8_t * s8 = (const int8_t *)&scales32; + //return vaddvq_s32(p1) * s8[0] + vaddvq_s32(p2) * s8[2] + vaddvq_s32(p3) * s8[1] + vaddvq_s32(p4) * s8[3]; + auto zero = vdupq_n_s32(0); + auto p1 = vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(zero, b1[0], q1.val[0]), b1[1], q1.val[1])); + auto p2 = vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(zero, b1[2], q1.val[2]), b1[3], q1.val[3])); + auto p3 = vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(zero, b2[0], q2.val[0]), b2[1], q2.val[1])); + auto p4 = vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(zero, b2[2], q2.val[2]), b2[3], q2.val[3])); + const int8_t * s8 = (const int8_t *)&scales32; + return p1 * s8[0] + p2 * s8[2] + p3 * s8[1] + p4 * s8[3]; } SimpleBits bits; - SignHelper sh; - uint32x4x2_t gas; + SignHelper1 sh; + const int16x8_t hshift = load_shift(); float d; @@ -2284,7 +2330,7 @@ struct DequantizerIQ3S final : public BaseDequantizer { template -void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +IQK_NOINLINE void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; @@ -2342,7 +2388,7 @@ void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info } template -void mul_mat_qX_K_q8_K_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +IQK_NOINLINE void mul_mat_qX_K_q8_K_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; @@ -2758,7 +2804,7 @@ inline void mul_mat_qX_Y_q8_Y_1(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& } template -static void mul_mat_qX_1_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void IQK_NOINLINE mul_mat_qX_1_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { Q81 q8(info); if constexpr (nrc_y == 1) { Dequantizer deq1(vx, bx), deq2(vx, bx); @@ -2770,7 +2816,7 @@ static void mul_mat_qX_1_q8_1(int n, const void * vx, size_t bx, const DataInfo& } template -static void mul_mat_qX_0_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void IQK_NOINLINE mul_mat_qX_0_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { Q80 q8(info); if constexpr (nrc_y == 1) { Dequantizer deq1(vx, bx), deq2(vx, bx); @@ -2782,14 +2828,14 @@ static void mul_mat_qX_0_q8_0(int n, const void * vx, size_t bx, const DataInfo& } template -static void mul_mat_qX_1_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void IQK_NOINLINE mul_mat_qX_1_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { Dequantizer deq1(vx, bx), deq2(vx, bx); Q81<1> q8(info); mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); } template -static void mul_mat_qX_0_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void IQK_NOINLINE mul_mat_qX_0_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { Dequantizer deq1(vx, bx), deq2(vx, bx); Q80<1> q8(info); mul_mat_qX_Y_q8_Y(n, deq1, deq2, q8, info, nrc_x); @@ -2817,6 +2863,17 @@ template void MulMat::set_functions(MulMat& m) { m.funcs[6] = mul_mat_qX_1_q8_1; m.funcs[7] = mul_mat_qX_1_q8_1; } + else if constexpr (std::is_same_v || std::is_same_v) { + m.funcs[0] = mul_mat_qX_K_q8_K_1; + //m.funcs[0] = mul_mat_qX_K_q8_K_T<1, DequantizerIQ2XXS>; + m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>; + m.funcs[2] = mul_mat_qX_K_q8_K_T<3, Dequantizer>; + m.funcs[3] = mul_mat_qX_K_q8_K_T<4, Dequantizer>; + m.funcs[4] = mul_mat_qX_K_q8_K_T<5, Dequantizer>; + m.funcs[5] = mul_mat_qX_K_q8_K_T<6, Dequantizer>; + m.funcs[6] = mul_mat_qX_K_q8_K_T<7, Dequantizer>; + m.funcs[7] = mul_mat_qX_K_q8_K_T<8, Dequantizer>; + } else { m.funcs[0] = mul_mat_qX_K_q8_K_T<1, Dequantizer>; m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>; @@ -2852,15 +2909,7 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int / MulMat::set_functions(m); break; case GGML_TYPE_IQ2_XXS: - //MulMat::set_functions(m); - m.funcs[0] = mul_mat_qX_K_q8_K_1; - m.funcs[1] = mul_mat_qX_K_q8_K_T<2, DequantizerIQ2XXS>; - m.funcs[2] = mul_mat_qX_K_q8_K_T<3, DequantizerIQ2XXS>; - m.funcs[3] = mul_mat_qX_K_q8_K_T<4, DequantizerIQ2XXS>; - m.funcs[4] = mul_mat_qX_K_q8_K_T<5, DequantizerIQ2XXS>; - m.funcs[5] = mul_mat_qX_K_q8_K_T<6, DequantizerIQ2XXS>; - m.funcs[6] = mul_mat_qX_K_q8_K_T<7, DequantizerIQ2XXS>; - m.funcs[7] = mul_mat_qX_K_q8_K_T<8, DequantizerIQ2XXS>; + MulMat::set_functions(m); break; case GGML_TYPE_IQ2_XS: MulMat::set_functions(m); From 79b3ba92208b475b41b8fa4f2be90b2e1b4a5706 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 30 May 2024 19:05:31 +0200 Subject: [PATCH 04/13] Arm for i-quants: iq3_xxs PP stays the same - 3.67X mainline. TG improves slightly to 5.05 t/s from 4.74 t/s @ 4 threads. This is still 15% slower than mainline. --- llamafile/iqk_mul_mat.inc | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/llamafile/iqk_mul_mat.inc b/llamafile/iqk_mul_mat.inc index c17a6ac859..6f7e17ecc7 100644 --- a/llamafile/iqk_mul_mat.inc +++ b/llamafile/iqk_mul_mat.inc @@ -2226,6 +2226,10 @@ struct DequantizerIQ3XXS final : public BaseDequantizer { b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[4]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[7]]}); apply_signs_2(b, keven_signs, sidx); } + inline static void make2_unsigned(const uint8_t * q3, uint8x16_t * b) { + b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[0]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[3]]}); + b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[4]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[7]]}); + } inline void prepare(int i, int j) { const auto * q3 = x[i].qs + 32*j; const auto * signs = (const uint32_t *)(gas.val + j); @@ -2234,6 +2238,37 @@ struct DequantizerIQ3XXS final : public BaseDequantizer { make2(q3, signs[2], bits.b2.val + 0); q3 += 8; make2(q3, signs[3], bits.b2.val + 2); } + inline void prepare_v2(int i, int j, const uint32_t * signs) { + const auto * q3 = x[i].qs + 32*j; + make2_unsigned(q3, bits.b1.val + 0); q3 += 8; + make2_unsigned(q3, bits.b1.val + 2); q3 += 8; + apply_signs_2(bits.b1.val+0, keven_signs, signs[0]); + apply_signs_2(bits.b1.val+2, keven_signs, signs[1]); + make2_unsigned(q3, bits.b2.val + 0); q3 += 8; + make2_unsigned(q3, bits.b2.val + 2); + apply_signs_2(bits.b2.val+0, keven_signs, signs[2]); + apply_signs_2(bits.b2.val+2, keven_signs, signs[3]); + } + + inline void new_block(int i) { + d = 0.25f * GGML_FP16_TO_FP32(x[i].d); + } + + inline int32_t process_block(int i, int j, const Q8<1, block_q8_K>& q8) { + gas.val[0] = vld1q_u32((const uint32_t *)(x[i].qs + QK_K/4 + 16*j)); + prepare_v2(i, j, (const uint32_t *)gas.val); + gas.val[0] = vorrq_u32(vshlq_n_u32(vshrq_n_u32(gas.val[0], 28), 1), vdupq_n_u32(1)); + int8x16_t * b1 = (int8x16_t *)bits.b1.val; + int8x16_t * b2 = (int8x16_t *)bits.b2.val; + auto q1 = q8.load_quants_64(0, i, 2*j+0); + auto q2 = q8.load_quants_64(0, i, 2*j+1); + auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[0], q1.val[0]), b1[1], q1.val[1]); + auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[2], q1.val[2]), b1[3], q1.val[3]); + auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[0], q2.val[0]), b2[1], q2.val[1]); + auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[2], q2.val[2]), b2[3], q2.val[3]); + const int32_t * s32 = (const int32_t *)gas.val; + return vaddvq_s32(p1) * s32[0] + vaddvq_s32(p2) * s32[1] + vaddvq_s32(p3) * s32[2] + vaddvq_s32(p4) * s32[3]; + } SimpleBits bits; uint32x4x2_t gas; @@ -2863,7 +2898,8 @@ template void MulMat::set_functions(MulMat& m) { m.funcs[6] = mul_mat_qX_1_q8_1; m.funcs[7] = mul_mat_qX_1_q8_1; } - else if constexpr (std::is_same_v || std::is_same_v) { + else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { m.funcs[0] = mul_mat_qX_K_q8_K_1; //m.funcs[0] = mul_mat_qX_K_q8_K_T<1, DequantizerIQ2XXS>; m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>; From 6960310ddc8e764d533085a39607338104c05b0e Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 31 May 2024 08:14:12 +0200 Subject: [PATCH 05/13] Arm for i-quants: iq2_s We get 3.32X mainline for PP. TG is, sadly, 0.92X @ 4 threads --- llamafile/iqk_mul_mat.inc | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/llamafile/iqk_mul_mat.inc b/llamafile/iqk_mul_mat.inc index 6f7e17ecc7..800926f060 100644 --- a/llamafile/iqk_mul_mat.inc +++ b/llamafile/iqk_mul_mat.inc @@ -2170,10 +2170,12 @@ struct DequantizerIQ2S final : public BaseDequantizer { template inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { d = 0.125f * GGML_FP16_TO_FP32(x[i].d); + prepare_internal(i, 0, bits); + //prepare_internal(i, 1, bits1); return prepare_4bit_scales16(x[i].scales); } - static inline void make4(SignHelper& sh, const uint8x16_t& signs16, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) { + static inline void make4(SignHelper1& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) { uint32_t aux32[2]; const uint16_t * aux16 = (const uint16_t *)aux32; for (int k = 0; k < 2; ++k) { @@ -2182,27 +2184,31 @@ struct DequantizerIQ2S final : public BaseDequantizer { aux32[1] &= 0x03000300; b[2*k+0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+0] | aux16[0]))), vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+1] | aux16[1])))); - sh.apply_signs_1(b+2*k+0, signs16); - b[2*k+1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+2] | aux16[2]))), vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+3] | aux16[3])))); - sh.apply_signs_1(b+2*k+1, signs16); + sh.apply_signs_1x(b+2*k+0, sign_bits); sign_bits += 2; + sh.apply_signs_1x(b+2*k+1, sign_bits); sign_bits += 2; } } inline void prepare(int i, int j) { + if (j == 1) prepare_internal(i, 1, bits); + //if (j == 1) bits = bits1; + } + + inline void prepare_internal(int i, int j, SimpleBits& sb) { const auto * qs = x[i].qs + 16*j; const auto * qh = x[i].qh + 4*j; - const auto signs16 = vld1q_u8(qs + QK_K/8); + const auto * sign_bits = qs + QK_K/8; - sh.init(); - make4(sh, signs16, qs+0, qh+0, bits.b1.val); - make4(sh, signs16, qs+8, qh+2, bits.b2.val); + make4(sh, sign_bits+0, qs+0, qh+0, sb.b1.val); + make4(sh, sign_bits+8, qs+8, qh+2, sb.b2.val); } SimpleBits bits; - SignHelper sh; + //SimpleBits bits1; + SignHelper1 sh; float d; @@ -2899,7 +2905,7 @@ template void MulMat::set_functions(MulMat& m) { m.funcs[7] = mul_mat_qX_1_q8_1; } else if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v) {// || std::is_same_v) { m.funcs[0] = mul_mat_qX_K_q8_K_1; //m.funcs[0] = mul_mat_qX_K_q8_K_T<1, DequantizerIQ2XXS>; m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>; From 3fc3cd012b79cc4dc2341571432871d16bfb09de Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 31 May 2024 08:31:40 +0200 Subject: [PATCH 06/13] Arm for i-quants: iq2_xs We get 2.87X mainline for PP. TG is, sadly, 0.95X @ 4 threads --- llamafile/iqk_mul_mat.inc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/llamafile/iqk_mul_mat.inc b/llamafile/iqk_mul_mat.inc index 800926f060..fec148c64f 100644 --- a/llamafile/iqk_mul_mat.inc +++ b/llamafile/iqk_mul_mat.inc @@ -2103,6 +2103,7 @@ struct DequantizerIQ2XS final : public BaseDequantizer { template inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { d = 0.125f * GGML_FP16_TO_FP32(x[i].d); + prepare_internal(i, 0); return prepare_4bit_scales16(x[i].scales); } @@ -2119,11 +2120,15 @@ struct DequantizerIQ2XS final : public BaseDequantizer { b[3] = make1(qs + 6); } - inline void prepare(int i, int j) { + inline void prepare_internal(int i, int j) { make4(x[i].qs + 16*j + 0, bits.b1.val); make4(x[i].qs + 16*j + 8, bits.b2.val); } + inline void prepare(int i, int j) { + if (j == 1) prepare_internal(i, 1); + } + SimpleBits bits; float d; From bf442bbf87c6ecb5f8e0e68105bf1802ba02b851 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 31 May 2024 10:31:27 +0200 Subject: [PATCH 07/13] Arm for i-quants: abandoning special-casing Ny = 1 --- llamafile/iqk_mul_mat.inc | 222 +++++++++++++++++++------------------- 1 file changed, 111 insertions(+), 111 deletions(-) diff --git a/llamafile/iqk_mul_mat.inc b/llamafile/iqk_mul_mat.inc index fec148c64f..c2cb5159f6 100644 --- a/llamafile/iqk_mul_mat.inc +++ b/llamafile/iqk_mul_mat.inc @@ -2054,28 +2054,28 @@ struct DequantizerIQ2XXS final : public BaseDequantizer { prepare2(bits.b2.val + 2, idx, keven_signs, sidx[3]); } - inline void new_block(int i) { - d = 0.125f * GGML_FP16_TO_FP32(x[i].d); - } - inline int32_t process_block(int i, int j, const Q8<1, block_q8_K>& q8) { - uint32_t aux32[8]; - std::memcpy(aux32, x[i].qs + 16*j, 32); - auto q1 = q8.load_quants_64(0, i, 2*j+0); - auto q2 = q8.load_quants_64(0, i, 2*j+1); - int8x16_t * b1 = (int8x16_t *)bits.b1.val; - int8x16_t * b2 = (int8x16_t *)bits.b2.val; - prepare4(bits.b1.val, aux32+0); - prepare4(bits.b2.val, aux32+4); - auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[0], q1.val[0]), b1[1], q1.val[1]); - auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[2], q1.val[2]), b1[3], q1.val[3]); - auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[0], q2.val[0]), b2[1], q2.val[1]); - auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[2], q2.val[2]), b2[3], q2.val[3]); - int32_t sumi = vaddvq_s32(p1) * (2*(aux32[1] >> 28) + 1) - + vaddvq_s32(p2) * (2*(aux32[3] >> 28) + 1) - + vaddvq_s32(p3) * (2*(aux32[5] >> 28) + 1) - + vaddvq_s32(p4) * (2*(aux32[7] >> 28) + 1); - return sumi; - } + //inline void new_block(int i) { + // d = 0.125f * GGML_FP16_TO_FP32(x[i].d); + //} + //inline int32_t process_block(int i, int j, const Q8<1, block_q8_K>& q8) { + // uint32_t aux32[8]; + // std::memcpy(aux32, x[i].qs + 16*j, 32); + // auto q1 = q8.load_quants_64(0, i, 2*j+0); + // auto q2 = q8.load_quants_64(0, i, 2*j+1); + // int8x16_t * b1 = (int8x16_t *)bits.b1.val; + // int8x16_t * b2 = (int8x16_t *)bits.b2.val; + // prepare4(bits.b1.val, aux32+0); + // prepare4(bits.b2.val, aux32+4); + // auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[0], q1.val[0]), b1[1], q1.val[1]); + // auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[2], q1.val[2]), b1[3], q1.val[3]); + // auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[0], q2.val[0]), b2[1], q2.val[1]); + // auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[2], q2.val[2]), b2[3], q2.val[3]); + // int32_t sumi = vaddvq_s32(p1) * (2*(aux32[1] >> 28) + 1) + // + vaddvq_s32(p2) * (2*(aux32[3] >> 28) + 1) + // + vaddvq_s32(p3) * (2*(aux32[5] >> 28) + 1) + // + vaddvq_s32(p4) * (2*(aux32[7] >> 28) + 1); + // return sumi; + //} uint32x4x4_t data; SimpleBits bits; @@ -2249,37 +2249,37 @@ struct DequantizerIQ3XXS final : public BaseDequantizer { make2(q3, signs[2], bits.b2.val + 0); q3 += 8; make2(q3, signs[3], bits.b2.val + 2); } - inline void prepare_v2(int i, int j, const uint32_t * signs) { - const auto * q3 = x[i].qs + 32*j; - make2_unsigned(q3, bits.b1.val + 0); q3 += 8; - make2_unsigned(q3, bits.b1.val + 2); q3 += 8; - apply_signs_2(bits.b1.val+0, keven_signs, signs[0]); - apply_signs_2(bits.b1.val+2, keven_signs, signs[1]); - make2_unsigned(q3, bits.b2.val + 0); q3 += 8; - make2_unsigned(q3, bits.b2.val + 2); - apply_signs_2(bits.b2.val+0, keven_signs, signs[2]); - apply_signs_2(bits.b2.val+2, keven_signs, signs[3]); - } + //inline void prepare_v2(int i, int j, const uint32_t * signs) { + // const auto * q3 = x[i].qs + 32*j; + // make2_unsigned(q3, bits.b1.val + 0); q3 += 8; + // make2_unsigned(q3, bits.b1.val + 2); q3 += 8; + // apply_signs_2(bits.b1.val+0, keven_signs, signs[0]); + // apply_signs_2(bits.b1.val+2, keven_signs, signs[1]); + // make2_unsigned(q3, bits.b2.val + 0); q3 += 8; + // make2_unsigned(q3, bits.b2.val + 2); + // apply_signs_2(bits.b2.val+0, keven_signs, signs[2]); + // apply_signs_2(bits.b2.val+2, keven_signs, signs[3]); + //} - inline void new_block(int i) { - d = 0.25f * GGML_FP16_TO_FP32(x[i].d); - } + //inline void new_block(int i) { + // d = 0.25f * GGML_FP16_TO_FP32(x[i].d); + //} - inline int32_t process_block(int i, int j, const Q8<1, block_q8_K>& q8) { - gas.val[0] = vld1q_u32((const uint32_t *)(x[i].qs + QK_K/4 + 16*j)); - prepare_v2(i, j, (const uint32_t *)gas.val); - gas.val[0] = vorrq_u32(vshlq_n_u32(vshrq_n_u32(gas.val[0], 28), 1), vdupq_n_u32(1)); - int8x16_t * b1 = (int8x16_t *)bits.b1.val; - int8x16_t * b2 = (int8x16_t *)bits.b2.val; - auto q1 = q8.load_quants_64(0, i, 2*j+0); - auto q2 = q8.load_quants_64(0, i, 2*j+1); - auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[0], q1.val[0]), b1[1], q1.val[1]); - auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[2], q1.val[2]), b1[3], q1.val[3]); - auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[0], q2.val[0]), b2[1], q2.val[1]); - auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[2], q2.val[2]), b2[3], q2.val[3]); - const int32_t * s32 = (const int32_t *)gas.val; - return vaddvq_s32(p1) * s32[0] + vaddvq_s32(p2) * s32[1] + vaddvq_s32(p3) * s32[2] + vaddvq_s32(p4) * s32[3]; - } + //inline int32_t process_block(int i, int j, const Q8<1, block_q8_K>& q8) { + // gas.val[0] = vld1q_u32((const uint32_t *)(x[i].qs + QK_K/4 + 16*j)); + // prepare_v2(i, j, (const uint32_t *)gas.val); + // gas.val[0] = vorrq_u32(vshlq_n_u32(vshrq_n_u32(gas.val[0], 28), 1), vdupq_n_u32(1)); + // int8x16_t * b1 = (int8x16_t *)bits.b1.val; + // int8x16_t * b2 = (int8x16_t *)bits.b2.val; + // auto q1 = q8.load_quants_64(0, i, 2*j+0); + // auto q2 = q8.load_quants_64(0, i, 2*j+1); + // auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[0], q1.val[0]), b1[1], q1.val[1]); + // auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[2], q1.val[2]), b1[3], q1.val[3]); + // auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[0], q2.val[0]), b2[1], q2.val[1]); + // auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[2], q2.val[2]), b2[3], q2.val[3]); + // const int32_t * s32 = (const int32_t *)gas.val; + // return vaddvq_s32(p1) * s32[0] + vaddvq_s32(p2) * s32[1] + vaddvq_s32(p3) * s32[2] + vaddvq_s32(p4) * s32[3]; + //} SimpleBits bits; uint32x4x2_t gas; @@ -2340,31 +2340,31 @@ struct DequantizerIQ3S final : public BaseDequantizer { make4(sh, x[i].signs + 16*j + 8, qs+16, qh+2, hshift, bits.b2.val); } - inline void new_block(int i) { - d = GGML_FP16_TO_FP32(x[i].d); - } - inline int32_t process_block(int i, int j, const Q8<1, block_q8_K>& q8) { - prepare(i, j); - const uint16_t * aux16 = (const uint16_t *)x[i].scales + j; - uint32_t scales32 = (((aux16[0] | (aux16[0] << 12)) & 0x0f0f0f0f) << 1) | 0x01010101; - int8x16_t * b1 = (int8x16_t *)bits.b1.val; - int8x16_t * b2 = (int8x16_t *)bits.b2.val; - auto q1 = q8.load_quants_64(0, i, 2*j+0); - auto q2 = q8.load_quants_64(0, i, 2*j+1); - //auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[0], q1.val[0]), b1[1], q1.val[1]); - //auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[2], q1.val[2]), b1[3], q1.val[3]); - //auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[0], q2.val[0]), b2[1], q2.val[1]); - //auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[2], q2.val[2]), b2[3], q2.val[3]); - //const int8_t * s8 = (const int8_t *)&scales32; - //return vaddvq_s32(p1) * s8[0] + vaddvq_s32(p2) * s8[2] + vaddvq_s32(p3) * s8[1] + vaddvq_s32(p4) * s8[3]; - auto zero = vdupq_n_s32(0); - auto p1 = vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(zero, b1[0], q1.val[0]), b1[1], q1.val[1])); - auto p2 = vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(zero, b1[2], q1.val[2]), b1[3], q1.val[3])); - auto p3 = vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(zero, b2[0], q2.val[0]), b2[1], q2.val[1])); - auto p4 = vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(zero, b2[2], q2.val[2]), b2[3], q2.val[3])); - const int8_t * s8 = (const int8_t *)&scales32; - return p1 * s8[0] + p2 * s8[2] + p3 * s8[1] + p4 * s8[3]; - } + //inline void new_block(int i) { + // d = GGML_FP16_TO_FP32(x[i].d); + //} + //inline int32_t process_block(int i, int j, const Q8<1, block_q8_K>& q8) { + // prepare(i, j); + // const uint16_t * aux16 = (const uint16_t *)x[i].scales + j; + // uint32_t scales32 = (((aux16[0] | (aux16[0] << 12)) & 0x0f0f0f0f) << 1) | 0x01010101; + // int8x16_t * b1 = (int8x16_t *)bits.b1.val; + // int8x16_t * b2 = (int8x16_t *)bits.b2.val; + // auto q1 = q8.load_quants_64(0, i, 2*j+0); + // auto q2 = q8.load_quants_64(0, i, 2*j+1); + // //auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[0], q1.val[0]), b1[1], q1.val[1]); + // //auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[2], q1.val[2]), b1[3], q1.val[3]); + // //auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[0], q2.val[0]), b2[1], q2.val[1]); + // //auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[2], q2.val[2]), b2[3], q2.val[3]); + // //const int8_t * s8 = (const int8_t *)&scales32; + // //return vaddvq_s32(p1) * s8[0] + vaddvq_s32(p2) * s8[2] + vaddvq_s32(p3) * s8[1] + vaddvq_s32(p4) * s8[3]; + // auto zero = vdupq_n_s32(0); + // auto p1 = vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(zero, b1[0], q1.val[0]), b1[1], q1.val[1])); + // auto p2 = vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(zero, b1[2], q1.val[2]), b1[3], q1.val[3])); + // auto p3 = vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(zero, b2[0], q2.val[0]), b2[1], q2.val[1])); + // auto p4 = vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(zero, b2[2], q2.val[2]), b2[3], q2.val[3])); + // const int8_t * s8 = (const int8_t *)&scales32; + // return p1 * s8[0] + p2 * s8[2] + p3 * s8[1] + p4 * s8[3]; + //} SimpleBits bits; SignHelper1 sh; @@ -2433,29 +2433,29 @@ IQK_NOINLINE void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const D } } -template -IQK_NOINLINE void mul_mat_qX_K_q8_K_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - assert(n % QK_K == 0); - const int nb = n / QK_K; - - Q8<1, block_q8_K> q8(info); - - Dequantizer deq(vx, bx, 1); - - for (int ix = 0; ix < nrc_x; ++ix) { - - deq.new_row(ix); - - float32_t acc = 0; - for (int i = 0; i < nb; ++i) { - deq.new_block(i); - auto sumi = deq.process_block(i, 0, q8); - sumi += deq.process_block(i, 1, q8); - acc += deq.d*q8.scale(0, i)*sumi; - } - info.store(ix, 0, acc); - } -} +//template +//IQK_NOINLINE void mul_mat_qX_K_q8_K_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +// assert(n % QK_K == 0); +// const int nb = n / QK_K; +// +// Q8<1, block_q8_K> q8(info); +// +// Dequantizer deq(vx, bx, 1); +// +// for (int ix = 0; ix < nrc_x; ++ix) { +// +// deq.new_row(ix); +// +// float32_t acc = 0; +// for (int i = 0; i < nb; ++i) { +// deq.new_block(i); +// auto sumi = deq.process_block(i, 0, q8); +// sumi += deq.process_block(i, 1, q8); +// acc += deq.d*q8.scale(0, i)*sumi; +// } +// info.store(ix, 0, acc); +// } +//} // =========================================== Legacy quants @@ -2909,18 +2909,18 @@ template void MulMat::set_functions(MulMat& m) { m.funcs[6] = mul_mat_qX_1_q8_1; m.funcs[7] = mul_mat_qX_1_q8_1; } - else if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v) {// || std::is_same_v) { - m.funcs[0] = mul_mat_qX_K_q8_K_1; - //m.funcs[0] = mul_mat_qX_K_q8_K_T<1, DequantizerIQ2XXS>; - m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>; - m.funcs[2] = mul_mat_qX_K_q8_K_T<3, Dequantizer>; - m.funcs[3] = mul_mat_qX_K_q8_K_T<4, Dequantizer>; - m.funcs[4] = mul_mat_qX_K_q8_K_T<5, Dequantizer>; - m.funcs[5] = mul_mat_qX_K_q8_K_T<6, Dequantizer>; - m.funcs[6] = mul_mat_qX_K_q8_K_T<7, Dequantizer>; - m.funcs[7] = mul_mat_qX_K_q8_K_T<8, Dequantizer>; - } + //else if constexpr (std::is_same_v || std::is_same_v || + // std::is_same_v) {// || std::is_same_v) { + // m.funcs[0] = mul_mat_qX_K_q8_K_1; + // //m.funcs[0] = mul_mat_qX_K_q8_K_T<1, DequantizerIQ2XXS>; + // m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>; + // m.funcs[2] = mul_mat_qX_K_q8_K_T<3, Dequantizer>; + // m.funcs[3] = mul_mat_qX_K_q8_K_T<4, Dequantizer>; + // m.funcs[4] = mul_mat_qX_K_q8_K_T<5, Dequantizer>; + // m.funcs[5] = mul_mat_qX_K_q8_K_T<6, Dequantizer>; + // m.funcs[6] = mul_mat_qX_K_q8_K_T<7, Dequantizer>; + // m.funcs[7] = mul_mat_qX_K_q8_K_T<8, Dequantizer>; + //} else { m.funcs[0] = mul_mat_qX_K_q8_K_T<1, Dequantizer>; m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>; From 6f60746898aba857d843c03880836840fa60071b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 31 May 2024 11:23:26 +0200 Subject: [PATCH 08/13] Arm for i-quants: cleanup and disable iqk_mul_mat for Ny = 1 --- llamafile/iqk_mul_mat.inc | 165 +++----------------------------------- 1 file changed, 9 insertions(+), 156 deletions(-) diff --git a/llamafile/iqk_mul_mat.inc b/llamafile/iqk_mul_mat.inc index c2cb5159f6..d9af2a73da 100644 --- a/llamafile/iqk_mul_mat.inc +++ b/llamafile/iqk_mul_mat.inc @@ -2034,16 +2034,6 @@ struct DequantizerIQ2XXS final : public BaseDequantizer { b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]}); apply_signs_2(b, signs, sidx); } - static inline void prepare4(uint8x16_t * b, const uint32_t * aux32) { - const uint8_t * idx = (const uint8_t *)aux32; - b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]}); - b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]}); - idx += 8; - b[2] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]}); - b[3] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]}); - apply_signs_2(b+0, keven_signs, aux32[1]); - apply_signs_2(b+2, keven_signs, aux32[3]); - } inline void prepare(int /*i*/, int j) { const uint8_t * idx = (const uint8_t *)(data.val + 2*j); @@ -2054,29 +2044,6 @@ struct DequantizerIQ2XXS final : public BaseDequantizer { prepare2(bits.b2.val + 2, idx, keven_signs, sidx[3]); } - //inline void new_block(int i) { - // d = 0.125f * GGML_FP16_TO_FP32(x[i].d); - //} - //inline int32_t process_block(int i, int j, const Q8<1, block_q8_K>& q8) { - // uint32_t aux32[8]; - // std::memcpy(aux32, x[i].qs + 16*j, 32); - // auto q1 = q8.load_quants_64(0, i, 2*j+0); - // auto q2 = q8.load_quants_64(0, i, 2*j+1); - // int8x16_t * b1 = (int8x16_t *)bits.b1.val; - // int8x16_t * b2 = (int8x16_t *)bits.b2.val; - // prepare4(bits.b1.val, aux32+0); - // prepare4(bits.b2.val, aux32+4); - // auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[0], q1.val[0]), b1[1], q1.val[1]); - // auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[2], q1.val[2]), b1[3], q1.val[3]); - // auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[0], q2.val[0]), b2[1], q2.val[1]); - // auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[2], q2.val[2]), b2[3], q2.val[3]); - // int32_t sumi = vaddvq_s32(p1) * (2*(aux32[1] >> 28) + 1) - // + vaddvq_s32(p2) * (2*(aux32[3] >> 28) + 1) - // + vaddvq_s32(p3) * (2*(aux32[5] >> 28) + 1) - // + vaddvq_s32(p4) * (2*(aux32[7] >> 28) + 1); - // return sumi; - //} - uint32x4x4_t data; SimpleBits bits; @@ -2137,22 +2104,6 @@ struct DequantizerIQ2XS final : public BaseDequantizer { struct SignHelper { - inline void init() { shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1)); } - - inline void apply_signs_1(uint8x16_t * b, const uint8x16_t& signs16) { - auto aux = vqtbl1q_u8(signs16, shuffle); - auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1)); - b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s)); - shuffle = vaddq_u8(shuffle, step); - } - - const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); - const uint8x16_t m1 = vdupq_n_u8(1); - const uint8x16_t step = vdupq_n_u8(2); - uint8x16_t shuffle; -}; -struct SignHelper1 { - inline void apply_signs_1x(uint8x16_t * b, const uint8_t * sign_bits) const { auto aux = vcombine_u8(vdup_n_u8(sign_bits[0]), vdup_n_u8(sign_bits[1])); //auto s = vceqq_u8(vandq_u8(aux, smask), smask); @@ -2176,11 +2127,10 @@ struct DequantizerIQ2S final : public BaseDequantizer { inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { d = 0.125f * GGML_FP16_TO_FP32(x[i].d); prepare_internal(i, 0, bits); - //prepare_internal(i, 1, bits1); return prepare_4bit_scales16(x[i].scales); } - static inline void make4(SignHelper1& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) { + static inline void make4(SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) { uint32_t aux32[2]; const uint16_t * aux16 = (const uint16_t *)aux32; for (int k = 0; k < 2; ++k) { @@ -2198,7 +2148,6 @@ struct DequantizerIQ2S final : public BaseDequantizer { inline void prepare(int i, int j) { if (j == 1) prepare_internal(i, 1, bits); - //if (j == 1) bits = bits1; } inline void prepare_internal(int i, int j, SimpleBits& sb) { @@ -2212,8 +2161,7 @@ struct DequantizerIQ2S final : public BaseDequantizer { } SimpleBits bits; - //SimpleBits bits1; - SignHelper1 sh; + SignHelper sh; float d; @@ -2237,10 +2185,6 @@ struct DequantizerIQ3XXS final : public BaseDequantizer { b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[4]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[7]]}); apply_signs_2(b, keven_signs, sidx); } - inline static void make2_unsigned(const uint8_t * q3, uint8x16_t * b) { - b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[0]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[3]]}); - b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[4]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[7]]}); - } inline void prepare(int i, int j) { const auto * q3 = x[i].qs + 32*j; const auto * signs = (const uint32_t *)(gas.val + j); @@ -2249,37 +2193,6 @@ struct DequantizerIQ3XXS final : public BaseDequantizer { make2(q3, signs[2], bits.b2.val + 0); q3 += 8; make2(q3, signs[3], bits.b2.val + 2); } - //inline void prepare_v2(int i, int j, const uint32_t * signs) { - // const auto * q3 = x[i].qs + 32*j; - // make2_unsigned(q3, bits.b1.val + 0); q3 += 8; - // make2_unsigned(q3, bits.b1.val + 2); q3 += 8; - // apply_signs_2(bits.b1.val+0, keven_signs, signs[0]); - // apply_signs_2(bits.b1.val+2, keven_signs, signs[1]); - // make2_unsigned(q3, bits.b2.val + 0); q3 += 8; - // make2_unsigned(q3, bits.b2.val + 2); - // apply_signs_2(bits.b2.val+0, keven_signs, signs[2]); - // apply_signs_2(bits.b2.val+2, keven_signs, signs[3]); - //} - - //inline void new_block(int i) { - // d = 0.25f * GGML_FP16_TO_FP32(x[i].d); - //} - - //inline int32_t process_block(int i, int j, const Q8<1, block_q8_K>& q8) { - // gas.val[0] = vld1q_u32((const uint32_t *)(x[i].qs + QK_K/4 + 16*j)); - // prepare_v2(i, j, (const uint32_t *)gas.val); - // gas.val[0] = vorrq_u32(vshlq_n_u32(vshrq_n_u32(gas.val[0], 28), 1), vdupq_n_u32(1)); - // int8x16_t * b1 = (int8x16_t *)bits.b1.val; - // int8x16_t * b2 = (int8x16_t *)bits.b2.val; - // auto q1 = q8.load_quants_64(0, i, 2*j+0); - // auto q2 = q8.load_quants_64(0, i, 2*j+1); - // auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[0], q1.val[0]), b1[1], q1.val[1]); - // auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[2], q1.val[2]), b1[3], q1.val[3]); - // auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[0], q2.val[0]), b2[1], q2.val[1]); - // auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[2], q2.val[2]), b2[3], q2.val[3]); - // const int32_t * s32 = (const int32_t *)gas.val; - // return vaddvq_s32(p1) * s32[0] + vaddvq_s32(p2) * s32[1] + vaddvq_s32(p3) * s32[2] + vaddvq_s32(p4) * s32[3]; - //} SimpleBits bits; uint32x4x2_t gas; @@ -2310,7 +2223,7 @@ struct DequantizerIQ3S final : public BaseDequantizer { return scales; } - static inline void make2(SignHelper1& sh, const uint8_t * sign_bits, const uint16x8_t& idx_l, uint8_t qh, + static inline void make2(SignHelper& sh, const uint8_t * sign_bits, const uint16x8_t& idx_l, uint8_t qh, const int16x8_t& hshift, uint8x16_t * b) { auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256))); const uint16_t * idx = (const uint16_t *)&vindex; @@ -2319,7 +2232,7 @@ struct DequantizerIQ3S final : public BaseDequantizer { b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]}); sh.apply_signs_1x(b+1, sign_bits+2); } - static inline void make4(SignHelper1& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, + static inline void make4(SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, const int16x8_t& hshift, uint8x16_t * b) { auto idx_l = vld1q_u8(qs); make2(sh, sign_bits+0, vmovl_u8(vget_low_u8 (idx_l)), qh[0], hshift, b+0); @@ -2340,41 +2253,14 @@ struct DequantizerIQ3S final : public BaseDequantizer { make4(sh, x[i].signs + 16*j + 8, qs+16, qh+2, hshift, bits.b2.val); } - //inline void new_block(int i) { - // d = GGML_FP16_TO_FP32(x[i].d); - //} - //inline int32_t process_block(int i, int j, const Q8<1, block_q8_K>& q8) { - // prepare(i, j); - // const uint16_t * aux16 = (const uint16_t *)x[i].scales + j; - // uint32_t scales32 = (((aux16[0] | (aux16[0] << 12)) & 0x0f0f0f0f) << 1) | 0x01010101; - // int8x16_t * b1 = (int8x16_t *)bits.b1.val; - // int8x16_t * b2 = (int8x16_t *)bits.b2.val; - // auto q1 = q8.load_quants_64(0, i, 2*j+0); - // auto q2 = q8.load_quants_64(0, i, 2*j+1); - // //auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[0], q1.val[0]), b1[1], q1.val[1]); - // //auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[2], q1.val[2]), b1[3], q1.val[3]); - // //auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[0], q2.val[0]), b2[1], q2.val[1]); - // //auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[2], q2.val[2]), b2[3], q2.val[3]); - // //const int8_t * s8 = (const int8_t *)&scales32; - // //return vaddvq_s32(p1) * s8[0] + vaddvq_s32(p2) * s8[2] + vaddvq_s32(p3) * s8[1] + vaddvq_s32(p4) * s8[3]; - // auto zero = vdupq_n_s32(0); - // auto p1 = vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(zero, b1[0], q1.val[0]), b1[1], q1.val[1])); - // auto p2 = vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(zero, b1[2], q1.val[2]), b1[3], q1.val[3])); - // auto p3 = vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(zero, b2[0], q2.val[0]), b2[1], q2.val[1])); - // auto p4 = vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(zero, b2[2], q2.val[2]), b2[3], q2.val[3])); - // const int8_t * s8 = (const int8_t *)&scales32; - // return p1 * s8[0] + p2 * s8[2] + p3 * s8[1] + p4 * s8[3]; - //} - SimpleBits bits; - SignHelper1 sh; + SignHelper sh; const int16x8_t hshift = load_shift(); float d; }; - template IQK_NOINLINE void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); @@ -2433,30 +2319,6 @@ IQK_NOINLINE void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const D } } -//template -//IQK_NOINLINE void mul_mat_qX_K_q8_K_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { -// assert(n % QK_K == 0); -// const int nb = n / QK_K; -// -// Q8<1, block_q8_K> q8(info); -// -// Dequantizer deq(vx, bx, 1); -// -// for (int ix = 0; ix < nrc_x; ++ix) { -// -// deq.new_row(ix); -// -// float32_t acc = 0; -// for (int i = 0; i < nb; ++i) { -// deq.new_block(i); -// auto sumi = deq.process_block(i, 0, q8); -// sumi += deq.process_block(i, 1, q8); -// acc += deq.d*q8.scale(0, i)*sumi; -// } -// info.store(ix, 0, acc); -// } -//} - // =========================================== Legacy quants template @@ -2909,18 +2771,6 @@ template void MulMat::set_functions(MulMat& m) { m.funcs[6] = mul_mat_qX_1_q8_1; m.funcs[7] = mul_mat_qX_1_q8_1; } - //else if constexpr (std::is_same_v || std::is_same_v || - // std::is_same_v) {// || std::is_same_v) { - // m.funcs[0] = mul_mat_qX_K_q8_K_1; - // //m.funcs[0] = mul_mat_qX_K_q8_K_T<1, DequantizerIQ2XXS>; - // m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>; - // m.funcs[2] = mul_mat_qX_K_q8_K_T<3, Dequantizer>; - // m.funcs[3] = mul_mat_qX_K_q8_K_T<4, Dequantizer>; - // m.funcs[4] = mul_mat_qX_K_q8_K_T<5, Dequantizer>; - // m.funcs[5] = mul_mat_qX_K_q8_K_T<6, Dequantizer>; - // m.funcs[6] = mul_mat_qX_K_q8_K_T<7, Dequantizer>; - // m.funcs[7] = mul_mat_qX_K_q8_K_T<8, Dequantizer>; - //} else { m.funcs[0] = mul_mat_qX_K_q8_K_T<1, Dequantizer>; m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>; @@ -2933,9 +2783,12 @@ template void MulMat::set_functions(MulMat& m) { } } -bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int /*Ny*/) { +bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int Ny) { row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00); + if (Ny == 1 && (typeA == GGML_TYPE_IQ2_XXS || typeA == GGML_TYPE_IQ2_XS || typeA == GGML_TYPE_IQ2_S || + typeA == GGML_TYPE_IQ3_XXS || typeA == GGML_TYPE_IQ3_S)) return false; + switch (typeA) { case GGML_TYPE_Q2_K: MulMat::set_functions(m); From 0609aa6665f670acec8c0a4a9f9e7b2675bcad75 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 1 Jun 2024 08:03:42 +0200 Subject: [PATCH 09/13] Arm for i-quants: holding the compiler's hand Turns out we can improve quite a bit by explicitely asking the compiler to never inline some functions, and to always inline some other. With that, PP performance gains are > 3X for all i-quants, reacing 4.3X for iq3_s. TG is also always better, except for iq3_xxs, where it is 0.99X, so re-enabled iql_mul_mat for Ny = 1. --- llamafile/iqk_mul_mat.inc | 80 ++++++++++++++++++++++++++------------- 1 file changed, 54 insertions(+), 26 deletions(-) diff --git a/llamafile/iqk_mul_mat.inc b/llamafile/iqk_mul_mat.inc index d9af2a73da..f155ed9022 100644 --- a/llamafile/iqk_mul_mat.inc +++ b/llamafile/iqk_mul_mat.inc @@ -25,8 +25,10 @@ #ifdef _MSC_VER #define IQK_NOINLINE __declspec(noinline) +#define IQK_ALWAYS_INLINE inline #else #define IQK_NOINLINE __attribute__((__noinline__)) +#define IQK_ALWAYS_INLINE __attribute__((always_inline)) #endif #define GGML_COMMON_IMPL_C @@ -1481,23 +1483,21 @@ template struct Q8 { }; template -inline void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8, +IQK_ALWAYS_INLINE void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8, const int32x4x2_t& scales, int iy, int i, int j, int32x4_t& sumi) { auto mzero = vdupq_n_s32(0); + const int8x16_t * qs_1 = (const int8x16_t *)qx_1.val; + const int8x16_t * qs_2 = (const int8x16_t *)qx_2.val; auto q8b_1 = q8.load_quants(iy, i, 4*j+0); - auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]), - vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1]); // block 1 + auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_1[0], q8b_1.val[0]), qs_1[1], q8b_1.val[1]); // block 1 auto q8b_2 = q8.load_quants(iy, i, 4*j+1); - auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]), - vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1]); // block 2 + auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_1[2], q8b_2.val[0]), qs_1[3], q8b_2.val[1]); // block 2 auto p12 = vpaddq_s32(p1, p2); auto q8b_3 = q8.load_quants(iy, i, 4*j+2); - auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]), - vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1]); // block 1 + auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_2[0], q8b_3.val[0]), qs_2[1], q8b_3.val[1]); // block 3 auto q8b_4 = q8.load_quants(iy, i, 4*j+3); - auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]), - vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1]); // block 2 + auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_2[2], q8b_4.val[0]), qs_2[3], q8b_4.val[1]); // block 4 auto p34 = vpaddq_s32(p3, p4); auto pall = vpaddq_s32(p12, p34); @@ -1505,7 +1505,7 @@ inline void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, } template -inline void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8, +IQK_ALWAYS_INLINE void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8, const int32x4x4_t& scales, int iy, int i, int j, int32x4_t& sumi) { auto mzero = vdupq_n_s32(0); @@ -1996,10 +1996,13 @@ struct SimpleBits { uint8x16x4_t b2; }; -inline int32x4x2_t prepare_scales_8(const uint32x4_t& v1, const uint32x4_t& v2) { +IQK_ALWAYS_INLINE int32x4x2_t prepare_scales_8(const uint32x4_t& v1, const uint32x4_t& v2) { int32x4x2_t scales; - scales.val[0] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v1, 28), 1), vdupq_n_u32(1))); - scales.val[1] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v2, 28), 1), vdupq_n_u32(1))); + auto one = vdupq_n_u32(1); + scales.val[0] = vreinterpretq_s32_u32(vsliq_n_u32(one, vshrq_n_u32(v1, 28), 1)); + scales.val[1] = vreinterpretq_s32_u32(vsliq_n_u32(one, vshrq_n_u32(v2, 28), 1)); + //scales.val[0] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v1, 28), 1), vdupq_n_u32(1))); + //scales.val[1] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v2, 28), 1), vdupq_n_u32(1))); return scales; } @@ -2020,6 +2023,9 @@ struct DequantizerIQ2XXS final : public BaseDequantizer { inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { d = 0.125f * GGML_FP16_TO_FP32(x[i].d); + //data = vld1q_u32_x4((const uint32_t *)x[i].qs); + //return prepare_scales_8(vuzp2q_u32(data.val[0], data.val[1]), vuzp2q_u32(data.val[2], data.val[3])); + auto tmp = vld1q_u32_x4((const uint32_t *)x[i].qs); data.val[0] = vuzp1q_u32(tmp.val[0], tmp.val[1]); // codebook indices for blocks 0...3 data.val[1] = vuzp2q_u32(tmp.val[0], tmp.val[1]); // scales and signs for blocks 0...3 @@ -2029,7 +2035,7 @@ struct DequantizerIQ2XXS final : public BaseDequantizer { return prepare_scales_8(data.val[1], data.val[3]); } - static inline void prepare2(uint8x16_t * b, const uint8_t * idx, const uint64_t * signs, uint32_t sidx) { + inline void prepare2(uint8x16_t * b, const uint8_t * idx, const uint64_t * signs, uint32_t sidx) { b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]}); b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]}); apply_signs_2(b, signs, sidx); @@ -2044,6 +2050,21 @@ struct DequantizerIQ2XXS final : public BaseDequantizer { prepare2(bits.b2.val + 2, idx, keven_signs, sidx[3]); } + //static inline void prepare2(uint8x16_t * b, const uint32_t * bits, const uint64_t * signs) { + // const uint8_t * idx = (const uint8_t *)bits; + // b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]}); + // b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]}); + // apply_signs_2(b, signs, bits[1]); + //} + + //inline void prepare(int /*i*/, int j) { + // const uint32_t * q2 = (const uint32_t *)(data.val + 2*j); + // prepare2(bits.b1.val + 0, q2+0, keven_signs); + // prepare2(bits.b1.val + 2, q2+2, keven_signs); + // prepare2(bits.b2.val + 0, q2+4, keven_signs); + // prepare2(bits.b2.val + 2, q2+6, keven_signs); + //} + uint32x4x4_t data; SimpleBits bits; @@ -2074,7 +2095,7 @@ struct DequantizerIQ2XS final : public BaseDequantizer { return prepare_4bit_scales16(x[i].scales); } - inline static uint8x16_t make1(const uint16_t * qs) { + static inline uint8x16_t make1(const uint16_t * qs) { auto b = vcombine_u8(vld1_u8((const uint8_t *)(iq2xs_grid + (qs[0] & 511))), vld1_u8((const uint8_t *)(iq2xs_grid + (qs[1] & 511)))); auto s = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[0] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[1] >> 9)))); return vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b), s)); @@ -2087,7 +2108,7 @@ struct DequantizerIQ2XS final : public BaseDequantizer { b[3] = make1(qs + 6); } - inline void prepare_internal(int i, int j) { + IQK_ALWAYS_INLINE void prepare_internal(int i, int j) { make4(x[i].qs + 16*j + 0, bits.b1.val); make4(x[i].qs + 16*j + 8, bits.b2.val); } @@ -2104,7 +2125,7 @@ struct DequantizerIQ2XS final : public BaseDequantizer { struct SignHelper { - inline void apply_signs_1x(uint8x16_t * b, const uint8_t * sign_bits) const { + IQK_ALWAYS_INLINE void apply_signs_1x(uint8x16_t * b, const uint8_t * sign_bits) const { auto aux = vcombine_u8(vdup_n_u8(sign_bits[0]), vdup_n_u8(sign_bits[1])); //auto s = vceqq_u8(vandq_u8(aux, smask), smask); //b[0] = vreinterpretq_u8_s8(vsubq_s8(vreinterpretq_s8_u8(veorq_u8(b[0], s)), vreinterpretq_s8_u8(s))); @@ -2130,7 +2151,7 @@ struct DequantizerIQ2S final : public BaseDequantizer { return prepare_4bit_scales16(x[i].scales); } - static inline void make4(SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) { + static inline void make4(const SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) { uint32_t aux32[2]; const uint16_t * aux16 = (const uint16_t *)aux32; for (int k = 0; k < 2; ++k) { @@ -2188,10 +2209,10 @@ struct DequantizerIQ3XXS final : public BaseDequantizer { inline void prepare(int i, int j) { const auto * q3 = x[i].qs + 32*j; const auto * signs = (const uint32_t *)(gas.val + j); - make2(q3, signs[0], bits.b1.val + 0); q3 += 8; - make2(q3, signs[1], bits.b1.val + 2); q3 += 8; - make2(q3, signs[2], bits.b2.val + 0); q3 += 8; - make2(q3, signs[3], bits.b2.val + 2); + make2(q3+ 0, signs[0], bits.b1.val + 0); + make2(q3+ 8, signs[1], bits.b1.val + 2); + make2(q3+16, signs[2], bits.b2.val + 0); + make2(q3+24, signs[3], bits.b2.val + 2); } SimpleBits bits; @@ -2223,7 +2244,7 @@ struct DequantizerIQ3S final : public BaseDequantizer { return scales; } - static inline void make2(SignHelper& sh, const uint8_t * sign_bits, const uint16x8_t& idx_l, uint8_t qh, + static inline void make2(const SignHelper& sh, const uint8_t * sign_bits, const uint16x8_t& idx_l, uint8_t qh, const int16x8_t& hshift, uint8x16_t * b) { auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256))); const uint16_t * idx = (const uint16_t *)&vindex; @@ -2232,7 +2253,7 @@ struct DequantizerIQ3S final : public BaseDequantizer { b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]}); sh.apply_signs_1x(b+1, sign_bits+2); } - static inline void make4(SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, + static inline void make4(const SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, const int16x8_t& hshift, uint8x16_t * b) { auto idx_l = vld1q_u8(qs); make2(sh, sign_bits+0, vmovl_u8(vget_low_u8 (idx_l)), qh[0], hshift, b+0); @@ -2277,6 +2298,7 @@ IQK_NOINLINE void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const D float32x4_t acc[nrc_y]; for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); +//#pragma GCC unroll 4 for (int i = 0; i < nb; ++i) { int32x4_t sumi[nrc_y]; @@ -2292,15 +2314,19 @@ IQK_NOINLINE void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const D if constexpr (Dequantizer::num_blocks() == 8) { auto scales = deq.new_block(i, q8, acc); deq.prepare(i, 0); +#pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); deq.prepare(i, 1); +#pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); } else if constexpr (Dequantizer::num_blocks() == 16) { auto scales = deq.new_block(i, q8, acc); deq.prepare(i, 0); +#pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); deq.prepare(i, 1); +#pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); } else { @@ -2308,11 +2334,13 @@ IQK_NOINLINE void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const D } } +#pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) { acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); } } +#pragma GCC unroll 8 for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, vaddvq_f32(acc[iy])); } @@ -2786,8 +2814,8 @@ template void MulMat::set_functions(MulMat& m) { bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int Ny) { row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00); - if (Ny == 1 && (typeA == GGML_TYPE_IQ2_XXS || typeA == GGML_TYPE_IQ2_XS || typeA == GGML_TYPE_IQ2_S || - typeA == GGML_TYPE_IQ3_XXS || typeA == GGML_TYPE_IQ3_S)) return false; + //if (Ny == 1 && (typeA == GGML_TYPE_IQ2_XXS || typeA == GGML_TYPE_IQ2_XS || typeA == GGML_TYPE_IQ2_S || + // typeA == GGML_TYPE_IQ3_XXS || typeA == GGML_TYPE_IQ3_S)) return false; switch (typeA) { case GGML_TYPE_Q2_K: From eccc2c09fac517f9a79e01e4d0ebc1b14e8e2892 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 1 Jun 2024 20:35:38 +0200 Subject: [PATCH 10/13] Arm for i-quants: iterating Turns out changing one method of a quant affects the performance of other qunts(s). Is the compiler somehow trying to optimize all template instantiations together? Anyway, with this version I have this: | cpu_info | model_filename | size | test | t/s | | ---------------------------: | -------------: | ---------: | ------: | ------: | | Apple M2 Max (+fp16+dotprod) | iq2xxs | 1.73 GiB | tg128 | 9.02 | | Apple M2 Max (+fp16+dotprod) | iq2xxs | 1.73 GiB | pp512 | 61.31 | | Apple M2 Max (+fp16+dotprod) | iq2xs | 1.89 GiB | tg128 | 10.58 | | Apple M2 Max (+fp16+dotprod) | iq2xs | 1.89 GiB | pp512 | 56.11 | | Apple M2 Max (+fp16+dotprod) | iq2m | 2.20 GiB | tg128 | 7.07 | | Apple M2 Max (+fp16+dotprod) | iq2m | 2.20 GiB | pp512 | 45.78 | | Apple M2 Max (+fp16+dotprod) | iq3xxs | 2.41 GiB | tg128 | 6.40 | | Apple M2 Max (+fp16+dotprod) | iq3xxs | 2.41 GiB | pp512 | 47.51 | | Apple M2 Max (+fp16+dotprod) | iq3m | 2.90 GiB | tg128 | 5.97 | | Apple M2 Max (+fp16+dotprod) | iq3m | 2.90 GiB | pp512 | 47.98 | TG is with 4 threads, PP with 8. --- llamafile/iqk_mul_mat.inc | 369 ++++++++++++++++++++++---------------- 1 file changed, 217 insertions(+), 152 deletions(-) diff --git a/llamafile/iqk_mul_mat.inc b/llamafile/iqk_mul_mat.inc index f155ed9022..87d77478a5 100644 --- a/llamafile/iqk_mul_mat.inc +++ b/llamafile/iqk_mul_mat.inc @@ -1470,6 +1470,7 @@ template struct Q8 { for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy); } + inline int8x16_t load_quants_16(int iy, int i, int j) const { return vld1q_s8(y[iy][i].qs + 16*j); } inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); } inline int8x16x4_t load_quants_64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); } inline int16x8x2_t load_bsums(int iy, int i) const { return vld1q_s16_x2(y[iy][i].bsums); } @@ -1482,12 +1483,131 @@ template struct Q8 { const block_q8 * y[nrc_y]; }; +template +IQK_NOINLINE void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8 q8(info); + + Dequantizer deq(vx, bx, nrc_y); + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + float32x4_t acc[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); + +//#pragma GCC unroll 4 + for (int i = 0; i < nb; ++i) { + + int32x4_t sumi[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0); + + if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) { + deq.process_scales(i, q8, acc); + deq.prepare(i, 0); + deq.compute(q8, i, 0, sumi); + deq.prepare(i, 1); + deq.compute(q8, i, 1, sumi); + } else { + if constexpr (Dequantizer::num_blocks() == 8) { + auto scales = deq.new_block(i, q8, acc); + deq.prepare(i, 0); +#pragma GCC unroll 8 + for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); + deq.prepare(i, 1); +#pragma GCC unroll 8 + for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); + } + else if constexpr (Dequantizer::num_blocks() == 16) { + auto scales = deq.new_block(i, q8, acc); + deq.prepare(i, 0); +#pragma GCC unroll 8 + for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); + deq.prepare(i, 1); +#pragma GCC unroll 8 + for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); + } + else { + GGML_ASSERT(false); + } + } + +#pragma GCC unroll 8 + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); + } + } + +#pragma GCC unroll 8 + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(acc[iy])); + } + } +} +template +IQK_NOINLINE void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8 q8(info); + + Dequantizer deq(vx, bx, nrc_y); + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + float32x4_t acc[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); + + for (int i = 0; i < nb; ++i) { + + int32x4_t sumi[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0); + + if constexpr (Dequantizer::num_blocks() == 8) { + auto scales = deq.new_block(i); + deq.prepare(i, 0); +#pragma GCC unroll 8 + for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); + deq.prepare(i, 1); +#pragma GCC unroll 8 + for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); + } + else if constexpr (Dequantizer::num_blocks() == 16) { + auto scales = deq.new_block(i); + deq.prepare(i, 0); +#pragma GCC unroll 8 + for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); + deq.prepare(i, 1); +#pragma GCC unroll 8 + for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); + } + else { + GGML_ASSERT(false); + } +#pragma GCC unroll 8 + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); + } + } +#pragma GCC unroll 8 + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(acc[iy])); + } + } +} + template IQK_ALWAYS_INLINE void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8, const int32x4x2_t& scales, int iy, int i, int j, int32x4_t& sumi) { auto mzero = vdupq_n_s32(0); const int8x16_t * qs_1 = (const int8x16_t *)qx_1.val; const int8x16_t * qs_2 = (const int8x16_t *)qx_2.val; + auto q8b_1 = q8.load_quants(iy, i, 4*j+0); auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_1[0], q8b_1.val[0]), qs_1[1], q8b_1.val[1]); // block 1 auto q8b_2 = q8.load_quants(iy, i, 4*j+1); @@ -2019,56 +2139,37 @@ struct DequantizerIQ2XXS final : public BaseDequantizer { constexpr static int num_blocks() { return 8; } constexpr static bool should_scale_quants() { return false; } - template - inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { - d = 0.125f * GGML_FP16_TO_FP32(x[i].d); - - //data = vld1q_u32_x4((const uint32_t *)x[i].qs); - //return prepare_scales_8(vuzp2q_u32(data.val[0], data.val[1]), vuzp2q_u32(data.val[2], data.val[3])); - - auto tmp = vld1q_u32_x4((const uint32_t *)x[i].qs); - data.val[0] = vuzp1q_u32(tmp.val[0], tmp.val[1]); // codebook indices for blocks 0...3 - data.val[1] = vuzp2q_u32(tmp.val[0], tmp.val[1]); // scales and signs for blocks 0...3 - data.val[2] = vuzp1q_u32(tmp.val[2], tmp.val[3]); // codebook indices for blocks 4...7 - data.val[3] = vuzp2q_u32(tmp.val[2], tmp.val[3]); // scales and signs for blocks 4...7 + SimpleBits bits; + float d; - return prepare_scales_8(data.val[1], data.val[3]); + IQK_ALWAYS_INLINE int32x4x2_t new_block(int i) { + d = 0.125f * GGML_FP16_TO_FP32(x[i].d); + data = vld1q_u32_x4((const uint32_t *)x[i].qs); + prepare_block(0); + return prepare_scales_8(vuzp2q_u32(data.val[0], data.val[1]), vuzp2q_u32(data.val[2], data.val[3])); } + inline void prepare(int /*i*/, int j) { + if (j == 1) prepare_block(1); + } + +private: - inline void prepare2(uint8x16_t * b, const uint8_t * idx, const uint64_t * signs, uint32_t sidx) { + static inline void prepare2(uint8x16_t * b, const uint32_t * bits, const uint64_t * signs) { + const uint8_t * idx = (const uint8_t *)bits; b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]}); b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]}); - apply_signs_2(b, signs, sidx); + apply_signs_2(b, signs, bits[1]); } - inline void prepare(int /*i*/, int j) { - const uint8_t * idx = (const uint8_t *)(data.val + 2*j); - const uint32_t * sidx = (const uint32_t *)(data.val + 2*j+1); - prepare2(bits.b1.val + 0, idx, keven_signs, sidx[0]); idx += 4; - prepare2(bits.b1.val + 2, idx, keven_signs, sidx[1]); idx += 4; - prepare2(bits.b2.val + 0, idx, keven_signs, sidx[2]); idx += 4; - prepare2(bits.b2.val + 2, idx, keven_signs, sidx[3]); - } - - //static inline void prepare2(uint8x16_t * b, const uint32_t * bits, const uint64_t * signs) { - // const uint8_t * idx = (const uint8_t *)bits; - // b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]}); - // b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]}); - // apply_signs_2(b, signs, bits[1]); - //} - - //inline void prepare(int /*i*/, int j) { - // const uint32_t * q2 = (const uint32_t *)(data.val + 2*j); - // prepare2(bits.b1.val + 0, q2+0, keven_signs); - // prepare2(bits.b1.val + 2, q2+2, keven_signs); - // prepare2(bits.b2.val + 0, q2+4, keven_signs); - // prepare2(bits.b2.val + 2, q2+6, keven_signs); - //} + inline void prepare_block(int j) { + const uint32_t * q2 = (const uint32_t *)(data.val + 2*j); + prepare2(bits.b1.val + 0, q2+0, keven_signs); + prepare2(bits.b1.val + 2, q2+2, keven_signs); + prepare2(bits.b2.val + 0, q2+4, keven_signs); + prepare2(bits.b2.val + 2, q2+6, keven_signs); + } uint32x4x4_t data; - SimpleBits bits; - - float d; }; inline int32x4x4_t prepare_4bit_scales16(const uint8_t * sc) { @@ -2088,13 +2189,21 @@ struct DequantizerIQ2XS final : public BaseDequantizer { constexpr static int num_blocks() { return 16; } constexpr static bool should_scale_quants() { return false; } - template - inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { + SimpleBits bits; + float d; + + inline int32x4x4_t new_block(int i) { d = 0.125f * GGML_FP16_TO_FP32(x[i].d); prepare_internal(i, 0); return prepare_4bit_scales16(x[i].scales); } + inline void prepare(int i, int j) { + if (j == 1) prepare_internal(i, 1); + } + +private: + static inline uint8x16_t make1(const uint16_t * qs) { auto b = vcombine_u8(vld1_u8((const uint8_t *)(iq2xs_grid + (qs[0] & 511))), vld1_u8((const uint8_t *)(iq2xs_grid + (qs[1] & 511)))); auto s = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[0] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[1] >> 9)))); @@ -2113,19 +2222,11 @@ struct DequantizerIQ2XS final : public BaseDequantizer { make4(x[i].qs + 16*j + 8, bits.b2.val); } - inline void prepare(int i, int j) { - if (j == 1) prepare_internal(i, 1); - } - - SimpleBits bits; - - float d; - }; struct SignHelper { - IQK_ALWAYS_INLINE void apply_signs_1x(uint8x16_t * b, const uint8_t * sign_bits) const { + inline void apply_signs_1x(uint8x16_t * b, const uint8_t * sign_bits) const { auto aux = vcombine_u8(vdup_n_u8(sign_bits[0]), vdup_n_u8(sign_bits[1])); //auto s = vceqq_u8(vandq_u8(aux, smask), smask); //b[0] = vreinterpretq_u8_s8(vsubq_s8(vreinterpretq_s8_u8(veorq_u8(b[0], s)), vreinterpretq_s8_u8(s))); @@ -2144,13 +2245,21 @@ struct DequantizerIQ2S final : public BaseDequantizer { constexpr static int num_blocks() { return 16; } constexpr static bool should_scale_quants() { return false; } - template - inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { + SimpleBits bits; + float d; + + inline int32x4x4_t new_block(int i) { d = 0.125f * GGML_FP16_TO_FP32(x[i].d); prepare_internal(i, 0, bits); return prepare_4bit_scales16(x[i].scales); } + inline void prepare(int i, int j) { + if (j == 1) prepare_internal(i, 1, bits); + } + +private: + static inline void make4(const SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) { uint32_t aux32[2]; const uint16_t * aux16 = (const uint16_t *)aux32; @@ -2167,10 +2276,6 @@ struct DequantizerIQ2S final : public BaseDequantizer { } } - inline void prepare(int i, int j) { - if (j == 1) prepare_internal(i, 1, bits); - } - inline void prepare_internal(int i, int j, SimpleBits& sb) { const auto * qs = x[i].qs + 16*j; @@ -2181,11 +2286,7 @@ struct DequantizerIQ2S final : public BaseDequantizer { make4(sh, sign_bits+8, qs+8, qh+2, sb.b2.val); } - SimpleBits bits; SignHelper sh; - - float d; - }; struct DequantizerIQ3XXS final : public BaseDequantizer { @@ -2194,32 +2295,39 @@ struct DequantizerIQ3XXS final : public BaseDequantizer { constexpr static int num_blocks() { return 8; } constexpr static bool should_scale_quants() { return false; } - template - inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { + SimpleBits bits; + float d; + + inline int32x4x2_t new_block(int i) { d = 0.25f * GGML_FP16_TO_FP32(x[i].d); + auto q3data = vld1q_u8_x2(x[i].qs); gas = vld1q_u32_x2((const uint32_t *)(x[i].qs + QK_K/4)); + prepare_block((const uint8_t *)q3data.val, (const uint32_t *)gas.val, bits.b1.val, bits.b2.val); return prepare_scales_8(gas.val[0], gas.val[1]); } + inline void prepare(int i, int j) { + if (j == 1) { + auto q3data = vld1q_u8_x2(x[i].qs + 32); + prepare_block((const uint8_t *)q3data.val, (const uint32_t *)(gas.val + 1), bits.b1.val, bits.b2.val); + } + } + +private: + inline static void make2(const uint8_t * q3, uint32_t sidx, uint8x16_t * b) { b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[0]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[3]]}); b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[4]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[7]]}); apply_signs_2(b, keven_signs, sidx); } - inline void prepare(int i, int j) { - const auto * q3 = x[i].qs + 32*j; - const auto * signs = (const uint32_t *)(gas.val + j); - make2(q3+ 0, signs[0], bits.b1.val + 0); - make2(q3+ 8, signs[1], bits.b1.val + 2); - make2(q3+16, signs[2], bits.b2.val + 0); - make2(q3+24, signs[3], bits.b2.val + 2); + inline static void prepare_block(const uint8_t * q3, const uint32_t * signs, uint8x16_t * b1, uint8x16_t * b2) { + make2(q3+ 0, signs[0], b1 + 0); + make2(q3+ 8, signs[1], b1 + 2); + make2(q3+16, signs[2], b2 + 0); + make2(q3+24, signs[3], b2 + 2); } - SimpleBits bits; uint32x4x2_t gas; - - float d; - }; struct DequantizerIQ3S final : public BaseDequantizer { @@ -2228,10 +2336,17 @@ struct DequantizerIQ3S final : public BaseDequantizer { constexpr static int num_blocks() { return 8; } constexpr static bool should_scale_quants() { return false; } - template - inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { + SimpleBits bits; + float d; + + inline int32x4x2_t new_block(int i) { d = GGML_FP16_TO_FP32(x[i].d); uint32_t scales32[2]; + auto qs = vld1q_u8_x2(x[i].qs); + auto signs = vld1q_u8(x[i].signs); + + prepare_block((const uint8_t *)qs.val, x[i].qh, (const uint8_t *)&signs); + std::memcpy(scales32, x[i].scales, 4); scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101; scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101; @@ -2244,6 +2359,16 @@ struct DequantizerIQ3S final : public BaseDequantizer { return scales; } + inline void prepare(int i, int j) { + if (j == 1) { + auto qs = vld1q_u8_x2(x[i].qs + 32); + auto signs = vld1q_u8(x[i].signs + 16); + prepare_block((const uint8_t *)qs.val, x[i].qh + 4, (const uint8_t *)&signs); + } + } + +private: + static inline void make2(const SignHelper& sh, const uint8_t * sign_bits, const uint16x8_t& idx_l, uint8_t qh, const int16x8_t& hshift, uint8x16_t * b) { auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256))); @@ -2265,88 +2390,16 @@ struct DequantizerIQ3S final : public BaseDequantizer { return vld1q_s16(k_shift); } - inline void prepare(int i, int j) { - - const auto * qs = x[i].qs + 32*j; - const auto * qh = x[i].qh + 4*j; - - make4(sh, x[i].signs + 16*j + 0, qs+ 0, qh+0, hshift, bits.b1.val); - make4(sh, x[i].signs + 16*j + 8, qs+16, qh+2, hshift, bits.b2.val); + inline void prepare_block(const uint8_t * qs, const uint8_t * qh, const uint8_t * signs) { + make4(sh, signs + 0, qs+ 0, qh+0, hshift, bits.b1.val); + make4(sh, signs + 8, qs+16, qh+2, hshift, bits.b2.val); } - SimpleBits bits; SignHelper sh; const int16x8_t hshift = load_shift(); - float d; - }; -template -IQK_NOINLINE void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - assert(n % QK_K == 0); - const int nb = n / QK_K; - - Q8 q8(info); - - Dequantizer deq(vx, bx, nrc_y); - - for (int ix = 0; ix < nrc_x; ++ix) { - - deq.new_row(ix); - - float32x4_t acc[nrc_y]; - for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); - -//#pragma GCC unroll 4 - for (int i = 0; i < nb; ++i) { - - int32x4_t sumi[nrc_y]; - for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0); - - if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) { - deq.process_scales(i, q8, acc); - deq.prepare(i, 0); - deq.compute(q8, i, 0, sumi); - deq.prepare(i, 1); - deq.compute(q8, i, 1, sumi); - } else { - if constexpr (Dequantizer::num_blocks() == 8) { - auto scales = deq.new_block(i, q8, acc); - deq.prepare(i, 0); -#pragma GCC unroll 8 - for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); - deq.prepare(i, 1); -#pragma GCC unroll 8 - for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); - } - else if constexpr (Dequantizer::num_blocks() == 16) { - auto scales = deq.new_block(i, q8, acc); - deq.prepare(i, 0); -#pragma GCC unroll 8 - for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); - deq.prepare(i, 1); -#pragma GCC unroll 8 - for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); - } - else { - GGML_ASSERT(false); - } - } - -#pragma GCC unroll 8 - for (int iy = 0; iy < nrc_y; ++iy) { - acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); - } - } - -#pragma GCC unroll 8 - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, vaddvq_f32(acc[iy])); - } - } -} - // =========================================== Legacy quants template @@ -2799,6 +2852,18 @@ template void MulMat::set_functions(MulMat& m) { m.funcs[6] = mul_mat_qX_1_q8_1; m.funcs[7] = mul_mat_qX_1_q8_1; } + else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { + m.funcs[0] = mul_mat_qX_K_q8_K_IQ<1, Dequantizer>; + m.funcs[1] = mul_mat_qX_K_q8_K_IQ<2, Dequantizer>; + m.funcs[2] = mul_mat_qX_K_q8_K_IQ<3, Dequantizer>; + m.funcs[3] = mul_mat_qX_K_q8_K_IQ<4, Dequantizer>; + m.funcs[4] = mul_mat_qX_K_q8_K_IQ<5, Dequantizer>; + m.funcs[5] = mul_mat_qX_K_q8_K_IQ<6, Dequantizer>; + m.funcs[6] = mul_mat_qX_K_q8_K_IQ<7, Dequantizer>; + m.funcs[7] = mul_mat_qX_K_q8_K_IQ<8, Dequantizer>; + } else { m.funcs[0] = mul_mat_qX_K_q8_K_T<1, Dequantizer>; m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>; From 91393f6d5a12b1d3f58a901ddd882f507058d775 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 3 Jun 2024 10:29:45 +0200 Subject: [PATCH 11/13] Arm for i-quants: iterating With this version we get | cpu_info | model_filename | size | test | t/s | | ---------------------------: | -------------: | ---------: | -----: | ------: | | Apple M2 Max (+fp16+dotprod) | iq2xxs | 1.73 GiB | tg128 | 10.83 | | Apple M2 Max (+fp16+dotprod) | iq2xxs | 1.73 GiB | pp512 | 60.82 | | Apple M2 Max (+fp16+dotprod) | iq2xs | 1.89 GiB | tg128 | 10.79 | | Apple M2 Max (+fp16+dotprod) | iq2xs | 1.89 GiB | pp512 | 57.10 | | Apple M2 Max (+fp16+dotprod) | iq2m | 2.20 GiB | tg128 | 7.45 | | Apple M2 Max (+fp16+dotprod) | iq2m | 2.20 GiB | pp512 | 46.39 | | Apple M2 Max (+fp16+dotprod) | iq3xxs | 2.41 GiB | tg128 | 6.77 | | Apple M2 Max (+fp16+dotprod) | iq3xxs | 2.41 GiB | pp512 | 48.74 | | Apple M2 Max (+fp16+dotprod) | iq3m | 2.90 GiB | tg128 | 5.97 | | Apple M2 Max (+fp16+dotprod) | iq3m | 2.90 GiB | pp512 | 48.59 | --- llamafile/iqk_mul_mat.inc | 293 ++++++++++++++++++++++++++++---------- 1 file changed, 215 insertions(+), 78 deletions(-) diff --git a/llamafile/iqk_mul_mat.inc b/llamafile/iqk_mul_mat.inc index 87d77478a5..5720dedb2a 100644 --- a/llamafile/iqk_mul_mat.inc +++ b/llamafile/iqk_mul_mat.inc @@ -1623,6 +1623,26 @@ IQK_ALWAYS_INLINE void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16 auto pall = vpaddq_s32(p12, p34); sumi = vmlaq_s32(sumi, scales.val[j], pall); } +template +IQK_ALWAYS_INLINE void compute_8_blocks(const int8x16_t * qx, const Q8& q8, + const int32x4_t& scales, int iy, int i, int j, int32x4_t& sumi) { + auto mzero = vdupq_n_s32(0); + + auto q8b_1 = q8.load_quants(iy, i, 4*j+0); + auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[0], q8b_1.val[0]), qx[1], q8b_1.val[1]); // block 1 + auto q8b_2 = q8.load_quants(iy, i, 4*j+1); + auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[2], q8b_2.val[0]), qx[3], q8b_2.val[1]); // block 2 + auto p12 = vpaddq_s32(p1, p2); + + auto q8b_3 = q8.load_quants(iy, i, 4*j+2); + auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[4], q8b_3.val[0]), qx[5], q8b_3.val[1]); // block 3 + auto q8b_4 = q8.load_quants(iy, i, 4*j+3); + auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[6], q8b_4.val[0]), qx[7], q8b_4.val[1]); // block 4 + auto p34 = vpaddq_s32(p3, p4); + + auto pall = vpaddq_s32(p12, p34); + sumi = vmlaq_s32(sumi, scales, pall); +} template IQK_ALWAYS_INLINE void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8, @@ -2133,23 +2153,19 @@ inline void apply_signs_2(uint8x16_t * b, const uint64_t * signs, uint32_t sidx) b[1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[1]), s2)); } +IQK_ALWAYS_INLINE int32x4_t prepare_scales_8(const uint32x4_t& v1) { + return vreinterpretq_s32_u32(vsliq_n_u32(vdupq_n_u32(1), vshrq_n_u32(v1, 28), 1)); +} + struct DequantizerIQ2XXS final : public BaseDequantizer { DequantizerIQ2XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} - constexpr static int num_blocks() { return 8; } - constexpr static bool should_scale_quants() { return false; } - - SimpleBits bits; - float d; + IQK_ALWAYS_INLINE float new_block(int i) const { return 0.125f * GGML_FP16_TO_FP32(x[i].d); } - IQK_ALWAYS_INLINE int32x4x2_t new_block(int i) { - d = 0.125f * GGML_FP16_TO_FP32(x[i].d); - data = vld1q_u32_x4((const uint32_t *)x[i].qs); - prepare_block(0); - return prepare_scales_8(vuzp2q_u32(data.val[0], data.val[1]), vuzp2q_u32(data.val[2], data.val[3])); - } - inline void prepare(int /*i*/, int j) { - if (j == 1) prepare_block(1); + inline int32x4_t unpack(int i, int j, uint8x16_t * q) const { + auto data = vld1q_u32_x2((const uint32_t *)(x[i].qs + 16*j)); + prepare_all(data, q); + return prepare_scales_8(vuzp2q_u32(data.val[0], data.val[1])); } private: @@ -2161,15 +2177,13 @@ private: apply_signs_2(b, signs, bits[1]); } - inline void prepare_block(int j) { - const uint32_t * q2 = (const uint32_t *)(data.val + 2*j); - prepare2(bits.b1.val + 0, q2+0, keven_signs); - prepare2(bits.b1.val + 2, q2+2, keven_signs); - prepare2(bits.b2.val + 0, q2+4, keven_signs); - prepare2(bits.b2.val + 2, q2+6, keven_signs); + inline static void prepare_all(const uint32x4x2_t& data, uint8x16_t * quants) { + const uint32_t * q2 = (const uint32_t *)data.val; + prepare2(quants+0, q2+0, keven_signs); + prepare2(quants+2, q2+2, keven_signs); + prepare2(quants+4, q2+4, keven_signs); + prepare2(quants+6, q2+6, keven_signs); } - - uint32x4x4_t data; }; inline int32x4x4_t prepare_4bit_scales16(const uint8_t * sc) { @@ -2204,17 +2218,18 @@ struct DequantizerIQ2XS final : public BaseDequantizer { private: - static inline uint8x16_t make1(const uint16_t * qs) { - auto b = vcombine_u8(vld1_u8((const uint8_t *)(iq2xs_grid + (qs[0] & 511))), vld1_u8((const uint8_t *)(iq2xs_grid + (qs[1] & 511)))); - auto s = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[0] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[1] >> 9)))); - return vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b), s)); + static void make2(const uint16_t * qs, uint8x16_t * b) { + auto v1 = vcombine_s8(vld1_s8((const int8_t *)(iq2xs_grid + (qs[0] & 511))), vld1_s8((const int8_t *)(iq2xs_grid + (qs[1] & 511)))); + auto v2 = vcombine_s8(vld1_s8((const int8_t *)(iq2xs_grid + (qs[2] & 511))), vld1_s8((const int8_t *)(iq2xs_grid + (qs[3] & 511)))); + auto s1 = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[0] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[1] >> 9)))); + auto s2 = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[2] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[3] >> 9)))); + b[0] = vreinterpretq_u8_s8(vmulq_s8(v1, s1)); + b[1] = vreinterpretq_u8_s8(vmulq_s8(v2, s2)); } inline static void make4(const uint16_t * qs, uint8x16_t * b) { - b[0] = make1(qs + 0); - b[1] = make1(qs + 2); - b[2] = make1(qs + 4); - b[3] = make1(qs + 6); + make2(qs + 0, b + 0); + make2(qs + 4, b + 2); } IQK_ALWAYS_INLINE void prepare_internal(int i, int j) { @@ -2224,19 +2239,87 @@ private: }; +static const uint64_t kall_signs[256] = { + 0x0101010101010101, 0x01010101010101ff, 0x010101010101ff01, 0x010101010101ffff, + 0x0101010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0x0101010101ffffff, + 0x01010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0x01010101ff01ffff, + 0x01010101ffff0101, 0x01010101ffff01ff, 0x01010101ffffff01, 0x01010101ffffffff, + 0x010101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0x010101ff0101ffff, + 0x010101ff01ff0101, 0x010101ff01ff01ff, 0x010101ff01ffff01, 0x010101ff01ffffff, + 0x010101ffff010101, 0x010101ffff0101ff, 0x010101ffff01ff01, 0x010101ffff01ffff, + 0x010101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0x010101ffffffffff, + 0x0101ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0x0101ff010101ffff, + 0x0101ff0101ff0101, 0x0101ff0101ff01ff, 0x0101ff0101ffff01, 0x0101ff0101ffffff, + 0x0101ff01ff010101, 0x0101ff01ff0101ff, 0x0101ff01ff01ff01, 0x0101ff01ff01ffff, + 0x0101ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0x0101ff01ffffffff, + 0x0101ffff01010101, 0x0101ffff010101ff, 0x0101ffff0101ff01, 0x0101ffff0101ffff, + 0x0101ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0x0101ffff01ffffff, + 0x0101ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0x0101ffffff01ffff, + 0x0101ffffffff0101, 0x0101ffffffff01ff, 0x0101ffffffffff01, 0x0101ffffffffffff, + 0x01ff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0x01ff01010101ffff, + 0x01ff010101ff0101, 0x01ff010101ff01ff, 0x01ff010101ffff01, 0x01ff010101ffffff, + 0x01ff0101ff010101, 0x01ff0101ff0101ff, 0x01ff0101ff01ff01, 0x01ff0101ff01ffff, + 0x01ff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0x01ff0101ffffffff, + 0x01ff01ff01010101, 0x01ff01ff010101ff, 0x01ff01ff0101ff01, 0x01ff01ff0101ffff, + 0x01ff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0x01ff01ff01ffffff, + 0x01ff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0x01ff01ffff01ffff, + 0x01ff01ffffff0101, 0x01ff01ffffff01ff, 0x01ff01ffffffff01, 0x01ff01ffffffffff, + 0x01ffff0101010101, 0x01ffff01010101ff, 0x01ffff010101ff01, 0x01ffff010101ffff, + 0x01ffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0x01ffff0101ffffff, + 0x01ffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0x01ffff01ff01ffff, + 0x01ffff01ffff0101, 0x01ffff01ffff01ff, 0x01ffff01ffffff01, 0x01ffff01ffffffff, + 0x01ffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0x01ffffff0101ffff, + 0x01ffffff01ff0101, 0x01ffffff01ff01ff, 0x01ffffff01ffff01, 0x01ffffff01ffffff, + 0x01ffffffff010101, 0x01ffffffff0101ff, 0x01ffffffff01ff01, 0x01ffffffff01ffff, + 0x01ffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0x01ffffffffffffff, + 0xff01010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0xff0101010101ffff, + 0xff01010101ff0101, 0xff01010101ff01ff, 0xff01010101ffff01, 0xff01010101ffffff, + 0xff010101ff010101, 0xff010101ff0101ff, 0xff010101ff01ff01, 0xff010101ff01ffff, + 0xff010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0xff010101ffffffff, + 0xff0101ff01010101, 0xff0101ff010101ff, 0xff0101ff0101ff01, 0xff0101ff0101ffff, + 0xff0101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0xff0101ff01ffffff, + 0xff0101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0xff0101ffff01ffff, + 0xff0101ffffff0101, 0xff0101ffffff01ff, 0xff0101ffffffff01, 0xff0101ffffffffff, + 0xff01ff0101010101, 0xff01ff01010101ff, 0xff01ff010101ff01, 0xff01ff010101ffff, + 0xff01ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0xff01ff0101ffffff, + 0xff01ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0xff01ff01ff01ffff, + 0xff01ff01ffff0101, 0xff01ff01ffff01ff, 0xff01ff01ffffff01, 0xff01ff01ffffffff, + 0xff01ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0xff01ffff0101ffff, + 0xff01ffff01ff0101, 0xff01ffff01ff01ff, 0xff01ffff01ffff01, 0xff01ffff01ffffff, + 0xff01ffffff010101, 0xff01ffffff0101ff, 0xff01ffffff01ff01, 0xff01ffffff01ffff, + 0xff01ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0xff01ffffffffffff, + 0xffff010101010101, 0xffff0101010101ff, 0xffff01010101ff01, 0xffff01010101ffff, + 0xffff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0xffff010101ffffff, + 0xffff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0xffff0101ff01ffff, + 0xffff0101ffff0101, 0xffff0101ffff01ff, 0xffff0101ffffff01, 0xffff0101ffffffff, + 0xffff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0xffff01ff0101ffff, + 0xffff01ff01ff0101, 0xffff01ff01ff01ff, 0xffff01ff01ffff01, 0xffff01ff01ffffff, + 0xffff01ffff010101, 0xffff01ffff0101ff, 0xffff01ffff01ff01, 0xffff01ffff01ffff, + 0xffff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0xffff01ffffffffff, + 0xffffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0xffffff010101ffff, + 0xffffff0101ff0101, 0xffffff0101ff01ff, 0xffffff0101ffff01, 0xffffff0101ffffff, + 0xffffff01ff010101, 0xffffff01ff0101ff, 0xffffff01ff01ff01, 0xffffff01ff01ffff, + 0xffffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0xffffff01ffffffff, + 0xffffffff01010101, 0xffffffff010101ff, 0xffffffff0101ff01, 0xffffffff0101ffff, + 0xffffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0xffffffff01ffffff, + 0xffffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0xffffffffff01ffff, + 0xffffffffffff0101, 0xffffffffffff01ff, 0xffffffffffffff01, 0xffffffffffffffff, +}; + struct SignHelper { - inline void apply_signs_1x(uint8x16_t * b, const uint8_t * sign_bits) const { - auto aux = vcombine_u8(vdup_n_u8(sign_bits[0]), vdup_n_u8(sign_bits[1])); - //auto s = vceqq_u8(vandq_u8(aux, smask), smask); - //b[0] = vreinterpretq_u8_s8(vsubq_s8(vreinterpretq_s8_u8(veorq_u8(b[0], s)), vreinterpretq_s8_u8(s))); - // Not much of a difference compared to the above. Perhaps tiny little bit faster. - auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1)); + IQK_ALWAYS_INLINE void apply_signs_1x(uint8x16_t * b, const uint8_t * sign_bits) const { + auto s = vreinterpretq_s8_u64(uint64x2_t{kall_signs[sign_bits[0]], kall_signs[sign_bits[1]]}); + //auto aux = vcombine_u8(vdup_n_u8(sign_bits[0]), vdup_n_u8(sign_bits[1])); + ////auto s = vceqq_u8(vandq_u8(aux, smask), smask); + ////b[0] = vreinterpretq_u8_s8(vsubq_s8(vreinterpretq_s8_u8(veorq_u8(b[0], s)), vreinterpretq_s8_u8(s))); + //// Not much of a difference compared to the above. Perhaps tiny little bit faster. + //auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1)); b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s)); } - const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); - const uint8x16_t m1 = vdupq_n_u8(1); + //const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); + //const uint8x16_t m1 = vdupq_n_u8(1); }; struct DequantizerIQ2S final : public BaseDequantizer { @@ -2260,7 +2343,25 @@ struct DequantizerIQ2S final : public BaseDequantizer { private: - static inline void make4(const SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) { + static void make4(const SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) { + // This commented out version is faster by a few percent. + // But using it makes iq2_xxs run 20% slower. This is the strangest thing I have ever seen. + //uint32_t aux32[2]; + //const uint16_t * aux16 = (const uint16_t *)aux32; + //aux32[1] = (qh[0] << 4) | (qh[0] << 18); + //aux32[0] = (aux32[1] << 4) & 0x03000300; + //aux32[1] &= 0x03000300; + //b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2s_grid[qs[0] | aux16[0]], iq2s_grid[qs[1] | aux16[1]]}); + //b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2s_grid[qs[2] | aux16[2]], iq2s_grid[qs[3] | aux16[3]]}); + //sh.apply_signs_1x(b+0, sign_bits+0); + //sh.apply_signs_1x(b+1, sign_bits+2); + //aux32[1] = (qh[1] << 4) | (qh[1] << 18); + //aux32[0] = (aux32[1] << 4) & 0x03000300; + //aux32[1] &= 0x03000300; + //b[2] = vreinterpretq_u8_u64(uint64x2_t{iq2s_grid[qs[4] | aux16[0]], iq2s_grid[qs[5] | aux16[1]]}); + //b[3] = vreinterpretq_u8_u64(uint64x2_t{iq2s_grid[qs[6] | aux16[2]], iq2s_grid[qs[7] | aux16[3]]}); + //sh.apply_signs_1x(b+2, sign_bits+4); + //sh.apply_signs_1x(b+3, sign_bits+6); uint32_t aux32[2]; const uint16_t * aux16 = (const uint16_t *)aux32; for (int k = 0; k < 2; ++k) { @@ -2276,7 +2377,7 @@ private: } } - inline void prepare_internal(int i, int j, SimpleBits& sb) { + void prepare_internal(int i, int j, SimpleBits& sb) { const auto * qs = x[i].qs + 16*j; const auto * qh = x[i].qh + 4*j; @@ -2292,42 +2393,28 @@ private: struct DequantizerIQ3XXS final : public BaseDequantizer { DequantizerIQ3XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} - constexpr static int num_blocks() { return 8; } - constexpr static bool should_scale_quants() { return false; } - - SimpleBits bits; - float d; - - inline int32x4x2_t new_block(int i) { - d = 0.25f * GGML_FP16_TO_FP32(x[i].d); - auto q3data = vld1q_u8_x2(x[i].qs); - gas = vld1q_u32_x2((const uint32_t *)(x[i].qs + QK_K/4)); - prepare_block((const uint8_t *)q3data.val, (const uint32_t *)gas.val, bits.b1.val, bits.b2.val); - return prepare_scales_8(gas.val[0], gas.val[1]); - } + IQK_ALWAYS_INLINE float new_block(int i) const { return 0.25f * GGML_FP16_TO_FP32(x[i].d); } - inline void prepare(int i, int j) { - if (j == 1) { - auto q3data = vld1q_u8_x2(x[i].qs + 32); - prepare_block((const uint8_t *)q3data.val, (const uint32_t *)(gas.val + 1), bits.b1.val, bits.b2.val); - } + inline int32x4_t unpack(int i, int j, uint8x16_t * q) const { + auto q3data = vld1q_u8_x2(x[i].qs + 32*j); + auto gas = vld1q_u32((const uint32_t *)(x[i].qs + QK_K/4 + 16*j)); + prepare_block((const uint8_t *)q3data.val, (const uint32_t *)&gas, q); + return prepare_scales_8(gas); } private: - inline static void make2(const uint8_t * q3, uint32_t sidx, uint8x16_t * b) { + inline static void make2(const uint8_t * q3, const uint32_t sidx, uint8x16_t * b) { b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[0]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[3]]}); b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[4]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[7]]}); apply_signs_2(b, keven_signs, sidx); } - inline static void prepare_block(const uint8_t * q3, const uint32_t * signs, uint8x16_t * b1, uint8x16_t * b2) { - make2(q3+ 0, signs[0], b1 + 0); - make2(q3+ 8, signs[1], b1 + 2); - make2(q3+16, signs[2], b2 + 0); - make2(q3+24, signs[3], b2 + 2); + inline static void prepare_block(const uint8_t * q3, const uint32_t * signs, uint8x16_t * quants) { + make2(q3+ 0, signs[0], quants + 0); + make2(q3+ 8, signs[1], quants + 2); + make2(q3+16, signs[2], quants + 4); + make2(q3+24, signs[3], quants + 6); } - - uint32x4x2_t gas; }; struct DequantizerIQ3S final : public BaseDequantizer { @@ -2390,9 +2477,11 @@ private: return vld1q_s16(k_shift); } - inline void prepare_block(const uint8_t * qs, const uint8_t * qh, const uint8_t * signs) { - make4(sh, signs + 0, qs+ 0, qh+0, hshift, bits.b1.val); - make4(sh, signs + 8, qs+16, qh+2, hshift, bits.b2.val); + inline void prepare_block(const uint8_t * qs, const uint8_t * qh, const uint8_t * sign_bits) { + auto signs = vld1q_u8(sign_bits); + auto s = (const uint8_t *)&signs; + make4(sh, s + 0, qs+ 0, qh+0, hshift, bits.b1.val); + make4(sh, s + 8, qs+16, qh+2, hshift, bits.b2.val); } SignHelper sh; @@ -2400,6 +2489,44 @@ private: }; +template +IQK_NOINLINE void mul_mat_qX_K_q8_K_IQXXS(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8 q8(info); + Dequantizer deq(vx, bx, nrc_y); + uint8x16_t qx[8]; + int32x4_t sumi[nrc_y]; + float32x4_t acc[nrc_y]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); + + for (int i = 0; i < nb; ++i) { + float d = deq.new_block(i); + auto scales = deq.unpack(i, 0, qx); +#pragma GCC unroll 8 + for (int iy = 0; iy < nrc_y; ++iy) { + sumi[iy] = vdupq_n_s32(0); + compute_8_blocks((const int8x16_t *)qx, q8, scales, iy, i, 0, sumi[iy]); + } + scales = deq.unpack(i, 1, qx); +#pragma GCC unroll 8 + for (int iy = 0; iy < nrc_y; ++iy) { + compute_8_blocks((const int8x16_t *)qx, q8, scales, iy, i, 1, sumi[iy]); + acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*q8.scale(iy, i)), vcvtq_f32_s32(sumi[iy])); + } + } +#pragma GCC unroll 8 + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(acc[iy])); + } + } +} + // =========================================== Legacy quants template @@ -2852,8 +2979,18 @@ template void MulMat::set_functions(MulMat& m) { m.funcs[6] = mul_mat_qX_1_q8_1; m.funcs[7] = mul_mat_qX_1_q8_1; } - else if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || + else if constexpr (std::is_same_v || std::is_same_v) { + m.funcs[0] = mul_mat_qX_K_q8_K_IQXXS<1, Dequantizer>; + m.funcs[1] = mul_mat_qX_K_q8_K_IQXXS<2, Dequantizer>; + m.funcs[2] = mul_mat_qX_K_q8_K_IQXXS<3, Dequantizer>; + m.funcs[3] = mul_mat_qX_K_q8_K_IQXXS<4, Dequantizer>; + m.funcs[4] = mul_mat_qX_K_q8_K_IQXXS<5, Dequantizer>; + m.funcs[5] = mul_mat_qX_K_q8_K_IQXXS<6, Dequantizer>; + m.funcs[6] = mul_mat_qX_K_q8_K_IQXXS<7, Dequantizer>; + m.funcs[7] = mul_mat_qX_K_q8_K_IQXXS<8, Dequantizer>; + } + else if constexpr (std::is_same_v || + std::is_same_v || std::is_same_v) { m.funcs[0] = mul_mat_qX_K_q8_K_IQ<1, Dequantizer>; m.funcs[1] = mul_mat_qX_K_q8_K_IQ<2, Dequantizer>; @@ -2901,20 +3038,20 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int N case GGML_TYPE_IQ4_XS: MulMat::set_functions(m); break; - case GGML_TYPE_IQ2_XXS: - MulMat::set_functions(m); + case GGML_TYPE_IQ3_S: + MulMat::set_functions(m); break; - case GGML_TYPE_IQ2_XS: - MulMat::set_functions(m); + case GGML_TYPE_IQ3_XXS: + MulMat::set_functions(m); break; case GGML_TYPE_IQ2_S: MulMat::set_functions(m); break; - case GGML_TYPE_IQ3_XXS: - MulMat::set_functions(m); + case GGML_TYPE_IQ2_XS: + MulMat::set_functions(m); break; - case GGML_TYPE_IQ3_S: - MulMat::set_functions(m); + case GGML_TYPE_IQ2_XXS: + MulMat::set_functions(m); break; case GGML_TYPE_Q4_0: MulMat::set_functions(m); From 7396ce0c82c574332e3942ae7ff70ad9bf1f4585 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 3 Jun 2024 11:34:16 +0200 Subject: [PATCH 12/13] Arm for i-quants: cleanup and comments --- llamafile/iqk_mul_mat.inc | 68 ++++++++++++++++----------------------- 1 file changed, 28 insertions(+), 40 deletions(-) diff --git a/llamafile/iqk_mul_mat.inc b/llamafile/iqk_mul_mat.inc index 5720dedb2a..d58d1b0518 100644 --- a/llamafile/iqk_mul_mat.inc +++ b/llamafile/iqk_mul_mat.inc @@ -23,6 +23,12 @@ #include "llama.cpp/ggml-quants.h" #include "sgemm.h" +// For i-quants, I had to explicitely specify which +// functions to inline / not inline (at least for some +// of the functions), else performance would be significantly +// lower. This is worrysome as things can change with, +// e.g., a different compiler version or running on a different +// CPU. #ifdef _MSC_VER #define IQK_NOINLINE __declspec(noinline) #define IQK_ALWAYS_INLINE inline @@ -37,16 +43,21 @@ // clang-format off // This matrix - vector and matrix - matrix multiplication implementation -// for k-quants and IQ4_XS makes prompt processing 150-200% faster +// for legacy quants, k-quants and i-quants makes prompt processing 150-200% +// (legacy and k-quants) or 250-400% (i-quants) faster. // compared to mainline llama.cpp (and llamafile). -// It is AVX2 only for now. +// It provides implementations for ARM_NEON (all quants) and AVX2 +// (all quants except sub-4 bit i-quants). // // Main idea is that unpacking the quants and the block scales to -// be ready for dot products with the corresponding Q8_K quants -// takes time. Hence, if we are performing a QX x Q8_K matrix matrix +// be ready for dot products with the corresponding Q8_Y quants +// takes time (here 'Y' stands for K, 0, or 1, depending on quantization type). +// Hence, if we are performing a QX x Q8_Y matrix matrix // multiplication (as needed for prompt processing), we can get // a significant speedup by reusing the unpacked QX quants and scales -// for multiplication with several Q8_K columns. +// for multiplication with several Q8_K columns. We also achieve fewer +// loads from memory, which is the main purpose of tiling in general +// purpose matrix multiplication packages. #include #include @@ -96,11 +107,7 @@ struct MulMat { std::array funcs = {}; //inline void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) { IQK_NOINLINE void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) { -#ifdef __aarch64__ - constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small) -#else - constexpr int k_x_step = 64; // This works best on my Ryzen-7950X (but differences to other tile size are small) -#endif + constexpr int k_x_step = 64; // This works best on my Ryzen-7950X and M2 Max CPUs (but differences to other tile size are small) int n_step = (nrc_y - info.cur_y)/funcs.size(); if (n_step > 0) { for (int ix = 0; ix < nrc_x; ix += k_x_step) { @@ -2111,11 +2118,6 @@ struct DequantizerIQ4XS final : public BaseDequantizer { } inline void prepare(int i, int j) { bits.prepare16(x[i].qs+64*j); - //if (nrc == 1) { - // bits.prepare16_v2(x[i].qs+64*j); - //} else { - // bits.prepare16(x[i].qs+64*j); - //} for (int k = 0; k < 4; ++k) { bits.b1.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b1.val[k])); bits.b2.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b2.val[k])); @@ -2141,8 +2143,6 @@ IQK_ALWAYS_INLINE int32x4x2_t prepare_scales_8(const uint32x4_t& v1, const uint3 auto one = vdupq_n_u32(1); scales.val[0] = vreinterpretq_s32_u32(vsliq_n_u32(one, vshrq_n_u32(v1, 28), 1)); scales.val[1] = vreinterpretq_s32_u32(vsliq_n_u32(one, vshrq_n_u32(v2, 28), 1)); - //scales.val[0] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v1, 28), 1), vdupq_n_u32(1))); - //scales.val[1] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v2, 28), 1), vdupq_n_u32(1))); return scales; } @@ -2239,6 +2239,11 @@ private: }; +// So, I hate to include this table, but with the GCC 12.3 compiler +// bundled in the Cosmopolitan tools, loading the unpacked sign bytes +// from this table using the packed 8 sign bits as index is faster than +// using the standard trick of vceqq_u8(vandq_u8(bits, mask), mask) to +// expand the bits to bytes. static const uint64_t kall_signs[256] = { 0x0101010101010101, 0x01010101010101ff, 0x010101010101ff01, 0x010101010101ffff, 0x0101010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0x0101010101ffffff, @@ -2310,14 +2315,13 @@ struct SignHelper { IQK_ALWAYS_INLINE void apply_signs_1x(uint8x16_t * b, const uint8_t * sign_bits) const { auto s = vreinterpretq_s8_u64(uint64x2_t{kall_signs[sign_bits[0]], kall_signs[sign_bits[1]]}); - //auto aux = vcombine_u8(vdup_n_u8(sign_bits[0]), vdup_n_u8(sign_bits[1])); - ////auto s = vceqq_u8(vandq_u8(aux, smask), smask); - ////b[0] = vreinterpretq_u8_s8(vsubq_s8(vreinterpretq_s8_u8(veorq_u8(b[0], s)), vreinterpretq_s8_u8(s))); - //// Not much of a difference compared to the above. Perhaps tiny little bit faster. - //auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1)); + // Normally we would expect this to be faster, but it isn't. + // auto aux = vcombine_u8(vdup_n_u8(sign_bits[0]), vdup_n_u8(sign_bits[1])); + // auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1)); b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s)); } + // We would need these two if we weren't loading from the unpacked sign table. //const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); //const uint8x16_t m1 = vdupq_n_u8(1); }; @@ -2344,24 +2348,6 @@ struct DequantizerIQ2S final : public BaseDequantizer { private: static void make4(const SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) { - // This commented out version is faster by a few percent. - // But using it makes iq2_xxs run 20% slower. This is the strangest thing I have ever seen. - //uint32_t aux32[2]; - //const uint16_t * aux16 = (const uint16_t *)aux32; - //aux32[1] = (qh[0] << 4) | (qh[0] << 18); - //aux32[0] = (aux32[1] << 4) & 0x03000300; - //aux32[1] &= 0x03000300; - //b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2s_grid[qs[0] | aux16[0]], iq2s_grid[qs[1] | aux16[1]]}); - //b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2s_grid[qs[2] | aux16[2]], iq2s_grid[qs[3] | aux16[3]]}); - //sh.apply_signs_1x(b+0, sign_bits+0); - //sh.apply_signs_1x(b+1, sign_bits+2); - //aux32[1] = (qh[1] << 4) | (qh[1] << 18); - //aux32[0] = (aux32[1] << 4) & 0x03000300; - //aux32[1] &= 0x03000300; - //b[2] = vreinterpretq_u8_u64(uint64x2_t{iq2s_grid[qs[4] | aux16[0]], iq2s_grid[qs[5] | aux16[1]]}); - //b[3] = vreinterpretq_u8_u64(uint64x2_t{iq2s_grid[qs[6] | aux16[2]], iq2s_grid[qs[7] | aux16[3]]}); - //sh.apply_signs_1x(b+2, sign_bits+4); - //sh.apply_signs_1x(b+3, sign_bits+6); uint32_t aux32[2]; const uint16_t * aux16 = (const uint16_t *)aux32; for (int k = 0; k < 2; ++k) { @@ -3016,6 +3002,8 @@ template void MulMat::set_functions(MulMat& m) { bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int Ny) { row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00); + (void)Ny; + // Uncommenting out this would disable iqk_mul_mat for matrix x vector multiplications. //if (Ny == 1 && (typeA == GGML_TYPE_IQ2_XXS || typeA == GGML_TYPE_IQ2_XS || typeA == GGML_TYPE_IQ2_S || // typeA == GGML_TYPE_IQ3_XXS || typeA == GGML_TYPE_IQ3_S)) return false; From 7a962c78480524426c1d012812e2d10d56307e75 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 3 Jun 2024 12:40:25 +0200 Subject: [PATCH 13/13] Remove forgotten experimental change in q3_K implementation --- llamafile/iqk_mul_mat.inc | 33 +++------------------------------ 1 file changed, 3 insertions(+), 30 deletions(-) diff --git a/llamafile/iqk_mul_mat.inc b/llamafile/iqk_mul_mat.inc index d58d1b0518..5b9c7b1a5d 100644 --- a/llamafile/iqk_mul_mat.inc +++ b/llamafile/iqk_mul_mat.inc @@ -1877,8 +1877,7 @@ struct DequantizerQ5K final : public BaseDequantizer { return s8.process_scales_mins(x[i], q8, i, acc); } inline void prepare(int i, int j) { - if (nrc == 1) bits.prepare_v2(x[i].qs+64*j); - else bits.prepare(x[i].qs+64*j); + bits.prepare(x[i].qs+64*j); h.apply(bits.b1, bits.b2, j == 0); } @@ -1955,7 +1954,6 @@ struct DequantizerQ3K final : public BaseDequantizer { inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { d = GGML_FP16_TO_FP32(x[i].d); h.bits = vld1q_u8_x2(x[i].hmask); - mask = vdupq_n_u8(0x01); const uint16_t * sc16 = (const uint16_t *)x[i].scales; uint32_t aux0 = sc16[0] | (sc16[1] << 16); uint32_t aux1 = sc16[2] | (sc16[3] << 16); @@ -1964,43 +1962,18 @@ struct DequantizerQ3K final : public BaseDequantizer { aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030); aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030); aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030); - auto scales8 = vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32)); - if (nrc > 1) { - return process_scales_mins_16(scales8, q8, acc, i, -4.f*d); - } - int16x8x2_t scales16; - scales16.val[0] = vmovl_s8(vget_low_s8(scales8)); - scales16.val[1] = vmovl_s8(vget_high_s8(scales8)); - return make_wider(scales16); + return process_scales_mins_16(vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32)), q8, acc, i, -4.f*d); } inline void prepare(int i, int j) { bits.prepare(x[i].qs+32*j); - if (nrc > 1) { - h.apply(bits.b1, bits.b2, j == 0); - } else { - auto minus4 = vdupq_n_u8(0xfc); - auto zero = vdupq_n_u8(0); - bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); - bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); - mask = vshlq_n_u8(mask, 1); - bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); - bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); - mask = vshlq_n_u8(mask, 1); - bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); - bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); - mask = vshlq_n_u8(mask, 1); - bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero))); - bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero))); - mask = vshlq_n_u8(mask, 1); - } + h.apply(bits.b1, bits.b2, j == 0); } uint32_t aux32[4]; Q2bits bits; - uint8x16_t mask; HighBit3 h; float d;