@@ -527,6 +527,13 @@ kernel void kernel_cpy_f32_f32(
527
527
528
528
#define QK_K 256
529
529
530
+ typedef struct {
531
+ uint8_t scales[QK_K/16 ]; // scales and mins, quantized with 4 bits
532
+ uint8_t qs[QK_K/4 ]; // quants
533
+ half d; // super-block scale for quantized scales
534
+ half dmin; // super-block scale for quantized mins
535
+ } block_q2_k;
536
+
530
537
typedef struct {
531
538
half d; // super-block scale for quantized scales
532
539
half dmin; // super-block scale for quantized mins
@@ -555,6 +562,41 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
555
562
return r;
556
563
}
557
564
565
+ // ========================================== dequantization =============================
566
+
567
+ static void dequantize_row_q2_k (device const block_q2_k * x, device float * y, int k) {
568
+ assert (k % QK_K == 0 );
569
+ const int nb = k / QK_K;
570
+
571
+ for (int i = 0 ; i < nb; i++) {
572
+
573
+ const float d = x[i].d ;
574
+ const float min = x[i].dmin ;
575
+
576
+ device const uint8_t * q = x[i].qs ;
577
+
578
+ int is = 0 ;
579
+ float dl, ml;
580
+ for (int n = 0 ; n < QK_K; n += 128 ) {
581
+ int shift = 0 ;
582
+ for (int j = 0 ; j < 4 ; ++j) {
583
+
584
+ uint8_t sc = x[i].scales [is++];
585
+ dl = d * (sc & 0xF ); ml = min * (sc >> 4 );
586
+ for (int l = 0 ; l < 16 ; ++l) *y++ = dl * ((int8_t )((q[l] >> shift) & 3 )) - ml;
587
+
588
+ sc = x[i].scales [is++];
589
+ dl = d * (sc & 0xF ); ml = min * (sc >> 4 );
590
+ for (int l = 0 ; l < 16 ; ++l) *y++ = dl * ((int8_t )((q[l+16 ] >> shift) & 3 )) - ml;
591
+
592
+ shift += 2 ;
593
+ }
594
+ q += 32 ;
595
+ }
596
+
597
+ }
598
+ }
599
+
558
600
static void dequantize_row_q4_k (device const block_q4_k * x, device float * y, int k) {
559
601
assert (k % QK_K == 0 );
560
602
const int nb = k / QK_K;
@@ -586,12 +628,12 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
586
628
587
629
for (int i = 0 ; i < nb; i++) {
588
630
589
- const float d = x[i].d ;
590
-
591
631
device const uint8_t * ql = x[i].ql ;
592
632
device const uint8_t * qh = x[i].qh ;
593
633
device const int8_t * sc = x[i].scales ;
594
634
635
+ const float d = x[i].d ;
636
+
595
637
for (int n = 0 ; n < QK_K; n += 128 ) {
596
638
for (int l = 0 ; l < 32 ; ++l) {
597
639
int is = l/16 ;
@@ -612,6 +654,22 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
612
654
}
613
655
}
614
656
657
+ kernel void kernel_get_rows_q2_k (
658
+ device const void * src0,
659
+ device const int * src1,
660
+ device float * dst,
661
+ constant int64_t & ne00,
662
+ constant uint64_t & nb01,
663
+ constant uint64_t & nb1,
664
+ uint tpig[[thread_position_in_grid]]) {
665
+ const int i = tpig;
666
+ const int r = ((device int32_t *) src1)[i];
667
+
668
+ dequantize_row_q2_k (
669
+ (device const block_q2_k *) ((device char *) src0 + r*nb01),
670
+ (device float *) ((device char *) dst + i*nb1), ne00);
671
+ }
672
+
615
673
kernel void kernel_get_rows_q4_k (
616
674
device const void * src0,
617
675
device const int * src1,
@@ -628,6 +686,129 @@ kernel void kernel_get_rows_q4_k(
628
686
(device float *) ((device char *) dst + i*nb1), ne00);
629
687
}
630
688
689
+ kernel void kernel_get_rows_q6_k (
690
+ device const void * src0,
691
+ device const int * src1,
692
+ device float * dst,
693
+ constant int64_t & ne00,
694
+ constant uint64_t & nb01,
695
+ constant uint64_t & nb1,
696
+ uint tpig[[thread_position_in_grid]]) {
697
+ const int i = tpig;
698
+ const int r = ((device int32_t *) src1)[i];
699
+
700
+ dequantize_row_q6_k (
701
+ (device const block_q6_k *) ((device char *) src0 + r*nb01),
702
+ (device float *) ((device char *) dst + i*nb1), ne00);
703
+ }
704
+
705
+ // ====================================== dot products =========================
706
+
707
+ kernel void kernel_mul_mat_q2_k_f32 (
708
+ device const void * src0,
709
+ device const float * src1,
710
+ device float * dst,
711
+ constant int64_t & ne00,
712
+ constant int64_t & ne01,
713
+ constant uint64_t & nb00,
714
+ constant uint64_t & nb01,
715
+ constant uint64_t & nb02,
716
+ constant int64_t & ne10,
717
+ constant int64_t & ne11,
718
+ constant uint64_t & nb10,
719
+ constant uint64_t & nb11,
720
+ constant uint64_t & nb12,
721
+ constant int64_t & ne0,
722
+ constant int64_t & ne1,
723
+ threadgroup float * sum [[threadgroup(0 )]],
724
+ uint2 tgpig[[threadgroup_position_in_grid]],
725
+ uint2 tpig[[thread_position_in_grid]], // we don't use this for now
726
+ uint2 tpitg[[thread_position_in_threadgroup]],
727
+ uint2 tptg[[threads_per_threadgroup]]) {
728
+
729
+ const int nb = ne00/QK_K;
730
+
731
+ const int64_t r0 = tgpig.x ;
732
+ const int64_t r1 = tgpig.y ;
733
+
734
+ device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb;
735
+ device const float * yy = (device const float *) src1 + r1*ne10;
736
+
737
+ const int nth = tptg.x *tptg.y ;
738
+ const int ith = tptg.y *tpitg.x + tpitg.y ;
739
+
740
+
741
+ const int tid = tpitg.y ; // 0...16
742
+ const int il = tid/4 ; // 0...3
743
+ const int ir = tid%4 ; // 0...3
744
+ const int ip = il/2 ; // 0 or 1
745
+ const int shift1 = 4 *(il%2 );// 0 or 4
746
+ const int shift2 = shift1+2 ;// 2 or 6
747
+ const int n = 8 ;
748
+ const int is = 4 *il + (n*ir)/16 ;
749
+
750
+ sum[ith] = 0 .0f ;
751
+
752
+ float sumf = 0 ;
753
+ for (int i = tpitg.x ; i < nb; i += tptg.x ) {
754
+
755
+ device const uint8_t * q = x[i].qs + 32 *ip + n*ir;
756
+ device const uint8_t * scales = x[i].scales + is;
757
+
758
+ uint8_t d1 = scales[0 ] & 0xF ;
759
+ uint8_t m1 = scales[0 ] >> 4 ;
760
+ uint8_t d2 = scales[2 ] & 0xF ;
761
+ uint8_t m2 = scales[2 ] >> 4 ;
762
+
763
+ device const float * y = yy + i*QK_K + 64 *il + n*ir;
764
+
765
+ const float dall = (float )x[i].d ;
766
+ const float dmin = (float )x[i].dmin ;
767
+
768
+ float4 s = {0 .f , 0 .f , 0 .f , 0 .f };
769
+ for (int l = 0 ; l < n; ++l) {
770
+ s[0 ] += y[l+ 0 ] * ((q[l] >> shift1) & 3 ); s[1 ] += y[l+ 0 ];
771
+ s[2 ] += y[l+32 ] * ((q[l] >> shift2) & 3 ); s[3 ] += y[l+32 ];
772
+ }
773
+ sumf += dall * (s[0 ] * d1 + s[2 ] * d2) - dmin * (s[1 ] * m1 + s[3 ] * m2);
774
+
775
+
776
+ }
777
+ sum[ith] = sumf;
778
+
779
+ //
780
+ // Accumulate the sum from all threads in the threadgroup
781
+ // This version is slightly faster than the commented out one below,
782
+ // which I copy-pasted from ggerganov's q4_0 dot product for metal.
783
+ //
784
+ threadgroup_barrier (mem_flags::mem_threadgroup);
785
+ if (ith%4 == 0 ) {
786
+ for (int i = 1 ; i < 4 ; ++i) sum[ith] += sum[ith + i];
787
+ }
788
+ threadgroup_barrier (mem_flags::mem_threadgroup);
789
+ if (ith%16 == 0 ) {
790
+ for (int i = 4 ; i < 16 ; i += 4 ) sum[ith] += sum[ith + i];
791
+ }
792
+ threadgroup_barrier (mem_flags::mem_threadgroup);
793
+ if (ith == 0 ) {
794
+ for (int i = 16 ; i < nth; i += 16 ) sum[0 ] += sum[i];
795
+ dst[r1*ne0 + r0] = sum[0 ];
796
+ }
797
+
798
+ // // accumulate the sum from all threads in the threadgroup
799
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
800
+ // for (uint i = nth/2; i > 0; i /= 2) {
801
+ // if (ith < i) {
802
+ // sum[ith] += sum[ith + i];
803
+ // }
804
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
805
+ // }
806
+
807
+ // if (ith == 0) {
808
+ // dst[r1*ne0 + r0] = sum[0];
809
+ // }
810
+ }
811
+
631
812
kernel void kernel_mul_mat_q4_k_f32 (
632
813
device const void * src0,
633
814
device const float * src1,
@@ -724,22 +905,6 @@ kernel void kernel_mul_mat_q4_k_f32(
724
905
// }
725
906
}
726
907
727
- kernel void kernel_get_rows_q6_k (
728
- device const void * src0,
729
- device const int * src1,
730
- device float * dst,
731
- constant int64_t & ne00,
732
- constant uint64_t & nb01,
733
- constant uint64_t & nb1,
734
- uint tpig[[thread_position_in_grid]]) {
735
- const int i = tpig;
736
- const int r = ((device int32_t *) src1)[i];
737
-
738
- dequantize_row_q6_k (
739
- (device const block_q6_k *) ((device char *) src0 + r*nb01),
740
- (device float *) ((device char *) dst + i*nb1), ne00);
741
- }
742
-
743
908
kernel void kernel_mul_mat_q6_k_f32 (
744
909
device const void * src0,
745
910
device const float * src1,
0 commit comments