Skip to content

Commit ea1d0cb

Browse files
authored
Unique the symbol of maybe_q_rope_offset_v. (#855)
Having building error with multiple definitions of `has_maybe_q_rope_offset_v` with building command: ``` TORCH_CUDA_ARCH_LIST="9.0a" FLASHINFER_ENABLE_AOT=1 pip install -e . -v ``` The reason is that we both defined same template struct in `prefill.cuh` and `decode.cuh`: https://github.com/flashinfer-ai/flashinfer/blob/7bee1cfb6d0322c58ee864f0b592d1952e8c758c/include/flashinfer/attention/prefill.cuh#L39 https://github.com/flashinfer-ai/flashinfer/blob/7bee1cfb6d0322c58ee864f0b592d1952e8c758c/include/flashinfer/attention/decode.cuh#L36 So if the struct is just for condition check, add a prefix of `decode` could avoid this, if no other uses of this struct.
1 parent 7bee1cf commit ea1d0cb

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

include/flashinfer/attention/decode.cuh

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
namespace flashinfer {
3535

36-
DEFINE_HAS_MEMBER(maybe_q_rope_offset)
36+
DEFINE_HAS_MEMBER(decode_maybe_q_rope_offset)
3737

3838
namespace cg = cooperative_groups;
3939
using cp_async::PrefetchMode;
@@ -441,7 +441,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ Params
441441
const uint32_t q_stride_h = params.q_stride_h;
442442
if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) {
443443
const IdType* q_rope_offset = nullptr;
444-
if constexpr (has_maybe_q_rope_offset_v<Params>) {
444+
if constexpr (has_decode_maybe_q_rope_offset_v<Params>) {
445445
q_rope_offset = params.maybe_q_rope_offset;
446446
}
447447
int32_t q_rope_offset_val = q_rope_offset == nullptr ? (kv_len - 1) : q_rope_offset[batch_idx];

0 commit comments

Comments
 (0)