Skip to content

Commit bca1b66

Browse files
mauvilsarohitgr7
authored andcommitted
Fixed use of LightningCLI in computer_vision_fine_tuning.py example (Lightning-AI#9934)
1 parent 99af0c5 commit bca1b66

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
544544
- Reset `val_dataloader` in `tuner/batch_size_scaling` ([#9857](https://github.com/PyTorchLightning/pytorch-lightning/pull/9857))
545545

546546

547+
- Fixed use of `LightningCLI` in computer_vision_fine_tuning.py example ([#9934](https://github.com/PyTorchLightning/pytorch-lightning/pull/9934))
548+
549+
547550
## [1.4.9] - 2021-09-30
548551

549552
- Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704))

pl_examples/domain_templates/computer_vision_fine_tuning.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
3535
Note:
3636
See: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
37+
38+
To run:
39+
python computer_vision_fine_tuning.py fit
3740
"""
3841

3942
import logging
@@ -265,7 +268,7 @@ def configure_optimizers(self):
265268

266269
class MyLightningCLI(LightningCLI):
267270
def add_arguments_to_parser(self, parser):
268-
parser.add_class_arguments(MilestonesFinetuning, "finetuning")
271+
parser.add_lightning_class_args(MilestonesFinetuning, "finetuning")
269272
parser.link_arguments("data.batch_size", "model.batch_size")
270273
parser.link_arguments("finetuning.milestones", "model.milestones")
271274
parser.link_arguments("finetuning.train_bn", "model.train_bn")
@@ -277,11 +280,6 @@ def add_arguments_to_parser(self, parser):
277280
}
278281
)
279282

280-
def instantiate_trainer(self, *args):
281-
finetuning_callback = MilestonesFinetuning(**self._get(self.config_init, "finetuning"))
282-
self.trainer_defaults["callbacks"] = [finetuning_callback]
283-
return super().instantiate_trainer(*args)
284-
285283

286284
def cli_main():
287285
MyLightningCLI(TransferLearningModel, CatDogImageDataModule, seed_everything_default=1234)

0 commit comments

Comments
 (0)