Skip to content

Commit b27a2cc

Browse files
authored
bugfix: fix the rope correctness issue introduced in #609 (#619)
As observed by @james-p-xu, #609 produce wrong results for some input shapes, this PR fixes the correctness issue, and add optimizations of dispatching to different parallelism modes for different input shapes. For large shape inputs, using the original implementation (re-use sin/cos for different heads) will be better. For small shape inputs, using head parallelism will be better. Some results: ``` Before #609 (no head-parallelism, re-use sin/cos value) ----------------- batch_size: 1, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 0.762GB/s batch_size: 1, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 22us, throughput: 0.919GB/s batch_size: 1, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 95.699GB/s batch_size: 1, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 28us, throughput: 95.244GB/s batch_size: 1, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 31us, throughput: 670.254GB/s batch_size: 1, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 31us, throughput: 667.253GB/s --- batch_size: 19, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 14.490GB/s batch_size: 19, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 27us, throughput: 14.466GB/s batch_size: 19, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 37us, throughput: 1344.086GB/s batch_size: 19, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 37us, throughput: 1344.902GB/s batch_size: 19, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 148us, throughput: 2699.475GB/s batch_size: 19, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 147us, throughput: 2701.897GB/s --- batch_size: 99, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 74.322GB/s batch_size: 99, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 27us, throughput: 74.568GB/s batch_size: 99, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 110us, throughput: 2352.352GB/s batch_size: 99, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 110us, throughput: 2365.580GB/s batch_size: 99, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 718us, throughput: 2893.608GB/s batch_size: 99, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 717us, throughput: 2894.859GB/s --- batch_size: 128, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 27us, throughput: 95.373GB/s batch_size: 128, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 27us, throughput: 95.810GB/s batch_size: 128, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 130us, throughput: 2583.872GB/s batch_size: 128, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 129us, throughput: 2595.944GB/s batch_size: 128, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 923us, throughput: 2907.408GB/s batch_size: 128, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 924us, throughput: 2905.533GB/s Head parallelism only (no dispatch) --------------------- batch_size: 1, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 6us, throughput: 3.321GB/s batch_size: 1, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 6us, throughput: 3.391GB/s batch_size: 1, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 7us, throughput: 358.862GB/s batch_size: 1, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 7us, throughput: 362.361GB/s batch_size: 1, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 15us, throughput: 1413.175GB/s batch_size: 1, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 15us, throughput: 1437.332GB/s --- batch_size: 19, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 6us, throughput: 60.526GB/s batch_size: 19, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 6us, throughput: 60.127GB/s batch_size: 19, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 26us, throughput: 1897.923GB/s batch_size: 19, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 24us, throughput: 2050.075GB/s batch_size: 19, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 164us, throughput: 2431.650GB/s batch_size: 19, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 147us, throughput: 2709.333GB/s --- batch_size: 99, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 7us, throughput: 284.641GB/s batch_size: 99, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 7us, throughput: 302.815GB/s batch_size: 99, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 109us, throughput: 2391.712GB/s batch_size: 99, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 97us, throughput: 2671.150GB/s batch_size: 99, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 860us, throughput: 2413.211GB/s batch_size: 99, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 828us, throughput: 2508.817GB/s --- batch_size: 128, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 7us, throughput: 349.795GB/s batch_size: 128, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 7us, throughput: 376.624GB/s batch_size: 128, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 139us, throughput: 2413.690GB/s batch_size: 128, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 124us, throughput: 2705.994GB/s batch_size: 128, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 1110us, throughput: 2417.480GB/s batch_size: 128, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 1063us, throughput: 2525.976GB/s This PR (shape dispatch) --------------------- batch_size: 1, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 28us, throughput: 0.728GB/s batch_size: 1, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 6us, throughput: 3.451GB/s batch_size: 1, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 7us, throughput: 359.759GB/s batch_size: 1, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 7us, throughput: 361.286GB/s batch_size: 1, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 15us, throughput: 1426.267GB/s batch_size: 1, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 15us, throughput: 1433.691GB/s --- batch_size: 19, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 6us, throughput: 60.390GB/s batch_size: 19, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 6us, throughput: 59.937GB/s batch_size: 19, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 26us, throughput: 1892.575GB/s batch_size: 19, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 24us, throughput: 2049.735GB/s batch_size: 19, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 148us, throughput: 2698.780GB/s batch_size: 19, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 147us, throughput: 2701.558GB/s --- batch_size: 99, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 7us, throughput: 285.335GB/s batch_size: 99, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 7us, throughput: 303.373GB/s batch_size: 99, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 110us, throughput: 2351.126GB/s batch_size: 99, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 110us, throughput: 2362.898GB/s batch_size: 99, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 717us, throughput: 2893.713GB/s batch_size: 99, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 717us, throughput: 2894.902GB/s --- batch_size: 128, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 7us, throughput: 350.720GB/s batch_size: 128, append_len: 1, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 7us, throughput: 376.690GB/s batch_size: 128, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 130us, throughput: 2584.221GB/s batch_size: 128, append_len: 128, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 129us, throughput: 2596.612GB/s batch_size: 128, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: False, latency: 924us, throughput: 2906.480GB/s batch_size: 128, append_len: 1024, num_qo_heads: 32, num_kv_heads: 8, head_dim: 128, use_cos_sin_cache: True, latency: 924us, throughput: 2905.134GB/s ``` cc @nandor @james-p-xu
1 parent eaf73fd commit b27a2cc

File tree

1 file changed

+218
-51
lines changed

1 file changed

+218
-51
lines changed

include/flashinfer/pos_enc.cuh

Lines changed: 218 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,56 @@ __device__ __forceinline__ vec_t<float, vec_size> vec_apply_llama_rope_cos_sin_i
168168
return vec;
169169
}
170170

171+
template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
172+
typename IdType>
173+
__global__ void BatchQKApplyRotaryPosIdsCosSinCacheHeadParallelismKernel(
174+
DType* q, DType* k, DType* q_rope, DType* k_rope, float* __restrict__ cos_cache,
175+
float* __restrict__ sin_cache, IdType* __restrict__ pos_ids, uint32_t nnz,
176+
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rotary_dim, size_t q_stride_n,
177+
size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n,
178+
size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h) {
179+
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
180+
uint32_t by = blockIdx.y;
181+
const uint32_t bdy = blockDim.y;
182+
183+
vec_t<float, vec_size> cos, sin;
184+
if (bx * bdy + ty < nnz) {
185+
const uint32_t idx = bx * bdy + ty;
186+
const IdType pos = pos_ids[idx];
187+
188+
if (tx * vec_size < rotary_dim) {
189+
cos.load(cos_cache + pos * rotary_dim + tx * vec_size);
190+
sin.load(sin_cache + pos * rotary_dim + tx * vec_size);
191+
}
192+
193+
if (by < num_qo_heads) {
194+
uint32_t qo_head_idx = by;
195+
DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h);
196+
DType* q_rope_ptr =
197+
q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
198+
vec_t<float, vec_size> q_vec;
199+
if constexpr (interleave) {
200+
q_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
201+
} else {
202+
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
203+
}
204+
q_vec.cast_store(q_rope_ptr + tx * vec_size);
205+
} else {
206+
uint32_t kv_head_idx = by - num_qo_heads;
207+
DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h);
208+
DType* k_rope_ptr =
209+
k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
210+
vec_t<float, vec_size> k_vec;
211+
if constexpr (interleave) {
212+
k_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
213+
} else {
214+
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
215+
}
216+
k_vec.cast_store(k_rope_ptr + tx * vec_size);
217+
}
218+
}
219+
}
220+
171221
template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
172222
typename IdType>
173223
__global__ void BatchQKApplyRotaryPosIdsCosSinCacheKernel(
@@ -221,69 +271,144 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheKernel(
221271

222272
template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
223273
typename IdType>
224-
__global__ void BatchQKApplyRotaryPosIdsKernel(
274+
__global__ void BatchQKApplyRotaryPosIdsHeadParallelismKernel(
225275
DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ pos_ids, uint32_t nnz,
226276
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rotary_dim, size_t q_stride_n,
227277
size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n,
228278
size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, float smooth_a,
229279
float smooth_b, float rope_rcp_scale, float rope_rcp_theta) {
230280
// NOTE: q and q_rope may be the same ptr, so do k and k_rope
231281
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
282+
uint32_t by = blockIdx.y;
283+
const uint32_t bdy = blockDim.y;
284+
vec_t<float, vec_size> freq;
285+
if (tx * vec_size < rotary_dim) {
286+
#pragma unroll
287+
for (uint32_t i = 0; i < vec_size; ++i) {
288+
if constexpr (interleave) {
289+
freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(rotary_dim));
290+
} else {
291+
freq[i] = __powf(rope_rcp_theta,
292+
float(2 * ((tx * vec_size + i) % (rotary_dim / 2))) / float(rotary_dim));
293+
}
232294

233-
const uint32_t idx = bx * blockDim.y + ty;
234-
const uint32_t pos_idx = idx / (num_qo_heads + num_kv_heads);
235-
if (pos_idx >= nnz) {
236-
return;
295+
float smooth = freq[i] * smooth_a + smooth_b;
296+
smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1]
297+
freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i];
298+
}
237299
}
238300

239-
const IdType pos = pos_ids[pos_idx];
240-
241301
vec_t<float, vec_size> cos, sin;
302+
303+
if (bx * bdy + ty < nnz) {
304+
const uint32_t idx = bx * bdy + ty;
305+
const IdType pos = pos_ids[idx];
306+
307+
if (tx * vec_size < rotary_dim) {
308+
#pragma unroll
309+
for (uint32_t i = 0; i < vec_size; ++i) {
310+
float embed = float(pos) * freq[i];
311+
__sincosf(embed, &sin[i], &cos[i]);
312+
}
313+
}
314+
315+
if (by < num_qo_heads) {
316+
uint32_t qo_head_idx = by;
317+
DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h);
318+
DType* q_rope_ptr =
319+
q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
320+
vec_t<float, vec_size> q_vec;
321+
if constexpr (interleave) {
322+
q_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
323+
} else {
324+
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
325+
}
326+
q_vec.cast_store(q_rope_ptr + tx * vec_size);
327+
} else {
328+
uint32_t kv_head_idx = by - num_qo_heads;
329+
DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h);
330+
DType* k_rope_ptr =
331+
k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
332+
vec_t<float, vec_size> k_vec;
333+
if constexpr (interleave) {
334+
k_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
335+
} else {
336+
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
337+
}
338+
k_vec.cast_store(k_rope_ptr + tx * vec_size);
339+
}
340+
}
341+
}
342+
343+
template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
344+
typename IdType>
345+
__global__ void BatchQKApplyRotaryPosIdsKernel(
346+
DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ pos_ids, uint32_t nnz,
347+
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rotary_dim, size_t q_stride_n,
348+
size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n,
349+
size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, float smooth_a,
350+
float smooth_b, float rope_rcp_scale, float rope_rcp_theta) {
351+
// NOTE: q and q_rope may be the same ptr, so do k and k_rope
352+
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
353+
const uint32_t bdy = blockDim.y;
354+
vec_t<float, vec_size> freq;
242355
if (tx * vec_size < rotary_dim) {
243-
#pragma unroll
356+
#pragma unroll
244357
for (uint32_t i = 0; i < vec_size; ++i) {
245-
float freq;
246358
if constexpr (interleave) {
247-
freq = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(rotary_dim));
359+
freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(rotary_dim));
248360
} else {
249-
freq = __powf(rope_rcp_theta,
250-
float(2 * ((tx * vec_size + i) % (rotary_dim / 2))) / float(rotary_dim));
361+
freq[i] = __powf(rope_rcp_theta,
362+
float(2 * ((tx * vec_size + i) % (rotary_dim / 2))) / float(rotary_dim));
251363
}
252364

253-
float smooth = freq * smooth_a + smooth_b;
365+
float smooth = freq[i] * smooth_a + smooth_b;
254366
smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1]
255-
freq = (1 - smooth) * (freq * rope_rcp_scale) + smooth * freq;
256-
257-
const float embed = float(pos) * freq;
258-
__sincosf(embed, &sin[i], &cos[i]);
367+
freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i];
259368
}
260369
}
261370

262-
const uint32_t head_idx = idx % (num_qo_heads + num_kv_heads);
263-
if (head_idx < num_qo_heads) {
264-
const uint32_t qo_head_idx = head_idx;
265-
DType* q_ptr = q + get_elem_offset_impl(pos_idx, qo_head_idx, 0, q_stride_n, q_stride_h);
266-
DType* q_rope_ptr =
267-
q_rope + get_elem_offset_impl(pos_idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
268-
vec_t<float, vec_size> q_vec;
269-
if constexpr (interleave) {
270-
q_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
271-
} else {
272-
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
371+
vec_t<float, vec_size> cos, sin;
372+
373+
if (bx * bdy + ty < nnz) {
374+
const uint32_t idx = bx * bdy + ty;
375+
const IdType pos = pos_ids[idx];
376+
377+
if (tx * vec_size < rotary_dim) {
378+
#pragma unroll
379+
for (uint32_t i = 0; i < vec_size; ++i) {
380+
float embed = float(pos) * freq[i];
381+
__sincosf(embed, &sin[i], &cos[i]);
382+
}
273383
}
274-
q_vec.cast_store(q_rope_ptr + tx * vec_size);
275-
} else {
276-
const uint32_t kv_head_idx = head_idx - num_qo_heads;
277-
DType* k_ptr = k + get_elem_offset_impl(pos_idx, kv_head_idx, 0, k_stride_n, k_stride_h);
278-
DType* k_rope_ptr =
279-
k_rope + get_elem_offset_impl(pos_idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
280-
vec_t<float, vec_size> k_vec;
281-
if constexpr (interleave) {
282-
k_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
283-
} else {
284-
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
384+
385+
#pragma unroll 1
386+
for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) {
387+
DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h);
388+
DType* q_rope_ptr =
389+
q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
390+
vec_t<float, vec_size> q_vec;
391+
if constexpr (interleave) {
392+
q_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
393+
} else {
394+
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
395+
}
396+
q_vec.cast_store(q_rope_ptr + tx * vec_size);
397+
}
398+
399+
#pragma unroll 1
400+
for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) {
401+
DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h);
402+
DType* k_rope_ptr =
403+
k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
404+
vec_t<float, vec_size> k_vec;
405+
if constexpr (interleave) {
406+
k_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
407+
} else {
408+
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
409+
}
410+
k_vec.cast_store(k_rope_ptr + tx * vec_size);
285411
}
286-
k_vec.cast_store(k_rope_ptr + tx * vec_size);
287412
}
288413
}
289414

@@ -383,16 +508,18 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCache(
383508
uint32_t rotary_dim, uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n,
384509
size_t k_stride_h, size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n,
385510
size_t k_rope_stride_h, bool interleave, cudaStream_t stream = nullptr) {
511+
int dev_id = 0;
512+
int num_sms = 0;
513+
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
514+
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
515+
386516
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
387517
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
388518
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
389519
constexpr uint32_t bdx = HEAD_DIM / vec_size;
390520
uint32_t num_threads = std::max(128U, bdx);
391521
uint32_t bdy = num_threads / bdx;
392-
dim3 nblks((nnz + bdy - 1) / bdy);
393-
dim3 nthrs(bdx, bdy);
394-
auto kernel = BatchQKApplyRotaryPosIdsCosSinCacheKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx,
395-
DType, IdType>;
522+
uint32_t nblks_x = (nnz + bdy - 1) / bdy;
396523
void* args[] = {(void*)&q,
397524
(void*)&k,
398525
(void*)&q_rope,
@@ -412,7 +539,26 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCache(
412539
(void*)&q_rope_stride_h,
413540
(void*)&k_rope_stride_n,
414541
(void*)&k_rope_stride_h};
415-
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
542+
auto kernel_0 = BatchQKApplyRotaryPosIdsCosSinCacheKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx,
543+
DType, IdType>;
544+
545+
int num_blocks_per_sm_0 = 0;
546+
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
547+
&num_blocks_per_sm_0, kernel_0, num_threads, /*smem_size=*/0));
548+
uint32_t num_ctas_0 = num_blocks_per_sm_0 * num_sms;
549+
550+
if ((nnz + bdy - 1) / bdy >= num_ctas_0) {
551+
dim3 nblks(nblks_x);
552+
dim3 nthrs(bdx, bdy);
553+
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_0, nblks, nthrs, args, 0, stream));
554+
} else {
555+
dim3 nblks(nblks_x, num_qo_heads + num_kv_heads);
556+
dim3 nthrs(bdx, bdy);
557+
auto kernel_1 =
558+
BatchQKApplyRotaryPosIdsCosSinCacheHeadParallelismKernel<INTERLEAVE, HEAD_DIM, vec_size,
559+
bdx, DType, IdType>;
560+
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_1, nblks, nthrs, args, 0, stream));
561+
}
416562
});
417563
});
418564

@@ -430,17 +576,19 @@ cudaError_t BatchQKApplyRotaryPosIds(
430576
float rope_rcp_theta = 1.0f / rope_theta;
431577
float smooth_a = 0.f;
432578
float smooth_b = 0.f;
579+
int dev_id = 0;
580+
int num_sms = 0;
581+
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
582+
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
433583

434584
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
435585
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
436586
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
437587
constexpr uint32_t bdx = HEAD_DIM / vec_size;
438588
uint32_t num_threads = std::max(128U, bdx);
439589
uint32_t bdy = num_threads / bdx;
440-
dim3 nblks((nnz + bdy - 1) / bdy);
441-
dim3 nthrs(bdx, bdy);
442-
auto kernel =
443-
BatchQKApplyRotaryPosIdsKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
590+
uint32_t nblks_x = (nnz + bdy - 1) / bdy;
591+
444592
void* args[] = {(void*)&q,
445593
(void*)&k,
446594
(void*)&q_rope,
@@ -462,7 +610,26 @@ cudaError_t BatchQKApplyRotaryPosIds(
462610
(void*)&smooth_b,
463611
(void*)&rope_rcp_scale,
464612
(void*)&rope_rcp_theta};
465-
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
613+
auto kernel_0 =
614+
BatchQKApplyRotaryPosIdsKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
615+
616+
int num_blocks_per_sm_0 = 0;
617+
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
618+
&num_blocks_per_sm_0, kernel_0, num_threads, /*smem_size=*/0));
619+
uint32_t num_ctas_0 = num_blocks_per_sm_0 * num_sms;
620+
if (nblks_x >= num_ctas_0) {
621+
dim3 nblks(nblks_x);
622+
dim3 nthrs(bdx, bdy);
623+
624+
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_0, nblks, nthrs, args, 0, stream));
625+
} else {
626+
dim3 nblks(nblks_x, num_qo_heads + num_kv_heads);
627+
dim3 nthrs(bdx, bdy);
628+
auto kernel_1 = BatchQKApplyRotaryPosIdsHeadParallelismKernel<INTERLEAVE, HEAD_DIM,
629+
vec_size, bdx, DType, IdType>;
630+
631+
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_1, nblks, nthrs, args, 0, stream));
632+
}
466633
});
467634
});
468635

@@ -606,7 +773,7 @@ cudaError_t BatchQKApplyLlama31RotaryPosIds(
606773
constexpr uint32_t bdx = HEAD_DIM / vec_size;
607774
uint32_t num_threads = std::max(128U, bdx);
608775
uint32_t bdy = num_threads / bdx;
609-
dim3 nblks((nnz + bdy - 1) / bdy * (num_qo_heads + num_kv_heads));
776+
dim3 nblks((nnz + bdy - 1) / bdy);
610777
dim3 nthrs(bdx, bdy);
611778
auto kernel =
612779
BatchQKApplyRotaryPosIdsKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;

0 commit comments

Comments
 (0)