From 1336ff7c02c7cd110c12bbae91a2d65aed51a7d7 Mon Sep 17 00:00:00 2001 From: Nexesenex <124105151+Nexesenex@users.noreply.github.com> Date: Mon, 23 Dec 2024 05:01:45 +0100 Subject: [PATCH] llama: Fix the KV cache quants q4_0 and q8_0 lead server abort in large context chat. #8073 Credit : @mengkin --- ggml/src/ggml-cpu/ggml-cpu.c | 312 ++++++++++++++++++++++++++++++++++- ggml/src/ggml-cuda/cpy.cu | 28 ++++ 2 files changed, 339 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 4640f04fd0567..1d2c68c1d7540 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -3358,6 +3358,308 @@ static void ggml_compute_forward_dup_same_cont( } } +static void ggml_compute_forward_dup_q4( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_TENSOR_UNARY_OP_LOCALS + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + if (src0->type == dst->type && + ne00 == ne0 && + nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) { + // copy by rows + const size_t rs = ne00 * nb00; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ir0; i01 < ir1; i01++) { + memcpy( + ((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3), + ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03), + rs); + } + } + } + return; + } + if (ggml_is_contiguous(dst)) { + if (nb00 == sizeof(block_q4_0)) { + const size_t rs = ne00 / 2; // QK4_0/2 bytes per row + if (dst->type == GGML_TYPE_F32) { + float * dst_ptr = (float *) dst->data; + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + size_t id = rs * ith; + for (int i01 = ir0; i01 < ir1; i01++) { + const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + dequantize_row_q4_0(src_ptr, dst_ptr + id, ne00); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_F16) { + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + size_t id = rs * ith; + for (int i01 = ir0; i01 < ir1; i01++) { + const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + float tmp[QK4_0]; + dequantize_row_q4_0(src_ptr, tmp, ne00); + for (int i00 = 0; i00 < QK4_0; i00++) { + dst_ptr[id + i00] = GGML_FP32_TO_FP16(tmp[i00]); + } + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_BF16) { + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data; + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + size_t id = rs * ith; + for (int i01 = ir0; i01 < ir1; i01++) { + const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + float tmp[QK4_0]; + dequantize_row_q4_0(src_ptr, tmp, ne00); + for (int i00 = 0; i00 < QK4_0; i00++) { + dst_ptr[id + i00] = GGML_FP32_TO_BF16(tmp[i00]); + } + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else if (ggml_get_type_traits_cpu(dst->type)->from_float) { + ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float; + float tmp[QK4_0]; + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + size_t id = rs * ith; + for (int i01 = ir0; i01 < ir1; i01++) { + const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + dequantize_row_q4_0(src_ptr, tmp, ne00); + quantize_row_q(tmp, dst->data + id, ne00); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + GGML_ABORT("fatal error"); // TODO: implement + } + } + } else { + if (dst->type == GGML_TYPE_F32) { + float * dst_ptr = (float *) dst->data; + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + size_t id = ith * QK4_0 / 2; + for (int i01 = ir0; i01 < ir1; i01++) { + const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + for (int i00 = 0; i00 < QK4_0 / 2; i00++) { + dst_ptr[id] = GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] & 0x0F) - 8); + dst_ptr[id + 1] = GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] >> 4) - 8); + id += 2; + } + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + size_t id = ith * QK4_0 / 2; + for (int i01 = ir0; i01 < ir1; i01++) { + const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + for (int i00 = 0; i00 < QK4_0 / 2; i00++) { + dst_ptr[id] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] & 0x0F) - 8)); + dst_ptr[id + 1] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] >> 4) - 8)); + id += 2; + } + } + } + } + } else if (dst->type == GGML_TYPE_BF16) { + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data; + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + size_t id = ith * QK4_0 / 2; + for (int i01 = ir0; i01 < ir1; i01++) { + const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + for (int i00 = 0; i00 < QK4_0 / 2; i00++) { + dst_ptr[id] = GGML_FP32_TO_BF16(GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] & 0x0F) - 8)); + dst_ptr[id + 1] = GGML_FP32_TO_BF16(GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] >> 4) - 8)); + id += 2; + } + } + } + } + } else { + GGML_ABORT("fatal error"); // TODO: implement + } + } + return; +} +static void ggml_compute_forward_dup_q8( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_TENSOR_UNARY_OP_LOCALS + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + if (src0->type == dst->type && + ne00 == ne0 && + nb00 >= ggml_type_size(src0->type) && nb0 >= ggml_type_size(dst->type)) { + // copy by rows + const size_t rs = ne00 * nb00; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ir0; i01 < ir1; i01++) { + memcpy( + ((char * ) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3), + ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03), + rs); + } + } + } + return; + } + if (ggml_is_contiguous(dst)) { + const size_t rs = ne00 / QK8_0; // QK8_0 bytes per row + if (dst->type == GGML_TYPE_F32) { + float * dst_ptr = (float *) dst->data; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + size_t id = rs * ith; + for (int64_t i01 = ir0; i01 < ir1; i01++) { + const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + dequantize_row_q8_0(src_ptr, dst_ptr + id * QK8_0, ne00); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_F16) { + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + size_t id = rs * ith; + for (int64_t i01 = ir0; i01 < ir1; i01++) { + const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + float tmp[QK8_0]; + dequantize_row_q8_0(src_ptr, tmp, ne00); + for (int64_t i00 = 0; i00 < QK8_0; i00++) { + dst_ptr[id * QK8_0 + i00] = GGML_FP32_TO_FP16(tmp[i00]); + } + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_BF16) { + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + size_t id = rs * ith; + for (int64_t i01 = ir0; i01 < ir1; i01++) { + const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + float tmp[QK8_0]; + dequantize_row_q8_0(src_ptr, tmp, ne00); + for (int64_t i00 = 0; i00 < QK8_0; i00++) { + dst_ptr[id * QK8_0 + i00] = GGML_FP32_TO_BF16(tmp[i00]); + } + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else if (ggml_get_type_traits_cpu(dst->type)->from_float) { + ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float; + float tmp[QK8_0]; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + size_t id = rs * ith; + for (int64_t i01 = ir0; i01 < ir1; i01++) { + const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + dequantize_row_q8_0(src_ptr, tmp, ne00); + quantize_row_q(tmp, dst->data + id * QK8_0, ne00); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + GGML_ABORT("fatal error"); // TODO: implement + } + } else { + if (dst->type == GGML_TYPE_F32) { + float * dst_ptr = (float *) dst->data; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + size_t id = ith * QK8_0; + for (int64_t i01 = ir0; i01 < ir1; i01++) { + const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + for (int64_t i00 = 0; i00 < QK8_0; i00++) { + dst_ptr[id] = GGML_FP16_TO_FP32(src_ptr->d) * src_ptr->qs[i00]; + id += 1; + } + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + size_t id = ith * QK8_0; + for (int64_t i01 = ir0; i01 < ir1; i01++) { + const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + for (int64_t i00 = 0; i00 < QK8_0; i00++) { + dst_ptr[id] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src_ptr->d) * src_ptr->qs[i00]); + id += 1; + } + } + } + } + } else if (dst->type == GGML_TYPE_BF16) { + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + size_t id = ith * QK8_0; + for (int64_t i01 = ir0; i01 < ir1; i01++) { + const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + for (int64_t i00 = 0; i00 < QK8_0; i00++) { + dst_ptr[id] = GGML_FP32_TO_BF16(GGML_FP16_TO_FP32(src_ptr->d) * src_ptr->qs[i00]); + id += 1; + } + } + } + } + } else { + GGML_ABORT("fatal error"); // TODO: implement + } + } + return; +} + static void ggml_compute_forward_dup_f16( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -4457,9 +4759,17 @@ static void ggml_compute_forward_dup( { ggml_compute_forward_dup_f32(params, dst); } break; + case GGML_TYPE_Q4_0: + { + ggml_compute_forward_dup_q4(params, dst); + } break; + case GGML_TYPE_Q8_0: + { + ggml_compute_forward_dup_q8(params, dst); + } break; default: { - GGML_ABORT("fatal error"); + GGML_ABORT("fatal error, not support forward dup oper from %d ot %d", src0->type, dst->type); } } } diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 0fe4268b86924..798f04dc720f3 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -131,6 +131,20 @@ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) { } } +static __device__ void cpy_blck_q4_0_f32(const char * cxi, char * cdsti) { + const block_q4_0 * xi = (const block_q4_0 *) cxi; + float * dsti = (float *) cdsti; + + const float d = (float)xi->d; + + for (int j = 0; j < QK4_0/2; ++j) { + const float x0 = (xi->qs[j] & 0x0F) - 8; + const float x1 = (xi->qs[j] >> 4) - 8; + dsti[j + 0] = x0 * d; + dsti[j + QK4_0/2] = x1 * d; + } +} + static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) { const float * xi = (const float *) cxi; block_q4_1 * dsti = (block_q4_1 *) cdsti; @@ -446,6 +460,16 @@ static void ggml_cpy_f32_q4_0_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } +static void ggml_cpy_q4_0_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + + const int num_blocks = ne; + cpy_q_f32<<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + static void ggml_cpy_f32_q4_1_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -556,6 +580,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { @@ -598,6 +624,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { return (void*) cpy_q_f32; + } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_q_f32; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {