@@ -417,7 +417,8 @@ __device__ __forceinline__ void load_q_global_smem(
417
417
uint32_t q, r;
418
418
group_size.divmod (packed_offset + lane_idx / 8 + mma_q * 16 + j * 4 , q, r);
419
419
const uint32_t q_idx = q;
420
- DTypeQ* q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h;
420
+ DTypeQ* q_ptr =
421
+ q_ptr_base + q * q_stride_n + r * q_stride_h + (lane_idx % 8 ) * upcast_size<DTypeQ>();
421
422
#pragma unroll
422
423
for (uint32_t mma_do = 0 ; mma_do < KTraits::NUM_MMA_D_QK / 4 ; ++mma_do) {
423
424
// load q fragment from gmem to smem
@@ -1095,59 +1096,83 @@ __device__ __forceinline__ void write_o_reg_gmem(
1095
1096
typename KTraits::DTypeO* o_ptr_base, const uint32_t o_packed_idx_base,
1096
1097
const uint32_t qo_upper_bound, const uint32_t o_stride_n, const uint32_t o_stride_h,
1097
1098
const uint_fastdiv group_size) {
1099
+ using DTypeO = typename KTraits::DTypeO;
1098
1100
constexpr uint32_t UPCAST_HEAD_DIM_O = KTraits::UPCAST_HEAD_DIM_O;
1099
1101
const uint32_t warp_idx_x = get_warp_idx_q<KTraits>();
1100
1102
const uint32_t lane_idx = threadIdx .x ;
1101
1103
1102
- if (get_warp_idx_kv<KTraits>( ) == 0 ) {
1104
+ if constexpr ( sizeof (DTypeO ) == 4 ) {
1103
1105
#pragma unroll
1104
1106
for (uint32_t mma_q = 0 ; mma_q < KTraits::NUM_MMA_Q; ++mma_q) {
1105
1107
#pragma unroll
1106
- for (uint32_t mma_d = 0 ; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) {
1107
- uint32_t o_frag_f16[8 / 2 ];
1108
- vec_cast<typename KTraits::DTypeO, float >::cast<8 >((typename KTraits::DTypeO*)o_frag_f16,
1109
- o_frag[mma_q][mma_d]);
1108
+ for (uint32_t j = 0 ; j < 2 ; ++j) {
1109
+ uint32_t q, r;
1110
+ group_size.divmod (o_packed_idx_base + lane_idx / 4 + mma_q * 16 + j * 8 , q, r);
1111
+ const uint32_t o_idx = q;
1112
+ #pragma unroll
1113
+ for (uint32_t mma_d = 0 ; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) {
1114
+ if (o_idx < qo_upper_bound) {
1115
+ *reinterpret_cast <float2 *>(o_ptr_base + q * o_stride_n + r * o_stride_h + mma_d * 16 +
1116
+ (lane_idx % 4 ) * 2 ) =
1117
+ *reinterpret_cast <float2 *>(&o_frag[mma_q][mma_d][j * 2 ]);
1118
+ *reinterpret_cast <float2 *>(o_ptr_base + q * o_stride_n + r * o_stride_h + mma_d * 16 +
1119
+ 8 + (lane_idx % 4 ) * 2 ) =
1120
+ *reinterpret_cast <float2 *>(&o_frag[mma_q][mma_d][4 + j * 2 ]);
1121
+ }
1122
+ }
1123
+ }
1124
+ }
1125
+ } else {
1126
+ if (get_warp_idx_kv<KTraits>() == 0 ) {
1127
+ #pragma unroll
1128
+ for (uint32_t mma_q = 0 ; mma_q < KTraits::NUM_MMA_Q; ++mma_q) {
1129
+ #pragma unroll
1130
+ for (uint32_t mma_d = 0 ; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) {
1131
+ uint32_t o_frag_f16[8 / 2 ];
1132
+ vec_cast<DTypeO, float >::cast<8 >((DTypeO*)o_frag_f16, o_frag[mma_q][mma_d]);
1110
1133
1111
1134
#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED
1112
- uint32_t o_smem_offset_w = o_smem->get_permuted_offset <UPCAST_HEAD_DIM_O>(
1113
- (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx % 16 ,
1114
- mma_d * 2 + lane_idx / 16 );
1115
- o_smem->stmatrix_m8n8x4 (o_smem_offset_w, o_frag_f16);
1135
+ uint32_t o_smem_offset_w = o_smem->get_permuted_offset <UPCAST_HEAD_DIM_O>(
1136
+ (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx % 16 ,
1137
+ mma_d * 2 + lane_idx / 16 );
1138
+ o_smem->stmatrix_m8n8x4 (o_smem_offset_w, o_frag_f16);
1116
1139
#else
1117
- uint32_t o_smem_offset_w = o_smem->get_permuted_offset <UPCAST_HEAD_DIM_O>(
1118
- (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx / 4 , mma_d * 2 );
1119
- ((uint32_t *)(o_smem->base + o_smem_offset_w))[lane_idx % 4 ] = o_frag_f16[0 ];
1120
- ((uint32_t *)(o_smem->base + o_smem_offset_w + 8 * UPCAST_HEAD_DIM_O))[lane_idx % 4 ] =
1121
- o_frag_f16[1 ];
1122
- ((uint32_t *)(o_smem->base + (o_smem_offset_w ^ 0x1 )))[lane_idx % 4 ] = o_frag_f16[2 ];
1123
- ((uint32_t *)(o_smem->base + (o_smem_offset_w ^ 0x1 ) +
1124
- 8 * UPCAST_HEAD_DIM_O))[lane_idx % 4 ] = o_frag_f16[3 ];
1140
+ uint32_t o_smem_offset_w = o_smem->get_permuted_offset <UPCAST_HEAD_DIM_O>(
1141
+ (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx / 4 , mma_d * 2 );
1142
+ ((uint32_t *)(o_smem->base + o_smem_offset_w))[lane_idx % 4 ] = o_frag_f16[0 ];
1143
+ ((uint32_t *)(o_smem->base + o_smem_offset_w + 8 * UPCAST_HEAD_DIM_O))[lane_idx % 4 ] =
1144
+ o_frag_f16[1 ];
1145
+ ((uint32_t *)(o_smem->base + (o_smem_offset_w ^ 0x1 )))[lane_idx % 4 ] = o_frag_f16[2 ];
1146
+ ((uint32_t *)(o_smem->base + (o_smem_offset_w ^ 0x1 ) +
1147
+ 8 * UPCAST_HEAD_DIM_O))[lane_idx % 4 ] = o_frag_f16[3 ];
1125
1148
#endif
1149
+ }
1126
1150
}
1127
- }
1128
1151
1129
- uint32_t o_smem_offset_w = o_smem->get_permuted_offset <UPCAST_HEAD_DIM_O>(
1130
- warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8 , lane_idx % 8 );
1152
+ uint32_t o_smem_offset_w = o_smem->get_permuted_offset <UPCAST_HEAD_DIM_O>(
1153
+ warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8 , lane_idx % 8 );
1131
1154
1132
1155
#pragma unroll
1133
- for (uint32_t mma_q = 0 ; mma_q < KTraits::NUM_MMA_Q; ++mma_q) {
1156
+ for (uint32_t mma_q = 0 ; mma_q < KTraits::NUM_MMA_Q; ++mma_q) {
1134
1157
#pragma unroll
1135
- for (uint32_t j = 0 ; j < 2 * 2 ; ++j) {
1136
- uint32_t q, r;
1137
- group_size.divmod (o_packed_idx_base + lane_idx / 8 + mma_q * 16 + j * 4 , q, r);
1138
- const uint32_t o_idx = q;
1139
- typename KTraits::DTypeO* o_ptr = o_ptr_base + q * o_stride_n + r * o_stride_h;
1158
+ for (uint32_t j = 0 ; j < 2 * 2 ; ++j) {
1159
+ uint32_t q, r;
1160
+ group_size.divmod (o_packed_idx_base + lane_idx / 8 + mma_q * 16 + j * 4 , q, r);
1161
+ const uint32_t o_idx = q;
1162
+ DTypeO* o_ptr =
1163
+ o_ptr_base + q * o_stride_n + r * o_stride_h + (lane_idx % 8 ) * upcast_size<DTypeO>();
1140
1164
#pragma unroll
1141
- for (uint32_t mma_do = 0 ; mma_do < KTraits::NUM_MMA_D_VO / 4 ; ++mma_do) {
1142
- if (o_idx < qo_upper_bound) {
1143
- o_smem->store_128b (o_smem_offset_w, o_ptr);
1165
+ for (uint32_t mma_do = 0 ; mma_do < KTraits::NUM_MMA_D_VO / 4 ; ++mma_do) {
1166
+ if (o_idx < qo_upper_bound) {
1167
+ o_smem->store_128b (o_smem_offset_w, o_ptr);
1168
+ }
1169
+ o_ptr += 8 * upcast_size<DTypeO>();
1170
+ o_smem_offset_w = o_smem->template advance_offset_by_column <8 >(o_smem_offset_w, mma_do);
1144
1171
}
1145
- o_ptr += 8 * upcast_size<typename KTraits::DTypeO>();
1146
- o_smem_offset_w = o_smem->template advance_offset_by_column <8 >(o_smem_offset_w, mma_do);
1172
+ o_smem_offset_w =
1173
+ o_smem->template advance_offset_by_row <4 , UPCAST_HEAD_DIM_O>(o_smem_offset_w) -
1174
+ 2 * KTraits::NUM_MMA_D_VO;
1147
1175
}
1148
- o_smem_offset_w =
1149
- o_smem->template advance_offset_by_row <4 , UPCAST_HEAD_DIM_O>(o_smem_offset_w) -
1150
- 2 * KTraits::NUM_MMA_D_VO;
1151
1176
}
1152
1177
}
1153
1178
}
@@ -1229,7 +1254,6 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void SinglePrefillWithKVCache
1229
1254
const uint_fastdiv& group_size = params.group_size ;
1230
1255
1231
1256
static_assert (sizeof (DTypeQ) == 2 );
1232
- static_assert (sizeof (DTypeO) == 2 );
1233
1257
const uint32_t lane_idx = threadIdx .x , warp_idx = get_warp_idx<KTraits>();
1234
1258
const uint32_t bx = blockIdx .x , chunk_idx = blockIdx .y , kv_head_idx = blockIdx .z ;
1235
1259
const uint32_t num_kv_heads = gridDim .z , num_qo_heads = num_kv_heads * group_size;
@@ -1264,13 +1288,10 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void SinglePrefillWithKVCache
1264
1288
(bx * NUM_WARPS_Q + get_warp_idx_q<KTraits>()) * NUM_MMA_Q * 16 ;
1265
1289
smem_t <SWIZZLE_MODE_Q> qo_smem (smem_storage.q_smem );
1266
1290
const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO;
1267
- DTypeQ* q_ptr_base =
1268
- q + (kv_head_idx * group_size) * q_stride_h + (lane_idx % 8 ) * upcast_size<DTypeQ>();
1269
- DTypeO* o_ptr_base =
1270
- partition_kv
1271
- ? o + chunk_idx * o_stride_n + (kv_head_idx * group_size) * o_stride_h +
1272
- (lane_idx % 8 ) * upcast_size<DTypeO>()
1273
- : o + (kv_head_idx * group_size) * o_stride_h + (lane_idx % 8 ) * upcast_size<DTypeO>();
1291
+ DTypeQ* q_ptr_base = q + (kv_head_idx * group_size) * q_stride_h;
1292
+ DTypeO* o_ptr_base = partition_kv
1293
+ ? o + chunk_idx * o_stride_n + (kv_head_idx * group_size) * o_stride_h
1294
+ : o + (kv_head_idx * group_size) * o_stride_h;
1274
1295
1275
1296
uint32_t q_smem_offset_r = qo_smem.get_permuted_offset <UPCAST_HEAD_DIM_Q>(
1276
1297
get_warp_idx_q<KTraits>() * NUM_MMA_Q * 16 + lane_idx % 16 , lane_idx / 16 );
@@ -1614,7 +1635,6 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV
1614
1635
const uint_fastdiv& group_size = params.group_size ;
1615
1636
1616
1637
static_assert (sizeof (DTypeQ) == 2 );
1617
- static_assert (sizeof (DTypeO) == 2 );
1618
1638
const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr );
1619
1639
1620
1640
auto block = cg::this_thread_block ();
@@ -1658,16 +1678,13 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV
1658
1678
smem_t <SWIZZLE_MODE_Q> qo_smem (smem_storage.q_smem );
1659
1679
const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO;
1660
1680
1661
- DTypeQ* q_ptr_base = q + q_indptr[request_idx] * q_stride_n +
1662
- kv_head_idx * group_size * q_stride_h +
1663
- (lane_idx % 8 ) * upcast_size<DTypeQ>();
1681
+ DTypeQ* q_ptr_base =
1682
+ q + q_indptr[request_idx] * q_stride_n + kv_head_idx * group_size * q_stride_h;
1664
1683
1665
- DTypeO* o_ptr_base =
1666
- partition_kv
1667
- ? o + (o_indptr[request_idx] + kv_tile_idx) * o_stride_n +
1668
- (kv_head_idx * group_size) * o_stride_h + (lane_idx % 8 ) * upcast_size<DTypeO>()
1669
- : o + o_indptr[request_idx] * o_stride_n + (kv_head_idx * group_size) * o_stride_h +
1670
- (lane_idx % 8 ) * upcast_size<DTypeO>();
1684
+ DTypeO* o_ptr_base = partition_kv ? o + (o_indptr[request_idx] + kv_tile_idx) * o_stride_n +
1685
+ (kv_head_idx * group_size) * o_stride_h
1686
+ : o + o_indptr[request_idx] * o_stride_n +
1687
+ (kv_head_idx * group_size) * o_stride_h;
1671
1688
1672
1689
uint32_t q_smem_offset_r = qo_smem.get_permuted_offset <UPCAST_HEAD_DIM_Q>(
1673
1690
get_warp_idx_q<KTraits>() * NUM_MMA_Q * 16 + lane_idx % 16 , lane_idx / 16 );
@@ -1901,7 +1918,6 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithPagedKVC
1901
1918
const uint_fastdiv& group_size = params.group_size ;
1902
1919
1903
1920
static_assert (sizeof (DTypeQ) == 2 );
1904
- static_assert (sizeof (DTypeO) == 2 );
1905
1921
auto block = cg::this_thread_block ();
1906
1922
const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr );
1907
1923
@@ -1945,15 +1961,12 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithPagedKVC
1945
1961
const uint32_t q_stride_n = params.q_stride_n , q_stride_h = params.q_stride_h ;
1946
1962
smem_t <SWIZZLE_MODE_Q> qo_smem (smem_storage.q_smem );
1947
1963
const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO;
1948
- DTypeQ* q_ptr_base = q + q_indptr[request_idx] * q_stride_n +
1949
- (kv_head_idx * group_size) * q_stride_h +
1950
- (lane_idx % 8 ) * upcast_size<DTypeQ>();
1951
- DTypeO* o_ptr_base =
1952
- partition_kv
1953
- ? o + (o_indptr[request_idx] + kv_tile_idx) * o_stride_n +
1954
- (kv_head_idx * group_size) * o_stride_h + (lane_idx % 8 ) * upcast_size<DTypeO>()
1955
- : o + o_indptr[request_idx] * o_stride_n + (kv_head_idx * group_size) * o_stride_h +
1956
- (lane_idx % 8 ) * upcast_size<DTypeO>();
1964
+ DTypeQ* q_ptr_base =
1965
+ q + q_indptr[request_idx] * q_stride_n + (kv_head_idx * group_size) * q_stride_h;
1966
+ DTypeO* o_ptr_base = partition_kv ? o + (o_indptr[request_idx] + kv_tile_idx) * o_stride_n +
1967
+ (kv_head_idx * group_size) * o_stride_h
1968
+ : o + o_indptr[request_idx] * o_stride_n +
1969
+ (kv_head_idx * group_size) * o_stride_h;
1957
1970
1958
1971
uint32_t q_smem_offset_r = qo_smem.get_permuted_offset <UPCAST_HEAD_DIM_Q>(
1959
1972
get_warp_idx_q<KTraits>() * NUM_MMA_Q * 16 + lane_idx % 16 , lane_idx / 16 );
0 commit comments