@@ -252,10 +252,10 @@ kernel void kernel_relu(
252
252
}
253
253
254
254
kernel void kernel_tanh (
255
- device const float4 * src0,
256
- device float4 * dst,
255
+ device const float * src0,
256
+ device float * dst,
257
257
uint tpig[[thread_position_in_grid]]) {
258
- device const float4 & x = src0[tpig];
258
+ device const float & x = src0[tpig];
259
259
dst[tpig] = precise::tanh (x);
260
260
}
261
261
@@ -367,7 +367,7 @@ kernel void kernel_soft_max(
367
367
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
368
368
369
369
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
370
- device const float * pmask = src1 ? src1 + i01*ne00 : nullptr ;
370
+ device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr ;
371
371
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
372
372
373
373
// parallel max
@@ -404,6 +404,7 @@ kernel void kernel_soft_max(
404
404
pdst[i00] = exp_psrc0;
405
405
}
406
406
407
+ threadgroup_barrier (mem_flags::mem_threadgroup);
407
408
float sum = simd_sum (lsum);
408
409
if (ntg > N_SIMDWIDTH) {
409
410
if (sgitg == 0 ) {
@@ -447,9 +448,9 @@ kernel void kernel_soft_max_4(
447
448
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
448
449
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
449
450
450
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
451
- device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr ;
452
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
451
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
452
+ device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr ;
453
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
453
454
454
455
// parallel max
455
456
float4 lmax4 = -INFINITY;
@@ -487,6 +488,7 @@ kernel void kernel_soft_max_4(
487
488
}
488
489
489
490
const float lsum = lsum4[0 ] + lsum4[1 ] + lsum4[2 ] + lsum4[3 ];
491
+ threadgroup_barrier (mem_flags::mem_threadgroup);
490
492
float sum = simd_sum (lsum);
491
493
if (ntg > N_SIMDWIDTH) {
492
494
if (sgitg == 0 ) {
@@ -693,6 +695,7 @@ kernel void kernel_group_norm(
693
695
tmp += src0[j];
694
696
}
695
697
698
+ threadgroup_barrier (mem_flags::mem_threadgroup);
696
699
tmp = simd_sum (tmp);
697
700
if (ntg > N_SIMDWIDTH) {
698
701
if (sgitg == 0 ) {
0 commit comments