Skip to content

Commit 133d99c

Browse files
CUDA: deduplicate FlashAttention code (#7352)
1 parent cb42c29 commit 133d99c

8 files changed

+315
-653
lines changed

ggml-cuda/common.cuh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,17 @@ static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -
477477

478478
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
479479

480+
static __device__ __forceinline__ float get_alibi_slope(
481+
const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1
482+
) {
483+
if (max_bias <= 0.0f) {
484+
return 1.0f;
485+
}
486+
const float base = h < n_head_log2 ? m0 : m1;
487+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
488+
489+
return powf(base, exph);
490+
}
480491

481492
//////////////////////
482493

ggml-cuda/fattn-common.cuh

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,44 @@
1+
#include "common.cuh"
2+
3+
#include <cstdint>
4+
15
#define FATTN_KQ_STRIDE 256
26
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
37
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
48

9+
typedef void (* fattn_kernel_t)(
10+
const char * __restrict__ Q,
11+
const char * __restrict__ K,
12+
const char * __restrict__ V,
13+
const char * __restrict__ mask,
14+
float * __restrict__ dst,
15+
float2 * __restrict__ dst_meta,
16+
const float scale,
17+
const float max_bias,
18+
const float m0,
19+
const float m1,
20+
const uint32_t n_head_log2,
21+
const int ne00,
22+
const int ne01,
23+
const int ne02,
24+
const int ne03,
25+
const int ne10,
26+
const int ne11,
27+
const int ne12,
28+
const int ne13,
29+
const int ne31,
30+
const int nb31,
31+
const int nb01,
32+
const int nb02,
33+
const int nb03,
34+
const int nb11,
35+
const int nb12,
36+
const int nb13,
37+
const int ne0,
38+
const int ne1,
39+
const int ne2,
40+
const int ne3);
41+
542
template<int D, int parallel_blocks> // D == head size
643
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
744
__launch_bounds__(D, 1)
@@ -45,3 +82,81 @@ static __global__ void flash_attn_combine_results(
4582

4683
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
4784
}
85+
86+
template <int D, int parallel_blocks>
87+
void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) {
88+
const ggml_tensor * Q = dst->src[0];
89+
const ggml_tensor * K = dst->src[1];
90+
const ggml_tensor * V = dst->src[2];
91+
92+
const ggml_tensor * mask = dst->src[3];
93+
94+
ggml_tensor * KQV = dst;
95+
96+
GGML_ASSERT(Q->type == GGML_TYPE_F32);
97+
GGML_ASSERT(K->type == GGML_TYPE_F16);
98+
GGML_ASSERT(V->type == GGML_TYPE_F16);
99+
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
100+
101+
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
102+
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
103+
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
104+
105+
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
106+
107+
ggml_cuda_pool & pool = ctx.pool();
108+
cudaStream_t main_stream = ctx.stream();
109+
110+
ggml_cuda_pool_alloc<float> dst_tmp(pool);
111+
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
112+
113+
if (parallel_blocks > 1) {
114+
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
115+
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
116+
}
117+
118+
const dim3 block_dim(WARP_SIZE, nwarps, 1);
119+
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
120+
const int shmem = 0;
121+
122+
float scale = 1.0f;
123+
float max_bias = 0.0f;
124+
125+
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
126+
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
127+
128+
const uint32_t n_head = Q->ne[2];
129+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
130+
131+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
132+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
133+
134+
fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
135+
(const char *) Q->data,
136+
(const char *) K->data,
137+
(const char *) V->data,
138+
mask ? ((const char *) mask->data) : nullptr,
139+
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
140+
scale, max_bias, m0, m1, n_head_log2,
141+
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
142+
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
143+
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
144+
Q->nb[1], Q->nb[2], Q->nb[3],
145+
K->nb[1], K->nb[2], K->nb[3],
146+
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
147+
);
148+
CUDA_CHECK(cudaGetLastError());
149+
150+
if ((parallel_blocks) == 1) {
151+
return;
152+
}
153+
154+
const dim3 block_dim_combine(D, 1, 1);
155+
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
156+
const int shmem_combine = 0;
157+
158+
flash_attn_combine_results<D, parallel_blocks>
159+
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
160+
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
161+
CUDA_CHECK(cudaGetLastError());
162+
}

ggml-cuda/fattn-tile-f16.cu

Lines changed: 26 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,8 @@ static __global__ void flash_attn_tile_ext_f16(
5454

5555
const int stride_KV2 = nb11 / sizeof(half2);
5656

57-
half slopeh = __float2half(1.0f);
58-
59-
// ALiBi
60-
if (max_bias > 0.0f) {
61-
const uint32_t h = blockIdx.y;
62-
63-
const float base = h < n_head_log2 ? m0 : m1;
64-
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
65-
66-
slopeh = __float2half(powf(base, exph));
67-
}
57+
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
58+
const half slopeh = __float2half(slopef);
6859

6960
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
7061

@@ -272,124 +263,50 @@ static __global__ void flash_attn_tile_ext_f16(
272263
#endif // FP16_AVAILABLE
273264
}
274265

275-
template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_tile_f16(
276-
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
277-
ggml_cuda_pool & pool, cudaStream_t main_stream
278-
) {
279-
ggml_cuda_pool_alloc<float> dst_tmp(pool);
280-
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
281-
282-
if (parallel_blocks > 1) {
283-
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
284-
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
285-
}
286-
287-
constexpr int nwarps = 8;
288-
const dim3 block_dim(WARP_SIZE, nwarps, 1);
289-
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
290-
const int shmem = 0;
291-
292-
float scale = 1.0f;
293-
float max_bias = 0.0f;
294-
295-
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
296-
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
297-
298-
const uint32_t n_head = Q->ne[2];
299-
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
300-
301-
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
302-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
303-
304-
flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>
305-
<<<blocks_num, block_dim, shmem, main_stream>>> (
306-
(const char *) Q->data,
307-
(const char *) K->data,
308-
(const char *) V->data,
309-
mask ? ((const char *) mask->data) : nullptr,
310-
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
311-
scale, max_bias, m0, m1, n_head_log2,
312-
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
313-
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
314-
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
315-
Q->nb[1], Q->nb[2], Q->nb[3],
316-
K->nb[1], K->nb[2], K->nb[3],
317-
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
318-
);
319-
CUDA_CHECK(cudaGetLastError());
320-
321-
if (parallel_blocks == 1) {
322-
return;
266+
template <int cols_per_block, int parallel_blocks>
267+
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
268+
const ggml_tensor * Q = dst->src[0];
269+
switch (Q->ne[0]) {
270+
case 64: {
271+
constexpr int D = 64;
272+
constexpr int nwarps = 8;
273+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
274+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
275+
} break;
276+
case 128: {
277+
constexpr int D = 128;
278+
constexpr int nwarps = 8;
279+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
280+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
281+
} break;
282+
default: {
283+
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
284+
} break;
323285
}
324-
325-
const dim3 block_dim_combine(D, 1, 1);
326-
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
327-
const int shmem_combine = 0;
328-
329-
flash_attn_combine_results<D, parallel_blocks>
330-
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
331-
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
332-
CUDA_CHECK(cudaGetLastError());
333286
}
334287

335288
void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
336-
const ggml_tensor * Q = dst->src[0];
337-
const ggml_tensor * K = dst->src[1];
338-
const ggml_tensor * V = dst->src[2];
339-
340-
const ggml_tensor * mask = dst->src[3];
341-
342-
ggml_tensor * KQV = dst;
289+
const ggml_tensor * KQV = dst;
290+
const ggml_tensor * Q = dst->src[0];
343291

344292
const int32_t precision = KQV->op_params[2];
345293
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
346-
GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
347294

348295
if (Q->ne[1] <= 16) {
349296
constexpr int cols_per_block = 16;
350297
constexpr int parallel_blocks = 4;
351-
switch (Q->ne[0]) {
352-
case 64:
353-
launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
354-
break;
355-
case 128:
356-
launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
357-
break;
358-
default:
359-
GGML_ASSERT(false);
360-
break;
361-
}
298+
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
362299
return;
363300
}
364301

365302
if (Q->ne[1] <= 32) {
366303
constexpr int cols_per_block = 32;
367304
constexpr int parallel_blocks = 4;
368-
switch (Q->ne[0]) {
369-
case 64:
370-
launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
371-
break;
372-
case 128:
373-
launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
374-
break;
375-
default:
376-
GGML_ASSERT(false);
377-
break;
378-
}
305+
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
379306
return;
380307
}
381308

382309
constexpr int cols_per_block = 32;
383310
constexpr int parallel_blocks = 1;
384-
switch (Q->ne[0]) {
385-
case 64:
386-
launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
387-
break;
388-
case 128:
389-
launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
390-
break;
391-
default:
392-
GGML_ASSERT(false);
393-
break;
394-
}
311+
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
395312
}

0 commit comments

Comments
 (0)