Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 6be8dd8

Browse files
authored
Merge branch 'main' into feature/torchvision-distillation-support
2 parents 735956e + da3ffb0 commit 6be8dd8

File tree

1 file changed

+47
-15
lines changed
  • src/sparseml/pytorch/torchvision

1 file changed

+47
-15
lines changed

src/sparseml/pytorch/torchvision/train.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import warnings
2424
from functools import update_wrapper
2525
from types import SimpleNamespace
26-
from typing import Optional
26+
from typing import Callable, Optional
2727

2828
import torch
2929
import torch.utils.data
@@ -62,23 +62,32 @@ def train_one_epoch(
6262
device: torch.device,
6363
epoch: int,
6464
args,
65+
log_metrics_fn: Callable[[str, utils.MetricLogger, int, int], None],
6566
manager=None,
6667
model_ema=None,
6768
scaler=None,
6869
) -> utils.MetricLogger:
70+
accum_steps = args.gradient_accum_steps
71+
6972
model.train()
7073
metric_logger = utils.MetricLogger(_LOGGER, delimiter=" ")
7174
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
72-
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
75+
metric_logger.add_meter(
76+
"imgs_per_sec", utils.SmoothedValue(window_size=10, fmt="{value}")
77+
)
78+
metric_logger.add_meter("loss", utils.SmoothedValue(window_size=accum_steps))
79+
metric_logger.add_meter("acc1", utils.SmoothedValue(window_size=accum_steps))
80+
metric_logger.add_meter("acc5", utils.SmoothedValue(window_size=accum_steps))
7381

7482
steps_accumulated = 0
83+
num_optim_steps = 0
7584

7685
# initial zero grad for gradient accumulation
7786
optimizer.zero_grad()
7887

7988
header = f"Epoch: [{epoch}]"
80-
for i, (image, target) in enumerate(
81-
metric_logger.log_every(data_loader, args.print_freq, header)
89+
for (image, target) in metric_logger.log_every(
90+
data_loader, args.logging_steps * accum_steps, header
8291
):
8392
start_time = time.time()
8493
image, target = image.to(device), target.to(device)
@@ -89,7 +98,7 @@ def train_one_epoch(
8998
output = output[0]
9099
loss = criterion(output, target)
91100

92-
if steps_accumulated % args.gradient_accum_steps == 0:
101+
if steps_accumulated % accum_steps == 0:
93102
if manager is not None:
94103
loss = manager.loss_update(
95104
loss=loss,
@@ -119,9 +128,10 @@ def train_one_epoch(
119128

120129
# zero grad here to start accumulating next set of gradients
121130
optimizer.zero_grad()
131+
num_optim_steps += 1
122132
steps_accumulated += 1
123133

124-
if model_ema and i % args.model_ema_steps == 0:
134+
if model_ema and num_optim_steps % args.model_ema_steps == 0:
125135
model_ema.update_parameters(model)
126136
if epoch < args.lr_warmup_epochs:
127137
# Reset ema buffer to keep copying weights during warmup period
@@ -132,7 +142,12 @@ def train_one_epoch(
132142
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
133143
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
134144
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
135-
metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
145+
metric_logger.meters["imgs_per_sec"].update(
146+
batch_size / (time.time() - start_time)
147+
)
148+
149+
if num_optim_steps % args.logging_steps == 0:
150+
log_metrics_fn("Train", metric_logger, epoch, num_optim_steps)
136151
return metric_logger
137152

138153

@@ -504,10 +519,17 @@ def collate_fn(batch):
504519
criterion,
505520
data_loader_test,
506521
device,
522+
print_freq=args.logging_steps,
507523
log_suffix="EMA",
508524
)
509525
else:
510-
evaluate(model, criterion, data_loader_test, device)
526+
evaluate(
527+
model,
528+
criterion,
529+
data_loader_test,
530+
device,
531+
print_freq=args.logging_steps,
532+
)
511533
return
512534

513535
if utils.is_main_process():
@@ -523,10 +545,13 @@ def collate_fn(batch):
523545
else:
524546
logger = LoggerManager(log_python=False)
525547

526-
def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int):
548+
steps_per_epoch = len(data_loader) / args.gradient_accum_steps
549+
550+
def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: int):
551+
step = int(epoch * steps_per_epoch + epoch_step)
527552
for metric_name, smoothed_value in metrics.meters.items():
528553
logger.log_scalar(
529-
f"{tag}/{metric_name}", smoothed_value.global_avg, step=epoch
554+
f"{tag}/{metric_name}", smoothed_value.global_avg, step=step
530555
)
531556

532557
if manager is not None:
@@ -537,7 +562,7 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int):
537562
distillation_teacher=args.distill_teacher,
538563
)
539564
optimizer = manager.modify(
540-
model, optimizer, len(data_loader), epoch=args.start_epoch
565+
model, optimizer, steps_per_epoch=steps_per_epoch, epoch=args.start_epoch
541566
)
542567

543568
lr_scheduler = _get_lr_scheduler(
@@ -570,17 +595,18 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int):
570595
device,
571596
epoch,
572597
args,
598+
log_metrics,
573599
manager=manager,
574600
model_ema=model_ema,
575601
scaler=scaler,
576602
)
577-
log_metrics("Train", train_metrics, epoch)
603+
log_metrics("Train", train_metrics, epoch, steps_per_epoch)
578604

579605
if lr_scheduler:
580606
lr_scheduler.step()
581607

582608
eval_metrics = evaluate(model, criterion, data_loader_test, device)
583-
log_metrics("Test", eval_metrics, epoch)
609+
log_metrics("Test", eval_metrics, epoch, steps_per_epoch)
584610

585611
top1_acc = eval_metrics.acc1.global_avg
586612
if model_ema:
@@ -591,7 +617,7 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int):
591617
device,
592618
log_suffix="EMA",
593619
)
594-
log_metrics("Test/EMA", ema_eval_metrics, epoch)
620+
log_metrics("Test/EMA", ema_eval_metrics, epoch, steps_per_epoch)
595621

596622
is_new_best = epoch >= args.save_best_after and top1_acc > best_top1_acc
597623
if is_new_best:
@@ -916,7 +942,13 @@ def new_func(*args, **kwargs):
916942
type=float,
917943
help="minimum lr of lr schedule",
918944
)
919-
@click.option("--print-freq", default=10, type=int, help="print frequency")
945+
@click.option("--print-freq", default=None, type=int, help="DEPRECATED. Does nothing.")
946+
@click.option(
947+
"--logging-steps",
948+
default=10,
949+
type=int,
950+
help="Frequency in number of batch updates for logging/printing",
951+
)
920952
@click.option("--output-dir", default=".", type=str, help="path to save outputs")
921953
@click.option("--resume", default=None, type=str, help="path of checkpoint")
922954
@click.option(

0 commit comments

Comments
 (0)