Skip to content

Commit 1161b12

Browse files
authored
bugfix: Fix inline RoPE in decode kernels (#847)
This PR fixes a bug in decode kernel that fails to process the `q_rope_offset` provided in Params. The variable naming follows the existing example in prefill kernel: https://github.com/flashinfer-ai/flashinfer/blob/ab6484effafd5133514890999ff23129ce3b8a9b/include/flashinfer/attention/prefill.cuh#L39
1 parent ab6484e commit 1161b12

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

include/flashinfer/attention/decode.cuh

+6-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333

3434
namespace flashinfer {
3535

36+
DEFINE_HAS_MEMBER(maybe_q_rope_offset)
37+
3638
namespace cg = cooperative_groups;
3739
using cp_async::PrefetchMode;
3840
using cp_async::SharedMemFillMode;
@@ -438,7 +440,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ Params
438440
const uint32_t q_stride_n = params.q_stride_n;
439441
const uint32_t q_stride_h = params.q_stride_h;
440442
if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) {
441-
const IdType* q_rope_offset = nullptr; // params.q_rope_offset;
443+
const IdType* q_rope_offset = nullptr;
444+
if constexpr (has_maybe_q_rope_offset_v<Params>) {
445+
q_rope_offset = params.maybe_q_rope_offset;
446+
}
442447
int32_t q_rope_offset_val = q_rope_offset == nullptr ? (kv_len - 1) : q_rope_offset[batch_idx];
443448
const float rope_rcp_scale = params.rope_rcp_scale;
444449
const float rope_rcp_theta = params.rope_rcp_theta;

0 commit comments

Comments
 (0)