Skip to content

Commit d28b07c

Browse files
committed
Extend kernel_mul_mat_f16_f32 to handle gqa broadcast
1 parent fee39ec commit d28b07c

File tree

2 files changed

+17
-71
lines changed

2 files changed

+17
-71
lines changed

ggml-metal.m

+14-21
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@
6565
GGML_METAL_DECL_KERNEL(rms_norm);
6666
GGML_METAL_DECL_KERNEL(norm);
6767
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
68-
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_gqa8);
6968
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
7069
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
7170
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
@@ -183,7 +182,6 @@ @implementation GGMLMetalClass
183182
GGML_METAL_ADD_KERNEL(rms_norm);
184183
GGML_METAL_ADD_KERNEL(norm);
185184
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
186-
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_gqa8);
187185
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
188186
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
189187
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
@@ -720,8 +718,7 @@ void ggml_metal_graph_compute(
720718
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
721719

722720
GGML_ASSERT(ne00 == ne10);
723-
int llama_2_70_gqa_step = ne02 == 8 && ne12 == 64;
724-
GGML_ASSERT(ne02 == ne12 || llama_2_70_gqa_step);
721+
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
725722

726723
if (ggml_is_contiguous(src0) &&
727724
ggml_is_contiguous(src1) &&
@@ -775,15 +772,9 @@ void ggml_metal_graph_compute(
775772
switch (src0t) {
776773
case GGML_TYPE_F16:
777774
{
778-
GGML_ASSERT(ne02 == ne12 || llama_2_70_gqa_step);
779-
780775
nth0 = 64;
781776
nth1 = 1;
782-
if (llama_2_70_gqa_step) {
783-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_gqa8];
784-
} else {
785-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
786-
}
777+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
787778
} break;
788779
case GGML_TYPE_Q4_0:
789780
{
@@ -860,16 +851,18 @@ void ggml_metal_graph_compute(
860851
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
861852
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
862853
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
863-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
864-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
865-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
866-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
867-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
868-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
869-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
870-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
871-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
872-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
854+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
855+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
856+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
857+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
858+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
859+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
860+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
861+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
862+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
863+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
864+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
865+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
873866

874867
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
875868
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {

ggml-metal.metal

+3-50
Original file line numberDiff line numberDiff line change
@@ -509,11 +509,13 @@ kernel void kernel_mul_mat_f16_f32(
509509
device float * dst,
510510
constant int64_t & ne00,
511511
constant int64_t & ne01,
512+
constant int64_t & ne02,
512513
constant uint64_t & nb00,
513514
constant uint64_t & nb01,
514515
constant uint64_t & nb02,
515516
constant int64_t & ne10,
516517
constant int64_t & ne11,
518+
constant int64_t & ne12,
517519
constant uint64_t & nb10,
518520
constant uint64_t & nb11,
519521
constant uint64_t & nb12,
@@ -529,56 +531,7 @@ kernel void kernel_mul_mat_f16_f32(
529531
const int64_t r1 = tgpig.y;
530532
const int64_t im = tgpig.z;
531533

532-
device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02);
533-
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
534-
535-
sum[tpitg.x] = 0.0f;
536-
537-
for (int i = tpitg.x; i < ne00; i += tptg.x) {
538-
sum[tpitg.x] += (float) x[i] * (float) y[i];
539-
}
540-
541-
// accumulate the sum from all threads in the threadgroup
542-
threadgroup_barrier(mem_flags::mem_threadgroup);
543-
for (uint i = tptg.x/2; i > 0; i /= 2) {
544-
if (tpitg.x < i) {
545-
sum[tpitg.x] += sum[tpitg.x + i];
546-
}
547-
threadgroup_barrier(mem_flags::mem_threadgroup);
548-
}
549-
550-
if (tpitg.x == 0) {
551-
dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
552-
}
553-
}
554-
555-
kernel void kernel_mul_mat_f16_f32_gqa8(
556-
device const char * src0,
557-
device const char * src1,
558-
device float * dst,
559-
constant int64_t & ne00,
560-
constant int64_t & ne01,
561-
constant uint64_t & nb00,
562-
constant uint64_t & nb01,
563-
constant uint64_t & nb02,
564-
constant int64_t & ne10,
565-
constant int64_t & ne11,
566-
constant uint64_t & nb10,
567-
constant uint64_t & nb11,
568-
constant uint64_t & nb12,
569-
constant int64_t & ne0,
570-
constant int64_t & ne1,
571-
threadgroup float * sum [[threadgroup(0)]],
572-
uint3 tgpig[[threadgroup_position_in_grid]],
573-
uint3 tpig[[thread_position_in_grid]],
574-
uint3 tpitg[[thread_position_in_threadgroup]],
575-
uint3 tptg[[threads_per_threadgroup]]) {
576-
577-
const int64_t r0 = tgpig.x;
578-
const int64_t r1 = tgpig.y;
579-
const int64_t im = tgpig.z;
580-
581-
device const half * x = (device const half *) (src0 + r0*nb01 + im/8*nb02);
534+
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
582535
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
583536

584537
sum[tpitg.x] = 0.0f;

0 commit comments

Comments
 (0)