File tree 1 file changed +5
-6
lines changed
1 file changed +5
-6
lines changed Original file line number Diff line number Diff line change @@ -108,14 +108,16 @@ def init_distributed_device_so(
108
108
world_size = 1
109
109
global_rank = 0
110
110
local_rank = 0
111
+ device_type , * device_idx = device .split (':' , maxsplit = 1 )
112
+
111
113
if dist_backend is None :
112
114
# FIXME: verify that ROCm transform nccl to rccl
113
115
dist_backends = {
114
116
"xpu" : "ccl" ,
115
117
"hpu" : "hccl" ,
116
118
"cuda" : "nccl" ,
117
119
}
118
- dist_backend = dist_backends .get (device , 'gloo' )
120
+ dist_backend = dist_backends .get (device_type , 'gloo' )
119
121
dist_url = dist_url or 'env://'
120
122
121
123
# TBD, support horovod?
@@ -155,18 +157,15 @@ def init_distributed_device_so(
155
157
global_rank = torch .distributed .get_rank ()
156
158
distributed = True
157
159
158
- if 'cuda' in device :
160
+ if device_type == 'cuda' :
159
161
assert torch .cuda .is_available (), f'CUDA is not available but { device } was specified.'
160
162
161
163
if distributed and device != 'cpu' :
162
- device , * device_idx = device .split (':' , maxsplit = 1 )
163
-
164
164
# Ignore manually specified device index in distributed mode and
165
165
# override with resolved local rank, fewer headaches in most setups.
166
166
if device_idx :
167
167
_logger .warning (f'device index { device_idx [0 ]} removed from specified ({ device } ).' )
168
-
169
- device = f'{ device } :{ local_rank } '
168
+ device = f'{ device_type } :{ local_rank } '
170
169
171
170
if device .startswith ('cuda:' ):
172
171
torch .cuda .set_device (device )
You can’t perform that action at this time.
0 commit comments