@@ -503,3 +503,165 @@ kernel void kernel_cpy_f32_f32(
503
503
dst_data[i00] = src[0 ];
504
504
}
505
505
}
506
+
507
+ // ============================================ k-quants ======================================================
508
+
509
+ #define QK_K 256
510
+
511
+ typedef struct {
512
+ half d; // super-block scale for quantized scales
513
+ half dmin; // super-block scale for quantized mins
514
+ uint8_t scales[3 *QK_K/64 ]; // scales and mins, quantized with 6 bits
515
+ uint8_t qs[QK_K/2 ]; // 4--bit quants
516
+ } block_q4_k;
517
+
518
+ static inline uchar4 get_scale_min_k4 (int j, device const uint8_t * q) {
519
+ uchar4 r;
520
+ if (j < 4 ) {
521
+ r[0 ] = q[j+0 ] & 63 ; r[1 ] = q[j+4 ] & 63 ;
522
+ r[2 ] = q[j+1 ] & 63 ; r[3 ] = q[j+5 ] & 63 ;
523
+ } else {
524
+ r[0 ] = (q[j+4 ] & 0xF ) | ((q[j-4 ] >> 6 ) << 4 );
525
+ r[1 ] = (q[j+4 ] >> 4 ) | ((q[j-0 ] >> 6 ) << 4 );
526
+ r[2 ] = (q[j+5 ] & 0xF ) | ((q[j-3 ] >> 6 ) << 4 );
527
+ r[3 ] = (q[j+5 ] >> 4 ) | ((q[j+1 ] >> 6 ) << 4 );
528
+ }
529
+ return r;
530
+ }
531
+
532
+ static void dequantize_row_q4_k (device const block_q4_k * x, device float * y, int k) {
533
+ assert (k % QK_K == 0 );
534
+ const int nb = k / QK_K;
535
+
536
+ for (int i = 0 ; i < nb; i++) {
537
+
538
+ const float d = x[i].d ;
539
+ const float min = x[i].dmin ;
540
+
541
+ device const uint8_t * q = x[i].qs ;
542
+ device const uint8_t * scales = x[i].scales ;
543
+
544
+ int is = 0 ;
545
+ for (int j = 0 ; j < QK_K; j += 64 ) {
546
+ const uchar4 sc = get_scale_min_k4 (is, scales);
547
+ const float d1 = d * sc[0 ]; const float m1 = min * sc[1 ];
548
+ const float d2 = d * sc[2 ]; const float m2 = min * sc[3 ];
549
+ for (int l = 0 ; l < 32 ; ++l) *y++ = d1 * (q[l] & 0xF ) - m1;
550
+ for (int l = 0 ; l < 32 ; ++l) *y++ = d2 * (q[l] >> 4 ) - m2;
551
+ q += 32 ; is += 2 ;
552
+ }
553
+
554
+ }
555
+ }
556
+
557
+ kernel void kernel_get_rows_q4_k (
558
+ device const void * src0,
559
+ device const int * src1,
560
+ device float * dst,
561
+ constant int64_t & ne00,
562
+ constant uint64_t & nb01,
563
+ constant uint64_t & nb1,
564
+ uint tpig[[thread_position_in_grid]]) {
565
+ const int i = tpig;
566
+ const int r = ((device int32_t *) src1)[i];
567
+
568
+ dequantize_row_q4_k (
569
+ (device const block_q4_k *) ((device char *) src0 + r*nb01),
570
+ (device float *) ((device char *) dst + i*nb1), ne00);
571
+ }
572
+
573
+ kernel void kernel_mul_mat_q4_k_f32 (
574
+ device const void * src0,
575
+ device const float * src1,
576
+ device float * dst,
577
+ constant int64_t & ne00,
578
+ constant int64_t & ne01,
579
+ constant uint64_t & nb00,
580
+ constant uint64_t & nb01,
581
+ constant uint64_t & nb02,
582
+ constant int64_t & ne10,
583
+ constant int64_t & ne11,
584
+ constant uint64_t & nb10,
585
+ constant uint64_t & nb11,
586
+ constant uint64_t & nb12,
587
+ constant int64_t & ne0,
588
+ constant int64_t & ne1,
589
+ threadgroup float * sum [[threadgroup(0 )]],
590
+ uint2 tgpig[[threadgroup_position_in_grid]],
591
+ uint2 tpig[[thread_position_in_grid]], // we don't use this for now
592
+ uint2 tpitg[[thread_position_in_threadgroup]],
593
+ uint2 tptg[[threads_per_threadgroup]]) {
594
+
595
+ const int nb = ne00/QK_K;
596
+
597
+ const int64_t r0 = tgpig.x ;
598
+ const int64_t r1 = tgpig.y ;
599
+
600
+ device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
601
+ device const float * yy = (device const float *) src1 + r1*ne10;
602
+
603
+ const uint nth = tptg.x *tptg.y ;
604
+ const uint ith = tptg.y *tpitg.x + tpitg.y ;
605
+
606
+ const int tid = tpitg.y ; // 0...16
607
+ const int il = tid/4 ; // 0...3
608
+ const int ir = tid%4 ; // 0...3
609
+ const int n = 8 ;
610
+ const int is = 2 *il;
611
+
612
+ sum[ith] = 0 .0f ;
613
+
614
+ float sumf = 0 ;
615
+ for (int i = tpitg.x ; i < nb; i += tptg.x ) {
616
+
617
+ device const uint8_t * q = (x + i)->qs + 32 *il + n*ir;
618
+ device const float * y = yy + i*QK_K + 64 *il + n*ir;
619
+ device const uint8_t * scales = (x + i)->scales ;
620
+
621
+ const float dall = (float )((x + i)->d );
622
+ const float dmin = (float )((x + i)->dmin );
623
+
624
+ const uchar4 sc = get_scale_min_k4 (is, scales);
625
+
626
+ float4 s = {0 .f , 0 .f , 0 .f , 0 .f };
627
+ for (int l = 0 ; l < n; ++l) {
628
+ s[0 ] += y[l+ 0 ] * (q[l] & 0xF ); s[1 ] += y[l+ 0 ];
629
+ s[2 ] += y[l+32 ] * (q[l] >> 4 ); s[3 ] += y[l+32 ];
630
+ }
631
+ sumf += dall * (s[0 ] * sc[0 ] + s[2 ] * sc[2 ]) - dmin * (s[1 ] * sc[1 ] + s[3 ] * sc[3 ]);
632
+
633
+ }
634
+ sum[ith] = sumf;
635
+
636
+ //
637
+ // Accumulate the sum from all threads in the threadgroup
638
+ // This version is slightly faster than the commented out one below,
639
+ // which I copy-pasted from ggerganov's q4_0 dot product for metal.
640
+ //
641
+ threadgroup_barrier (mem_flags::mem_threadgroup);
642
+ if (ith%4 == 0 ) {
643
+ for (int i = 1 ; i < 4 ; ++i) sum[ith] += sum[ith + i];
644
+ }
645
+ threadgroup_barrier (mem_flags::mem_threadgroup);
646
+ if (ith%16 == 0 ) {
647
+ for (int i = 4 ; i < 16 ; i += 4 ) sum[ith] += sum[ith + i];
648
+ }
649
+ threadgroup_barrier (mem_flags::mem_threadgroup);
650
+ if (ith == 0 ) {
651
+ for (int i = 16 ; i < nth; i += 16 ) sum[0 ] += sum[i];
652
+ dst[r1*ne0 + r0] = sum[0 ];
653
+ }
654
+
655
+ // // accumulate the sum from all threads in the threadgroup
656
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
657
+ // for (uint i = nth/2; i > 0; i /= 2) {
658
+ // if (ith < i) {
659
+ // sum[ith] += sum[ith + i];
660
+ // }
661
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
662
+ // }
663
+
664
+ // if (ith == 0) {
665
+ // dst[r1*ne0 + r0] = sum[0];
666
+ // }
667
+ }
0 commit comments