File tree 1 file changed +6
-1
lines changed
include/flashinfer/attention
1 file changed +6
-1
lines changed Original file line number Diff line number Diff line change 33
33
34
34
namespace flashinfer {
35
35
36
+ DEFINE_HAS_MEMBER (maybe_q_rope_offset)
37
+
36
38
namespace cg = cooperative_groups;
37
39
using cp_async::PrefetchMode;
38
40
using cp_async::SharedMemFillMode;
@@ -438,7 +440,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ Params
438
440
const uint32_t q_stride_n = params.q_stride_n ;
439
441
const uint32_t q_stride_h = params.q_stride_h ;
440
442
if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama ) {
441
- const IdType* q_rope_offset = nullptr ; // params.q_rope_offset;
443
+ const IdType* q_rope_offset = nullptr ;
444
+ if constexpr (has_maybe_q_rope_offset_v<Params>) {
445
+ q_rope_offset = params.maybe_q_rope_offset ;
446
+ }
442
447
int32_t q_rope_offset_val = q_rope_offset == nullptr ? (kv_len - 1 ) : q_rope_offset[batch_idx];
443
448
const float rope_rcp_scale = params.rope_rcp_scale ;
444
449
const float rope_rcp_theta = params.rope_rcp_theta ;
You can’t perform that action at this time.
0 commit comments