Skip to content

Commit 7538246

Browse files
authored
cuda : add f32 to bf16 copy op (#12806)
This allows BF16 KV-cache on CUDA.
1 parent b32efad commit 7538246

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

ggml/src/ggml-cuda/cpy.cu

+21
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
1010
*dsti = *xi;
1111
}
1212

13+
static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
14+
const float * xi = (const float *) cxi;
15+
nv_bfloat16 * dsti = (nv_bfloat16 *) cdsti;
16+
17+
*dsti = *xi;
18+
}
19+
1320
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
1421
const float * xi = (const float *) cxi;
1522
half * dsti = (half *) cdsti;
@@ -386,6 +393,16 @@ static void ggml_cpy_f32_f32_cuda(
386393
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
387394
}
388395

396+
static void ggml_cpy_f32_bf16_cuda(
397+
const char * cx, char * cdst, const int ne,
398+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
399+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
400+
401+
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
402+
cpy_f32_f16<cpy_1_f32_bf16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
403+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
404+
}
405+
389406
static void ggml_cpy_f32_f16_cuda(
390407
const char * cx, char * cdst, const int ne,
391408
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -581,6 +598,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
581598
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
582599
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
583600
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
601+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
602+
ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
584603
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
585604
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
586605
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
@@ -634,6 +653,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
634653
return nullptr;
635654
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
636655
return (void*) cpy_f32_f16<cpy_1_f32_f32>;
656+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
657+
return (void*) cpy_f32_f16<cpy_1_f32_bf16>;
637658
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
638659
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
639660
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {

ggml/src/ggml-cuda/ggml-cuda.cu

+3
Original file line numberDiff line numberDiff line change
@@ -3079,6 +3079,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30793079
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
30803080
return true;
30813081
}
3082+
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_BF16) {
3083+
return true;
3084+
}
30823085
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
30833086
return true;
30843087
}

0 commit comments

Comments
 (0)