|
65 | 65 | GGML_METAL_DECL_KERNEL(rms_norm);
|
66 | 66 | GGML_METAL_DECL_KERNEL(norm);
|
67 | 67 | GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
68 |
| - GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_gqa8); |
69 | 68 | GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
70 | 69 | GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
71 | 70 | GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
|
@@ -183,7 +182,6 @@ @implementation GGMLMetalClass
|
183 | 182 | GGML_METAL_ADD_KERNEL(rms_norm);
|
184 | 183 | GGML_METAL_ADD_KERNEL(norm);
|
185 | 184 | GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
186 |
| - GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_gqa8); |
187 | 185 | GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
188 | 186 | GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
189 | 187 | GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
|
@@ -720,8 +718,7 @@ void ggml_metal_graph_compute(
|
720 | 718 | // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
|
721 | 719 |
|
722 | 720 | GGML_ASSERT(ne00 == ne10);
|
723 |
| - int llama_2_70_gqa_step = ne02 == 8 && ne12 == 64; |
724 |
| - GGML_ASSERT(ne02 == ne12 || llama_2_70_gqa_step); |
| 721 | + // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere |
725 | 722 |
|
726 | 723 | if (ggml_is_contiguous(src0) &&
|
727 | 724 | ggml_is_contiguous(src1) &&
|
@@ -775,15 +772,9 @@ void ggml_metal_graph_compute(
|
775 | 772 | switch (src0t) {
|
776 | 773 | case GGML_TYPE_F16:
|
777 | 774 | {
|
778 |
| - GGML_ASSERT(ne02 == ne12 || llama_2_70_gqa_step); |
779 |
| - |
780 | 775 | nth0 = 64;
|
781 | 776 | nth1 = 1;
|
782 |
| - if (llama_2_70_gqa_step) { |
783 |
| - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_gqa8]; |
784 |
| - } else { |
785 |
| - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; |
786 |
| - } |
| 777 | + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; |
787 | 778 | } break;
|
788 | 779 | case GGML_TYPE_Q4_0:
|
789 | 780 | {
|
@@ -860,16 +851,18 @@ void ggml_metal_graph_compute(
|
860 | 851 | [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
861 | 852 | [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
862 | 853 | [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
863 |
| - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5]; |
864 |
| - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6]; |
865 |
| - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7]; |
866 |
| - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8]; |
867 |
| - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9]; |
868 |
| - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10]; |
869 |
| - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11]; |
870 |
| - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12]; |
871 |
| - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; |
872 |
| - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; |
| 854 | + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; |
| 855 | + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; |
| 856 | + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; |
| 857 | + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; |
| 858 | + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; |
| 859 | + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; |
| 860 | + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11]; |
| 861 | + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12]; |
| 862 | + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13]; |
| 863 | + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; |
| 864 | + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; |
| 865 | + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; |
873 | 866 |
|
874 | 867 | if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
875 | 868 | src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
|
0 commit comments