Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix POD JIT bugs #971

Merged
merged 3 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion csrc/pod.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include <flashinfer/pos_enc.cuh>
#include <optional>

#include "aot_extension_utils.h"
#include "pod_config.inc"
#include "pytorch_conversion_utils.h"
#include "pytorch_extension_utils.h"
Expand Down
5 changes: 2 additions & 3 deletions csrc/pod_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ using DecodeParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>;

#define DISPATCH_context(MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, \
USE_SLIDING_WINDOW_P, USE_SLIDING_WINDOW_D, USE_LOGITS_SOFT_CAP, ...) \
DISPATCH_mask_mode(mask_mode_p, MASK_MODE_P, [&] { \
return DISPATCH_mask_mode(mask_mode_d, MASK_MODE_D, [&] { \
DISPATCH_MASK_MODE(mask_mode_p, MASK_MODE_P, { \
DISPATCH_MASK_MODE(mask_mode_d, MASK_MODE_D, { \
__VA_ARGS__(); \
return true; \
}); \
});
2 changes: 0 additions & 2 deletions csrc/pod_kernel_inst.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

#include "pytorch_conversion_utils.h"
#include "pytorch_extension_utils.h"
#include "aot_default_additional_params.h"
#include "aot_extension_utils.h"

#include "pod_config.inc"

Expand Down
337 changes: 170 additions & 167 deletions include/flashinfer/attention/pod.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params,
const uint_fastdiv group_size_fastdiv(group_size);
constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16;
constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16;

uint32_t cta_tile_q_p = 0;
int64_t unpacked_qo_len = qo_len * group_size;
if (unpacked_qo_len > 64 && HEAD_DIM_VO < 256) {
Expand Down Expand Up @@ -268,183 +269,185 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params,
NUM_MMA_Q_D * NUM_WARPS_Q_D) /
(2 * NUM_WARPS_KV_D);

// DISPATCH_CTA_TILE_Q(cta_tile_q_p, CTA_TILE_Q_P, {
constexpr size_t CTA_TILE_Q_P = 128;
constexpr uint32_t NUM_WARPS_Q_P = get_num_warps_q(CTA_TILE_Q_P);
constexpr uint32_t NUM_WARPS_KV_P = get_num_warps_kv(CTA_TILE_Q_P);
constexpr uint32_t NUM_MMA_Q_P = get_num_mma_q(CTA_TILE_Q_P);

using DTypeQKAccum_P =
typename std::conditional<USE_FP16_QK_REDUCTION && std::is_same_v<DTypeQ_P, half>, half,
float>::type;

// we expect each sm execute two threadblocks
// TODO(Zihao): fix the following computation
const int num_ctas_per_sm_p =
max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ_P) * 16) ? 2 : 1;
const int max_smem_per_threadblock_p = max_smem_per_sm / num_ctas_per_sm_p;

constexpr uint32_t max_num_mma_kv_reg_p =
(HEAD_DIM_VO >= 128 && NUM_MMA_Q_P == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama &&
!USE_FP16_QK_REDUCTION)
? 2
: (8 / NUM_MMA_Q_P);
// TODO(Zihao): fix the following computation
const uint32_t max_num_mma_kv_smem_p =
(max_smem_per_threadblock_p / (16 * HEAD_DIM_QK * sizeof(DTypeQ_P)) -
NUM_MMA_Q_P * NUM_WARPS_Q_P) /
(2 * NUM_WARPS_KV_P);

// control NUM_MMA_KV for maximum warp occupancy
DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_p, max_num_mma_kv_reg_p), NUM_MMA_KV_P, {
using KTraits_P = KernelTraits<MASK_MODE_P, CTA_TILE_Q_P, NUM_MMA_Q_P, NUM_MMA_KV_P,
NUM_MMA_D_QK, NUM_MMA_D_VO, NUM_WARPS_Q_P, NUM_WARPS_KV_P,
POS_ENCODING_MODE, DTypeQ_P, DTypeKV_P, DTypeO_P, DTypeQKAccum_P,
typename PrefillParams::IdType, PrefillAttentionVariant>;

if constexpr (KTraits_P::IsInvalid()) {
// Invalid configuration, skip
std::ostringstream err_msg;
err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_P
<< " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO
<< " NUM_MMA_KV=" << NUM_MMA_KV_P << " NUM_WARPS_Q=" << NUM_WARPS_Q_P
<< " NUM_WARPS_KV=" << NUM_WARPS_KV_P
<< " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
" and report the issue to the developers.";
FLASHINFER_ERROR(err_msg.str());
} else {
// Decode stuff
// TODO: Is there a way to avoid this nested dispatch?
DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_d, max_num_mma_kv_reg_d), NUM_MMA_KV_D, {
using KTraits_D =
KernelTraits<MASK_MODE_D, CTA_TILE_Q_D, NUM_MMA_Q_D, NUM_MMA_KV_D, NUM_MMA_D_QK,
NUM_MMA_D_VO, NUM_WARPS_Q_D, NUM_WARPS_KV_D, POS_ENCODING_MODE, DTypeQ_D,
DTypeKV_D, DTypeO_D, DTypeQKAccum_D, typename DecodeParams::IdType,
DecodeAttentionVariant>;
if constexpr (KTraits_D::IsInvalid()) {
// Invalid configuration, skip
std::ostringstream err_msg;
err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_D
<< " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO
<< " NUM_MMA_KV=" << NUM_MMA_KV_D << " NUM_WARPS_Q=" << NUM_WARPS_Q_D
<< " NUM_WARPS_KV=" << NUM_WARPS_KV_D
<< " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
" and report the issue to the developers.";
FLASHINFER_ERROR(err_msg.str());
} else {
// End decode stuff
constexpr uint32_t num_threads_p = (NUM_WARPS_Q_P * NUM_WARPS_KV_P) * WARP_SIZE;
size_t smem_size_p = sizeof(typename KTraits_P::SharedStorage);
size_t smem_size_d = sizeof(typename KTraits_D::SharedStorage);

auto kernel =
PODWithKVCacheTensorKernel<KTraits_P, KTraits_D, true, PrefillParams, DecodeParams>;
// Prefill: decide num_splits for split-kv
int num_blocks_per_sm = 0;
int num_sm = 0;
FLASHINFER_CUDA_CALL(
cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_sm, kernel, num_threads_p, smem_size_p));
uint32_t max_num_kv_chunks =
(num_blocks_per_sm * num_sm) /
(num_kv_heads * ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q));
uint32_t num_chunks;
if (max_num_kv_chunks > 0) {
uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256);
num_chunks = ceil_div(kv_len, chunk_size);
DISPATCH_CTA_TILE_Q(cta_tile_q_p, CTA_TILE_Q_P, {
constexpr uint32_t NUM_WARPS_Q_P = get_num_warps_q(CTA_TILE_Q_P);
constexpr uint32_t NUM_WARPS_KV_P = get_num_warps_kv(CTA_TILE_Q_P);
constexpr uint32_t NUM_MMA_Q_P = get_num_mma_q(CTA_TILE_Q_P);

using DTypeQKAccum_P =
typename std::conditional<USE_FP16_QK_REDUCTION && std::is_same_v<DTypeQ_P, half>, half,
float>::type;

// we expect each sm execute two threadblocks
// TODO(Zihao): fix the following computation
const int num_ctas_per_sm_p =
max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ_P) * 16) ? 2 : 1;
const int max_smem_per_threadblock_p = max_smem_per_sm / num_ctas_per_sm_p;

constexpr uint32_t max_num_mma_kv_reg_p =
(HEAD_DIM_VO >= 128 && NUM_MMA_Q_P == 2 &&
POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && !USE_FP16_QK_REDUCTION)
? 2
: (8 / NUM_MMA_Q_P);
// TODO(Zihao): fix the following computation
const uint32_t max_num_mma_kv_smem_p =
(max_smem_per_threadblock_p / (16 * HEAD_DIM_QK * sizeof(DTypeQ_P)) -
NUM_MMA_Q_P * NUM_WARPS_Q_P) /
(2 * NUM_WARPS_KV_P);

// control NUM_MMA_KV for maximum warp occupancy
DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_p, max_num_mma_kv_reg_p), NUM_MMA_KV_P, {
using KTraits_P =
KernelTraits<MASK_MODE_P, CTA_TILE_Q_P, NUM_MMA_Q_P, NUM_MMA_KV_P, NUM_MMA_D_QK,
NUM_MMA_D_VO, NUM_WARPS_Q_P, NUM_WARPS_KV_P, POS_ENCODING_MODE, DTypeQ_P,
DTypeKV_P, DTypeO_P, DTypeQKAccum_P, typename PrefillParams::IdType,
PrefillAttentionVariant>;

if constexpr (KTraits_P::IsInvalid()) {
// Invalid configuration, skip
std::ostringstream err_msg;
err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_P
<< " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO
<< " NUM_MMA_KV=" << NUM_MMA_KV_P << " NUM_WARPS_Q=" << NUM_WARPS_Q_P
<< " NUM_WARPS_KV=" << NUM_WARPS_KV_P
<< " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
" and report the issue to the developers.";
FLASHINFER_ERROR(err_msg.str());
} else {
// Decode stuff
// TODO: Is there a way to avoid this nested dispatch?
DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_d, max_num_mma_kv_reg_d), NUM_MMA_KV_D, {
using KTraits_D =
KernelTraits<MASK_MODE_D, CTA_TILE_Q_D, NUM_MMA_Q_D, NUM_MMA_KV_D, NUM_MMA_D_QK,
NUM_MMA_D_VO, NUM_WARPS_Q_D, NUM_WARPS_KV_D, POS_ENCODING_MODE, DTypeQ_D,
DTypeKV_D, DTypeO_D, DTypeQKAccum_D, typename DecodeParams::IdType,
DecodeAttentionVariant>;
if constexpr (KTraits_D::IsInvalid()) {
// Invalid configuration, skip
std::ostringstream err_msg;
err_msg
<< "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_D
<< " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO
<< " NUM_MMA_KV=" << NUM_MMA_KV_D << " NUM_WARPS_Q=" << NUM_WARPS_Q_D
<< " NUM_WARPS_KV=" << NUM_WARPS_KV_D
<< " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
" and report the issue to the developers.";
FLASHINFER_ERROR(err_msg.str());
} else {
num_chunks = 0;
}
// End decode stuff
constexpr uint32_t num_threads_p = (NUM_WARPS_Q_P * NUM_WARPS_KV_P) * WARP_SIZE;
size_t smem_size_p = sizeof(typename KTraits_P::SharedStorage);
size_t smem_size_d = sizeof(typename KTraits_D::SharedStorage);

// Setup new prefill params if (not) split
auto o_p = prefill_params.o;
auto lse_p = prefill_params.lse;
float* tmp_lse = (float*)(tmp_p + num_chunks * qo_len * num_qo_heads * HEAD_DIM_VO);
if (num_chunks <= 1 || tmp_p == nullptr) {
// Enough parallelism, do not split-kv
prefill_params.partition_kv = 0;
kernel = PODWithKVCacheTensorKernel<KTraits_P, KTraits_D, false, PrefillParams,
DecodeParams>;
} else {
// Use cooperative groups to increase occupancy
prefill_params.partition_kv = num_chunks;
prefill_params.o = tmp_p;
prefill_params.lse = tmp_lse;
kernel =
auto kernel =
PODWithKVCacheTensorKernel<KTraits_P, KTraits_D, true, PrefillParams, DecodeParams>;
}
// Prefill: decide num_splits for split-kv
int num_blocks_per_sm = 0;
int num_sm = 0;
FLASHINFER_CUDA_CALL(
cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_sm, kernel, num_threads_p, smem_size_p));
uint32_t max_num_kv_chunks =
(num_blocks_per_sm * num_sm) /
(num_kv_heads * ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q));
uint32_t num_chunks;
if (max_num_kv_chunks > 0) {
uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256);
num_chunks = ceil_div(kv_len, chunk_size);
} else {
num_chunks = 0;
}

// Setup new decode params if (not) split
auto o_d = decode_params.o;
auto lse_d = decode_params.lse;
if (tmp_v == nullptr) {
// do not partition kv
decode_params.partition_kv = false;
} else {
decode_params.partition_kv = true;
decode_params.o = tmp_v;
decode_params.lse = tmp_s;
}
uint32_t xsize = ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q);
int nblks_p(xsize * (prefill_params.partition_kv ? prefill_params.partition_kv : 1) *
num_kv_heads);
int nthrs_p(32 * NUM_WARPS_Q_P * NUM_WARPS_KV_P);

int nblks_d(padded_batch_size_d * 1 * num_kv_heads);
int nthrs_d(32 * NUM_WARPS_Q_D * NUM_WARPS_KV_D);

// ******* Select final combined sizes here ******* /
size_t smem_size = max(smem_size_p, smem_size_d);
int nblks = nblks_p + nblks_d;
int nthrs = max(nthrs_p, nthrs_d);

// printf("Smem: prefill %zu, decode %zu, total %zu\n", smem_size_p, smem_size_d,
// smem_size); printf("Blocks: prefill %d, decode %d, total %d\n", nblks_p, nblks_d,
// nblks); printf("Threads: prefill %d, decode %d, total %d\n", nthrs_p, nthrs_d, nthrs);
// ************************************************ /

static int* tbAssign = nullptr;
if (tbAssign == nullptr) cudaMalloc(&tbAssign, sizeof(int) * (num_sm + 2));
cudaMemset(tbAssign, 0, sizeof(int) * (num_sm + 2));

// Setup kernel arguments
void* args[] = {(void*)&xsize, (void*)&prefill_params, (void*)&decode_params,
(void*)&tbAssign};
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// Launch kernel
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));

// Post-kernel stuff for split-kv prefill
if (!(num_chunks <= 1 || tmp_p == nullptr)) {
if constexpr (PrefillAttentionVariant::use_softmax) {
FLASHINFER_CUDA_CALL(MergeStates(tmp_p, tmp_lse, o_p, lse_p, num_chunks, qo_len,
num_qo_heads, HEAD_DIM_VO, stream));
// Setup new prefill params if (not) split
auto o_p = prefill_params.o;
auto lse_p = prefill_params.lse;
float* tmp_lse = (float*)(tmp_p + num_chunks * qo_len * num_qo_heads * HEAD_DIM_VO);
if (num_chunks <= 1 || tmp_p == nullptr) {
// Enough parallelism, do not split-kv
prefill_params.partition_kv = 0;
kernel = PODWithKVCacheTensorKernel<KTraits_P, KTraits_D, false, PrefillParams,
DecodeParams>;
} else {
FLASHINFER_CUDA_CALL(
AttentionSum(tmp_p, o_p, num_chunks, qo_len, num_qo_heads, HEAD_DIM_VO, stream));
// Use cooperative groups to increase occupancy
prefill_params.partition_kv = num_chunks;
prefill_params.o = tmp_p;
prefill_params.lse = tmp_lse;
kernel = PODWithKVCacheTensorKernel<KTraits_P, KTraits_D, true, PrefillParams,
DecodeParams>;
}
}
// Post-kernel stuff for split-kv decode
if (tmp_v != nullptr) {
if constexpr (DecodeAttentionVariant::use_softmax) {
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(
tmp_v, tmp_s, decode_params.merge_indptr, o_d, lse_d,
decode_params.max_total_num_rows, decode_params.total_num_rows, num_qo_heads,
HEAD_DIM_VO, stream));

// Setup new decode params if (not) split
auto o_d = decode_params.o;
auto lse_d = decode_params.lse;
if (tmp_v == nullptr) {
// do not partition kv
decode_params.partition_kv = false;
} else {
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(
tmp_v, decode_params.merge_indptr, o_d, decode_params.max_total_num_rows,
decode_params.total_num_rows, num_qo_heads, HEAD_DIM_VO, stream));
decode_params.partition_kv = true;
decode_params.o = tmp_v;
decode_params.lse = tmp_s;
}
uint32_t xsize = ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q);
int nblks_p(xsize * (prefill_params.partition_kv ? prefill_params.partition_kv : 1) *
num_kv_heads);
int nthrs_p(32 * NUM_WARPS_Q_P * NUM_WARPS_KV_P);

int nblks_d(padded_batch_size_d * 1 * num_kv_heads);
int nthrs_d(32 * NUM_WARPS_Q_D * NUM_WARPS_KV_D);

// ******* Select final combined sizes here ******* /
size_t smem_size = max(smem_size_p, smem_size_d);
int nblks = nblks_p + nblks_d;
int nthrs = max(nthrs_p, nthrs_d);

// printf("Smem: prefill %zu, decode %zu, total %zu\n", smem_size_p, smem_size_d,
// smem_size); printf("Blocks: prefill %d, decode %d, total %d\n", nblks_p, nblks_d,
// nblks); printf("Threads: prefill %d, decode %d, total %d\n", nthrs_p, nthrs_d,
// nthrs);
// ************************************************ /

static int* tbAssign = nullptr;
if (tbAssign == nullptr) cudaMalloc(&tbAssign, sizeof(int) * (num_sm + 2));
cudaMemset(tbAssign, 0, sizeof(int) * (num_sm + 2));

// Setup kernel arguments
void* args[] = {(void*)&xsize, (void*)&prefill_params, (void*)&decode_params,
(void*)&tbAssign};
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// Launch kernel
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));

// Post-kernel stuff for split-kv prefill
if (!(num_chunks <= 1 || tmp_p == nullptr)) {
if constexpr (PrefillAttentionVariant::use_softmax) {
FLASHINFER_CUDA_CALL(MergeStates(tmp_p, tmp_lse, o_p, lse_p, num_chunks, qo_len,
num_qo_heads, HEAD_DIM_VO, stream));
} else {
FLASHINFER_CUDA_CALL(AttentionSum(tmp_p, o_p, num_chunks, qo_len, num_qo_heads,
HEAD_DIM_VO, stream));
}
}
// Post-kernel stuff for split-kv decode
if (tmp_v != nullptr) {
if constexpr (DecodeAttentionVariant::use_softmax) {
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(
tmp_v, tmp_s, decode_params.merge_indptr, o_d, lse_d,
decode_params.max_total_num_rows, decode_params.total_num_rows, num_qo_heads,
HEAD_DIM_VO, stream));
} else {
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(
tmp_v, decode_params.merge_indptr, o_d, decode_params.max_total_num_rows,
decode_params.total_num_rows, num_qo_heads, HEAD_DIM_VO, stream));
}
}
}
}
});
}
});
}
});
});
//});
return cudaSuccess;
}

Expand Down