Skip to content

Commit 74ffba1

Browse files
authored
feat: non-inplace rope operators (#405)
As requested in #403, this PR implements non-inplace rope operators.
1 parent 2496f5b commit 74ffba1

File tree

10 files changed

+555
-17
lines changed

10 files changed

+555
-17
lines changed

docs/api/python/rope.rst

+2
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@ Kernels for applying rotary embeddings.
1212

1313
apply_rope_inplace
1414
apply_llama31_rope_inplace
15+
apply_rope
16+
apply_llama31_rope

include/flashinfer/attention/prefill.cuh

+2-2
Original file line numberDiff line numberDiff line change
@@ -1440,7 +1440,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg
14401440
write_o_reg_gmem<num_warps_x, num_warps_z, num_frags_x, num_frags_y>(
14411441
o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len,
14421442
/*o_stride_n=*/
1443-
partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim,
1443+
partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim,
14441444
/*o_stride_h=*/head_dim, group_size);
14451445

14461446
// write lse
@@ -1732,7 +1732,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
17321732
write_o_reg_gmem<num_warps_x, num_warps_z, num_frags_x, num_frags_y>(
17331733
o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len,
17341734
/*o_stride_n=*/
1735-
partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim,
1735+
partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim,
17361736
/*o_stride_h=*/head_dim, group_size);
17371737

17381738
// write lse

include/flashinfer/pos_enc.cuh

+174
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,86 @@ __global__ void BatchQKApplyRotaryInPlaceKernel(
191191
}
192192
}
193193

194+
template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
195+
typename IdType>
196+
__global__ void BatchQKApplyRotaryKernel(DType* __restrict__ q, DType* __restrict__ k,
197+
DType* __restrict__ q_rope, DType* __restrict__ k_rope,
198+
IdType* __restrict__ indptr, IdType* __restrict__ offsets,
199+
uint32_t batch_size, uint32_t num_qo_heads,
200+
uint32_t num_kv_heads, size_t q_stride_n,
201+
size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
202+
float smooth_a, float smooth_b, float rope_rcp_scale,
203+
float rope_rcp_theta) {
204+
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
205+
const uint32_t bdy = blockDim.y;
206+
vec_t<float, vec_size> freq;
207+
#pragma unroll
208+
for (uint32_t i = 0; i < vec_size; ++i) {
209+
if constexpr (interleave) {
210+
freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(head_dim));
211+
} else {
212+
freq[i] = __powf(rope_rcp_theta,
213+
float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim));
214+
}
215+
216+
float smooth = freq[i] * smooth_a + smooth_b;
217+
smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1]
218+
freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i];
219+
}
220+
221+
if (bx < batch_size * num_qo_heads) {
222+
// apply rotary to q
223+
const uint32_t batch_idx = bx / num_qo_heads;
224+
const uint32_t qo_head_idx = bx % num_qo_heads;
225+
const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx];
226+
const uint32_t offset = offsets[batch_idx];
227+
#pragma unroll 2
228+
for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) {
229+
vec_t<float, vec_size> q_vec;
230+
if (i * bdy + ty < seq_len) {
231+
DType* q_ptr = q + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0,
232+
q_stride_n, q_stride_h);
233+
DType* q_rope_ptr =
234+
q_rope + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0,
235+
/*q_stride_n=*/num_qo_heads * head_dim,
236+
/*q_stride_h=*/head_dim);
237+
if constexpr (interleave) {
238+
q_vec =
239+
vec_apply_llama_rope_interleave<vec_size, bdx>(q_ptr, freq, offset + i * bdy + ty);
240+
} else {
241+
q_vec = vec_apply_llama_rope<vec_size, bdx>(q_ptr, freq, offset + i * bdy + ty);
242+
}
243+
q_vec.cast_store(q_rope_ptr + tx * vec_size);
244+
}
245+
}
246+
} else {
247+
// apply rotary to k
248+
uint32_t batch_idx = (bx - batch_size * num_qo_heads) / num_kv_heads;
249+
uint32_t kv_head_idx = (bx - batch_size * num_qo_heads) % num_kv_heads;
250+
const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx];
251+
const uint32_t offset = offsets[batch_idx];
252+
#pragma unroll 2
253+
for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) {
254+
vec_t<float, vec_size> k_vec;
255+
if (i * bdy + ty < seq_len) {
256+
DType* k_ptr = k + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0,
257+
k_stride_n, k_stride_h);
258+
DType* k_rope_ptr =
259+
k_rope + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0,
260+
/*kv_stride_n=*/num_kv_heads * head_dim,
261+
/*kv_stride_h=*/head_dim);
262+
if constexpr (interleave) {
263+
k_vec =
264+
vec_apply_llama_rope_interleave<vec_size, bdx>(k_ptr, freq, offset + i * bdy + ty);
265+
} else {
266+
k_vec = vec_apply_llama_rope<vec_size, bdx>(k_ptr, freq, offset + i * bdy + ty);
267+
}
268+
k_vec.cast_store(k_rope_ptr + +tx * vec_size);
269+
}
270+
}
271+
}
272+
}
273+
194274
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
195275
if (interleave) { \
196276
const bool INTERLEAVE = true; \
@@ -289,6 +369,100 @@ cudaError_t BatchQKApplyLlama31RotaryInPlace(
289369
return cudaSuccess;
290370
}
291371

372+
template <typename DType, typename IdType>
373+
cudaError_t BatchQKApplyRotary(DType* __restrict__ q, DType* __restrict__ k,
374+
DType* __restrict__ q_rope, DType* __restrict__ k_rope,
375+
IdType* __restrict__ indptr, IdType* __restrict__ offsets,
376+
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
377+
uint32_t head_dim, size_t q_stride_n, size_t q_stride_h,
378+
size_t k_stride_n, size_t k_stride_h, bool interleave,
379+
float rope_scale, float rope_theta, cudaStream_t stream = nullptr) {
380+
float rope_rcp_scale = 1.0f / rope_scale;
381+
float rope_rcp_theta = 1.0f / rope_theta;
382+
float smooth_a = 0.f;
383+
float smooth_b = 0.f;
384+
385+
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
386+
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
387+
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
388+
constexpr uint32_t bdx = HEAD_DIM / vec_size;
389+
uint32_t num_threads = std::max(128U, bdx);
390+
uint32_t bdy = num_threads / bdx;
391+
dim3 nblks(batch_size * (num_qo_heads + num_kv_heads));
392+
dim3 nthrs(bdx, bdy);
393+
auto kernel = BatchQKApplyRotaryKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
394+
void* args[] = {(void*)&q,
395+
(void*)&k,
396+
(void*)&q_rope,
397+
(void*)&k_rope,
398+
(void*)&indptr,
399+
(void*)&offsets,
400+
(void*)&batch_size,
401+
(void*)&num_qo_heads,
402+
(void*)&num_kv_heads,
403+
(void*)&q_stride_n,
404+
(void*)&q_stride_h,
405+
(void*)&k_stride_n,
406+
(void*)&k_stride_h,
407+
(void*)&smooth_a,
408+
(void*)&smooth_b,
409+
(void*)&rope_rcp_scale,
410+
(void*)&rope_rcp_theta};
411+
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
412+
});
413+
});
414+
415+
return cudaSuccess;
416+
}
417+
418+
template <typename DType, typename IdType>
419+
cudaError_t BatchQKApplyLlama31Rotary(DType* __restrict__ q, DType* __restrict__ k,
420+
DType* __restrict__ q_rope, DType* __restrict__ k_rope,
421+
IdType* __restrict__ indptr, IdType* __restrict__ offsets,
422+
uint32_t batch_size, uint32_t num_qo_heads,
423+
uint32_t num_kv_heads, uint32_t head_dim, size_t q_stride_n,
424+
size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
425+
bool interleave, float rope_scale, float rope_theta,
426+
float low_freq_factor, float high_freq_factor,
427+
float old_context_length, cudaStream_t stream = nullptr) {
428+
float rope_rcp_scale = 1.0f / rope_scale;
429+
float rope_rcp_theta = 1.0f / rope_theta;
430+
float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor);
431+
float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f);
432+
433+
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
434+
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
435+
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
436+
constexpr uint32_t bdx = HEAD_DIM / vec_size;
437+
uint32_t num_threads = std::max(128U, bdx);
438+
uint32_t bdy = num_threads / bdx;
439+
dim3 nblks(batch_size * (num_qo_heads + num_kv_heads));
440+
dim3 nthrs(bdx, bdy);
441+
auto kernel = BatchQKApplyRotaryKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
442+
void* args[] = {(void*)&q,
443+
(void*)&k,
444+
(void*)&q_rope,
445+
(void*)&k_rope,
446+
(void*)&indptr,
447+
(void*)&offsets,
448+
(void*)&batch_size,
449+
(void*)&num_qo_heads,
450+
(void*)&num_kv_heads,
451+
(void*)&q_stride_n,
452+
(void*)&q_stride_h,
453+
(void*)&k_stride_n,
454+
(void*)&k_stride_h,
455+
(void*)&smooth_a,
456+
(void*)&smooth_b,
457+
(void*)&rope_rcp_scale,
458+
(void*)&rope_rcp_theta};
459+
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
460+
});
461+
});
462+
463+
return cudaSuccess;
464+
}
465+
292466
} // namespace flashinfer
293467

294468
#endif // FLASHINFER_POS_ENC_CUH_

python/csrc/flashinfer_ops.cu

+2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
4545
m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place");
4646
m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace,
4747
"Apply Llama 3.1 style RoPE in-place");
48+
m.def("apply_rope", &apply_rope, "Apply RoPE");
49+
m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE");
4850
m.def("packbits", &packbits, "GPU packbits operator");
4951
m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator");
5052
py::class_<BatchDecodeWithPagedKVCachePyTorchWrapper>(m,

python/csrc/flashinfer_ops.h

+10
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,16 @@ void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor
8383
float rope_theta, float low_freq_factor, float high_freq_factor,
8484
float old_context_length);
8585

86+
std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
87+
torch::Tensor offsets, bool interleave, float rope_scale,
88+
float rope_theta);
89+
90+
std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
91+
torch::Tensor indptr, torch::Tensor offsets,
92+
bool interleave, float rope_scale, float rope_theta,
93+
float low_freq_factor, float high_freq_factor,
94+
float old_context_length);
95+
8696
torch::Tensor packbits(torch::Tensor x, const std::string& bitorder);
8797

8898
torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr,

python/csrc/rope.cu

+99-1
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,102 @@ void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor
102102
std::string(cudaGetErrorString(status)));
103103
return true;
104104
});
105-
}
105+
}
106+
107+
std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
108+
torch::Tensor offsets, bool interleave, float rope_scale,
109+
float rope_theta) {
110+
CHECK_CUDA(q); // not necessarily contiguous
111+
CHECK_CUDA(k); // not necessarily contiguous
112+
CHECK_INPUT(indptr);
113+
CHECK_INPUT(offsets);
114+
115+
auto device = q.device();
116+
CHECK_EQ(k.device(), device);
117+
CHECK_DIM(3, q); // q: (nnz, H_Q, D)
118+
CHECK_DIM(3, k); // k: (nnz, H_K, D)
119+
CHECK_DIM(1, indptr); // indptr: (B + 1)
120+
CHECK_DIM(1, offsets); // offsets: (B)
121+
CHECK_EQ(q.size(0), k.size(0));
122+
CHECK_EQ(q.size(2), k.size(2));
123+
unsigned int num_qo_heads = q.size(1);
124+
unsigned int num_kv_heads = k.size(1);
125+
unsigned int head_dim = q.size(2);
126+
unsigned int batch_size = offsets.size(0);
127+
CHECK_EQ(indptr.size(0), batch_size + 1);
128+
size_t q_stride_n = q.stride(0);
129+
size_t q_stride_h = q.stride(1);
130+
size_t k_stride_n = k.stride(0);
131+
size_t k_stride_h = k.stride(1);
132+
indptr = indptr.to(torch::kInt32);
133+
offsets = offsets.to(torch::kInt32);
134+
// NOTE(Zihao): empty_like do not copy strides so it's okay to use it here.
135+
auto q_rope = torch::empty_like(q);
136+
auto k_rope = torch::empty_like(k);
137+
138+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
139+
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
140+
cudaError_t status = BatchQKApplyRotary(
141+
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
142+
static_cast<c_type*>(q_rope.data_ptr()), static_cast<c_type*>(k_rope.data_ptr()),
143+
static_cast<int32_t*>(indptr.data_ptr()), static_cast<int32_t*>(offsets.data_ptr()),
144+
batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n,
145+
k_stride_h, interleave, rope_scale, rope_theta, torch_current_stream);
146+
TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotary failed with error code " +
147+
std::string(cudaGetErrorString(status)));
148+
return true;
149+
});
150+
151+
return {q_rope, k_rope};
152+
}
153+
154+
std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
155+
torch::Tensor indptr, torch::Tensor offsets,
156+
bool interleave, float rope_scale, float rope_theta,
157+
float low_freq_factor, float high_freq_factor,
158+
float old_context_length) {
159+
CHECK_CUDA(q); // not necessarily contiguous
160+
CHECK_CUDA(k); // not necessarily contiguous
161+
CHECK_INPUT(indptr);
162+
CHECK_INPUT(offsets);
163+
164+
auto device = q.device();
165+
CHECK_EQ(k.device(), device);
166+
CHECK_DIM(3, q); // q: (nnz, H_Q, D)
167+
CHECK_DIM(3, k); // k: (nnz, H_K, D)
168+
CHECK_DIM(1, indptr); // indptr: (B + 1)
169+
CHECK_DIM(1, offsets); // offsets: (B)
170+
CHECK_EQ(q.size(0), k.size(0));
171+
CHECK_EQ(q.size(2), k.size(2));
172+
unsigned int num_qo_heads = q.size(1);
173+
unsigned int num_kv_heads = k.size(1);
174+
unsigned int head_dim = q.size(2);
175+
unsigned int batch_size = offsets.size(0);
176+
CHECK_EQ(indptr.size(0), batch_size + 1);
177+
size_t q_stride_n = q.stride(0);
178+
size_t q_stride_h = q.stride(1);
179+
size_t k_stride_n = k.stride(0);
180+
size_t k_stride_h = k.stride(1);
181+
indptr = indptr.to(torch::kInt32);
182+
offsets = offsets.to(torch::kInt32);
183+
184+
// NOTE(Zihao): empty_like do not copy strides so it's okay to use it here.
185+
auto q_rope = torch::empty_like(q);
186+
auto k_rope = torch::empty_like(k);
187+
188+
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
189+
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
190+
cudaError_t status = BatchQKApplyLlama31Rotary(
191+
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
192+
static_cast<c_type*>(q_rope.data_ptr()), static_cast<c_type*>(k_rope.data_ptr()),
193+
static_cast<int32_t*>(indptr.data_ptr()), static_cast<int32_t*>(offsets.data_ptr()),
194+
batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n,
195+
k_stride_h, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor,
196+
old_context_length, torch_current_stream);
197+
TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31Rotary failed with error code " +
198+
std::string(cudaGetErrorString(status)));
199+
return true;
200+
});
201+
202+
return {q_rope, k_rope};
203+
}

python/csrc/single_prefill.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
7171
TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative");
7272
const LogitsPostHook logits_post_hook =
7373
logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone;
74-
74+
7575
bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
7676
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] {
7777
return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] {

python/flashinfer/__init__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,12 @@
4444
chain_speculative_sampling,
4545
)
4646
from .norm import rmsnorm
47-
from .rope import apply_rope_inplace, apply_llama31_rope_inplace
47+
from .rope import (
48+
apply_rope_inplace,
49+
apply_llama31_rope_inplace,
50+
apply_rope,
51+
apply_llama31_rope,
52+
)
4853
from .group_gemm import SegmentGEMMWrapper
4954
from .quantization import packbits, segment_packbits
5055

0 commit comments

Comments
 (0)