Skip to content

Commit 32388d0

Browse files
authored
feat: support f32 attention output in FA2 template (#799)
For bf16 kernels, we need to use f32 as the intermediate data type for split-k partial output to avoid numerical errors. This PR adds support for f32 output in CUDA templates. This PR is the first step towards addressing the bf16 kernel numerical issues. The remaining tasks include: 1. do not pre-apply `sm_scale` to query, apply `sm_scale` to logits instead. 2. change the split-k default partial output dtype to float32.
1 parent 824ce40 commit 32388d0

File tree

2 files changed

+78
-65
lines changed

2 files changed

+78
-65
lines changed

include/flashinfer/attention/default_prefill_params.cuh

+2-2
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ struct SinglePrefillParams {
114114
window_left(window_left),
115115
logits_soft_cap(logits_soft_cap),
116116
sm_scale(sm_scale),
117-
rope_rcp_scale(-std::log2f(rope_scale)),
118-
rope_rcp_theta(-std::log2f(rope_theta)),
117+
rope_rcp_scale(1. / rope_scale),
118+
rope_rcp_theta(1. / rope_theta),
119119
partition_kv(false) {}
120120

121121
__host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const {

include/flashinfer/attention/prefill.cuh

+76-63
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,8 @@ __device__ __forceinline__ void load_q_global_smem(
417417
uint32_t q, r;
418418
group_size.divmod(packed_offset + lane_idx / 8 + mma_q * 16 + j * 4, q, r);
419419
const uint32_t q_idx = q;
420-
DTypeQ* q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h;
420+
DTypeQ* q_ptr =
421+
q_ptr_base + q * q_stride_n + r * q_stride_h + (lane_idx % 8) * upcast_size<DTypeQ>();
421422
#pragma unroll
422423
for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_QK / 4; ++mma_do) {
423424
// load q fragment from gmem to smem
@@ -1095,59 +1096,83 @@ __device__ __forceinline__ void write_o_reg_gmem(
10951096
typename KTraits::DTypeO* o_ptr_base, const uint32_t o_packed_idx_base,
10961097
const uint32_t qo_upper_bound, const uint32_t o_stride_n, const uint32_t o_stride_h,
10971098
const uint_fastdiv group_size) {
1099+
using DTypeO = typename KTraits::DTypeO;
10981100
constexpr uint32_t UPCAST_HEAD_DIM_O = KTraits::UPCAST_HEAD_DIM_O;
10991101
const uint32_t warp_idx_x = get_warp_idx_q<KTraits>();
11001102
const uint32_t lane_idx = threadIdx.x;
11011103

1102-
if (get_warp_idx_kv<KTraits>() == 0) {
1104+
if constexpr (sizeof(DTypeO) == 4) {
11031105
#pragma unroll
11041106
for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) {
11051107
#pragma unroll
1106-
for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) {
1107-
uint32_t o_frag_f16[8 / 2];
1108-
vec_cast<typename KTraits::DTypeO, float>::cast<8>((typename KTraits::DTypeO*)o_frag_f16,
1109-
o_frag[mma_q][mma_d]);
1108+
for (uint32_t j = 0; j < 2; ++j) {
1109+
uint32_t q, r;
1110+
group_size.divmod(o_packed_idx_base + lane_idx / 4 + mma_q * 16 + j * 8, q, r);
1111+
const uint32_t o_idx = q;
1112+
#pragma unroll
1113+
for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) {
1114+
if (o_idx < qo_upper_bound) {
1115+
*reinterpret_cast<float2*>(o_ptr_base + q * o_stride_n + r * o_stride_h + mma_d * 16 +
1116+
(lane_idx % 4) * 2) =
1117+
*reinterpret_cast<float2*>(&o_frag[mma_q][mma_d][j * 2]);
1118+
*reinterpret_cast<float2*>(o_ptr_base + q * o_stride_n + r * o_stride_h + mma_d * 16 +
1119+
8 + (lane_idx % 4) * 2) =
1120+
*reinterpret_cast<float2*>(&o_frag[mma_q][mma_d][4 + j * 2]);
1121+
}
1122+
}
1123+
}
1124+
}
1125+
} else {
1126+
if (get_warp_idx_kv<KTraits>() == 0) {
1127+
#pragma unroll
1128+
for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) {
1129+
#pragma unroll
1130+
for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) {
1131+
uint32_t o_frag_f16[8 / 2];
1132+
vec_cast<DTypeO, float>::cast<8>((DTypeO*)o_frag_f16, o_frag[mma_q][mma_d]);
11101133

11111134
#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED
1112-
uint32_t o_smem_offset_w = o_smem->get_permuted_offset<UPCAST_HEAD_DIM_O>(
1113-
(warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx % 16,
1114-
mma_d * 2 + lane_idx / 16);
1115-
o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16);
1135+
uint32_t o_smem_offset_w = o_smem->get_permuted_offset<UPCAST_HEAD_DIM_O>(
1136+
(warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx % 16,
1137+
mma_d * 2 + lane_idx / 16);
1138+
o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16);
11161139
#else
1117-
uint32_t o_smem_offset_w = o_smem->get_permuted_offset<UPCAST_HEAD_DIM_O>(
1118-
(warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx / 4, mma_d * 2);
1119-
((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0];
1120-
((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * UPCAST_HEAD_DIM_O))[lane_idx % 4] =
1121-
o_frag_f16[1];
1122-
((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = o_frag_f16[2];
1123-
((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) +
1124-
8 * UPCAST_HEAD_DIM_O))[lane_idx % 4] = o_frag_f16[3];
1140+
uint32_t o_smem_offset_w = o_smem->get_permuted_offset<UPCAST_HEAD_DIM_O>(
1141+
(warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx / 4, mma_d * 2);
1142+
((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0];
1143+
((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * UPCAST_HEAD_DIM_O))[lane_idx % 4] =
1144+
o_frag_f16[1];
1145+
((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = o_frag_f16[2];
1146+
((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) +
1147+
8 * UPCAST_HEAD_DIM_O))[lane_idx % 4] = o_frag_f16[3];
11251148
#endif
1149+
}
11261150
}
1127-
}
11281151

1129-
uint32_t o_smem_offset_w = o_smem->get_permuted_offset<UPCAST_HEAD_DIM_O>(
1130-
warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8);
1152+
uint32_t o_smem_offset_w = o_smem->get_permuted_offset<UPCAST_HEAD_DIM_O>(
1153+
warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8);
11311154

11321155
#pragma unroll
1133-
for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) {
1156+
for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) {
11341157
#pragma unroll
1135-
for (uint32_t j = 0; j < 2 * 2; ++j) {
1136-
uint32_t q, r;
1137-
group_size.divmod(o_packed_idx_base + lane_idx / 8 + mma_q * 16 + j * 4, q, r);
1138-
const uint32_t o_idx = q;
1139-
typename KTraits::DTypeO* o_ptr = o_ptr_base + q * o_stride_n + r * o_stride_h;
1158+
for (uint32_t j = 0; j < 2 * 2; ++j) {
1159+
uint32_t q, r;
1160+
group_size.divmod(o_packed_idx_base + lane_idx / 8 + mma_q * 16 + j * 4, q, r);
1161+
const uint32_t o_idx = q;
1162+
DTypeO* o_ptr =
1163+
o_ptr_base + q * o_stride_n + r * o_stride_h + (lane_idx % 8) * upcast_size<DTypeO>();
11401164
#pragma unroll
1141-
for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_VO / 4; ++mma_do) {
1142-
if (o_idx < qo_upper_bound) {
1143-
o_smem->store_128b(o_smem_offset_w, o_ptr);
1165+
for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_VO / 4; ++mma_do) {
1166+
if (o_idx < qo_upper_bound) {
1167+
o_smem->store_128b(o_smem_offset_w, o_ptr);
1168+
}
1169+
o_ptr += 8 * upcast_size<DTypeO>();
1170+
o_smem_offset_w = o_smem->template advance_offset_by_column<8>(o_smem_offset_w, mma_do);
11441171
}
1145-
o_ptr += 8 * upcast_size<typename KTraits::DTypeO>();
1146-
o_smem_offset_w = o_smem->template advance_offset_by_column<8>(o_smem_offset_w, mma_do);
1172+
o_smem_offset_w =
1173+
o_smem->template advance_offset_by_row<4, UPCAST_HEAD_DIM_O>(o_smem_offset_w) -
1174+
2 * KTraits::NUM_MMA_D_VO;
11471175
}
1148-
o_smem_offset_w =
1149-
o_smem->template advance_offset_by_row<4, UPCAST_HEAD_DIM_O>(o_smem_offset_w) -
1150-
2 * KTraits::NUM_MMA_D_VO;
11511176
}
11521177
}
11531178
}
@@ -1229,7 +1254,6 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void SinglePrefillWithKVCache
12291254
const uint_fastdiv& group_size = params.group_size;
12301255

12311256
static_assert(sizeof(DTypeQ) == 2);
1232-
static_assert(sizeof(DTypeO) == 2);
12331257
const uint32_t lane_idx = threadIdx.x, warp_idx = get_warp_idx<KTraits>();
12341258
const uint32_t bx = blockIdx.x, chunk_idx = blockIdx.y, kv_head_idx = blockIdx.z;
12351259
const uint32_t num_kv_heads = gridDim.z, num_qo_heads = num_kv_heads * group_size;
@@ -1264,13 +1288,10 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void SinglePrefillWithKVCache
12641288
(bx * NUM_WARPS_Q + get_warp_idx_q<KTraits>()) * NUM_MMA_Q * 16;
12651289
smem_t<SWIZZLE_MODE_Q> qo_smem(smem_storage.q_smem);
12661290
const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO;
1267-
DTypeQ* q_ptr_base =
1268-
q + (kv_head_idx * group_size) * q_stride_h + (lane_idx % 8) * upcast_size<DTypeQ>();
1269-
DTypeO* o_ptr_base =
1270-
partition_kv
1271-
? o + chunk_idx * o_stride_n + (kv_head_idx * group_size) * o_stride_h +
1272-
(lane_idx % 8) * upcast_size<DTypeO>()
1273-
: o + (kv_head_idx * group_size) * o_stride_h + (lane_idx % 8) * upcast_size<DTypeO>();
1291+
DTypeQ* q_ptr_base = q + (kv_head_idx * group_size) * q_stride_h;
1292+
DTypeO* o_ptr_base = partition_kv
1293+
? o + chunk_idx * o_stride_n + (kv_head_idx * group_size) * o_stride_h
1294+
: o + (kv_head_idx * group_size) * o_stride_h;
12741295

12751296
uint32_t q_smem_offset_r = qo_smem.get_permuted_offset<UPCAST_HEAD_DIM_Q>(
12761297
get_warp_idx_q<KTraits>() * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16);
@@ -1614,7 +1635,6 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV
16141635
const uint_fastdiv& group_size = params.group_size;
16151636

16161637
static_assert(sizeof(DTypeQ) == 2);
1617-
static_assert(sizeof(DTypeO) == 2);
16181638
const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr);
16191639

16201640
auto block = cg::this_thread_block();
@@ -1658,16 +1678,13 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKV
16581678
smem_t<SWIZZLE_MODE_Q> qo_smem(smem_storage.q_smem);
16591679
const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO;
16601680

1661-
DTypeQ* q_ptr_base = q + q_indptr[request_idx] * q_stride_n +
1662-
kv_head_idx * group_size * q_stride_h +
1663-
(lane_idx % 8) * upcast_size<DTypeQ>();
1681+
DTypeQ* q_ptr_base =
1682+
q + q_indptr[request_idx] * q_stride_n + kv_head_idx * group_size * q_stride_h;
16641683

1665-
DTypeO* o_ptr_base =
1666-
partition_kv
1667-
? o + (o_indptr[request_idx] + kv_tile_idx) * o_stride_n +
1668-
(kv_head_idx * group_size) * o_stride_h + (lane_idx % 8) * upcast_size<DTypeO>()
1669-
: o + o_indptr[request_idx] * o_stride_n + (kv_head_idx * group_size) * o_stride_h +
1670-
(lane_idx % 8) * upcast_size<DTypeO>();
1684+
DTypeO* o_ptr_base = partition_kv ? o + (o_indptr[request_idx] + kv_tile_idx) * o_stride_n +
1685+
(kv_head_idx * group_size) * o_stride_h
1686+
: o + o_indptr[request_idx] * o_stride_n +
1687+
(kv_head_idx * group_size) * o_stride_h;
16711688

16721689
uint32_t q_smem_offset_r = qo_smem.get_permuted_offset<UPCAST_HEAD_DIM_Q>(
16731690
get_warp_idx_q<KTraits>() * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16);
@@ -1901,7 +1918,6 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithPagedKVC
19011918
const uint_fastdiv& group_size = params.group_size;
19021919

19031920
static_assert(sizeof(DTypeQ) == 2);
1904-
static_assert(sizeof(DTypeO) == 2);
19051921
auto block = cg::this_thread_block();
19061922
const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr);
19071923

@@ -1945,15 +1961,12 @@ __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithPagedKVC
19451961
const uint32_t q_stride_n = params.q_stride_n, q_stride_h = params.q_stride_h;
19461962
smem_t<SWIZZLE_MODE_Q> qo_smem(smem_storage.q_smem);
19471963
const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO;
1948-
DTypeQ* q_ptr_base = q + q_indptr[request_idx] * q_stride_n +
1949-
(kv_head_idx * group_size) * q_stride_h +
1950-
(lane_idx % 8) * upcast_size<DTypeQ>();
1951-
DTypeO* o_ptr_base =
1952-
partition_kv
1953-
? o + (o_indptr[request_idx] + kv_tile_idx) * o_stride_n +
1954-
(kv_head_idx * group_size) * o_stride_h + (lane_idx % 8) * upcast_size<DTypeO>()
1955-
: o + o_indptr[request_idx] * o_stride_n + (kv_head_idx * group_size) * o_stride_h +
1956-
(lane_idx % 8) * upcast_size<DTypeO>();
1964+
DTypeQ* q_ptr_base =
1965+
q + q_indptr[request_idx] * q_stride_n + (kv_head_idx * group_size) * q_stride_h;
1966+
DTypeO* o_ptr_base = partition_kv ? o + (o_indptr[request_idx] + kv_tile_idx) * o_stride_n +
1967+
(kv_head_idx * group_size) * o_stride_h
1968+
: o + o_indptr[request_idx] * o_stride_n +
1969+
(kv_head_idx * group_size) * o_stride_h;
19571970

19581971
uint32_t q_smem_offset_r = qo_smem.get_permuted_offset<UPCAST_HEAD_DIM_Q>(
19591972
get_warp_idx_q<KTraits>() * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16);

0 commit comments

Comments
 (0)