Skip to content

Commit 0cf6725

Browse files
CUDA: FA support for Deepseek (Ampere or newer) (#13306)
* CUDA: FA support for Deepseek (Ampere or newer) * do loop unrolling via C++ template
1 parent 27ebfca commit 0cf6725

33 files changed

+826
-521
lines changed

ggml/src/ggml-cuda/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ if (CUDAToolkit_FOUND)
118118

119119
set(CUDA_CXX_FLAGS "")
120120

121-
set(CUDA_FLAGS -use_fast_math)
121+
set(CUDA_FLAGS -use_fast_math -extended-lambda)
122122

123123
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
124124
# Options are:

ggml/src/ggml-cuda/common.cuh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,25 @@ static __device__ void no_device_code(
296296
#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
297297
#endif // __CUDA_ARCH__
298298

299+
// The compiler is always able to unroll loops if they contain continue expressions.
300+
// In such cases loop unrolling can still be achieved via recursion:
301+
template <int n>
302+
struct ggml_cuda_unroll {
303+
template <typename Func, typename... Args>
304+
__device__ void operator()(const Func & f, Args... args) const {
305+
f(n - 1, args...);
306+
ggml_cuda_unroll<n - 1>{}(f, args...);
307+
}
308+
};
309+
310+
template <>
311+
struct ggml_cuda_unroll<1> {
312+
template <typename Func, typename... Args>
313+
__device__ void operator()(const Func & f, Args... args) const {
314+
f(0, args...);
315+
}
316+
};
317+
299318
template<int width = WARP_SIZE>
300319
static __device__ __forceinline__ int warp_reduce_sum(int x) {
301320
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE

ggml/src/ggml-cuda/cp-async.cuh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,17 @@
22

33
#include "common.cuh"
44

5+
6+
static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) {
7+
#ifdef CP_ASYNC_AVAILABLE
8+
return __cvta_generic_to_shared(generic_ptr);
9+
#else
10+
GGML_UNUSED(generic_ptr);
11+
NO_DEVICE_CODE;
12+
return 0;
13+
#endif // CP_ASYNC_AVAILABLE
14+
}
15+
516
// Copies data from global to shared memory, cg == cache global.
617
// Both the src and dst pointers must be aligned to 16 bit.
718
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
516516
nullptr;
517517
}
518518

519-
template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
519+
template<int D, int ncols1, int ncols2> // D == head size
520520
__launch_bounds__(D, 1)
521521
static __global__ void flash_attn_stream_k_fixup(
522522
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
@@ -665,13 +665,13 @@ static void on_no_fattn_vec_case(const int D) {
665665
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
666666
GGML_ABORT("fatal error");
667667
} else {
668-
fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");
668+
fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);
669669
fprintf(stderr, "Only f16 is supported.\n");
670670
GGML_ABORT("fatal error");
671671
}
672672
}
673673

674-
template <int D, int ncols1, int ncols2, int KQ_stride>
674+
template <int DV, int ncols1, int ncols2>
675675
void launch_fattn(
676676
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
677677
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
@@ -691,7 +691,7 @@ void launch_fattn(
691691

692692
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
693693
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
694-
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
694+
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
695695

696696
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
697697

@@ -754,10 +754,13 @@ void launch_fattn(
754754
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
755755

756756
const dim3 block_dim(warp_size, nwarps, 1);
757+
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
758+
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
759+
757760
dim3 blocks_num;
758761
if (stream_k) {
759762
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
760-
const int max_blocks = 2*nsm;
763+
const int max_blocks = max_blocks_per_sm*nsm;
761764
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
762765
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
763766

@@ -769,14 +772,11 @@ void launch_fattn(
769772
blocks_num.y = 1;
770773
blocks_num.z = 1;
771774

772-
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
775+
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
773776
} else {
774777
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
775778
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
776779

777-
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
778-
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
779-
780780
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
781781
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
782782

@@ -853,19 +853,19 @@ void launch_fattn(
853853

854854
if (stream_k) {
855855
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
856-
const dim3 block_dim_combine(D, 1, 1);
856+
const dim3 block_dim_combine(DV, 1, 1);
857857
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
858858

859-
flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
859+
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
860860
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
861861
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
862862
}
863863
} else if (parallel_blocks > 1) {
864-
const dim3 block_dim_combine(D, 1, 1);
864+
const dim3 block_dim_combine(DV, 1, 1);
865865
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
866866
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
867867

868-
flash_attn_combine_results<D>
868+
flash_attn_combine_results<DV>
869869
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
870870
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
871871
}

0 commit comments

Comments
 (0)