-
Notifications
You must be signed in to change notification settings - Fork 272
/
Copy pathbatch_decode_customize_config.jinja
61 lines (49 loc) · 1.71 KB
/
batch_decode_customize_config.jinja
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#pragma once
#include <flashinfer/page.cuh>
#include <flashinfer/math.cuh>
#include <flashinfer/layout.cuh>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/attention/variant_helper.cuh>
#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }}
#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) { \
using AttentionVariant = {{ variant_name }}; \
__VA_ARGS__(); \
}
using namespace flashinfer;
using DTypeQ = {{ dtype_q }};
using DTypeKV = {{ dtype_kv }};
using DTypeO = {{ dtype_o }};
using IdType = {{ idtype }};
constexpr int HEAD_DIM_QK = {{ head_dim_qk }};
constexpr int HEAD_DIM_VO = {{ head_dim_vo }};
constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }};
constexpr auto POS_ENCODING_MODE = {{ pos_encoding_mode }};
constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }};
struct Params {
using DTypeQ = DTypeQ;
using DTypeKV = DTypeKV;
using DTypeO = DTypeO;
using IdType = IdType;
DTypeQ* q;
paged_kv_t<DTypeKV, IdType> paged_kv;
DTypeO* o;
float* lse;
{{ additional_params_decl }}
uint32_t padded_batch_size;
uint32_t num_qo_heads;
IdType q_stride_n;
IdType q_stride_h;
int32_t window_left;
IdType* request_indices;
IdType* kv_tile_indices;
IdType* o_indptr;
IdType* kv_chunk_size_ptr;
bool* block_valid_mask;
bool partition_kv;
__host__ __device__ __forceinline__ int32_t get_qo_len(int32_t batch_idx) const { return 1; }
__host__ __device__ __forceinline__ int32_t get_kv_len(int32_t batch_idx) const {
return paged_kv.get_length(batch_idx);
}
};
{{ variant_decl }}