@@ -54,17 +54,8 @@ static __global__ void flash_attn_tile_ext_f16(
54
54
55
55
const int stride_KV2 = nb11 / sizeof (half2);
56
56
57
- half slopeh = __float2half (1 .0f );
58
-
59
- // ALiBi
60
- if (max_bias > 0 .0f ) {
61
- const uint32_t h = blockIdx .y ;
62
-
63
- const float base = h < n_head_log2 ? m0 : m1;
64
- const int exph = h < n_head_log2 ? h + 1 : 2 *(h - n_head_log2) + 1 ;
65
-
66
- slopeh = __float2half (powf (base, exph));
67
- }
57
+ const float slopef = get_alibi_slope (max_bias, blockIdx .y , n_head_log2, m0, m1);
58
+ const half slopeh = __float2half (slopef);
68
59
69
60
static_assert (D % (2 *WARP_SIZE) == 0 , " D not divisible by 2*WARP_SIZE == 64." );
70
61
@@ -272,124 +263,50 @@ static __global__ void flash_attn_tile_ext_f16(
272
263
#endif // FP16_AVAILABLE
273
264
}
274
265
275
- template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_tile_f16 (
276
- const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
277
- ggml_cuda_pool & pool, cudaStream_t main_stream
278
- ) {
279
- ggml_cuda_pool_alloc<float > dst_tmp (pool);
280
- ggml_cuda_pool_alloc<float2 > dst_tmp_meta (pool);
281
-
282
- if (parallel_blocks > 1 ) {
283
- dst_tmp.alloc (parallel_blocks*ggml_nelements (KQV));
284
- dst_tmp_meta.alloc (parallel_blocks*ggml_nrows (KQV));
285
- }
286
-
287
- constexpr int nwarps = 8 ;
288
- const dim3 block_dim (WARP_SIZE, nwarps, 1 );
289
- const dim3 blocks_num (parallel_blocks*((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block), Q->ne [2 ], Q->ne [3 ]);
290
- const int shmem = 0 ;
291
-
292
- float scale = 1 .0f ;
293
- float max_bias = 0 .0f ;
294
-
295
- memcpy (&scale, (float *) KQV->op_params + 0 , sizeof (float ));
296
- memcpy (&max_bias, (float *) KQV->op_params + 1 , sizeof (float ));
297
-
298
- const uint32_t n_head = Q->ne [2 ];
299
- const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
300
-
301
- const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
302
- const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
303
-
304
- flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>
305
- <<<blocks_num, block_dim, shmem, main_stream>>> (
306
- (const char *) Q->data ,
307
- (const char *) K->data ,
308
- (const char *) V->data ,
309
- mask ? ((const char *) mask->data ) : nullptr ,
310
- parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr , dst_tmp_meta.ptr ,
311
- scale, max_bias, m0, m1, n_head_log2,
312
- Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
313
- K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
314
- mask ? mask->ne [1 ] : 0 , mask ? mask->nb [1 ] : 0 ,
315
- Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
316
- K->nb [1 ], K->nb [2 ], K->nb [3 ],
317
- KQV->ne [0 ], KQV->ne [1 ], KQV->ne [2 ], KQV->ne [3 ]
318
- );
319
- CUDA_CHECK (cudaGetLastError ());
320
-
321
- if (parallel_blocks == 1 ) {
322
- return ;
266
+ template <int cols_per_block, int parallel_blocks>
267
+ void launch_fattn_tile_f16_64_128 (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
268
+ const ggml_tensor * Q = dst->src [0 ];
269
+ switch (Q->ne [0 ]) {
270
+ case 64 : {
271
+ constexpr int D = 64 ;
272
+ constexpr int nwarps = 8 ;
273
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
274
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
275
+ } break ;
276
+ case 128 : {
277
+ constexpr int D = 128 ;
278
+ constexpr int nwarps = 8 ;
279
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
280
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
281
+ } break ;
282
+ default : {
283
+ GGML_ASSERT (false && " FlashAttention without tensor cores only supports head sizes 64 and 128." );
284
+ } break ;
323
285
}
324
-
325
- const dim3 block_dim_combine (D, 1 , 1 );
326
- const dim3 blocks_num_combine (Q->ne [1 ], blocks_num.y , blocks_num.z );
327
- const int shmem_combine = 0 ;
328
-
329
- flash_attn_combine_results<D, parallel_blocks>
330
- <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
331
- (dst_tmp.ptr , dst_tmp_meta.ptr , (float *) KQV->data );
332
- CUDA_CHECK (cudaGetLastError ());
333
286
}
334
287
335
288
void ggml_cuda_flash_attn_ext_tile_f16 (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
336
- const ggml_tensor * Q = dst->src [0 ];
337
- const ggml_tensor * K = dst->src [1 ];
338
- const ggml_tensor * V = dst->src [2 ];
339
-
340
- const ggml_tensor * mask = dst->src [3 ];
341
-
342
- ggml_tensor * KQV = dst;
289
+ const ggml_tensor * KQV = dst;
290
+ const ggml_tensor * Q = dst->src [0 ];
343
291
344
292
const int32_t precision = KQV->op_params [2 ];
345
293
GGML_ASSERT (precision == GGML_PREC_DEFAULT);
346
- GGML_ASSERT (Q->ne [0 ] == 64 || Q->ne [0 ] == 128 && " FlashAttention without tensor cores only supports head sizes 64 and 128." );
347
294
348
295
if (Q->ne [1 ] <= 16 ) {
349
296
constexpr int cols_per_block = 16 ;
350
297
constexpr int parallel_blocks = 4 ;
351
- switch (Q->ne [0 ]) {
352
- case 64 :
353
- launch_fattn_tile_f16< 64 , cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool (), ctx.stream ());
354
- break ;
355
- case 128 :
356
- launch_fattn_tile_f16<128 , cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool (), ctx.stream ());
357
- break ;
358
- default :
359
- GGML_ASSERT (false );
360
- break ;
361
- }
298
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
362
299
return ;
363
300
}
364
301
365
302
if (Q->ne [1 ] <= 32 ) {
366
303
constexpr int cols_per_block = 32 ;
367
304
constexpr int parallel_blocks = 4 ;
368
- switch (Q->ne [0 ]) {
369
- case 64 :
370
- launch_fattn_tile_f16< 64 , cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool (), ctx.stream ());
371
- break ;
372
- case 128 :
373
- launch_fattn_tile_f16<128 , cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool (), ctx.stream ());
374
- break ;
375
- default :
376
- GGML_ASSERT (false );
377
- break ;
378
- }
305
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
379
306
return ;
380
307
}
381
308
382
309
constexpr int cols_per_block = 32 ;
383
310
constexpr int parallel_blocks = 1 ;
384
- switch (Q->ne [0 ]) {
385
- case 64 :
386
- launch_fattn_tile_f16< 64 , cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool (), ctx.stream ());
387
- break ;
388
- case 128 :
389
- launch_fattn_tile_f16<128 , cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool (), ctx.stream ());
390
- break ;
391
- default :
392
- GGML_ASSERT (false );
393
- break ;
394
- }
311
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
395
312
}
0 commit comments