@@ -130,9 +130,19 @@ typedef struct {
130
130
} block_q8_0;
131
131
static_assert (sizeof (block_q8_0) == sizeof(ggml_fp16_t ) + QK8_0, "wrong q8_0 block size/padding");
132
132
133
+ #define WARP_SIZE 32
134
+
133
135
#define CUDA_MUL_BLOCK_SIZE 256
136
+
134
137
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
135
- #define CUDA_DMMV_BLOCK_SIZE 64 // dmmv = dequantize_mul_mat_vec
138
+
139
+ // dmmv = dequantize_mul_mat_vec
140
+ #ifndef GGML_CUDA_DMMV_X
141
+ #define GGML_CUDA_DMMV_X 32
142
+ #endif
143
+ #ifndef GGML_CUDA_DMMV_Y
144
+ #define GGML_CUDA_DMMV_Y 1
145
+ #endif
136
146
137
147
static __global__ void mul_f32 (const float * x, const float * y, float * dst, const int kx, const int ky) {
138
148
const int i = blockDim .x *blockIdx .x + threadIdx .x ;
@@ -247,41 +257,51 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
247
257
dequantize_kernel (vx, ib, iqs, v0, v1);
248
258
}
249
259
250
- template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
260
+ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
251
261
static __global__ void dequantize_mul_mat_vec (const void * vx, const float * y, float * dst, const int ncols) {
252
- const int row = blockIdx .x ;
262
+ // qk = quantized weights per x block
263
+ // qr = number of quantized weights per data value in x block
264
+ const int row = blockIdx .x *blockDim .y + threadIdx .y ;
253
265
const int tid = threadIdx .x ;
254
266
267
+ const int iter_stride = 2 *GGML_CUDA_DMMV_X;
268
+ const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
255
269
const int y_offset = qr == 1 ? 1 : qk/2 ;
256
270
257
- __shared__ float tmp[block_size]; // separate sum for each thread
258
- tmp[tid] = 0 ;
271
+ float tmp = 0 ; // partial sum for thread in warp
259
272
260
- for (int i = 0 ; i < ncols/block_size ; i += 2 ) {
261
- const int col = i*block_size + 2 *tid;
262
- const int ib = (row*ncols + col)/qk; // block index
263
- const int iqs = (col%qk)/qr; // quant index
273
+ for (int i = 0 ; i < ncols; i += iter_stride ) {
274
+ const int col = i + vals_per_iter *tid;
275
+ const int ib = (row*ncols + col)/qk; // x block index
276
+ const int iqs = (col%qk)/qr; // x quant index
264
277
const int iybs = col - col%qk; // y block start index
265
278
266
- // dequantize
267
- float v0, v1;
268
- dequantize_kernel (vx, ib, iqs, v0, v1);
279
+ // processing >2 values per i iter is faster for fast GPUs
280
+ #pragma unroll
281
+ for (int j = 0 ; j < vals_per_iter; j += 2 ) {
282
+ // process 2 vals per j iter
283
+
284
+ // dequantize
285
+ float v0, v1;
286
+ dequantize_kernel (vx, ib, iqs + j/qr, v0, v1);
287
+ // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
269
288
270
- // matrix multiplication
271
- tmp[tid] += v0 * y[iybs + iqs + 0 ];
272
- tmp[tid] += v1 * y[iybs + iqs + y_offset];
289
+ // matrix multiplication
290
+ tmp += v0 * y[iybs + iqs + j/qr + 0 ];
291
+ tmp += v1 * y[iybs + iqs + j/qr + y_offset];
292
+ // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
293
+ }
273
294
}
274
295
275
296
// sum up partial sums and write back result
276
297
__syncthreads ();
277
- for (int s=block_size/2 ; s>0 ; s>>=1 ) {
278
- if (tid < s) {
279
- tmp[tid] += tmp[tid + s];
280
- }
281
- __syncthreads ();
298
+ #pragma unroll
299
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
300
+ tmp += __shfl_xor_sync (0xffffffff , tmp, mask, 32 );
282
301
}
302
+
283
303
if (tid == 0 ) {
284
- dst[row] = tmp[ 0 ] ;
304
+ dst[row] = tmp;
285
305
}
286
306
}
287
307
@@ -316,33 +336,43 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
316
336
}
317
337
318
338
static void dequantize_mul_mat_vec_q4_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
319
- GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
320
- dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0>
321
- <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
339
+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
340
+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_Y == 0 );
341
+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_Y, 1 );
342
+ dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
343
+ <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
322
344
}
323
345
324
346
static void dequantize_mul_mat_vec_q4_1_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
325
- GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
326
- dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1>
327
- <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
347
+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
348
+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_Y == 0 );
349
+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_Y, 1 );
350
+ dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
351
+ <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
328
352
}
329
353
330
354
static void dequantize_mul_mat_vec_q5_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
331
- GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
332
- dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0>
333
- <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
355
+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
356
+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_Y == 0 );
357
+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_Y, 1 );
358
+ dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
359
+ <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
334
360
}
335
361
336
362
static void dequantize_mul_mat_vec_q5_1_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
337
- GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
338
- dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1>
339
- <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
363
+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
364
+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_Y == 0 );
365
+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_Y, 1 );
366
+ dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
367
+ <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
340
368
}
341
369
342
370
static void dequantize_mul_mat_vec_q8_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
343
- GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
344
- dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0>
345
- <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
371
+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
372
+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_Y == 0 );
373
+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_Y, 1 );
374
+ dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
375
+ <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
346
376
}
347
377
348
378
static void convert_fp16_to_fp32_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -351,9 +381,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
351
381
}
352
382
353
383
static void convert_mul_mat_vec_f16_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
354
- GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
355
- dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32 , 1 , convert_f16>
356
- <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
384
+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
385
+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_Y == 0 );
386
+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_Y, 1 );
387
+ dequantize_mul_mat_vec<1 , 1 , convert_f16>
388
+ <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
357
389
}
358
390
359
391
static to_fp32_cuda_t ggml_get_to_fp32_cuda (ggml_type type) {
0 commit comments