Skip to content

Commit

Permalink
llama: Fix the KV cache quants q4_0 and q8_0 lead server abort in lar…
Browse files Browse the repository at this point in the history
…ge context chat. ggerganov#8073

Credit : @mengkin
  • Loading branch information
Nexesenex committed Dec 24, 2024
1 parent 52dcabb commit 1336ff7
Show file tree
Hide file tree
Showing 2 changed files with 339 additions and 1 deletion.
312 changes: 311 additions & 1 deletion ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -3358,6 +3358,308 @@ static void ggml_compute_forward_dup_same_cont(
}
}

static void ggml_compute_forward_dup_q4(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
GGML_TENSOR_UNARY_OP_LOCALS
const int ith = params->ith; // thread index
const int nth = params->nth; // number of threads
// parallelize by rows
const int nr = ne01;
// number of rows per thread
const int dr = (nr + nth - 1) / nth;
// row range for this thread
const int ir0 = dr * ith;
const int ir1 = MIN(ir0 + dr, nr);
if (src0->type == dst->type &&
ne00 == ne0 &&
nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
// copy by rows
const size_t rs = ne00 * nb00;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ir0; i01 < ir1; i01++) {
memcpy(
((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3),
((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03),
rs);
}
}
}
return;
}
if (ggml_is_contiguous(dst)) {
if (nb00 == sizeof(block_q4_0)) {
const size_t rs = ne00 / 2; // QK4_0/2 bytes per row
if (dst->type == GGML_TYPE_F32) {
float * dst_ptr = (float *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
size_t id = rs * ith;
for (int i01 = ir0; i01 < ir1; i01++) {
const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
dequantize_row_q4_0(src_ptr, dst_ptr + id, ne00);
id += rs;
}
id += rs * (ne01 - ir1);
}
}
} else if (dst->type == GGML_TYPE_F16) {
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
size_t id = rs * ith;
for (int i01 = ir0; i01 < ir1; i01++) {
const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
float tmp[QK4_0];
dequantize_row_q4_0(src_ptr, tmp, ne00);
for (int i00 = 0; i00 < QK4_0; i00++) {
dst_ptr[id + i00] = GGML_FP32_TO_FP16(tmp[i00]);
}
id += rs;
}
id += rs * (ne01 - ir1);
}
}
} else if (dst->type == GGML_TYPE_BF16) {
ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
size_t id = rs * ith;
for (int i01 = ir0; i01 < ir1; i01++) {
const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
float tmp[QK4_0];
dequantize_row_q4_0(src_ptr, tmp, ne00);
for (int i00 = 0; i00 < QK4_0; i00++) {
dst_ptr[id + i00] = GGML_FP32_TO_BF16(tmp[i00]);
}
id += rs;
}
id += rs * (ne01 - ir1);
}
}
} else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
float tmp[QK4_0];
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
size_t id = rs * ith;
for (int i01 = ir0; i01 < ir1; i01++) {
const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
dequantize_row_q4_0(src_ptr, tmp, ne00);
quantize_row_q(tmp, dst->data + id, ne00);
id += rs;
}
id += rs * (ne01 - ir1);
}
}
} else {
GGML_ABORT("fatal error"); // TODO: implement
}
}
} else {
if (dst->type == GGML_TYPE_F32) {
float * dst_ptr = (float *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
size_t id = ith * QK4_0 / 2;
for (int i01 = ir0; i01 < ir1; i01++) {
const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
for (int i00 = 0; i00 < QK4_0 / 2; i00++) {
dst_ptr[id] = GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] & 0x0F) - 8);
dst_ptr[id + 1] = GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] >> 4) - 8);
id += 2;
}
}
}
}
} else if (dst->type == GGML_TYPE_F16) {
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
size_t id = ith * QK4_0 / 2;
for (int i01 = ir0; i01 < ir1; i01++) {
const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
for (int i00 = 0; i00 < QK4_0 / 2; i00++) {
dst_ptr[id] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] & 0x0F) - 8));
dst_ptr[id + 1] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] >> 4) - 8));
id += 2;
}
}
}
}
} else if (dst->type == GGML_TYPE_BF16) {
ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
size_t id = ith * QK4_0 / 2;
for (int i01 = ir0; i01 < ir1; i01++) {
const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
for (int i00 = 0; i00 < QK4_0 / 2; i00++) {
dst_ptr[id] = GGML_FP32_TO_BF16(GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] & 0x0F) - 8));
dst_ptr[id + 1] = GGML_FP32_TO_BF16(GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] >> 4) - 8));
id += 2;
}
}
}
}
} else {
GGML_ABORT("fatal error"); // TODO: implement
}
}
return;
}
static void ggml_compute_forward_dup_q8(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
GGML_TENSOR_UNARY_OP_LOCALS
const int ith = params->ith; // thread index
const int nth = params->nth; // number of threads
// parallelize by rows
const int nr = ne01;
// number of rows per thread
const int dr = (nr + nth - 1) / nth;
// row range for this thread
const int ir0 = dr * ith;
const int ir1 = MIN(ir0 + dr, nr);
if (src0->type == dst->type &&
ne00 == ne0 &&
nb00 >= ggml_type_size(src0->type) && nb0 >= ggml_type_size(dst->type)) {
// copy by rows
const size_t rs = ne00 * nb00;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ir0; i01 < ir1; i01++) {
memcpy(
((char * ) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3),
((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03),
rs);
}
}
}
return;
}
if (ggml_is_contiguous(dst)) {
const size_t rs = ne00 / QK8_0; // QK8_0 bytes per row
if (dst->type == GGML_TYPE_F32) {
float * dst_ptr = (float *) dst->data;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
size_t id = rs * ith;
for (int64_t i01 = ir0; i01 < ir1; i01++) {
const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
dequantize_row_q8_0(src_ptr, dst_ptr + id * QK8_0, ne00);
id += rs;
}
id += rs * (ne01 - ir1);
}
}
} else if (dst->type == GGML_TYPE_F16) {
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
size_t id = rs * ith;
for (int64_t i01 = ir0; i01 < ir1; i01++) {
const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
float tmp[QK8_0];
dequantize_row_q8_0(src_ptr, tmp, ne00);
for (int64_t i00 = 0; i00 < QK8_0; i00++) {
dst_ptr[id * QK8_0 + i00] = GGML_FP32_TO_FP16(tmp[i00]);
}
id += rs;
}
id += rs * (ne01 - ir1);
}
}
} else if (dst->type == GGML_TYPE_BF16) {
ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
size_t id = rs * ith;
for (int64_t i01 = ir0; i01 < ir1; i01++) {
const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
float tmp[QK8_0];
dequantize_row_q8_0(src_ptr, tmp, ne00);
for (int64_t i00 = 0; i00 < QK8_0; i00++) {
dst_ptr[id * QK8_0 + i00] = GGML_FP32_TO_BF16(tmp[i00]);
}
id += rs;
}
id += rs * (ne01 - ir1);
}
}
} else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
float tmp[QK8_0];
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
size_t id = rs * ith;
for (int64_t i01 = ir0; i01 < ir1; i01++) {
const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
dequantize_row_q8_0(src_ptr, tmp, ne00);
quantize_row_q(tmp, dst->data + id * QK8_0, ne00);
id += rs;
}
id += rs * (ne01 - ir1);
}
}
} else {
GGML_ABORT("fatal error"); // TODO: implement
}
} else {
if (dst->type == GGML_TYPE_F32) {
float * dst_ptr = (float *) dst->data;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
size_t id = ith * QK8_0;
for (int64_t i01 = ir0; i01 < ir1; i01++) {
const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
for (int64_t i00 = 0; i00 < QK8_0; i00++) {
dst_ptr[id] = GGML_FP16_TO_FP32(src_ptr->d) * src_ptr->qs[i00];
id += 1;
}
}
}
}
} else if (dst->type == GGML_TYPE_F16) {
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
size_t id = ith * QK8_0;
for (int64_t i01 = ir0; i01 < ir1; i01++) {
const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
for (int64_t i00 = 0; i00 < QK8_0; i00++) {
dst_ptr[id] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src_ptr->d) * src_ptr->qs[i00]);
id += 1;
}
}
}
}
} else if (dst->type == GGML_TYPE_BF16) {
ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
size_t id = ith * QK8_0;
for (int64_t i01 = ir0; i01 < ir1; i01++) {
const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
for (int64_t i00 = 0; i00 < QK8_0; i00++) {
dst_ptr[id] = GGML_FP32_TO_BF16(GGML_FP16_TO_FP32(src_ptr->d) * src_ptr->qs[i00]);
id += 1;
}
}
}
}
} else {
GGML_ABORT("fatal error"); // TODO: implement
}
}
return;
}

static void ggml_compute_forward_dup_f16(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
Expand Down Expand Up @@ -4457,9 +4759,17 @@ static void ggml_compute_forward_dup(
{
ggml_compute_forward_dup_f32(params, dst);
} break;
case GGML_TYPE_Q4_0:
{
ggml_compute_forward_dup_q4(params, dst);
} break;
case GGML_TYPE_Q8_0:
{
ggml_compute_forward_dup_q8(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
GGML_ABORT("fatal error, not support forward dup oper from %d ot %d", src0->type, dst->type);
}
}
}
Expand Down
28 changes: 28 additions & 0 deletions ggml/src/ggml-cuda/cpy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,20 @@ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
}
}

static __device__ void cpy_blck_q4_0_f32(const char * cxi, char * cdsti) {
const block_q4_0 * xi = (const block_q4_0 *) cxi;
float * dsti = (float *) cdsti;

const float d = (float)xi->d;

for (int j = 0; j < QK4_0/2; ++j) {
const float x0 = (xi->qs[j] & 0x0F) - 8;
const float x1 = (xi->qs[j] >> 4) - 8;
dsti[j + 0] = x0 * d;
dsti[j + QK4_0/2] = x1 * d;
}
}

static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q4_1 * dsti = (block_q4_1 *) cdsti;
Expand Down Expand Up @@ -446,6 +460,16 @@ static void ggml_cpy_f32_q4_0_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}

static void ggml_cpy_q4_0_f32_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {

const int num_blocks = ne;
cpy_q_f32<cpy_blck_q4_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}

static void ggml_cpy_f32_q4_1_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
Expand Down Expand Up @@ -556,6 +580,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
Expand Down Expand Up @@ -598,6 +624,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q4_0_f32, QK4_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
Expand Down

0 comments on commit 1336ff7

Please sign in to comment.