Skip to content

Commit 74a6d92

Browse files
ikawrakowKawrakow
andauthored
Metal implementation for all k_quants (#1807)
* metal : improve q4_K 28.3 -> 26.0 ms/token by avoiding a branch in the calculation of the scales. * metal : small improvement for Q4_K * metal : still optimizing Q4_K This commit pushes it down to 25.3 ms / token. The crazy idea of using 6 bits for the scales is really costly on Metal: if I remove the bit fiddling necessary to make the block scales, time goes almost to the Q4_0 23 ms/token. Before pushing the k-quants upstream I had a Q4_K variant that had used 8-bit scales. It wasn't more accurate, used 0.125 bits more per weight, was running slightly slower on the CPU (due to the larger model size and being memory bound there), and the difference was entirely negligible under CUDA. So, I decided to publish the version with 6-bit scales. Perhaps I should re-consider and change to 8-bit scales? * metal : some more optimizations Q2_K: 25.4 ms/token Q6_K: 27.3 ms/token Q4_0: 22.8 ms/token Q4_1: 23.1 ms/token * metal : Q3_K support Something is not quite right yet. * metal : Q5_K support Initial version achieves 31.2 ms/token, 210 GB/s * metal : still not able to figure out why q3_K does not work * Minor * metal : yet another failed attempt to make q3_K work * metal : optimize Q5_K 31.2 ms -> 27.8 ms. 250 GB/s. * metal : q3_K still not working Adding a heavily commented q3_K metal kernel to explain my obviously faulty logic. Perhaps someone could spot the issue? * metal : q3_K finally working Not optimized at all. What was the issue? The scales are not 4-bytes aligned, and I was accessing them with a uint32_t pointer. When I tried that on CUDA, I got an error (illegal memory access) and added a memcpy to a local array of 3 uint32_t's. But on Metal it told me there is no memcpy, so I tried accessing directly. There is no error, just garbage results. At some point I did try accessing the scales with an uint16_t pointer (the scales are for sure 2-byte aligned), but was still getting garbage. I guess, there must have been another bug. No access to scales is via a uint16_t pointer and, after starting from scratch from the C dequantize function, it finally works. * metal : Q3_K 1st optimization pass * metal : Q3_K second optimization pass - 29.6 ms/token * metal : Q3_K cleanup * metal : fixed accidentally broken Q2_K --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent e4caa8d commit 74a6d92

File tree

3 files changed

+463
-135
lines changed

3 files changed

+463
-135
lines changed

ggml-metal.m

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,18 @@
5252
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
5353
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
5454
GGML_METAL_DECL_KERNEL(get_rows_q2_k);
55+
GGML_METAL_DECL_KERNEL(get_rows_q3_k);
5556
GGML_METAL_DECL_KERNEL(get_rows_q4_k);
57+
GGML_METAL_DECL_KERNEL(get_rows_q5_k);
5658
GGML_METAL_DECL_KERNEL(get_rows_q6_k);
5759
GGML_METAL_DECL_KERNEL(rms_norm);
5860
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
5961
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
6062
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
6163
GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
64+
GGML_METAL_DECL_KERNEL(mul_mat_q3_k_f32);
6265
GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
66+
GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32);
6367
GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
6468
GGML_METAL_DECL_KERNEL(rope);
6569
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
@@ -153,14 +157,18 @@ @implementation GGMLMetalClass
153157
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
154158
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
155159
GGML_METAL_ADD_KERNEL(get_rows_q2_k);
160+
GGML_METAL_ADD_KERNEL(get_rows_q3_k);
156161
GGML_METAL_ADD_KERNEL(get_rows_q4_k);
162+
GGML_METAL_ADD_KERNEL(get_rows_q5_k);
157163
GGML_METAL_ADD_KERNEL(get_rows_q6_k);
158164
GGML_METAL_ADD_KERNEL(rms_norm);
159165
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
160166
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
161167
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
162168
GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
169+
GGML_METAL_ADD_KERNEL(mul_mat_q3_k_f32);
163170
GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
171+
GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32);
164172
GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
165173
GGML_METAL_ADD_KERNEL(rope);
166174
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
@@ -575,6 +583,15 @@ void ggml_metal_graph_compute(
575583
nth1 = 16;
576584
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
577585
} break;
586+
case GGML_TYPE_Q3_K:
587+
{
588+
GGML_ASSERT(ne02 == 1);
589+
GGML_ASSERT(ne12 == 1);
590+
591+
nth0 = 4;
592+
nth1 = 16;
593+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32];
594+
} break;
578595
case GGML_TYPE_Q4_K:
579596
{
580597
GGML_ASSERT(ne02 == 1);
@@ -584,6 +601,15 @@ void ggml_metal_graph_compute(
584601
nth1 = 16;
585602
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
586603
} break;
604+
case GGML_TYPE_Q5_K:
605+
{
606+
GGML_ASSERT(ne02 == 1);
607+
GGML_ASSERT(ne12 == 1);
608+
609+
nth0 = 4;
610+
nth1 = 16;
611+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32];
612+
} break;
587613
case GGML_TYPE_Q6_K:
588614
{
589615
GGML_ASSERT(ne02 == 1);
@@ -620,15 +646,14 @@ void ggml_metal_graph_compute(
620646
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
621647
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
622648
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
623-
} else if (src0t == GGML_TYPE_Q2_K) {
649+
}
650+
else if (src0t == GGML_TYPE_Q2_K ||
651+
src0t == GGML_TYPE_Q3_K ||
652+
src0t == GGML_TYPE_Q4_K ||
653+
src0t == GGML_TYPE_Q5_K ||
654+
src0t == GGML_TYPE_Q6_K) {
624655
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
625656
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
626-
} else if (src0t == GGML_TYPE_Q4_K) {
627-
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
628-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
629-
} else if (src0t == GGML_TYPE_Q6_K) {
630-
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
631-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
632657
} else {
633658
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
634659
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -646,7 +671,9 @@ void ggml_metal_graph_compute(
646671
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
647672
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
648673
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
674+
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break;
649675
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
676+
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break;
650677
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
651678
default: GGML_ASSERT(false && "not implemented");
652679
}

0 commit comments

Comments
 (0)