Skip to content

Commit 892b5a7

Browse files
committed
fix: add SM90 guard for FP8 Blockscale GEMM
Signed-off-by: Zihua Wu <[email protected]>
1 parent fadb1a8 commit 892b5a7

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

cpp/include/tensorrt_llm/deep_gemm/compiler.cuh

+9
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,15 @@ public:
307307
uint32_t const block_k, uint32_t const num_groups, uint32_t const num_stages, uint32_t const num_tma_multicast,
308308
deep_gemm::GemmType const gemm_type)
309309
{
310+
int sm_version = tensorrt_llm::common::getSMVersion();
311+
if (sm_version != 90)
312+
{
313+
TLLM_THROW(
314+
"DeepGEMM only supports Hopper (SM90) architectures, but current device compute "
315+
"capability is %d.",
316+
sm_version);
317+
}
318+
310319
// Build signature - simplified, no MD5 calculation
311320
std::string name = "gemm_" + std::to_string(shape_n) + "_" + std::to_string(shape_k) + "_"
312321
+ std::to_string(block_m) + "_" + std::to_string(block_n) + "_" + std::to_string(block_k) + "_"

cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.cu

-8
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,6 @@ template <typename ElementA, typename ElementB, typename ElementD>
2525
CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::CutlassFp8BlockScaleGemmRunner()
2626
{
2727
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
28-
int sm = tensorrt_llm::common::getSMVersion();
29-
if (sm != 90)
30-
{
31-
TLLM_THROW(
32-
"FP8 Blockscale GEMM kernels are only supported on SM90 architectures, but current device compute "
33-
"capability is %d.",
34-
sm);
35-
}
3628
}
3729

3830
template <typename ElementA, typename ElementB, typename ElementD>

0 commit comments

Comments
 (0)