@@ -168,6 +168,56 @@ __device__ __forceinline__ vec_t<float, vec_size> vec_apply_llama_rope_cos_sin_i
168
168
return vec;
169
169
}
170
170
171
+ template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
172
+ typename IdType>
173
+ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheHeadParallelismKernel (
174
+ DType* q, DType* k, DType* q_rope, DType* k_rope, float * __restrict__ cos_cache,
175
+ float * __restrict__ sin_cache, IdType* __restrict__ pos_ids, uint32_t nnz,
176
+ uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rotary_dim, size_t q_stride_n,
177
+ size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n,
178
+ size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h) {
179
+ uint32_t bx = blockIdx .x , tx = threadIdx .x , ty = threadIdx .y ;
180
+ uint32_t by = blockIdx .y ;
181
+ const uint32_t bdy = blockDim .y ;
182
+
183
+ vec_t <float , vec_size> cos, sin;
184
+ if (bx * bdy + ty < nnz) {
185
+ const uint32_t idx = bx * bdy + ty;
186
+ const IdType pos = pos_ids[idx];
187
+
188
+ if (tx * vec_size < rotary_dim) {
189
+ cos.load (cos_cache + pos * rotary_dim + tx * vec_size);
190
+ sin.load (sin_cache + pos * rotary_dim + tx * vec_size);
191
+ }
192
+
193
+ if (by < num_qo_heads) {
194
+ uint32_t qo_head_idx = by;
195
+ DType* q_ptr = q + get_elem_offset_impl (idx, qo_head_idx, 0 , q_stride_n, q_stride_h);
196
+ DType* q_rope_ptr =
197
+ q_rope + get_elem_offset_impl (idx, qo_head_idx, 0 , q_rope_stride_n, q_rope_stride_h);
198
+ vec_t <float , vec_size> q_vec;
199
+ if constexpr (interleave) {
200
+ q_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
201
+ } else {
202
+ q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
203
+ }
204
+ q_vec.cast_store (q_rope_ptr + tx * vec_size);
205
+ } else {
206
+ uint32_t kv_head_idx = by - num_qo_heads;
207
+ DType* k_ptr = k + get_elem_offset_impl (idx, kv_head_idx, 0 , k_stride_n, k_stride_h);
208
+ DType* k_rope_ptr =
209
+ k_rope + get_elem_offset_impl (idx, kv_head_idx, 0 , k_rope_stride_n, k_rope_stride_h);
210
+ vec_t <float , vec_size> k_vec;
211
+ if constexpr (interleave) {
212
+ k_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
213
+ } else {
214
+ k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
215
+ }
216
+ k_vec.cast_store (k_rope_ptr + tx * vec_size);
217
+ }
218
+ }
219
+ }
220
+
171
221
template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
172
222
typename IdType>
173
223
__global__ void BatchQKApplyRotaryPosIdsCosSinCacheKernel (
@@ -221,69 +271,144 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheKernel(
221
271
222
272
template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
223
273
typename IdType>
224
- __global__ void BatchQKApplyRotaryPosIdsKernel (
274
+ __global__ void BatchQKApplyRotaryPosIdsHeadParallelismKernel (
225
275
DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ pos_ids, uint32_t nnz,
226
276
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rotary_dim, size_t q_stride_n,
227
277
size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n,
228
278
size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, float smooth_a,
229
279
float smooth_b, float rope_rcp_scale, float rope_rcp_theta) {
230
280
// NOTE: q and q_rope may be the same ptr, so do k and k_rope
231
281
uint32_t bx = blockIdx .x , tx = threadIdx .x , ty = threadIdx .y ;
282
+ uint32_t by = blockIdx .y ;
283
+ const uint32_t bdy = blockDim .y ;
284
+ vec_t <float , vec_size> freq;
285
+ if (tx * vec_size < rotary_dim) {
286
+ #pragma unroll
287
+ for (uint32_t i = 0 ; i < vec_size; ++i) {
288
+ if constexpr (interleave) {
289
+ freq[i] = __powf (rope_rcp_theta, float (2 * ((tx * vec_size + i) / 2 )) / float (rotary_dim));
290
+ } else {
291
+ freq[i] = __powf (rope_rcp_theta,
292
+ float (2 * ((tx * vec_size + i) % (rotary_dim / 2 ))) / float (rotary_dim));
293
+ }
232
294
233
- const uint32_t idx = bx * blockDim . y + ty ;
234
- const uint32_t pos_idx = idx / (num_qo_heads + num_kv_heads);
235
- if (pos_idx >= nnz) {
236
- return ;
295
+ float smooth = freq[i] * smooth_a + smooth_b ;
296
+ smooth = max ( 0 . 0f , min ( 1 . 0f , smooth)); // clamp to [0, 1]
297
+ freq[i] = ( 1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i];
298
+ }
237
299
}
238
300
239
- const IdType pos = pos_ids[pos_idx];
240
-
241
301
vec_t <float , vec_size> cos, sin;
302
+
303
+ if (bx * bdy + ty < nnz) {
304
+ const uint32_t idx = bx * bdy + ty;
305
+ const IdType pos = pos_ids[idx];
306
+
307
+ if (tx * vec_size < rotary_dim) {
308
+ #pragma unroll
309
+ for (uint32_t i = 0 ; i < vec_size; ++i) {
310
+ float embed = float (pos) * freq[i];
311
+ __sincosf (embed, &sin[i], &cos[i]);
312
+ }
313
+ }
314
+
315
+ if (by < num_qo_heads) {
316
+ uint32_t qo_head_idx = by;
317
+ DType* q_ptr = q + get_elem_offset_impl (idx, qo_head_idx, 0 , q_stride_n, q_stride_h);
318
+ DType* q_rope_ptr =
319
+ q_rope + get_elem_offset_impl (idx, qo_head_idx, 0 , q_rope_stride_n, q_rope_stride_h);
320
+ vec_t <float , vec_size> q_vec;
321
+ if constexpr (interleave) {
322
+ q_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
323
+ } else {
324
+ q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
325
+ }
326
+ q_vec.cast_store (q_rope_ptr + tx * vec_size);
327
+ } else {
328
+ uint32_t kv_head_idx = by - num_qo_heads;
329
+ DType* k_ptr = k + get_elem_offset_impl (idx, kv_head_idx, 0 , k_stride_n, k_stride_h);
330
+ DType* k_rope_ptr =
331
+ k_rope + get_elem_offset_impl (idx, kv_head_idx, 0 , k_rope_stride_n, k_rope_stride_h);
332
+ vec_t <float , vec_size> k_vec;
333
+ if constexpr (interleave) {
334
+ k_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
335
+ } else {
336
+ k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
337
+ }
338
+ k_vec.cast_store (k_rope_ptr + tx * vec_size);
339
+ }
340
+ }
341
+ }
342
+
343
+ template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
344
+ typename IdType>
345
+ __global__ void BatchQKApplyRotaryPosIdsKernel (
346
+ DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ pos_ids, uint32_t nnz,
347
+ uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rotary_dim, size_t q_stride_n,
348
+ size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n,
349
+ size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, float smooth_a,
350
+ float smooth_b, float rope_rcp_scale, float rope_rcp_theta) {
351
+ // NOTE: q and q_rope may be the same ptr, so do k and k_rope
352
+ uint32_t bx = blockIdx .x , tx = threadIdx .x , ty = threadIdx .y ;
353
+ const uint32_t bdy = blockDim .y ;
354
+ vec_t <float , vec_size> freq;
242
355
if (tx * vec_size < rotary_dim) {
243
- #pragma unroll
356
+ #pragma unroll
244
357
for (uint32_t i = 0 ; i < vec_size; ++i) {
245
- float freq;
246
358
if constexpr (interleave) {
247
- freq = __powf (rope_rcp_theta, float (2 * ((tx * vec_size + i) / 2 )) / float (rotary_dim));
359
+ freq[i] = __powf (rope_rcp_theta, float (2 * ((tx * vec_size + i) / 2 )) / float (rotary_dim));
248
360
} else {
249
- freq = __powf (rope_rcp_theta,
250
- float (2 * ((tx * vec_size + i) % (rotary_dim / 2 ))) / float (rotary_dim));
361
+ freq[i] = __powf (rope_rcp_theta,
362
+ float (2 * ((tx * vec_size + i) % (rotary_dim / 2 ))) / float (rotary_dim));
251
363
}
252
364
253
- float smooth = freq * smooth_a + smooth_b;
365
+ float smooth = freq[i] * smooth_a + smooth_b;
254
366
smooth = max (0 .0f , min (1 .0f , smooth)); // clamp to [0, 1]
255
- freq = (1 - smooth) * (freq * rope_rcp_scale) + smooth * freq;
256
-
257
- const float embed = float (pos) * freq;
258
- __sincosf (embed, &sin[i], &cos[i]);
367
+ freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i];
259
368
}
260
369
}
261
370
262
- const uint32_t head_idx = idx % (num_qo_heads + num_kv_heads);
263
- if (head_idx < num_qo_heads) {
264
- const uint32_t qo_head_idx = head_idx;
265
- DType* q_ptr = q + get_elem_offset_impl (pos_idx, qo_head_idx, 0 , q_stride_n, q_stride_h);
266
- DType* q_rope_ptr =
267
- q_rope + get_elem_offset_impl (pos_idx, qo_head_idx, 0 , q_rope_stride_n, q_rope_stride_h);
268
- vec_t <float , vec_size> q_vec;
269
- if constexpr (interleave) {
270
- q_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
271
- } else {
272
- q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
371
+ vec_t <float , vec_size> cos, sin;
372
+
373
+ if (bx * bdy + ty < nnz) {
374
+ const uint32_t idx = bx * bdy + ty;
375
+ const IdType pos = pos_ids[idx];
376
+
377
+ if (tx * vec_size < rotary_dim) {
378
+ #pragma unroll
379
+ for (uint32_t i = 0 ; i < vec_size; ++i) {
380
+ float embed = float (pos) * freq[i];
381
+ __sincosf (embed, &sin[i], &cos[i]);
382
+ }
273
383
}
274
- q_vec.cast_store (q_rope_ptr + tx * vec_size);
275
- } else {
276
- const uint32_t kv_head_idx = head_idx - num_qo_heads;
277
- DType* k_ptr = k + get_elem_offset_impl (pos_idx, kv_head_idx, 0 , k_stride_n, k_stride_h);
278
- DType* k_rope_ptr =
279
- k_rope + get_elem_offset_impl (pos_idx, kv_head_idx, 0 , k_rope_stride_n, k_rope_stride_h);
280
- vec_t <float , vec_size> k_vec;
281
- if constexpr (interleave) {
282
- k_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
283
- } else {
284
- k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
384
+
385
+ #pragma unroll 1
386
+ for (uint32_t qo_head_idx = 0 ; qo_head_idx < num_qo_heads; ++qo_head_idx) {
387
+ DType* q_ptr = q + get_elem_offset_impl (idx, qo_head_idx, 0 , q_stride_n, q_stride_h);
388
+ DType* q_rope_ptr =
389
+ q_rope + get_elem_offset_impl (idx, qo_head_idx, 0 , q_rope_stride_n, q_rope_stride_h);
390
+ vec_t <float , vec_size> q_vec;
391
+ if constexpr (interleave) {
392
+ q_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
393
+ } else {
394
+ q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
395
+ }
396
+ q_vec.cast_store (q_rope_ptr + tx * vec_size);
397
+ }
398
+
399
+ #pragma unroll 1
400
+ for (uint32_t kv_head_idx = 0 ; kv_head_idx < num_kv_heads; ++kv_head_idx) {
401
+ DType* k_ptr = k + get_elem_offset_impl (idx, kv_head_idx, 0 , k_stride_n, k_stride_h);
402
+ DType* k_rope_ptr =
403
+ k_rope + get_elem_offset_impl (idx, kv_head_idx, 0 , k_rope_stride_n, k_rope_stride_h);
404
+ vec_t <float , vec_size> k_vec;
405
+ if constexpr (interleave) {
406
+ k_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
407
+ } else {
408
+ k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
409
+ }
410
+ k_vec.cast_store (k_rope_ptr + tx * vec_size);
285
411
}
286
- k_vec.cast_store (k_rope_ptr + tx * vec_size);
287
412
}
288
413
}
289
414
@@ -383,16 +508,18 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCache(
383
508
uint32_t rotary_dim, uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n,
384
509
size_t k_stride_h, size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n,
385
510
size_t k_rope_stride_h, bool interleave, cudaStream_t stream = nullptr ) {
511
+ int dev_id = 0 ;
512
+ int num_sms = 0 ;
513
+ FLASHINFER_CUDA_CALL (cudaGetDevice (&dev_id));
514
+ FLASHINFER_CUDA_CALL (cudaDeviceGetAttribute (&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
515
+
386
516
DISPATCH_INTERLEAVE (interleave, INTERLEAVE, {
387
517
DISPATCH_HEAD_DIM (head_dim, HEAD_DIM, {
388
518
constexpr uint32_t vec_size = std::max (16 / sizeof (DType), HEAD_DIM / 32 );
389
519
constexpr uint32_t bdx = HEAD_DIM / vec_size;
390
520
uint32_t num_threads = std::max (128U , bdx);
391
521
uint32_t bdy = num_threads / bdx;
392
- dim3 nblks ((nnz + bdy - 1 ) / bdy);
393
- dim3 nthrs (bdx, bdy);
394
- auto kernel = BatchQKApplyRotaryPosIdsCosSinCacheKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx,
395
- DType, IdType>;
522
+ uint32_t nblks_x = (nnz + bdy - 1 ) / bdy;
396
523
void * args[] = {(void *)&q,
397
524
(void *)&k,
398
525
(void *)&q_rope,
@@ -412,7 +539,26 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCache(
412
539
(void *)&q_rope_stride_h,
413
540
(void *)&k_rope_stride_n,
414
541
(void *)&k_rope_stride_h};
415
- FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, 0 , stream));
542
+ auto kernel_0 = BatchQKApplyRotaryPosIdsCosSinCacheKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx,
543
+ DType, IdType>;
544
+
545
+ int num_blocks_per_sm_0 = 0 ;
546
+ FLASHINFER_CUDA_CALL (cudaOccupancyMaxActiveBlocksPerMultiprocessor (
547
+ &num_blocks_per_sm_0, kernel_0, num_threads, /* smem_size=*/ 0 ));
548
+ uint32_t num_ctas_0 = num_blocks_per_sm_0 * num_sms;
549
+
550
+ if ((nnz + bdy - 1 ) / bdy >= num_ctas_0) {
551
+ dim3 nblks (nblks_x);
552
+ dim3 nthrs (bdx, bdy);
553
+ FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel_0, nblks, nthrs, args, 0 , stream));
554
+ } else {
555
+ dim3 nblks (nblks_x, num_qo_heads + num_kv_heads);
556
+ dim3 nthrs (bdx, bdy);
557
+ auto kernel_1 =
558
+ BatchQKApplyRotaryPosIdsCosSinCacheHeadParallelismKernel<INTERLEAVE, HEAD_DIM, vec_size,
559
+ bdx, DType, IdType>;
560
+ FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel_1, nblks, nthrs, args, 0 , stream));
561
+ }
416
562
});
417
563
});
418
564
@@ -430,17 +576,19 @@ cudaError_t BatchQKApplyRotaryPosIds(
430
576
float rope_rcp_theta = 1 .0f / rope_theta;
431
577
float smooth_a = 0 .f ;
432
578
float smooth_b = 0 .f ;
579
+ int dev_id = 0 ;
580
+ int num_sms = 0 ;
581
+ FLASHINFER_CUDA_CALL (cudaGetDevice (&dev_id));
582
+ FLASHINFER_CUDA_CALL (cudaDeviceGetAttribute (&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
433
583
434
584
DISPATCH_INTERLEAVE (interleave, INTERLEAVE, {
435
585
DISPATCH_HEAD_DIM (head_dim, HEAD_DIM, {
436
586
constexpr uint32_t vec_size = std::max (16 / sizeof (DType), HEAD_DIM / 32 );
437
587
constexpr uint32_t bdx = HEAD_DIM / vec_size;
438
588
uint32_t num_threads = std::max (128U , bdx);
439
589
uint32_t bdy = num_threads / bdx;
440
- dim3 nblks ((nnz + bdy - 1 ) / bdy);
441
- dim3 nthrs (bdx, bdy);
442
- auto kernel =
443
- BatchQKApplyRotaryPosIdsKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
590
+ uint32_t nblks_x = (nnz + bdy - 1 ) / bdy;
591
+
444
592
void * args[] = {(void *)&q,
445
593
(void *)&k,
446
594
(void *)&q_rope,
@@ -462,7 +610,26 @@ cudaError_t BatchQKApplyRotaryPosIds(
462
610
(void *)&smooth_b,
463
611
(void *)&rope_rcp_scale,
464
612
(void *)&rope_rcp_theta};
465
- FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, 0 , stream));
613
+ auto kernel_0 =
614
+ BatchQKApplyRotaryPosIdsKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
615
+
616
+ int num_blocks_per_sm_0 = 0 ;
617
+ FLASHINFER_CUDA_CALL (cudaOccupancyMaxActiveBlocksPerMultiprocessor (
618
+ &num_blocks_per_sm_0, kernel_0, num_threads, /* smem_size=*/ 0 ));
619
+ uint32_t num_ctas_0 = num_blocks_per_sm_0 * num_sms;
620
+ if (nblks_x >= num_ctas_0) {
621
+ dim3 nblks (nblks_x);
622
+ dim3 nthrs (bdx, bdy);
623
+
624
+ FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel_0, nblks, nthrs, args, 0 , stream));
625
+ } else {
626
+ dim3 nblks (nblks_x, num_qo_heads + num_kv_heads);
627
+ dim3 nthrs (bdx, bdy);
628
+ auto kernel_1 = BatchQKApplyRotaryPosIdsHeadParallelismKernel<INTERLEAVE, HEAD_DIM,
629
+ vec_size, bdx, DType, IdType>;
630
+
631
+ FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel_1, nblks, nthrs, args, 0 , stream));
632
+ }
466
633
});
467
634
});
468
635
@@ -606,7 +773,7 @@ cudaError_t BatchQKApplyLlama31RotaryPosIds(
606
773
constexpr uint32_t bdx = HEAD_DIM / vec_size;
607
774
uint32_t num_threads = std::max (128U , bdx);
608
775
uint32_t bdy = num_threads / bdx;
609
- dim3 nblks ((nnz + bdy - 1 ) / bdy * (num_qo_heads + num_kv_heads) );
776
+ dim3 nblks ((nnz + bdy - 1 ) / bdy);
610
777
dim3 nthrs (bdx, bdy);
611
778
auto kernel =
612
779
BatchQKApplyRotaryPosIdsKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
0 commit comments