Skip to content

[Refactor] Uniform PoDAttention API with Horizontal Fusion SMs Schedule #967

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ void BatchPrefillWithPagedKVCacheRun(
static_cast<IdType*>(paged_kv_last_page_len.data_ptr()));
params.paged_kv = paged_kv;
params.q_indptr = static_cast<IdType*>(qo_indptr.data_ptr());
params.q_lenptr = nullptr; // disable incontinous qo
params.o = static_cast<DTypeO*>(o.data_ptr());

params.lse = maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr;
Expand Down
6 changes: 6 additions & 0 deletions csrc/batch_prefill_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ struct PagedParams {

DTypeQ* q;
paged_kv_t<DTypeKV, IdType> paged_kv;

IdType* q_indptr;
uint32_t* q_lenptr;

DTypeO* o;
float* lse;
uint_fastdiv group_size;
Expand All @@ -110,6 +113,9 @@ struct PagedParams {
bool partition_kv;

__host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const {
if(q_lenptr){
return q_lenptr[batch_idx];
}
return q_indptr[batch_idx + 1] - q_indptr[batch_idx];
}

Expand Down
318 changes: 131 additions & 187 deletions csrc/pod.cu

Large diffs are not rendered by default.

64 changes: 32 additions & 32 deletions csrc/pod_config.inc
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#pragma once
#include <flashinfer/attention/default_prefill_params.cuh>
#include <flashinfer/attention/default_decode_params.cuh>
#include <flashinfer/attention/variants.cuh>
#include <flashinfer/attention/scheduler.cuh>
#include <flashinfer/attention/mask.cuh>
#include <flashinfer/layout.cuh>
#include <flashinfer/math.cuh>
Expand All @@ -12,34 +10,36 @@
#include "aot_default_additional_params.h"
#include "aot_extension_utils.h"

using namespace flashinfer;
using IdType = int32_t;

#define DISPATCH_context(MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, \
USE_SLIDING_WINDOW_P, USE_SLIDING_WINDOW_D, USE_LOGITS_SOFT_CAP, ...) \
{ \
DISPATCH_mask_mode(mask_mode_p, MASK_MODE_P, [&] { \
return DISPATCH_mask_mode(mask_mode_d, MASK_MODE_D, [&] { \
return DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE( \
q_scalar_type, kv_scalar_type, DTypeQ, DTypeKV, [&] { \
using DTypeO = DTypeQ; \
constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; \
constexpr bool USE_FP16_QK_REDUCTION = false; \
return DISPATCH_head_dim(head_dim_qk, HEAD_DIM_QK, [&] { \
[[maybe_unused]] constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \
return DISPATCH_BOOL(window_left_p > -1, USE_SLIDING_WINDOW_P, [&] { \
return DISPATCH_BOOL(window_left_d > -1, USE_SLIDING_WINDOW_D, [&] { \
return DISPATCH_BOOL(false, USE_LOGITS_SOFT_CAP, [&] { \
using IdType = int32_t; \
using PrefillParams = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>;\
using DecodeParams = BatchPrefillPagedParams<DTypeQ, \
DTypeKV, DTypeO, IdType>; \
__VA_ARGS__(); \
return true; \
}); \
}); \
}); \
}); \
}); \
}); \
}); \
}
#define ADDITIONAL_FUNC_PARAMS BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS
#define ADDITIONAL_PARAMS_SETTER BATCH_PREFILL_ADDITIONAL_PARAMS_SETTER

#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, \
POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, \
USE_FP16_QK_REDUCTION, AttentionVariant, RaggedParams, PagedParams, ...) \
{ \
DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { \
return DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE( \
q_scalar_type, kv_scalar_type, DTypeQ, DTypeKV, [&] { \
using DTypeO = DTypeQ; \
using RaggedParams = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
using PagedParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; \
constexpr bool USE_FP16_QK_REDUCTION = false; \
constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom; \
return DISPATCH_head_dim(head_dim_qk, HEAD_DIM_QK, [&] { \
[[maybe_unused]] constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \
return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \
return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \
using AttentionVariant = \
DefaultAttention</*use_custom_mask=*/use_custom_mask, USE_SLIDING_WINDOW, \
USE_LOGITS_SOFT_CAP, /*use_alibi_bias=*/false>; \
__VA_ARGS__(); \
return true; \
}); \
}); \
}); \
}); \
}); \
}
134 changes: 109 additions & 25 deletions csrc/pod_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@
#include <flashinfer/utils.cuh>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/fastdiv.cuh>
#include <flashinfer/attention/scheduler.cuh>
#include <flashinfer/attention/mask.cuh>
#include <flashinfer/attention/variant_helper.cuh>
#include <flashinfer/attention/default_prefill_params.cuh>

#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }}

#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, AttentionVariant, RaggedParams, PagedParams, ...) \
DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \
constexpr auto use_custom_mask = MASK_MODE == MaskMode::kCustom; \
using AttentionVariant = {{ variant_name }}; \
__VA_ARGS__(); \
})

using namespace flashinfer;

Expand All @@ -19,25 +26,102 @@ using IdType = {{ idtype }};
constexpr int HEAD_DIM_QK = {{ head_dim_qk }};
constexpr int HEAD_DIM_VO = {{ head_dim_vo }};
constexpr bool USE_FP16_QK_REDUCTION = {{ use_fp16_qk_reduction }};
constexpr auto USE_LOGITS_SOFT_CAP_P = {{ use_logits_soft_cap_p }};
constexpr auto POS_ENCODING_MODE_P = {{ pos_encoding_mode_p }};
constexpr auto USE_SLIDING_WINDOW_P = {{ use_sliding_window_p }};

constexpr auto USE_LOGITS_SOFT_CAP_D = {{ use_logits_soft_cap_d }};
constexpr auto POS_ENCODING_MODE_D = {{ pos_encoding_mode_d }};
constexpr auto USE_SLIDING_WINDOW_D = {{ use_sliding_window_d }};

constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone;
constexpr bool USE_LOGITS_SOFT_CAP = false;

using PrefillParams = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>;
using DecodeParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>;

#define DISPATCH_context(MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, \
USE_SLIDING_WINDOW_P, USE_SLIDING_WINDOW_D, USE_LOGITS_SOFT_CAP, ...) \
DISPATCH_mask_mode(mask_mode_p, MASK_MODE_P, [&] { \
return DISPATCH_mask_mode(mask_mode_d, MASK_MODE_D, [&] { \
__VA_ARGS__(); \
return true; \
}); \
});
constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }};
constexpr auto POS_ENCODING_MODE = {{ pos_encoding_mode }};
constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }};


struct RaggedParams {
using DTypeQ = DTypeQ;
using DTypeKV = DTypeKV;
using DTypeO = DTypeO;
using IdType = IdType;

DTypeQ* q;
DTypeKV* k;
DTypeKV* v;
IdType* q_indptr;
IdType* kv_indptr;
DTypeO* o;
float* lse;
uint_fastdiv group_size;

{{ additional_params_decl }}
uint32_t num_qo_heads;
uint32_t num_kv_heads;
uint32_t q_stride_n;
uint32_t q_stride_h;
uint32_t k_stride_n;
uint32_t k_stride_h;
uint32_t v_stride_n;
uint32_t v_stride_h;
int32_t window_left;

IdType* request_indices;
IdType* qo_tile_indices;
IdType* kv_tile_indices;
IdType* merge_indptr;
IdType* o_indptr;
IdType* kv_chunk_size_ptr;
bool* block_valid_mask;
uint32_t max_total_num_rows;
uint32_t* total_num_rows;
uint32_t padded_batch_size;
bool partition_kv;

__host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const {
return q_indptr[batch_idx + 1] - q_indptr[batch_idx];
}

__host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const {
return kv_indptr[batch_idx + 1] - kv_indptr[batch_idx];
}
};

struct PagedParams {
using DTypeQ = DTypeQ;
using DTypeKV = DTypeKV;
using DTypeO = DTypeO;
using IdType = IdType;

DTypeQ* q;
paged_kv_t<DTypeKV, IdType> paged_kv;

IdType* q_indptr;
uint32_t* q_lenptr;

DTypeO* o;
float* lse;
uint_fastdiv group_size;

{{ additional_params_decl }}
uint32_t num_qo_heads;
IdType q_stride_n;
IdType q_stride_h;
int32_t window_left;

IdType* request_indices;
IdType* qo_tile_indices;
IdType* kv_tile_indices;
IdType* merge_indptr;
IdType* o_indptr;
bool* block_valid_mask;
IdType* kv_chunk_size_ptr;
uint32_t max_total_num_rows;
uint32_t* total_num_rows;
uint32_t padded_batch_size;
bool partition_kv;

__host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const {
if(q_lenptr){
return q_lenptr[batch_idx];
}
return q_indptr[batch_idx + 1] - q_indptr[batch_idx];
}

__host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const {
return paged_kv.get_length(batch_idx);
}
};

{{ variant_decl }}
37 changes: 17 additions & 20 deletions csrc/pod_jit_pybind.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,23 @@
#include "pod_config.inc"
#include "pytorch_extension_utils.h"

void pod_with_kv_cache_tensor(
// Prefill params
at::Tensor q_p, at::Tensor k_p, at::Tensor v_p, at::Tensor tmp_p, at::Tensor o_p,
std::optional<at::Tensor> maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p,
int64_t window_left_p, std::optional<at::Tensor> maybe_custom_mask_p,
std::optional<at::Tensor> maybe_alibi_slopes_p, double logits_soft_cap_p, double sm_scale_p,
double rope_rcp_scale_p, double rope_rcp_theta_p,
// Decode params
at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d,
at::Tensor plan_info_vec, at::Tensor q_d, at::Tensor paged_k_cache_d,
at::Tensor paged_v_cache_d, at::Tensor qo_indptr_d, at::Tensor paged_kv_indptr_d,
at::Tensor paged_kv_indices_d, at::Tensor paged_kv_last_page_len_d, at::Tensor o_d,
std::optional<at::Tensor> maybe_lse_d, int64_t mask_mode_code_d, int64_t layout_d,
int64_t window_left_d, std::optional<at::Tensor> maybe_custom_mask_d,
std::optional<at::Tensor> maybe_mask_indptr_d, std::optional<at::Tensor> maybe_alibi_slopes_d,
double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d,
// Shared params
int64_t cuda_stream);
at::Tensor PODWithPagedKVCachePlan(at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer,
at::Tensor page_locked_int_workspace_buffer,
at::Tensor qo_indptr, at::Tensor kv_indptr,
at::Tensor kv_last_page_len, int64_t total_num_rows,
int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
int64_t head_dim_vo, bool causal, int64_t cuda_stream);

void PODWithPagedKVCacheRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache,
at::Tensor paged_v_cache, at::Tensor paged_kv_indices, at::Tensor o,
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS,
int64_t cuda_stream);

TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
// Batch-request prefill attention with KV-Cache operator
m.def("pod_with_kv_cache_tensor", pod_with_kv_cache_tensor);
m.def("plan", PODWithPagedKVCachePlan);
m.def("paged_run", PODWithPagedKVCacheRun);
}
40 changes: 11 additions & 29 deletions csrc/pod_kernel_inst.jinja
Original file line number Diff line number Diff line change
@@ -1,34 +1,16 @@
#include <flashinfer/attention/default_prefill_params.cuh>
#include <flashinfer/attention/default_decode_params.cuh>
#include <flashinfer/attention/variants.cuh>
#include <flashinfer/attention/scheduler.cuh>
#include <flashinfer/attention/mask.cuh>
#include <flashinfer/attention/pod.cuh>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/utils.cuh>
#include <flashinfer/page.cuh>

#include "pytorch_conversion_utils.h"
#include "pytorch_extension_utils.h"
#include "aot_default_additional_params.h"
#include "aot_extension_utils.h"

#include "pod_config.inc"

using namespace flashinfer;

namespace flashinfer {
constexpr auto use_custom_mask_p = {{ mask_mode_p }} == MaskMode::kCustom;
constexpr auto use_custom_mask_d = {{ mask_mode_d }} == MaskMode::kCustom;
// Not sure about the below declaration
constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone;

template cudaError_t PODWithKVCacheTensorDispatched<
{{ head_dim_qk }}, {{ head_dim_vo }}, POS_ENCODING_MODE,
{{ use_fp16_qk_reduction }}, {{ mask_mode_p }}, 16,
{{ mask_mode_d }}, {{ variant_name_p }},
{{ variant_name_d }}, PrefillParams, DecodeParams>(
PrefillParams prefill_params, {{ dtype_o }}* tmp,
DecodeParams decode_params, {{ dtype_o }}* tmp_v,
float *tmp_s, cudaStream_t stream);
};
constexpr auto use_custom_mask = {{ mask_mode }} == MaskMode::kCustom;

{% for cta_tile_q_p in [16, 64, 128] %}
{% for cta_tile_q_d in [16, 64, 128] %}
template cudaError_t PODWithPagedKVCacheDispatched<
/*CTA_TILE_Q_P=*/{{cta_tile_q_p}}, /*CTA_TILE_Q_D=*/{{cta_tile_q_d}}, {{head_dim_qk}}, {{head_dim_vo}}, {{pos_encoding_mode}}, {{use_fp16_qk_reduction}}, {{mask_mode}},
{{ variant_name }}, PagedParams, PagedParams>(PagedParams prefill_params, PagedParams decode_params, {{ dtype_o }}* tmp_v, float* tmp_s, cudaStream_t stream);
{% endfor %}
{% endfor %}

}; // namespace flashinfer
Loading