Skip to content

Commit 27fd2f3

Browse files
authored
Merge pull request #2181 from huggingface/Delaunay-dist-backend
Delaunay dist backend flag
2 parents cd0e7b1 + e57625e commit 27fd2f3

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

Diff for: timm/utils/distributed.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,16 @@ def init_distributed_device_so(
108108
world_size = 1
109109
global_rank = 0
110110
local_rank = 0
111+
device_type, *device_idx = device.split(':', maxsplit=1)
112+
111113
if dist_backend is None:
112-
# FIXME sane defaults for other device backends?
113-
dist_backend = 'nccl' if 'cuda' in device else 'gloo'
114+
# FIXME: verify that ROCm transform nccl to rccl
115+
dist_backends = {
116+
"xpu": "ccl",
117+
"hpu": "hccl",
118+
"cuda": "nccl",
119+
}
120+
dist_backend = dist_backends.get(device_type, 'gloo')
114121
dist_url = dist_url or 'env://'
115122

116123
# TBD, support horovod?
@@ -150,18 +157,15 @@ def init_distributed_device_so(
150157
global_rank = torch.distributed.get_rank()
151158
distributed = True
152159

153-
if 'cuda' in device:
160+
if device_type == 'cuda':
154161
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
155162

156163
if distributed and device != 'cpu':
157-
device, *device_idx = device.split(':', maxsplit=1)
158-
159164
# Ignore manually specified device index in distributed mode and
160165
# override with resolved local rank, fewer headaches in most setups.
161166
if device_idx:
162167
_logger.warning(f'device index {device_idx[0]} removed from specified ({device}).')
163-
164-
device = f'{device}:{local_rank}'
168+
device = f'{device_type}:{local_rank}'
165169

166170
if device.startswith('cuda:'):
167171
torch.cuda.set_device(device)

0 commit comments

Comments
 (0)