Skip to content

Commit d805404

Browse files
committed
metal : fix shared memory calc + reduce smem + comments
1 parent 1e12961 commit d805404

File tree

2 files changed

+40
-14
lines changed

2 files changed

+40
-14
lines changed

ggml/src/ggml-metal.m

+35-12
Original file line numberDiff line numberDiff line change
@@ -3145,12 +3145,16 @@ static void ggml_metal_encode_node(
31453145
GGML_ASSERT(nqptg % 8 == 0);
31463146
GGML_ASSERT(ncpsg % 32 == 0);
31473147

3148+
// 16*32*nsgmax
3149+
// the shared memory needed for the simdgroups to load the KV cache
3150+
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
3151+
//
3152+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
3153+
31483154
int64_t nsgmax = 2;
31493155

31503156
while (true) {
3151-
// 16*32*nsgmax - the shared memory needed for the simdgroups to load the KV cache
3152-
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
3153-
const size_t smem = (nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg)) + 16*32*nsgmax)*(sizeof(float)/2);
3157+
const size_t smem = FATTN_SMEM(nsgmax);
31543158
if (smem > device.maxThreadgroupMemoryLength) {
31553159
break;
31563160
}
@@ -3161,13 +3165,12 @@ static void ggml_metal_encode_node(
31613165
// simdgroups per threadgroup (a.k.a. warps)
31623166
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
31633167

3164-
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 16*32*nsg)*(sizeof(float)/2);
3168+
const size_t smem = FATTN_SMEM(nsg);
31653169

31663170
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
31673171
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
3168-
3169-
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
3170-
3172+
[encoder setThreadgroupMemoryLength:smem atIndex:0];
3173+
#undef FATTN_SMEM
31713174
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
31723175
} else {
31733176
// half4x4 kernel
@@ -3178,21 +3181,41 @@ static void ggml_metal_encode_node(
31783181
GGML_ASSERT(nqptg % 1 == 0);
31793182
GGML_ASSERT(ncpsg % 32 == 0);
31803183

3184+
// ne00 + 2*ncpsg*(nsg)
3185+
// for each query, we load it as f16 in shared memory (ne00)
3186+
// and store the attention scores (nqptg x ncpsg) as f32
3187+
//
3188+
// 2*ne00*(nsg)
3189+
// each simdgroup has a full f32 head vector in shared mem to accumulate results
3190+
//
3191+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))
3192+
3193+
int64_t nsgmax = 2;
3194+
3195+
while (true) {
3196+
const size_t smem = FATTN_SMEM(nsgmax);
3197+
if (smem > device.maxThreadgroupMemoryLength) {
3198+
break;
3199+
}
3200+
nsgmax *= 2;
3201+
}
3202+
nsgmax /= 2;
3203+
31813204
// simdgroups per threadgroup (a.k.a. warps)
3182-
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
3205+
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
31833206

31843207
int64_t nsg = 1;
31853208
while (nsg <= nsgt) {
31863209
nsg *= 2;
31873210
}
31883211
nsg /= 2;
31893212

3190-
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + 2*nsg*ne00)*(sizeof(float)/2);
3213+
const size_t smem = FATTN_SMEM(nsg);
31913214

3192-
//printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
3215+
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
31933216
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
3194-
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
3195-
3217+
[encoder setThreadgroupMemoryLength:smem atIndex:0];
3218+
#undef FATTN_SMEM
31963219
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
31973220
}
31983221
} break;

ggml/src/ggml-metal.metal

+5-2
Original file line numberDiff line numberDiff line change
@@ -2876,6 +2876,7 @@ kernel void kernel_flash_attn_ext(
28762876
for (short cc = 0; cc < C/8; ++cc) {
28772877
simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
28782878

2879+
// this is compile-time check, so it does not have runtime overhead
28792880
if (is_same<block_q, half4x4>::value) {
28802881
// we can read directly from global memory
28812882
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
@@ -2891,6 +2892,7 @@ kernel void kernel_flash_attn_ext(
28912892
device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
28922893

28932894
if (D16%4 == 0) {
2895+
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
28942896
half4x4 tmp;
28952897
dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
28962898
skv4[4*ty + tx] = tmp;
@@ -3009,6 +3011,7 @@ kernel void kernel_flash_attn_ext(
30093011
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
30103012

30113013
if (D16%4 == 0) {
3014+
// no need for bound checks
30123015
half4x4 tmp;
30133016
dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
30143017
skv4[4*ty + tx] = tmp;
@@ -3233,14 +3236,14 @@ kernel void kernel_flash_attn_ext_vec(
32333236
const short D16 = D/16;
32343237
const short NW = N_SIMDWIDTH;
32353238
const short NW4 = NW/4;
3236-
const short SH = (C + Q); // shared memory per simdgroup in (half)
3239+
const short SH = C; // shared memory per simdgroup in (half)
32373240

32383241
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
32393242

32403243
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
32413244
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
32423245
threadgroup half4x4 * sq44 = (threadgroup half4x4 *) (shared + 0*D); // same as above but in half4x4
3243-
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
3246+
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention
32443247
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
32453248
threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2*sgitg*D + Q*T); // scratch buffer for the results
32463249

0 commit comments

Comments
 (0)