Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit adb30a0

Browse files
rahul-tulicorey-nm
andauthored
Distillation support for torchvision script (#1310)
* Add support for `self` distillation and `disable` * Pull out model creation into a method * Add support to distill with another model * Add modifier loss update before backward pass * bugfix, set loss * Update src/sparseml/pytorch/torchvision/train.py Co-authored-by: corey-nm <[email protected]> Co-authored-by: corey-nm <[email protected]>
1 parent da3ffb0 commit adb30a0

File tree

1 file changed

+103
-23
lines changed
  • src/sparseml/pytorch/torchvision

1 file changed

+103
-23
lines changed

src/sparseml/pytorch/torchvision/train.py

Lines changed: 103 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import warnings
2424
from functools import update_wrapper
2525
from types import SimpleNamespace
26-
from typing import Callable
26+
from typing import Callable, Optional
2727

2828
import torch
2929
import torch.utils.data
@@ -63,6 +63,7 @@ def train_one_epoch(
6363
epoch: int,
6464
args,
6565
log_metrics_fn: Callable[[str, utils.MetricLogger, int, int], None],
66+
manager=None,
6667
model_ema=None,
6768
scaler=None,
6869
) -> utils.MetricLogger:
@@ -91,13 +92,24 @@ def train_one_epoch(
9192
start_time = time.time()
9293
image, target = image.to(device), target.to(device)
9394
with torch.cuda.amp.autocast(enabled=scaler is not None):
94-
output = model(image)
95+
outputs = output = model(image)
9596
if isinstance(output, tuple):
9697
# NOTE: sparseml models return two things (logits & probs)
9798
output = output[0]
9899
loss = criterion(output, target)
99100

100101
if steps_accumulated % accum_steps == 0:
102+
if manager is not None:
103+
loss = manager.loss_update(
104+
loss=loss,
105+
module=model,
106+
optimizer=optimizer,
107+
epoch=epoch,
108+
steps_per_epoch=len(data_loader) / accum_steps,
109+
student_outputs=outputs,
110+
student_inputs=image,
111+
)
112+
101113
# first: do training to consume gradients
102114
if scaler is not None:
103115
scaler.scale(loss).backward()
@@ -348,27 +360,28 @@ def collate_fn(batch):
348360
)
349361

350362
_LOGGER.info("Creating model")
351-
if args.arch_key in ModelRegistry.available_keys():
352-
with torch_distributed_zero_first(args.rank if args.distributed else None):
353-
model = ModelRegistry.create(
354-
key=args.arch_key,
355-
pretrained=args.pretrained,
356-
pretrained_path=args.checkpoint_path,
357-
pretrained_dataset=args.pretrained_dataset,
358-
num_classes=num_classes,
359-
)
360-
elif args.arch_key in torchvision.models.__dict__:
361-
# fall back to torchvision
362-
model = torchvision.models.__dict__[args.arch_key](
363-
pretrained=args.pretrained, num_classes=num_classes
364-
)
365-
if args.checkpoint_path is not None:
366-
load_model(args.checkpoint_path, model, strict=True)
367-
else:
368-
raise ValueError(
369-
f"Unable to find {args.arch_key} in ModelRegistry or in torchvision.models"
363+
local_rank = args.rank if args.distributed else None
364+
model = _create_model(
365+
arch_key=args.arch_key,
366+
local_rank=local_rank,
367+
pretrained=args.pretrained,
368+
checkpoint_path=args.checkpoint_path,
369+
pretrained_dataset=args.pretrained_dataset,
370+
device=device,
371+
num_classes=num_classes,
372+
)
373+
374+
if args.distill_teacher not in ["self", "disable", None]:
375+
_LOGGER.info("Instantiating teacher")
376+
args.distill_teacher = _create_model(
377+
arch_key=args.teacher_arch_key,
378+
local_rank=local_rank,
379+
pretrained=True, # teacher is always pretrained
380+
pretrained_dataset=args.pretrained_teacher_dataset,
381+
checkpoint_path=args.distill_teacher,
382+
device=device,
383+
num_classes=num_classes,
370384
)
371-
model.to(device)
372385

373386
if args.distributed and args.sync_bn:
374387
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
@@ -542,7 +555,12 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
542555
)
543556

544557
if manager is not None:
545-
manager.initialize(model, epoch=args.start_epoch, loggers=logger)
558+
manager.initialize(
559+
model,
560+
epoch=args.start_epoch,
561+
loggers=logger,
562+
distillation_teacher=args.distill_teacher,
563+
)
546564
optimizer = manager.modify(
547565
model, optimizer, steps_per_epoch=steps_per_epoch, epoch=args.start_epoch
548566
)
@@ -578,6 +596,7 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
578596
epoch,
579597
args,
580598
log_metrics,
599+
manager=manager,
581600
model_ema=model_ema,
582601
scaler=scaler,
583602
)
@@ -650,6 +669,39 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
650669
_LOGGER.info(f"Training time {total_time_str}")
651670

652671

672+
def _create_model(
673+
arch_key: Optional[str] = None,
674+
local_rank=None,
675+
pretrained: Optional[bool] = False,
676+
checkpoint_path: Optional[str] = None,
677+
pretrained_dataset: Optional[str] = None,
678+
device=None,
679+
num_classes=None,
680+
):
681+
if arch_key in ModelRegistry.available_keys():
682+
with torch_distributed_zero_first(local_rank):
683+
model = ModelRegistry.create(
684+
key=arch_key,
685+
pretrained=pretrained,
686+
pretrained_path=checkpoint_path,
687+
pretrained_dataset=pretrained_dataset,
688+
num_classes=num_classes,
689+
)
690+
elif arch_key in torchvision.models.__dict__:
691+
# fall back to torchvision
692+
model = torchvision.models.__dict__[arch_key](
693+
pretrained=pretrained, num_classes=num_classes
694+
)
695+
if checkpoint_path is not None:
696+
load_model(checkpoint_path, model, strict=True)
697+
else:
698+
raise ValueError(
699+
f"Unable to find {arch_key} in ModelRegistry or in torchvision.models"
700+
)
701+
model.to(device)
702+
return model
703+
704+
653705
def _get_lr_scheduler(args, optimizer, checkpoint=None, manager=None):
654706
lr_scheduler = None
655707

@@ -1026,6 +1078,34 @@ def new_func(*args, **kwargs):
10261078
help="Save the best validation result after the given "
10271079
"epoch completes until the end of training",
10281080
)
1081+
@click.option(
1082+
"--distill-teacher",
1083+
default=None,
1084+
type=str,
1085+
help="Teacher model for distillation (a trained image classification model)"
1086+
" can be set to 'self' for self-distillation and 'disable' to switch-off"
1087+
" distillation, additionally can also take in a SparseZoo stub",
1088+
)
1089+
@click.option(
1090+
"--pretrained-teacher-dataset",
1091+
default=None,
1092+
type=str,
1093+
help=(
1094+
"The dataset to load pretrained weights for the teacher"
1095+
"Load the default dataset for the architecture if set to None. "
1096+
"examples:`imagenet`, `cifar10`, etc..."
1097+
),
1098+
)
1099+
@click.option(
1100+
"--teacher-arch-key",
1101+
default=None,
1102+
type=str,
1103+
help=(
1104+
"The architecture key for teacher image classification model; "
1105+
"example: `resnet50`, `mobilenet`. "
1106+
"Note: Will be read from the checkpoint if not specified"
1107+
),
1108+
)
10291109
@click.pass_context
10301110
def cli(ctx, **kwargs):
10311111
"""

0 commit comments

Comments
 (0)