Skip to content

Commit 41a4f56

Browse files
authored
perf: dynamic split-k for MLA (#863)
#804 didn't implement split-k, which might result in performance degradation if concurrency is not large enough. This PR fixes issue. We implemented the v2 scheduler and write-through optimization mentioned in [our paper](https://arxiv.org/pdf/2501.01005) (section 3.3 and appendix in D.2) for load-balancing. In an early PR (#72), we turned off `cudaLaunchCooperativeKernels` and `grid.sync()` because we are not sure whether it's compatible with CUDAGraph. This PR adds them back again for grid synchronization, to save some kernel launch overhead. ## Benchmark On H100 SXM5 80GB (3352 GB/s), this PR: ``` Config: batch_size=1, seq_len=1024, num_heads=16 Memory bandwidth: 22.33 GB/s Config: batch_size=16, seq_len=1024, num_heads=16 Memory bandwidth: 330.72 GB/s Config: batch_size=32, seq_len=1024, num_heads=16 Memory bandwidth: 638.73 GB/s Config: batch_size=64, seq_len=1024, num_heads=16 Memory bandwidth: 1188.90 GB/s Config: batch_size=1, seq_len=2048, num_heads=16 Memory bandwidth: 40.74 GB/s Config: batch_size=16, seq_len=2048, num_heads=16 Memory bandwidth: 592.77 GB/s Config: batch_size=32, seq_len=2048, num_heads=16 Memory bandwidth: 1112.83 GB/s Config: batch_size=64, seq_len=2048, num_heads=16 Memory bandwidth: 1506.01 GB/s Config: batch_size=1, seq_len=4096, num_heads=16 Memory bandwidth: 72.53 GB/s Config: batch_size=16, seq_len=4096, num_heads=16 Memory bandwidth: 1007.80 GB/s Config: batch_size=32, seq_len=4096, num_heads=16 Memory bandwidth: 1438.99 GB/s Config: batch_size=64, seq_len=4096, num_heads=16 Memory bandwidth: 1730.62 GB/s Config: batch_size=1, seq_len=8192, num_heads=16 Memory bandwidth: 120.74 GB/s Config: batch_size=16, seq_len=8192, num_heads=16 Memory bandwidth: 1340.86 GB/s Config: batch_size=32, seq_len=8192, num_heads=16 Memory bandwidth: 1689.36 GB/s Config: batch_size=64, seq_len=8192, num_heads=16 Memory bandwidth: 1901.26 GB/s Config: batch_size=1, seq_len=16384, num_heads=16 Memory bandwidth: 177.94 GB/s Config: batch_size=16, seq_len=16384, num_heads=16 Memory bandwidth: 1619.51 GB/s Config: batch_size=32, seq_len=16384, num_heads=16 Memory bandwidth: 1876.50 GB/s Config: batch_size=64, seq_len=16384, num_heads=16 Memory bandwidth: 2010.58 GB/s Config: batch_size=1, seq_len=32768, num_heads=16 Memory bandwidth: 231.70 GB/s Config: batch_size=16, seq_len=32768, num_heads=16 Memory bandwidth: 1835.16 GB/s Config: batch_size=32, seq_len=32768, num_heads=16 Memory bandwidth: 1997.24 GB/s Config: batch_size=64, seq_len=32768, num_heads=16 Memory bandwidth: 2067.99 GB/s ``` Before this PR: ``` Config: batch_size=1, seq_len=1024, num_heads=16 Memory bandwidth: 15.46 GB/s Config: batch_size=16, seq_len=1024, num_heads=16 Memory bandwidth: 238.49 GB/s Config: batch_size=32, seq_len=1024, num_heads=16 Memory bandwidth: 472.44 GB/s Config: batch_size=64, seq_len=1024, num_heads=16 Memory bandwidth: 929.12 GB/s Config: batch_size=1, seq_len=2048, num_heads=16 Memory bandwidth: 15.47 GB/s Config: batch_size=16, seq_len=2048, num_heads=16 Memory bandwidth: 250.71 GB/s Config: batch_size=32, seq_len=2048, num_heads=16 Memory bandwidth: 500.21 GB/s Config: batch_size=64, seq_len=2048, num_heads=16 Memory bandwidth: 996.37 GB/s Config: batch_size=1, seq_len=4096, num_heads=16 Memory bandwidth: 16.36 GB/s Config: batch_size=16, seq_len=4096, num_heads=16 Memory bandwidth: 257.59 GB/s Config: batch_size=32, seq_len=4096, num_heads=16 Memory bandwidth: 515.88 GB/s Config: batch_size=64, seq_len=4096, num_heads=16 Memory bandwidth: 1035.55 GB/s Config: batch_size=1, seq_len=8192, num_heads=16 Memory bandwidth: 16.37 GB/s Config: batch_size=16, seq_len=8192, num_heads=16 Memory bandwidth: 261.47 GB/s Config: batch_size=32, seq_len=8192, num_heads=16 Memory bandwidth: 524.76 GB/s Config: batch_size=64, seq_len=8192, num_heads=16 Memory bandwidth: 1054.54 GB/s Config: batch_size=1, seq_len=16384, num_heads=16 Memory bandwidth: 16.50 GB/s Config: batch_size=16, seq_len=16384, num_heads=16 Memory bandwidth: 263.69 GB/s Config: batch_size=32, seq_len=16384, num_heads=16 Memory bandwidth: 528.89 GB/s Config: batch_size=64, seq_len=16384, num_heads=16 Memory bandwidth: 1064.87 GB/s Config: batch_size=1, seq_len=32768, num_heads=16 Memory bandwidth: 16.45 GB/s Config: batch_size=16, seq_len=32768, num_heads=16 Memory bandwidth: 264.66 GB/s Config: batch_size=32, seq_len=32768, num_heads=16 Memory bandwidth: 530.87 GB/s Config: batch_size=64, seq_len=32768, num_heads=16 Memory bandwidth: 1070.93 GB/s ```
1 parent 127ff22 commit 41a4f56

File tree

6 files changed

+266
-76
lines changed

6 files changed

+266
-76
lines changed

benchmarks/bench_deepseek_mla.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,6 @@ def bench_deepseek_mla_decode(batch_size, seq_len, num_heads):
7474

7575

7676
if __name__ == "__main__":
77-
bench_deepseek_mla_decode(768, 1024, 16)
78-
bench_deepseek_mla_decode(768, 1024, 32)
79-
bench_deepseek_mla_decode(768, 1024, 64)
80-
bench_deepseek_mla_decode(768, 2048, 16)
81-
bench_deepseek_mla_decode(768, 2048, 32)
82-
bench_deepseek_mla_decode(768, 2048, 64)
77+
for seq_len in [1024, 2048, 4096, 8192, 16384, 32768]:
78+
for batch_size in [1, 16, 32, 64]:
79+
bench_deepseek_mla_decode(batch_size, seq_len, 16)

csrc/batch_mla_run.cu

+8
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int
7171

7272
params.q_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.q_indptr_offset);
7373
params.kv_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_indptr_offset);
74+
params.partial_indptr =
75+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.partial_indptr_offset);
7476
params.kv_indices = static_cast<IdType*>(kv_indices.data_ptr());
7577
params.q_len = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.q_len_offset);
7678
params.kv_len = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_len_offset);
@@ -79,6 +81,12 @@ void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int
7981
params.kv_end = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_end_offset);
8082
params.work_indptr =
8183
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
84+
params.merge_packed_offset_start = GetPtrFromBaseOffset<IdType>(
85+
int_buffer_ptr, plan_info.merge_packed_offset_start_offset);
86+
params.merge_packed_offset_end =
87+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.merge_packed_offset_end_offset);
88+
params.merge_indptr =
89+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.merge_indptr_offset);
8290
params.final_o = static_cast<DTypeO*>(o.data_ptr());
8391
params.final_lse =
8492
maybe_lse.has_value() ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr;

include/flashinfer/attention/cascade.cuh

+2-5
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,13 @@
1616
#ifndef FLASHINFER_CASCADE_CUH_
1717
#define FLASHINFER_CASCADE_CUH_
1818

19-
#include <cooperative_groups.h>
20-
2119
#include "../cp_async.cuh"
2220
#include "../math.cuh"
2321
#include "../utils.cuh"
2422
#include "state.cuh"
2523

2624
namespace flashinfer {
2725

28-
namespace cg = cooperative_groups;
2926
using cp_async::PrefetchMode;
3027
using cp_async::SharedMemFillMode;
3128

@@ -323,8 +320,8 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa
323320
}
324321

325322
/*!
326-
* \brief The CUDA kernel to merge self-attention states of multiple index sets, the number of index
327-
* sets at each position might vary.
323+
* \brief The CUDA kernel to merge self-attention states of multiple index sets, the number of
324+
* index sets at each position might vary.
328325
*
329326
* For CUDA graph support, the kernel can be built with a maximum sequence length and executed
330327
* using a truncated, dynamic sequence length passed through `seq_len_ptr`.

include/flashinfer/attention/mla_fa2.cuh

+107-19
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
*/
1616
#ifndef FLASHINFER_MLA_FA2_CUH_
1717
#define FLASHINFER_MLA_FA2_CUH_
18+
#include <cooperative_groups.h>
19+
1820
#include <cstdint>
1921
#include <sstream>
2022

@@ -90,7 +92,6 @@ struct KernelTraits {
9092
static constexpr uint32_t UPCAST_STRIDE_KPE = HEAD_DIM_KPE / upcast_size<DTypeKV_>();
9193
static constexpr uint32_t UPCAST_STRIDE_FINAL_O = HEAD_DIM_CKV / upcast_size<DTypeO_>();
9294
static constexpr uint32_t UPCAST_STRIDE_P = CTA_TILE_KV / upcast_size<DTypeKV_>();
93-
static constexpr uint32_t UPCAST_STRIDE_PARTIAL_O = HEAD_DIM_CKV / upcast_size<float>();
9495

9596
using DTypeQ = DTypeQ_;
9697
using DTypeKV = DTypeKV_;
@@ -618,6 +619,52 @@ __device__ __forceinline__ void finalize_m_(typename KTraits::AttentionVariant v
618619
}
619620
}
620621

622+
template <typename KTraits>
623+
__device__ void DevicePersistentMergeStates(typename KTraits::IdType* merge_packed_offset_start,
624+
typename KTraits::IdType* merge_packed_offset_end,
625+
typename KTraits::IdType* merge_indptr,
626+
float* partial_o, float* partial_lse,
627+
typename KTraits::DTypeO* final_o, float* final_lse,
628+
const uint32_t o_stride_n, const uint32_t o_stride_h,
629+
const uint32_t cluster_tile_q,
630+
const uint_fastdiv& num_heads) {
631+
constexpr uint32_t VEC_SIZE = 4; // partial o has data type float
632+
constexpr uint32_t NUM_THRS_PER_ROW = KTraits::HEAD_DIM_CKV / VEC_SIZE;
633+
constexpr uint32_t ROWS_PER_ITERATION = (KTraits::NUM_THREADS) / NUM_THRS_PER_ROW;
634+
const uint32_t cluster_id = blockIdx.y;
635+
const uint32_t thread_id = (threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x;
636+
const uint32_t offset_start = merge_packed_offset_start[cluster_id];
637+
const uint32_t offset_end = merge_packed_offset_end[cluster_id];
638+
const uint32_t partial_offset_start = merge_indptr[cluster_id];
639+
const uint32_t partial_offset_end = merge_indptr[cluster_id + 1];
640+
const uint32_t stride = offset_end - offset_start;
641+
#pragma unroll 1
642+
for (uint32_t local_packed_offset =
643+
blockIdx.x * ROWS_PER_ITERATION + thread_id / NUM_THRS_PER_ROW;
644+
local_packed_offset < stride; local_packed_offset += gridDim.x * ROWS_PER_ITERATION) {
645+
uint32_t final_packed_offset = offset_start + local_packed_offset;
646+
uint32_t q, r;
647+
num_heads.divmod(final_packed_offset, q, r);
648+
state_t<VEC_SIZE> st;
649+
#pragma unroll 2
650+
for (uint32_t partial_packed_offset = partial_offset_start + local_packed_offset;
651+
partial_packed_offset < partial_offset_end; partial_packed_offset += stride) {
652+
vec_t<float, VEC_SIZE> o_partial;
653+
float lse_partial;
654+
o_partial.load(partial_o + partial_packed_offset * KTraits::HEAD_DIM_CKV +
655+
(thread_id % NUM_THRS_PER_ROW) * VEC_SIZE);
656+
lse_partial = partial_lse[partial_packed_offset];
657+
st.merge(o_partial, lse_partial, 1);
658+
}
659+
st.normalize();
660+
st.o.cast_store(final_o +
661+
(q * o_stride_n + r * o_stride_h + (thread_id % NUM_THRS_PER_ROW) * VEC_SIZE));
662+
if (final_lse) {
663+
final_lse[q * num_heads + r] = st.get_lse();
664+
}
665+
}
666+
}
667+
621668
template <typename KTraits>
622669
__device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_storage,
623670
typename KTraits::DTypeO* final_o, float* final_lse,
@@ -631,12 +678,40 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st
631678
constexpr uint32_t HEAD_DIM_CKV = KTraits::HEAD_DIM_CKV;
632679
constexpr uint32_t UPCAST_STRIDE_FINAL_O = KTraits::UPCAST_STRIDE_FINAL_O;
633680
const uint32_t lane_idx = threadIdx.x, warpgroup_idx = threadIdx.z, warp_idx_in_wg = threadIdx.y;
681+
smem_t<KTraits::SWIZZLE_MODE_O> o_smem(smem_storage->o_smem);
682+
683+
if (partial_o != nullptr) {
684+
// write to partial_o
685+
#pragma unroll
686+
for (uint32_t j = 0; j < 2; ++j) {
687+
uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4) / num_heads;
688+
if (lane_idx % 4 == 0 && q_idx < q_len) {
689+
partial_lse[(blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4] =
690+
math::ptx_log2(d[j]) + float(m[j]);
691+
}
692+
}
634693

635-
if (false) {
636-
// TOOD(Zihao): write to partial
694+
#pragma unroll
695+
for (uint32_t j = 0; j < 2; ++j) {
696+
uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4) / num_heads;
697+
#pragma unroll
698+
for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV / 2; ++mma_d) {
699+
if (q_idx < q_len) {
700+
*reinterpret_cast<float2*>(
701+
partial_o +
702+
((blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4) * HEAD_DIM_CKV +
703+
warpgroup_idx * (HEAD_DIM_CKV / 2) + mma_d * 16 + (lane_idx % 4) * 2) =
704+
*reinterpret_cast<float2*>(&o_frag[mma_d][j * 2]);
705+
*reinterpret_cast<float2*>(
706+
partial_o +
707+
((blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4) * HEAD_DIM_CKV +
708+
warpgroup_idx * (HEAD_DIM_CKV / 2) + mma_d * 16 + 8 + (lane_idx % 4) * 2) =
709+
*reinterpret_cast<float2*>(&o_frag[mma_d][4 + j * 2]);
710+
}
711+
}
712+
}
637713
} else {
638714
// write to final_o
639-
smem_t<KTraits::SWIZZLE_MODE_O> o_smem(smem_storage->o_smem);
640715

641716
if (final_lse) {
642717
#pragma unroll
@@ -748,13 +823,14 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
748823
const uint32_t kpe_stride_n = params.kpe_stride_n;
749824
const uint32_t o_stride_n = params.o_stride_n;
750825
const uint32_t o_stride_h = params.o_stride_h;
751-
const uint32_t cluster_tile_q = blockDim.x * KTraits::CTA_TILE_Q;
826+
const uint32_t cluster_tile_q = gridDim.x * KTraits::CTA_TILE_Q;
752827

753828
#pragma unroll 1
754829
for (IdType work_idx = work_indptr[blockIdx.y]; work_idx < work_indptr[blockIdx.y + 1];
755830
++work_idx) {
756831
const uint32_t q_indptr = params.q_indptr[work_idx];
757832
const uint32_t kv_indptr = params.kv_indptr[work_idx];
833+
const int32_t partial_indptr = params.partial_indptr[work_idx];
758834
const uint32_t q_len = params.q_len[work_idx];
759835
const uint32_t kv_len = params.kv_len[work_idx];
760836
const uint32_t packed_qo_start = params.q_start[work_idx];
@@ -778,14 +854,14 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
778854
(CAUSAL ? min(kv_end, kv_len - q_len + (packed_qo_start + cluster_tile_q) / num_heads)
779855
: kv_end),
780856
CTA_TILE_KV) -
781-
1;
857+
1 - (kv_start / CTA_TILE_KV);
782858

783859
int mask_tile_idx =
784-
(CAUSAL ? min(kv_end, kv_len - q_len + packed_qo_start / num_heads) : kv_end) / CTA_TILE_KV;
860+
(CAUSAL ? min(kv_end, kv_len - q_len + packed_qo_start / num_heads) : kv_end) /
861+
CTA_TILE_KV -
862+
(kv_start / CTA_TILE_KV);
785863

786-
int start_tile_idx = kv_start / CTA_TILE_KV; // ceil_div(kv_start, CTA_TILE_KV);
787864
uint32_t block_iter_base = kv_indptr * block_size + kv_start;
788-
789865
// last kv tile
790866
__syncthreads();
791867
uint32_t kv_bound = kv_indptr + (kv_len + block_size - 1) / block_size; // ceil_div
@@ -796,7 +872,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
796872
cp_async::commit_group();
797873
#pragma unroll
798874
for (int stage_idx = 1; stage_idx < NUM_STAGES; ++stage_idx) {
799-
if (kv_tile_idx - stage_idx >= start_tile_idx) {
875+
if (kv_tile_idx - stage_idx >= 0) {
800876
load_kv<KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
801877
kpe_stride_n, kpe_stride_page, kv_bound,
802878
block_iter_base + (kv_tile_idx - stage_idx) * CTA_TILE_KV, block_size,
@@ -807,7 +883,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
807883

808884
// loop with mask
809885
#pragma unroll 1
810-
for (; kv_tile_idx >= mask_tile_idx && kv_tile_idx > start_tile_idx; --kv_tile_idx) {
886+
for (; kv_tile_idx >= mask_tile_idx && kv_tile_idx > 0; --kv_tile_idx) {
811887
cp_async::wait_group<NUM_STAGES - 1>();
812888
__syncthreads();
813889

@@ -825,7 +901,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
825901
// compute sfm * v
826902
compute_mla_pv<KTraits>(&smem_storage, kv_tile_idx % NUM_STAGES, s_frag, d, o_frag);
827903

828-
if (kv_tile_idx - NUM_STAGES >= start_tile_idx) {
904+
if (kv_tile_idx - NUM_STAGES >= 0) {
829905
__syncthreads();
830906
load_kv<KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
831907
kpe_stride_n, kpe_stride_page, kv_bound,
@@ -837,7 +913,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
837913

838914
// loop without mask
839915
#pragma unroll 1
840-
for (; kv_tile_idx + 1 > start_tile_idx + NUM_STAGES; --kv_tile_idx) {
916+
for (; kv_tile_idx + 1 > NUM_STAGES; --kv_tile_idx) {
841917
cp_async::wait_group<NUM_STAGES - 1>();
842918
__syncthreads();
843919

@@ -862,7 +938,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
862938

863939
// last tiles
864940
#pragma unroll
865-
for (; kv_tile_idx >= start_tile_idx; --kv_tile_idx) {
941+
for (; kv_tile_idx >= 0; --kv_tile_idx) {
866942
// compute mla qk
867943
compute_mla_qk<KTraits>(&smem_storage, kv_tile_idx % NUM_STAGES, s_frag);
868944

@@ -884,11 +960,22 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
884960

885961
finalize_m_<KTraits>(variant, m);
886962

887-
write_o<KTraits>(&smem_storage, final_o + q_indptr * o_stride_n,
888-
final_lse ? final_lse + q_indptr * num_heads : nullptr, partial_o, partial_lse,
889-
o_frag, m, d, o_stride_n, o_stride_h, qo_upperbound, qo_packed_idx_base,
890-
num_heads);
963+
write_o<KTraits>(
964+
&smem_storage, final_o + q_indptr * o_stride_n,
965+
final_lse ? final_lse + q_indptr * num_heads : nullptr,
966+
(partial_indptr == -1) ? nullptr : partial_o + partial_indptr * KTraits::HEAD_DIM_CKV,
967+
(partial_indptr == -1) ? nullptr : partial_lse + partial_indptr, o_frag, m, d, o_stride_n,
968+
o_stride_h, qo_upperbound, qo_packed_idx_base, num_heads);
891969
}
970+
971+
auto grid = cg::this_grid();
972+
grid.sync();
973+
974+
// the second stage, merge partial outputs
975+
DevicePersistentMergeStates<KTraits>(params.merge_packed_offset_start,
976+
params.merge_packed_offset_end, params.merge_indptr,
977+
partial_o, partial_lse, final_o, final_lse, o_stride_n,
978+
o_stride_h, cluster_tile_q, num_heads);
892979
}
893980

894981
#define DISPATCH_SMEM_CONFIG(smem_limit_per_sm, NUM_STAGES, CTA_TILE_KV, QK_SHARD, ...) \
@@ -948,7 +1035,8 @@ cudaError_t BatchMLAPagedAttention(Params params, uint32_t num_blks_x, uint32_t
9481035

9491036
FLASHINFER_CUDA_CALL(
9501037
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
951-
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
1038+
FLASHINFER_CUDA_CALL(
1039+
cudaLaunchCooperativeKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
9521040
});
9531041

9541042
return cudaSuccess;

include/flashinfer/attention/mla_params.cuh

+4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ struct MLAParams {
3939

4040
IdType* q_indptr;
4141
IdType* kv_indptr;
42+
IdType* partial_indptr;
43+
IdType* merge_packed_offset_start;
44+
IdType* merge_packed_offset_end;
45+
IdType* merge_indptr;
4246
IdType* kv_indices;
4347
IdType* q_len;
4448
IdType* kv_len;

0 commit comments

Comments
 (0)