Skip to content

Commit ad89b50

Browse files
authored
fixed evaluate to match train function
1 parent cdef4d4 commit ad89b50

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

Diff for: imagenet/main.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def main_worker(gpu, ngpus_per_node, args):
270270
num_workers=args.workers, pin_memory=True, sampler=val_sampler)
271271

272272
if args.evaluate:
273-
validate(val_loader, model, criterion, args)
273+
validate(val_loader, model, criterion, device, args)
274274
return
275275

276276
for epoch in range(args.start_epoch, args.epochs):
@@ -281,7 +281,7 @@ def main_worker(gpu, ngpus_per_node, args):
281281
train(train_loader, model, criterion, optimizer, epoch, device, args)
282282

283283
# evaluate on validation set
284-
acc1 = validate(val_loader, model, criterion, args)
284+
acc1 = validate(val_loader, model, criterion, device, args)
285285

286286
scheduler.step()
287287

@@ -347,21 +347,15 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args):
347347
progress.display(i + 1)
348348

349349

350-
def validate(val_loader, model, criterion, args):
350+
def validate(val_loader, model, criterion, device, args):
351351

352352
def run_validate(loader, base_progress=0):
353353
with torch.no_grad():
354354
end = time.time()
355355
for i, (images, target) in enumerate(loader):
356356
i = base_progress + i
357-
if args.gpu is not None and torch.cuda.is_available():
358-
images = images.cuda(args.gpu, non_blocking=True)
359-
if torch.backends.mps.is_available():
360-
images = images.to('mps')
361-
target = target.to('mps')
362-
if torch.cuda.is_available():
363-
target = target.cuda(args.gpu, non_blocking=True)
364-
357+
images = images.to(device, non_blocking=True)
358+
target = target.to(device, non_blocking=True)
365359
# compute output
366360
output = model(images)
367361
loss = criterion(output, target)

0 commit comments

Comments
 (0)