49
49
GGML_METAL_DECL_KERNEL (diag_mask_inf);
50
50
GGML_METAL_DECL_KERNEL (get_rows_f16);
51
51
GGML_METAL_DECL_KERNEL (get_rows_q4_0);
52
+ GGML_METAL_DECL_KERNEL (get_rows_q2_k);
53
+ GGML_METAL_DECL_KERNEL (get_rows_q4_k);
54
+ GGML_METAL_DECL_KERNEL (get_rows_q6_k);
52
55
GGML_METAL_DECL_KERNEL (rms_norm);
53
56
GGML_METAL_DECL_KERNEL (mul_mat_f16_f32);
54
57
GGML_METAL_DECL_KERNEL (mul_mat_q4_0_f32);
58
+ GGML_METAL_DECL_KERNEL (mul_mat_q2_k_f32);
59
+ GGML_METAL_DECL_KERNEL (mul_mat_q4_k_f32);
60
+ GGML_METAL_DECL_KERNEL (mul_mat_q6_k_f32);
55
61
GGML_METAL_DECL_KERNEL (rope);
56
62
GGML_METAL_DECL_KERNEL (cpy_f32_f16);
57
63
GGML_METAL_DECL_KERNEL (cpy_f32_f32);
133
139
GGML_METAL_ADD_KERNEL (diag_mask_inf);
134
140
GGML_METAL_ADD_KERNEL (get_rows_f16);
135
141
GGML_METAL_ADD_KERNEL (get_rows_q4_0);
142
+ GGML_METAL_ADD_KERNEL (get_rows_q2_k);
143
+ GGML_METAL_ADD_KERNEL (get_rows_q4_k);
144
+ GGML_METAL_ADD_KERNEL (get_rows_q6_k);
136
145
GGML_METAL_ADD_KERNEL (rms_norm);
137
146
GGML_METAL_ADD_KERNEL (mul_mat_f16_f32);
138
147
GGML_METAL_ADD_KERNEL (mul_mat_q4_0_f32);
148
+ GGML_METAL_ADD_KERNEL (mul_mat_q2_k_f32);
149
+ GGML_METAL_ADD_KERNEL (mul_mat_q4_k_f32);
150
+ GGML_METAL_ADD_KERNEL (mul_mat_q6_k_f32);
139
151
GGML_METAL_ADD_KERNEL (rope);
140
152
GGML_METAL_ADD_KERNEL (cpy_f32_f16);
141
153
GGML_METAL_ADD_KERNEL (cpy_f32_f32);
@@ -517,7 +529,38 @@ void ggml_metal_graph_compute(
517
529
nth1 = 4 ;
518
530
[encoder setComputePipelineState: ctx->pipeline_mul_mat_q4_0_f32];
519
531
} break ;
520
- default : GGML_ASSERT (false && " not implemented" );
532
+ case GGML_TYPE_Q2_K:
533
+ {
534
+ GGML_ASSERT (ne02 == 1 );
535
+ GGML_ASSERT (ne12 == 1 );
536
+
537
+ nth0 = 4 ;
538
+ nth1 = 16 ;
539
+ [encoder setComputePipelineState: ctx->pipeline_mul_mat_q2_k_f32];
540
+ } break ;
541
+ case GGML_TYPE_Q4_K:
542
+ {
543
+ GGML_ASSERT (ne02 == 1 );
544
+ GGML_ASSERT (ne12 == 1 );
545
+
546
+ nth0 = 4 ;
547
+ nth1 = 16 ;
548
+ [encoder setComputePipelineState: ctx->pipeline_mul_mat_q4_k_f32];
549
+ } break ;
550
+ case GGML_TYPE_Q6_K:
551
+ {
552
+ GGML_ASSERT (ne02 == 1 );
553
+ GGML_ASSERT (ne12 == 1 );
554
+
555
+ nth0 = 4 ;
556
+ nth1 = 16 ;
557
+ [encoder setComputePipelineState: ctx->pipeline_mul_mat_q6_k_f32];
558
+ } break ;
559
+ default :
560
+ {
561
+ fprintf (stderr, " Asserting on type %d \n " ,(int )src0t);
562
+ GGML_ASSERT (false && " not implemented" );
563
+ }
521
564
};
522
565
523
566
@@ -540,6 +583,15 @@ void ggml_metal_graph_compute(
540
583
if (src0t == GGML_TYPE_Q4_0) {
541
584
[encoder setThreadgroupMemoryLength: nth0*nth1*sizeof (float ) atIndex: 0 ];
542
585
[encoder dispatchThreadgroups: MTLSizeMake (ne01, ne11, 1 ) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
586
+ } else if (src0t == GGML_TYPE_Q2_K) {
587
+ [encoder setThreadgroupMemoryLength: nth0*nth1*sizeof (float ) atIndex: 0 ];
588
+ [encoder dispatchThreadgroups: MTLSizeMake (ne01, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
589
+ } else if (src0t == GGML_TYPE_Q4_K) {
590
+ [encoder setThreadgroupMemoryLength: nth0*nth1*sizeof (float ) atIndex: 0 ];
591
+ [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne11, 1 ) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
592
+ } else if (src0t == GGML_TYPE_Q6_K) {
593
+ [encoder setThreadgroupMemoryLength: nth0*nth1*sizeof (float ) atIndex: 0 ];
594
+ [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne11, 1 ) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
543
595
} else {
544
596
[encoder setThreadgroupMemoryLength: nth0*sizeof (float ) atIndex: 0 ];
545
597
[encoder dispatchThreadgroups: MTLSizeMake (ne01, ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
@@ -555,6 +607,9 @@ void ggml_metal_graph_compute(
555
607
switch (src0->type ) {
556
608
case GGML_TYPE_F16: [encoder setComputePipelineState: ctx->pipeline_get_rows_f16]; break ;
557
609
case GGML_TYPE_Q4_0: [encoder setComputePipelineState: ctx->pipeline_get_rows_q4_0]; break ;
610
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState: ctx->pipeline_get_rows_q2_k]; break ;
611
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState: ctx->pipeline_get_rows_q4_k]; break ;
612
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState: ctx->pipeline_get_rows_q6_k]; break ;
558
613
default : GGML_ASSERT (false && " not implemented" );
559
614
}
560
615
0 commit comments