Skip to content

Commit 5af935c

Browse files
authored
feat: mma rowsum for fp8 (#180)
Both e4m3 and e5m2.
1 parent d305798 commit 5af935c

File tree

1 file changed

+41
-3
lines changed

1 file changed

+41
-3
lines changed

Diff for: include/flashinfer/mma.cuh

+41-3
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t* R, T* smem_ptr) {
133133
template <typename T, MMAMode mma_mode = MMAMode::kInplaceUpdate>
134134
__device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f32(float* C, uint32_t* A,
135135
uint32_t* B) {
136+
static_assert(sizeof(T) == 1, "DType must be 8bit floating data type");
136137
#if defined(FLASHINFER_MMA_F8F8F32_M16N8K32_ENABLED)
137138
if constexpr (mma_mode == MMAMode::kInit) {
138139
if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
@@ -216,7 +217,7 @@ __device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f32(float* C, uin
216217
}
217218
}
218219
#else
219-
static_assert(false, "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+");
220+
#error "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+"
220221
#endif
221222
}
222223

@@ -387,8 +388,45 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, u
387388
#endif
388389
}
389390

390-
// template <typename DType>
391-
// __device__ __forceinline__ void
391+
/*!
392+
* \brief Use mma instructions to compute rowsum.
393+
*/
394+
template <typename DType>
395+
__device__ __forceinline__ void rowsum_f8f8f32(float* d, DType* s) {
396+
static_assert(sizeof(DType) == 1, "DType must be 8bit floating data type");
397+
uint32_t* s_u32 = (uint32_t*)(s);
398+
#if defined(FLASHINFER_MMA_F8F8F32_M16N8K32_ENABLED)
399+
if constexpr (std::is_same<DType, __nv_fp8_e4m3>::value) {
400+
asm volatile(
401+
"{\n"
402+
".reg .f32 ph;\n"
403+
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
404+
"{%0, ph, %1, ph},"
405+
"{%2, %3, %4, %5},"
406+
"{%6, %7},"
407+
"{%8, 0., %9, 0.};\n"
408+
"}\n"
409+
: "=f"(d[0]), "=f"(d[1])
410+
: "r"(s_u32[0]), "r"(s_u32[1]), "r"(s_u32[2]), "r"(s_u32[3]), "r"(943208504),
411+
"r"(943208504), "f"(d[0]), "f"(d[1]));
412+
} else { // e5m2
413+
asm volatile(
414+
"{\n"
415+
".reg .f32 ph;\n"
416+
"mma.sync.aligned.m16n8k16.row.col.f32.e5m2.e5m2.f32 "
417+
"{%0, ph, %1, ph},"
418+
"{%2, %3, %4, %5},"
419+
"{%6, %7},"
420+
"{%8, 0., %9, 0.};\n"
421+
"}\n"
422+
: "=f"(d[0]), "=f"(d[1])
423+
: "r"(s_u32[0]), "r"(s_u32[1]), "r"(s_u32[2]), "r"(s_u32[3]), "r"(1010580540),
424+
"r"(1010580540), "f"(d[0]), "f"(d[1]));
425+
}
426+
#else
427+
#error "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+"
428+
#endif
429+
}
392430

393431
/*!
394432
* \brief Use mma instructions to compute rowsum.

0 commit comments

Comments
 (0)