@@ -269,6 +269,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
269
269
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
270
270
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
271
271
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
272
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
273
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
274
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
275
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
276
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
277
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
272
278
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
273
279
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
274
280
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
@@ -300,12 +306,14 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
300
306
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
301
307
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
302
308
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
309
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
303
310
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
304
311
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
305
312
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
306
313
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
307
314
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
308
315
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
316
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
309
317
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
310
318
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
311
319
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
@@ -585,6 +593,9 @@ @implementation GGMLMetalClass
585
593
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
586
594
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
587
595
kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
596
+ GGML_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
597
+ (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
598
+ (int) kernel->pipeline.threadExecutionWidth); \
588
599
[metal_function release]; \
589
600
if (error) { \
590
601
GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
@@ -777,6 +788,12 @@ @implementation GGMLMetalClass
777
788
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
778
789
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
779
790
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
791
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && has_bfloat);
792
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && has_bfloat);
793
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && has_bfloat);
794
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && has_bfloat);
795
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && has_bfloat);
796
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && has_bfloat);
780
797
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
781
798
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
782
799
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
@@ -808,12 +825,14 @@ @implementation GGMLMetalClass
808
825
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
809
826
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
810
827
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
828
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && has_bfloat);
811
829
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
812
830
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
813
831
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
814
832
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
815
833
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
816
834
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
835
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && has_bfloat);
817
836
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
818
837
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
819
838
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
@@ -1111,7 +1130,7 @@ static void ggml_metal_encode_node(
1111
1130
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
1112
1131
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1113
1132
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1114
- const uint64_t nb23 = src2 ? src2->nb [3 ] : 0 ;
1133
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
1115
1134
1116
1135
const int64_t ne0 = dst ? dst->ne[0] : 0;
1117
1136
const int64_t ne1 = dst ? dst->ne[1] : 0;
@@ -3033,6 +3052,23 @@ static void ggml_metal_encode_node(
3033
3052
}
3034
3053
}
3035
3054
} break;
3055
+ case GGML_TYPE_BF16:
3056
+ {
3057
+ switch (ne00) {
3058
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
3059
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
3060
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
3061
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
3062
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
3063
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
3064
+ default:
3065
+ {
3066
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3067
+ GGML_LOG_ERROR("add template specialization for this size\n");
3068
+ GGML_ABORT("add template specialization for this size");
3069
+ }
3070
+ }
3071
+ } break;
3036
3072
case GGML_TYPE_Q4_0:
3037
3073
{
3038
3074
switch (ne00) {
@@ -3133,6 +3169,7 @@ static void ggml_metal_encode_node(
3133
3169
{
3134
3170
switch (src1->type) {
3135
3171
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
3172
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break;
3136
3173
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
3137
3174
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
3138
3175
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
@@ -3150,6 +3187,7 @@ static void ggml_metal_encode_node(
3150
3187
{
3151
3188
switch (src1->type) {
3152
3189
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
3190
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break;
3153
3191
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
3154
3192
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
3155
3193
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
@@ -3194,18 +3232,15 @@ static void ggml_metal_encode_node(
3194
3232
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
3195
3233
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
3196
3234
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
3197
- [encoder setBytes: &nb21 length: sizeof (uint64_t ) atIndex: 17 ];
3198
- [encoder setBytes: &nb22 length: sizeof (uint64_t ) atIndex: 18 ];
3199
- [encoder setBytes: &nb23 length: sizeof (uint64_t ) atIndex: 19 ];
3200
- [encoder setBytes: &nb31 length: sizeof (uint64_t ) atIndex: 20 ];
3201
- [encoder setBytes: &ne1 length: sizeof ( int64_t ) atIndex: 21 ];
3202
- [encoder setBytes: &ne2 length: sizeof ( int64_t ) atIndex: 22 ];
3203
- [encoder setBytes: &scale length: sizeof ( float ) atIndex: 23 ];
3204
- [encoder setBytes: &max_bias length: sizeof ( float ) atIndex: 24 ];
3205
- [encoder setBytes: &m0 length: sizeof (m0) atIndex: 25 ];
3206
- [encoder setBytes: &m1 length: sizeof (m1) atIndex: 26 ];
3207
- [encoder setBytes: &n_head_log2 length: sizeof (n_head_log2) atIndex: 27 ];
3208
- [encoder setBytes: &logit_softcap length: sizeof (logit_softcap) atIndex: 28 ];
3235
+ [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17];
3236
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18];
3237
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19];
3238
+ [encoder setBytes:&scale length:sizeof( float) atIndex:20];
3239
+ [encoder setBytes:&max_bias length:sizeof( float) atIndex:21];
3240
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:22];
3241
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:23];
3242
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24];
3243
+ [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
3209
3244
3210
3245
if (!use_vec_kernel) {
3211
3246
// half8x8 kernel
@@ -3216,11 +3251,14 @@ static void ggml_metal_encode_node(
3216
3251
GGML_ASSERT(nqptg % 8 == 0);
3217
3252
GGML_ASSERT(ncpsg % 32 == 0);
3218
3253
3254
+ // 2*(2*ncpsg + nqptg)*(nsg)
3255
+ // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
3256
+ //
3219
3257
// 16*32*(nsg)
3220
3258
// the shared memory needed for the simdgroups to load the KV cache
3221
3259
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
3222
3260
//
3223
- #define FATTN_SMEM (nsg ) (GGML_PAD((nqptg*(ne00 + 2 *(ncpsg + nqptg)*(nsg)) + 16 *32 *(nsg))*(sizeof (float )/2 ), 16 ))
3261
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2* ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
3224
3262
3225
3263
int64_t nsgmax = 2;
3226
3264
@@ -3254,12 +3292,12 @@ static void ggml_metal_encode_node(
3254
3292
3255
3293
// ne00 + 2*ncpsg*(nsg)
3256
3294
// for each query, we load it as f16 in shared memory (ne00)
3257
- // and store the attention scores (nqptg x ncpsg) as f32
3295
+ // and store the soft_max values and the mask
3258
3296
//
3259
- // 2* ne00*(nsg)
3260
- // each simdgroup has a full f32 head vector in shared mem to accumulate results
3297
+ // ne00*(nsg)
3298
+ // each simdgroup has a full f16 head vector in shared mem to accumulate results
3261
3299
//
3262
- #define FATTN_SMEM (nsg ) (GGML_PAD((nqptg*(ne00 + 2 *ncpsg*(nsg)) + 2 * ne00*(nsg))*(sizeof (float )/2 ), 16 ))
3300
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
3263
3301
3264
3302
int64_t nsgmax = 2;
3265
3303
0 commit comments