@@ -269,6 +269,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
269
269
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
270
270
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
271
271
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
272
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
273
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
274
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
275
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
276
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
277
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
272
278
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
273
279
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
274
280
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
@@ -300,12 +306,14 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
300
306
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
301
307
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
302
308
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
309
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
303
310
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
304
311
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
305
312
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
306
313
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
307
314
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
308
315
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
316
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
309
317
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
310
318
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
311
319
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
@@ -585,6 +593,9 @@ @implementation GGMLMetalClass
585
593
struct ggml_metal_kernel * kernel = &ctx->kernels [e]; \
586
594
id <MTLFunction > metal_function = [metal_library newFunctionWithName: @" kernel_" #name]; \
587
595
kernel->pipeline = [device newComputePipelineStateWithFunction: metal_function error: &error]; \
596
+ GGML_LOG_INFO (" %s : loaded %-40s %16p | th_max = %4d | th_width = %4d \n " , __func__, " kernel_" #name, (void *) kernel->pipeline , \
597
+ (int ) kernel->pipeline .maxTotalThreadsPerThreadgroup , \
598
+ (int ) kernel->pipeline .threadExecutionWidth ); \
588
599
[metal_function release ]; \
589
600
if (error) { \
590
601
GGML_LOG_ERROR (" %s : error: load pipeline error: %s \n " , __func__, [[error description ] UTF8String ]); \
@@ -777,6 +788,12 @@ @implementation GGMLMetalClass
777
788
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
778
789
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
779
790
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
791
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && has_bfloat);
792
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && has_bfloat);
793
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && has_bfloat);
794
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && has_bfloat);
795
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && has_bfloat);
796
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && has_bfloat);
780
797
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
781
798
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
782
799
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
@@ -808,12 +825,14 @@ @implementation GGMLMetalClass
808
825
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
809
826
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
810
827
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
828
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && has_bfloat);
811
829
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
812
830
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
813
831
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
814
832
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
815
833
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
816
834
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
835
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && has_bfloat);
817
836
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
818
837
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
819
838
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
@@ -1111,7 +1130,7 @@ static void ggml_metal_encode_node(
1111
1130
const uint64_t nb20 = src2 ? src2->nb [0 ] : 0 ; GGML_UNUSED (nb20);
1112
1131
const uint64_t nb21 = src2 ? src2->nb [1 ] : 0 ;
1113
1132
const uint64_t nb22 = src2 ? src2->nb [2 ] : 0 ;
1114
- const uint64_t nb23 = src2 ? src2->nb [3 ] : 0 ;
1133
+ const uint64_t nb23 = src2 ? src2->nb [3 ] : 0 ; GGML_UNUSED (nb23);
1115
1134
1116
1135
const int64_t ne0 = dst ? dst->ne [0 ] : 0 ;
1117
1136
const int64_t ne1 = dst ? dst->ne [1 ] : 0 ;
@@ -3033,6 +3052,23 @@ static void ggml_metal_encode_node(
3033
3052
}
3034
3053
}
3035
3054
} break ;
3055
+ case GGML_TYPE_BF16:
3056
+ {
3057
+ switch (ne00) {
3058
+ case 64 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline ; break ;
3059
+ case 80 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline ; break ;
3060
+ case 96 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline ; break ;
3061
+ case 112 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline ; break ;
3062
+ case 128 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline ; break ;
3063
+ case 256 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline ; break ;
3064
+ default :
3065
+ {
3066
+ GGML_LOG_ERROR (" unsupported size: %lld \n " , ne00);
3067
+ GGML_LOG_ERROR (" add template specialization for this size\n " );
3068
+ GGML_ABORT (" add template specialization for this size" );
3069
+ }
3070
+ }
3071
+ } break ;
3036
3072
case GGML_TYPE_Q4_0:
3037
3073
{
3038
3074
switch (ne00) {
@@ -3133,6 +3169,7 @@ static void ggml_metal_encode_node(
3133
3169
{
3134
3170
switch (src1->type ) {
3135
3171
case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline ; break ;
3172
+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline ; break ;
3136
3173
case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline ; break ;
3137
3174
case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline ; break ;
3138
3175
case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline ; break ;
@@ -3150,6 +3187,7 @@ static void ggml_metal_encode_node(
3150
3187
{
3151
3188
switch (src1->type ) {
3152
3189
case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline ; break ;
3190
+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline ; break ;
3153
3191
case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline ; break ;
3154
3192
case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline ; break ;
3155
3193
case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline ; break ;
@@ -3194,18 +3232,15 @@ static void ggml_metal_encode_node(
3194
3232
[encoder setBytes: &nb11 length: sizeof (uint64_t ) atIndex: 14 ];
3195
3233
[encoder setBytes: &nb12 length: sizeof (uint64_t ) atIndex: 15 ];
3196
3234
[encoder setBytes: &nb13 length: sizeof (uint64_t ) atIndex: 16 ];
3197
- [encoder setBytes: &nb21 length: sizeof (uint64_t ) atIndex: 17 ];
3198
- [encoder setBytes: &nb22 length: sizeof (uint64_t ) atIndex: 18 ];
3199
- [encoder setBytes: &nb23 length: sizeof (uint64_t ) atIndex: 19 ];
3200
- [encoder setBytes: &nb31 length: sizeof (uint64_t ) atIndex: 20 ];
3201
- [encoder setBytes: &ne1 length: sizeof ( int64_t ) atIndex: 21 ];
3202
- [encoder setBytes: &ne2 length: sizeof ( int64_t ) atIndex: 22 ];
3203
- [encoder setBytes: &scale length: sizeof ( float ) atIndex: 23 ];
3204
- [encoder setBytes: &max_bias length: sizeof ( float ) atIndex: 24 ];
3205
- [encoder setBytes: &m0 length: sizeof (m0) atIndex: 25 ];
3206
- [encoder setBytes: &m1 length: sizeof (m1) atIndex: 26 ];
3207
- [encoder setBytes: &n_head_log2 length: sizeof (n_head_log2) atIndex: 27 ];
3208
- [encoder setBytes: &logit_softcap length: sizeof (logit_softcap) atIndex: 28 ];
3235
+ [encoder setBytes: &nb31 length: sizeof (uint64_t ) atIndex: 17 ];
3236
+ [encoder setBytes: &ne1 length: sizeof ( int64_t ) atIndex: 18 ];
3237
+ [encoder setBytes: &ne2 length: sizeof ( int64_t ) atIndex: 19 ];
3238
+ [encoder setBytes: &scale length: sizeof ( float ) atIndex: 20 ];
3239
+ [encoder setBytes: &max_bias length: sizeof ( float ) atIndex: 21 ];
3240
+ [encoder setBytes: &m0 length: sizeof (m0) atIndex: 22 ];
3241
+ [encoder setBytes: &m1 length: sizeof (m1) atIndex: 23 ];
3242
+ [encoder setBytes: &n_head_log2 length: sizeof (n_head_log2) atIndex: 24 ];
3243
+ [encoder setBytes: &logit_softcap length: sizeof (logit_softcap) atIndex: 25 ];
3209
3244
3210
3245
if (!use_vec_kernel) {
3211
3246
// half8x8 kernel
@@ -3216,11 +3251,14 @@ static void ggml_metal_encode_node(
3216
3251
GGML_ASSERT (nqptg % 8 == 0 );
3217
3252
GGML_ASSERT (ncpsg % 32 == 0 );
3218
3253
3254
+ // 2*(2*ncpsg + nqptg)*(nsg)
3255
+ // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
3256
+ //
3219
3257
// 16*32*(nsg)
3220
3258
// the shared memory needed for the simdgroups to load the KV cache
3221
3259
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
3222
3260
//
3223
- #define FATTN_SMEM (nsg ) (GGML_PAD((nqptg*(ne00 + 2 *(ncpsg + nqptg)*(nsg)) + 16 *32 *(nsg))*(sizeof (float )/2 ), 16 ))
3261
+ #define FATTN_SMEM (nsg ) (GGML_PAD((nqptg*(ne00 + 2 *(2 * ncpsg + nqptg)*(nsg)) + 16 *32 *(nsg))*(sizeof (float )/2 ), 16 ))
3224
3262
3225
3263
int64_t nsgmax = 2 ;
3226
3264
@@ -3254,12 +3292,12 @@ static void ggml_metal_encode_node(
3254
3292
3255
3293
// ne00 + 2*ncpsg*(nsg)
3256
3294
// for each query, we load it as f16 in shared memory (ne00)
3257
- // and store the attention scores (nqptg x ncpsg) as f32
3295
+ // and store the soft_max values and the mask
3258
3296
//
3259
- // 2* ne00*(nsg)
3260
- // each simdgroup has a full f32 head vector in shared mem to accumulate results
3297
+ // ne00*(nsg)
3298
+ // each simdgroup has a full f16 head vector in shared mem to accumulate results
3261
3299
//
3262
- #define FATTN_SMEM (nsg ) (GGML_PAD((nqptg*(ne00 + 2 *ncpsg*(nsg)) + 2 * ne00*(nsg))*(sizeof (float )/2 ), 16 ))
3300
+ #define FATTN_SMEM (nsg ) (GGML_PAD((nqptg*(ne00 + 2 *ncpsg*(nsg)) + ne00*(nsg))*(sizeof (float )/2 ), 16 ))
3263
3301
3264
3302
int64_t nsgmax = 2 ;
3265
3303
0 commit comments