Skip to content

Commit 01dc509

Browse files
committed
Merge branch 'master' into concedo_experimental
# Conflicts: # .devops/full.Dockerfile # .devops/main.Dockerfile # CMakeLists.txt
2 parents 0833845 + 72ff528 commit 01dc509

File tree

4 files changed

+572
-19
lines changed

4 files changed

+572
-19
lines changed

ggml-metal.m

+56-1
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,15 @@
4949
GGML_METAL_DECL_KERNEL(diag_mask_inf);
5050
GGML_METAL_DECL_KERNEL(get_rows_f16);
5151
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
52+
GGML_METAL_DECL_KERNEL(get_rows_q2_k);
53+
GGML_METAL_DECL_KERNEL(get_rows_q4_k);
54+
GGML_METAL_DECL_KERNEL(get_rows_q6_k);
5255
GGML_METAL_DECL_KERNEL(rms_norm);
5356
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
5457
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
58+
GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
59+
GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
60+
GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
5561
GGML_METAL_DECL_KERNEL(rope);
5662
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
5763
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
@@ -133,9 +139,15 @@
133139
GGML_METAL_ADD_KERNEL(diag_mask_inf);
134140
GGML_METAL_ADD_KERNEL(get_rows_f16);
135141
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
142+
GGML_METAL_ADD_KERNEL(get_rows_q2_k);
143+
GGML_METAL_ADD_KERNEL(get_rows_q4_k);
144+
GGML_METAL_ADD_KERNEL(get_rows_q6_k);
136145
GGML_METAL_ADD_KERNEL(rms_norm);
137146
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
138147
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
148+
GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
149+
GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
150+
GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
139151
GGML_METAL_ADD_KERNEL(rope);
140152
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
141153
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
@@ -517,7 +529,38 @@ void ggml_metal_graph_compute(
517529
nth1 = 4;
518530
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
519531
} break;
520-
default: GGML_ASSERT(false && "not implemented");
532+
case GGML_TYPE_Q2_K:
533+
{
534+
GGML_ASSERT(ne02 == 1);
535+
GGML_ASSERT(ne12 == 1);
536+
537+
nth0 = 4;
538+
nth1 = 16;
539+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
540+
} break;
541+
case GGML_TYPE_Q4_K:
542+
{
543+
GGML_ASSERT(ne02 == 1);
544+
GGML_ASSERT(ne12 == 1);
545+
546+
nth0 = 4;
547+
nth1 = 16;
548+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
549+
} break;
550+
case GGML_TYPE_Q6_K:
551+
{
552+
GGML_ASSERT(ne02 == 1);
553+
GGML_ASSERT(ne12 == 1);
554+
555+
nth0 = 4;
556+
nth1 = 16;
557+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
558+
} break;
559+
default:
560+
{
561+
fprintf(stderr, "Asserting on type %d\n",(int)src0t);
562+
GGML_ASSERT(false && "not implemented");
563+
}
521564
};
522565

523566

@@ -540,6 +583,15 @@ void ggml_metal_graph_compute(
540583
if (src0t == GGML_TYPE_Q4_0) {
541584
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
542585
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
586+
} else if (src0t == GGML_TYPE_Q2_K) {
587+
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
588+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
589+
} else if (src0t == GGML_TYPE_Q4_K) {
590+
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
591+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
592+
} else if (src0t == GGML_TYPE_Q6_K) {
593+
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
594+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
543595
} else {
544596
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
545597
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -555,6 +607,9 @@ void ggml_metal_graph_compute(
555607
switch (src0->type) {
556608
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
557609
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
610+
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
611+
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
612+
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
558613
default: GGML_ASSERT(false && "not implemented");
559614
}
560615

0 commit comments

Comments
 (0)