We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent a264d06 commit 5219a2bCopy full SHA for 5219a2b
torch/testing/_internal/common_cuda.py
@@ -105,10 +105,11 @@ def evaluate_platform_supports_fp8():
105
PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8())
106
107
def _platform_supports_mx_gemm():
108
- if TEST_CUDA:
109
- return SM100OrLater
110
- if TEST_WITH_ROCM:
111
- return torch.cuda.get_device_properties(torch.cuda.current_device(0)).name.startswith('gfx950')
+ if torch.cuda.is_available():
+ if torch.version.hip:
+ return 'gfx95' in torch.cuda.get_device_properties(0).gcnArchName
+ else:
112
+ return SM100OrLater
113
return False
114
115
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: _platform_supports_mx_gemm())
0 commit comments