Skip to content

Commit 4161bdc

Browse files
ikawrakowKawrakow
andauthored
metal : add Q4_K implementation (ggml-org#1733)
* Metal implementation for Q4_K Very slow for now: 42 ms / token, Q4_0 runs in 28 ms/token on my 30-core M2 Max GPU. * Optimizing Q4_K on metal The first token always takes longer, I guess because the metal kernel is being jit-compiled. So, using n = 128 to measure time. At this point Q4_K takes 29.5 ms / token compared to 27.2 ms / token for Q4_0. Quite a bit better than the initial attempt, but still not good enough. * Optimizing q4_K metal dot some more For n = 256 it is now 28.1 ms/token compared to 27 ms/token for q4_0. * Fix after merge with master --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 0035858 commit 4161bdc

File tree

3 files changed

+184
-19
lines changed

3 files changed

+184
-19
lines changed

.clang-tidy

-18
This file was deleted.

ggml-metal.m

+22-1
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,11 @@
4949
GGML_METAL_DECL_KERNEL(diag_mask_inf);
5050
GGML_METAL_DECL_KERNEL(get_rows_f16);
5151
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
52+
GGML_METAL_DECL_KERNEL(get_rows_q4_k);
5253
GGML_METAL_DECL_KERNEL(rms_norm);
5354
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
5455
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
56+
GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
5557
GGML_METAL_DECL_KERNEL(rope);
5658
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
5759
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
@@ -133,9 +135,11 @@
133135
GGML_METAL_ADD_KERNEL(diag_mask_inf);
134136
GGML_METAL_ADD_KERNEL(get_rows_f16);
135137
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
138+
GGML_METAL_ADD_KERNEL(get_rows_q4_k);
136139
GGML_METAL_ADD_KERNEL(rms_norm);
137140
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
138141
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
142+
GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
139143
GGML_METAL_ADD_KERNEL(rope);
140144
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
141145
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
@@ -517,7 +521,20 @@ void ggml_metal_graph_compute(
517521
nth1 = 4;
518522
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
519523
} break;
520-
default: GGML_ASSERT(false && "not implemented");
524+
case GGML_TYPE_Q4_K:
525+
{
526+
GGML_ASSERT(ne02 == 1);
527+
GGML_ASSERT(ne12 == 1);
528+
529+
nth0 = 4;
530+
nth1 = 16;
531+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
532+
} break;
533+
default:
534+
{
535+
fprintf(stderr, "Asserting on type %d\n",(int)src0t);
536+
GGML_ASSERT(false && "not implemented");
537+
}
521538
};
522539

523540

@@ -540,6 +557,9 @@ void ggml_metal_graph_compute(
540557
if (src0t == GGML_TYPE_Q4_0) {
541558
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
542559
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
560+
} else if (src0t == GGML_TYPE_Q4_K) {
561+
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
562+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
543563
} else {
544564
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
545565
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -555,6 +575,7 @@ void ggml_metal_graph_compute(
555575
switch (src0->type) {
556576
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
557577
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
578+
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
558579
default: GGML_ASSERT(false && "not implemented");
559580
}
560581

ggml-metal.metal

+162
Original file line numberDiff line numberDiff line change
@@ -503,3 +503,165 @@ kernel void kernel_cpy_f32_f32(
503503
dst_data[i00] = src[0];
504504
}
505505
}
506+
507+
//============================================ k-quants ======================================================
508+
509+
#define QK_K 256
510+
511+
typedef struct {
512+
half d; // super-block scale for quantized scales
513+
half dmin; // super-block scale for quantized mins
514+
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
515+
uint8_t qs[QK_K/2]; // 4--bit quants
516+
} block_q4_k;
517+
518+
static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
519+
uchar4 r;
520+
if (j < 4) {
521+
r[0] = q[j+0] & 63; r[1] = q[j+4] & 63;
522+
r[2] = q[j+1] & 63; r[3] = q[j+5] & 63;
523+
} else {
524+
r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
525+
r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
526+
r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
527+
r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
528+
}
529+
return r;
530+
}
531+
532+
static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
533+
assert(k % QK_K == 0);
534+
const int nb = k / QK_K;
535+
536+
for (int i = 0; i < nb; i++) {
537+
538+
const float d = x[i].d;
539+
const float min = x[i].dmin;
540+
541+
device const uint8_t * q = x[i].qs;
542+
device const uint8_t * scales = x[i].scales;
543+
544+
int is = 0;
545+
for (int j = 0; j < QK_K; j += 64) {
546+
const uchar4 sc = get_scale_min_k4(is, scales);
547+
const float d1 = d * sc[0]; const float m1 = min * sc[1];
548+
const float d2 = d * sc[2]; const float m2 = min * sc[3];
549+
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
550+
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
551+
q += 32; is += 2;
552+
}
553+
554+
}
555+
}
556+
557+
kernel void kernel_get_rows_q4_k(
558+
device const void * src0,
559+
device const int * src1,
560+
device float * dst,
561+
constant int64_t & ne00,
562+
constant uint64_t & nb01,
563+
constant uint64_t & nb1,
564+
uint tpig[[thread_position_in_grid]]) {
565+
const int i = tpig;
566+
const int r = ((device int32_t *) src1)[i];
567+
568+
dequantize_row_q4_k(
569+
(device const block_q4_k *) ((device char *) src0 + r*nb01),
570+
(device float *) ((device char *) dst + i*nb1), ne00);
571+
}
572+
573+
kernel void kernel_mul_mat_q4_k_f32(
574+
device const void * src0,
575+
device const float * src1,
576+
device float * dst,
577+
constant int64_t & ne00,
578+
constant int64_t & ne01,
579+
constant uint64_t & nb00,
580+
constant uint64_t & nb01,
581+
constant uint64_t & nb02,
582+
constant int64_t & ne10,
583+
constant int64_t & ne11,
584+
constant uint64_t & nb10,
585+
constant uint64_t & nb11,
586+
constant uint64_t & nb12,
587+
constant int64_t & ne0,
588+
constant int64_t & ne1,
589+
threadgroup float * sum [[threadgroup(0)]],
590+
uint2 tgpig[[threadgroup_position_in_grid]],
591+
uint2 tpig[[thread_position_in_grid]], // we don't use this for now
592+
uint2 tpitg[[thread_position_in_threadgroup]],
593+
uint2 tptg[[threads_per_threadgroup]]) {
594+
595+
const int nb = ne00/QK_K;
596+
597+
const int64_t r0 = tgpig.x;
598+
const int64_t r1 = tgpig.y;
599+
600+
device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
601+
device const float * yy = (device const float *) src1 + r1*ne10;
602+
603+
const uint nth = tptg.x*tptg.y;
604+
const uint ith = tptg.y*tpitg.x + tpitg.y;
605+
606+
const int tid = tpitg.y; // 0...16
607+
const int il = tid/4; // 0...3
608+
const int ir = tid%4; // 0...3
609+
const int n = 8;
610+
const int is = 2*il;
611+
612+
sum[ith] = 0.0f;
613+
614+
float sumf = 0;
615+
for (int i = tpitg.x; i < nb; i += tptg.x) {
616+
617+
device const uint8_t * q = (x + i)->qs + 32*il + n*ir;
618+
device const float * y = yy + i*QK_K + 64*il + n*ir;
619+
device const uint8_t * scales = (x + i)->scales;
620+
621+
const float dall = (float)((x + i)->d);
622+
const float dmin = (float)((x + i)->dmin);
623+
624+
const uchar4 sc = get_scale_min_k4(is, scales);
625+
626+
float4 s = {0.f, 0.f, 0.f, 0.f};
627+
for (int l = 0; l < n; ++l) {
628+
s[0] += y[l+ 0] * (q[l] & 0xF); s[1] += y[l+ 0];
629+
s[2] += y[l+32] * (q[l] >> 4); s[3] += y[l+32];
630+
}
631+
sumf += dall * (s[0] * sc[0] + s[2] * sc[2]) - dmin * (s[1] * sc[1] + s[3] * sc[3]);
632+
633+
}
634+
sum[ith] = sumf;
635+
636+
//
637+
// Accumulate the sum from all threads in the threadgroup
638+
// This version is slightly faster than the commented out one below,
639+
// which I copy-pasted from ggerganov's q4_0 dot product for metal.
640+
//
641+
threadgroup_barrier(mem_flags::mem_threadgroup);
642+
if (ith%4 == 0) {
643+
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
644+
}
645+
threadgroup_barrier(mem_flags::mem_threadgroup);
646+
if (ith%16 == 0) {
647+
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
648+
}
649+
threadgroup_barrier(mem_flags::mem_threadgroup);
650+
if (ith == 0) {
651+
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
652+
dst[r1*ne0 + r0] = sum[0];
653+
}
654+
655+
//// accumulate the sum from all threads in the threadgroup
656+
//threadgroup_barrier(mem_flags::mem_threadgroup);
657+
//for (uint i = nth/2; i > 0; i /= 2) {
658+
// if (ith < i) {
659+
// sum[ith] += sum[ith + i];
660+
// }
661+
// threadgroup_barrier(mem_flags::mem_threadgroup);
662+
//}
663+
664+
//if (ith == 0) {
665+
// dst[r1*ne0 + r0] = sum[0];
666+
//}
667+
}

0 commit comments

Comments
 (0)