Skip to content

Commit 5219a2b

Browse files
committed
Refine _platform_supports_mx_gemm check
Signed-off-by: Jagadish Krishnamoorthy <[email protected]>
1 parent a264d06 commit 5219a2b

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

torch/testing/_internal/common_cuda.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,11 @@ def evaluate_platform_supports_fp8():
105105
PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8())
106106

107107
def _platform_supports_mx_gemm():
108-
if TEST_CUDA:
109-
return SM100OrLater
110-
if TEST_WITH_ROCM:
111-
return torch.cuda.get_device_properties(torch.cuda.current_device(0)).name.startswith('gfx950')
108+
if torch.cuda.is_available():
109+
if torch.version.hip:
110+
return 'gfx95' in torch.cuda.get_device_properties(0).gcnArchName
111+
else:
112+
return SM100OrLater
112113
return False
113114

114115
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: _platform_supports_mx_gemm())

0 commit comments

Comments
 (0)