Skip to content

Commit e57625e

Browse files
committed
Tweak dist_backend to use device_type (before possible :)
1 parent 6ca9257 commit e57625e

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

Diff for: timm/utils/distributed.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,16 @@ def init_distributed_device_so(
108108
world_size = 1
109109
global_rank = 0
110110
local_rank = 0
111+
device_type, *device_idx = device.split(':', maxsplit=1)
112+
111113
if dist_backend is None:
112114
# FIXME: verify that ROCm transform nccl to rccl
113115
dist_backends = {
114116
"xpu": "ccl",
115117
"hpu": "hccl",
116118
"cuda": "nccl",
117119
}
118-
dist_backend = dist_backends.get(device, 'gloo')
120+
dist_backend = dist_backends.get(device_type, 'gloo')
119121
dist_url = dist_url or 'env://'
120122

121123
# TBD, support horovod?
@@ -155,18 +157,15 @@ def init_distributed_device_so(
155157
global_rank = torch.distributed.get_rank()
156158
distributed = True
157159

158-
if 'cuda' in device:
160+
if device_type == 'cuda':
159161
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
160162

161163
if distributed and device != 'cpu':
162-
device, *device_idx = device.split(':', maxsplit=1)
163-
164164
# Ignore manually specified device index in distributed mode and
165165
# override with resolved local rank, fewer headaches in most setups.
166166
if device_idx:
167167
_logger.warning(f'device index {device_idx[0]} removed from specified ({device}).')
168-
169-
device = f'{device}:{local_rank}'
168+
device = f'{device_type}:{local_rank}'
170169

171170
if device.startswith('cuda:'):
172171
torch.cuda.set_device(device)

0 commit comments

Comments
 (0)