Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: make group_size a part of params #786

Merged
merged 1 commit into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ void BatchPrefillWithRaggedKVCacheRun(
params.kv_indptr = static_cast<IdType*>(kv_indptr.data_ptr());
params.num_qo_heads = num_qo_heads;
params.num_kv_heads = num_kv_heads;
params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads);
params.q_stride_n = q_stride_n;
params.q_stride_h = q_stride_h;
params.k_stride_n = k_stride_n;
Expand Down Expand Up @@ -260,6 +261,7 @@ void BatchPrefillWithPagedKVCacheRun(

params.lse = maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr;
params.num_qo_heads = num_qo_heads;
params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads);
params.q_stride_n = q_stride_n;
params.q_stride_h = q_stride_h;
params.window_left = window_left;
Expand Down
5 changes: 5 additions & 0 deletions csrc/batch_prefill_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <flashinfer/layout.cuh>
#include <flashinfer/utils.cuh>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/fastdiv.cuh>

#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }}
Expand Down Expand Up @@ -42,6 +43,8 @@ struct RaggedParams {
IdType* kv_indptr;
DTypeO* o;
float* lse;
uint_fastdiv group_size;

{{ additional_params_decl }}
uint32_t num_qo_heads;
uint32_t num_kv_heads;
Expand Down Expand Up @@ -85,6 +88,8 @@ struct PagedParams {
IdType* q_indptr;
DTypeO* o;
float* lse;
uint_fastdiv group_size;

{{ additional_params_decl }}
uint32_t num_qo_heads;
IdType q_stride_n;
Expand Down
2 changes: 2 additions & 0 deletions csrc/single_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <flashinfer/pos_enc.cuh>
#include <optional>

#include "flashinfer/fastdiv.cuh"
#include "pytorch_extension_utils.h"
#include "single_prefill_config.inc"

Expand Down Expand Up @@ -85,6 +86,7 @@ void single_prefill_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::
params.lse = maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr;
params.num_qo_heads = num_qo_heads;
params.num_kv_heads = num_kv_heads;
params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads);
params.qo_len = qo_len;
params.kv_len = kv_len;
params.q_stride_n = q_stride_n;
Expand Down
2 changes: 2 additions & 0 deletions csrc/single_prefill_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <flashinfer/layout.cuh>
#include <flashinfer/utils.cuh>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/fastdiv.cuh>

#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }}
Expand Down Expand Up @@ -39,6 +40,7 @@ struct Params {
DTypeKV* v;
DTypeO* o;
float* lse;
uint_fastdiv group_size;

{{ additional_params_decl }}

Expand Down
9 changes: 9 additions & 0 deletions include/flashinfer/attention/default_prefill_params.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct SinglePrefillParams {
DTypeO* o;
float* lse;
float* maybe_alibi_slopes;
uint_fastdiv group_size;
uint32_t qo_len;
uint32_t kv_len;
uint32_t num_qo_heads;
Expand Down Expand Up @@ -65,6 +66,7 @@ struct SinglePrefillParams {
o(nullptr),
lse(nullptr),
maybe_alibi_slopes(nullptr),
group_size(),
qo_len(0),
kv_len(0),
num_qo_heads(0),
Expand Down Expand Up @@ -97,6 +99,7 @@ struct SinglePrefillParams {
o(o),
lse(lse),
maybe_alibi_slopes(maybe_alibi_slopes),
group_size(num_qo_heads / num_kv_heads),
num_qo_heads(num_qo_heads),
num_kv_heads(num_kv_heads),
qo_len(qo_len),
Expand Down Expand Up @@ -143,6 +146,7 @@ struct BatchPrefillRaggedParams {
DTypeO* o;
float* lse;
float* maybe_alibi_slopes;
uint_fastdiv group_size;
uint32_t num_qo_heads;
uint32_t num_kv_heads;
uint32_t q_stride_n;
Expand Down Expand Up @@ -182,6 +186,7 @@ struct BatchPrefillRaggedParams {
o(nullptr),
lse(nullptr),
maybe_alibi_slopes(nullptr),
group_size(),
num_qo_heads(0),
num_kv_heads(0),
q_stride_n(0),
Expand Down Expand Up @@ -228,6 +233,7 @@ struct BatchPrefillRaggedParams {
o(o),
lse(lse),
maybe_alibi_slopes(maybe_alibi_slopes),
group_size(num_qo_heads / num_kv_heads),
num_qo_heads(num_qo_heads),
num_kv_heads(num_kv_heads),
q_stride_n(q_stride_n),
Expand Down Expand Up @@ -278,6 +284,7 @@ struct BatchPrefillPagedParams {
DTypeO* o;
float* lse;
float* maybe_alibi_slopes;
uint_fastdiv group_size;
uint32_t num_qo_heads;
IdType q_stride_n;
IdType q_stride_h;
Expand Down Expand Up @@ -309,6 +316,7 @@ struct BatchPrefillPagedParams {
o(nullptr),
lse(nullptr),
maybe_alibi_slopes(nullptr),
group_size(),
num_qo_heads(0),
q_stride_n(0),
q_stride_h(0),
Expand Down Expand Up @@ -345,6 +353,7 @@ struct BatchPrefillPagedParams {
o(o),
lse(lse),
maybe_alibi_slopes(maybe_alibi_slopes),
group_size(num_qo_heads / paged_kv.num_heads),
num_qo_heads(num_qo_heads),
q_stride_n(q_stride_n),
q_stride_h(q_stride_h),
Expand Down
24 changes: 12 additions & 12 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1180,7 +1180,7 @@ __device__ __forceinline__ void write_o_reg_gmem(
*/
template <typename KTraits, typename Params>
__global__ __launch_bounds__(KTraits::NUM_THREADS) void SinglePrefillWithKVCacheKernel(
const uint_fastdiv group_size, const __grid_constant__ Params params) {
const __grid_constant__ Params params) {
using DTypeQ = typename Params::DTypeQ;
#if (__CUDA_ARCH__ < 800)
if constexpr (std::is_same_v<DTypeQ, nv_bfloat16>) {
Expand Down Expand Up @@ -1226,6 +1226,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void SinglePrefillWithKVCache
const uint32_t v_stride_n = params.v_stride_n;
const uint32_t v_stride_h = params.v_stride_h;
const int32_t maybe_window_left = params.window_left;
const uint_fastdiv& group_size = params.group_size;

static_assert(sizeof(DTypeQ) == 2);
static_assert(sizeof(DTypeO) == 2);
Expand Down Expand Up @@ -1449,7 +1450,6 @@ cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::D
}

const uint32_t group_size = num_qo_heads / num_kv_heads;
const uint_fastdiv group_size_fastdiv(group_size);
constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16;
constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16;
int64_t unpacked_qo_len = qo_len * group_size;
Expand Down Expand Up @@ -1525,7 +1525,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::D
if (num_chunks <= 1 || tmp == nullptr) {
// Enough parallelism, do not split-kv
params.partition_kv = false;
void* args[] = {(void*)&group_size_fastdiv, (void*)&params};
void* args[] = {(void*)&params};
dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), 1, num_kv_heads);
dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV);
FLASHINFER_CUDA_CALL(
Expand All @@ -1538,7 +1538,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::D
auto lse = params.lse;
params.o = tmp;
params.lse = tmp_lse;
void* args[] = {(void*)&group_size_fastdiv, (void*)&params};
void* args[] = {(void*)&params};
dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), num_chunks, num_kv_heads);
dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV);
FLASHINFER_CUDA_CALL(
Expand All @@ -1559,7 +1559,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::D

template <typename KTraits, typename Params>
__global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKVCacheKernel(
const uint_fastdiv group_size, const __grid_constant__ Params params) {
const __grid_constant__ Params params) {
using DTypeQ = typename Params::DTypeQ;
#if (__CUDA_ARCH__ < 800)
if constexpr (std::is_same_v<DTypeQ, nv_bfloat16>) {
Expand Down Expand Up @@ -1611,6 +1611,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV
const uint32_t v_stride_n = params.v_stride_n;
const uint32_t v_stride_h = params.v_stride_h;
const int32_t maybe_window_left = params.window_left;
const uint_fastdiv& group_size = params.group_size;

static_assert(sizeof(DTypeQ) == 2);
static_assert(sizeof(DTypeO) == 2);
Expand Down Expand Up @@ -1853,7 +1854,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV

template <typename KTraits, typename Params>
__global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithPagedKVCacheKernel(
const uint_fastdiv group_size, const __grid_constant__ Params params) {
const __grid_constant__ Params params) {
using DTypeQ = typename Params::DTypeQ;
#if (__CUDA_ARCH__ < 800)
if constexpr (std::is_same_v<DTypeQ, nv_bfloat16>) {
Expand Down Expand Up @@ -1897,6 +1898,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithPagedKVC
const paged_kv_t<DTypeKV, IdType>& paged_kv = params.paged_kv;
const bool partition_kv = params.partition_kv;
const int32_t maybe_window_left = params.window_left;
const uint_fastdiv& group_size = params.group_size;

static_assert(sizeof(DTypeQ) == 2);
static_assert(sizeof(DTypeO) == 2);
Expand Down Expand Up @@ -2158,7 +2160,6 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Para
const uint32_t padded_batch_size = params.padded_batch_size;
const uint32_t num_qo_heads = params.num_qo_heads;
const uint32_t num_kv_heads = params.num_kv_heads;
const uint_fastdiv group_size_fastdiv(num_qo_heads / num_kv_heads);
constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q);
constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q);
constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q);
Expand Down Expand Up @@ -2220,7 +2221,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Para
if (tmp_v == nullptr) {
// do not partition kv
params.partition_kv = false;
void* args[] = {(void*)&group_size_fastdiv, (void*)&params};
void* args[] = {(void*)&params};
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
Expand All @@ -2230,7 +2231,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Para
auto lse = params.lse;
params.o = tmp_v;
params.lse = tmp_s;
void* args[] = {(void*)&group_size_fastdiv, (void*)&params};
void* args[] = {(void*)&params};
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
if constexpr (AttentionVariant::use_softmax) {
Expand Down Expand Up @@ -2259,7 +2260,6 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Param
const uint32_t padded_batch_size = params.padded_batch_size;
const uint32_t num_qo_heads = params.num_qo_heads;
const uint32_t num_kv_heads = params.paged_kv.num_heads;
const uint_fastdiv group_size_fastdiv(num_qo_heads / num_kv_heads);
constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q);
constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q);
constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q);
Expand Down Expand Up @@ -2322,7 +2322,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Param
if (tmp_v == nullptr) {
// do not partition kv
params.partition_kv = false;
void* args[] = {(void*)&group_size_fastdiv, (void*)&params};
void* args[] = {(void*)&params};
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
Expand All @@ -2331,7 +2331,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Param
auto lse = params.lse;
params.o = tmp_v;
params.lse = tmp_s;
void* args[] = {(void*)&group_size_fastdiv, (void*)&params};
void* args[] = {(void*)&params};
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
if constexpr (AttentionVariant::use_softmax) {
Expand Down