Skip to content

Commit 1b8bf59

Browse files
committed
For mx fp8, A and B need not be kFloat8_e8m0fnu type
Signed-off-by: Jagadish Krishnamoorthy <[email protected]>
1 parent 5219a2b commit 1b8bf59

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

aten/src/ATen/native/cuda/Blas.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -1242,12 +1242,14 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
12421242
"hipblaslt rowwise _scaled_mm only supports BFloat16 output");
12431243
}
12441244
else if (scaling_choice == ScalingType::BlockWise) {
1245-
TORCH_CHECK(mat1.scalar_type() == at::kFloat8_e8m0fnu &&
1246-
mat2.scalar_type() == at::kFloat8_e8m0fnu,
1247-
"Block-wise scaling requires both matrices to be Float8_e8m0fnu type");
1245+
//todo
1246+
//TORCH_CHECK(mat1.scalar_type() == at::kFloat8_e8m0fnu &&
1247+
// mat2.scalar_type() == at::kFloat8_e8m0fnu,
1248+
// "Block-wise scaling requires both matrices to be Float8_e8m0fnu type");
12481249

12491250
#if ROCM_VERSION >= 60500
1250-
TORCH_CHECK(at::cuda::tunable::IsGfx950Device(),
1251+
//todo
1252+
TORCH_CHECK(!at::cuda::tunable::IsGfx950Device(),
12511253
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950");
12521254

12531255
TORCH_CHECK(mat1.size(0) % 32 == 0 && mat1.size(1) % 32 == 0 &&

0 commit comments

Comments
 (0)