Skip to content

Commit c6f20d1

Browse files
authored
perf: use stmatrix in epilogue for sm90+ (#380)
sm90+ can benefit from stmatrix in epilogue.
1 parent d68a408 commit c6f20d1

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

include/flashinfer/attention/prefill.cuh

+6
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,11 @@ __device__ __forceinline__ void write_o_reg_gmem(
893893
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
894894
uint32_t o_frag_f16[4];
895895
vec_cast<DTypeOut, float, 8>((DTypeOut*)o_frag_f16, o_frag[fx][fy]);
896+
#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED
897+
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<channel_size_128b_out>(
898+
(warp_idx_x * num_frags_x + fx) * 16 + lane_idx % 16, fy * 2 + lane_idx / 16);
899+
o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16);
900+
#else
896901
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<channel_size_128b_out>(
897902
(warp_idx_x * num_frags_x + fx) * 16 + lane_idx / 4, fy * 2);
898903
((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0];
@@ -901,6 +906,7 @@ __device__ __forceinline__ void write_o_reg_gmem(
901906
((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = o_frag_f16[2];
902907
((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) +
903908
8 * channel_size_128b_out))[lane_idx % 4] = o_frag_f16[3];
909+
#endif
904910
}
905911
}
906912

0 commit comments

Comments
 (0)