13
13
ROOT_DIR = os .path .dirname (__file__ )
14
14
15
15
# Supported NVIDIA GPU architectures.
16
- SUPPORTED_ARCHS = [ "7.0" , "7.5" , "8.0" , "8.6" , "8.9" , "9.0" ]
16
+ SUPPORTED_ARCHS = { "7.0" , "7.5" , "8.0" , "8.6" , "8.9" , "9.0" }
17
17
18
18
# Compiler flags.
19
19
CXX_FLAGS = ["-g" , "-O2" , "-std=c++17" ]
@@ -49,19 +49,32 @@ def get_torch_arch_list() -> Set[str]:
49
49
# and executed on the 8.6 or newer architectures. While the PTX code will
50
50
# not give the best performance on the newer architectures, it provides
51
51
# forward compatibility.
52
- valid_arch_strs = SUPPORTED_ARCHS + [s + "+PTX" for s in SUPPORTED_ARCHS ]
53
- arch_list = os .environ .get ("TORCH_CUDA_ARCH_LIST" , None )
54
- if arch_list is None :
52
+ env_arch_list = os .environ .get ("TORCH_CUDA_ARCH_LIST" , None )
53
+ if env_arch_list is None :
55
54
return set ()
56
55
57
56
# List are separated by ; or space.
58
- arch_list = arch_list .replace (" " , ";" ).split (";" )
59
- for arch in arch_list :
60
- if arch not in valid_arch_strs :
61
- raise ValueError (
62
- f"Unsupported CUDA arch ({ arch } ). "
63
- f"Valid CUDA arch strings are: { valid_arch_strs } ." )
64
- return set (arch_list )
57
+ torch_arch_list = set (env_arch_list .replace (" " , ";" ).split (";" ))
58
+ if not torch_arch_list :
59
+ return set ()
60
+
61
+ # Filter out the invalid architectures and print a warning.
62
+ valid_archs = SUPPORTED_ARCHS .union ({s + "+PTX" for s in SUPPORTED_ARCHS })
63
+ arch_list = torch_arch_list .intersection (valid_archs )
64
+ # If none of the specified architectures are valid, raise an error.
65
+ if not arch_list :
66
+ raise RuntimeError (
67
+ "None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
68
+ f"variable ({ env_arch_list } ) is supported. "
69
+ f"Supported CUDA architectures are: { valid_archs } ." )
70
+ invalid_arch_list = torch_arch_list - valid_archs
71
+ if invalid_arch_list :
72
+ warnings .warn (
73
+ f"Unsupported CUDA architectures ({ invalid_arch_list } ) are "
74
+ "excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
75
+ f"({ env_arch_list } ). Supported CUDA architectures are: "
76
+ f"{ valid_archs } ." )
77
+ return arch_list
65
78
66
79
67
80
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
@@ -81,7 +94,7 @@ def get_torch_arch_list() -> Set[str]:
81
94
if not compute_capabilities :
82
95
# If no GPU is specified nor available, add all supported architectures
83
96
# based on the NVCC CUDA version.
84
- compute_capabilities = set ( SUPPORTED_ARCHS )
97
+ compute_capabilities = SUPPORTED_ARCHS . copy ( )
85
98
if nvcc_cuda_version < Version ("11.1" ):
86
99
compute_capabilities .remove ("8.6" )
87
100
if nvcc_cuda_version < Version ("11.8" ):
0 commit comments