Skip to content

Commit 061db55

Browse files
authored
bugfix: fix potential issues of FA3 template loading nans for PageAttention (#945)
Ref #941
1 parent 39dc66d commit 061db55

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

include/flashinfer/attention/mla_hopper.cuh

+17-15
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ __device__ __forceinline__ void load_kv(
176176
typename KTraits::SharedStorage* smem_storage, typename KTraits::DTypeKV* ckv,
177177
typename KTraits::DTypeKV* kpe, typename KTraits::IdType* indices, const uint32_t ckv_stride_n,
178178
const uint32_t ckv_stride_page, const uint32_t kpe_stride_n, const uint32_t kpe_stride_page,
179-
const uint32_t kv_bound, const uint32_t packed_block_iter_base, const uint_fastdiv& block_size,
180-
const uint32_t stage_idx) {
179+
const uint32_t packed_kv_bound, const uint32_t packed_block_iter_base,
180+
const uint_fastdiv& block_size, const uint32_t stage_idx) {
181181
using DTypeKV = typename KTraits::DTypeKV;
182182
constexpr uint32_t UPCAST_STRIDE_CKV = KTraits::UPCAST_STRIDE_CKV;
183183
constexpr uint32_t UPCAST_STRIDE_KPE = KTraits::UPCAST_STRIDE_KPE;
@@ -195,13 +195,15 @@ __device__ __forceinline__ void load_kv(
195195
#pragma unroll
196196
for (uint32_t j = 0; j < 2; ++j) {
197197
uint32_t q, r;
198+
uint32_t packed_block_iter =
199+
packed_block_iter_base + lane_idx / 8 + (j + mma_kv * 2) * 16 + warp_idx_in_wg * 4;
200+
block_size.divmod(packed_block_iter, q, r);
198201

199-
block_size.divmod(
200-
packed_block_iter_base + lane_idx / 8 + (j + mma_kv * 2) * 16 + warp_idx_in_wg * 4, q, r);
201-
202-
DTypeKV* ckv_ptr = ckv + (q < kv_bound ? indices[q] : 0) * ckv_stride_page +
202+
DTypeKV* ckv_ptr = ckv +
203+
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * ckv_stride_page +
203204
r * ckv_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
204-
DTypeKV* kpe_ptr = kpe + (q < kv_bound ? indices[q] : 0) * kpe_stride_page +
205+
DTypeKV* kpe_ptr = kpe +
206+
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * kpe_stride_page +
205207
r * kpe_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
206208
uint32_t ckv_smem_offset_w = get_swizzle_offset<KTraits::SWIZZLE_MODE_CKV, UPCAST_STRIDE_CKV>(
207209
32 * mma_kv + j * 16 + warp_idx_in_wg * 4 + lane_idx / 8, 8 * 0 + lane_idx % 8);
@@ -211,8 +213,8 @@ __device__ __forceinline__ void load_kv(
211213
#pragma unroll
212214
for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_CKV / 4; ++mma_d) {
213215
if constexpr (predicate) {
214-
ckv_smem.load_128b_async<SharedMemFillMode::kFillZero>(ckv_smem_offset_w, ckv_ptr,
215-
q < kv_bound);
216+
ckv_smem.load_128b_async<SharedMemFillMode::kFillZero>(
217+
ckv_smem_offset_w, ckv_ptr, packed_block_iter < packed_kv_bound);
216218
} else {
217219
ckv_smem.load_128b_async(ckv_smem_offset_w, ckv_ptr);
218220
}
@@ -223,8 +225,8 @@ __device__ __forceinline__ void load_kv(
223225
#pragma unroll
224226
for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_KPE / 4; ++mma_d) {
225227
if constexpr (predicate) {
226-
kpe_smem.load_128b_async<SharedMemFillMode::kFillZero>(kpe_smem_offset_w, kpe_ptr,
227-
q < kv_bound);
228+
kpe_smem.load_128b_async<SharedMemFillMode::kFillZero>(
229+
kpe_smem_offset_w, kpe_ptr, packed_block_iter < packed_kv_bound);
228230
} else {
229231
kpe_smem.load_128b_async(kpe_smem_offset_w, kpe_ptr);
230232
}
@@ -673,7 +675,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
673675
const uint32_t qo_upperbound =
674676
min(q_len, ceil_div(qo_packed_idx_base + KTraits::CTA_TILE_Q, num_heads));
675677

676-
uint32_t kv_bound = kv_indptr + (kv_len + block_size - 1) / block_size;
678+
uint32_t packed_kv_bound = kv_indptr * block_size + kv_len;
677679
int kv_tile_idx =
678680
ceil_div(
679681
(CAUSAL ? min(kv_end, kv_len - q_len + (packed_qo_start + cluster_tile_q) / num_heads)
@@ -687,7 +689,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
687689

688690
PROFILER_EVENT_START(variant, ProfileEventType::kIssueLoadKV);
689691
load_kv<true, KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
690-
kpe_stride_n, kpe_stride_page, kv_bound,
692+
kpe_stride_n, kpe_stride_page, packed_kv_bound,
691693
block_iter_base + kv_tile_idx * CTA_TILE_KV, block_size,
692694
smem_pipe_write_kv.index());
693695

@@ -715,7 +717,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
715717
pipeline_kv.producer_acquire(smem_pipe_write_kv);
716718
PROFILER_EVENT_START(variant, ProfileEventType::kIssueLoadKV);
717719
load_kv<false, KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
718-
kpe_stride_n, kpe_stride_page, kv_bound,
720+
kpe_stride_n, kpe_stride_page, packed_kv_bound,
719721
block_iter_base + kv_tile_idx * CTA_TILE_KV, block_size,
720722
smem_pipe_write_kv.index());
721723
PROFILER_EVENT_END(variant, ProfileEventType::kIssueLoadKV);
@@ -734,7 +736,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPageAttentionHop
734736
pipeline_kv.producer_acquire(smem_pipe_write_kv);
735737
PROFILER_EVENT_START(variant, ProfileEventType::kIssueLoadKV);
736738
load_kv<false, KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
737-
kpe_stride_n, kpe_stride_page, kv_bound,
739+
kpe_stride_n, kpe_stride_page, packed_kv_bound,
738740
block_iter_base + kv_tile_idx * CTA_TILE_KV, block_size,
739741
smem_pipe_write_kv.index());
740742
PROFILER_EVENT_END(variant, ProfileEventType::kIssueLoadKV);

0 commit comments

Comments
 (0)