@@ -191,6 +191,86 @@ __global__ void BatchQKApplyRotaryInPlaceKernel(
191
191
}
192
192
}
193
193
194
+ template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
195
+ typename IdType>
196
+ __global__ void BatchQKApplyRotaryKernel (DType* __restrict__ q, DType* __restrict__ k,
197
+ DType* __restrict__ q_rope, DType* __restrict__ k_rope,
198
+ IdType* __restrict__ indptr, IdType* __restrict__ offsets,
199
+ uint32_t batch_size, uint32_t num_qo_heads,
200
+ uint32_t num_kv_heads, size_t q_stride_n,
201
+ size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
202
+ float smooth_a, float smooth_b, float rope_rcp_scale,
203
+ float rope_rcp_theta) {
204
+ uint32_t bx = blockIdx .x , tx = threadIdx .x , ty = threadIdx .y ;
205
+ const uint32_t bdy = blockDim .y ;
206
+ vec_t <float , vec_size> freq;
207
+ #pragma unroll
208
+ for (uint32_t i = 0 ; i < vec_size; ++i) {
209
+ if constexpr (interleave) {
210
+ freq[i] = __powf (rope_rcp_theta, float (2 * ((tx * vec_size + i) / 2 )) / float (head_dim));
211
+ } else {
212
+ freq[i] = __powf (rope_rcp_theta,
213
+ float (2 * ((tx * vec_size + i) % (head_dim / 2 ))) / float (head_dim));
214
+ }
215
+
216
+ float smooth = freq[i] * smooth_a + smooth_b;
217
+ smooth = max (0 .0f , min (1 .0f , smooth)); // clamp to [0, 1]
218
+ freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i];
219
+ }
220
+
221
+ if (bx < batch_size * num_qo_heads) {
222
+ // apply rotary to q
223
+ const uint32_t batch_idx = bx / num_qo_heads;
224
+ const uint32_t qo_head_idx = bx % num_qo_heads;
225
+ const uint32_t seq_len = indptr[batch_idx + 1 ] - indptr[batch_idx];
226
+ const uint32_t offset = offsets[batch_idx];
227
+ #pragma unroll 2
228
+ for (uint32_t i = 0 ; i < (seq_len + bdy - 1 ) / bdy; ++i) {
229
+ vec_t <float , vec_size> q_vec;
230
+ if (i * bdy + ty < seq_len) {
231
+ DType* q_ptr = q + get_elem_offset_impl (indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0 ,
232
+ q_stride_n, q_stride_h);
233
+ DType* q_rope_ptr =
234
+ q_rope + get_elem_offset_impl (indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0 ,
235
+ /* q_stride_n=*/ num_qo_heads * head_dim,
236
+ /* q_stride_h=*/ head_dim);
237
+ if constexpr (interleave) {
238
+ q_vec =
239
+ vec_apply_llama_rope_interleave<vec_size, bdx>(q_ptr, freq, offset + i * bdy + ty);
240
+ } else {
241
+ q_vec = vec_apply_llama_rope<vec_size, bdx>(q_ptr, freq, offset + i * bdy + ty);
242
+ }
243
+ q_vec.cast_store (q_rope_ptr + tx * vec_size);
244
+ }
245
+ }
246
+ } else {
247
+ // apply rotary to k
248
+ uint32_t batch_idx = (bx - batch_size * num_qo_heads) / num_kv_heads;
249
+ uint32_t kv_head_idx = (bx - batch_size * num_qo_heads) % num_kv_heads;
250
+ const uint32_t seq_len = indptr[batch_idx + 1 ] - indptr[batch_idx];
251
+ const uint32_t offset = offsets[batch_idx];
252
+ #pragma unroll 2
253
+ for (uint32_t i = 0 ; i < (seq_len + bdy - 1 ) / bdy; ++i) {
254
+ vec_t <float , vec_size> k_vec;
255
+ if (i * bdy + ty < seq_len) {
256
+ DType* k_ptr = k + get_elem_offset_impl (indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0 ,
257
+ k_stride_n, k_stride_h);
258
+ DType* k_rope_ptr =
259
+ k_rope + get_elem_offset_impl (indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0 ,
260
+ /* kv_stride_n=*/ num_kv_heads * head_dim,
261
+ /* kv_stride_h=*/ head_dim);
262
+ if constexpr (interleave) {
263
+ k_vec =
264
+ vec_apply_llama_rope_interleave<vec_size, bdx>(k_ptr, freq, offset + i * bdy + ty);
265
+ } else {
266
+ k_vec = vec_apply_llama_rope<vec_size, bdx>(k_ptr, freq, offset + i * bdy + ty);
267
+ }
268
+ k_vec.cast_store (k_rope_ptr + +tx * vec_size);
269
+ }
270
+ }
271
+ }
272
+ }
273
+
194
274
#define DISPATCH_INTERLEAVE (interleave, INTERLEAVE, ...) \
195
275
if (interleave) { \
196
276
const bool INTERLEAVE = true ; \
@@ -289,6 +369,100 @@ cudaError_t BatchQKApplyLlama31RotaryInPlace(
289
369
return cudaSuccess;
290
370
}
291
371
372
+ template <typename DType, typename IdType>
373
+ cudaError_t BatchQKApplyRotary (DType* __restrict__ q, DType* __restrict__ k,
374
+ DType* __restrict__ q_rope, DType* __restrict__ k_rope,
375
+ IdType* __restrict__ indptr, IdType* __restrict__ offsets,
376
+ uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
377
+ uint32_t head_dim, size_t q_stride_n, size_t q_stride_h,
378
+ size_t k_stride_n, size_t k_stride_h, bool interleave,
379
+ float rope_scale, float rope_theta, cudaStream_t stream = nullptr ) {
380
+ float rope_rcp_scale = 1 .0f / rope_scale;
381
+ float rope_rcp_theta = 1 .0f / rope_theta;
382
+ float smooth_a = 0 .f ;
383
+ float smooth_b = 0 .f ;
384
+
385
+ DISPATCH_INTERLEAVE (interleave, INTERLEAVE, {
386
+ DISPATCH_HEAD_DIM (head_dim, HEAD_DIM, {
387
+ constexpr uint32_t vec_size = std::max (16 / sizeof (DType), HEAD_DIM / 32 );
388
+ constexpr uint32_t bdx = HEAD_DIM / vec_size;
389
+ uint32_t num_threads = std::max (128U , bdx);
390
+ uint32_t bdy = num_threads / bdx;
391
+ dim3 nblks (batch_size * (num_qo_heads + num_kv_heads));
392
+ dim3 nthrs (bdx, bdy);
393
+ auto kernel = BatchQKApplyRotaryKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
394
+ void * args[] = {(void *)&q,
395
+ (void *)&k,
396
+ (void *)&q_rope,
397
+ (void *)&k_rope,
398
+ (void *)&indptr,
399
+ (void *)&offsets,
400
+ (void *)&batch_size,
401
+ (void *)&num_qo_heads,
402
+ (void *)&num_kv_heads,
403
+ (void *)&q_stride_n,
404
+ (void *)&q_stride_h,
405
+ (void *)&k_stride_n,
406
+ (void *)&k_stride_h,
407
+ (void *)&smooth_a,
408
+ (void *)&smooth_b,
409
+ (void *)&rope_rcp_scale,
410
+ (void *)&rope_rcp_theta};
411
+ FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, 0 , stream));
412
+ });
413
+ });
414
+
415
+ return cudaSuccess;
416
+ }
417
+
418
+ template <typename DType, typename IdType>
419
+ cudaError_t BatchQKApplyLlama31Rotary (DType* __restrict__ q, DType* __restrict__ k,
420
+ DType* __restrict__ q_rope, DType* __restrict__ k_rope,
421
+ IdType* __restrict__ indptr, IdType* __restrict__ offsets,
422
+ uint32_t batch_size, uint32_t num_qo_heads,
423
+ uint32_t num_kv_heads, uint32_t head_dim, size_t q_stride_n,
424
+ size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
425
+ bool interleave, float rope_scale, float rope_theta,
426
+ float low_freq_factor, float high_freq_factor,
427
+ float old_context_length, cudaStream_t stream = nullptr ) {
428
+ float rope_rcp_scale = 1 .0f / rope_scale;
429
+ float rope_rcp_theta = 1 .0f / rope_theta;
430
+ float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor);
431
+ float smooth_b = -1 .0f / (high_freq_factor / low_freq_factor - 1 .0f );
432
+
433
+ DISPATCH_INTERLEAVE (interleave, INTERLEAVE, {
434
+ DISPATCH_HEAD_DIM (head_dim, HEAD_DIM, {
435
+ constexpr uint32_t vec_size = std::max (16 / sizeof (DType), HEAD_DIM / 32 );
436
+ constexpr uint32_t bdx = HEAD_DIM / vec_size;
437
+ uint32_t num_threads = std::max (128U , bdx);
438
+ uint32_t bdy = num_threads / bdx;
439
+ dim3 nblks (batch_size * (num_qo_heads + num_kv_heads));
440
+ dim3 nthrs (bdx, bdy);
441
+ auto kernel = BatchQKApplyRotaryKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
442
+ void * args[] = {(void *)&q,
443
+ (void *)&k,
444
+ (void *)&q_rope,
445
+ (void *)&k_rope,
446
+ (void *)&indptr,
447
+ (void *)&offsets,
448
+ (void *)&batch_size,
449
+ (void *)&num_qo_heads,
450
+ (void *)&num_kv_heads,
451
+ (void *)&q_stride_n,
452
+ (void *)&q_stride_h,
453
+ (void *)&k_stride_n,
454
+ (void *)&k_stride_h,
455
+ (void *)&smooth_a,
456
+ (void *)&smooth_b,
457
+ (void *)&rope_rcp_scale,
458
+ (void *)&rope_rcp_theta};
459
+ FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, 0 , stream));
460
+ });
461
+ });
462
+
463
+ return cudaSuccess;
464
+ }
465
+
292
466
} // namespace flashinfer
293
467
294
468
#endif // FLASHINFER_POS_ENC_CUH_
0 commit comments