Skip to content

Commit 841f27a

Browse files
authored
metal : optimize FA kernels (#10171)
* ggml : add ggml_flash_attn_ext_get_prec * metal : use F16 precision in FA kernels ggml-ci * metal : minor clean-up * metal : compile-guard bf16 FA kernels ggml-ci * build : remove obsolete compile flag [no ci] * metal : prevent int overflows [no ci] * cuda : disable BF16 FA ggml-ci * metal : fix BF16 requirement for FA kernels ggml-ci * make : clean-up [no ci]
1 parent d05b312 commit 841f27a

File tree

8 files changed

+504
-345
lines changed

8 files changed

+504
-345
lines changed

examples/llama-bench/llama-bench.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ static ggml_type ggml_type_from_name(const std::string & s) {
256256
if (s == "f16") {
257257
return GGML_TYPE_F16;
258258
}
259+
if (s == "bf16") {
260+
return GGML_TYPE_BF16;
261+
}
259262
if (s == "q8_0") {
260263
return GGML_TYPE_Q8_0;
261264
}

ggml/include/ggml.h

+3
Original file line numberDiff line numberDiff line change
@@ -1746,6 +1746,9 @@ extern "C" {
17461746
struct ggml_tensor * a,
17471747
enum ggml_prec prec);
17481748

1749+
GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
1750+
const struct ggml_tensor * a);
1751+
17491752
// TODO: needs to be adapted to ggml_flash_attn_ext
17501753
GGML_API struct ggml_tensor * ggml_flash_attn_back(
17511754
struct ggml_context * ctx,

ggml/src/ggml-cuda.cu

+3
Original file line numberDiff line numberDiff line change
@@ -3159,6 +3159,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31593159
#ifndef FLASH_ATTN_AVAILABLE
31603160
return false;
31613161
#endif
3162+
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
3163+
return false;
3164+
}
31623165
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
31633166
return true;
31643167
}

ggml/src/ggml-cuda/fattn.cu

+5-5
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
1313
const ggml_tensor * KQV = dst;
1414
const ggml_tensor * Q = dst->src[0];
1515

16-
const int32_t precision = KQV->op_params[3];
16+
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
1717

18-
if (precision != GGML_PREC_DEFAULT) {
18+
if (prec != GGML_PREC_DEFAULT) {
1919
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
2020
constexpr int cols_per_block = 16;
2121
switch (Q->ne[0]) {
@@ -301,11 +301,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
301301

302302
ggml_cuda_set_device(ctx.device);
303303
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
304-
const int32_t precision = KQV->op_params[3];
304+
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
305305

306306
// On AMD the tile kernels perform poorly, use the vec kernel instead:
307307
if (cc >= CC_OFFSET_AMD) {
308-
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
308+
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
309309
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
310310
} else {
311311
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
@@ -332,7 +332,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
332332
}
333333

334334
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
335-
if (precision == GGML_PREC_DEFAULT) {
335+
if (prec == GGML_PREC_DEFAULT) {
336336
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
337337
return;
338338
} else if(Q->ne[0] <= 128) {

ggml/src/ggml-metal.m

+56-18
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
269269
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
270270
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
271271
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
272+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
273+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
274+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
275+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
276+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
277+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
272278
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
273279
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
274280
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
@@ -300,12 +306,14 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
300306
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
301307
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
302308
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
309+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
303310
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
304311
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
305312
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
306313
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
307314
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
308315
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
316+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
309317
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
310318
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
311319
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
@@ -585,6 +593,9 @@ @implementation GGMLMetalClass
585593
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
586594
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
587595
kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
596+
GGML_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
597+
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
598+
(int) kernel->pipeline.threadExecutionWidth); \
588599
[metal_function release]; \
589600
if (error) { \
590601
GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
@@ -777,6 +788,12 @@ @implementation GGMLMetalClass
777788
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
778789
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
779790
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
791+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && has_bfloat);
792+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && has_bfloat);
793+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && has_bfloat);
794+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && has_bfloat);
795+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && has_bfloat);
796+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && has_bfloat);
780797
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
781798
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
782799
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
@@ -808,12 +825,14 @@ @implementation GGMLMetalClass
808825
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
809826
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
810827
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
828+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && has_bfloat);
811829
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
812830
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
813831
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
814832
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
815833
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
816834
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
835+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && has_bfloat);
817836
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
818837
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
819838
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
@@ -1111,7 +1130,7 @@ static void ggml_metal_encode_node(
11111130
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
11121131
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
11131132
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1114-
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
1133+
const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
11151134

11161135
const int64_t ne0 = dst ? dst->ne[0] : 0;
11171136
const int64_t ne1 = dst ? dst->ne[1] : 0;
@@ -3033,6 +3052,23 @@ static void ggml_metal_encode_node(
30333052
}
30343053
}
30353054
} break;
3055+
case GGML_TYPE_BF16:
3056+
{
3057+
switch (ne00) {
3058+
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
3059+
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
3060+
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
3061+
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
3062+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
3063+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
3064+
default:
3065+
{
3066+
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3067+
GGML_LOG_ERROR("add template specialization for this size\n");
3068+
GGML_ABORT("add template specialization for this size");
3069+
}
3070+
}
3071+
} break;
30363072
case GGML_TYPE_Q4_0:
30373073
{
30383074
switch (ne00) {
@@ -3133,6 +3169,7 @@ static void ggml_metal_encode_node(
31333169
{
31343170
switch (src1->type) {
31353171
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
3172+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break;
31363173
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
31373174
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
31383175
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
@@ -3150,6 +3187,7 @@ static void ggml_metal_encode_node(
31503187
{
31513188
switch (src1->type) {
31523189
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
3190+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break;
31533191
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
31543192
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
31553193
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
@@ -3194,18 +3232,15 @@ static void ggml_metal_encode_node(
31943232
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
31953233
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
31963234
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
3197-
[encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
3198-
[encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
3199-
[encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
3200-
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
3201-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
3202-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
3203-
[encoder setBytes:&scale length:sizeof( float) atIndex:23];
3204-
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
3205-
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
3206-
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
3207-
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
3208-
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
3235+
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17];
3236+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18];
3237+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19];
3238+
[encoder setBytes:&scale length:sizeof( float) atIndex:20];
3239+
[encoder setBytes:&max_bias length:sizeof( float) atIndex:21];
3240+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:22];
3241+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:23];
3242+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24];
3243+
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
32093244

32103245
if (!use_vec_kernel) {
32113246
// half8x8 kernel
@@ -3216,11 +3251,14 @@ static void ggml_metal_encode_node(
32163251
GGML_ASSERT(nqptg % 8 == 0);
32173252
GGML_ASSERT(ncpsg % 32 == 0);
32183253

3254+
// 2*(2*ncpsg + nqptg)*(nsg)
3255+
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
3256+
//
32193257
// 16*32*(nsg)
32203258
// the shared memory needed for the simdgroups to load the KV cache
32213259
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
32223260
//
3223-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
3261+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
32243262

32253263
int64_t nsgmax = 2;
32263264

@@ -3254,12 +3292,12 @@ static void ggml_metal_encode_node(
32543292

32553293
// ne00 + 2*ncpsg*(nsg)
32563294
// for each query, we load it as f16 in shared memory (ne00)
3257-
// and store the attention scores (nqptg x ncpsg) as f32
3295+
// and store the soft_max values and the mask
32583296
//
3259-
// 2*ne00*(nsg)
3260-
// each simdgroup has a full f32 head vector in shared mem to accumulate results
3297+
// ne00*(nsg)
3298+
// each simdgroup has a full f16 head vector in shared mem to accumulate results
32613299
//
3262-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))
3300+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
32633301

32643302
int64_t nsgmax = 2;
32653303

0 commit comments

Comments
 (0)