Skip to content

Commit 1914017

Browse files
committed
metal : soft max, tanh, supports_op fixes
1 parent b9a77fa commit 1914017

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

Diff for: src/ggml-metal.m

+8-3
Original file line numberDiff line numberDiff line change
@@ -855,8 +855,8 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
855855
case GGML_OP_DIAG_MASK_INF:
856856
case GGML_OP_GET_ROWS:
857857
{
858-
return op->ne[0] % 4 == 0;
859-
}
858+
return op->ne[3] == 1;
859+
} break;
860860
default:
861861
return false;
862862
}
@@ -931,7 +931,10 @@ void ggml_metal_graph_compute(
931931
} break;
932932
}
933933

934-
GGML_ASSERT(ggml_metal_supports_op(dst));
934+
if (!ggml_metal_supports_op(dst)) {
935+
GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
936+
GGML_ASSERT(!"unsupported op");
937+
}
935938

936939
const int64_t ne00 = src0 ? src0->ne[0] : 0;
937940
const int64_t ne01 = src0 ? src0->ne[1] : 0;
@@ -1326,6 +1329,8 @@ void ggml_metal_graph_compute(
13261329
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
13271330
if (id_src1) {
13281331
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1332+
} else {
1333+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
13291334
}
13301335
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
13311336
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];

Diff for: src/ggml-metal.metal

+10-7
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,10 @@ kernel void kernel_relu(
252252
}
253253

254254
kernel void kernel_tanh(
255-
device const float4 * src0,
256-
device float4 * dst,
255+
device const float * src0,
256+
device float * dst,
257257
uint tpig[[thread_position_in_grid]]) {
258-
device const float4 & x = src0[tpig];
258+
device const float & x = src0[tpig];
259259
dst[tpig] = precise::tanh(x);
260260
}
261261

@@ -367,7 +367,7 @@ kernel void kernel_soft_max(
367367
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
368368

369369
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
370-
device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
370+
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
371371
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
372372

373373
// parallel max
@@ -404,6 +404,7 @@ kernel void kernel_soft_max(
404404
pdst[i00] = exp_psrc0;
405405
}
406406

407+
threadgroup_barrier(mem_flags::mem_threadgroup);
407408
float sum = simd_sum(lsum);
408409
if (ntg > N_SIMDWIDTH) {
409410
if (sgitg == 0) {
@@ -447,9 +448,9 @@ kernel void kernel_soft_max_4(
447448
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
448449
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
449450

450-
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
451-
device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
452-
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
451+
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
452+
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
453+
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
453454

454455
// parallel max
455456
float4 lmax4 = -INFINITY;
@@ -487,6 +488,7 @@ kernel void kernel_soft_max_4(
487488
}
488489

489490
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
491+
threadgroup_barrier(mem_flags::mem_threadgroup);
490492
float sum = simd_sum(lsum);
491493
if (ntg > N_SIMDWIDTH) {
492494
if (sgitg == 0) {
@@ -693,6 +695,7 @@ kernel void kernel_group_norm(
693695
tmp += src0[j];
694696
}
695697

698+
threadgroup_barrier(mem_flags::mem_threadgroup);
696699
tmp = simd_sum(tmp);
697700
if (ntg > N_SIMDWIDTH) {
698701
if (sgitg == 0) {

0 commit comments

Comments
 (0)