Skip to content

Commit

Permalink
A faster version for Q4_1 x Q8_0 dot products
Browse files Browse the repository at this point in the history
The idea nehind being that Q8_0 quantized
values get used many times in the matrix multiplications
where they are involved. In the current implementations,
when we are evaluating the dot products, we need to compute
the sum of the quants in the Q8_0 vector, so the same
operation is repeated many times. Here we pre-compute
the sum during Q8_0 quantization, store it in the
now modified block_q8_0 struct, and then reuse this
result in the subsequent dot products.

In a synthetic benchmark (just compute a bunch of dot
products), this change speeds up the Q4_1 * Q8_0 dot
product by 80%, making the performance identical to
Q4_0 * Q8_0.

In practical application, I see a ~15% gain in speed for
token prediction on M2, and ~5% gain on Ryzen 7950X.
The speed gain in the prompt evaluation is much bigger
(around 50%).

I have only done the change for the scalar version,
ARM_NEON, and AVX2, so we still need an AVX implementation.
  • Loading branch information
Kawrakow committed Apr 20, 2023
1 parent 02d6988 commit b51101a
Show file tree
Hide file tree
Showing 3 changed files with 268 additions and 37 deletions.
128 changes: 91 additions & 37 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -647,9 +647,10 @@ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2
#define QK8_0 32
typedef struct {
float d; // delta
float s; // d * sum(qs[i])
int8_t qs[QK8_0]; // quants
} block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
static_assert(sizeof(block_q8_0) == 2*sizeof(float) + QK8_0, "wrong q8_0 block size/padding");


// reference implementation for deterministic creation of model files
Expand Down Expand Up @@ -1247,10 +1248,13 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r

y[i].d = d;

int sum = 0;
for (int l = 0; l < QK8_0; ++l) {
const float v = x[i*QK8_0 + l]*id;
y[i].qs[l] = roundf(v);
sum += y[i].qs[l];
}
y[i].s = d * sum;
}
}

Expand Down Expand Up @@ -1280,6 +1284,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int

y[i].d = d;

int32x4_t accv = vdupq_n_s32(0);

for (int l = 0; l < 8; l++) {
const float32x4_t v = vmulq_n_f32(srcv[l], id);
const int32x4_t vi = vcvtnq_s32_f32(v);
Expand All @@ -1288,7 +1294,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);

accv = vaddq_s32(accv, vi);
}
int32_t sum = vaddvq_s32(accv);
y[i].s = d * sum;
}
#elif defined(__AVX2__) || defined(__AVX__)
for (int i = 0; i < nb; i++) {
Expand Down Expand Up @@ -1336,6 +1346,16 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
__m256i i3 = _mm256_cvtps_epi32( v3 );

#if defined(__AVX2__)

// Compute the sum of the quants
// There is not better way of doing this???
__m256i acc = _mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3));
__m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(acc), _mm256_extracti128_si256(acc, 1));
__m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
__m128i sum64 = _mm_add_epi32(hi64, sum128);
__m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
y[i].s = d * _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));

// Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
Expand Down Expand Up @@ -1378,6 +1398,14 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
// scalar
quantize_row_q8_0_reference(x, y, k);
#endif
#if defined __AVX__
// TODO: vectorize this
for (int i=0; i<nb; ++i) {
int sum = 0;
for (int l=0; l<QK8_0; ++l) sum += y[i].qs[l];
y[i].s = y[i].d * sum;
}
#endif
}

static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
Expand Down Expand Up @@ -2282,14 +2310,18 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);

float sum8 = 0;

for (int i = 0; i < nb; i += 2) {
const block_q4_0 * restrict x0 = &x[i + 0];
const block_q4_0 * restrict x1 = &x[i + 1];
const block_q8_0 * restrict y0 = &y[i + 0];
const block_q8_0 * restrict y1 = &y[i + 1];

sum8 += x0->d * y0->s + x1->d * y1->s;

const uint8x16_t m4b = vdupq_n_u8(0xf);
const int8x16_t s8b = vdupq_n_s8(0x8);
//const int8x16_t s8b = vdupq_n_s8(0x8);

const uint8x16_t v0_0 = vld1q_u8(x0->qs);
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
Expand All @@ -2301,10 +2333,10 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));

// sub 8
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
//const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
//const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
//const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
//const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);

// load y
const int8x16_t v1_0l = vld1q_s8(y0->qs);
Expand All @@ -2320,21 +2352,31 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *

#if defined(__ARM_FEATURE_DOTPROD)
// dot product into int32x4_t
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
//const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
//const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);

sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
#else
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
//const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
//const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
//const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
//const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));

const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
//const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
//const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
//const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
//const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));

const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
Expand All @@ -2346,7 +2388,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
#endif
}

sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8;
#elif defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
Expand Down Expand Up @@ -2479,12 +2521,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);

float summs = 0;

for (int i = 0; i < nb; i += 2) {
const block_q4_1 * restrict x0 = &x[i + 0];
const block_q4_1 * restrict x1 = &x[i + 1];
const block_q8_0 * restrict y0 = &y[i + 0];
const block_q8_0 * restrict y1 = &y[i + 1];

summs += x0->m * y0->s + x1->m * y1->s;

const uint8x16_t m4b = vdupq_n_u8(0xf);

const uint8x16_t v0_0 = vld1q_u8(x0->qs);
Expand All @@ -2508,16 +2554,18 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);

const int16x8_t s0i = vaddq_s16(
vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
// We no longer need this. We have computed the sum of the y quants during quantization,
// so we get the same as these via the scalar instruction above (summs += x0->m * y0->s + x1->m * y1->s)
//const int16x8_t s0i = vaddq_s16(
// vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
// vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));

const int16x8_t s1i = vaddq_s16(
vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
//const int16x8_t s1i = vaddq_s16(
// vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
// vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));

sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
//sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
//sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);

#if defined(__ARM_FEATURE_DOTPROD)
// dot product into int32x4_t
Expand Down Expand Up @@ -2547,24 +2595,28 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
#endif
}

sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
#elif defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();

float summs = 0;

// Main loop
for (int i = 0; i < nb; ++i) {
const float * d0 = &x[i].d;
const float * d1 = &y[i].d;
const float * m0 = &x[i].m;
//const float * m0 = &x[i].m;

summs += x[i].m * y[i].s;

const __m256 d0v = _mm256_broadcast_ss( d0 );
const __m256 d1v = _mm256_broadcast_ss( d1 );
const __m256 m0v = _mm256_broadcast_ss( m0 );
//const __m256 m0v = _mm256_broadcast_ss( m0 );

// Compute combined scales
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
//const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );

// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
const __m256i bx = bytesFromNibbles( x[i].qs );
Expand All @@ -2587,14 +2639,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
// Accumulate d0*d1*x*y
acc = _mm256_fmadd_ps( d0d1, xy, acc );

// Compute sum of y values
const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
// We no longer need this. We have computed the sum of the y quants during quantization,
// so we get the same as these via the single scalar instruction above (summs += x[i].m * y[i].s)
//// Compute sum of y values
//const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
//const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
//const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
//const __m256 ysum = _mm256_cvtepi32_ps( ysumi );

// Accumulate d1*m0*y
acc = _mm256_fmadd_ps( d1m0, ysum, acc );
//// Accumulate d1*m0*y
//acc = _mm256_fmadd_ps( d1m0, ysum, acc );
}

// Return horizontal sum of the acc vector
Expand All @@ -2603,7 +2657,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );

sumf = _mm_cvtss_f32( res );
sumf = _mm_cvtss_f32( res ) + summs;
#else
// scalar
for (int i = 0; i < nb; i++) {
Expand Down
5 changes: 5 additions & 0 deletions pocs/vdot/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,8 @@ set(TARGET vdot)
add_executable(${TARGET} vdot.cpp)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

set(TARGET q8dot)
add_executable(${TARGET} q8dot.cpp)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
Loading

0 comments on commit b51101a

Please sign in to comment.