Skip to content

Commit ab11b38

Browse files
committed
Added gqa8 kernel to allow llama-2-70B on metal
1 parent a113689 commit ab11b38

File tree

2 files changed

+62
-5
lines changed

2 files changed

+62
-5
lines changed

ggml-metal.m

+12-5
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
GGML_METAL_DECL_KERNEL(rms_norm);
6666
GGML_METAL_DECL_KERNEL(norm);
6767
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
68+
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_gqa8);
6869
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
6970
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
7071
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
@@ -182,6 +183,7 @@ @implementation GGMLMetalClass
182183
GGML_METAL_ADD_KERNEL(rms_norm);
183184
GGML_METAL_ADD_KERNEL(norm);
184185
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
186+
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_gqa8);
185187
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
186188
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
187189
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
@@ -718,7 +720,8 @@ void ggml_metal_graph_compute(
718720
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
719721

720722
GGML_ASSERT(ne00 == ne10);
721-
GGML_ASSERT(ne02 == ne12);
723+
int llama_2_70_gqa_step = ne02 == 8 && ne12 == 64;
724+
GGML_ASSERT(ne02 == ne12 || llama_2_70_gqa_step);
722725

723726
if (ggml_is_contiguous(src0) &&
724727
ggml_is_contiguous(src1) &&
@@ -749,8 +752,8 @@ void ggml_metal_graph_compute(
749752
// we need to do ne02 multiplications
750753
// TODO: is there a way to do this in parallel - currently very slow ..
751754
// TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
752-
for (int64_t i02 = 0; i02 < ne02; ++i02) {
753-
size_t offs_src0_cur = offs_src0 + i02*nb02;
755+
for (int64_t i02 = 0; i02 < ne12; ++i02) {
756+
size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
754757
size_t offs_src1_cur = offs_src1 + i02*nb12;
755758
size_t offs_dst_cur = offs_dst + i02*nb2;
756759

@@ -772,11 +775,15 @@ void ggml_metal_graph_compute(
772775
switch (src0t) {
773776
case GGML_TYPE_F16:
774777
{
775-
GGML_ASSERT(ne02 == ne12);
778+
GGML_ASSERT(ne02 == ne12 || llama_2_70_gqa_step);
776779

777780
nth0 = 64;
778781
nth1 = 1;
779-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
782+
if (llama_2_70_gqa_step) {
783+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_gqa8];
784+
} else {
785+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
786+
}
780787
} break;
781788
case GGML_TYPE_Q4_0:
782789
{

ggml-metal.metal

+50
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,56 @@ kernel void kernel_mul_mat_f16_f32(
552552
}
553553
}
554554

555+
kernel void kernel_mul_mat_f16_f32_gqa8(
556+
device const char * src0,
557+
device const char * src1,
558+
device float * dst,
559+
constant int64_t & ne00,
560+
constant int64_t & ne01,
561+
constant uint64_t & nb00,
562+
constant uint64_t & nb01,
563+
constant uint64_t & nb02,
564+
constant int64_t & ne10,
565+
constant int64_t & ne11,
566+
constant uint64_t & nb10,
567+
constant uint64_t & nb11,
568+
constant uint64_t & nb12,
569+
constant int64_t & ne0,
570+
constant int64_t & ne1,
571+
threadgroup float * sum [[threadgroup(0)]],
572+
uint3 tgpig[[threadgroup_position_in_grid]],
573+
uint3 tpig[[thread_position_in_grid]],
574+
uint3 tpitg[[thread_position_in_threadgroup]],
575+
uint3 tptg[[threads_per_threadgroup]]) {
576+
577+
const int64_t r0 = tgpig.x;
578+
const int64_t r1 = tgpig.y;
579+
const int64_t im = tgpig.z;
580+
581+
device const half * x = (device const half *) (src0 + r0*nb01 + im/8*nb02);
582+
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
583+
584+
sum[tpitg.x] = 0.0f;
585+
586+
for (int i = tpitg.x; i < ne00; i += tptg.x) {
587+
sum[tpitg.x] += (float) x[i] * (float) y[i];
588+
}
589+
590+
// accumulate the sum from all threads in the threadgroup
591+
threadgroup_barrier(mem_flags::mem_threadgroup);
592+
for (uint i = tptg.x/2; i > 0; i /= 2) {
593+
if (tpitg.x < i) {
594+
sum[tpitg.x] += sum[tpitg.x + i];
595+
}
596+
threadgroup_barrier(mem_flags::mem_threadgroup);
597+
}
598+
599+
if (tpitg.x == 0) {
600+
dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
601+
}
602+
}
603+
604+
555605
kernel void kernel_alibi_f32(
556606
device const float * src0,
557607
device float * dst,

0 commit comments

Comments
 (0)