@@ -83,9 +83,19 @@ typedef struct {
83
83
} block_q8_0;
84
84
static_assert (sizeof (block_q8_0) == sizeof(float ) + QK8_0, "wrong q8_0 block size/padding");
85
85
86
+ #define CUDA_MUL_BLOCK_SIZE 256
86
87
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
87
88
#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
88
89
90
+ static __global__ void mul_f32 (const float * x, const float * y, float * dst, const int kx, const int ky) {
91
+ const int i = blockDim .x *blockIdx .x + threadIdx .x ;
92
+
93
+ if (i >= kx) {
94
+ return ;
95
+ }
96
+ dst[i] = x[i] * y[i%ky];
97
+ }
98
+
89
99
static __device__ void dequantize_q4_0 (const void * vx, const int ib, const int iqs, float & v0, float & v1){
90
100
const block_q4_0 * x = (const block_q4_0 *) vx;
91
101
@@ -228,6 +238,11 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
228
238
}
229
239
}
230
240
241
+ static void mul_f32_cuda (const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
242
+ const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1 ) / CUDA_MUL_BLOCK_SIZE;
243
+ mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0 , stream>>> (x, y, dst, kx, ky);
244
+ }
245
+
231
246
static void dequantize_row_q4_0_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
232
247
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE;
233
248
dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>> (vx, y, k);
@@ -467,6 +482,67 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor
467
482
}
468
483
}
469
484
485
+ static void ggml_cuda_mul_f32 (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
486
+ GGML_ASSERT (src1->backend == GGML_BACKEND_CUDA);
487
+ const int64_t ne00 = src0->ne [0 ];
488
+ const int64_t ne01 = src0->ne [1 ];
489
+ const int64_t ne02 = src0->ne [2 ];
490
+ const int64_t ne03 = src0->ne [2 ];
491
+ const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
492
+ const int64_t ne10 = src1->ne [0 ];
493
+ const int64_t ne11 = src1->ne [1 ];
494
+ const int64_t ne12 = src1->ne [2 ];
495
+ const int64_t ne13 = src1->ne [3 ];
496
+ const int nb2 = dst->nb [2 ];
497
+ const int nb3 = dst->nb [3 ];
498
+ size_t x_size, d_size;
499
+
500
+ float * d_X = (float *) ggml_cuda_pool_malloc (ne0 * sizeof (float ), &x_size); // src0
501
+ float * d_Y = (float *) src1->data ; // src1 is already on device, broadcasted.
502
+ float * d_D = (float *) ggml_cuda_pool_malloc (ne0 * sizeof (float ), &d_size); // dst
503
+
504
+ for (int64_t i03 = 0 ; i03 < ne03; i03++) {
505
+ for (int64_t i02 = 0 ; i02 < ne02; i02++) {
506
+ const int i0 = i03*ne02 + i02;
507
+ float * c_X2 = d_X + i0*ne01*ne00;
508
+ float * c_D2 = d_D + i0*ne01*ne00;
509
+
510
+ cudaStream_t cudaStream = g_cudaStreams[i0 % GGML_CUDA_MAX_STREAMS];
511
+ cudaStream_t cudaStream2 = g_cudaStreams2[i0 % GGML_CUDA_MAX_STREAMS];
512
+ cudaEvent_t cudaEvent = g_cudaEvents[i0 % GGML_CUDA_MAX_EVENTS];
513
+
514
+ // copy src0 to device
515
+ CUDA_CHECK (ggml_cuda_h2d_tensor_2d (c_X2, src0, i03, i02, cudaStream2));
516
+ CUDA_CHECK (cudaEventRecord (cudaEvent, cudaStream2));
517
+
518
+ // wait for data
519
+ CUDA_CHECK (cudaStreamWaitEvent (cudaStream, cudaEvent, 0 ));
520
+
521
+ for (int64_t i01 = 0 ; i01 < ne01; i01++) {
522
+ const int64_t i13 = i03%ne13;
523
+ const int64_t i12 = i02%ne12;
524
+ const int64_t i11 = i01%ne11;
525
+ const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
526
+
527
+ float * c_X1 = c_X2 + i01*ne00;
528
+ float * c_Y = d_Y + i1*ne10;
529
+ float * c_D1 = c_D2 + i01*ne00;
530
+
531
+ // compute
532
+ mul_f32_cuda (c_X1, c_Y, c_D1, ne00, ne10, cudaStream);
533
+ CUDA_CHECK (cudaGetLastError ());
534
+ }
535
+
536
+ // copy dst to host
537
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
538
+ CUDA_CHECK (cudaMemcpyAsync (d, c_D2, sizeof (float )*ne00*ne01, cudaMemcpyDeviceToHost, cudaStream));
539
+ }
540
+ }
541
+ CUDA_CHECK (cudaDeviceSynchronize ());
542
+ ggml_cuda_pool_free (d_X, x_size);
543
+ ggml_cuda_pool_free (d_D, d_size);
544
+ }
545
+
470
546
static void ggml_cuda_mul_mat_f32 (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
471
547
const int64_t ne00 = src0->ne [0 ];
472
548
const int64_t ne01 = src0->ne [1 ];
@@ -724,6 +800,11 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
724
800
ggml_cuda_pool_free (d_Q, q_size);
725
801
}
726
802
803
+ void ggml_cuda_mul (const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
804
+ GGML_ASSERT (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
805
+ ggml_cuda_mul_f32 (src0, src1, dst);
806
+ }
807
+
727
808
bool ggml_cuda_can_mul_mat (const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
728
809
const int64_t ne10 = src1->ne [0 ];
729
810
@@ -797,14 +878,18 @@ void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
797
878
const size_t q_sz = ggml_type_size (type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size (type);
798
879
799
880
size_t q_size;
800
- char * d_Q = (char *) ggml_cuda_pool_malloc (q_sz, &q_size);
881
+ char * dst = (char *) ggml_cuda_pool_malloc (q_sz, &q_size);
801
882
802
883
cudaStream_t cudaStream2 = g_cudaStreams2[0 ];
803
884
804
885
// copy tensor to device
805
- CUDA_CHECK (ggml_cuda_h2d_tensor_2d (d_Q, tensor, 0 , 0 , cudaStream2));
806
- CUDA_CHECK (cudaDeviceSynchronize ());
886
+ for (int64_t i3 = 0 ; i3 < ne3; i3++) {
887
+ for (int64_t i2 = 0 ; i2 < ne2; i2++) {
888
+ int i = i3*ne2 + i2;
889
+ CUDA_CHECK (ggml_cuda_h2d_tensor_2d (dst + i*ne0*ne1, tensor, i3, i2, cudaStream2));
890
+ }
891
+ }
807
892
808
- tensor->data = d_Q ;
893
+ tensor->data = dst ;
809
894
tensor->backend = GGML_BACKEND_CUDA;
810
895
}
0 commit comments