Skip to content

Commit 1873ff5

Browse files
mbosccebtenzzre
andauthored
metal : add gqa8 kernel to allow llama-2-70B on metal (#2459)
* Added gqa8 kernel to allow llama-2-70B on metal * Update ggml-metal.m Co-authored-by: Cebtenzzre <[email protected]> * Extend kernel_mul_mat_f16_f32 to handle gqa broadcast * Added ne03==ne13 assertion --------- Co-authored-by: Cebtenzzre <[email protected]>
1 parent 49e7cb5 commit 1873ff5

File tree

2 files changed

+21
-17
lines changed

2 files changed

+21
-17
lines changed

ggml-metal.m

+17-16
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,8 @@ void ggml_metal_graph_compute(
718718
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
719719

720720
GGML_ASSERT(ne00 == ne10);
721-
GGML_ASSERT(ne02 == ne12);
721+
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
722+
GGML_ASSERT(ne03 == ne13);
722723

723724
if (ggml_is_contiguous(src0) &&
724725
ggml_is_contiguous(src1) &&
@@ -746,11 +747,11 @@ void ggml_metal_graph_compute(
746747
initWithDevice:ctx->device transposeLeft:false transposeRight:true
747748
resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
748749

749-
// we need to do ne02 multiplications
750+
// we need to do ne12 multiplications
750751
// TODO: is there a way to do this in parallel - currently very slow ..
751752
// TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
752-
for (int64_t i02 = 0; i02 < ne02; ++i02) {
753-
size_t offs_src0_cur = offs_src0 + i02*nb02;
753+
for (int64_t i02 = 0; i02 < ne12; ++i02) {
754+
size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
754755
size_t offs_src1_cur = offs_src1 + i02*nb12;
755756
size_t offs_dst_cur = offs_dst + i02*nb2;
756757

@@ -772,8 +773,6 @@ void ggml_metal_graph_compute(
772773
switch (src0t) {
773774
case GGML_TYPE_F16:
774775
{
775-
GGML_ASSERT(ne02 == ne12);
776-
777776
nth0 = 64;
778777
nth1 = 1;
779778
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
@@ -853,16 +852,18 @@ void ggml_metal_graph_compute(
853852
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
854853
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
855854
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
856-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
857-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
858-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
859-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
860-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
861-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
862-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
863-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
864-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
865-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
855+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
856+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
857+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
858+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
859+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
860+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
861+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
862+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
863+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
864+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
865+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
866+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
866867

867868
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
868869
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {

ggml-metal.metal

+4-1
Original file line numberDiff line numberDiff line change
@@ -509,11 +509,13 @@ kernel void kernel_mul_mat_f16_f32(
509509
device float * dst,
510510
constant int64_t & ne00,
511511
constant int64_t & ne01,
512+
constant int64_t & ne02,
512513
constant uint64_t & nb00,
513514
constant uint64_t & nb01,
514515
constant uint64_t & nb02,
515516
constant int64_t & ne10,
516517
constant int64_t & ne11,
518+
constant int64_t & ne12,
517519
constant uint64_t & nb10,
518520
constant uint64_t & nb11,
519521
constant uint64_t & nb12,
@@ -529,7 +531,7 @@ kernel void kernel_mul_mat_f16_f32(
529531
const int64_t r1 = tgpig.y;
530532
const int64_t im = tgpig.z;
531533

532-
device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02);
534+
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
533535
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
534536

535537
sum[tpitg.x] = 0.0f;
@@ -552,6 +554,7 @@ kernel void kernel_mul_mat_f16_f32(
552554
}
553555
}
554556

557+
555558
kernel void kernel_alibi_f32(
556559
device const float * src0,
557560
device float * dst,

0 commit comments

Comments
 (0)