Skip to content

Commit d0740df

Browse files
Fix error message on TORCH_CUDA_ARCH_LIST (#1239)
Co-authored-by: Yunfeng Bai <[email protected]>
1 parent de89472 commit d0740df

File tree

1 file changed

+25
-12
lines changed

1 file changed

+25
-12
lines changed

setup.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
ROOT_DIR = os.path.dirname(__file__)
1414

1515
# Supported NVIDIA GPU architectures.
16-
SUPPORTED_ARCHS = ["7.0", "7.5", "8.0", "8.6", "8.9", "9.0"]
16+
SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
1717

1818
# Compiler flags.
1919
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
@@ -49,19 +49,32 @@ def get_torch_arch_list() -> Set[str]:
4949
# and executed on the 8.6 or newer architectures. While the PTX code will
5050
# not give the best performance on the newer architectures, it provides
5151
# forward compatibility.
52-
valid_arch_strs = SUPPORTED_ARCHS + [s + "+PTX" for s in SUPPORTED_ARCHS]
53-
arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
54-
if arch_list is None:
52+
env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
53+
if env_arch_list is None:
5554
return set()
5655

5756
# List are separated by ; or space.
58-
arch_list = arch_list.replace(" ", ";").split(";")
59-
for arch in arch_list:
60-
if arch not in valid_arch_strs:
61-
raise ValueError(
62-
f"Unsupported CUDA arch ({arch}). "
63-
f"Valid CUDA arch strings are: {valid_arch_strs}.")
64-
return set(arch_list)
57+
torch_arch_list = set(env_arch_list.replace(" ", ";").split(";"))
58+
if not torch_arch_list:
59+
return set()
60+
61+
# Filter out the invalid architectures and print a warning.
62+
valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in SUPPORTED_ARCHS})
63+
arch_list = torch_arch_list.intersection(valid_archs)
64+
# If none of the specified architectures are valid, raise an error.
65+
if not arch_list:
66+
raise RuntimeError(
67+
"None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
68+
f"variable ({env_arch_list}) is supported. "
69+
f"Supported CUDA architectures are: {valid_archs}.")
70+
invalid_arch_list = torch_arch_list - valid_archs
71+
if invalid_arch_list:
72+
warnings.warn(
73+
f"Unsupported CUDA architectures ({invalid_arch_list}) are "
74+
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
75+
f"({env_arch_list}). Supported CUDA architectures are: "
76+
f"{valid_archs}.")
77+
return arch_list
6578

6679

6780
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
@@ -81,7 +94,7 @@ def get_torch_arch_list() -> Set[str]:
8194
if not compute_capabilities:
8295
# If no GPU is specified nor available, add all supported architectures
8396
# based on the NVCC CUDA version.
84-
compute_capabilities = set(SUPPORTED_ARCHS)
97+
compute_capabilities = SUPPORTED_ARCHS.copy()
8598
if nvcc_cuda_version < Version("11.1"):
8699
compute_capabilities.remove("8.6")
87100
if nvcc_cuda_version < Version("11.8"):

0 commit comments

Comments
 (0)