Skip to content

Commit e16b9fa

Browse files
authored
metal : multi-simd softmax (#3710)
ggml-ci
1 parent ff8f9a8 commit e16b9fa

File tree

2 files changed

+108
-30
lines changed

2 files changed

+108
-30
lines changed

ggml-metal.m

+7-2
Original file line numberDiff line numberDiff line change
@@ -1001,20 +1001,25 @@ void ggml_metal_graph_compute(
10011001
} break;
10021002
case GGML_OP_SOFT_MAX:
10031003
{
1004-
const int nth = MIN(32, ne00);
1004+
int nth = 32; // SIMD width
10051005

10061006
if (ne00%4 == 0) {
10071007
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
10081008
} else {
1009+
do {
1010+
nth *= 2;
1011+
} while (nth <= ne00 && nth <= 1024);
1012+
nth /= 2;
10091013
[encoder setComputePipelineState:ctx->pipeline_soft_max];
10101014
}
10111015
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
10121016
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
10131017
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
10141018
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
10151019
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1020+
[encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
10161021

1017-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1022+
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
10181023
} break;
10191024
case GGML_OP_DIAG_MASK_INF:
10201025
{

ggml-metal.metal

+101-28
Original file line numberDiff line numberDiff line change
@@ -184,36 +184,73 @@ kernel void kernel_soft_max(
184184
constant int64_t & ne00,
185185
constant int64_t & ne01,
186186
constant int64_t & ne02,
187-
uint3 tgpig[[threadgroup_position_in_grid]],
188-
uint3 tpitg[[thread_position_in_threadgroup]],
189-
uint3 ntg[[threads_per_threadgroup]]) {
190-
const int64_t i03 = tgpig[2];
191-
const int64_t i02 = tgpig[1];
192-
const int64_t i01 = tgpig[0];
187+
threadgroup float * buf [[threadgroup(0)]],
188+
uint tgpig[[threadgroup_position_in_grid]],
189+
uint tpitg[[thread_position_in_threadgroup]],
190+
uint sgitg[[simdgroup_index_in_threadgroup]],
191+
uint tiisg[[thread_index_in_simdgroup]],
192+
uint ntg[[threads_per_threadgroup]]) {
193+
const int64_t i03 = (tgpig) / (ne02*ne01);
194+
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
195+
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
193196

194197
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
195198
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
196199

197200
// parallel max
198-
float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
199-
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
201+
float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
202+
203+
for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
200204
lmax = MAX(lmax, psrc0[i00]);
201205
}
202-
const float max = simd_max(lmax);
206+
207+
float max = simd_max(lmax);
208+
if (tiisg == 0) {
209+
buf[sgitg] = max;
210+
}
211+
212+
threadgroup_barrier(mem_flags::mem_threadgroup);
213+
214+
// broadcast, simd group number is ntg / 32
215+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
216+
if (tpitg < i) {
217+
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
218+
}
219+
}
220+
221+
threadgroup_barrier(mem_flags::mem_threadgroup);
222+
223+
max = buf[0];
203224

204225
// parallel sum
205226
float lsum = 0.0f;
206-
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
227+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
207228
const float exp_psrc0 = exp(psrc0[i00] - max);
208229
lsum += exp_psrc0;
209230
// Remember the result of exp here. exp is expensive, so we really do not
210-
// whish to compute it twice.
231+
// wish to compute it twice.
211232
pdst[i00] = exp_psrc0;
212233
}
213234

214-
const float sum = simd_sum(lsum);
235+
float sum = simd_sum(lsum);
236+
if (tiisg == 0) {
237+
buf[sgitg] = sum;
238+
}
239+
240+
threadgroup_barrier(mem_flags::mem_threadgroup);
241+
242+
// broadcast, simd group number is ntg / 32
243+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
244+
if (tpitg < i) {
245+
buf[tpitg] += buf[tpitg + i];
246+
}
247+
}
248+
249+
threadgroup_barrier(mem_flags::mem_threadgroup);
250+
251+
sum = buf[0];
215252

216-
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
253+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
217254
pdst[i00] /= sum;
218255
}
219256
}
@@ -224,37 +261,73 @@ kernel void kernel_soft_max_4(
224261
constant int64_t & ne00,
225262
constant int64_t & ne01,
226263
constant int64_t & ne02,
227-
uint3 tgpig[[threadgroup_position_in_grid]],
228-
uint3 tpitg[[thread_position_in_threadgroup]],
229-
uint3 ntg[[threads_per_threadgroup]]) {
230-
const int64_t i03 = tgpig[2];
231-
const int64_t i02 = tgpig[1];
232-
const int64_t i01 = tgpig[0];
264+
threadgroup float * buf [[threadgroup(0)]],
265+
uint tgpig[[threadgroup_position_in_grid]],
266+
uint tpitg[[thread_position_in_threadgroup]],
267+
uint sgitg[[simdgroup_index_in_threadgroup]],
268+
uint tiisg[[thread_index_in_simdgroup]],
269+
uint ntg[[threads_per_threadgroup]]) {
270+
const int64_t i03 = (tgpig) / (ne02*ne01);
271+
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
272+
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
233273

234274
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
235275
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
236276

237277
// parallel max
238-
float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
239-
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
278+
float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
279+
280+
for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
240281
lmax4 = fmax(lmax4, psrc4[i00]);
241282
}
242-
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
243283

244-
const float max = simd_max(lmax);
284+
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
285+
float max = simd_max(lmax);
286+
if (tiisg == 0) {
287+
buf[sgitg] = max;
288+
}
289+
290+
threadgroup_barrier(mem_flags::mem_threadgroup);
291+
292+
// broadcast, simd group number is ntg / 32
293+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
294+
if (tpitg < i) {
295+
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
296+
}
297+
}
298+
299+
threadgroup_barrier(mem_flags::mem_threadgroup);
300+
301+
max = buf[0];
245302

246303
// parallel sum
247304
float4 lsum4 = 0.0f;
248-
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
305+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
249306
const float4 exp_psrc4 = exp(psrc4[i00] - max);
250307
lsum4 += exp_psrc4;
251308
pdst4[i00] = exp_psrc4;
252309
}
253-
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
254310

255-
const float sum = simd_sum(lsum);
311+
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
312+
float sum = simd_sum(lsum);
313+
if (tiisg == 0) {
314+
buf[sgitg] = sum;
315+
}
316+
317+
threadgroup_barrier(mem_flags::mem_threadgroup);
318+
319+
// broadcast, simd group number is ntg / 32
320+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
321+
if (tpitg < i) {
322+
buf[tpitg] += buf[tpitg + i];
323+
}
324+
}
325+
326+
threadgroup_barrier(mem_flags::mem_threadgroup);
327+
328+
sum = buf[0];
256329

257-
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
330+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
258331
pdst4[i00] /= sum;
259332
}
260333
}
@@ -274,7 +347,7 @@ kernel void kernel_diag_mask_inf(
274347
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
275348
} else {
276349
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
277-
}
350+
}
278351
}
279352

280353
kernel void kernel_diag_mask_inf_8(

0 commit comments

Comments
 (0)