Skip to content

Commit 695ad75

Browse files
committed
metal : improve clarity (minor) (#10171)
1 parent 841f27a commit 695ad75

File tree

1 file changed

+45
-31
lines changed

1 file changed

+45
-31
lines changed

ggml/src/ggml-metal.metal

+45-31
Original file line numberDiff line numberDiff line change
@@ -3356,7 +3356,7 @@ kernel void kernel_flash_attn_ext_vec(
33563356
const short D4 = D/4;
33573357
const short D16 = D/16;
33583358
const short NW = N_SIMDWIDTH;
3359-
const short NW4 = NW/4;
3359+
const short NL = NW/4;
33603360
const short SH = 2*C; // shared memory per simdgroup
33613361

33623362
const short T = D + nsg*SH; // shared memory size per query in (half)
@@ -3370,7 +3370,7 @@ kernel void kernel_flash_attn_ext_vec(
33703370
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
33713371

33723372
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
3373-
o4x4_t lo[D16/NW4];
3373+
o4x4_t lo[D16/NL];
33743374

33753375
// load heads from Q to shared memory
33763376
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
@@ -3384,7 +3384,7 @@ kernel void kernel_flash_attn_ext_vec(
33843384
}
33853385

33863386
// zero out lo
3387-
for (short i = 0; i < D16/NW4; i += NW4) {
3387+
for (short i = 0; i < D16/NL; ++i) {
33883388
lo[i] = (o4x4_t) 0.0f;
33893389
}
33903390

@@ -3400,8 +3400,8 @@ kernel void kernel_flash_attn_ext_vec(
34003400
half M = -__FLT16_MAX__/2;
34013401

34023402
// thread indices inside the simdgroup
3403-
const short tx = tiisg%8;
3404-
const short ty = tiisg/8;
3403+
const short tx = tiisg%NL;
3404+
const short ty = tiisg/NL;
34053405

34063406
// broadcast kv
34073407
//const short rk2 = ne02/ne12;
@@ -3411,10 +3411,10 @@ kernel void kernel_flash_attn_ext_vec(
34113411
const short ikv3 = iq3/(ne03/ne_12_3);
34123412

34133413
// load the queries from shared memory into local memory
3414-
q4x4_t mq[D16/NW4];
3414+
q4x4_t mq[D16/NL];
34153415

3416-
for (short ii = 0; ii < D16; ii += NW4) {
3417-
mq[ii/NW4] = sq4x4[ii + tx];
3416+
for (short ii = 0; ii < D16; ii += NL) {
3417+
mq[ii/NL] = sq4x4[ii + tx];
34183418
}
34193419

34203420
const bool has_mask = mask != q;
@@ -3455,17 +3455,17 @@ kernel void kernel_flash_attn_ext_vec(
34553455
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
34563456

34573457
#pragma unroll
3458-
for (short ii = 0; ii < D16; ii += NW4) {
3458+
for (short ii = 0; ii < D16; ii += NL) {
34593459
const short i = ii + tx;
34603460

34613461
k4x4_t mk;
34623462
deq_k(pk + i/nl_k, i%nl_k, mk);
34633463

34643464
mqk +=
3465-
dot(mq[ii/NW4][0], mk[0]) +
3466-
dot(mq[ii/NW4][1], mk[1]) +
3467-
dot(mq[ii/NW4][2], mk[2]) +
3468-
dot(mq[ii/NW4][3], mk[3]);
3465+
dot(mq[ii/NL][0], mk[0]) +
3466+
dot(mq[ii/NL][1], mk[1]) +
3467+
dot(mq[ii/NL][2], mk[2]) +
3468+
dot(mq[ii/NL][3], mk[3]);
34693469
}
34703470

34713471
// simdgroup reduce
@@ -3513,8 +3513,8 @@ kernel void kernel_flash_attn_ext_vec(
35133513

35143514
// O = diag(ms)*O
35153515
#pragma unroll
3516-
for (short ii = 0; ii < D16; ii += NW4) {
3517-
lo[ii/NW4] *= ms;
3516+
for (short ii = 0; ii < D16; ii += NL) {
3517+
lo[ii/NL] *= ms;
35183518
}
35193519
}
35203520

@@ -3529,13 +3529,13 @@ kernel void kernel_flash_attn_ext_vec(
35293529
const s4x4_t ms(ss[4*cc + ty]);
35303530

35313531
#pragma unroll
3532-
for (short ii = 0; ii < D16; ii += NW4) {
3532+
for (short ii = 0; ii < D16; ii += NL) {
35333533
const short i = ii + tx;
35343534

35353535
v4x4_t mv;
35363536
deq_v(pv4 + i/nl_v, i%nl_v, mv);
35373537

3538-
lo[ii/NW4] += mv*ms;
3538+
lo[ii/NL] += mv*ms;
35393539
}
35403540
}
35413541
}
@@ -3557,23 +3557,37 @@ kernel void kernel_flash_attn_ext_vec(
35573557
// [ 5, 13, 21, 29] -> [ 5]
35583558
// [ 6, 14, 22, 30] -> [ 6]
35593559
// [ 7, 15, 23, 31] -> [ 7]
3560-
for (short ii = 0; ii < D16; ii += NW4) {
3561-
lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 16);
3562-
lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 8);
3563-
3564-
lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 16);
3565-
lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 8);
3566-
3567-
lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 16);
3568-
lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 8);
3569-
3570-
lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 16);
3571-
lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 8);
3560+
for (short ii = 0; ii < D16; ii += NL) {
3561+
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
3562+
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8);
3563+
//lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4);
3564+
//lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2);
3565+
//lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1);
3566+
3567+
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
3568+
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8);
3569+
//lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4);
3570+
//lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2);
3571+
//lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1);
3572+
3573+
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
3574+
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8);
3575+
//lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4);
3576+
//lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2);
3577+
//lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1);
3578+
3579+
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
3580+
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8);
3581+
//lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4);
3582+
//lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2);
3583+
//lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1);
35723584
}
35733585

3586+
threadgroup_barrier(mem_flags::mem_threadgroup);
3587+
35743588
// store results to shared memory
3575-
for (short i = tiisg; i < D16; i += NW4) {
3576-
sr4x4[i] = lo[i/NW4];
3589+
for (short i = tiisg; i < D16; i += NL) {
3590+
sr4x4[i] = lo[i/NL];
35773591
}
35783592

35793593
threadgroup_barrier(mem_flags::mem_threadgroup);

0 commit comments

Comments
 (0)