File tree 1 file changed +6
-0
lines changed
include/flashinfer/attention
1 file changed +6
-0
lines changed Original file line number Diff line number Diff line change @@ -893,6 +893,11 @@ __device__ __forceinline__ void write_o_reg_gmem(
893
893
for (uint32_t fy = 0 ; fy < num_frags_y; ++fy) {
894
894
uint32_t o_frag_f16[4 ];
895
895
vec_cast<DTypeOut, float , 8 >((DTypeOut*)o_frag_f16, o_frag[fx][fy]);
896
+ #ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED
897
+ uint32_t o_smem_offset_w = smem_t ::get_permuted_offset<channel_size_128b_out>(
898
+ (warp_idx_x * num_frags_x + fx) * 16 + lane_idx % 16 , fy * 2 + lane_idx / 16 );
899
+ o_smem->stmatrix_m8n8x4 (o_smem_offset_w, o_frag_f16);
900
+ #else
896
901
uint32_t o_smem_offset_w = smem_t ::get_permuted_offset<channel_size_128b_out>(
897
902
(warp_idx_x * num_frags_x + fx) * 16 + lane_idx / 4 , fy * 2 );
898
903
((uint32_t *)(o_smem->base + o_smem_offset_w))[lane_idx % 4 ] = o_frag_f16[0 ];
@@ -901,6 +906,7 @@ __device__ __forceinline__ void write_o_reg_gmem(
901
906
((uint32_t *)(o_smem->base + (o_smem_offset_w ^ 0x1 )))[lane_idx % 4 ] = o_frag_f16[2 ];
902
907
((uint32_t *)(o_smem->base + (o_smem_offset_w ^ 0x1 ) +
903
908
8 * channel_size_128b_out))[lane_idx % 4 ] = o_frag_f16[3 ];
909
+ #endif
904
910
}
905
911
}
906
912
You can’t perform that action at this time.
0 commit comments