Skip to content

Commit b217a6f

Browse files
authored
feat:support any num_heads for get_alibi_slope (#200)
When I was using flashinfer, I encountered that the heads of some models were not powers of 2. I refer to **flashinfer/python/tests/alibi_reference.py**, modifies this part of the C++ code.
1 parent a22aeb6 commit b217a6f

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

Diff for: include/flashinfer/pos_enc.cuh

+3-2
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ inline std::string PosEncodingModeToString(const PosEncodingMode& pos_encoding_m
5656
}
5757

5858
__device__ __forceinline__ float get_alibi_slope(uint32_t head_idx, uint32_t num_heads) {
59-
// NOTE(Zihao): here we assume that num_heads is a power of 2
60-
return math::ptx_exp2(-8. * float(head_idx + 1) / float(num_heads));
59+
int n = math::ptx_exp2((int)math::ptx_log2(num_heads));
60+
return head_idx < n ? math::ptx_exp2(-8. * float(head_idx + 1) / float(n))
61+
: math::ptx_exp2(-4. * float((head_idx + 1 - n) * 2 - 1) / float(n));
6162
}
6263

6364
/*!

0 commit comments

Comments
 (0)