Skip to content

Commit af005ce

Browse files
CUDA kernel for ggml_mul, norms in VRAM
1 parent 9ca9b35 commit af005ce

File tree

4 files changed

+100
-4
lines changed

4 files changed

+100
-4
lines changed

ggml-cuda.cu

+89-4
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,19 @@ typedef struct {
8383
} block_q8_0;
8484
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
8585

86+
#define CUDA_MUL_BLOCK_SIZE 256
8687
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
8788
#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
8889

90+
static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
91+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
92+
93+
if (i >= kx) {
94+
return;
95+
}
96+
dst[i] = x[i] * y[i%ky];
97+
}
98+
8999
static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
90100
const block_q4_0 * x = (const block_q4_0 *) vx;
91101

@@ -228,6 +238,11 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
228238
}
229239
}
230240

241+
static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
242+
const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
243+
mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
244+
}
245+
231246
static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
232247
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
233248
dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
@@ -467,6 +482,67 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor
467482
}
468483
}
469484

485+
static void ggml_cuda_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
486+
GGML_ASSERT(src1->backend == GGML_BACKEND_CUDA);
487+
const int64_t ne00 = src0->ne[0];
488+
const int64_t ne01 = src0->ne[1];
489+
const int64_t ne02 = src0->ne[2];
490+
const int64_t ne03 = src0->ne[2];
491+
const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
492+
const int64_t ne10 = src1->ne[0];
493+
const int64_t ne11 = src1->ne[1];
494+
const int64_t ne12 = src1->ne[2];
495+
const int64_t ne13 = src1->ne[3];
496+
const int nb2 = dst->nb[2];
497+
const int nb3 = dst->nb[3];
498+
size_t x_size, d_size;
499+
500+
float * d_X = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &x_size); // src0
501+
float * d_Y = (float *) src1->data; // src1 is already on device, broadcasted.
502+
float * d_D = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &d_size); // dst
503+
504+
for (int64_t i03 = 0; i03 < ne03; i03++) {
505+
for (int64_t i02 = 0; i02 < ne02; i02++) {
506+
const int i0 = i03*ne02 + i02;
507+
float * c_X2 = d_X + i0*ne01*ne00;
508+
float * c_D2 = d_D + i0*ne01*ne00;
509+
510+
cudaStream_t cudaStream = g_cudaStreams[i0 % GGML_CUDA_MAX_STREAMS];
511+
cudaStream_t cudaStream2 = g_cudaStreams2[i0 % GGML_CUDA_MAX_STREAMS];
512+
cudaEvent_t cudaEvent = g_cudaEvents[i0 % GGML_CUDA_MAX_EVENTS];
513+
514+
// copy src0 to device
515+
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X2, src0, i03, i02, cudaStream2));
516+
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
517+
518+
// wait for data
519+
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
520+
521+
for (int64_t i01 = 0; i01 < ne01; i01++) {
522+
const int64_t i13 = i03%ne13;
523+
const int64_t i12 = i02%ne12;
524+
const int64_t i11 = i01%ne11;
525+
const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
526+
527+
float * c_X1 = c_X2 + i01*ne00;
528+
float * c_Y = d_Y + i1*ne10;
529+
float * c_D1 = c_D2 + i01*ne00;
530+
531+
// compute
532+
mul_f32_cuda(c_X1, c_Y, c_D1, ne00, ne10, cudaStream);
533+
CUDA_CHECK(cudaGetLastError());
534+
}
535+
536+
// copy dst to host
537+
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
538+
CUDA_CHECK(cudaMemcpyAsync(d, c_D2, sizeof(float)*ne00*ne01, cudaMemcpyDeviceToHost, cudaStream));
539+
}
540+
}
541+
CUDA_CHECK(cudaDeviceSynchronize());
542+
ggml_cuda_pool_free(d_X, x_size);
543+
ggml_cuda_pool_free(d_D, d_size);
544+
}
545+
470546
static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
471547
const int64_t ne00 = src0->ne[0];
472548
const int64_t ne01 = src0->ne[1];
@@ -724,6 +800,11 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
724800
ggml_cuda_pool_free(d_Q, q_size);
725801
}
726802

803+
void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
804+
GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
805+
ggml_cuda_mul_f32(src0, src1, dst);
806+
}
807+
727808
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
728809
const int64_t ne10 = src1->ne[0];
729810

@@ -797,14 +878,18 @@ void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
797878
const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
798879

799880
size_t q_size;
800-
char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
881+
char * dst = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
801882

802883
cudaStream_t cudaStream2 = g_cudaStreams2[0];
803884

804885
// copy tensor to device
805-
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2));
806-
CUDA_CHECK(cudaDeviceSynchronize());
886+
for (int64_t i3 = 0; i3 < ne3; i3++) {
887+
for (int64_t i2 = 0; i2 < ne2; i2++) {
888+
int i = i3*ne2 + i2;
889+
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(dst + i*ne0*ne1, tensor, i3, i2, cudaStream2));
890+
}
891+
}
807892

808-
tensor->data = d_Q;
893+
tensor->data = dst;
809894
tensor->backend = GGML_BACKEND_CUDA;
810895
}

ggml-cuda.h

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ extern "C" {
66

77
void ggml_init_cublas(void);
88

9+
void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
910
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
1011
size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
1112
void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);

ggml.c

+8
Original file line numberDiff line numberDiff line change
@@ -7961,6 +7961,14 @@ static void ggml_compute_forward_mul_f32(
79617961
}
79627962
const int ith = params->ith;
79637963
const int nth = params->nth;
7964+
#ifdef GGML_USE_CUBLAS
7965+
if (src1->backend == GGML_BACKEND_CUDA) {
7966+
if (ith == 0) {
7967+
ggml_cuda_mul(src0, src1, dst);
7968+
}
7969+
return;
7970+
}
7971+
#endif
79647972

79657973
const size_t nb00 = src0->nb[0];
79667974
const size_t nb01 = src0->nb[1];

llama.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -1038,10 +1038,12 @@ static void llama_model_load_internal(
10381038
for (int i = 0; i < n_gpu; ++i) {
10391039
const auto & layer = model.layers[i];
10401040

1041+
ggml_cuda_transform_tensor(layer.attention_norm); vram_total += ggml_nbytes(layer.attention_norm);
10411042
ggml_cuda_transform_tensor(layer.wq); vram_total += ggml_nbytes(layer.wq);
10421043
ggml_cuda_transform_tensor(layer.wk); vram_total += ggml_nbytes(layer.wk);
10431044
ggml_cuda_transform_tensor(layer.wv); vram_total += ggml_nbytes(layer.wv);
10441045
ggml_cuda_transform_tensor(layer.wo); vram_total += ggml_nbytes(layer.wo);
1046+
ggml_cuda_transform_tensor(layer.ffn_norm); vram_total += ggml_nbytes(layer.ffn_norm);
10451047
ggml_cuda_transform_tensor(layer.w1); vram_total += ggml_nbytes(layer.w1);
10461048
ggml_cuda_transform_tensor(layer.w2); vram_total += ggml_nbytes(layer.w2);
10471049
ggml_cuda_transform_tensor(layer.w3); vram_total += ggml_nbytes(layer.w3);

0 commit comments

Comments
 (0)