File tree 1 file changed +11
-7
lines changed
1 file changed +11
-7
lines changed Original file line number Diff line number Diff line change @@ -108,9 +108,16 @@ def init_distributed_device_so(
108
108
world_size = 1
109
109
global_rank = 0
110
110
local_rank = 0
111
+ device_type , * device_idx = device .split (':' , maxsplit = 1 )
112
+
111
113
if dist_backend is None :
112
- # FIXME sane defaults for other device backends?
113
- dist_backend = 'nccl' if 'cuda' in device else 'gloo'
114
+ # FIXME: verify that ROCm transform nccl to rccl
115
+ dist_backends = {
116
+ "xpu" : "ccl" ,
117
+ "hpu" : "hccl" ,
118
+ "cuda" : "nccl" ,
119
+ }
120
+ dist_backend = dist_backends .get (device_type , 'gloo' )
114
121
dist_url = dist_url or 'env://'
115
122
116
123
# TBD, support horovod?
@@ -150,18 +157,15 @@ def init_distributed_device_so(
150
157
global_rank = torch .distributed .get_rank ()
151
158
distributed = True
152
159
153
- if 'cuda' in device :
160
+ if device_type == 'cuda' :
154
161
assert torch .cuda .is_available (), f'CUDA is not available but { device } was specified.'
155
162
156
163
if distributed and device != 'cpu' :
157
- device , * device_idx = device .split (':' , maxsplit = 1 )
158
-
159
164
# Ignore manually specified device index in distributed mode and
160
165
# override with resolved local rank, fewer headaches in most setups.
161
166
if device_idx :
162
167
_logger .warning (f'device index { device_idx [0 ]} removed from specified ({ device } ).' )
163
-
164
- device = f'{ device } :{ local_rank } '
168
+ device = f'{ device_type } :{ local_rank } '
165
169
166
170
if device .startswith ('cuda:' ):
167
171
torch .cuda .set_device (device )
You can’t perform that action at this time.
0 commit comments