@@ -137,6 +137,7 @@ def main_worker(gpu, ngpus_per_node, args):
137
137
# For multiprocessing distributed training, rank needs to be the
138
138
# global rank among all the processes
139
139
args .rank = args .rank * ngpus_per_node + gpu
140
+ os .environ ['CUDA_VISIBLE_DEVICES' ] = str (args .rank )
140
141
dist .init_process_group (backend = args .dist_backend , init_method = args .dist_url ,
141
142
world_size = args .world_size , rank = args .rank )
142
143
# create model
@@ -154,20 +155,14 @@ def main_worker(gpu, ngpus_per_node, args):
154
155
# should always set the single device scope, otherwise,
155
156
# DistributedDataParallel will use all available devices.
156
157
if torch .cuda .is_available ():
158
+ model .cuda ()
157
159
if args .gpu is not None :
158
- torch .cuda .set_device (args .gpu )
159
- model .cuda (args .gpu )
160
160
# When using a single GPU per process and per
161
161
# DistributedDataParallel, we need to divide the batch size
162
162
# ourselves based on the total number of GPUs of the current node.
163
163
args .batch_size = int (args .batch_size / ngpus_per_node )
164
164
args .workers = int ((args .workers + ngpus_per_node - 1 ) / ngpus_per_node )
165
- model = torch .nn .parallel .DistributedDataParallel (model , device_ids = [args .gpu ])
166
- else :
167
- model .cuda ()
168
- # DistributedDataParallel will divide and allocate batch_size to all
169
- # available GPUs if device_ids are not set
170
- model = torch .nn .parallel .DistributedDataParallel (model )
165
+ model = torch .nn .parallel .DistributedDataParallel (model )
171
166
elif args .gpu is not None and torch .cuda .is_available ():
172
167
torch .cuda .set_device (args .gpu )
173
168
model = model .cuda (args .gpu )
@@ -183,7 +178,7 @@ def main_worker(gpu, ngpus_per_node, args):
183
178
model = torch .nn .DataParallel (model ).cuda ()
184
179
185
180
if torch .cuda .is_available ():
186
- if args .gpu :
181
+ if args .gpu and not args . distributed :
187
182
device = torch .device ('cuda:{}' .format (args .gpu ))
188
183
else :
189
184
device = torch .device ("cuda" )
@@ -205,17 +200,11 @@ def main_worker(gpu, ngpus_per_node, args):
205
200
if args .resume :
206
201
if os .path .isfile (args .resume ):
207
202
print ("=> loading checkpoint '{}'" .format (args .resume ))
208
- if args .gpu is None :
209
- checkpoint = torch .load (args .resume )
210
- elif torch .cuda .is_available ():
211
- # Map model to be loaded to specified single gpu.
212
- loc = 'cuda:{}' .format (args .gpu )
213
- checkpoint = torch .load (args .resume , map_location = loc )
203
+ checkpoint = torch .load (args .resume , map_location = device )
214
204
args .start_epoch = checkpoint ['epoch' ]
215
205
best_acc1 = checkpoint ['best_acc1' ]
216
- if args .gpu is not None :
217
- # best_acc1 may be from a checkpoint from a different GPU
218
- best_acc1 = best_acc1 .to (args .gpu )
206
+ # best_acc1 may be from a checkpoint from a different GPU
207
+ best_acc1 = best_acc1 .to (device = device )
219
208
model .load_state_dict (checkpoint ['state_dict' ])
220
209
optimizer .load_state_dict (checkpoint ['optimizer' ])
221
210
scheduler .load_state_dict (checkpoint ['scheduler' ])
@@ -270,7 +259,7 @@ def main_worker(gpu, ngpus_per_node, args):
270
259
num_workers = args .workers , pin_memory = True , sampler = val_sampler )
271
260
272
261
if args .evaluate :
273
- validate (val_loader , model , criterion , args )
262
+ validate (val_loader , model , criterion , device , args )
274
263
return
275
264
276
265
for epoch in range (args .start_epoch , args .epochs ):
@@ -281,7 +270,7 @@ def main_worker(gpu, ngpus_per_node, args):
281
270
train (train_loader , model , criterion , optimizer , epoch , device , args )
282
271
283
272
# evaluate on validation set
284
- acc1 = validate (val_loader , model , criterion , args )
273
+ acc1 = validate (val_loader , model , criterion , device , args )
285
274
286
275
scheduler .step ()
287
276
@@ -302,11 +291,11 @@ def main_worker(gpu, ngpus_per_node, args):
302
291
303
292
304
293
def train (train_loader , model , criterion , optimizer , epoch , device , args ):
305
- batch_time = AverageMeter ('Time' , ':6.3f' )
306
- data_time = AverageMeter ('Data' , ':6.3f' )
307
- losses = AverageMeter ('Loss' , ':.4e' )
308
- top1 = AverageMeter ('Acc@1' , ':6.2f' )
309
- top5 = AverageMeter ('Acc@5' , ':6.2f' )
294
+ batch_time = AverageMeter ('Time' , device , ':6.3f' )
295
+ data_time = AverageMeter ('Data' , device , ':6.3f' )
296
+ losses = AverageMeter ('Loss' , device , ':.4e' )
297
+ top1 = AverageMeter ('Acc@1' , device , ':6.2f' )
298
+ top5 = AverageMeter ('Acc@5' , device , ':6.2f' )
310
299
progress = ProgressMeter (
311
300
len (train_loader ),
312
301
[batch_time , data_time , losses , top1 , top5 ],
@@ -347,20 +336,15 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args):
347
336
progress .display (i + 1 )
348
337
349
338
350
- def validate (val_loader , model , criterion , args ):
339
+ def validate (val_loader , model , criterion , device , args ):
351
340
352
341
def run_validate (loader , base_progress = 0 ):
353
342
with torch .no_grad ():
354
343
end = time .time ()
355
344
for i , (images , target ) in enumerate (loader ):
356
345
i = base_progress + i
357
- if args .gpu is not None and torch .cuda .is_available ():
358
- images = images .cuda (args .gpu , non_blocking = True )
359
- if torch .backends .mps .is_available ():
360
- images = images .to ('mps' )
361
- target = target .to ('mps' )
362
- if torch .cuda .is_available ():
363
- target = target .cuda (args .gpu , non_blocking = True )
346
+ images = images .to (device , non_blocking = True )
347
+ target = target .to (device , non_blocking = True )
364
348
365
349
# compute output
366
350
output = model (images )
@@ -379,10 +363,10 @@ def run_validate(loader, base_progress=0):
379
363
if i % args .print_freq == 0 :
380
364
progress .display (i + 1 )
381
365
382
- batch_time = AverageMeter ('Time' , ':6.3f' , Summary .NONE )
383
- losses = AverageMeter ('Loss' , ':.4e' , Summary .NONE )
384
- top1 = AverageMeter ('Acc@1' , ':6.2f' , Summary .AVERAGE )
385
- top5 = AverageMeter ('Acc@5' , ':6.2f' , Summary .AVERAGE )
366
+ batch_time = AverageMeter ('Time' , device , ':6.3f' , Summary .NONE )
367
+ losses = AverageMeter ('Loss' , device , ':.4e' , Summary .NONE )
368
+ top1 = AverageMeter ('Acc@1' , device , ':6.2f' , Summary .AVERAGE )
369
+ top5 = AverageMeter ('Acc@5' , device , ':6.2f' , Summary .AVERAGE )
386
370
progress = ProgressMeter (
387
371
len (val_loader ) + (args .distributed and (len (val_loader .sampler ) * args .world_size < len (val_loader .dataset ))),
388
372
[batch_time , losses , top1 , top5 ],
@@ -422,8 +406,9 @@ class Summary(Enum):
422
406
423
407
class AverageMeter (object ):
424
408
"""Computes and stores the average and current value"""
425
- def __init__ (self , name , fmt = ':f' , summary_type = Summary .AVERAGE ):
409
+ def __init__ (self , name , device , fmt = ':f' , summary_type = Summary .AVERAGE ):
426
410
self .name = name
411
+ self .device = device
427
412
self .fmt = fmt
428
413
self .summary_type = summary_type
429
414
self .reset ()
@@ -441,13 +426,7 @@ def update(self, val, n=1):
441
426
self .avg = self .sum / self .count
442
427
443
428
def all_reduce (self ):
444
- if torch .cuda .is_available ():
445
- device = torch .device ("cuda" )
446
- elif torch .backends .mps .is_available ():
447
- device = torch .device ("mps" )
448
- else :
449
- device = torch .device ("cpu" )
450
- total = torch .tensor ([self .sum , self .count ], dtype = torch .float32 , device = device )
429
+ total = torch .tensor ([self .sum , self .count ], dtype = torch .float32 , device = self .device )
451
430
dist .all_reduce (total , dist .ReduceOp .SUM , async_op = False )
452
431
self .sum , self .count = total .tolist ()
453
432
self .avg = self .sum / self .count
0 commit comments