@@ -18,6 +18,21 @@ typedef struct {
18
18
uint8_t qs[QK4_1 / 2 ]; // nibbles / quants
19
19
} block_q4_1;
20
20
21
+ #define QK5_0 32
22
+ typedef struct {
23
+ half d; // delta
24
+ uint8_t qh[4 ]; // 5-th bit of quants
25
+ uint8_t qs[QK5_0 / 2 ]; // nibbles / quants
26
+ } block_q5_0;
27
+
28
+ #define QK5_1 32
29
+ typedef struct {
30
+ half d; // delta
31
+ half m; // min
32
+ uint8_t qh[4 ]; // 5-th bit of quants
33
+ uint8_t qs[QK5_1 / 2 ]; // nibbles / quants
34
+ } block_q5_1;
35
+
21
36
#define QK8_0 32
22
37
typedef struct {
23
38
half d; // delta
@@ -399,8 +414,11 @@ kernel void kernel_rms_norm(
399
414
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
400
415
inline float block_q_n_dot_y (device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
401
416
float d = qb_curr->d ;
417
+
402
418
float2 acc = 0 .f ;
419
+
403
420
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2 );
421
+
404
422
for (int i = 0 ; i < 8 ; i+=2 ) {
405
423
acc[0 ] += yl[i + 0 ] * (qs[i / 2 ] & 0x000F )
406
424
+ yl[i + 1 ] * (qs[i / 2 ] & 0x0F00 );
@@ -417,8 +435,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
417
435
inline float block_q_n_dot_y (device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
418
436
float d = qb_curr->d ;
419
437
float m = qb_curr->m ;
420
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/ 2 );
438
+
421
439
float2 acc = 0 .f ;
440
+
441
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2 );
442
+
422
443
for (int i = 0 ; i < 8 ; i+=2 ) {
423
444
acc[0 ] += yl[i + 0 ] * (qs[i / 2 ] & 0x000F )
424
445
+ yl[i + 1 ] * (qs[i / 2 ] & 0x0F00 );
@@ -428,6 +449,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
428
449
return d * (acc[0 ] + acc[1 ]) + sumy * m;
429
450
}
430
451
452
+ // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
453
+ // il indicates where the q5 quants begin (0 or QK5_0/4)
454
+ // we assume that the yl's have been multiplied with the appropriate scale factor
455
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
456
+ inline float block_q_n_dot_y (device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
457
+ float d = qb_curr->d ;
458
+
459
+ float2 acc = 0 .f ;
460
+
461
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2 );
462
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh );
463
+
464
+ for (int i = 0 ; i < 8 ; i+=2 ) {
465
+ acc[0 ] += yl[i + 0 ] * ((qs[i / 2 ] & 0x000F ) | ((qh >> (i+0 +il ) << 4 ) & 0x00010 ))
466
+ + yl[i + 1 ] * ((qs[i / 2 ] & 0x0F00 ) | ((qh >> (i+1 +il ) << 12 ) & 0x01000 ));
467
+ acc[1 ] += yl[i + 8 ] * ((qs[i / 2 ] & 0x00F0 ) | ((qh >> (i+0 +il+QK5_0/2 ) << 8 ) & 0x00100 ))
468
+ + yl[i + 9 ] * ((qs[i / 2 ] & 0xF000 ) | ((qh >> (i+1 +il+QK5_0/2 ) << 16 ) & 0x10000 ));
469
+ }
470
+ return d * (sumy * -16 .f + acc[0 ] + acc[1 ]);
471
+ }
472
+
473
+ // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
474
+ // il indicates where the q5 quants begin (0 or QK5_1/4)
475
+ // we assume that the yl's have been multiplied with the appropriate scale factor
476
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
477
+ inline float block_q_n_dot_y (device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
478
+ float d = qb_curr->d ;
479
+ float m = qb_curr->m ;
480
+
481
+ float2 acc = 0 .f ;
482
+
483
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2 );
484
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh );
485
+
486
+ for (int i = 0 ; i < 8 ; i+=2 ) {
487
+ acc[0 ] += yl[i + 0 ] * ((qs[i / 2 ] & 0x000F ) | ((qh >> (i+0 +il ) << 4 ) & 0x00010 ))
488
+ + yl[i + 1 ] * ((qs[i / 2 ] & 0x0F00 ) | ((qh >> (i+1 +il ) << 12 ) & 0x01000 ));
489
+ acc[1 ] += yl[i + 8 ] * ((qs[i / 2 ] & 0x00F0 ) | ((qh >> (i+0 +il+QK5_0/2 ) << 8 ) & 0x00100 ))
490
+ + yl[i + 9 ] * ((qs[i / 2 ] & 0xF000 ) | ((qh >> (i+1 +il+QK5_0/2 ) << 16 ) & 0x10000 ));
491
+ }
492
+ return d * (acc[0 ] + acc[1 ]) + sumy * m;
493
+ }
494
+
431
495
// putting them in the kernel cause a significant performance penalty
432
496
#define N_DST 4 // each SIMD group works on 4 rows
433
497
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
@@ -525,6 +589,43 @@ kernel void kernel_mul_mv_q4_1_f32(
525
589
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
526
590
}
527
591
592
+ kernel void kernel_mul_mv_q5_0_f32 (
593
+ device const void * src0,
594
+ device const float * src1,
595
+ device float * dst,
596
+ constant int64_t & ne00,
597
+ constant int64_t & ne01[[buffer(4 )]],
598
+ constant int64_t & ne02[[buffer(5 )]],
599
+ constant int64_t & ne10[[buffer(9 )]],
600
+ constant int64_t & ne12[[buffer(11 )]],
601
+ constant int64_t & ne0[[buffer(15 )]],
602
+ constant int64_t & ne1[[buffer(16 )]],
603
+ constant uint & gqa[[buffer(17 )]],
604
+ uint3 tgpig[[threadgroup_position_in_grid]],
605
+ uint tiisg[[thread_index_in_simdgroup]],
606
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
607
+ mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
608
+ }
609
+
610
+ kernel void kernel_mul_mv_q5_1_f32 (
611
+ device const void * src0,
612
+ device const float * src1,
613
+ device float * dst,
614
+ constant int64_t & ne00,
615
+ constant int64_t & ne01[[buffer(4 )]],
616
+ constant int64_t & ne02[[buffer(5 )]],
617
+ constant int64_t & ne10[[buffer(9 )]],
618
+ constant int64_t & ne12[[buffer(11 )]],
619
+ constant int64_t & ne0[[buffer(15 )]],
620
+ constant int64_t & ne1[[buffer(16 )]],
621
+ constant uint & gqa[[buffer(17 )]],
622
+ uint3 tgpig[[threadgroup_position_in_grid]],
623
+ uint tiisg[[thread_index_in_simdgroup]],
624
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
625
+ mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
626
+ }
627
+
628
+
528
629
#define NB_Q8_0 8
529
630
530
631
kernel void kernel_mul_mv_q8_0_f32 (
@@ -2149,6 +2250,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
2149
2250
}
2150
2251
}
2151
2252
2253
+ template <typename type4x4>
2254
+ void dequantize_q5_0 (device const block_q5_0 *xb, short il, thread type4x4 & reg) {
2255
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3 );
2256
+ const float d = xb->d ;
2257
+ const float md = -16 .h * xb->d ;
2258
+ const ushort mask = il ? 0x00F0 : 0x000F ;
2259
+
2260
+ const uint32_t qh = *((device const uint32_t *)xb->qh );
2261
+
2262
+ const int x_mv = il ? 4 : 0 ;
2263
+
2264
+ const int gh_mv = il ? 12 : 0 ;
2265
+ const int gh_bk = il ? 0 : 4 ;
2266
+
2267
+ for (int i = 0 ; i < 8 ; i++) {
2268
+ // extract the 5-th bits for x0 and x1
2269
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2 *i )) << gh_bk) & 0x10 ;
2270
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2 *i+1 )) << gh_bk) & 0x10 ;
2271
+
2272
+ // combine the 4-bits from qs with the 5th bit
2273
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2274
+ const int32_t x1 = ((((qs[i] >> 8 ) & mask) >> x_mv) | xh_1);
2275
+
2276
+ reg[i/2 ][2 *(i%2 )+0 ] = d * x0 + md;
2277
+ reg[i/2 ][2 *(i%2 )+1 ] = d * x1 + md;
2278
+ }
2279
+ }
2280
+
2281
+ template <typename type4x4>
2282
+ void dequantize_q5_1 (device const block_q5_1 *xb, short il, thread type4x4 & reg) {
2283
+ device const uint16_t * qs = ((device const uint16_t *)xb + 4 );
2284
+ const float d = xb->d ;
2285
+ const float m = xb->m ;
2286
+ const ushort mask = il ? 0x00F0 : 0x000F ;
2287
+
2288
+ const uint32_t qh = *((device const uint32_t *)xb->qh );
2289
+
2290
+ const int x_mv = il ? 4 : 0 ;
2291
+
2292
+ const int gh_mv = il ? 12 : 0 ;
2293
+ const int gh_bk = il ? 0 : 4 ;
2294
+
2295
+ for (int i = 0 ; i < 8 ; i++) {
2296
+ // extract the 5-th bits for x0 and x1
2297
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2 *i )) << gh_bk) & 0x10 ;
2298
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2 *i+1 )) << gh_bk) & 0x10 ;
2299
+
2300
+ // combine the 4-bits from qs with the 5th bit
2301
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2302
+ const int32_t x1 = ((((qs[i] >> 8 ) & mask) >> x_mv) | xh_1);
2303
+
2304
+ reg[i/2 ][2 *(i%2 )+0 ] = d * x0 + m;
2305
+ reg[i/2 ][2 *(i%2 )+1 ] = d * x1 + m;
2306
+ }
2307
+ }
2308
+
2152
2309
template <typename type4x4>
2153
2310
void dequantize_q8_0 (device const block_q8_0 *xb, short il, thread type4x4 & reg) {
2154
2311
device const int8_t * qs = ((device const int8_t *)xb->qs );
@@ -2490,6 +2647,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows
2490
2647
template [[host_name(" kernel_get_rows_f16" )]] kernel get_rows_t kernel_get_rows<half4x4, 1 , dequantize_f16>;
2491
2648
template [[host_name(" kernel_get_rows_q4_0" )]] kernel get_rows_t kernel_get_rows<block_q4_0, 2 , dequantize_q4_0>;
2492
2649
template [[host_name(" kernel_get_rows_q4_1" )]] kernel get_rows_t kernel_get_rows<block_q4_1, 2 , dequantize_q4_1>;
2650
+ template [[host_name(" kernel_get_rows_q5_0" )]] kernel get_rows_t kernel_get_rows<block_q5_0, 2 , dequantize_q5_0>;
2651
+ template [[host_name(" kernel_get_rows_q5_1" )]] kernel get_rows_t kernel_get_rows<block_q5_1, 2 , dequantize_q5_1>;
2493
2652
template [[host_name(" kernel_get_rows_q8_0" )]] kernel get_rows_t kernel_get_rows<block_q8_0, 2 , dequantize_q8_0>;
2494
2653
template [[host_name(" kernel_get_rows_q2_K" )]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
2495
2654
template [[host_name(" kernel_get_rows_q3_K" )]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
@@ -2518,6 +2677,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<f
2518
2677
template [[host_name(" kernel_mul_mm_f16_f32" )]] kernel mat_mm_t kernel_mul_mm<half4x4, 1 , dequantize_f16>;
2519
2678
template [[host_name(" kernel_mul_mm_q4_0_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2 , dequantize_q4_0>;
2520
2679
template [[host_name(" kernel_mul_mm_q4_1_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2 , dequantize_q4_1>;
2680
+ template [[host_name(" kernel_mul_mm_q5_0_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2 , dequantize_q5_0>;
2681
+ template [[host_name(" kernel_mul_mm_q5_1_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2 , dequantize_q5_1>;
2521
2682
template [[host_name(" kernel_mul_mm_q8_0_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2 , dequantize_q8_0>;
2522
2683
template [[host_name(" kernel_mul_mm_q2_K_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
2523
2684
template [[host_name(" kernel_mul_mm_q3_K_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
0 commit comments