@@ -265,8 +265,6 @@ static bool getDisableAddmmCudaLt() {
265
265
266
266
#ifdef USE_ROCM
267
267
static bool isSupportedHipLtROCmArch (int index) {
268
- hipDeviceProp_t* prop = at::cuda::getDeviceProperties (index );
269
- std::string device_arch = prop->gcnArchName ;
270
268
static const std::vector<std::string> archs = {
271
269
" gfx90a" , " gfx942" ,
272
270
#if ROCM_VERSION >= 60300
@@ -276,13 +274,7 @@ static bool isSupportedHipLtROCmArch(int index) {
276
274
" gfx950"
277
275
#endif
278
276
};
279
- for (std::string arch : archs) {
280
- size_t substring = device_arch.find (arch);
281
- if (substring != std::string::npos) {
282
- return true ;
283
- }
284
- }
285
- return false ;
277
+ return at::detail::getCUDAHooks ().isGPUArch (archs, index );
286
278
}
287
279
#endif
288
280
@@ -939,9 +931,7 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) {
939
931
}
940
932
941
933
static bool _scaled_mm_allowed_device () {
942
- auto dprops = at::cuda::getCurrentDeviceProperties ();
943
934
#ifdef USE_ROCM
944
- std::string device_arch = dprops->gcnArchName ;
945
935
static const std::vector<std::string> archs = {
946
936
" gfx942" ,
947
937
#if ROCM_VERSION >= 60300
@@ -951,30 +941,16 @@ static bool _scaled_mm_allowed_device() {
951
941
" gfx950"
952
942
#endif
953
943
};
954
- for (std::string arch : archs) {
955
- size_t substring = device_arch.find (arch);
956
- if (substring != std::string::npos) {
957
- return true ;
958
- }
959
- }
960
- return false ;
944
+ return at::detail::getCUDAHooks ().isGPUArch (archs);
961
945
#else
946
+ auto dprops = at::cuda::getCurrentDeviceProperties ();
962
947
return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9 );
963
948
#endif
964
949
}
965
950
966
951
#ifdef USE_ROCM
967
952
static bool _scaled_mm_is_fnuz () {
968
- auto dprops = at::cuda::getCurrentDeviceProperties ();
969
- std::string device_arch = dprops->gcnArchName ;
970
- static const std::vector<std::string> archs = {" gfx942" };
971
- for (std::string arch : archs) {
972
- size_t substring = device_arch.find (arch);
973
- if (substring != std::string::npos) {
974
- return true ;
975
- }
976
- }
977
- return false ;
953
+ return at::detail::getCUDAHooks ().isGPUArch ({" gfx942" });
978
954
}
979
955
#endif
980
956
0 commit comments