@@ -3356,7 +3356,7 @@ kernel void kernel_flash_attn_ext_vec(
3356
3356
const short D4 = D/4 ;
3357
3357
const short D16 = D/16 ;
3358
3358
const short NW = N_SIMDWIDTH;
3359
- const short NW4 = NW/4 ;
3359
+ const short NL = NW/4 ;
3360
3360
const short SH = 2 *C; // shared memory per simdgroup
3361
3361
3362
3362
const short T = D + nsg*SH; // shared memory size per query in (half)
@@ -3370,7 +3370,7 @@ kernel void kernel_flash_attn_ext_vec(
3370
3370
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
3371
3371
3372
3372
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
3373
- o4x4_t lo[D16/NW4 ];
3373
+ o4x4_t lo[D16/NL ];
3374
3374
3375
3375
// load heads from Q to shared memory
3376
3376
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
@@ -3384,7 +3384,7 @@ kernel void kernel_flash_attn_ext_vec(
3384
3384
}
3385
3385
3386
3386
// zero out lo
3387
- for (short i = 0 ; i < D16/NW4; i += NW4 ) {
3387
+ for (short i = 0 ; i < D16/NL; ++i ) {
3388
3388
lo[i] = (o4x4_t ) 0 .0f ;
3389
3389
}
3390
3390
@@ -3400,8 +3400,8 @@ kernel void kernel_flash_attn_ext_vec(
3400
3400
half M = -__FLT16_MAX__/2 ;
3401
3401
3402
3402
// thread indices inside the simdgroup
3403
- const short tx = tiisg%8 ;
3404
- const short ty = tiisg/8 ;
3403
+ const short tx = tiisg%NL ;
3404
+ const short ty = tiisg/NL ;
3405
3405
3406
3406
// broadcast kv
3407
3407
// const short rk2 = ne02/ne12;
@@ -3411,10 +3411,10 @@ kernel void kernel_flash_attn_ext_vec(
3411
3411
const short ikv3 = iq3/(ne03/ne_12_3);
3412
3412
3413
3413
// load the queries from shared memory into local memory
3414
- q4x4_t mq[D16/NW4 ];
3414
+ q4x4_t mq[D16/NL ];
3415
3415
3416
- for (short ii = 0 ; ii < D16; ii += NW4 ) {
3417
- mq[ii/NW4 ] = sq4x4[ii + tx];
3416
+ for (short ii = 0 ; ii < D16; ii += NL ) {
3417
+ mq[ii/NL ] = sq4x4[ii + tx];
3418
3418
}
3419
3419
3420
3420
const bool has_mask = mask != q;
@@ -3455,17 +3455,17 @@ kernel void kernel_flash_attn_ext_vec(
3455
3455
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4 *cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
3456
3456
3457
3457
#pragma unroll
3458
- for (short ii = 0 ; ii < D16; ii += NW4 ) {
3458
+ for (short ii = 0 ; ii < D16; ii += NL ) {
3459
3459
const short i = ii + tx;
3460
3460
3461
3461
k4x4_t mk;
3462
3462
deq_k (pk + i/nl_k, i%nl_k, mk);
3463
3463
3464
3464
mqk +=
3465
- dot (mq[ii/NW4 ][0 ], mk[0 ]) +
3466
- dot (mq[ii/NW4 ][1 ], mk[1 ]) +
3467
- dot (mq[ii/NW4 ][2 ], mk[2 ]) +
3468
- dot (mq[ii/NW4 ][3 ], mk[3 ]);
3465
+ dot (mq[ii/NL ][0 ], mk[0 ]) +
3466
+ dot (mq[ii/NL ][1 ], mk[1 ]) +
3467
+ dot (mq[ii/NL ][2 ], mk[2 ]) +
3468
+ dot (mq[ii/NL ][3 ], mk[3 ]);
3469
3469
}
3470
3470
3471
3471
// simdgroup reduce
@@ -3513,8 +3513,8 @@ kernel void kernel_flash_attn_ext_vec(
3513
3513
3514
3514
// O = diag(ms)*O
3515
3515
#pragma unroll
3516
- for (short ii = 0 ; ii < D16; ii += NW4 ) {
3517
- lo[ii/NW4 ] *= ms;
3516
+ for (short ii = 0 ; ii < D16; ii += NL ) {
3517
+ lo[ii/NL ] *= ms;
3518
3518
}
3519
3519
}
3520
3520
@@ -3529,13 +3529,13 @@ kernel void kernel_flash_attn_ext_vec(
3529
3529
const s4x4_t ms (ss[4 *cc + ty]);
3530
3530
3531
3531
#pragma unroll
3532
- for (short ii = 0 ; ii < D16; ii += NW4 ) {
3532
+ for (short ii = 0 ; ii < D16; ii += NL ) {
3533
3533
const short i = ii + tx;
3534
3534
3535
3535
v4x4_t mv;
3536
3536
deq_v (pv4 + i/nl_v, i%nl_v, mv);
3537
3537
3538
- lo[ii/NW4 ] += mv*ms;
3538
+ lo[ii/NL ] += mv*ms;
3539
3539
}
3540
3540
}
3541
3541
}
@@ -3557,23 +3557,37 @@ kernel void kernel_flash_attn_ext_vec(
3557
3557
// [ 5, 13, 21, 29] -> [ 5]
3558
3558
// [ 6, 14, 22, 30] -> [ 6]
3559
3559
// [ 7, 15, 23, 31] -> [ 7]
3560
- for (short ii = 0 ; ii < D16; ii += NW4) {
3561
- lo[ii/NW4][0 ] += simd_shuffle_down (lo[ii/NW4][0 ], 16 );
3562
- lo[ii/NW4][0 ] += simd_shuffle_down (lo[ii/NW4][0 ], 8 );
3563
-
3564
- lo[ii/NW4][1 ] += simd_shuffle_down (lo[ii/NW4][1 ], 16 );
3565
- lo[ii/NW4][1 ] += simd_shuffle_down (lo[ii/NW4][1 ], 8 );
3566
-
3567
- lo[ii/NW4][2 ] += simd_shuffle_down (lo[ii/NW4][2 ], 16 );
3568
- lo[ii/NW4][2 ] += simd_shuffle_down (lo[ii/NW4][2 ], 8 );
3569
-
3570
- lo[ii/NW4][3 ] += simd_shuffle_down (lo[ii/NW4][3 ], 16 );
3571
- lo[ii/NW4][3 ] += simd_shuffle_down (lo[ii/NW4][3 ], 8 );
3560
+ for (short ii = 0 ; ii < D16; ii += NL) {
3561
+ lo[ii/NL][0 ] += simd_shuffle_down (lo[ii/NL][0 ], 16 );
3562
+ lo[ii/NL][0 ] += simd_shuffle_down (lo[ii/NL][0 ], 8 );
3563
+ // lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4);
3564
+ // lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2);
3565
+ // lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1);
3566
+
3567
+ lo[ii/NL][1 ] += simd_shuffle_down (lo[ii/NL][1 ], 16 );
3568
+ lo[ii/NL][1 ] += simd_shuffle_down (lo[ii/NL][1 ], 8 );
3569
+ // lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4);
3570
+ // lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2);
3571
+ // lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1);
3572
+
3573
+ lo[ii/NL][2 ] += simd_shuffle_down (lo[ii/NL][2 ], 16 );
3574
+ lo[ii/NL][2 ] += simd_shuffle_down (lo[ii/NL][2 ], 8 );
3575
+ // lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4);
3576
+ // lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2);
3577
+ // lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1);
3578
+
3579
+ lo[ii/NL][3 ] += simd_shuffle_down (lo[ii/NL][3 ], 16 );
3580
+ lo[ii/NL][3 ] += simd_shuffle_down (lo[ii/NL][3 ], 8 );
3581
+ // lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4);
3582
+ // lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2);
3583
+ // lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1);
3572
3584
}
3573
3585
3586
+ threadgroup_barrier (mem_flags::mem_threadgroup);
3587
+
3574
3588
// store results to shared memory
3575
- for (short i = tiisg; i < D16; i += NW4 ) {
3576
- sr4x4[i] = lo[i/NW4 ];
3589
+ for (short i = tiisg; i < D16; i += NL ) {
3590
+ sr4x4[i] = lo[i/NL ];
3577
3591
}
3578
3592
3579
3593
threadgroup_barrier (mem_flags::mem_threadgroup);
0 commit comments