Skip to content

Commit 1f09cf2

Browse files
authored
Fixed use of LightningCLI in computer_vision_fine_tuning.py example (#9934)
1 parent 5e8829b commit 1f09cf2

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
541541
- Reset `val_dataloader` in `tuner/batch_size_scaling` ([#9857](https://github.com/PyTorchLightning/pytorch-lightning/pull/9857))
542542

543543

544+
- Fixed use of `LightningCLI` in computer_vision_fine_tuning.py example ([#9934](https://github.com/PyTorchLightning/pytorch-lightning/pull/9934))
545+
546+
544547
## [1.4.9] - 2021-09-30
545548

546549
- Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704))

pl_examples/domain_templates/computer_vision_fine_tuning.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
3535
Note:
3636
See: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
37+
38+
To run:
39+
python computer_vision_fine_tuning.py fit
3740
"""
3841

3942
import logging
@@ -265,7 +268,7 @@ def configure_optimizers(self):
265268

266269
class MyLightningCLI(LightningCLI):
267270
def add_arguments_to_parser(self, parser):
268-
parser.add_class_arguments(MilestonesFinetuning, "finetuning")
271+
parser.add_lightning_class_args(MilestonesFinetuning, "finetuning")
269272
parser.link_arguments("data.batch_size", "model.batch_size")
270273
parser.link_arguments("finetuning.milestones", "model.milestones")
271274
parser.link_arguments("finetuning.train_bn", "model.train_bn")
@@ -277,11 +280,6 @@ def add_arguments_to_parser(self, parser):
277280
}
278281
)
279282

280-
def instantiate_trainer(self, *args):
281-
finetuning_callback = MilestonesFinetuning(**self._get(self.config_init, "finetuning"))
282-
self.trainer_defaults["callbacks"] = [finetuning_callback]
283-
return super().instantiate_trainer(*args)
284-
285283

286284
def cli_main():
287285
MyLightningCLI(TransferLearningModel, CatDogImageDataModule, seed_everything_default=1234)

0 commit comments

Comments
 (0)