@@ -133,6 +133,7 @@ __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t* R, T* smem_ptr) {
133
133
template <typename T, MMAMode mma_mode = MMAMode::kInplaceUpdate >
134
134
__device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f32 (float * C, uint32_t * A,
135
135
uint32_t * B) {
136
+ static_assert (sizeof (T) == 1 , " DType must be 8bit floating data type" );
136
137
#if defined(FLASHINFER_MMA_F8F8F32_M16N8K32_ENABLED)
137
138
if constexpr (mma_mode == MMAMode::kInit ) {
138
139
if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
@@ -216,7 +217,7 @@ __device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f32(float* C, uin
216
217
}
217
218
}
218
219
#else
219
- static_assert ( false , " fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+" );
220
+ # error "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+"
220
221
#endif
221
222
}
222
223
@@ -387,8 +388,45 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, u
387
388
#endif
388
389
}
389
390
390
- // template <typename DType>
391
- // __device__ __forceinline__ void
391
+ /* !
392
+ * \brief Use mma instructions to compute rowsum.
393
+ */
394
+ template <typename DType>
395
+ __device__ __forceinline__ void rowsum_f8f8f32 (float * d, DType* s) {
396
+ static_assert (sizeof (DType) == 1 , " DType must be 8bit floating data type" );
397
+ uint32_t * s_u32 = (uint32_t *)(s);
398
+ #if defined(FLASHINFER_MMA_F8F8F32_M16N8K32_ENABLED)
399
+ if constexpr (std::is_same<DType, __nv_fp8_e4m3>::value) {
400
+ asm volatile (
401
+ " {\n "
402
+ " .reg .f32 ph;\n "
403
+ " mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
404
+ " {%0, ph, %1, ph},"
405
+ " {%2, %3, %4, %5},"
406
+ " {%6, %7},"
407
+ " {%8, 0., %9, 0.};\n "
408
+ " }\n "
409
+ : " =f" (d[0 ]), " =f" (d[1 ])
410
+ : " r" (s_u32[0 ]), " r" (s_u32[1 ]), " r" (s_u32[2 ]), " r" (s_u32[3 ]), " r" (943208504 ),
411
+ " r" (943208504 ), " f" (d[0 ]), " f" (d[1 ]));
412
+ } else { // e5m2
413
+ asm volatile (
414
+ " {\n "
415
+ " .reg .f32 ph;\n "
416
+ " mma.sync.aligned.m16n8k16.row.col.f32.e5m2.e5m2.f32 "
417
+ " {%0, ph, %1, ph},"
418
+ " {%2, %3, %4, %5},"
419
+ " {%6, %7},"
420
+ " {%8, 0., %9, 0.};\n "
421
+ " }\n "
422
+ : " =f" (d[0 ]), " =f" (d[1 ])
423
+ : " r" (s_u32[0 ]), " r" (s_u32[1 ]), " r" (s_u32[2 ]), " r" (s_u32[3 ]), " r" (1010580540 ),
424
+ " r" (1010580540 ), " f" (d[0 ]), " f" (d[1 ]));
425
+ }
426
+ #else
427
+ #error "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+"
428
+ #endif
429
+ }
392
430
393
431
/* !
394
432
* \brief Use mma instructions to compute rowsum.
0 commit comments