Skip to content

Commit d10a469

Browse files
committed
fix device handling
See #969 Setting CUDA_VISIBLE_DEVICES is the recommended way to handle DDP device ids now (https://pytorch.org/docs/stable/generated/torch.cuda.set_device.html)
1 parent 1bef748 commit d10a469

File tree

1 file changed

+24
-45
lines changed

1 file changed

+24
-45
lines changed

Diff for: imagenet/main.py

+24-45
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def main_worker(gpu, ngpus_per_node, args):
137137
# For multiprocessing distributed training, rank needs to be the
138138
# global rank among all the processes
139139
args.rank = args.rank * ngpus_per_node + gpu
140+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
140141
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
141142
world_size=args.world_size, rank=args.rank)
142143
# create model
@@ -154,20 +155,14 @@ def main_worker(gpu, ngpus_per_node, args):
154155
# should always set the single device scope, otherwise,
155156
# DistributedDataParallel will use all available devices.
156157
if torch.cuda.is_available():
158+
model.cuda()
157159
if args.gpu is not None:
158-
torch.cuda.set_device(args.gpu)
159-
model.cuda(args.gpu)
160160
# When using a single GPU per process and per
161161
# DistributedDataParallel, we need to divide the batch size
162162
# ourselves based on the total number of GPUs of the current node.
163163
args.batch_size = int(args.batch_size / ngpus_per_node)
164164
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
165-
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
166-
else:
167-
model.cuda()
168-
# DistributedDataParallel will divide and allocate batch_size to all
169-
# available GPUs if device_ids are not set
170-
model = torch.nn.parallel.DistributedDataParallel(model)
165+
model = torch.nn.parallel.DistributedDataParallel(model)
171166
elif args.gpu is not None and torch.cuda.is_available():
172167
torch.cuda.set_device(args.gpu)
173168
model = model.cuda(args.gpu)
@@ -183,7 +178,7 @@ def main_worker(gpu, ngpus_per_node, args):
183178
model = torch.nn.DataParallel(model).cuda()
184179

185180
if torch.cuda.is_available():
186-
if args.gpu:
181+
if args.gpu and not args.distributed:
187182
device = torch.device('cuda:{}'.format(args.gpu))
188183
else:
189184
device = torch.device("cuda")
@@ -205,17 +200,11 @@ def main_worker(gpu, ngpus_per_node, args):
205200
if args.resume:
206201
if os.path.isfile(args.resume):
207202
print("=> loading checkpoint '{}'".format(args.resume))
208-
if args.gpu is None:
209-
checkpoint = torch.load(args.resume)
210-
elif torch.cuda.is_available():
211-
# Map model to be loaded to specified single gpu.
212-
loc = 'cuda:{}'.format(args.gpu)
213-
checkpoint = torch.load(args.resume, map_location=loc)
203+
checkpoint = torch.load(args.resume, map_location=device)
214204
args.start_epoch = checkpoint['epoch']
215205
best_acc1 = checkpoint['best_acc1']
216-
if args.gpu is not None:
217-
# best_acc1 may be from a checkpoint from a different GPU
218-
best_acc1 = best_acc1.to(args.gpu)
206+
# best_acc1 may be from a checkpoint from a different GPU
207+
best_acc1 = best_acc1.to(device=device)
219208
model.load_state_dict(checkpoint['state_dict'])
220209
optimizer.load_state_dict(checkpoint['optimizer'])
221210
scheduler.load_state_dict(checkpoint['scheduler'])
@@ -270,7 +259,7 @@ def main_worker(gpu, ngpus_per_node, args):
270259
num_workers=args.workers, pin_memory=True, sampler=val_sampler)
271260

272261
if args.evaluate:
273-
validate(val_loader, model, criterion, args)
262+
validate(val_loader, model, criterion, device, args)
274263
return
275264

276265
for epoch in range(args.start_epoch, args.epochs):
@@ -281,7 +270,7 @@ def main_worker(gpu, ngpus_per_node, args):
281270
train(train_loader, model, criterion, optimizer, epoch, device, args)
282271

283272
# evaluate on validation set
284-
acc1 = validate(val_loader, model, criterion, args)
273+
acc1 = validate(val_loader, model, criterion, device, args)
285274

286275
scheduler.step()
287276

@@ -302,11 +291,11 @@ def main_worker(gpu, ngpus_per_node, args):
302291

303292

304293
def train(train_loader, model, criterion, optimizer, epoch, device, args):
305-
batch_time = AverageMeter('Time', ':6.3f')
306-
data_time = AverageMeter('Data', ':6.3f')
307-
losses = AverageMeter('Loss', ':.4e')
308-
top1 = AverageMeter('Acc@1', ':6.2f')
309-
top5 = AverageMeter('Acc@5', ':6.2f')
294+
batch_time = AverageMeter('Time', device, ':6.3f')
295+
data_time = AverageMeter('Data', device, ':6.3f')
296+
losses = AverageMeter('Loss', device, ':.4e')
297+
top1 = AverageMeter('Acc@1', device, ':6.2f')
298+
top5 = AverageMeter('Acc@5', device, ':6.2f')
310299
progress = ProgressMeter(
311300
len(train_loader),
312301
[batch_time, data_time, losses, top1, top5],
@@ -347,20 +336,15 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args):
347336
progress.display(i + 1)
348337

349338

350-
def validate(val_loader, model, criterion, args):
339+
def validate(val_loader, model, criterion, device, args):
351340

352341
def run_validate(loader, base_progress=0):
353342
with torch.no_grad():
354343
end = time.time()
355344
for i, (images, target) in enumerate(loader):
356345
i = base_progress + i
357-
if args.gpu is not None and torch.cuda.is_available():
358-
images = images.cuda(args.gpu, non_blocking=True)
359-
if torch.backends.mps.is_available():
360-
images = images.to('mps')
361-
target = target.to('mps')
362-
if torch.cuda.is_available():
363-
target = target.cuda(args.gpu, non_blocking=True)
346+
images = images.to(device, non_blocking=True)
347+
target = target.to(device, non_blocking=True)
364348

365349
# compute output
366350
output = model(images)
@@ -379,10 +363,10 @@ def run_validate(loader, base_progress=0):
379363
if i % args.print_freq == 0:
380364
progress.display(i + 1)
381365

382-
batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
383-
losses = AverageMeter('Loss', ':.4e', Summary.NONE)
384-
top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
385-
top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
366+
batch_time = AverageMeter('Time', device, ':6.3f', Summary.NONE)
367+
losses = AverageMeter('Loss', device, ':.4e', Summary.NONE)
368+
top1 = AverageMeter('Acc@1', device, ':6.2f', Summary.AVERAGE)
369+
top5 = AverageMeter('Acc@5', device, ':6.2f', Summary.AVERAGE)
386370
progress = ProgressMeter(
387371
len(val_loader) + (args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset))),
388372
[batch_time, losses, top1, top5],
@@ -422,8 +406,9 @@ class Summary(Enum):
422406

423407
class AverageMeter(object):
424408
"""Computes and stores the average and current value"""
425-
def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
409+
def __init__(self, name, device, fmt=':f', summary_type=Summary.AVERAGE):
426410
self.name = name
411+
self.device = device
427412
self.fmt = fmt
428413
self.summary_type = summary_type
429414
self.reset()
@@ -441,13 +426,7 @@ def update(self, val, n=1):
441426
self.avg = self.sum / self.count
442427

443428
def all_reduce(self):
444-
if torch.cuda.is_available():
445-
device = torch.device("cuda")
446-
elif torch.backends.mps.is_available():
447-
device = torch.device("mps")
448-
else:
449-
device = torch.device("cpu")
450-
total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
429+
total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=self.device)
451430
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
452431
self.sum, self.count = total.tolist()
453432
self.avg = self.sum / self.count

0 commit comments

Comments
 (0)