Skip to content

Commit 7c12ec3

Browse files
committed
llama: Fix the KV cache quants q4_0 and q8_0 lead server abort in large context chat. ggml-org#8073
Credit : @mengkin Only the Cuda q4_0 calculations.
1 parent 52dcabb commit 7c12ec3

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

ggml/src/ggml-cuda/cpy.cu

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,20 @@ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
131131
}
132132
}
133133

134+
static __device__ void cpy_blck_q4_0_f32(const char * cxi, char * cdsti) {
135+
const block_q4_0 * xi = (const block_q4_0 *) cxi;
136+
float * dsti = (float *) cdsti;
137+
138+
const float d = (float)xi->d;
139+
140+
for (int j = 0; j < QK4_0/2; ++j) {
141+
const float x0 = (xi->qs[j] & 0x0F) - 8;
142+
const float x1 = (xi->qs[j] >> 4) - 8;
143+
dsti[j + 0] = x0 * d;
144+
dsti[j + QK4_0/2] = x1 * d;
145+
}
146+
}
147+
134148
static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
135149
const float * xi = (const float *) cxi;
136150
block_q4_1 * dsti = (block_q4_1 *) cdsti;
@@ -446,6 +460,16 @@ static void ggml_cpy_f32_q4_0_cuda(
446460
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
447461
}
448462

463+
static void ggml_cpy_q4_0_f32_cuda(
464+
const char * cx, char * cdst, const int ne,
465+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
466+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
467+
468+
const int num_blocks = ne;
469+
cpy_q_f32<cpy_blck_q4_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
470+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
471+
}
472+
449473
static void ggml_cpy_f32_q4_1_cuda(
450474
const char * cx, char * cdst, const int ne,
451475
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -556,6 +580,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
556580
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
557581
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
558582
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
583+
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
584+
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
559585
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
560586
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
561587
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
@@ -598,6 +624,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
598624
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
599625
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
600626
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
627+
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
628+
return (void*) cpy_q_f32<cpy_blck_q4_0_f32, QK4_0>;
601629
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
602630
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
603631
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {

0 commit comments

Comments
 (0)