Skip to content

Commit 7aeee64

Browse files
Ther-LFMu Huai
authored and
Mu Huai
committed
[Bugfix] Fix cutlass dispatch for fp8/int8 to properly invoke M<=16 c… (vllm-project#16751)
Signed-off-by: Ther-LF <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent 0e42fd6 commit 7aeee64

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
336336

337337
uint32_t const m = a.size(0);
338338
uint32_t const mp2 =
339-
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
339+
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
340340

341341
if (mp2 <= 16) {
342342
// M in [1, 16]

csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
321321

322322
uint32_t const m = a.size(0);
323323
uint32_t const mp2 =
324-
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
324+
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
325325

326326
if (mp2 <= 16) {
327327
// M in [1, 16]

0 commit comments

Comments
 (0)