15
15
*/
16
16
#ifndef FLASHINFER_MLA_FA2_CUH_
17
17
#define FLASHINFER_MLA_FA2_CUH_
18
+ #include < cooperative_groups.h>
19
+
18
20
#include < cstdint>
19
21
#include < sstream>
20
22
@@ -90,7 +92,6 @@ struct KernelTraits {
90
92
static constexpr uint32_t UPCAST_STRIDE_KPE = HEAD_DIM_KPE / upcast_size<DTypeKV_>();
91
93
static constexpr uint32_t UPCAST_STRIDE_FINAL_O = HEAD_DIM_CKV / upcast_size<DTypeO_>();
92
94
static constexpr uint32_t UPCAST_STRIDE_P = CTA_TILE_KV / upcast_size<DTypeKV_>();
93
- static constexpr uint32_t UPCAST_STRIDE_PARTIAL_O = HEAD_DIM_CKV / upcast_size<float >();
94
95
95
96
using DTypeQ = DTypeQ_;
96
97
using DTypeKV = DTypeKV_;
@@ -618,6 +619,52 @@ __device__ __forceinline__ void finalize_m_(typename KTraits::AttentionVariant v
618
619
}
619
620
}
620
621
622
+ template <typename KTraits>
623
+ __device__ void DevicePersistentMergeStates (typename KTraits::IdType* merge_packed_offset_start,
624
+ typename KTraits::IdType* merge_packed_offset_end,
625
+ typename KTraits::IdType* merge_indptr,
626
+ float * partial_o, float * partial_lse,
627
+ typename KTraits::DTypeO* final_o, float * final_lse,
628
+ const uint32_t o_stride_n, const uint32_t o_stride_h,
629
+ const uint32_t cluster_tile_q,
630
+ const uint_fastdiv& num_heads) {
631
+ constexpr uint32_t VEC_SIZE = 4 ; // partial o has data type float
632
+ constexpr uint32_t NUM_THRS_PER_ROW = KTraits::HEAD_DIM_CKV / VEC_SIZE;
633
+ constexpr uint32_t ROWS_PER_ITERATION = (KTraits::NUM_THREADS) / NUM_THRS_PER_ROW;
634
+ const uint32_t cluster_id = blockIdx .y ;
635
+ const uint32_t thread_id = (threadIdx .z * blockDim .y + threadIdx .y ) * blockDim .x + threadIdx .x ;
636
+ const uint32_t offset_start = merge_packed_offset_start[cluster_id];
637
+ const uint32_t offset_end = merge_packed_offset_end[cluster_id];
638
+ const uint32_t partial_offset_start = merge_indptr[cluster_id];
639
+ const uint32_t partial_offset_end = merge_indptr[cluster_id + 1 ];
640
+ const uint32_t stride = offset_end - offset_start;
641
+ #pragma unroll 1
642
+ for (uint32_t local_packed_offset =
643
+ blockIdx .x * ROWS_PER_ITERATION + thread_id / NUM_THRS_PER_ROW;
644
+ local_packed_offset < stride; local_packed_offset += gridDim .x * ROWS_PER_ITERATION) {
645
+ uint32_t final_packed_offset = offset_start + local_packed_offset;
646
+ uint32_t q, r;
647
+ num_heads.divmod (final_packed_offset, q, r);
648
+ state_t <VEC_SIZE> st;
649
+ #pragma unroll 2
650
+ for (uint32_t partial_packed_offset = partial_offset_start + local_packed_offset;
651
+ partial_packed_offset < partial_offset_end; partial_packed_offset += stride) {
652
+ vec_t <float , VEC_SIZE> o_partial;
653
+ float lse_partial;
654
+ o_partial.load (partial_o + partial_packed_offset * KTraits::HEAD_DIM_CKV +
655
+ (thread_id % NUM_THRS_PER_ROW) * VEC_SIZE);
656
+ lse_partial = partial_lse[partial_packed_offset];
657
+ st.merge (o_partial, lse_partial, 1 );
658
+ }
659
+ st.normalize ();
660
+ st.o .cast_store (final_o +
661
+ (q * o_stride_n + r * o_stride_h + (thread_id % NUM_THRS_PER_ROW) * VEC_SIZE));
662
+ if (final_lse) {
663
+ final_lse[q * num_heads + r] = st.get_lse ();
664
+ }
665
+ }
666
+ }
667
+
621
668
template <typename KTraits>
622
669
__device__ __forceinline__ void write_o (typename KTraits::SharedStorage* smem_storage,
623
670
typename KTraits::DTypeO* final_o, float * final_lse,
@@ -631,12 +678,40 @@ __device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_st
631
678
constexpr uint32_t HEAD_DIM_CKV = KTraits::HEAD_DIM_CKV;
632
679
constexpr uint32_t UPCAST_STRIDE_FINAL_O = KTraits::UPCAST_STRIDE_FINAL_O;
633
680
const uint32_t lane_idx = threadIdx .x , warpgroup_idx = threadIdx .z , warp_idx_in_wg = threadIdx .y ;
681
+ smem_t <KTraits::SWIZZLE_MODE_O> o_smem (smem_storage->o_smem );
682
+
683
+ if (partial_o != nullptr ) {
684
+ // write to partial_o
685
+ #pragma unroll
686
+ for (uint32_t j = 0 ; j < 2 ; ++j) {
687
+ uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4 ) / num_heads;
688
+ if (lane_idx % 4 == 0 && q_idx < q_len) {
689
+ partial_lse[(blockIdx .x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4 ] =
690
+ math::ptx_log2 (d[j]) + float (m[j]);
691
+ }
692
+ }
634
693
635
- if (false ) {
636
- // TOOD(Zihao): write to partial
694
+ #pragma unroll
695
+ for (uint32_t j = 0 ; j < 2 ; ++j) {
696
+ uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4 ) / num_heads;
697
+ #pragma unroll
698
+ for (uint32_t mma_d = 0 ; mma_d < NUM_MMA_D_CKV / 2 ; ++mma_d) {
699
+ if (q_idx < q_len) {
700
+ *reinterpret_cast <float2 *>(
701
+ partial_o +
702
+ ((blockIdx .x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4 ) * HEAD_DIM_CKV +
703
+ warpgroup_idx * (HEAD_DIM_CKV / 2 ) + mma_d * 16 + (lane_idx % 4 ) * 2 ) =
704
+ *reinterpret_cast <float2 *>(&o_frag[mma_d][j * 2 ]);
705
+ *reinterpret_cast <float2 *>(
706
+ partial_o +
707
+ ((blockIdx .x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4 ) * HEAD_DIM_CKV +
708
+ warpgroup_idx * (HEAD_DIM_CKV / 2 ) + mma_d * 16 + 8 + (lane_idx % 4 ) * 2 ) =
709
+ *reinterpret_cast <float2 *>(&o_frag[mma_d][4 + j * 2 ]);
710
+ }
711
+ }
712
+ }
637
713
} else {
638
714
// write to final_o
639
- smem_t <KTraits::SWIZZLE_MODE_O> o_smem (smem_storage->o_smem );
640
715
641
716
if (final_lse) {
642
717
#pragma unroll
@@ -748,13 +823,14 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
748
823
const uint32_t kpe_stride_n = params.kpe_stride_n ;
749
824
const uint32_t o_stride_n = params.o_stride_n ;
750
825
const uint32_t o_stride_h = params.o_stride_h ;
751
- const uint32_t cluster_tile_q = blockDim .x * KTraits::CTA_TILE_Q;
826
+ const uint32_t cluster_tile_q = gridDim .x * KTraits::CTA_TILE_Q;
752
827
753
828
#pragma unroll 1
754
829
for (IdType work_idx = work_indptr[blockIdx .y ]; work_idx < work_indptr[blockIdx .y + 1 ];
755
830
++work_idx) {
756
831
const uint32_t q_indptr = params.q_indptr [work_idx];
757
832
const uint32_t kv_indptr = params.kv_indptr [work_idx];
833
+ const int32_t partial_indptr = params.partial_indptr [work_idx];
758
834
const uint32_t q_len = params.q_len [work_idx];
759
835
const uint32_t kv_len = params.kv_len [work_idx];
760
836
const uint32_t packed_qo_start = params.q_start [work_idx];
@@ -778,14 +854,14 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
778
854
(CAUSAL ? min (kv_end, kv_len - q_len + (packed_qo_start + cluster_tile_q) / num_heads)
779
855
: kv_end),
780
856
CTA_TILE_KV) -
781
- 1 ;
857
+ 1 - (kv_start / CTA_TILE_KV) ;
782
858
783
859
int mask_tile_idx =
784
- (CAUSAL ? min (kv_end, kv_len - q_len + packed_qo_start / num_heads) : kv_end) / CTA_TILE_KV;
860
+ (CAUSAL ? min (kv_end, kv_len - q_len + packed_qo_start / num_heads) : kv_end) /
861
+ CTA_TILE_KV -
862
+ (kv_start / CTA_TILE_KV);
785
863
786
- int start_tile_idx = kv_start / CTA_TILE_KV; // ceil_div(kv_start, CTA_TILE_KV);
787
864
uint32_t block_iter_base = kv_indptr * block_size + kv_start;
788
-
789
865
// last kv tile
790
866
__syncthreads ();
791
867
uint32_t kv_bound = kv_indptr + (kv_len + block_size - 1 ) / block_size; // ceil_div
@@ -796,7 +872,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
796
872
cp_async::commit_group ();
797
873
#pragma unroll
798
874
for (int stage_idx = 1 ; stage_idx < NUM_STAGES; ++stage_idx) {
799
- if (kv_tile_idx - stage_idx >= start_tile_idx ) {
875
+ if (kv_tile_idx - stage_idx >= 0 ) {
800
876
load_kv<KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
801
877
kpe_stride_n, kpe_stride_page, kv_bound,
802
878
block_iter_base + (kv_tile_idx - stage_idx) * CTA_TILE_KV, block_size,
@@ -807,7 +883,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
807
883
808
884
// loop with mask
809
885
#pragma unroll 1
810
- for (; kv_tile_idx >= mask_tile_idx && kv_tile_idx > start_tile_idx ; --kv_tile_idx) {
886
+ for (; kv_tile_idx >= mask_tile_idx && kv_tile_idx > 0 ; --kv_tile_idx) {
811
887
cp_async::wait_group<NUM_STAGES - 1 >();
812
888
__syncthreads ();
813
889
@@ -825,7 +901,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
825
901
// compute sfm * v
826
902
compute_mla_pv<KTraits>(&smem_storage, kv_tile_idx % NUM_STAGES, s_frag, d, o_frag);
827
903
828
- if (kv_tile_idx - NUM_STAGES >= start_tile_idx ) {
904
+ if (kv_tile_idx - NUM_STAGES >= 0 ) {
829
905
__syncthreads ();
830
906
load_kv<KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
831
907
kpe_stride_n, kpe_stride_page, kv_bound,
@@ -837,7 +913,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
837
913
838
914
// loop without mask
839
915
#pragma unroll 1
840
- for (; kv_tile_idx + 1 > start_tile_idx + NUM_STAGES; --kv_tile_idx) {
916
+ for (; kv_tile_idx + 1 > NUM_STAGES; --kv_tile_idx) {
841
917
cp_async::wait_group<NUM_STAGES - 1 >();
842
918
__syncthreads ();
843
919
@@ -862,7 +938,7 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
862
938
863
939
// last tiles
864
940
#pragma unroll
865
- for (; kv_tile_idx >= start_tile_idx ; --kv_tile_idx) {
941
+ for (; kv_tile_idx >= 0 ; --kv_tile_idx) {
866
942
// compute mla qk
867
943
compute_mla_qk<KTraits>(&smem_storage, kv_tile_idx % NUM_STAGES, s_frag);
868
944
@@ -884,11 +960,22 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKe
884
960
885
961
finalize_m_<KTraits>(variant, m);
886
962
887
- write_o<KTraits>(&smem_storage, final_o + q_indptr * o_stride_n,
888
- final_lse ? final_lse + q_indptr * num_heads : nullptr , partial_o, partial_lse,
889
- o_frag, m, d, o_stride_n, o_stride_h, qo_upperbound, qo_packed_idx_base,
890
- num_heads);
963
+ write_o<KTraits>(
964
+ &smem_storage, final_o + q_indptr * o_stride_n,
965
+ final_lse ? final_lse + q_indptr * num_heads : nullptr ,
966
+ (partial_indptr == -1 ) ? nullptr : partial_o + partial_indptr * KTraits::HEAD_DIM_CKV,
967
+ (partial_indptr == -1 ) ? nullptr : partial_lse + partial_indptr, o_frag, m, d, o_stride_n,
968
+ o_stride_h, qo_upperbound, qo_packed_idx_base, num_heads);
891
969
}
970
+
971
+ auto grid = cg::this_grid ();
972
+ grid.sync ();
973
+
974
+ // the second stage, merge partial outputs
975
+ DevicePersistentMergeStates<KTraits>(params.merge_packed_offset_start ,
976
+ params.merge_packed_offset_end , params.merge_indptr ,
977
+ partial_o, partial_lse, final_o, final_lse, o_stride_n,
978
+ o_stride_h, cluster_tile_q, num_heads);
892
979
}
893
980
894
981
#define DISPATCH_SMEM_CONFIG (smem_limit_per_sm, NUM_STAGES, CTA_TILE_KV, QK_SHARD, ...) \
@@ -948,7 +1035,8 @@ cudaError_t BatchMLAPagedAttention(Params params, uint32_t num_blks_x, uint32_t
948
1035
949
1036
FLASHINFER_CUDA_CALL (
950
1037
cudaFuncSetAttribute (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
951
- FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
1038
+ FLASHINFER_CUDA_CALL (
1039
+ cudaLaunchCooperativeKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
952
1040
});
953
1041
954
1042
return cudaSuccess;
0 commit comments