forked from flashinfer-ai/flashinfer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathflashinfer_ops.cu
313 lines (255 loc) · 15.6 KB
/
flashinfer_ops.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "aot_default_additional_params.h"
#include "pytorch_extension_utils.h"
//========== activation ==========
void silu_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl);
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl);
void gelu_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl);
//========== cascade ==========
void merge_state(at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b,
at::Tensor v_merged, at::Tensor s_merged);
void merge_state_in_place(at::Tensor v, at::Tensor s, at::Tensor v_other, at::Tensor s_other,
std::optional<at::Tensor> mask);
void merge_states(at::Tensor v, at::Tensor s, at::Tensor v_merged, at::Tensor s_merged);
//========== decode ==========
void single_decode_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp,
at::Tensor o, int64_t layout,
int64_t window_left SINGLE_DECODE_ADDITIONAL_FUNC_PARAMS);
at::Tensor BatchDecodeWithPagedKVCachePlan(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, int64_t batch_size,
int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph,
int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, int64_t head_dim_vo,
at::Tensor empty_q_data, at::Tensor empty_kv_data);
void BatchDecodeWithPagedKVCacheRun(at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
at::Tensor q, at::Tensor paged_k_cache,
at::Tensor paged_v_cache, at::Tensor paged_kv_indptr,
at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len,
at::Tensor o, std::optional<at::Tensor> maybe_lse,
int64_t kv_layout_code,
int64_t window_left BATCH_DECODE_ADDITIONAL_FUNC_PARAMS);
//========== gemm ==========
void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale,
at::Tensor workspace_buffer, int64_t cublas_handle);
void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at::Tensor x_ptr,
at::Tensor w_ptr, at::Tensor y_ptr, at::Tensor x_ld, at::Tensor w_ld,
at::Tensor y_ld, at::Tensor empty_x_data, bool weight_column_major);
//========== norm ==========
void rmsnorm(at::Tensor& out, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps,
bool enable_pdl);
void gemma_rmsnorm(at::Tensor& out, at::Tensor& input, at::Tensor& weight, double eps,
bool enable_pdl);
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight,
double eps, bool enable_pdl);
//========== page ==========
void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::Tensor batch_indices,
at::Tensor positions, at::Tensor paged_k_cache, at::Tensor paged_v_cache,
at::Tensor kv_indices, at::Tensor kv_indptr, at::Tensor kv_last_page_len,
int64_t layout);
void append_paged_mla_kv_cache(at::Tensor append_ckv, at::Tensor append_kpe,
at::Tensor batch_indices, at::Tensor positions, at::Tensor ckv_cache,
at::Tensor kpe_cache, at::Tensor kv_indices, at::Tensor kv_indptr,
at::Tensor kv_last_page_len);
void block_sparse_indices_to_vector_sparse_offsets(
at::Tensor block_sparse_indices, at::Tensor block_sparse_indptr,
at::Tensor vector_sparse_offsets, at::Tensor vector_sparse_indptr, at::Tensor kv_len_arr,
int64_t stride_block, int64_t stride_n, int64_t batch_size, int64_t block_size);
//========== prefill ==========
void single_prefill_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp,
at::Tensor o, std::optional<at::Tensor> maybe_lse,
int64_t mask_mode_code, int64_t layout,
int64_t window_left SINGLE_PREFILL_ADDITIONAL_FUNC_PARAMS);
at::Tensor BatchPrefillWithKVCachePlan(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
int64_t head_dim_vo, bool causal);
void BatchPrefillWithRaggedKVCacheRun(at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
at::Tensor q, at::Tensor k, at::Tensor v,
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o,
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code,
int64_t layout,
int64_t window_left BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS);
void BatchPrefillWithPagedKVCacheRun(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr,
at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len,
at::Tensor o, std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code, int64_t layout,
int64_t window_left BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS);
//========== pod-attention =========
void pod_with_kv_cache_tensor(
// Prefill params
at::Tensor q_p, at::Tensor k_p, at::Tensor v_p, at::Tensor tmp_p, at::Tensor o_p,
std::optional<at::Tensor> maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p,
int64_t window_left_p, std::optional<at::Tensor> maybe_custom_mask_p,
std::optional<at::Tensor> maybe_alibi_slopes_p, double logits_soft_cap_p, double sm_scale_p,
double rope_rcp_scale_p, double rope_rcp_theta_p,
// Decode params
at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d,
at::Tensor plan_info_vec, at::Tensor q_d, at::Tensor paged_k_cache_d,
at::Tensor paged_v_cache_d, at::Tensor qo_indptr_d, at::Tensor paged_kv_indptr_d,
at::Tensor paged_kv_indices_d, at::Tensor paged_kv_last_page_len_d, at::Tensor o_d,
std::optional<at::Tensor> maybe_lse_d, int64_t mask_mode_code_d, int64_t layout_d,
int64_t window_left, std::optional<at::Tensor> maybe_custom_mask_d,
std::optional<at::Tensor> maybe_mask_indptr_d, std::optional<at::Tensor> maybe_alibi_slopes_d,
double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d);
//========== quantization ==========
void packbits(at::Tensor x, const std::string& bitorder, at::Tensor y);
void segment_packbits(at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr,
const std::string& bitorder, at::Tensor y);
//========== rope ==========
void apply_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, at::Tensor indptr,
at::Tensor offsets, int64_t rotary_dim, bool interleave, double rope_scale,
double rope_theta);
void apply_llama31_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope,
at::Tensor indptr, at::Tensor offsets, int64_t rotary_dim, bool interleave,
double rope_scale, double rope_theta, double low_freq_factor,
double high_freq_factor, double old_context_length);
void apply_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope,
at::Tensor pos_ids, int64_t rotary_dim, bool interleave, double rope_scale,
double rope_theta);
void apply_llama31_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope,
at::Tensor pos_ids, int64_t rotary_dim, bool interleave,
double rope_scale, double rope_theta, double low_freq_factor,
double high_freq_factor, double old_context_length);
void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope,
at::Tensor k_rope, at::Tensor cos_sin_cache,
at::Tensor pos_ids, bool interleave);
//========== sampling ==========
void sampling_from_probs(at::Tensor probs, at::Tensor output,
std::optional<at::Tensor> maybe_indices, bool deterministic,
std::optional<at::Generator> gen);
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor output,
std::optional<at::Tensor> maybe_indices,
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val,
bool deterministic, std::optional<at::Generator> gen);
void top_k_sampling_from_probs(at::Tensor probs, at::Tensor output,
std::optional<at::Tensor> maybe_indices,
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val,
bool deterministic, std::optional<at::Generator> gen);
void min_p_sampling_from_probs(at::Tensor probs, at::Tensor output,
std::optional<at::Tensor> maybe_indices,
std::optional<at::Tensor> maybe_min_p_arr, double min_p_val,
bool deterministic, std::optional<at::Generator> gen);
void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor output,
std::optional<at::Tensor> maybe_indices,
std::optional<at::Tensor> maybe_top_k_arr, double top_k_val,
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val,
bool deterministic, std::optional<at::Generator> gen);
void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val);
void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val);
void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits,
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val);
void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_ids,
at::Tensor target_probs, at::Tensor output_token_ids,
at::Tensor output_accepted_token_num,
at::Tensor output_emitted_draft_token_num, bool deterministic,
std::optional<at::Generator> gen);
//========== Torch Library ==========
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
// activation
// Fused SiLU and Mul
m.def("silu_and_mul", silu_and_mul);
// Fused GeLU Tanh and Mul
m.def("gelu_tanh_and_mul", gelu_tanh_and_mul);
// Fused GeLU and Mul
m.def("gelu_and_mul", gelu_and_mul);
// cascade
// Merge two self-attention states
m.def("merge_state", merge_state);
// Merge another self-attention state in-place.
m.def("merge_state_in_place", merge_state_in_place);
// "Merge multiple self-attention states"
m.def("merge_states", merge_states);
// decode
// "Single-request decode with KV-Cache operator"
m.def("single_decode_with_kv_cache", single_decode_with_kv_cache);
m.def("batch_decode_with_paged_kv_cache_plan", BatchDecodeWithPagedKVCachePlan);
m.def("batch_decode_with_paged_kv_cache_run", BatchDecodeWithPagedKVCacheRun);
// gemm
// BMM FP8
m.def("bmm_fp8", bmm_fp8);
// Cutlass Segment GEMM operator
m.def("cutlass_segment_gemm", CutlassSegmentGEMM);
// norm
// Root mean square normalization
m.def("rmsnorm", rmsnorm);
// Fused add root mean square normalization
m.def("fused_add_rmsnorm", fused_add_rmsnorm);
// Gemma Root mean square normalization
m.def("gemma_rmsnorm", gemma_rmsnorm);
// Gemma Fused add root mean square normalization
m.def("gemma_fused_add_rmsnorm", gemma_fused_add_rmsnorm);
// page
// Append paged KV-Cache operator
m.def("append_paged_kv_cache", append_paged_kv_cache);
// Append paged MLA KV-Cache operator
m.def("append_paged_mla_kv_cache", append_paged_mla_kv_cache);
// Precompute block sparse offsets
m.def("block_sparse_indices_to_vector_sparse_offsets",
block_sparse_indices_to_vector_sparse_offsets);
// prefill
// Single-request prefill attention with KV-Cache operator
m.def("single_prefill_with_kv_cache", single_prefill_with_kv_cache);
m.def("batch_prefill_with_kv_cache_plan", BatchPrefillWithKVCachePlan);
m.def("batch_prefill_with_ragged_kv_cache_run", BatchPrefillWithRaggedKVCacheRun);
m.def("batch_prefill_with_paged_kv_cache_run", BatchPrefillWithPagedKVCacheRun);
// pod-attention
// Temporarily disabled because we don't generate the implementation yet.
// m.def("pod_with_kv_cache_tensor", pod_with_kv_cache_tensor);
// quantization
// GPU packbits operator
m.def("packbits", packbits);
// GPU segment packbits operator
m.def("segment_packbits", segment_packbits);
// rope
// "Apply RoPE"
m.def("apply_rope", apply_rope);
// "Apply Llama 3.1 style RoPE"
m.def("apply_llama31_rope", apply_llama31_rope);
// "Apply RoPE with positional ids"
m.def("apply_rope_pos_ids", apply_rope_pos_ids);
// "Apply Llama 3.1 style RoPE with positional ids"
m.def("apply_llama31_rope_pos_ids", apply_llama31_rope_pos_ids);
// "Apply RoPE with positional ids and cosine/sine cache"
m.def("apply_rope_pos_ids_cos_sin_cache", apply_rope_pos_ids_cos_sin_cache);
// sampling
// Sample from probabilities
m.def("sampling_from_probs", sampling_from_probs);
// Top-k sampling from probabilities
m.def("top_k_sampling_from_probs", top_k_sampling_from_probs);
// Min-p sampling from probabilities
m.def("min_p_sampling_from_probs", min_p_sampling_from_probs);
// Top-p sampling from probabilities
m.def("top_p_sampling_from_probs", top_p_sampling_from_probs);
// Top-k and top-p sampling from probabilities
m.def("top_k_top_p_sampling_from_probs", top_k_top_p_sampling_from_probs);
// Renormalize probabilities by top-k mask
m.def("top_k_renorm_probs", top_k_renorm_probs);
// Renormalize probabilities by top-p mask
m.def("top_p_renorm_probs", top_p_renorm_probs);
// Mask logits by top-k mask
m.def("top_k_mask_logits", top_k_mask_logits);
// Speculative sampling from sequence of probabilities
m.def("chain_speculative_sampling", chain_speculative_sampling);
}