|
23 | 23 | import warnings
|
24 | 24 | from functools import update_wrapper
|
25 | 25 | from types import SimpleNamespace
|
26 |
| -from typing import Callable |
| 26 | +from typing import Callable, Optional |
27 | 27 |
|
28 | 28 | import torch
|
29 | 29 | import torch.utils.data
|
@@ -63,6 +63,7 @@ def train_one_epoch(
|
63 | 63 | epoch: int,
|
64 | 64 | args,
|
65 | 65 | log_metrics_fn: Callable[[str, utils.MetricLogger, int, int], None],
|
| 66 | + manager=None, |
66 | 67 | model_ema=None,
|
67 | 68 | scaler=None,
|
68 | 69 | ) -> utils.MetricLogger:
|
@@ -91,13 +92,24 @@ def train_one_epoch(
|
91 | 92 | start_time = time.time()
|
92 | 93 | image, target = image.to(device), target.to(device)
|
93 | 94 | with torch.cuda.amp.autocast(enabled=scaler is not None):
|
94 |
| - output = model(image) |
| 95 | + outputs = output = model(image) |
95 | 96 | if isinstance(output, tuple):
|
96 | 97 | # NOTE: sparseml models return two things (logits & probs)
|
97 | 98 | output = output[0]
|
98 | 99 | loss = criterion(output, target)
|
99 | 100 |
|
100 | 101 | if steps_accumulated % accum_steps == 0:
|
| 102 | + if manager is not None: |
| 103 | + loss = manager.loss_update( |
| 104 | + loss=loss, |
| 105 | + module=model, |
| 106 | + optimizer=optimizer, |
| 107 | + epoch=epoch, |
| 108 | + steps_per_epoch=len(data_loader) / accum_steps, |
| 109 | + student_outputs=outputs, |
| 110 | + student_inputs=image, |
| 111 | + ) |
| 112 | + |
101 | 113 | # first: do training to consume gradients
|
102 | 114 | if scaler is not None:
|
103 | 115 | scaler.scale(loss).backward()
|
@@ -348,27 +360,28 @@ def collate_fn(batch):
|
348 | 360 | )
|
349 | 361 |
|
350 | 362 | _LOGGER.info("Creating model")
|
351 |
| - if args.arch_key in ModelRegistry.available_keys(): |
352 |
| - with torch_distributed_zero_first(args.rank if args.distributed else None): |
353 |
| - model = ModelRegistry.create( |
354 |
| - key=args.arch_key, |
355 |
| - pretrained=args.pretrained, |
356 |
| - pretrained_path=args.checkpoint_path, |
357 |
| - pretrained_dataset=args.pretrained_dataset, |
358 |
| - num_classes=num_classes, |
359 |
| - ) |
360 |
| - elif args.arch_key in torchvision.models.__dict__: |
361 |
| - # fall back to torchvision |
362 |
| - model = torchvision.models.__dict__[args.arch_key]( |
363 |
| - pretrained=args.pretrained, num_classes=num_classes |
364 |
| - ) |
365 |
| - if args.checkpoint_path is not None: |
366 |
| - load_model(args.checkpoint_path, model, strict=True) |
367 |
| - else: |
368 |
| - raise ValueError( |
369 |
| - f"Unable to find {args.arch_key} in ModelRegistry or in torchvision.models" |
| 363 | + local_rank = args.rank if args.distributed else None |
| 364 | + model = _create_model( |
| 365 | + arch_key=args.arch_key, |
| 366 | + local_rank=local_rank, |
| 367 | + pretrained=args.pretrained, |
| 368 | + checkpoint_path=args.checkpoint_path, |
| 369 | + pretrained_dataset=args.pretrained_dataset, |
| 370 | + device=device, |
| 371 | + num_classes=num_classes, |
| 372 | + ) |
| 373 | + |
| 374 | + if args.distill_teacher not in ["self", "disable", None]: |
| 375 | + _LOGGER.info("Instantiating teacher") |
| 376 | + args.distill_teacher = _create_model( |
| 377 | + arch_key=args.teacher_arch_key, |
| 378 | + local_rank=local_rank, |
| 379 | + pretrained=True, # teacher is always pretrained |
| 380 | + pretrained_dataset=args.pretrained_teacher_dataset, |
| 381 | + checkpoint_path=args.distill_teacher, |
| 382 | + device=device, |
| 383 | + num_classes=num_classes, |
370 | 384 | )
|
371 |
| - model.to(device) |
372 | 385 |
|
373 | 386 | if args.distributed and args.sync_bn:
|
374 | 387 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
@@ -542,7 +555,12 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
|
542 | 555 | )
|
543 | 556 |
|
544 | 557 | if manager is not None:
|
545 |
| - manager.initialize(model, epoch=args.start_epoch, loggers=logger) |
| 558 | + manager.initialize( |
| 559 | + model, |
| 560 | + epoch=args.start_epoch, |
| 561 | + loggers=logger, |
| 562 | + distillation_teacher=args.distill_teacher, |
| 563 | + ) |
546 | 564 | optimizer = manager.modify(
|
547 | 565 | model, optimizer, steps_per_epoch=steps_per_epoch, epoch=args.start_epoch
|
548 | 566 | )
|
@@ -578,6 +596,7 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
|
578 | 596 | epoch,
|
579 | 597 | args,
|
580 | 598 | log_metrics,
|
| 599 | + manager=manager, |
581 | 600 | model_ema=model_ema,
|
582 | 601 | scaler=scaler,
|
583 | 602 | )
|
@@ -650,6 +669,39 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i
|
650 | 669 | _LOGGER.info(f"Training time {total_time_str}")
|
651 | 670 |
|
652 | 671 |
|
| 672 | +def _create_model( |
| 673 | + arch_key: Optional[str] = None, |
| 674 | + local_rank=None, |
| 675 | + pretrained: Optional[bool] = False, |
| 676 | + checkpoint_path: Optional[str] = None, |
| 677 | + pretrained_dataset: Optional[str] = None, |
| 678 | + device=None, |
| 679 | + num_classes=None, |
| 680 | +): |
| 681 | + if arch_key in ModelRegistry.available_keys(): |
| 682 | + with torch_distributed_zero_first(local_rank): |
| 683 | + model = ModelRegistry.create( |
| 684 | + key=arch_key, |
| 685 | + pretrained=pretrained, |
| 686 | + pretrained_path=checkpoint_path, |
| 687 | + pretrained_dataset=pretrained_dataset, |
| 688 | + num_classes=num_classes, |
| 689 | + ) |
| 690 | + elif arch_key in torchvision.models.__dict__: |
| 691 | + # fall back to torchvision |
| 692 | + model = torchvision.models.__dict__[arch_key]( |
| 693 | + pretrained=pretrained, num_classes=num_classes |
| 694 | + ) |
| 695 | + if checkpoint_path is not None: |
| 696 | + load_model(checkpoint_path, model, strict=True) |
| 697 | + else: |
| 698 | + raise ValueError( |
| 699 | + f"Unable to find {arch_key} in ModelRegistry or in torchvision.models" |
| 700 | + ) |
| 701 | + model.to(device) |
| 702 | + return model |
| 703 | + |
| 704 | + |
653 | 705 | def _get_lr_scheduler(args, optimizer, checkpoint=None, manager=None):
|
654 | 706 | lr_scheduler = None
|
655 | 707 |
|
@@ -1026,6 +1078,34 @@ def new_func(*args, **kwargs):
|
1026 | 1078 | help="Save the best validation result after the given "
|
1027 | 1079 | "epoch completes until the end of training",
|
1028 | 1080 | )
|
| 1081 | +@click.option( |
| 1082 | + "--distill-teacher", |
| 1083 | + default=None, |
| 1084 | + type=str, |
| 1085 | + help="Teacher model for distillation (a trained image classification model)" |
| 1086 | + " can be set to 'self' for self-distillation and 'disable' to switch-off" |
| 1087 | + " distillation, additionally can also take in a SparseZoo stub", |
| 1088 | +) |
| 1089 | +@click.option( |
| 1090 | + "--pretrained-teacher-dataset", |
| 1091 | + default=None, |
| 1092 | + type=str, |
| 1093 | + help=( |
| 1094 | + "The dataset to load pretrained weights for the teacher" |
| 1095 | + "Load the default dataset for the architecture if set to None. " |
| 1096 | + "examples:`imagenet`, `cifar10`, etc..." |
| 1097 | + ), |
| 1098 | +) |
| 1099 | +@click.option( |
| 1100 | + "--teacher-arch-key", |
| 1101 | + default=None, |
| 1102 | + type=str, |
| 1103 | + help=( |
| 1104 | + "The architecture key for teacher image classification model; " |
| 1105 | + "example: `resnet50`, `mobilenet`. " |
| 1106 | + "Note: Will be read from the checkpoint if not specified" |
| 1107 | + ), |
| 1108 | +) |
1029 | 1109 | @click.pass_context
|
1030 | 1110 | def cli(ctx, **kwargs):
|
1031 | 1111 | """
|
|
0 commit comments