Skip to content

Commit 72ff528

Browse files
ikawrakowKawrakow
andauthored
metal : add Q2_K implementation (ggml-org#1762)
* metal : add Q2_K implementation 27.1 ms / token on M2 Max 30-core GPU, so about the same speed as Q4_0. Memory throughput is ~156 GB/s. The access pattern used in the Q2_K CUDA implementation resulted in significantly lower performance (~31 ms/token). * Fixing merge conflicts --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 0bf7cf1 commit 72ff528

File tree

2 files changed

+200
-18
lines changed

2 files changed

+200
-18
lines changed

ggml-metal.m

+17
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,13 @@
4949
GGML_METAL_DECL_KERNEL(diag_mask_inf);
5050
GGML_METAL_DECL_KERNEL(get_rows_f16);
5151
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
52+
GGML_METAL_DECL_KERNEL(get_rows_q2_k);
5253
GGML_METAL_DECL_KERNEL(get_rows_q4_k);
5354
GGML_METAL_DECL_KERNEL(get_rows_q6_k);
5455
GGML_METAL_DECL_KERNEL(rms_norm);
5556
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
5657
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
58+
GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
5759
GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
5860
GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
5961
GGML_METAL_DECL_KERNEL(rope);
@@ -137,11 +139,13 @@
137139
GGML_METAL_ADD_KERNEL(diag_mask_inf);
138140
GGML_METAL_ADD_KERNEL(get_rows_f16);
139141
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
142+
GGML_METAL_ADD_KERNEL(get_rows_q2_k);
140143
GGML_METAL_ADD_KERNEL(get_rows_q4_k);
141144
GGML_METAL_ADD_KERNEL(get_rows_q6_k);
142145
GGML_METAL_ADD_KERNEL(rms_norm);
143146
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
144147
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
148+
GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
145149
GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
146150
GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
147151
GGML_METAL_ADD_KERNEL(rope);
@@ -525,6 +529,15 @@ void ggml_metal_graph_compute(
525529
nth1 = 4;
526530
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
527531
} break;
532+
case GGML_TYPE_Q2_K:
533+
{
534+
GGML_ASSERT(ne02 == 1);
535+
GGML_ASSERT(ne12 == 1);
536+
537+
nth0 = 4;
538+
nth1 = 16;
539+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
540+
} break;
528541
case GGML_TYPE_Q4_K:
529542
{
530543
GGML_ASSERT(ne02 == 1);
@@ -570,6 +583,9 @@ void ggml_metal_graph_compute(
570583
if (src0t == GGML_TYPE_Q4_0) {
571584
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
572585
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
586+
} else if (src0t == GGML_TYPE_Q2_K) {
587+
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
588+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
573589
} else if (src0t == GGML_TYPE_Q4_K) {
574590
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
575591
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -591,6 +607,7 @@ void ggml_metal_graph_compute(
591607
switch (src0->type) {
592608
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
593609
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
610+
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
594611
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
595612
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
596613
default: GGML_ASSERT(false && "not implemented");

ggml-metal.metal

+183-18
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,13 @@ kernel void kernel_cpy_f32_f32(
527527

528528
#define QK_K 256
529529

530+
typedef struct {
531+
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
532+
uint8_t qs[QK_K/4]; // quants
533+
half d; // super-block scale for quantized scales
534+
half dmin; // super-block scale for quantized mins
535+
} block_q2_k;
536+
530537
typedef struct {
531538
half d; // super-block scale for quantized scales
532539
half dmin; // super-block scale for quantized mins
@@ -555,6 +562,41 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
555562
return r;
556563
}
557564

565+
//========================================== dequantization =============================
566+
567+
static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) {
568+
assert(k % QK_K == 0);
569+
const int nb = k / QK_K;
570+
571+
for (int i = 0; i < nb; i++) {
572+
573+
const float d = x[i].d;
574+
const float min = x[i].dmin;
575+
576+
device const uint8_t * q = x[i].qs;
577+
578+
int is = 0;
579+
float dl, ml;
580+
for (int n = 0; n < QK_K; n += 128) {
581+
int shift = 0;
582+
for (int j = 0; j < 4; ++j) {
583+
584+
uint8_t sc = x[i].scales[is++];
585+
dl = d * (sc & 0xF); ml = min * (sc >> 4);
586+
for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
587+
588+
sc = x[i].scales[is++];
589+
dl = d * (sc & 0xF); ml = min * (sc >> 4);
590+
for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
591+
592+
shift += 2;
593+
}
594+
q += 32;
595+
}
596+
597+
}
598+
}
599+
558600
static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
559601
assert(k % QK_K == 0);
560602
const int nb = k / QK_K;
@@ -586,12 +628,12 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
586628

587629
for (int i = 0; i < nb; i++) {
588630

589-
const float d = x[i].d;
590-
591631
device const uint8_t * ql = x[i].ql;
592632
device const uint8_t * qh = x[i].qh;
593633
device const int8_t * sc = x[i].scales;
594634

635+
const float d = x[i].d;
636+
595637
for (int n = 0; n < QK_K; n += 128) {
596638
for (int l = 0; l < 32; ++l) {
597639
int is = l/16;
@@ -612,6 +654,22 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
612654
}
613655
}
614656

657+
kernel void kernel_get_rows_q2_k(
658+
device const void * src0,
659+
device const int * src1,
660+
device float * dst,
661+
constant int64_t & ne00,
662+
constant uint64_t & nb01,
663+
constant uint64_t & nb1,
664+
uint tpig[[thread_position_in_grid]]) {
665+
const int i = tpig;
666+
const int r = ((device int32_t *) src1)[i];
667+
668+
dequantize_row_q2_k(
669+
(device const block_q2_k *) ((device char *) src0 + r*nb01),
670+
(device float *) ((device char *) dst + i*nb1), ne00);
671+
}
672+
615673
kernel void kernel_get_rows_q4_k(
616674
device const void * src0,
617675
device const int * src1,
@@ -628,6 +686,129 @@ kernel void kernel_get_rows_q4_k(
628686
(device float *) ((device char *) dst + i*nb1), ne00);
629687
}
630688

689+
kernel void kernel_get_rows_q6_k(
690+
device const void * src0,
691+
device const int * src1,
692+
device float * dst,
693+
constant int64_t & ne00,
694+
constant uint64_t & nb01,
695+
constant uint64_t & nb1,
696+
uint tpig[[thread_position_in_grid]]) {
697+
const int i = tpig;
698+
const int r = ((device int32_t *) src1)[i];
699+
700+
dequantize_row_q6_k(
701+
(device const block_q6_k *) ((device char *) src0 + r*nb01),
702+
(device float *) ((device char *) dst + i*nb1), ne00);
703+
}
704+
705+
//====================================== dot products =========================
706+
707+
kernel void kernel_mul_mat_q2_k_f32(
708+
device const void * src0,
709+
device const float * src1,
710+
device float * dst,
711+
constant int64_t & ne00,
712+
constant int64_t & ne01,
713+
constant uint64_t & nb00,
714+
constant uint64_t & nb01,
715+
constant uint64_t & nb02,
716+
constant int64_t & ne10,
717+
constant int64_t & ne11,
718+
constant uint64_t & nb10,
719+
constant uint64_t & nb11,
720+
constant uint64_t & nb12,
721+
constant int64_t & ne0,
722+
constant int64_t & ne1,
723+
threadgroup float * sum [[threadgroup(0)]],
724+
uint2 tgpig[[threadgroup_position_in_grid]],
725+
uint2 tpig[[thread_position_in_grid]], // we don't use this for now
726+
uint2 tpitg[[thread_position_in_threadgroup]],
727+
uint2 tptg[[threads_per_threadgroup]]) {
728+
729+
const int nb = ne00/QK_K;
730+
731+
const int64_t r0 = tgpig.x;
732+
const int64_t r1 = tgpig.y;
733+
734+
device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb;
735+
device const float * yy = (device const float *) src1 + r1*ne10;
736+
737+
const int nth = tptg.x*tptg.y;
738+
const int ith = tptg.y*tpitg.x + tpitg.y;
739+
740+
741+
const int tid = tpitg.y; // 0...16
742+
const int il = tid/4; // 0...3
743+
const int ir = tid%4; // 0...3
744+
const int ip = il/2; // 0 or 1
745+
const int shift1 = 4*(il%2);// 0 or 4
746+
const int shift2 = shift1+2;// 2 or 6
747+
const int n = 8;
748+
const int is = 4*il + (n*ir)/16;
749+
750+
sum[ith] = 0.0f;
751+
752+
float sumf = 0;
753+
for (int i = tpitg.x; i < nb; i += tptg.x) {
754+
755+
device const uint8_t * q = x[i].qs + 32*ip + n*ir;
756+
device const uint8_t * scales = x[i].scales + is;
757+
758+
uint8_t d1 = scales[0] & 0xF;
759+
uint8_t m1 = scales[0] >> 4;
760+
uint8_t d2 = scales[2] & 0xF;
761+
uint8_t m2 = scales[2] >> 4;
762+
763+
device const float * y = yy + i*QK_K + 64*il + n*ir;
764+
765+
const float dall = (float)x[i].d;
766+
const float dmin = (float)x[i].dmin;
767+
768+
float4 s = {0.f, 0.f, 0.f, 0.f};
769+
for (int l = 0; l < n; ++l) {
770+
s[0] += y[l+ 0] * ((q[l] >> shift1) & 3); s[1] += y[l+ 0];
771+
s[2] += y[l+32] * ((q[l] >> shift2) & 3); s[3] += y[l+32];
772+
}
773+
sumf += dall * (s[0] * d1 + s[2] * d2) - dmin * (s[1] * m1 + s[3] * m2);
774+
775+
776+
}
777+
sum[ith] = sumf;
778+
779+
//
780+
// Accumulate the sum from all threads in the threadgroup
781+
// This version is slightly faster than the commented out one below,
782+
// which I copy-pasted from ggerganov's q4_0 dot product for metal.
783+
//
784+
threadgroup_barrier(mem_flags::mem_threadgroup);
785+
if (ith%4 == 0) {
786+
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
787+
}
788+
threadgroup_barrier(mem_flags::mem_threadgroup);
789+
if (ith%16 == 0) {
790+
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
791+
}
792+
threadgroup_barrier(mem_flags::mem_threadgroup);
793+
if (ith == 0) {
794+
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
795+
dst[r1*ne0 + r0] = sum[0];
796+
}
797+
798+
//// accumulate the sum from all threads in the threadgroup
799+
//threadgroup_barrier(mem_flags::mem_threadgroup);
800+
//for (uint i = nth/2; i > 0; i /= 2) {
801+
// if (ith < i) {
802+
// sum[ith] += sum[ith + i];
803+
// }
804+
// threadgroup_barrier(mem_flags::mem_threadgroup);
805+
//}
806+
807+
//if (ith == 0) {
808+
// dst[r1*ne0 + r0] = sum[0];
809+
//}
810+
}
811+
631812
kernel void kernel_mul_mat_q4_k_f32(
632813
device const void * src0,
633814
device const float * src1,
@@ -724,22 +905,6 @@ kernel void kernel_mul_mat_q4_k_f32(
724905
//}
725906
}
726907

727-
kernel void kernel_get_rows_q6_k(
728-
device const void * src0,
729-
device const int * src1,
730-
device float * dst,
731-
constant int64_t & ne00,
732-
constant uint64_t & nb01,
733-
constant uint64_t & nb1,
734-
uint tpig[[thread_position_in_grid]]) {
735-
const int i = tpig;
736-
const int r = ((device int32_t *) src1)[i];
737-
738-
dequantize_row_q6_k(
739-
(device const block_q6_k *) ((device char *) src0 + r*nb01),
740-
(device float *) ((device char *) dst + i*nb1), ne00);
741-
}
742-
743908
kernel void kernel_mul_mat_q6_k_f32(
744909
device const void * src0,
745910
device const float * src1,

0 commit comments

Comments
 (0)