-
Notifications
You must be signed in to change notification settings - Fork 157
Distillation support for torchvision script #1310
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changes look good, but need to call manager.update_loss (or whatever the call is) to actually use distillation loss
Great catch, updated! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left one comment, otherwise LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great @rahul-tuli
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like loss_update returns the new loss to use! So close 😀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀 LETS GOOOOOO
Co-authored-by: corey-nm <[email protected]>
* Add support for `self` distillation and `disable` * Pull out model creation into a method * Add support to distill with another model * Add modifier loss update before backward pass * bugfix, set loss * Update src/sparseml/pytorch/torchvision/train.py Co-authored-by: corey-nm <[email protected]> Co-authored-by: corey-nm <[email protected]>
* Add support for `self` distillation and `disable` * Pull out model creation into a method * Add support to distill with another model * Add modifier loss update before backward pass * bugfix, set loss * Update src/sparseml/pytorch/torchvision/train.py Co-authored-by: corey-nm <[email protected]> Co-authored-by: corey-nm <[email protected]>
The goal of this PR is to add distillation support to our
pytorch/torchvision
integrationTest recipe:
distillation.yaml
:Test commands, (run manually):
sparseml.image_classification.train \ --recipe distillation.yaml --pretrained True --pretrained-dataset imagenette \ --arch-key resnet50 --dataset-path /home/rahul/datasets/imagenette/imagenette-160 \ --batch-size 128 --opt SGD --output-dir ./training-runs/image_classification-pretrained \ --distill-teacher self
--distill-teacher
specified)sparseml.image_classification.train \ --recipe distillation.yaml --pretrained True --pretrained-dataset imagenette \ --arch-key resnet50 --dataset-path /home/rahul/datasets/imagenette/imagenette-160 \ --batch-size 128 --opt SGD --output-dir ./training-runs/image_classification-pretrained
sparseml.image_classification.train \ --recipe distillation.yaml --pretrained True --pretrained-dataset imagenette \ --arch-key resnet50 --dataset-path /home/rahul/datasets/imagenette/imagenette-160 \ --batch-size 128 --opt SGD --output-dir ./training-runs/image_classification-pretrained \ --distill-teacher disable
mobilenet
using aresnet50
teacher from sparsezoosparseml.image_classification.train \ --recipe distillation.yaml --pretrained True --pretrained-dataset imagenette \ --arch-key mobilenet --dataset-path /home/rahul/datasets/imagenette/imagenette-160 \ --batch-size 128 --opt SGD --output-dir ./training-runs/image_classification-pretrained \ --distill-teacher zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/base-none \ --pretrained-teacher-dataset imagenet --teacher-arch-key resnet50