Skip to content

Commit 478dbc9

Browse files
CUDA: fix Pascal FA, deq. KV to FP16 for batch > 8
1 parent 9b59641 commit 478dbc9

7 files changed

+73
-27
lines changed

ggml-cuda/fattn-common.cuh

+59-15
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "common.cuh"
4+
#include "convert.cuh"
45
#include "vecdotq.cuh"
56

67
#include <cstdint>
@@ -53,7 +54,7 @@ typedef float (*vec_dot_KQ_f32_t)(
5354
template<typename T, int D>
5455
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
5556
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
56-
#if __CUDA_ARCH__ > MIN_CC_DP4A
57+
#if __CUDA_ARCH__ >= MIN_CC_DP4A
5758

5859
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
5960
GGML_UNUSED(Q_v);
@@ -95,13 +96,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
9596
GGML_UNUSED(Q_q8);
9697
GGML_UNUSED(Q_ds_v);
9798
NO_DEVICE_CODE;
98-
#endif // __CUDA_ARCH__ > MIN_CC_DP4A
99+
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
99100
}
100101

101102
template<typename T, int D>
102103
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
103104
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
104-
#if __CUDA_ARCH__ > MIN_CC_DP4A
105+
#if __CUDA_ARCH__ >= MIN_CC_DP4A
105106

106107
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
107108
GGML_UNUSED(Q_v);
@@ -147,13 +148,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
147148
GGML_UNUSED(Q_q8);
148149
GGML_UNUSED(Q_ds_v);
149150
NO_DEVICE_CODE;
150-
#endif // __CUDA_ARCH__ > MIN_CC_DP4A
151+
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
151152
}
152153

153154
template<typename T, int D>
154155
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
155156
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
156-
#if __CUDA_ARCH__ > MIN_CC_DP4A
157+
#if __CUDA_ARCH__ >= MIN_CC_DP4A
157158

158159
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
159160
GGML_UNUSED(Q_v);
@@ -202,13 +203,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
202203
GGML_UNUSED(Q_q8);
203204
GGML_UNUSED(Q_ds_v);
204205
NO_DEVICE_CODE;
205-
#endif // __CUDA_ARCH__ > MIN_CC_DP4A
206+
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
206207
}
207208

208209
template<typename T, int D>
209210
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
210211
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
211-
#if __CUDA_ARCH__ > MIN_CC_DP4A
212+
#if __CUDA_ARCH__ >= MIN_CC_DP4A
212213

213214
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
214215
GGML_UNUSED(Q_v);
@@ -261,13 +262,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
261262
GGML_UNUSED(Q_q8);
262263
GGML_UNUSED(Q_ds_v);
263264
NO_DEVICE_CODE;
264-
#endif // __CUDA_ARCH__ > MIN_CC_DP4A
265+
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
265266
}
266267

267268
template <typename T, int D>
268269
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
269270
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
270-
#if __CUDA_ARCH__ > MIN_CC_DP4A
271+
#if __CUDA_ARCH__ >= MIN_CC_DP4A
271272

272273
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
273274
GGML_UNUSED(Q_v);
@@ -302,7 +303,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
302303
GGML_UNUSED(Q_q8);
303304
GGML_UNUSED(Q_ds_v);
304305
NO_DEVICE_CODE;
305-
#endif // __CUDA_ARCH__ > MIN_CC_DP4A
306+
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
306307
}
307308

308309
template <typename T, int D>
@@ -620,7 +621,10 @@ static void on_no_fattn_vec_case(const int D) {
620621
}
621622

622623
template <int D, int parallel_blocks>
623-
void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) {
624+
void launch_fattn(
625+
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
626+
const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
627+
) {
624628
const ggml_tensor * Q = dst->src[0];
625629
const ggml_tensor * K = dst->src[1];
626630
const ggml_tensor * V = dst->src[2];
@@ -641,9 +645,49 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
641645
ggml_cuda_pool & pool = ctx.pool();
642646
cudaStream_t main_stream = ctx.stream();
643647

648+
ggml_cuda_pool_alloc<half> K_f16(pool);
649+
ggml_cuda_pool_alloc<half> V_f16(pool);
644650
ggml_cuda_pool_alloc<float> dst_tmp(pool);
645651
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
646652

653+
char * K_data = (char *) K->data;
654+
size_t nb11 = K->nb[1];
655+
size_t nb12 = K->nb[2];
656+
size_t nb13 = K->nb[3];
657+
658+
char * V_data = (char *) V->data;
659+
size_t nb21 = V->nb[1];
660+
size_t nb22 = V->nb[2];
661+
size_t nb23 = V->nb[3];
662+
663+
if (need_f16_K && K->type != GGML_TYPE_F16) {
664+
K_f16.alloc(ggml_nelements(K));
665+
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
666+
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
667+
K_data = (char *) K_f16.ptr;
668+
669+
const size_t bs = ggml_blck_size(K->type);
670+
const size_t ts = ggml_type_size(K->type);
671+
672+
nb11 = nb11*bs*sizeof(half)/ts;
673+
nb12 = nb12*bs*sizeof(half)/ts;
674+
nb13 = nb13*bs*sizeof(half)/ts;
675+
}
676+
677+
if (need_f16_V && V->type != GGML_TYPE_F16) {
678+
V_f16.alloc(ggml_nelements(V));
679+
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
680+
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
681+
V_data = (char *) V_f16.ptr;
682+
683+
const size_t bs = ggml_blck_size(V->type);
684+
const size_t ts = ggml_type_size(V->type);
685+
686+
nb21 = nb21*bs*sizeof(half)/ts;
687+
nb22 = nb22*bs*sizeof(half)/ts;
688+
nb23 = nb23*bs*sizeof(half)/ts;
689+
}
690+
647691
if (parallel_blocks > 1) {
648692
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
649693
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
@@ -667,17 +711,17 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
667711

668712
fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
669713
(const char *) Q->data,
670-
(const char *) K->data,
671-
(const char *) V->data,
714+
K_data,
715+
V_data,
672716
mask ? ((const char *) mask->data) : nullptr,
673717
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
674718
scale, max_bias, m0, m1, n_head_log2,
675719
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
676720
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
677721
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
678722
Q->nb[1], Q->nb[2], Q->nb[3],
679-
K->nb[1], K->nb[2], K->nb[3],
680-
V->nb[1], V->nb[2], V->nb[3],
723+
nb11, nb12, nb13,
724+
nb21, nb22, nb23,
681725
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
682726
);
683727
CUDA_CHECK(cudaGetLastError());

ggml-cuda/fattn-tile-f16.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -278,13 +278,13 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
278278
constexpr int D = 64;
279279
constexpr int nwarps = 8;
280280
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
281-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
281+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
282282
} break;
283283
case 128: {
284284
constexpr int D = 128;
285285
constexpr int nwarps = 8;
286286
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
287-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
287+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
288288
} break;
289289
default: {
290290
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");

ggml-cuda/fattn-tile-f32.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -275,13 +275,13 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
275275
constexpr int D = 64;
276276
constexpr int nwarps = 8;
277277
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
278-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
278+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
279279
} break;
280280
case 128: {
281281
constexpr int D = 128;
282282
constexpr int nwarps = 8;
283283
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
284-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
284+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
285285
} break;
286286
default: {
287287
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");

ggml-cuda/fattn-vec-f16.cuh

+3-1
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,9 @@ template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml
290290
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
291291
constexpr int nwarps = D/WARP_SIZE;
292292
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>;
293-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
293+
constexpr bool need_f16_K = D != 128;
294+
constexpr bool need_f16_V = D != 128 && D != 64;
295+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
294296
}
295297

296298
template <int D, ggml_type type_K, ggml_type type_V>

ggml-cuda/fattn-vec-f32.cuh

+3-1
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,9 @@ template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml
271271
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
272272
constexpr int nwarps = D/WARP_SIZE;
273273
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
274-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
274+
constexpr bool need_f16_K = D != 128;
275+
constexpr bool need_f16_V = D != 128 && D != 64;
276+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
275277
}
276278

277279
template <int D, ggml_type type_K, ggml_type type_V>

ggml-cuda/fattn-wmma-f16.cuh

+3-3
Original file line numberDiff line numberDiff line change
@@ -438,18 +438,18 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
438438
if (4*blocks_num_pb1 < 2*nsm) {
439439
constexpr int parallel_blocks = 4;
440440
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
441-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
441+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
442442
return;
443443
}
444444
if (2*blocks_num_pb1 < 2*nsm) {
445445
constexpr int parallel_blocks = 2;
446446
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
447-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
447+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
448448
return;
449449
}
450450
constexpr int parallel_blocks = 1;
451451
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
452-
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
452+
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
453453
}
454454

455455
#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \

ggml-cuda/fattn.cu

+1-3
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
305305
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
306306
const int32_t precision = KQV->op_params[2];
307307

308-
const bool quantized_KV = ggml_is_quantized(K->type) || ggml_is_quantized(V->type);
309-
310308
// On AMD the tile kernels perform poorly, use the vec kernel instead:
311-
if (cc >= CC_OFFSET_AMD || quantized_KV) {
309+
if (cc >= CC_OFFSET_AMD) {
312310
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
313311
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
314312
} else {

0 commit comments

Comments
 (0)