@@ -3145,12 +3145,16 @@ static void ggml_metal_encode_node(
3145
3145
GGML_ASSERT (nqptg % 8 == 0 );
3146
3146
GGML_ASSERT (ncpsg % 32 == 0 );
3147
3147
3148
+ // 16*32*nsgmax
3149
+ // the shared memory needed for the simdgroups to load the KV cache
3150
+ // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
3151
+ //
3152
+ #define FATTN_SMEM (nsg ) (GGML_PAD((nqptg*(ne00 + 2 *(ncpsg + nqptg)*(nsg)) + 16 *32 *(nsg))*(sizeof (float )/2 ), 16 ))
3153
+
3148
3154
int64_t nsgmax = 2 ;
3149
3155
3150
3156
while (true ) {
3151
- // 16*32*nsgmax - the shared memory needed for the simdgroups to load the KV cache
3152
- // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
3153
- const size_t smem = (nqptg*(ne00 + 2 *nsgmax*(ncpsg + nqptg)) + 16 *32 *nsgmax)*(sizeof (float )/2 );
3157
+ const size_t smem = FATTN_SMEM (nsgmax);
3154
3158
if (smem > device.maxThreadgroupMemoryLength ) {
3155
3159
break ;
3156
3160
}
@@ -3161,13 +3165,12 @@ static void ggml_metal_encode_node(
3161
3165
// simdgroups per threadgroup (a.k.a. warps)
3162
3166
const int64_t nsg = ne01 <= nqptg ? MAX (4 , MIN (nsgmax, MIN (ne11/ncpsg, (int64_t ) pipeline.maxTotalThreadsPerThreadgroup /32 ))) : 4 ;
3163
3167
3164
- const size_t smem = (nqptg*(ne00 + 2 * nsg*(ncpsg + nqptg)) + 16 * 32 *nsg)*( sizeof ( float )/ 2 );
3168
+ const size_t smem = FATTN_SMEM ( nsg);
3165
3169
3166
3170
// printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
3167
3171
GGML_ASSERT (smem <= device.maxThreadgroupMemoryLength );
3168
-
3169
- [encoder setThreadgroupMemoryLength: GGML_PAD (smem, 16 ) atIndex: 0 ];
3170
-
3172
+ [encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
3173
+ #undef FATTN_SMEM
3171
3174
[encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nqptg - 1 )/nqptg, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (32 , nsg, 1 )];
3172
3175
} else {
3173
3176
// half4x4 kernel
@@ -3178,21 +3181,41 @@ static void ggml_metal_encode_node(
3178
3181
GGML_ASSERT (nqptg % 1 == 0 );
3179
3182
GGML_ASSERT (ncpsg % 32 == 0 );
3180
3183
3184
+ // ne00 + 2*ncpsg*(nsg)
3185
+ // for each query, we load it as f16 in shared memory (ne00)
3186
+ // and store the attention scores (nqptg x ncpsg) as f32
3187
+ //
3188
+ // 2*ne00*(nsg)
3189
+ // each simdgroup has a full f32 head vector in shared mem to accumulate results
3190
+ //
3191
+ #define FATTN_SMEM (nsg ) (GGML_PAD((nqptg*(ne00 + 2 *ncpsg*(nsg)) + 2 *ne00*(nsg))*(sizeof (float )/2 ), 16 ))
3192
+
3193
+ int64_t nsgmax = 2 ;
3194
+
3195
+ while (true ) {
3196
+ const size_t smem = FATTN_SMEM (nsgmax);
3197
+ if (smem > device.maxThreadgroupMemoryLength ) {
3198
+ break ;
3199
+ }
3200
+ nsgmax *= 2 ;
3201
+ }
3202
+ nsgmax /= 2 ;
3203
+
3181
3204
// simdgroups per threadgroup (a.k.a. warps)
3182
- const int64_t nsgt = MAX (2 , MIN (ne11/ncpsg, (int64_t ) pipeline.maxTotalThreadsPerThreadgroup /32 ));
3205
+ const int64_t nsgt = MAX (2 , MIN (nsgmax, MIN ( ne11/ncpsg, (int64_t ) pipeline.maxTotalThreadsPerThreadgroup /32 ) ));
3183
3206
3184
3207
int64_t nsg = 1 ;
3185
3208
while (nsg <= nsgt) {
3186
3209
nsg *= 2 ;
3187
3210
}
3188
3211
nsg /= 2 ;
3189
3212
3190
- const size_t smem = (nqptg*(ne00 + 2 * nsg*(ncpsg + nqptg)) + 2 *nsg*ne00)*( sizeof ( float )/ 2 );
3213
+ const size_t smem = FATTN_SMEM ( nsg);
3191
3214
3192
- // printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
3215
+ // printf("smem: %zu, max: %zu, nsg = %d \n", smem, device.maxThreadgroupMemoryLength, (int) nsg );
3193
3216
GGML_ASSERT (smem <= device.maxThreadgroupMemoryLength );
3194
- [encoder setThreadgroupMemoryLength: GGML_PAD ( smem, 16 ) atIndex: 0 ];
3195
-
3217
+ [encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
3218
+ # undef FATTN_SMEM
3196
3219
[encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nqptg - 1 )/nqptg, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (32 , nsg, 1 )];
3197
3220
}
3198
3221
} break ;
0 commit comments