@@ -29,7 +29,7 @@ __pack_half2(const half x, const half y) {
29
29
30
30
__global__ void __launch_bounds__ (64 ) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int * __restrict__ B, half* __restrict__ scaling_factors, int * __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
31
31
{
32
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
32
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
33
33
assert (false );
34
34
#else
35
35
static constexpr uint32_t ZERO = 0x0 ;
@@ -191,6 +191,39 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
191
191
}
192
192
}
193
193
for (int j_0_4 = 0 ; j_0_4 < 4 ; ++j_0_4) {
194
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
195
+ {
196
+ __asm__ __volatile__ (
197
+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
198
+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
199
+ : " =f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ])
200
+ : " r" (((unsigned *)(A_shared_warp + 0 ))[0 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[1 ]), " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[0 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ]));
201
+ }
202
+
203
+ {
204
+ __asm__ __volatile__ (
205
+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
206
+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
207
+ : " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ])
208
+ : " r" (((unsigned *)(A_shared_warp + 0 ))[0 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[1 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ]));
209
+ }
210
+
211
+ {
212
+ __asm__ __volatile__ (
213
+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
214
+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
215
+ : " =f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ])
216
+ : " r" (((unsigned *)(A_shared_warp + 0 ))[2 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[3 ]), " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[1 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ]));
217
+ }
218
+
219
+ {
220
+ __asm__ __volatile__ (
221
+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
222
+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
223
+ : " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ])
224
+ : " r" (((unsigned *)(A_shared_warp + 0 ))[2 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[3 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ]));
225
+ }
226
+ #else
194
227
{
195
228
__asm__ __volatile__ (
196
229
" mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
@@ -206,6 +239,8 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
206
239
: " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ])
207
240
: " r" (((unsigned *)(A_shared_warp + 0 ))[0 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[1 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[2 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[3 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ]));
208
241
}
242
+
243
+ #endif
209
244
}
210
245
}
211
246
}
@@ -226,7 +261,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
226
261
227
262
__global__ void __launch_bounds__ (64 ) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int * __restrict__ B, half* __restrict__ scaling_factors, int * __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
228
263
{
229
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
264
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
230
265
assert (false );
231
266
#else
232
267
static constexpr uint32_t ZERO = 0x0 ;
@@ -392,7 +427,39 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
392
427
393
428
for (int j_0_4 = 0 ; j_0_4 < 2 ; ++j_0_4)
394
429
{
430
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
431
+ {
432
+ __asm__ __volatile__ (
433
+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
434
+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
435
+ : " =f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ])
436
+ : " r" (((unsigned *)(A_shared_warp + 0 ))[0 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[1 ]), " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[0 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ]));
437
+ }
395
438
439
+ {
440
+ __asm__ __volatile__ (
441
+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
442
+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
443
+ : " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ])
444
+ : " r" (((unsigned *)(A_shared_warp + 0 ))[0 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[1 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ]));
445
+ }
446
+
447
+ {
448
+ __asm__ __volatile__ (
449
+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
450
+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
451
+ : " =f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ])
452
+ : " r" (((unsigned *)(A_shared_warp + 0 ))[2 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[3 ]), " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[1 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ]));
453
+ }
454
+
455
+ {
456
+ __asm__ __volatile__ (
457
+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
458
+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
459
+ : " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ])
460
+ : " r" (((unsigned *)(A_shared_warp + 0 ))[2 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[3 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ]));
461
+ }
462
+ #else
396
463
{
397
464
__asm__ __volatile__ (
398
465
" mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
@@ -408,6 +475,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
408
475
: " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ])
409
476
: " r" (((unsigned *)(A_shared_warp + 0 ))[0 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[1 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[2 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[3 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ]));
410
477
}
478
+ #endif
411
479
}
412
480
}
413
481
}
0 commit comments