Skip to content

Commit e6e07ec

Browse files
apakbinpytorchmergebot
authored andcommitted
[ROCm] code cleanup of architecture checks (pytorch#150473)
This PR replaces several calls to `at::cuda::getCurrentDeviceProperties()->gcnArchName` and `at::cuda::getDeviceProperties(device_index)->gcnArchName` when checking to see if the GPU architecture is in a certain list. Pull Request resolved: pytorch#150473 Approved by: https://github.com/jeffdaily, https://github.com/cyyever
1 parent 9e10601 commit e6e07ec

File tree

9 files changed

+21
-54
lines changed

9 files changed

+21
-54
lines changed

aten/src/ATen/Context.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ at::BlasBackend Context::blasPreferredBackend() {
340340
#endif
341341
};
342342
for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) {
343-
if (!detail::getCUDAHooks().isGPUArch(index, archs)) {
343+
if (!detail::getCUDAHooks().isGPUArch(archs, index)) {
344344
return false;
345345
}
346346
}
@@ -366,7 +366,7 @@ at::BlasBackend Context::blasPreferredBackend() {
366366
#endif
367367
};
368368
for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) {
369-
if (!detail::getCUDAHooks().isGPUArch(index, archs)) {
369+
if (!detail::getCUDAHooks().isGPUArch(archs, index)) {
370370
TORCH_WARN_ONCE(
371371
"Attempting to use hipBLASLt on an unsupported architecture! "
372372
"Overriding blas backend to hipblas");
@@ -419,7 +419,7 @@ void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) {
419419
"gfx90a", "gfx942"
420420
};
421421
for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) {
422-
if (!detail::getCUDAHooks().isGPUArch(index, archs)) {
422+
if (!detail::getCUDAHooks().isGPUArch(archs, index)) {
423423
TORCH_WARN_ONCE(
424424
"Attempting to use CK on an unsupported architecture! Cannot set backend to CK");
425425
return true;

aten/src/ATen/cuda/CUDABlas.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -1085,9 +1085,7 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
10851085
}
10861086
#if defined(USE_ROCM) && !defined(_MSC_VER)
10871087
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
1088-
auto dprops = at::cuda::getCurrentDeviceProperties();
1089-
c10::string_view arch(dprops->gcnArchName);
1090-
if (arch == "gfx1100") { //no CK GEMM version for gfx1100
1088+
if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) { //no CK GEMM version for gfx1100
10911089
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
10921090
} else{
10931091
at::native::gemm_internal_ck<float>(CUDABLAS_GEMM_ARGS(float));

aten/src/ATen/cuda/CublasHandlePool.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,7 @@ size_t parseChosenWorkspaceSize() {
124124
val = getenv("ROCBLAS_WORKSPACE_CONFIG");
125125
}
126126
/* 32MiB default, 128MiB for MI300 */
127-
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
128-
std::string device_arch = properties->gcnArchName;
129-
const bool gfx94 = device_arch.find("gfx94") != std::string::npos;
127+
const bool gfx94 = at::detail::getCUDAHooks().isGPUArch({"gfx94"});
130128
const size_t default_size = gfx94 ? 1024 * 128 * 1024 : 1024 * 32 * 1024;
131129
#else
132130
/* :4096:2:16:8 default, 32MiB for Hopper */

aten/src/ATen/cuda/detail/CUDAHooks.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -448,8 +448,14 @@ DeviceIndex CUDAHooks::getCurrentDevice() const {
448448
}
449449

450450
#ifdef USE_ROCM
451-
bool CUDAHooks::isGPUArch(DeviceIndex device_index, const std::vector<std::string>& archs) const {
452-
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(device_index);
451+
bool CUDAHooks::isGPUArch(const std::vector<std::string>& archs, DeviceIndex device_index) const {
452+
hipDeviceProp_t* prop;
453+
if (device_index == -1){
454+
prop = at::cuda::getCurrentDeviceProperties();
455+
} else {
456+
prop = at::cuda::getDeviceProperties(device_index);
457+
}
458+
453459
std::string device_arch = prop->gcnArchName;
454460
for (std::string arch : archs) {
455461
size_t substring = device_arch.find(arch);

aten/src/ATen/cuda/detail/CUDAHooks.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
5757
DeviceIndex getCurrentDevice() const override;
5858

5959
#ifdef USE_ROCM
60-
bool isGPUArch(DeviceIndex device_index, const std::vector<std::string>& archs) const override;
60+
bool isGPUArch(const std::vector<std::string>& archs, DeviceIndex device_index = -1) const override;
6161
#endif
6262
void deviceSynchronize(DeviceIndex device_index) const override;
6363
};

aten/src/ATen/detail/CUDAHooksInterface.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
196196
}
197197

198198
#ifdef USE_ROCM
199-
virtual bool isGPUArch(DeviceIndex /*device_index*/, const std::vector<std::string>& /*archs*/) const {
199+
virtual bool isGPUArch(const std::vector<std::string>& /*archs*/, DeviceIndex = -1 /*device_index*/) const {
200200
TORCH_CHECK(false, "Cannot check GPU arch without ATen_cuda library. ", CUDA_HELP);
201201
}
202202
#endif

aten/src/ATen/native/cuda/Blas.cpp

+4-28
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,6 @@ static bool getDisableAddmmCudaLt() {
265265

266266
#ifdef USE_ROCM
267267
static bool isSupportedHipLtROCmArch(int index) {
268-
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
269-
std::string device_arch = prop->gcnArchName;
270268
static const std::vector<std::string> archs = {
271269
"gfx90a", "gfx942",
272270
#if ROCM_VERSION >= 60300
@@ -276,13 +274,7 @@ static bool isSupportedHipLtROCmArch(int index) {
276274
"gfx950"
277275
#endif
278276
};
279-
for (std::string arch : archs) {
280-
size_t substring = device_arch.find(arch);
281-
if (substring != std::string::npos) {
282-
return true;
283-
}
284-
}
285-
return false;
277+
return at::detail::getCUDAHooks().isGPUArch(archs, index);
286278
}
287279
#endif
288280

@@ -939,9 +931,7 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) {
939931
}
940932

941933
static bool _scaled_mm_allowed_device() {
942-
auto dprops = at::cuda::getCurrentDeviceProperties();
943934
#ifdef USE_ROCM
944-
std::string device_arch = dprops->gcnArchName;
945935
static const std::vector<std::string> archs = {
946936
"gfx942",
947937
#if ROCM_VERSION >= 60300
@@ -951,30 +941,16 @@ static bool _scaled_mm_allowed_device() {
951941
"gfx950"
952942
#endif
953943
};
954-
for (std::string arch : archs) {
955-
size_t substring = device_arch.find(arch);
956-
if (substring != std::string::npos) {
957-
return true;
958-
}
959-
}
960-
return false;
944+
return at::detail::getCUDAHooks().isGPUArch(archs);
961945
#else
946+
auto dprops = at::cuda::getCurrentDeviceProperties();
962947
return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
963948
#endif
964949
}
965950

966951
#ifdef USE_ROCM
967952
static bool _scaled_mm_is_fnuz() {
968-
auto dprops = at::cuda::getCurrentDeviceProperties();
969-
std::string device_arch = dprops->gcnArchName;
970-
static const std::vector<std::string> archs = {"gfx942"};
971-
for (std::string arch : archs) {
972-
size_t substring = device_arch.find(arch);
973-
if (substring != std::string::npos) {
974-
return true;
975-
}
976-
}
977-
return false;
953+
return at::detail::getCUDAHooks().isGPUArch({"gfx942"});
978954
}
979955
#endif
980956

aten/src/ATen/native/cuda/int4mm.cu

+1-10
Original file line numberDiff line numberDiff line change
@@ -135,16 +135,7 @@ template<typename T, uint32_t Rank>
135135
using VecT = T __attribute__((ext_vector_type(Rank)));
136136

137137
static bool isCDNA2orLater(int index) {
138-
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
139-
std::string device_arch = prop->gcnArchName;
140-
static const std::vector<std::string> archs = {"gfx90a", "gfx942"};
141-
for (std::string arch : archs) {
142-
size_t substring = device_arch.find(arch);
143-
if (substring != std::string::npos) {
144-
return true;
145-
}
146-
}
147-
return false;
138+
return at::detail::getCUDAHooks().isGPUArch({"gfx90a", "gfx942"}, index);
148139
}
149140

150141
#else

aten/src/ATen/native/hip/ck_gemm_half.hip

+1-3
Original file line numberDiff line numberDiff line change
@@ -598,9 +598,7 @@ void dispatch_half_gemm_wmma(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
598598

599599
template <>
600600
void gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
601-
auto dprops = at::cuda::getCurrentDeviceProperties();
602-
c10::string_view arch(dprops->gcnArchName);
603-
if (arch == "gfx1100") {
601+
if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) {
604602
dispatch_half_gemm_wmma(CUDABLAS_GEMM_ARGS(at::Half));
605603
} else{
606604
dispatch_half_gemm(CUDABLAS_GEMM_ARGS(at::Half));

0 commit comments

Comments
 (0)