@@ -718,7 +718,8 @@ void ggml_metal_graph_compute(
718
718
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
719
719
720
720
GGML_ASSERT (ne00 == ne10);
721
- GGML_ASSERT (ne02 == ne12);
721
+ // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
722
+ GGML_ASSERT (ne03 == ne13);
722
723
723
724
if (ggml_is_contiguous (src0) &&
724
725
ggml_is_contiguous (src1) &&
@@ -746,11 +747,11 @@ void ggml_metal_graph_compute(
746
747
initWithDevice: ctx->device transposeLeft: false transposeRight: true
747
748
resultRows: ne11 resultColumns: ne01 interiorColumns: ne00 alpha: 1.0 beta: 0.0 ];
748
749
749
- // we need to do ne02 multiplications
750
+ // we need to do ne12 multiplications
750
751
// TODO: is there a way to do this in parallel - currently very slow ..
751
752
// TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
752
- for (int64_t i02 = 0 ; i02 < ne02 ; ++i02) {
753
- size_t offs_src0_cur = offs_src0 + i02*nb02;
753
+ for (int64_t i02 = 0 ; i02 < ne12 ; ++i02) {
754
+ size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02) *nb02; // gqa not used for now
754
755
size_t offs_src1_cur = offs_src1 + i02*nb12;
755
756
size_t offs_dst_cur = offs_dst + i02*nb2;
756
757
@@ -772,8 +773,6 @@ void ggml_metal_graph_compute(
772
773
switch (src0t) {
773
774
case GGML_TYPE_F16:
774
775
{
775
- GGML_ASSERT (ne02 == ne12);
776
-
777
776
nth0 = 64 ;
778
777
nth1 = 1 ;
779
778
[encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32];
@@ -853,16 +852,18 @@ void ggml_metal_graph_compute(
853
852
[encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
854
853
[encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 3 ];
855
854
[encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 4 ];
856
- [encoder setBytes: &nb00 length: sizeof (nb00) atIndex: 5 ];
857
- [encoder setBytes: &nb01 length: sizeof (nb01) atIndex: 6 ];
858
- [encoder setBytes: &nb02 length: sizeof (nb02) atIndex: 7 ];
859
- [encoder setBytes: &ne10 length: sizeof (ne10) atIndex: 8 ];
860
- [encoder setBytes: &ne11 length: sizeof (ne11) atIndex: 9 ];
861
- [encoder setBytes: &nb10 length: sizeof (nb10) atIndex: 10 ];
862
- [encoder setBytes: &nb11 length: sizeof (nb11) atIndex: 11 ];
863
- [encoder setBytes: &nb12 length: sizeof (nb12) atIndex: 12 ];
864
- [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 13 ];
865
- [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 14 ];
855
+ [encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 5 ];
856
+ [encoder setBytes: &nb00 length: sizeof (nb00) atIndex: 6 ];
857
+ [encoder setBytes: &nb01 length: sizeof (nb01) atIndex: 7 ];
858
+ [encoder setBytes: &nb02 length: sizeof (nb02) atIndex: 8 ];
859
+ [encoder setBytes: &ne10 length: sizeof (ne10) atIndex: 9 ];
860
+ [encoder setBytes: &ne11 length: sizeof (ne11) atIndex: 10 ];
861
+ [encoder setBytes: &ne12 length: sizeof (ne12) atIndex: 11 ];
862
+ [encoder setBytes: &nb10 length: sizeof (nb10) atIndex: 12 ];
863
+ [encoder setBytes: &nb11 length: sizeof (nb11) atIndex: 13 ];
864
+ [encoder setBytes: &nb12 length: sizeof (nb12) atIndex: 14 ];
865
+ [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 15 ];
866
+ [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 16 ];
866
867
867
868
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
868
869
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
0 commit comments