@@ -184,36 +184,73 @@ kernel void kernel_soft_max(
184
184
constant int64_t & ne00,
185
185
constant int64_t & ne01,
186
186
constant int64_t & ne02,
187
- uint3 tgpig[[threadgroup_position_in_grid]],
188
- uint3 tpitg[[thread_position_in_threadgroup]],
189
- uint3 ntg[[threads_per_threadgroup]]) {
190
- const int64_t i03 = tgpig[2 ];
191
- const int64_t i02 = tgpig[1 ];
192
- const int64_t i01 = tgpig[0 ];
187
+ threadgroup float * buf [[threadgroup(0 )]],
188
+ uint tgpig[[threadgroup_position_in_grid]],
189
+ uint tpitg[[thread_position_in_threadgroup]],
190
+ uint sgitg[[simdgroup_index_in_threadgroup]],
191
+ uint tiisg[[thread_index_in_simdgroup]],
192
+ uint ntg[[threads_per_threadgroup]]) {
193
+ const int64_t i03 = (tgpig) / (ne02*ne01);
194
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
195
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
193
196
194
197
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
195
198
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
196
199
197
200
// parallel max
198
- float lmax = tpitg[0 ] < ne00 ? psrc0[tpitg[0 ]] : -INFINITY;
199
- for (int i00 = tpitg[0 ] + ntg[0 ]; i00 < ne00; i00 += ntg[0 ]) {
201
+ float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
202
+
203
+ for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
200
204
lmax = MAX (lmax, psrc0[i00]);
201
205
}
202
- const float max = simd_max (lmax);
206
+
207
+ float max = simd_max (lmax);
208
+ if (tiisg == 0 ) {
209
+ buf[sgitg] = max;
210
+ }
211
+
212
+ threadgroup_barrier (mem_flags::mem_threadgroup);
213
+
214
+ // broadcast, simd group number is ntg / 32
215
+ for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
216
+ if (tpitg < i) {
217
+ buf[tpitg] = MAX (buf[tpitg], buf[tpitg + i]);
218
+ }
219
+ }
220
+
221
+ threadgroup_barrier (mem_flags::mem_threadgroup);
222
+
223
+ max = buf[0 ];
203
224
204
225
// parallel sum
205
226
float lsum = 0 .0f ;
206
- for (int i00 = tpitg[ 0 ] ; i00 < ne00; i00 += ntg[ 0 ] ) {
227
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
207
228
const float exp_psrc0 = exp (psrc0[i00] - max);
208
229
lsum += exp_psrc0;
209
230
// Remember the result of exp here. exp is expensive, so we really do not
210
- // whish to compute it twice.
231
+ // wish to compute it twice.
211
232
pdst[i00] = exp_psrc0;
212
233
}
213
234
214
- const float sum = simd_sum (lsum);
235
+ float sum = simd_sum (lsum);
236
+ if (tiisg == 0 ) {
237
+ buf[sgitg] = sum;
238
+ }
239
+
240
+ threadgroup_barrier (mem_flags::mem_threadgroup);
241
+
242
+ // broadcast, simd group number is ntg / 32
243
+ for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
244
+ if (tpitg < i) {
245
+ buf[tpitg] += buf[tpitg + i];
246
+ }
247
+ }
248
+
249
+ threadgroup_barrier (mem_flags::mem_threadgroup);
250
+
251
+ sum = buf[0 ];
215
252
216
- for (int i00 = tpitg[ 0 ] ; i00 < ne00; i00 += ntg[ 0 ] ) {
253
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
217
254
pdst[i00] /= sum;
218
255
}
219
256
}
@@ -224,37 +261,73 @@ kernel void kernel_soft_max_4(
224
261
constant int64_t & ne00,
225
262
constant int64_t & ne01,
226
263
constant int64_t & ne02,
227
- uint3 tgpig[[threadgroup_position_in_grid]],
228
- uint3 tpitg[[thread_position_in_threadgroup]],
229
- uint3 ntg[[threads_per_threadgroup]]) {
230
- const int64_t i03 = tgpig[2 ];
231
- const int64_t i02 = tgpig[1 ];
232
- const int64_t i01 = tgpig[0 ];
264
+ threadgroup float * buf [[threadgroup(0 )]],
265
+ uint tgpig[[threadgroup_position_in_grid]],
266
+ uint tpitg[[thread_position_in_threadgroup]],
267
+ uint sgitg[[simdgroup_index_in_threadgroup]],
268
+ uint tiisg[[thread_index_in_simdgroup]],
269
+ uint ntg[[threads_per_threadgroup]]) {
270
+ const int64_t i03 = (tgpig) / (ne02*ne01);
271
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
272
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
233
273
234
274
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
235
275
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
236
276
237
277
// parallel max
238
- float4 lmax4 = tpitg[0 ] < ne00/4 ? psrc4[tpitg[0 ]] : -INFINITY;
239
- for (int i00 = tpitg[0 ] + ntg[0 ]; i00 < ne00/4 ; i00 += ntg[0 ]) {
278
+ float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
279
+
280
+ for (int i00 = tpitg + ntg; i00 < ne00/4 ; i00 += ntg) {
240
281
lmax4 = fmax (lmax4, psrc4[i00]);
241
282
}
242
- float lmax = MAX (MAX (lmax4[0 ], lmax4[1 ]), MAX (lmax4[2 ], lmax4[3 ]));
243
283
244
- const float max = simd_max (lmax);
284
+ const float lmax = MAX (MAX (lmax4[0 ], lmax4[1 ]), MAX (lmax4[2 ], lmax4[3 ]));
285
+ float max = simd_max (lmax);
286
+ if (tiisg == 0 ) {
287
+ buf[sgitg] = max;
288
+ }
289
+
290
+ threadgroup_barrier (mem_flags::mem_threadgroup);
291
+
292
+ // broadcast, simd group number is ntg / 32
293
+ for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
294
+ if (tpitg < i) {
295
+ buf[tpitg] = MAX (buf[tpitg], buf[tpitg + i]);
296
+ }
297
+ }
298
+
299
+ threadgroup_barrier (mem_flags::mem_threadgroup);
300
+
301
+ max = buf[0 ];
245
302
246
303
// parallel sum
247
304
float4 lsum4 = 0 .0f ;
248
- for (int i00 = tpitg[ 0 ] ; i00 < ne00/4 ; i00 += ntg[ 0 ] ) {
305
+ for (int i00 = tpitg; i00 < ne00/4 ; i00 += ntg) {
249
306
const float4 exp_psrc4 = exp (psrc4[i00] - max);
250
307
lsum4 += exp_psrc4;
251
308
pdst4[i00] = exp_psrc4;
252
309
}
253
- float lsum = lsum4[0 ] + lsum4[1 ] + lsum4[2 ] + lsum4[3 ];
254
310
255
- const float sum = simd_sum (lsum);
311
+ const float lsum = lsum4[0 ] + lsum4[1 ] + lsum4[2 ] + lsum4[3 ];
312
+ float sum = simd_sum (lsum);
313
+ if (tiisg == 0 ) {
314
+ buf[sgitg] = sum;
315
+ }
316
+
317
+ threadgroup_barrier (mem_flags::mem_threadgroup);
318
+
319
+ // broadcast, simd group number is ntg / 32
320
+ for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
321
+ if (tpitg < i) {
322
+ buf[tpitg] += buf[tpitg + i];
323
+ }
324
+ }
325
+
326
+ threadgroup_barrier (mem_flags::mem_threadgroup);
327
+
328
+ sum = buf[0 ];
256
329
257
- for (int i00 = tpitg[ 0 ] ; i00 < ne00/4 ; i00 += ntg[ 0 ] ) {
330
+ for (int i00 = tpitg; i00 < ne00/4 ; i00 += ntg) {
258
331
pdst4[i00] /= sum;
259
332
}
260
333
}
@@ -274,7 +347,7 @@ kernel void kernel_diag_mask_inf(
274
347
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
275
348
} else {
276
349
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
277
- }
350
+ }
278
351
}
279
352
280
353
kernel void kernel_diag_mask_inf_8 (
0 commit comments