@@ -270,7 +270,7 @@ def main_worker(gpu, ngpus_per_node, args):
270
270
num_workers = args .workers , pin_memory = True , sampler = val_sampler )
271
271
272
272
if args .evaluate :
273
- validate (val_loader , model , criterion , args )
273
+ validate (val_loader , model , criterion , device , args )
274
274
return
275
275
276
276
for epoch in range (args .start_epoch , args .epochs ):
@@ -281,7 +281,7 @@ def main_worker(gpu, ngpus_per_node, args):
281
281
train (train_loader , model , criterion , optimizer , epoch , device , args )
282
282
283
283
# evaluate on validation set
284
- acc1 = validate (val_loader , model , criterion , args )
284
+ acc1 = validate (val_loader , model , criterion , device , args )
285
285
286
286
scheduler .step ()
287
287
@@ -347,21 +347,15 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args):
347
347
progress .display (i + 1 )
348
348
349
349
350
- def validate (val_loader , model , criterion , args ):
350
+ def validate (val_loader , model , criterion , device , args ):
351
351
352
352
def run_validate (loader , base_progress = 0 ):
353
353
with torch .no_grad ():
354
354
end = time .time ()
355
355
for i , (images , target ) in enumerate (loader ):
356
356
i = base_progress + i
357
- if args .gpu is not None and torch .cuda .is_available ():
358
- images = images .cuda (args .gpu , non_blocking = True )
359
- if torch .backends .mps .is_available ():
360
- images = images .to ('mps' )
361
- target = target .to ('mps' )
362
- if torch .cuda .is_available ():
363
- target = target .cuda (args .gpu , non_blocking = True )
364
-
357
+ images = images .to (device , non_blocking = True )
358
+ target = target .to (device , non_blocking = True )
365
359
# compute output
366
360
output = model (images )
367
361
loss = criterion (output , target )
0 commit comments