65
65
GGML_METAL_DECL_KERNEL (rms_norm);
66
66
GGML_METAL_DECL_KERNEL (norm);
67
67
GGML_METAL_DECL_KERNEL (mul_mat_f16_f32);
68
+ GGML_METAL_DECL_KERNEL (mul_mat_f16_f32_gqa8);
68
69
GGML_METAL_DECL_KERNEL (mul_mat_q4_0_f32);
69
70
GGML_METAL_DECL_KERNEL (mul_mat_q4_1_f32);
70
71
GGML_METAL_DECL_KERNEL (mul_mat_q2_K_f32);
@@ -182,6 +183,7 @@ @implementation GGMLMetalClass
182
183
GGML_METAL_ADD_KERNEL (rms_norm);
183
184
GGML_METAL_ADD_KERNEL (norm);
184
185
GGML_METAL_ADD_KERNEL (mul_mat_f16_f32);
186
+ GGML_METAL_ADD_KERNEL (mul_mat_f16_f32_gqa8);
185
187
GGML_METAL_ADD_KERNEL (mul_mat_q4_0_f32);
186
188
GGML_METAL_ADD_KERNEL (mul_mat_q4_1_f32);
187
189
GGML_METAL_ADD_KERNEL (mul_mat_q2_K_f32);
@@ -718,7 +720,8 @@ void ggml_metal_graph_compute(
718
720
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
719
721
720
722
GGML_ASSERT (ne00 == ne10);
721
- GGML_ASSERT (ne02 == ne12);
723
+ int llama_2_70_gqa_step = ne02 == 8 && ne12 == 64 ;
724
+ GGML_ASSERT (ne02 == ne12 || llama_2_70_gqa_step);
722
725
723
726
if (ggml_is_contiguous (src0) &&
724
727
ggml_is_contiguous (src1) &&
@@ -749,8 +752,8 @@ void ggml_metal_graph_compute(
749
752
// we need to do ne02 multiplications
750
753
// TODO: is there a way to do this in parallel - currently very slow ..
751
754
// TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
752
- for (int64_t i02 = 0 ; i02 < ne02 ; ++i02) {
753
- size_t offs_src0_cur = offs_src0 + i02*nb02;
755
+ for (int64_t i02 = 0 ; i02 < ne12 ; ++i02) {
756
+ size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02) *nb02; // gqa not used for now
754
757
size_t offs_src1_cur = offs_src1 + i02*nb12;
755
758
size_t offs_dst_cur = offs_dst + i02*nb2;
756
759
@@ -772,11 +775,15 @@ void ggml_metal_graph_compute(
772
775
switch (src0t) {
773
776
case GGML_TYPE_F16:
774
777
{
775
- GGML_ASSERT (ne02 == ne12);
778
+ GGML_ASSERT (ne02 == ne12 || llama_2_70_gqa_step );
776
779
777
780
nth0 = 64 ;
778
781
nth1 = 1 ;
779
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32];
782
+ if (llama_2_70_gqa_step) {
783
+ [encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32_gqa8];
784
+ } else {
785
+ [encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32];
786
+ }
780
787
} break ;
781
788
case GGML_TYPE_Q4_0:
782
789
{
0 commit comments